diff --git a/.dockerignore b/.dockerignore new file mode 100644 index 000000000..2b50a4a9d --- /dev/null +++ b/.dockerignore @@ -0,0 +1,69 @@ +# Git +.git/ +.gitignore +.github/ + +# Node.js +node_modules/ +npm-debug.log* +yarn-debug.log* +yarn-error.log* + +# Build outputs +dist/ +build/ +out/ +*.tsbuildinfo + +# Logs +logs/ +*.log + +# Environment files +.env +.env.* +*.env + +# IDE and editor files +.vscode/ +.idea/ +*.swp +*.swo +*~ + +# OS files +.DS_Store +Thumbs.db + +# Testing +coverage/ +.nyc_output/ + +# Cache directories +.cache/ +.parcel-cache/ + +# Documentation that's not needed for build +docs/ +README.md +*.md + +# CI/CD +ci/ + +# Plugin build artifacts +plugins/*/dist/ + +# Test directories +tests/ +test/ +__tests__/ + +# Temporary files +tmp/ +temp/ +.tmp/ + +# Go workspaces (local only) +go.work +go.work.sum \ No newline at end of file diff --git a/.editorconfig b/.editorconfig new file mode 100644 index 000000000..7223b342a --- /dev/null +++ b/.editorconfig @@ -0,0 +1,9 @@ +root = true + +[*] +insert_final_newline = false +end_of_line = lf +charset = utf-8 + +[*.{js,jsx,ts,tsx,mjs,json,md,css,scss,html}] +insert_final_newline = false diff --git a/.github/ISSUE_TEMPLATE/bug_report.yml b/.github/ISSUE_TEMPLATE/bug_report.yml new file mode 100644 index 000000000..42db6746b --- /dev/null +++ b/.github/ISSUE_TEMPLATE/bug_report.yml @@ -0,0 +1,131 @@ +name: Bug report +description: Report a problem or regression in Bifrost +title: "[Bug]: " +labels: [bug] +assignees: [] +body: + - type: markdown + attributes: + value: | + Thanks for taking the time to fill out a bug report! Please provide as much detail as possible. + + - type: checkboxes + id: prerequisites + attributes: + label: Prerequisites + options: + - label: I have searched existing issues and discussions to avoid duplicates + required: true + - label: I am using the latest version (or have tested against main/nightly) + required: false + + - type: textarea + id: description + attributes: + label: Description + description: What happened? Include screenshots if helpful. + placeholder: Clear and concise description of the bug + validations: + required: true + + - type: textarea + id: reproduction + attributes: + label: Steps to reproduce + description: Provide a minimal, reproducible example. Link to a repo, gist, or include exact steps. + placeholder: | + 1. Go to '...' + 2. Run '...' + 3. Observe '...' + validations: + required: true + + - type: input + id: expected + attributes: + label: Expected behavior + placeholder: What did you expect to happen? + validations: + required: true + + - type: input + id: actual + attributes: + label: Actual behavior + placeholder: What actually happened? + validations: + required: true + + - type: dropdown + id: area + attributes: + label: Affected area(s) + multiple: true + options: + - Core (Go) + - Framework + - Transports (HTTP) + - Plugins + - UI (Next.js) + - Docs + validations: + required: true + + - type: input + id: version + attributes: + label: Version + description: Affected version(s). + placeholder: e.g., v1.0.3 + validations: + required: true + + - type: textarea + id: env + attributes: + label: Environment + description: Include as many as apply. + placeholder: | + - OS: macOS 14.5, Linux x.y, Windows 11 + - Go: 1.22.x + - Node: 20.x, npm/pnpm/yarn version + - Browser (if UI): Chrome/Firefox/Safari versions + - Bifrost components and versions (core, transports, ui) + - Any relevant environment flags/config + render: text + validations: + required: false + + - type: textarea + id: logs + attributes: + label: Relevant logs/output + description: Paste error logs, stack traces, or console output. + render: shell + placeholder: | + + validations: + required: false + + - type: input + id: regression + attributes: + label: Regression? + description: If this worked in a previous version, which version? + placeholder: e.g., Worked in v0.8.0, broke in v0.9.0 + validations: + required: false + + - type: dropdown + id: severity + attributes: + label: Severity + options: + - Low (minor issue or cosmetic) + - Medium (some functionality impaired) + - High (major functionality broken) + - Critical (blocks releases or production) + validations: + required: true + + diff --git a/.github/ISSUE_TEMPLATE/config.yml b/.github/ISSUE_TEMPLATE/config.yml new file mode 100644 index 000000000..99d680b0a --- /dev/null +++ b/.github/ISSUE_TEMPLATE/config.yml @@ -0,0 +1,2 @@ +blank_issues_enabled: false + diff --git a/.github/ISSUE_TEMPLATE/docs_issue.yml b/.github/ISSUE_TEMPLATE/docs_issue.yml new file mode 100644 index 000000000..ee3ce3dbb --- /dev/null +++ b/.github/ISSUE_TEMPLATE/docs_issue.yml @@ -0,0 +1,45 @@ +name: Documentation issue +description: Report missing, unclear, or incorrect documentation +title: "[Docs]: " +labels: [documentation] +assignees: [] +body: + - type: markdown + attributes: + value: | + Help us improve the docs! Please provide links and suggestions. + + - type: checkboxes + id: prerequisites + attributes: + label: Prerequisites + options: + - label: I have searched existing issues and docs to avoid duplicates + required: true + + - type: input + id: page + attributes: + label: Affected page(s) + description: Provide the path or URL to the affected doc(s) + placeholder: docs/usage/providers.md or https://... + validations: + required: true + + - type: textarea + id: issue + attributes: + label: What’s wrong or missing? + description: Be as specific as possible. + validations: + required: true + + - type: textarea + id: suggestion + attributes: + label: Suggested change + description: Propose wording or structure improvements. + validations: + required: false + + diff --git a/.github/ISSUE_TEMPLATE/feature_request.yml b/.github/ISSUE_TEMPLATE/feature_request.yml new file mode 100644 index 000000000..c138cf2a0 --- /dev/null +++ b/.github/ISSUE_TEMPLATE/feature_request.yml @@ -0,0 +1,69 @@ +name: Feature request +description: Suggest an idea or enhancement for Bifrost +title: "[Feature]: " +labels: [enhancement] +assignees: [] +body: + - type: markdown + attributes: + value: | + Thanks for proposing a feature! Please fill out the details below. + + - type: checkboxes + id: prerequisites + attributes: + label: Prerequisites + options: + - label: I have searched existing issues and discussions to avoid duplicates + required: true + + - type: textarea + id: problem + attributes: + label: Problem to solve + description: What problem does this feature solve? Who benefits? + placeholder: Describe the problem clearly. + validations: + required: true + + - type: textarea + id: proposal + attributes: + label: Proposed solution + description: Describe your proposed API/UX/CLI. Include examples if helpful. + placeholder: Provide details about how this should work. + validations: + required: true + + - type: textarea + id: alternatives + attributes: + label: Alternatives considered + description: What other solutions or workarounds did you consider? + validations: + required: false + + - type: dropdown + id: area + attributes: + label: Area(s) + multiple: true + options: + - Core (Go) + - Framework + - Transports (HTTP) + - Plugins + - UI (Next.js) + - Docs + validations: + required: true + + - type: textarea + id: additional + attributes: + label: Additional context + description: Add any other context, sketches, or references here. + validations: + required: false + + diff --git a/.github/pull_request_template.md b/.github/pull_request_template.md new file mode 100644 index 000000000..0f339107f --- /dev/null +++ b/.github/pull_request_template.md @@ -0,0 +1,72 @@ +## Summary + +Briefly explain the purpose of this PR and the problem it solves. + +## Changes + +- What was changed and why +- Any notable design decisions or trade-offs + +## Type of change + +- [ ] Bug fix +- [ ] Feature +- [ ] Refactor +- [ ] Documentation +- [ ] Chore/CI + +## Affected areas + +- [ ] Core (Go) +- [ ] Transports (HTTP) +- [ ] Providers/Integrations +- [ ] Plugins +- [ ] UI (Next.js) +- [ ] Docs + +## How to test + +Describe the steps to validate this change. Include commands and expected outcomes. + +```sh +# Core/Transports +go version +go test ./... + +# UI +cd ui +pnpm i || npm i +pnpm test || npm test +pnpm build || npm run build +``` + +If adding new configs or environment variables, document them here. + +## Screenshots/Recordings + +If UI changes, add before/after screenshots or short clips. + +## Breaking changes + +- [ ] Yes +- [ ] No + +If yes, describe impact and migration instructions. + +## Related issues + +Link related issues and discussions. Example: Closes #123 + +## Security considerations + +Note any security implications (auth, secrets, PII, sandboxing, etc.). + +## Checklist + +- [ ] I read `docs/contributing/README.md` and followed the guidelines +- [ ] I added/updated tests where appropriate +- [ ] I updated documentation where needed +- [ ] I verified builds succeed (Go and UI) +- [ ] I verified the CI pipeline passes locally if applicable + + diff --git a/.github/workflows/npx-publish.yml b/.github/workflows/npx-publish.yml new file mode 100644 index 000000000..820e6d67c --- /dev/null +++ b/.github/workflows/npx-publish.yml @@ -0,0 +1,106 @@ +name: NPX Package Publish + +# Triggers when npx package is tagged +on: + push: + tags: + - "npx/v*" + +# Prevent concurrent runs for the same trigger +concurrency: + group: npx-publish-${{ github.ref }} + cancel-in-progress: true + +jobs: + publish: + runs-on: ubuntu-latest + permissions: + contents: write + id-token: write # Required for npm provenance + steps: + # Checkout the repository + - name: Checkout repository + uses: actions/checkout@v4 + + # Set up Node.js environment + - name: Set up Node.js + uses: actions/setup-node@v4 + with: + node-version: "20" + registry-url: "https://registry.npmjs.org" + cache: "npm" + cache-dependency-path: | + npx/package-lock.json + + # Extract and validate version from tag + - name: Extract version from tag + id: extract-version + run: ./.github/workflows/scripts/extract-npx-version.sh + + # Update package.json with the tagged version + - name: Update package version + working-directory: npx + run: | + VERSION="${{ steps.extract-version.outputs.version }}" + echo "πŸ“ Updating package.json version to $VERSION" + npm version "$VERSION" --no-git-tag-version + + # Install dependencies (if any) + - name: Install dependencies + working-directory: npx + run: npm ci + + # Run tests (if any exist) + - name: Run tests + working-directory: npx + run: | + if [ -f "package.json" ] && npm run | grep -q "test"; then + echo "πŸ§ͺ Running tests..." + npm test + else + echo "⏭️ No tests found, skipping..." + fi + + # Publish to npm + - name: Publish to npm + working-directory: npx + env: + NODE_AUTH_TOKEN: ${{ secrets.NPM_TOKEN }} + run: | + VERSION="${{ steps.extract-version.outputs.version }}" + echo "πŸ“¦ Publishing @maximhq/bifrost@${VERSION} to npm..." + if npm view @maximhq/bifrost@"${VERSION}" version >/dev/null 2>&1; then + echo "ℹ️ @maximhq/bifrost@${VERSION} already exists on npm. Skipping publish." + else + npm publish --provenance --access public + fi + + # Create GitHub release + - name: Create GitHub Release + env: + GH_TOKEN: ${{ secrets.GH_TOKEN }} + run: bash .github/workflows/scripts/create-npx-release.sh "${{ steps.extract-version.outputs.version }}" "${{ steps.extract-version.outputs.full-tag }}" + + # Discord notification + - name: Discord Notification + if: always() + env: + DISCORD_WEBHOOK: ${{ secrets.DISCORD_WEBHOOK }} + run: | + AUTHOR="${{ github.actor }}" + COMMIT_AUTHOR="$(git log -1 --pretty=%an || true)" + if [ -n "$COMMIT_AUTHOR" ]; then AUTHOR="$COMMIT_AUTHOR"; fi + if [ "${{ job.status }}" = "success" ]; then + TITLE="πŸ“¦ **NPX Package Published**" + STATUS="βœ… Success" + VERSION_LINE="**Version**: \`${{ steps.extract-version.outputs.version }}\`" + PACKAGE_LINE="**Package**: \`@maximhq/bifrost\`" + NPM_LINK="**[View on npm](https://www.npmjs.com/package/@maximhq/bifrost)**" + MESSAGE="$TITLE\n**Status**: $STATUS\n$VERSION_LINE\n$PACKAGE_LINE\n$NPM_LINK\n**Tag**: \`${{ steps.extract-version.outputs.full-tag }}\`\n**Commit**: \`${{ github.sha }}\`\n**Author**: ${AUTHOR}\n**[View Workflow Run](${{ github.server_url }}/${{ github.repository }}/actions/runs/${{ github.run_id }})**" + else + TITLE="πŸ“¦ **NPX Package Publish Failed**" + STATUS="❌ Failed" + MESSAGE="$TITLE\n**Status**: $STATUS\n**Tag**: \`${{ steps.extract-version.outputs.full-tag }}\`\n**Commit**: \`${{ github.sha }}\`\n**Author**: ${AUTHOR}\n**[View Workflow Run](${{ github.server_url }}/${{ github.repository }}/actions/runs/${{ github.run_id }})**" + fi + payload="$(jq -n --arg content "$MESSAGE" '{content:$content}')" + curl -sS -H "Content-Type: application/json" -d "$payload" "$DISCORD_WEBHOOK" diff --git a/.github/workflows/release-pipeline.yml b/.github/workflows/release-pipeline.yml new file mode 100644 index 000000000..01d16b51a --- /dev/null +++ b/.github/workflows/release-pipeline.yml @@ -0,0 +1,439 @@ +name: Release Pipeline + +# Triggers automatically on push to main when any version file changes +on: + push: + branches: ["main"] + +# Prevent concurrent runs +concurrency: + group: release-pipeline + cancel-in-progress: false + +jobs: + # Check if pipeline should be skipped based on first line of commit message + check-skip: + runs-on: ubuntu-latest + outputs: + should-skip: ${{ steps.check.outputs.should-skip }} + steps: + - name: Check if pipeline should be skipped + id: check + env: + COMMIT_MESSAGE: ${{ github.event.head_commit.message }} + run: | + FIRST_LINE=$(echo "$COMMIT_MESSAGE" | head -n 1) + if [[ "$FIRST_LINE" == *"--skip-pipeline"* ]]; then + echo "should-skip=true" >> $GITHUB_OUTPUT + else + echo "should-skip=false" >> $GITHUB_OUTPUT + fi + + # Detect what needs to be released + detect-changes: + needs: [check-skip] + runs-on: ubuntu-latest + # Skip if first line of commit message contains --skip-pipeline + if: needs.check-skip.outputs.should-skip != 'true' + outputs: + core-needs-release: ${{ steps.detect.outputs.core-needs-release }} + framework-needs-release: ${{ steps.detect.outputs.framework-needs-release }} + plugins-need-release: ${{ steps.detect.outputs.plugins-need-release }} + bifrost-http-needs-release: ${{ steps.detect.outputs.bifrost-http-needs-release }} + docker-needs-release: ${{ steps.detect.outputs.docker-needs-release }} + changed-plugins: ${{ steps.detect.outputs.changed-plugins }} + core-version: ${{ steps.detect.outputs.core-version }} + framework-version: ${{ steps.detect.outputs.framework-version }} + transport-version: ${{ steps.detect.outputs.transport-version }} + + steps: + - name: Checkout repository + uses: actions/checkout@v4 + with: + fetch-depth: 0 + fetch-tags: true + + - name: Install jq + run: | + sudo apt-get update + sudo apt-get install -y jq + + - name: Detect what needs release + id: detect + run: ./.github/workflows/scripts/detect-all-changes.sh "auto" + + core-release: + needs: [check-skip, detect-changes] + if: needs.check-skip.outputs.should-skip != 'true' && needs.detect-changes.outputs.core-needs-release == 'true' + runs-on: ubuntu-latest + permissions: + contents: write + outputs: + success: ${{ steps.release.outputs.success }} + version: ${{ needs.detect-changes.outputs.core-version }} + steps: + - name: Checkout repository + uses: actions/checkout@v4 + with: + fetch-depth: 0 + fetch-tags: true + token: ${{ secrets.GH_TOKEN }} + + - name: Set up Go + uses: actions/setup-go@v5 + with: + go-version: "1.24.1" + + - name: Configure Git + run: | + git config user.name "GitHub Actions Bot" + git config user.email "github-actions[bot]@users.noreply.github.com" + + - name: Release core + id: release + env: + GH_TOKEN: ${{ secrets.GH_TOKEN }} + run: ./.github/workflows/scripts/release-core.sh "${{ needs.detect-changes.outputs.core-version }}" + + framework-release: + needs: [check-skip, detect-changes, core-release] + if: "always() && needs.check-skip.outputs.should-skip != 'true' && needs.detect-changes.outputs.framework-needs-release == 'true' && (needs.detect-changes.outputs.core-needs-release == 'false' || needs.core-release.result == 'success' || needs.core-release.result == 'skipped')" + runs-on: ubuntu-latest + permissions: + contents: write + outputs: + success: ${{ steps.release.outputs.success }} + version: ${{ needs.detect-changes.outputs.framework-version }} + steps: + - name: Checkout repository + uses: actions/checkout@v4 + with: + fetch-depth: 0 + fetch-tags: true + token: ${{ secrets.GH_TOKEN }} + + - name: Set up Go + uses: actions/setup-go@v5 + with: + go-version: "1.24.1" + + - name: Configure Git + run: | + git config user.name "GitHub Actions Bot" + git config user.email "github-actions[bot]@users.noreply.github.com" + + - name: Set up Docker Compose + run: | + # Verify Docker is available + docker --version + # Install Docker Compose if not available as plugin + if ! docker compose version >/dev/null 2>&1; then + echo "Installing Docker Compose..." + sudo curl -L "https://github.com/docker/compose/releases/latest/download/docker-compose-$(uname -s)-$(uname -m)" -o /usr/local/bin/docker-compose + sudo chmod +x /usr/local/bin/docker-compose + docker-compose --version + else + echo "Docker Compose plugin is available" + docker compose version + fi + + - name: Release framework + id: release + env: + GH_TOKEN: ${{ secrets.GH_TOKEN }} + run: ./.github/workflows/scripts/release-framework.sh "${{ needs.detect-changes.outputs.framework-version }}" + + plugins-release: + needs: [check-skip, detect-changes, core-release, framework-release] + if: "always() && needs.check-skip.outputs.should-skip != 'true' && needs.detect-changes.outputs.plugins-need-release == 'true' && (needs.detect-changes.outputs.core-needs-release == 'false' || needs.core-release.result == 'success' || needs.core-release.result == 'skipped') && (needs.detect-changes.outputs.framework-needs-release == 'false' || needs.framework-release.result == 'success' || needs.framework-release.result == 'skipped')" + runs-on: ubuntu-latest + permissions: + contents: write + outputs: + success: ${{ steps.release.outputs.success }} + steps: + - name: Checkout repository + uses: actions/checkout@v4 + with: + fetch-depth: 0 + fetch-tags: true + token: ${{ secrets.GH_TOKEN }} + + - name: Install jq + run: | + sudo apt-get update + sudo apt-get install -y jq + + - name: Set up Go + uses: actions/setup-go@v5 + with: + go-version: "1.24.1" + + - name: Configure Git + run: | + git config user.name "GitHub Actions Bot" + git config user.email "github-actions[bot]@users.noreply.github.com" + + - name: Set up Docker Compose + run: | + # Verify Docker is available + docker --version + # Install Docker Compose if not available as plugin + if ! docker compose version >/dev/null 2>&1; then + echo "Installing Docker Compose..." + sudo curl -L "https://github.com/docker/compose/releases/latest/download/docker-compose-$(uname -s)-$(uname -m)" -o /usr/local/bin/docker-compose + sudo chmod +x /usr/local/bin/docker-compose + docker-compose --version + else + echo "Docker Compose plugin is available" + docker compose version + fi + + - name: Release all changed plugins + id: release + env: + GH_TOKEN: ${{ secrets.GH_TOKEN }} + MAXIM_API_KEY: ${{ secrets.MAXIM_API_KEY }} + MAXIM_LOGGER_ID: ${{ secrets.MAXIM_LOG_REPO_ID }} + OPENAI_API_KEY: ${{ secrets.OPENAI_API_KEY }} + run: ./.github/workflows/scripts/release-all-plugins.sh '${{ needs.detect-changes.outputs.changed-plugins }}' + + bifrost-http-release: + needs: [check-skip, detect-changes, core-release, framework-release, plugins-release] + if: "always() && needs.check-skip.outputs.should-skip != 'true' && needs.detect-changes.outputs.bifrost-http-needs-release == 'true' && (needs.detect-changes.outputs.core-needs-release == 'false' || needs.core-release.result == 'success' || needs.core-release.result == 'skipped') && (needs.detect-changes.outputs.framework-needs-release == 'false' || needs.framework-release.result == 'success' || needs.framework-release.result == 'skipped') && (needs.detect-changes.outputs.plugins-need-release == 'false' || needs.plugins-release.result == 'success' || needs.plugins-release.result == 'skipped')" + runs-on: ubuntu-latest + permissions: + contents: write + outputs: + success: ${{ steps.release.outputs.success }} + version: ${{ needs.detect-changes.outputs.transport-version }} + steps: + - name: Checkout repository + uses: actions/checkout@v4 + with: + fetch-depth: 0 + fetch-tags: true + token: ${{ secrets.GH_TOKEN }} + + - name: Set up Go + uses: actions/setup-go@v5 + with: + go-version: "1.24.1" + + - name: Set up Node.js + uses: actions/setup-node@v4 + with: + node-version: "20" + + - name: Configure Git + run: | + git config user.name "GitHub Actions Bot" + git config user.email "github-actions[bot]@users.noreply.github.com" + + - name: Set up Docker Compose + run: | + # Verify Docker is available + docker --version + # Install Docker Compose if not available as plugin + if ! docker compose version >/dev/null 2>&1; then + echo "Installing Docker Compose..." + sudo curl -L "https://github.com/docker/compose/releases/latest/download/docker-compose-$(uname -s)-$(uname -m)" -o /usr/local/bin/docker-compose + sudo chmod +x /usr/local/bin/docker-compose + docker-compose --version + else + echo "Docker Compose plugin is available" + docker compose version + fi + + - name: Release bifrost-http + id: release + env: + GH_TOKEN: ${{ secrets.GH_TOKEN }} + R2_ENDPOINT: ${{ secrets.R2_ENDPOINT }} + R2_ACCESS_KEY_ID: ${{ secrets.R2_ACCESS_KEY_ID }} + R2_SECRET_ACCESS_KEY: ${{ secrets.R2_SECRET_ACCESS_KEY }} + R2_BUCKET: ${{ secrets.R2_BUCKET }} + run: ./.github/workflows/scripts/release-bifrost-http.sh "${{ needs.detect-changes.outputs.transport-version }}" + + # Docker build amd64 + docker-build-amd64: + needs: [check-skip, detect-changes, bifrost-http-release] + if: "always() && needs.check-skip.outputs.should-skip != 'true' && needs.detect-changes.outputs.docker-needs-release == 'true' && (needs.detect-changes.outputs.bifrost-http-needs-release == 'false' || needs.bifrost-http-release.result == 'success' || needs.bifrost-http-release.result == 'skipped')" + runs-on: ubuntu-latest + permissions: + contents: write + env: + REGISTRY: docker.io + ACCOUNT: maximhq + IMAGE_NAME: bifrost + steps: + - name: Checkout repository + uses: actions/checkout@v4 + + - name: Set up Docker Buildx + uses: docker/setup-buildx-action@v3 + + - name: Log in to Docker Hub + uses: docker/login-action@v3 + with: + username: ${{ secrets.DOCKER_USERNAME }} + password: ${{ secrets.DOCKER_PASSWORD }} + + - name: Determine Docker tags + id: tags + run: | + git pull origin ${{ github.ref_name }} + VERSION="${{ needs.detect-changes.outputs.transport-version }}" + BASE_TAG="${{ env.REGISTRY }}/${{ env.ACCOUNT }}/${{ env.IMAGE_NAME }}:v${VERSION}-amd64" + echo "tags=${BASE_TAG}" >> $GITHUB_OUTPUT + + - name: Build and push AMD64 Docker image + uses: docker/build-push-action@v5 + with: + context: . + build-args: | + VERSION=${{ needs.detect-changes.outputs.transport-version }} + file: ./transports/Dockerfile + push: true + tags: ${{ steps.tags.outputs.tags }} + platforms: linux/amd64 + cache-from: type=gha + cache-to: type=gha,mode=max + + # Docker build arm64 + docker-build-arm64: + needs: [check-skip, detect-changes, bifrost-http-release] + if: "always() && needs.check-skip.outputs.should-skip != 'true' && needs.detect-changes.outputs.docker-needs-release == 'true' && (needs.detect-changes.outputs.bifrost-http-needs-release == 'false' || needs.bifrost-http-release.result == 'success' || needs.bifrost-http-release.result == 'skipped')" + runs-on: ubuntu-24.04-arm + permissions: + contents: write + env: + REGISTRY: docker.io + ACCOUNT: maximhq + IMAGE_NAME: bifrost + steps: + - name: Checkout repository + uses: actions/checkout@v4 + + - name: Set up Docker Buildx + uses: docker/setup-buildx-action@v3 + + - name: Log in to Docker Hub + uses: docker/login-action@v3 + with: + username: ${{ secrets.DOCKER_USERNAME }} + password: ${{ secrets.DOCKER_PASSWORD }} + + - name: Determine Docker tags + id: tags + run: | + git pull origin ${{ github.ref_name }} + VERSION="${{ needs.detect-changes.outputs.transport-version }}" + BASE_TAG="${{ env.REGISTRY }}/${{ env.ACCOUNT }}/${{ env.IMAGE_NAME }}:v${VERSION}-arm64" + echo "tags=${BASE_TAG}" >> $GITHUB_OUTPUT + + - name: Build and push ARM64 Docker image + uses: docker/build-push-action@v5 + with: + context: . + file: ./transports/Dockerfile + push: true + build-args: | + VERSION=${{ needs.detect-changes.outputs.transport-version }} + tags: ${{ steps.tags.outputs.tags }} + platforms: linux/arm64 + cache-from: type=gha + cache-to: type=gha,mode=max + + # Docker manifest + docker-manifest: + needs: [check-skip, detect-changes, docker-build-amd64, docker-build-arm64] + if: "always() && needs.check-skip.outputs.should-skip != 'true' && needs.docker-build-amd64.result == 'success' && needs.docker-build-arm64.result == 'success'" + runs-on: ubuntu-latest + env: + REGISTRY: docker.io + ACCOUNT: maximhq + IMAGE_NAME: bifrost + steps: + - name: Checkout repository + uses: actions/checkout@v4 + + - name: Log in to Docker Hub + uses: docker/login-action@v3 + with: + username: ${{ secrets.DOCKER_USERNAME }} + password: ${{ secrets.DOCKER_PASSWORD }} + + - name: Create and push multi-arch manifest + run: | + ./.github/workflows/scripts/create-docker-manifest.sh "${{ needs.detect-changes.outputs.transport-version }}" + + # Push Mintlify changelog + push-mintlify-changelog: + needs: [check-skip, detect-changes, bifrost-http-release] + if: "always() && needs.check-skip.outputs.should-skip != 'true' && (needs.detect-changes.outputs.bifrost-http-needs-release == 'false' || needs.bifrost-http-release.result == 'success' || needs.bifrost-http-release.result == 'skipped')" + runs-on: ubuntu-latest + permissions: + contents: write + steps: + - name: Checkout repository + uses: actions/checkout@v4 + with: + fetch-depth: 0 + fetch-tags: true + token: ${{ secrets.GH_TOKEN }} + + - name: Push Mintlify changelog + run: | + ./.github/workflows/scripts/push-mintlify-changelog.sh "${{ needs.detect-changes.outputs.transport-version }}" + + # Notification + notify: + needs: [check-skip, detect-changes, core-release, framework-release, plugins-release, bifrost-http-release, docker-manifest] + if: "always() && needs.check-skip.outputs.should-skip != 'true'" + runs-on: ubuntu-latest + steps: + - name: Install jq + run: | + sudo apt-get update + sudo apt-get install -y jq + + - name: Discord Notification + env: + DISCORD_WEBHOOK: ${{ secrets.DISCORD_WEBHOOK }} + run: | + # Build status summary + CORE_STATUS="⏭️ Skipped" + FRAMEWORK_STATUS="⏭️ Skipped" + PLUGINS_STATUS="⏭️ Skipped" + BIFROST_STATUS="⏭️ Skipped" + + if [ "${{ needs.core-release.result }}" = "success" ]; then + CORE_STATUS="βœ… Released v${{ needs.detect-changes.outputs.core-version }}" + elif [ "${{ needs.core-release.result }}" = "failure" ]; then + CORE_STATUS="❌ Failed" + fi + + if [ "${{ needs.framework-release.result }}" = "success" ]; then + FRAMEWORK_STATUS="βœ… Released v${{ needs.detect-changes.outputs.framework-version }}" + elif [ "${{ needs.framework-release.result }}" = "failure" ]; then + FRAMEWORK_STATUS="❌ Failed" + fi + + if [ "${{ needs.plugins-release.result }}" = "success" ]; then + PLUGINS_STATUS="βœ… Released plugins" + elif [ "${{ needs.plugins-release.result }}" = "failure" ]; then + PLUGINS_STATUS="❌ Failed" + fi + + if [ "${{ needs.bifrost-http-release.result }}" = "success" ]; then + BIFROST_STATUS="βœ… Released v${{ needs.detect-changes.outputs.transport-version }}" + elif [ "${{ needs.bifrost-http-release.result }}" = "failure" ]; then + BIFROST_STATUS="❌ Failed" + fi + + # Build the message with proper formatting + MESSAGE=$(printf "πŸš€ **Release Pipeline Complete**\n\n**Components:**\nβ€’ Core: %s\nβ€’ Framework: %s\nβ€’ Plugins: %s\nβ€’ Bifrost HTTP: %s\n\n**Details:**\nβ€’ Branch: \`main\`\nβ€’ Commit: \`%.8s\`\nβ€’ Author: %s\n\n[View Workflow Run](${{ github.server_url }}/${{ github.repository }}/actions/runs/${{ github.run_id }})" "$CORE_STATUS" "$FRAMEWORK_STATUS" "$PLUGINS_STATUS" "$BIFROST_STATUS" "${{ github.sha }}" "${{ github.actor }}") + + payload="$(jq -n --arg content "$MESSAGE" '{content:$content}')" + curl -sS -H "Content-Type: application/json" -d "$payload" "$DISCORD_WEBHOOK" diff --git a/.github/workflows/scripts/build-executables.sh b/.github/workflows/scripts/build-executables.sh new file mode 100755 index 000000000..8b3d12b36 --- /dev/null +++ b/.github/workflows/scripts/build-executables.sh @@ -0,0 +1,98 @@ +#!/usr/bin/env bash +set -euo pipefail + +# Cross-compile Go binaries for multiple platforms +# Usage: ./build-executables.sh + +# Require version argument (matches usage) +if [[ -z "${1:-}" ]]; then + echo "Usage: $0 " >&2 + exit 1 +fi +VERSION="$1" + +echo "πŸ”¨ Building Go executables with version: $VERSION" + +# Get the script directory and project root +SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" +PROJECT_ROOT="$(cd "$SCRIPT_DIR/../../.." && pwd)" + +# Clean and create dist directory +rm -rf "$PROJECT_ROOT/dist" +mkdir -p "$PROJECT_ROOT/dist" + + +# Define platforms +platforms=( + "darwin/amd64" + "darwin/arm64" + "linux/amd64" + "linux/arm64" + "windows/amd64" +) + +MODULE_PATH="$PROJECT_ROOT/transports/bifrost-http" + + +for platform in "${platforms[@]}"; do + IFS='/' read -r PLATFORM_DIR GOARCH <<< "$platform" + + case "$PLATFORM_DIR" in + "windows") GOOS="windows" ;; + "darwin") GOOS="darwin" ;; + "linux") GOOS="linux" ;; + *) echo "Unsupported platform: $PLATFORM_DIR"; exit 1 ;; + esac + + output_name="bifrost-http" + [[ "$GOOS" = "windows" ]] && output_name+='.exe' + + echo "Building bifrost-http for $PLATFORM_DIR/$GOARCH..." + mkdir -p "$PROJECT_ROOT/dist/$PLATFORM_DIR/$GOARCH" + + # Change to the module directory for building + cd "$MODULE_PATH" + + if [[ "$GOOS" = "linux" ]]; then + if [[ "$GOARCH" = "amd64" ]]; then + CC_COMPILER="x86_64-linux-musl-gcc" + CXX_COMPILER="x86_64-linux-musl-g++" + elif [[ "$GOARCH" = "arm64" ]]; then + CC_COMPILER="aarch64-linux-musl-gcc" + CXX_COMPILER="aarch64-linux-musl-g++" + fi + + env GOWORK=off CGO_ENABLED=1 GOOS="$GOOS" GOARCH="$GOARCH" CC="$CC_COMPILER" CXX="$CXX_COMPILER" \ + go build -trimpath -tags "netgo,osusergo,sqlite_static" \ + -ldflags "-s -w -buildid= -extldflags '-static' -X main.Version=v${VERSION}" \ + -o "$PROJECT_ROOT/dist/$PLATFORM_DIR/$GOARCH/$output_name" . + + elif [[ "$GOOS" = "windows" ]]; then + if [[ "$GOARCH" = "amd64" ]]; then + CC_COMPILER="x86_64-w64-mingw32-gcc" + CXX_COMPILER="x86_64-w64-mingw32-g++" + fi + + env GOWORK=off CGO_ENABLED=1 GOOS="$GOOS" GOARCH="$GOARCH" CC="$CC_COMPILER" CXX="$CXX_COMPILER" \ + go build -trimpath -ldflags "-s -w -buildid= -X main.Version=v${VERSION}" \ + -o "$PROJECT_ROOT/dist/$PLATFORM_DIR/$GOARCH/$output_name" . + + else # Darwin (macOS) + if [[ "$GOARCH" = "amd64" ]]; then + CC_COMPILER="o64-clang" + CXX_COMPILER="o64-clang++" + elif [[ "$GOARCH" = "arm64" ]]; then + CC_COMPILER="oa64-clang" + CXX_COMPILER="oa64-clang++" + fi + + env GOWORK=off CGO_ENABLED=1 GOOS="$GOOS" GOARCH="$GOARCH" CC="$CC_COMPILER" CXX="$CXX_COMPILER" \ + go build -trimpath -ldflags "-s -w -buildid= -X main.Version=v${VERSION}" \ + -o "$PROJECT_ROOT/dist/$PLATFORM_DIR/$GOARCH/$output_name" . + fi + + # Change back to project root + cd "$PROJECT_ROOT" +done + +echo "βœ… All binaries built successfully" diff --git a/.github/workflows/scripts/changelog-utils.sh b/.github/workflows/scripts/changelog-utils.sh new file mode 100644 index 000000000..0fa7f9011 --- /dev/null +++ b/.github/workflows/scripts/changelog-utils.sh @@ -0,0 +1,13 @@ +# Function to extract changelog content from a file +# Usage: get_changelog_content +get_changelog_content() { + CHANGELOG_BODY=$(cat $1) + # Skip comments from changelog + CHANGELOG_BODY=$(echo "$CHANGELOG_BODY" | grep -v '^') + # If changelog is empty, return error + if [ -z "$CHANGELOG_BODY" ]; then + echo "❌ Changelog is empty" + exit 1 + fi + echo "$CHANGELOG_BODY" +} diff --git a/.github/workflows/scripts/check-core-version-increment.sh b/.github/workflows/scripts/check-core-version-increment.sh new file mode 100755 index 000000000..492c4901d --- /dev/null +++ b/.github/workflows/scripts/check-core-version-increment.sh @@ -0,0 +1,56 @@ +#!/usr/bin/env bash +set -euo pipefail + +# Check if core version has been incremented and needs release +# Usage: ./check-core-version-increment.sh + +CURRENT_VERSION=$(cat core/version) +TAG_NAME="core/v${CURRENT_VERSION}" + +echo "πŸ“‹ Current core version: $CURRENT_VERSION" +echo "🏷️ Expected tag: $TAG_NAME" + +# Check if tag already exists +if git rev-parse --verify "$TAG_NAME" >/dev/null 2>&1; then + echo "⚠️ Tag $TAG_NAME already exists" + { + echo "should-release=false" + echo "new-version=$CURRENT_VERSION" + echo "tag-exists=true" + } >> "$GITHUB_OUTPUT" + exit 0 +fi + +# Get previous version from git tags +LATEST_CORE_TAG=$(git tag -l "core/v*" | sort -V | tail -1) + +if [ -z "$LATEST_CORE_TAG" ]; then + echo "πŸ“¦ No existing core tags found, this will be the first release" + { + echo "should-release=true" + echo "new-version=$CURRENT_VERSION" + echo "tag-exists=false" + } >> "$GITHUB_OUTPUT" + exit 0 +fi + +PREVIOUS_VERSION=${LATEST_CORE_TAG#core/v} +echo "πŸ“‹ Previous core version: $PREVIOUS_VERSION" + +# Compare versions using sort -V (version sort) +if [ "$(printf '%s\n' "$PREVIOUS_VERSION" "$CURRENT_VERSION" | sort -V | tail -1)" = "$CURRENT_VERSION" ] && [ "$PREVIOUS_VERSION" != "$CURRENT_VERSION" ]; then + echo "βœ… Version incremented from $PREVIOUS_VERSION to $CURRENT_VERSION" + echo "πŸš€ Core release needed" + { + echo "should-release=true" + echo "new-version=$CURRENT_VERSION" + echo "tag-exists=false" + } >> "$GITHUB_OUTPUT" +else + echo "⏭️ No version increment detected (current: $CURRENT_VERSION, latest: $PREVIOUS_VERSION)" + { + echo "should-release=false" + echo "new-version=$CURRENT_VERSION" + echo "tag-exists=false" + } >> "$GITHUB_OUTPUT" +fi diff --git a/.github/workflows/scripts/check-dependency-flow.sh b/.github/workflows/scripts/check-dependency-flow.sh new file mode 100755 index 000000000..57c34a85b --- /dev/null +++ b/.github/workflows/scripts/check-dependency-flow.sh @@ -0,0 +1,81 @@ +#!/usr/bin/env bash +set -euo pipefail + +# Check the dependency flow and suggest next steps +# Usage: ./check-dependency-flow.sh [version] +# stage: core|framework|plugins +# version: required for core/framework; optional for plugins +usage() { + echo "Usage: $0 [version]" >&2 + echo "Examples:" >&2 + echo " $0 core v1.2.3" >&2 + echo " $0 framework v1.2.3" >&2 + echo " $0 plugins" >&2 +} +if [[ $# -lt 1 ]]; then + usage + exit 2 +fi +STAGE="${1:-}" +VERSION="${2:-}" + +# Validate stage first, then enforce version requirement by stage +case "$STAGE" in + core|framework|plugins) + ;; + *) + echo "❌ Unknown stage: $STAGE" >&2 + usage + exit 1 + ;; +esac + +# VERSION is required for core/framework; optional for plugins +if [[ "$STAGE" != "plugins" && -z "${VERSION:-}" ]]; then + echo "❌ VERSION is required for stage '$STAGE'." >&2 + usage + exit 2 +fi + +case "$STAGE" in + "core") + echo "πŸ”§ Core v$VERSION released!" + echo "" + echo "πŸ“‹ Dependency Flow Status:" + echo "βœ… Core: v$VERSION (just released)" + echo "❓ Framework: Check if update needed" + echo "❓ Plugins: Will check after framework" + echo "❓ Bifrost HTTP: Will check after plugins" + echo "" + echo "πŸ”„ Next Step: Manually trigger Framework Release if needed" + ;; + + "framework") + echo "πŸ“¦ Framework v$VERSION released!" + echo "" + echo "πŸ“‹ Dependency Flow Status:" + echo "βœ… Core: (already updated)" + echo "βœ… Framework: v$VERSION (just released)" + echo "❓ Plugins: Check if any need updates" + echo "❓ Bifrost HTTP: Will check after plugins" + echo "" + echo "πŸ”„ Next Step: Check Plugins Release workflow" + ;; + + "plugins") + echo "πŸ”Œ Plugins ${VERSION:+v$VERSION }released!" + echo "" + echo "πŸ“‹ Dependency Flow Status:" + echo "βœ… Core: (already updated)" + echo "βœ… Framework: (already updated)" + echo "βœ… Plugins: (just released)" + echo "❓ Bifrost HTTP: Check if update needed" + echo "" + echo "πŸ”„ Next Step: Manually trigger Bifrost HTTP Release if needed" + ;; + + *) + echo "❌ Unknown stage: $STAGE" + exit 1 + ;; +esac diff --git a/.github/workflows/scripts/configure-r2.sh b/.github/workflows/scripts/configure-r2.sh new file mode 100755 index 000000000..36085e624 --- /dev/null +++ b/.github/workflows/scripts/configure-r2.sh @@ -0,0 +1,31 @@ +#!/usr/bin/env bash +set -euo pipefail + +# Configure AWS CLI for R2 uploads +# Usage: ./configure-r2.sh + +echo "βš™οΈ Configuring AWS CLI for R2..." + +pip install awscli + +# Clean and trim environment variables (removing any whitespace) +R2_ENDPOINT="$(echo "$R2_ENDPOINT" | tr -d '[:space:]')" +R2_ACCESS_KEY_ID="$(echo "$R2_ACCESS_KEY_ID" | tr -d '[:space:]')" +R2_SECRET_ACCESS_KEY="$(echo "$R2_SECRET_ACCESS_KEY" | tr -d '[:space:]')" + +# Validate environment variables +if [ -z "$R2_ENDPOINT" ] || [ -z "$R2_ACCESS_KEY_ID" ] || [ -z "$R2_SECRET_ACCESS_KEY" ]; then + echo "❌ Missing required R2 credentials" + exit 1 +fi + +# Configure AWS CLI for R2 using dedicated profile +aws configure set --profile R2 aws_access_key_id "$R2_ACCESS_KEY_ID" +aws configure set --profile R2 aws_secret_access_key "$R2_SECRET_ACCESS_KEY" +aws configure set --profile R2 region us-east-1 +aws configure set --profile R2 s3.signature_version s3v4 + +# Test connection +echo "πŸ” Testing R2 connection..." +aws s3 ls s3://prod-downloads/ --endpoint-url "$R2_ENDPOINT" --profile R2 >/dev/null +echo "βœ… R2 connection successful" diff --git a/.github/workflows/scripts/create-docker-manifest.sh b/.github/workflows/scripts/create-docker-manifest.sh new file mode 100755 index 000000000..a594507fd --- /dev/null +++ b/.github/workflows/scripts/create-docker-manifest.sh @@ -0,0 +1,36 @@ +# Validate input argument +if [ "${1:-}" = "" ]; then + echo "Usage: $0 " >&2 + exit 1 +fi + +VERSION="$1" +REGISTRY="docker.io" +ACCOUNT="maximhq" +IMAGE_NAME="bifrost" +IMAGE="${REGISTRY}/${ACCOUNT}/${IMAGE_NAME}" + +# Get the actual image digests from the platform-specific builds +AMD64_DIGEST=$(docker manifest inspect ${IMAGE}:v${VERSION}-amd64 | jq -r '.manifests[0].digest') +ARM64_DIGEST=$(docker manifest inspect ${IMAGE}:v${VERSION}-arm64 | jq -r '.manifests[0].digest') + +echo "AMD64 digest: ${AMD64_DIGEST}" +echo "ARM64 digest: ${ARM64_DIGEST}" + +# Create manifest for versioned tag using digests +docker manifest create \ + ${IMAGE}:v${VERSION} \ + ${IMAGE}@${AMD64_DIGEST} \ + ${IMAGE}@${ARM64_DIGEST} + +docker manifest push ${IMAGE}:v${VERSION} + +# Create latest manifest only for stable versions +if [[ "$VERSION" != *-* ]]; then + docker manifest create \ + ${IMAGE}:latest \ + ${IMAGE}@${AMD64_DIGEST} \ + ${IMAGE}@${ARM64_DIGEST} + + docker manifest push ${IMAGE}:latest +fi \ No newline at end of file diff --git a/.github/workflows/scripts/create-npx-release.sh b/.github/workflows/scripts/create-npx-release.sh new file mode 100755 index 000000000..db33d5ed8 --- /dev/null +++ b/.github/workflows/scripts/create-npx-release.sh @@ -0,0 +1,80 @@ +#!/usr/bin/env bash +set -euo pipefail + +# Create GitHub release for NPX package +# Usage: ./create-npx-release.sh + +VERSION="$1" +FULL_TAG="$2" + +if [[ -z "$VERSION" || -z "$FULL_TAG" ]]; then + echo "❌ Usage: $0 " + exit 1 +fi +# Mark prereleases when version contains a hyphen +PRERELEASE_FLAG="" +if [[ "$VERSION" == *-* ]]; then + PRERELEASE_FLAG="--prerelease" +fi +TITLE="NPX Package v$VERSION" + +# Create release body +BODY="## NPX Package Release + +### πŸ“¦ NPX Package v$VERSION + +The Bifrost CLI is now available on npm! + +### Installation + +\`\`\`bash +# Install globally +npm install -g @maximhq/bifrost + +# Or use with npx (no installation needed) +npx @maximhq/bifrost --help +\`\`\` + +### Usage + +\`\`\`bash +# Start Bifrost HTTP server +bifrost + +# Use specific transport version +bifrost --transport-version v1.2.3 + +# Get help +bifrost --help +\`\`\` + +### Links + +- πŸ“¦ [View on npm](https://www.npmjs.com/package/@maximhq/bifrost) +- πŸ“š [Documentation](https://github.com/maximhq/bifrost) +- πŸ› [Report Issues](https://github.com/maximhq/bifrost/issues) + +### What's New + +This NPX package provides a convenient way to run Bifrost without manual binary downloads. The CLI automatically: + +- Detects your platform and architecture +- Downloads the appropriate binary +- Supports version pinning with \`--transport-version\` +- Provides progress indicators for downloads + +--- +_This release was automatically created from tag \`$FULL_TAG\`_" + +# Create release +echo "πŸŽ‰ Creating GitHub release for $TITLE..." +if gh release view "$FULL_TAG" >/dev/null 2>&1; then + echo "ℹ️ Release $FULL_TAG already exists. Skipping creation." + exit 0 +fi +gh release create "$FULL_TAG" \ + --title "$TITLE" \ + --notes "$BODY" \ + --latest=false \ + --verify-tag \ + ${PRERELEASE_FLAG} diff --git a/.github/workflows/scripts/detect-all-changes.sh b/.github/workflows/scripts/detect-all-changes.sh new file mode 100755 index 000000000..3972bd092 --- /dev/null +++ b/.github/workflows/scripts/detect-all-changes.sh @@ -0,0 +1,312 @@ +#!/usr/bin/env bash +set -euo pipefail +shopt -s nullglob + +# Detect what components need to be released based on version changes +# Usage: ./detect-all-changes.sh +echo "πŸ” Auto-detecting version changes across all components..." + +# Initialize outputs +CORE_NEEDS_RELEASE="false" +FRAMEWORK_NEEDS_RELEASE="false" +PLUGINS_NEED_RELEASE="false" +BIFROST_HTTP_NEEDS_RELEASE="false" +DOCKER_NEEDS_RELEASE="false" +CHANGED_PLUGINS="[]" + +# Get current versions +CORE_VERSION=$(cat core/version) +FRAMEWORK_VERSION=$(cat framework/version) +TRANSPORT_VERSION=$(cat transports/version) + +echo "πŸ“¦ Current versions:" +echo " Core: $CORE_VERSION" +echo " Framework: $FRAMEWORK_VERSION" +echo " Transport: $TRANSPORT_VERSION" + +START_FROM="none" + +# Check Core +echo "" +echo "πŸ”§ Checking core..." +CORE_TAG="core/v${CORE_VERSION}" +if git rev-parse --verify "$CORE_TAG" >/dev/null 2>&1; then + echo " ⏭️ Tag $CORE_TAG already exists" +else + # Get previous version + LATEST_CORE_TAG=$(git tag -l "core/v*" | sort -V | tail -1) + echo "🏷️ Latest core tag $LATEST_CORE_TAG" + if [ -z "$LATEST_CORE_TAG" ]; then + echo " βœ… First core release: $CORE_VERSION" + CORE_NEEDS_RELEASE="true" + else + if [[ "$CORE_VERSION" == *"-"* ]]; then + # current_version has prerelease, so include all versions but prefer stable + ALL_TAGS=$(git tag -l "core/v*" | sort -V) + STABLE_TAGS=$(echo "$ALL_TAGS" | grep -v '\-') + PRERELEASE_TAGS=$(echo "$ALL_TAGS" | grep '\-' || true) + if [ -n "$STABLE_TAGS" ]; then + # Get the highest stable version + LATEST_CORE_TAG=$(echo "$STABLE_TAGS" | tail -1) + echo "latest core tag (stable preferred): $LATEST_CORE_TAG" + else + # No stable versions, get highest prerelease + LATEST_CORE_TAG=$(echo "$PRERELEASE_TAGS" | tail -1) + echo "latest core tag (prerelease only): $LATEST_CORE_TAG" + fi + else + # VERSION has no prerelease, so only consider stable releases + LATEST_CORE_TAG=$(git tag -l "core/v*" | grep -v '\-' | sort -V | tail -1) + echo "latest core tag (stable only): $LATEST_CORE_TAG" + fi + PREVIOUS_CORE_VERSION=${LATEST_CORE_TAG#core/v} + echo " πŸ“‹ Previous: $PREVIOUS_CORE_VERSION, Current: $CORE_VERSION" + # Fixed: Use head -1 instead of tail -1 for your sort -V behavior, and check against current version + if [ "$(printf '%s\n' "$PREVIOUS_CORE_VERSION" "$CORE_VERSION" | sort -V | tail -1)" = "$CORE_VERSION" ] && [ "$PREVIOUS_CORE_VERSION" != "$CORE_VERSION" ]; then + echo " βœ… Core version incremented: $PREVIOUS_CORE_VERSION β†’ $CORE_VERSION" + CORE_NEEDS_RELEASE="true" + else + echo " ⏭️ No core version increment" + fi + fi +fi + +# Check Framework +echo "" +echo "πŸ“¦ Checking framework..." +FRAMEWORK_TAG="framework/v${FRAMEWORK_VERSION}" +if git rev-parse --verify "$FRAMEWORK_TAG" >/dev/null 2>&1; then + echo " ⏭️ Tag $FRAMEWORK_TAG already exists" +else + ALL_TAGS=$(git tag -l "framework/v*" | sort -V) + STABLE_TAGS=$(echo "$ALL_TAGS" | grep -v '\-') + PRERELEASE_TAGS=$(echo "$ALL_TAGS" | grep '\-' || true) + LATEST_FRAMEWORK_TAG="" + if [ -n "$STABLE_TAGS" ]; then + LATEST_FRAMEWORK_TAG=$(echo "$STABLE_TAGS" | tail -1) + echo "latest framework tag (stable preferred): $LATEST_FRAMEWORK_TAG" + else + LATEST_FRAMEWORK_TAG=$(echo "$PRERELEASE_TAGS" | tail -1) + echo "latest framework tag (prerelease only): $LATEST_FRAMEWORK_TAG" + fi + if [ -z "$LATEST_FRAMEWORK_TAG" ]; then + echo " βœ… First framework release: $FRAMEWORK_VERSION" + FRAMEWORK_NEEDS_RELEASE="true" + else + PREVIOUS_FRAMEWORK_VERSION=${LATEST_FRAMEWORK_TAG#framework/v} + echo " πŸ“‹ Previous: $PREVIOUS_FRAMEWORK_VERSION, Current: $FRAMEWORK_VERSION" + # Fixed: Use head -1 instead of tail -1 for your sort -V behavior, and check against current version + if [ "$(printf '%s\n' "$PREVIOUS_FRAMEWORK_VERSION" "$FRAMEWORK_VERSION" | sort -V | tail -1)" = "$FRAMEWORK_VERSION" ] && [ "$PREVIOUS_FRAMEWORK_VERSION" != "$FRAMEWORK_VERSION" ]; then + echo " βœ… Framework version incremented: $PREVIOUS_FRAMEWORK_VERSION β†’ $FRAMEWORK_VERSION" + FRAMEWORK_NEEDS_RELEASE="true" + else + echo " ⏭️ No framework version increment" + fi + fi +fi + +# Check Plugins +echo "" +echo "πŸ”Œ Checking plugins..." +PLUGIN_CHANGES=() + +for plugin_dir in plugins/*/; do + if [ ! -d "$plugin_dir" ]; then + continue + fi + + plugin_name=$(basename "$plugin_dir") + version_file="${plugin_dir}version" + + if [ ! -f "$version_file" ]; then + echo " ⚠️ No version file for: $plugin_name" + continue + fi + + current_version=$(cat "$version_file" | tr -d '\n\r') + if [ -z "$current_version" ]; then + echo " ⚠️ Empty version file for: $plugin_name" + continue + fi + + tag_name="plugins/${plugin_name}/v${current_version}" + echo " πŸ“¦ Plugin: $plugin_name (v$current_version)" + + if git rev-parse --verify "$tag_name" >/dev/null 2>&1; then + echo " ⏭️ Tag already exists" + continue + fi + + if [[ "$current_version" == *"-"* ]]; then + # current_version has prerelease, so include all versions but prefer stable + ALL_TAGS=$(git tag -l "plugins/${plugin_name}/v*" | sort -V) + STABLE_TAGS=$(echo "$ALL_TAGS" | grep -v '\-') + PRERELEASE_TAGS=$(echo "$ALL_TAGS" | grep '\-') + + if [ -n "$STABLE_TAGS" ]; then + # Get the highest stable version + LATEST_PLUGIN_TAG=$(echo "$STABLE_TAGS" | tail -1) + echo "latest plugin tag (stable preferred): $LATEST_PLUGIN_TAG" + else + # No stable versions, get highest prerelease + LATEST_PLUGIN_TAG=$(echo "$PRERELEASE_TAGS" | tail -1) + echo "latest plugin tag (prerelease only): $LATEST_PLUGIN_TAG" + fi + else + # VERSION has no prerelease, so only consider stable releases + LATEST_PLUGIN_TAG=$(git tag -l "plugins/${plugin_name}/v*" | grep -v '\-' | sort -V | tail -1) + echo "latest plugin tag (stable only): $LATEST_PLUGIN_TAG" + fi + + latest_tag=$LATEST_PLUGIN_TAG + if [ -z "$latest_tag" ]; then + echo " βœ… First release" + PLUGIN_CHANGES+=("$plugin_name") + else + previous_version=${latest_tag#plugins/${plugin_name}/v} + echo "previous version: $previous_version" + echo "current version: $current_version" + echo "latest tag: $latest_tag" + if [ "$(printf '%s\n' "$previous_version" "$current_version" | sort -V | tail -1)" = "$current_version" ] && [ "$previous_version" != "$current_version" ]; then + echo " βœ… Version incremented: $previous_version β†’ $current_version" + PLUGIN_CHANGES+=("$plugin_name") + else + echo " ⏭️ No version increment" + fi + fi +done + +if [ ${#PLUGIN_CHANGES[@]} -gt 0 ]; then + PLUGINS_NEED_RELEASE="true" + echo " πŸ”„ Plugins with changes: ${PLUGIN_CHANGES[*]}" +else + echo " ⏭️ No plugin changes detected" +fi + +# Check Bifrost HTTP +echo "" +echo "πŸš€ Checking bifrost-http..." +TRANSPORT_TAG="transports/v${TRANSPORT_VERSION}" +DOCKER_TAG_EXISTS="false" + +# Check if Git tag exists +GIT_TAG_EXISTS="false" +if git rev-parse --verify "$TRANSPORT_TAG" >/dev/null 2>&1; then + echo " ⏭️ Git tag $TRANSPORT_TAG already exists" + GIT_TAG_EXISTS="true" +fi + +# Check if Docker tag exists on DockerHub +echo " 🐳 Checking DockerHub for tag v${TRANSPORT_VERSION}..." +DOCKER_CHECK_RESPONSE=$(curl -s "https://registry.hub.docker.com/v2/repositories/maximhq/bifrost/tags/v${TRANSPORT_VERSION}/" 2>/dev/null || echo "") +if [ -n "$DOCKER_CHECK_RESPONSE" ] && echo "$DOCKER_CHECK_RESPONSE" | grep -q '"name"'; then + echo " ⏭️ Docker tag v${TRANSPORT_VERSION} already exists on DockerHub" + DOCKER_TAG_EXISTS="true" +else + echo " ❌ Docker tag v${TRANSPORT_VERSION} not found on DockerHub" +fi + +# Determine if release is needed +if [ "$GIT_TAG_EXISTS" = "true" ] && [ "$DOCKER_TAG_EXISTS" = "true" ]; then + echo " ⏭️ Both Git tag and Docker image exist - no release needed" +else + # Get all transport tags, prioritize stable over prerelease for same base version + ALL_TRANSPORT_TAGS=$(git tag -l "transports/v*" | sort -V) + + # Function to get base version (remove prerelease suffix) + get_base_version() { + echo "$1" | sed 's/-.*$//' + } + + # Find the latest version, prioritizing stable over prerelease + LATEST_TRANSPORT_TAG="" + LATEST_BASE_VERSION="" + + for tag in $ALL_TRANSPORT_TAGS; do + version=${tag#transports/v} + base_version=$(get_base_version "$version") + + # If this base version is newer, or same base version but current is stable and we had prerelease + if [ -z "$LATEST_BASE_VERSION" ] || \ + [ "$(printf '%s\n' "$LATEST_BASE_VERSION" "$base_version" | sort -V | tail -1)" = "$base_version" ]; then + + if [ "$base_version" = "$LATEST_BASE_VERSION" ]; then + # Same base version - prefer stable (no hyphen) over prerelease + if [[ "$version" != *"-"* ]] && [[ "${LATEST_TRANSPORT_TAG#transports/v}" == *"-"* ]]; then + LATEST_TRANSPORT_TAG="$tag" + fi + else + # New base version is higher + LATEST_TRANSPORT_TAG="$tag" + LATEST_BASE_VERSION="$base_version" + fi + fi + done + if [ -z "$LATEST_TRANSPORT_TAG" ]; then + echo " βœ… First transport release: $TRANSPORT_VERSION" + if [ "$GIT_TAG_EXISTS" = "false" ]; then + echo " 🏷️ Git tag missing - transport release needed" + BIFROST_HTTP_NEEDS_RELEASE="true" + fi + else + PREVIOUS_TRANSPORT_VERSION=${LATEST_TRANSPORT_TAG#transports/v} + echo " πŸ“‹ Previous: $PREVIOUS_TRANSPORT_VERSION, Current: $TRANSPORT_VERSION" + # Debug the sort behavior + sorted_first=$(printf '%s\n' "$PREVIOUS_TRANSPORT_VERSION" "$TRANSPORT_VERSION" | sort -V | head -1) + echo " πŸ” DEBUG: sort -V | head -1 returns: '$sorted_first'" + echo " πŸ” DEBUG: Current version: '$TRANSPORT_VERSION'" + echo " πŸ” DEBUG: Versions different? $([ "$PREVIOUS_TRANSPORT_VERSION" != "$TRANSPORT_VERSION" ] && echo "YES" || echo "NO")" + # Fixed: Check if previous version sorts first (meaning current is greater) + if [ "$sorted_first" = "$PREVIOUS_TRANSPORT_VERSION" ] && [ "$PREVIOUS_TRANSPORT_VERSION" != "$TRANSPORT_VERSION" ]; then + echo " βœ… Transport version incremented: $PREVIOUS_TRANSPORT_VERSION β†’ $TRANSPORT_VERSION" + if [ "$GIT_TAG_EXISTS" = "false" ]; then + echo " 🏷️ Git tag missing - transport release needed" + BIFROST_HTTP_NEEDS_RELEASE="true" + fi + else + echo " ⏭️ No transport version increment" + fi + fi +fi + +# Check if Docker image needs to be built (independent of transport release) +if [ "$DOCKER_TAG_EXISTS" = "false" ]; then + echo " 🐳 Docker image missing - docker release needed" + DOCKER_NEEDS_RELEASE="true" +fi + + +# Convert plugin array to JSON (compact format) +if [ ${#PLUGIN_CHANGES[@]} -eq 0 ]; then + CHANGED_PLUGINS_JSON="[]" +else + CHANGED_PLUGINS_JSON=$(printf '%s\n' "${PLUGIN_CHANGES[@]}" | jq -R . | jq -s -c .) +fi + +echo "CHANGED_PLUGINS_JSON: $CHANGED_PLUGINS_JSON" + +# Summary +echo "" +echo "πŸ“‹ Release Summary:" +echo " Core: $CORE_NEEDS_RELEASE (v$CORE_VERSION)" +echo " Framework: $FRAMEWORK_NEEDS_RELEASE (v$FRAMEWORK_VERSION)" +echo " Plugins: $PLUGINS_NEED_RELEASE (${#PLUGIN_CHANGES[@]} plugins)" +echo " Bifrost HTTP: $BIFROST_HTTP_NEEDS_RELEASE (v$TRANSPORT_VERSION)" +echo " Docker: $DOCKER_NEEDS_RELEASE (v$TRANSPORT_VERSION)" + +# Set outputs (only when running in GitHub Actions) +if [ -n "${GITHUB_OUTPUT:-}" ]; then + { + echo "core-needs-release=$CORE_NEEDS_RELEASE" + echo "framework-needs-release=$FRAMEWORK_NEEDS_RELEASE" + echo "plugins-need-release=$PLUGINS_NEED_RELEASE" + echo "bifrost-http-needs-release=$BIFROST_HTTP_NEEDS_RELEASE" + echo "docker-needs-release=$DOCKER_NEEDS_RELEASE" + echo "changed-plugins=$CHANGED_PLUGINS_JSON" + echo "core-version=$CORE_VERSION" + echo "framework-version=$FRAMEWORK_VERSION" + echo "transport-version=$TRANSPORT_VERSION" + } >> "$GITHUB_OUTPUT" +else + echo "ℹ️ GITHUB_OUTPUT not set; skipping outputs write (local run)" +fi \ No newline at end of file diff --git a/.github/workflows/scripts/extract-npx-version.sh b/.github/workflows/scripts/extract-npx-version.sh new file mode 100755 index 000000000..c6c89b516 --- /dev/null +++ b/.github/workflows/scripts/extract-npx-version.sh @@ -0,0 +1,36 @@ +#!/usr/bin/env bash +set -euo pipefail + +# Extract NPX version from tag +# Usage: ./extract-npx-version.sh + +# Extract tag name from ref (prefer GITHUB_REF_NAME, fallback to GITHUB_REF) +# Use an intermediate to avoid set -u errors when both are unset in local runs +RAW_REF="${GITHUB_REF_NAME:-${GITHUB_REF:-}}" +TAG_NAME="${RAW_REF#refs/tags/}" +if [[ -z "${TAG_NAME}" ]]; then + echo "❌ TAG_NAME is empty. Ensure this runs on a tag ref or set GITHUB_REF_NAME." + exit 1 +fi + +echo "πŸ“‹ Processing tag: ${TAG_NAME}" + +# Validate tag format (npx/vX.Y.Z or prerelease like npx/vX.Y.Z-rc.1) +if [[ ! "${TAG_NAME}" =~ ^npx/v[0-9]+\.[0-9]+\.[0-9]+(-[0-9A-Za-z.-]+)?(\+[0-9A-Za-z.-]+)?$ ]]; then + echo "❌ Invalid tag format '${TAG_NAME}'. Expected format: npx/vMAJOR.MINOR.PATCH" + exit 1 +fi + +# Extract version (remove 'npx/v' prefix to get just the version number) +VERSION="${TAG_NAME#npx/v}" +echo "πŸ“¦ Extracted NPX version: ${VERSION}" +echo "🏷️ Full tag: ${TAG_NAME}" +# Set outputs (only when running in GitHub Actions) +if [[ -n "${GITHUB_OUTPUT:-}" ]]; then + { + echo "version=${VERSION}" + echo "full-tag=${TAG_NAME}" + } >> "$GITHUB_OUTPUT" +else + echo "::notice::GITHUB_OUTPUT not set; skipping outputs (local run?)" +fi \ No newline at end of file diff --git a/.github/workflows/scripts/go-utils.sh b/.github/workflows/scripts/go-utils.sh new file mode 100755 index 000000000..1a2290385 --- /dev/null +++ b/.github/workflows/scripts/go-utils.sh @@ -0,0 +1,45 @@ +#!/usr/bin/env bash + +# Shared utilities for Go operations in release scripts +# Usage: source .github/workflows/scripts/go-utils.sh + +# Function to perform go get with exponential backoff +# Usage: go_get_with_backoff +go_get_with_backoff() { + local package="$1" + local max_attempts=30 + local initial_wait=30 + local max_wait=120 # 2 minutes + local attempt=1 + local wait_time=$initial_wait + + echo "πŸ”„ Attempting to get $package with exponential backoff..." + + while [ $attempt -le $max_attempts ]; do + echo "πŸ“¦ Attempt $attempt/$max_attempts: go get $package" + + if go get "$package"; then + echo "βœ… Successfully retrieved $package on attempt $attempt" + return 0 + fi + + if [ $attempt -eq $max_attempts ]; then + echo "❌ Failed to get $package after $max_attempts attempts" + return 1 + fi + + echo "⏳ Waiting ${wait_time}s before retry (attempt $attempt/$max_attempts failed)..." + sleep $wait_time + + # Calculate next wait time (exponential backoff) + # Double the wait time, but cap at max_wait + wait_time=$((wait_time * 2)) + if [ $wait_time -gt $max_wait ]; then + wait_time=$max_wait + fi + + attempt=$((attempt + 1)) + done + + return 1 +} diff --git a/.github/workflows/scripts/install-cross-compilers.sh b/.github/workflows/scripts/install-cross-compilers.sh new file mode 100755 index 000000000..65051171b --- /dev/null +++ b/.github/workflows/scripts/install-cross-compilers.sh @@ -0,0 +1,76 @@ +#!/usr/bin/env bash +set -euo pipefail + +# Install cross-compilation toolchains for Go + CGO +# Usage: ./install-cross-compilers.sh + +echo "πŸ“¦ Installing cross-compilation toolchains for Go + CGO..." + +# Install all required packages +sudo apt-get update +sudo apt-get install -y \ + gcc-x86-64-linux-gnu \ + gcc-aarch64-linux-gnu \ + gcc-mingw-w64-x86-64 \ + musl-tools \ + clang \ + lld \ + xz-utils \ + curl + +# Create symbolic links for musl compilers +sudo ln -sf /usr/bin/x86_64-linux-gnu-gcc /usr/local/bin/x86_64-linux-musl-gcc +sudo ln -sf /usr/bin/x86_64-linux-gnu-g++ /usr/local/bin/x86_64-linux-musl-g++ +sudo ln -sf /usr/bin/aarch64-linux-gnu-gcc /usr/local/bin/aarch64-linux-musl-gcc +sudo ln -sf /usr/bin/aarch64-linux-gnu-g++ /usr/local/bin/aarch64-linux-musl-g++ + +echo "🍎 Setting up Darwin cross-compilation..." + +# Where to install SDK +SDK_DIR="/opt/MacOSX11.3.sdk" +SDK_URL="https://github.com/phracker/MacOSX-SDKs/releases/download/11.3/MacOSX11.3.sdk.tar.xz" + +# Download and extract macOS SDK if not already installed +if [ ! -d "$SDK_DIR" ]; then + echo "πŸ“¦ Downloading macOS SDK..." + curl -L "$SDK_URL" -o /tmp/MacOSX11.3.sdk.tar.xz + sudo mkdir -p /opt + sudo tar -xf /tmp/MacOSX11.3.sdk.tar.xz -C /opt + rm -f /tmp/MacOSX11.3.sdk.tar.xz +fi + +# Create wrapper scripts with proper shebang and linker configuration +sudo tee /usr/local/bin/o64-clang > /dev/null << 'WRAPPER_EOF' +#!/bin/bash +exec clang -target x86_64-apple-darwin --sysroot=/opt/MacOSX11.3.sdk -fuse-ld=lld -Wno-unused-command-line-argument "$@" +WRAPPER_EOF + +sudo tee /usr/local/bin/o64-clang++ > /dev/null << 'WRAPPER_EOF' +#!/bin/bash +exec clang++ -target x86_64-apple-darwin --sysroot=/opt/MacOSX11.3.sdk -fuse-ld=lld -Wno-unused-command-line-argument "$@" +WRAPPER_EOF + +sudo tee /usr/local/bin/oa64-clang > /dev/null << 'WRAPPER_EOF' +#!/bin/bash +exec clang -target arm64-apple-darwin --sysroot=/opt/MacOSX11.3.sdk -fuse-ld=lld -Wno-unused-command-line-argument "$@" +WRAPPER_EOF + +sudo tee /usr/local/bin/oa64-clang++ > /dev/null << 'WRAPPER_EOF' +#!/bin/bash +exec clang++ -target arm64-apple-darwin --sysroot=/opt/MacOSX11.3.sdk -fuse-ld=lld -Wno-unused-command-line-argument "$@" +WRAPPER_EOF + +sudo chmod +x /usr/local/bin/o64-clang /usr/local/bin/o64-clang++ \ + /usr/local/bin/oa64-clang /usr/local/bin/oa64-clang++ + +echo "βœ… Darwin cross-compilation environment ready!" + +echo "βœ… Cross-compilation toolchains installed" +echo "" +echo "Available cross-compilers:" +echo " Linux amd64: x86_64-linux-musl-gcc, x86_64-linux-musl-g++" +echo " Linux arm64: aarch64-linux-musl-gcc, aarch64-linux-musl-g++" +echo " Windows amd64: x86_64-w64-mingw32-gcc, x86_64-w64-mingw32-g++" +echo " Windows arm64: aarch64-w64-mingw32-gcc, aarch64-w64-mingw32-g++" +echo " Darwin amd64: o64-clang, o64-clang++" +echo " Darwin arm64: oa64-clang, oa64-clang++" \ No newline at end of file diff --git a/.github/workflows/scripts/push-mintlify-changelog.sh b/.github/workflows/scripts/push-mintlify-changelog.sh new file mode 100755 index 000000000..14cc78b3c --- /dev/null +++ b/.github/workflows/scripts/push-mintlify-changelog.sh @@ -0,0 +1,100 @@ +#!/usr/bin/env bash + +VERSION=$1 + +if [ -z "$VERSION" ]; then + echo "Usage: $0 " + echo "Example: $0 1.2.0" + exit 1 +fi + +VERSION="v$VERSION" + +# Check if this page already exists in docs/changelogs/ +if [ -f "docs/changelogs/$VERSION.mdx" ]; then + echo "βœ… Changelog for $VERSION already exists" + exit 0 +fi + +# Source changelog utilities +source "$(dirname "$0")/changelog-utils.sh" + +# Preparing changelog file +CHANGELOG_BODY="--- +title: \"$VERSION\" +description: \"$VERSION changelog\" +---" + +# Helper to append a section if changelog file exists and is non-empty +append_section () { + label=$1 + path=$2 + if [ -f "$path" ]; then + content=$(get_changelog_content "$path") || return 0 + CHANGELOG_BODY+=$'\n'""$'\n'"$content"$'\n\n'"" + fi +} + +# HTTP changelog +append_section "Bifrost(HTTP)" transports/changelog.md + +# Core changelog +append_section "Core" core/changelog.md + +# Framework changelog +append_section "Framework" framework/changelog.md + +# Plugins changelogs +for plugin in plugins/*; do + name=$(basename "$plugin") + append_section "$name" "$plugin/changelog.md" +done + +# Write to file +mkdir -p docs/changelogs +echo "$CHANGELOG_BODY" > docs/changelogs/$VERSION.mdx + +# Update docs.json to include this new changelog route in the Changelogs tab pages array +# Handles both non-empty and empty array forms +route="changelogs/$VERSION" +if ! grep -q "\"$route\"" docs/docs.json; then + awk -v route="$route" ' + function indent(line){ + x = line + sub(/[^[:space:]].*$/, "", x) + return x + } + $0 ~ /"tab": "Changelogs"/ { in_tab=1 } + in_tab && $0 ~ /"pages": \[\]/ { + ind = indent($0) + print ind "\"pages\": [" + print ind " \"" route "\"" + print ind "]" + fixing_empty=1 + in_tab=0 + next + } + in_tab && $0 ~ /"pages": \[/ { + print + ind = indent($0) + print ind " \"" route "\"," + in_tab=0 + next + } + fixing_empty && $0 ~ /^[[:space:]]*"changelogs\/[^"]+",?$/ { + fixing_empty=0 + next + } + { print } + ' docs/docs.json > docs/docs.json.tmp && mv docs/docs.json.tmp docs/docs.json +fi + + + +# Commit and push changes +git add docs/changelogs/$VERSION.mdx +git add docs/docs.json +git config user.name "github-actions[bot]" +git config user.email "41898282+github-actions[bot]@users.noreply.github.com" +git commit -m "Adds changelog for $VERSION --skip-pipeline" +git push origin main diff --git a/.github/workflows/scripts/release-all-plugins.sh b/.github/workflows/scripts/release-all-plugins.sh new file mode 100755 index 000000000..16a21d0a1 --- /dev/null +++ b/.github/workflows/scripts/release-all-plugins.sh @@ -0,0 +1,136 @@ +#!/usr/bin/env bash +set -euo pipefail + +# Release all changed plugins sequentially +# Usage: ./release-all-plugins.sh '["plugin1", "plugin2"]' + +# Validate that an argument was provided +if [ $# -eq 0 ]; then + echo "❌ Error: Missing required argument" + echo "Usage: $0 ''" + echo "Example: $0 '[\"plugin1\", \"plugin2\"]'" + exit 1 +fi + +CHANGED_PLUGINS_JSON="$1" + +# Verify jq is available +if ! command -v jq >/dev/null 2>&1; then + echo "❌ Error: jq is required but not installed" + echo "Please install jq to parse JSON input" + exit 1 +fi + +# Validate that the input is valid JSON +if ! echo "$CHANGED_PLUGINS_JSON" | jq empty >/dev/null 2>&1; then + echo "❌ Error: Invalid JSON provided" + echo "Input: $CHANGED_PLUGINS_JSON" + echo "Please provide a valid JSON array of plugin names" + exit 1 +fi + + +# Starting dependencies of plugin tests +echo "πŸ”§ Starting dependencies of plugin tests..." +# Use docker compose (v2) if available, fallback to docker-compose (v1) +if command -v docker-compose >/dev/null 2>&1; then + docker-compose -f tests/docker-compose.yml up -d +elif docker compose version >/dev/null 2>&1; then + docker compose -f tests/docker-compose.yml up -d +else + echo "❌ Neither docker-compose nor docker compose is available" + exit 1 +fi +sleep 20 + +echo "πŸ”Œ Processing plugin releases..." +echo "πŸ“‹ Changed plugins JSON: $CHANGED_PLUGINS_JSON" + +# No work early‐exit if array is empty +if jq -e 'length==0' <<<"$CHANGED_PLUGINS_JSON" >/dev/null 2>&1; then + echo "⏭️ No plugins to release" + echo "success=true" >> "${GITHUB_OUTPUT:-/dev/null}" + exit 0 +fi + +# Convert JSON array to bash array using readarray to avoid word-splitting +if ! readarray -t PLUGINS < <(echo "$CHANGED_PLUGINS_JSON" | jq -r '.[]' 2>/dev/null); then + echo "❌ Error: Failed to parse plugin names from JSON" + echo "Input: $CHANGED_PLUGINS_JSON" + exit 1 +fi + +# Verify release-single-plugin.sh exists and is executable +RELEASE_SCRIPT="./.github/workflows/scripts/release-single-plugin.sh" +if [ ! -f "$RELEASE_SCRIPT" ]; then + echo "❌ Error: Release script not found: $RELEASE_SCRIPT" + exit 1 +fi + +if [ ! -x "$RELEASE_SCRIPT" ]; then + echo "❌ Error: Release script is not executable: $RELEASE_SCRIPT" + exit 1 +fi + +if [ ${#PLUGINS[@]} -eq 0 ]; then + echo "⏭️ No plugins to release" + echo "success=true" >> "${GITHUB_OUTPUT:-/dev/null}" + exit 0 +fi + +echo "πŸ”„ Releasing ${#PLUGINS[@]} plugins:" +for p in "${PLUGINS[@]}"; do + echo " β€’ $p" +done + +FAILED_PLUGINS=() +SUCCESS_COUNT=0 +OVERALL_EXIT_CODE=0 + +# Release each plugin +for plugin in "${PLUGINS[@]}"; do + echo "" + echo "πŸ”Œ Releasing plugin: $plugin" + + # Capture the exit code of the plugin release + if "$RELEASE_SCRIPT" "$plugin"; then + PLUGIN_EXIT_CODE=$? + echo "βœ… Successfully released: $plugin" + SUCCESS_COUNT=$((SUCCESS_COUNT + 1)) + else + PLUGIN_EXIT_CODE=$? + echo "❌ Failed to release plugin '$plugin' (exit code: $PLUGIN_EXIT_CODE)" + FAILED_PLUGINS+=("$plugin") + OVERALL_EXIT_CODE=1 + fi +done + + +# Shutting down dependencies +echo "πŸ”§ Shutting down dependencies of plugin tests..." +# Use docker compose (v2) if available, fallback to docker-compose (v1) +if command -v docker-compose >/dev/null 2>&1; then + docker-compose -f tests/docker-compose.yml down +elif docker compose version >/dev/null 2>&1; then + docker compose -f tests/docker-compose.yml down +else + echo "❌ Neither docker-compose nor docker compose is available" + exit 1 +fi + +# Summary +echo "" +echo "πŸ“‹ Plugin Release Summary:" +echo " βœ… Successful: $SUCCESS_COUNT/${#PLUGINS[@]}" +echo " ❌ Failed: ${#FAILED_PLUGINS[@]}" + +if [ ${#FAILED_PLUGINS[@]} -gt 0 ]; then + echo " Failed plugins: ${FAILED_PLUGINS[*]}" + echo "success=false" >> "${GITHUB_OUTPUT:-/dev/null}" + echo "❌ Plugin release process completed with failures" + exit $OVERALL_EXIT_CODE +else + echo " πŸŽ‰ All plugins released successfully!" + echo "success=true" >> "${GITHUB_OUTPUT:-/dev/null}" + echo "βœ… All plugin releases completed successfully" +fi diff --git a/.github/workflows/scripts/release-bifrost-http.sh b/.github/workflows/scripts/release-bifrost-http.sh new file mode 100755 index 000000000..8f42271ad --- /dev/null +++ b/.github/workflows/scripts/release-bifrost-http.sh @@ -0,0 +1,306 @@ +#!/usr/bin/env bash +set -euo pipefail + +# Release bifrost-http component +# Usage: ./release-bifrost-http.sh + +# Source Go utilities for exponential backoff +source "$(dirname "$0")/go-utils.sh" + +# Validate input argument +if [ "${1:-}" = "" ]; then + echo "Usage: $0 " >&2 + exit 1 +fi + +VERSION="$1" +TAG_NAME="transports/v${VERSION}" + +echo "πŸš€ Releasing bifrost-http v$VERSION..." + +# Ensure tags are available (CI often does shallow clones) +git fetch --tags --force >/dev/null 2>&1 || true +LATEST_CORE_TAG=$(git tag -l "core/v*" | sort -V | tail -1) +LATEST_FRAMEWORK_TAG=$(git tag -l "framework/v*" | sort -V | tail -1) + +if [ -z "$LATEST_CORE_TAG" ]; then + CORE_VERSION="v$(tr -d '\n\r' < core/version)" +else + CORE_VERSION=${LATEST_CORE_TAG#core/} +fi + +if [ -z "$LATEST_FRAMEWORK_TAG" ]; then + FRAMEWORK_VERSION="v$(tr -d '\n\r' < framework/version)" +else + FRAMEWORK_VERSION=${LATEST_FRAMEWORK_TAG#framework/} +fi + +echo "πŸ” DEBUG: LATEST_CORE_TAG: $LATEST_CORE_TAG" +echo "πŸ” DEBUG: CORE_VERSION: $CORE_VERSION" +echo "πŸ” DEBUG: LATEST_FRAMEWORK_TAG: $LATEST_FRAMEWORK_TAG" +echo "πŸ” DEBUG: FRAMEWORK_VERSION: $FRAMEWORK_VERSION" + + +# Get latest plugin versions +echo "πŸ”Œ Getting latest plugin release versions..." +declare -A PLUGIN_VERSIONS + +# First, get versions for plugins that exist in the plugins/ directory +for plugin_dir in plugins/*/; do + if [ -d "$plugin_dir" ]; then + plugin_name=$(basename "$plugin_dir") + + # Check if VERSION parameter contains prerelease suffix + if [[ "$VERSION" == *"-"* ]]; then + # VERSION has prerelease, so include all versions but prefer stable + ALL_TAGS=$(git tag -l "plugins/${plugin_name}/v*" | sort -V) + STABLE_TAGS=$(echo "$ALL_TAGS" | grep -v '\-') + PRERELEASE_TAGS=$(echo "$ALL_TAGS" | grep '\-') + + if [ -n "$STABLE_TAGS" ]; then + # Get the highest stable version + LATEST_PLUGIN_TAG=$(echo "$STABLE_TAGS" | tail -1) + echo "latest plugin tag (stable preferred): $LATEST_PLUGIN_TAG" + else + # No stable versions, get highest prerelease + LATEST_PLUGIN_TAG=$(echo "$PRERELEASE_TAGS" | tail -1) + echo "latest plugin tag (prerelease only): $LATEST_PLUGIN_TAG" + fi + else + # VERSION has no prerelease, so only consider stable releases + LATEST_PLUGIN_TAG=$(git tag -l "plugins/${plugin_name}/v*" | grep -v '\-' | sort -V | tail -1) + echo "latest plugin tag (stable only): $LATEST_PLUGIN_TAG" + fi + + if [ -z "$LATEST_PLUGIN_TAG" ]; then + # No matching release found, use version from file + PLUGIN_VERSION="v$(tr -d '\n\r' < "${plugin_dir}version")" + echo " πŸ“¦ $plugin_name: $PLUGIN_VERSION (from version file - not yet released)" + else + PLUGIN_VERSION=${LATEST_PLUGIN_TAG#plugins/${plugin_name}/} + echo " πŸ“¦ $plugin_name: $PLUGIN_VERSION (latest release)" + fi + + PLUGIN_VERSIONS["$plugin_name"]="$PLUGIN_VERSION" + fi +done + +# Also check for any plugins already in transport go.mod that might not be in plugins/ directory +cd transports +echo "πŸ” Checking for additional plugins in transport go.mod..." +# Parse go.mod plugin lines and add missing ones +while IFS= read -r plugin_line; do + plugin_name=$(echo "$plugin_line" | awk -F'/' '{print $NF}' | awk '{print $1}') + current_version=$(echo "$plugin_line" | awk '{print $NF}') + + # Only add if we don't already have this plugin + if [[ -z "${PLUGIN_VERSIONS[$plugin_name]:-}" ]]; then + echo " πŸ“¦ $plugin_name: $current_version (from transport go.mod)" + PLUGIN_VERSIONS["$plugin_name"]="$current_version" + fi +done < <(grep "github.com/maximhq/bifrost/plugins/" go.mod) +cd .. + +echo "πŸ”§ Using versions:" +echo " Core: $CORE_VERSION" +echo " Framework: $FRAMEWORK_VERSION" +echo " Plugins:" +for plugin_name in "${!PLUGIN_VERSIONS[@]}"; do + echo " - $plugin_name: ${PLUGIN_VERSIONS[$plugin_name]}" +done + +# Update transport dependencies to use latest plugin releases +echo "πŸ”§ Using latest plugin release versions for transport..." +PLUGINS_USED=() + +# Track which plugins are actually used by the transport +cd transports +for plugin_name in "${!PLUGIN_VERSIONS[@]}"; do + plugin_version="${PLUGIN_VERSIONS[$plugin_name]}" + + # Check if transport depends on this plugin + if grep -q "github.com/maximhq/bifrost/plugins/$plugin_name" go.mod; then + echo " πŸ“¦ Using $plugin_name plugin $plugin_version" + go_get_with_backoff "github.com/maximhq/bifrost/plugins/$plugin_name@$plugin_version" + PLUGINS_USED+=("$plugin_name:$plugin_version") + fi +done + +# Also ensure core and framework are up to date + +echo " πŸ”§ Updating core to $CORE_VERSION" +go_get_with_backoff "github.com/maximhq/bifrost/core@$CORE_VERSION" + +echo " πŸ“¦ Updating framework to $FRAMEWORK_VERSION" +go_get_with_backoff "github.com/maximhq/bifrost/framework@$FRAMEWORK_VERSION" + +go mod tidy + +cd .. + +# We need to build UI first before we can validate the transport build +echo "🎨 Building UI..." +make build-ui + +# Validate transport build +echo "πŸ”¨ Validating transport build..." +cd transports +go test ./... +cd .. +echo "βœ… Transport build validation successful" + +# Commit and push changes if any +# First, stage any changes made to transports/ +git add transports/ +if ! git diff --cached --quiet; then + git pull origin main + git config user.name "github-actions[bot]" + git config user.email "41898282+github-actions[bot]@users.noreply.github.com" + echo "πŸ”§ Committing and pushing changes..." + git commit -m "transports: update dependencies --skip-pipeline" + git push -u origin HEAD +else + echo "ℹ️ No staged changes to commit" +fi + +# Install cross-compilation toolchains +echo "πŸ“¦ Installing cross-compilation toolchains..." +bash ./.github/workflows/scripts/install-cross-compilers.sh + +# Build Go executables +echo "πŸ”¨ Building executables..." +bash ./.github/workflows/scripts/build-executables.sh $VERSION + +# Configure and upload to R2 +echo "πŸ“€ Uploading binaries..." +bash ./.github/workflows/scripts/configure-r2.sh +bash ./.github/workflows/scripts/upload-to-r2.sh "$TAG_NAME" + +# Capturing changelog +CHANGELOG_BODY=$(cat transports/changelog.md) +# Skip comments from changelog +CHANGELOG_BODY=$(echo "$CHANGELOG_BODY" | grep -v '^') +# If changelog is empty, return error +if [ -z "$CHANGELOG_BODY" ]; then + echo "❌ Changelog is empty" + exit 1 +fi +echo "πŸ“ New changelog: $CHANGELOG_BODY" + +# Finding previous tag +echo "πŸ” Finding previous tag..." +PREV_TAG=$(git tag -l "transports/v*" | sort -V | tail -1) +if [[ "$PREV_TAG" == "$TAG_NAME" ]]; then + PREV_TAG=$(git tag -l "transports/v*" | sort -V | tail -2 | head -1) +fi +echo "πŸ” Previous tag: $PREV_TAG" + +# Get message of the tag +echo "πŸ” Getting previous tag message..." +PREV_CHANGELOG=$(git tag -l --format='%(contents)' "$PREV_TAG") +echo "πŸ“ Previous changelog body: $PREV_CHANGELOG" + +# Checking if tag message is the same as the changelog +if [[ "$PREV_CHANGELOG" == "$CHANGELOG_BODY" ]]; then + echo "❌ Changelog is the same as the previous changelog" + exit 1 +fi + +# Create and push tag +echo "🏷️ Creating tag: $TAG_NAME" +git tag "$TAG_NAME" -m "Release transports v$VERSION" -m "$CHANGELOG_BODY" +git push origin "$TAG_NAME" + +# Create GitHub release +TITLE="Bifrost HTTP v$VERSION" + +# Mark prereleases when version contains a hyphen +PRERELEASE_FLAG="" +if [[ "$VERSION" == *-* ]]; then + PRERELEASE_FLAG="--prerelease" +fi + +LATEST_FLAG="" +if [[ "$VERSION" != *-* ]]; then + LATEST_FLAG="--latest" +fi + +# Generate plugin version summary +PLUGIN_UPDATES="" +if [ ${#PLUGINS_USED[@]} -gt 0 ]; then + PLUGIN_UPDATES=" + +### πŸ”Œ Plugin Versions +This release includes the following plugin versions: +" + for plugin_info in "${PLUGINS_USED[@]}"; do + plugin_name="${plugin_info%%:*}" + plugin_version="${plugin_info##*:}" + PLUGIN_UPDATES="$PLUGIN_UPDATES- **$plugin_name**: \`$plugin_version\` +" + done +else + # Show all available plugin versions even if not directly used + PLUGIN_UPDATES=" + +### πŸ”Œ Available Plugin Versions +The following plugin versions are compatible with this release: +" + for plugin_name in "${!PLUGIN_VERSIONS[@]}"; do + plugin_version="${PLUGIN_VERSIONS[$plugin_name]}" + PLUGIN_UPDATES="$PLUGIN_UPDATES- **$plugin_name**: \`$plugin_version\` +" + done +fi + +BODY="## Bifrost HTTP Transport Release v$VERSION + +$CHANGELOG_BODY + +### Installation + +#### Docker +\`\`\`bash +docker run -p 8080:8080 maximhq/bifrost:v$VERSION +\`\`\` + +#### Binary Download +\`\`\`bash +npx @maximhq/bifrost --transport-version v$VERSION +\`\`\` + +### Docker Images +- **\`maximhq/bifrost:v$VERSION\`** - This specific version +- **\`maximhq/bifrost:latest\`** - Latest version (updated with this release) + +--- +_This release was automatically created with dependencies: core \`$CORE_VERSION\`, framework \`$FRAMEWORK_VERSION\`. All plugins have been validated and updated._" + +if [ -z "${GH_TOKEN:-}" ] && [ -z "${GITHUB_TOKEN:-}" ]; then + echo "Error: GH_TOKEN or GITHUB_TOKEN is not set. Please export one to authenticate the GitHub CLI." + exit 1 +fi + +echo "πŸŽ‰ Creating GitHub release for $TITLE..." +gh release create "$TAG_NAME" \ + --title "$TITLE" \ + --notes "$BODY" \ + ${PRERELEASE_FLAG} ${LATEST_FLAG} + +echo "βœ… Bifrost HTTP released successfully" + +# Print summary +echo "" +echo "πŸ“‹ Release Summary:" +echo " 🏷️ Tag: $TAG_NAME" +echo " πŸ”§ Core version: $CORE_VERSION" +echo " πŸ”§ Framework version: $FRAMEWORK_VERSION" +echo " πŸ“¦ Transport: Updated" +if [ ${#PLUGINS_USED[@]} -gt 0 ]; then + echo " πŸ”Œ Plugins used: ${PLUGINS_USED[*]}" +else + echo " πŸ”Œ Available plugins: $(printf "%s " "${!PLUGIN_VERSIONS[@]}")" +fi +echo " πŸŽ‰ GitHub release: Created" + +echo "success=true" >> "$GITHUB_OUTPUT" diff --git a/.github/workflows/scripts/release-core.sh b/.github/workflows/scripts/release-core.sh new file mode 100755 index 000000000..ed612fbef --- /dev/null +++ b/.github/workflows/scripts/release-core.sh @@ -0,0 +1,109 @@ +#!/usr/bin/env bash +set -euo pipefail + +# Release core component +# Usage: ./release-core.sh + +if [[ "${1:-}" == "" ]]; then + echo "Usage: $0 " + echo "Example: $0 1.2.0" + exit 1 +fi +VERSION="$1" + +TAG_NAME="core/v${VERSION}" + +echo "πŸ”§ Releasing core v$VERSION..." + +# Validate core build +echo "πŸ”¨ Validating core build..." +cd core + +if [[ ! -f version ]]; then + echo "❌ Missing core/version file" + exit 1 +fi +FILE_VERSION="$(cat version | tr -d '[:space:]')" +if [[ "$FILE_VERSION" != "$VERSION" ]]; then + echo "❌ Version mismatch: arg=$VERSION, core/version=$FILE_VERSION" + exit 1 +fi + +# Building core +go mod download +go build ./... +go test ./... +cd .. +echo "βœ… Core build validation successful" + + +# Capturing changelog +CHANGELOG_BODY=$(cat core/changelog.md) +# Skip comments from changelog +CHANGELOG_BODY=$(echo "$CHANGELOG_BODY" | grep -v '^') +# If changelog is empty, return error +if [ -z "$CHANGELOG_BODY" ]; then + echo "❌ Changelog is empty" + exit 1 +fi +echo "πŸ“ New changelog: $CHANGELOG_BODY" + +# Finding previous tag +echo "πŸ” Finding previous tag..." +PREV_TAG=$(git tag -l "core/v*" | sort -V | tail -1) +if [[ "$PREV_TAG" == "$TAG_NAME" ]]; then + PREV_TAG=$(git tag -l "core/v*" | sort -V | tail -2 | head -1) +fi +echo "πŸ” Previous tag: $PREV_TAG" + +# Get message of the tag +echo "πŸ” Getting previous tag message..." +PREV_CHANGELOG=$(git tag -l --format='%(contents)' "$PREV_TAG") +echo "πŸ“ Previous changelog body: $PREV_CHANGELOG" + +# Checking if tag message is the same as the changelog +if [[ "$PREV_CHANGELOG" == "$CHANGELOG_BODY" ]]; then + echo "❌ Changelog is the same as the previous changelog" + exit 1 +fi + +# Create and push tag +echo "🏷️ Creating tag: $TAG_NAME" +git tag "$TAG_NAME" -m "Release core v$VERSION" -m "$CHANGELOG_BODY" +git push origin "$TAG_NAME" + +# Create GitHub release +TITLE="Core v$VERSION" + +# Mark prereleases when version contains a hyphen +PRERELEASE_FLAG="" +if [[ "$VERSION" == *-* ]]; then + PRERELEASE_FLAG="--prerelease" +fi + +LATEST_FLAG="" +if [[ "$VERSION" != *-* ]]; then + LATEST_FLAG="--latest" +fi + +BODY="## Core Release v$VERSION + +$CHANGELOG_BODY + +### Installation + +\`\`\`bash +go get github.com/maximhq/bifrost/core@v$VERSION +\`\`\` + +--- +_This release was automatically created from version file: \`core/version\`_" + +echo "πŸŽ‰ Creating GitHub release for $TITLE..." +gh release create "$TAG_NAME" \ + --title "$TITLE" \ + --notes "$BODY" \ + ${PRERELEASE_FLAG} ${LATEST_FLAG} + +echo "βœ… Core released successfully" +echo "success=true" >> "$GITHUB_OUTPUT" diff --git a/.github/workflows/scripts/release-framework.sh b/.github/workflows/scripts/release-framework.sh new file mode 100755 index 000000000..8f923bf3d --- /dev/null +++ b/.github/workflows/scripts/release-framework.sh @@ -0,0 +1,176 @@ +#!/usr/bin/env bash +set -euo pipefail + +# Release framework component +# Usage: ./release-framework.sh + +# Source Go utilities for exponential backoff +source "$(dirname "$0")/go-utils.sh" + +# Making sure version is provided +if [ $# -ne 1 ]; then + echo "Usage: $0 " >&2 + exit 1 +fi + +VERSION_RAW="$1" +# Ensure leading 'v' for module/tag semver +if [[ "$VERSION_RAW" == v* ]]; then + VERSION="$VERSION_RAW" +else + VERSION="v$VERSION_RAW" +fi + +TAG_NAME="framework/${VERSION}" + +echo "πŸ“¦ Releasing framework $VERSION..." + +# Ensure we have the latest version +git pull origin +# Fetching all tags +git fetch --tags >/dev/null 2>&1 || true + +# Get latest core version +LATEST_CORE_TAG=$(git tag -l "core/v*" | sort -V | tail -1) +if [ -z "$LATEST_CORE_TAG" ]; then + CORE_VERSION="v$(tr -d '\n\r' < core/version)" +else + CORE_VERSION=${LATEST_CORE_TAG#core/} +fi + +echo "πŸ”§ Using core version: $CORE_VERSION" + +# Update framework dependencies +echo "πŸ”§ Updating framework dependencies..." +cd framework +go_get_with_backoff "github.com/maximhq/bifrost/core@$CORE_VERSION" +go mod tidy +git add go.mod go.sum + +# Check if there are any changes to commit +git add go.mod go.sum + + +# Validate framework build +echo "πŸ”¨ Validating framework build..." +go build ./... +# Starting dependencies of framework tests +echo "πŸ”§ Starting dependencies of framework tests..." +# Use docker compose (v2) if available, fallback to docker-compose (v1) +if command -v docker-compose >/dev/null 2>&1; then + docker-compose -f ../tests/docker-compose.yml up -d +elif docker compose version >/dev/null 2>&1; then + docker compose -f ../tests/docker-compose.yml up -d +else + echo "❌ Neither docker-compose nor docker compose is available" + exit 1 +fi +sleep 20 +go test ./... +# Shutting down dependencies +echo "πŸ”§ Shutting down dependencies of framework tests..." +# Use docker compose (v2) if available, fallback to docker-compose (v1) +if command -v docker-compose >/dev/null 2>&1; then + docker-compose -f ../tests/docker-compose.yml down +elif docker compose version >/dev/null 2>&1; then + docker compose -f ../tests/docker-compose.yml down +else + echo "❌ Neither docker-compose nor docker compose is available" + exit 1 +fi +cd .. + +echo "βœ… Framework build validation successful" + +# Check if there are any changes to commit +if ! git diff --cached --quiet; then + git config user.name "github-actions[bot]" + git config user.email "41898282+github-actions[bot]@users.noreply.github.com" + git commit -m "framework: bump core to $CORE_VERSION --skip-pipeline" + # Push the bump so go.mod/go.sum changes are recorded on the branch + CURRENT_BRANCH="$(git rev-parse --abbrev-ref HEAD)" + git push origin "$CURRENT_BRANCH" + echo "πŸ”§ Pushed framework bump to $CURRENT_BRANCH" +else + echo "No dependency changes detected; skipping commit." +fi + +# Capturing changelog +CHANGELOG_BODY=$(cat framework/changelog.md) +# Skip comments from changelog +CHANGELOG_BODY=$(echo "$CHANGELOG_BODY" | grep -v '^') +# If changelog is empty, return error +if [ -z "$CHANGELOG_BODY" ]; then + echo "❌ Changelog is empty" + exit 1 +fi +echo "πŸ“ New changelog: $CHANGELOG_BODY" + +# Finding previous tag +echo "πŸ” Finding previous tag..." +PREV_TAG=$(git tag -l "framework/v*" | sort -V | tail -1) +if [[ "$PREV_TAG" == "$TAG_NAME" ]]; then + PREV_TAG=$(git tag -l "framework/v*" | sort -V | tail -2 | head -1) +fi +echo "πŸ” Previous tag: $PREV_TAG" + +# Get message of the tag +echo "πŸ” Getting previous tag message..." +PREV_CHANGELOG=$(git tag -l --format='%(contents)' "$PREV_TAG") +echo "πŸ“ Previous changelog body: $PREV_CHANGELOG" + +# Checking if tag message is the same as the changelog +if [[ "$PREV_CHANGELOG" == "$CHANGELOG_BODY" ]]; then + echo "❌ Changelog is the same as the previous changelog" + exit 1 +fi + +# Create and push tag +echo "🏷️ Creating tag: $TAG_NAME" +if git rev-parse --verify "$TAG_NAME" >/dev/null 2>&1; then + echo "Tag $TAG_NAME already exists; skipping tag creation." +else + git tag "$TAG_NAME" -m "Release framework $VERSION" -m "$CHANGELOG_BODY" + git push origin "$TAG_NAME" +fi + +# Create GitHub release +TITLE="Framework $VERSION" + +# Mark prereleases when version contains a hyphen +PRERELEASE_FLAG="" +if [[ "$VERSION" == *-* ]]; then + PRERELEASE_FLAG="--prerelease" +fi + +LATEST_FLAG="" +if [[ "$VERSION" != *-* ]]; then + LATEST_FLAG="--latest" +fi + +BODY="## Framework Release $VERSION + +$CHANGELOG_BODY + +### Installation + +\`\`\`bash +go get github.com/maximhq/bifrost/framework@$VERSION +\`\`\` + +--- +_This release was automatically created and uses core version: \`$CORE_VERSION\`_" + +echo "πŸŽ‰ Creating GitHub release for $TITLE..." +if gh release view "$TAG_NAME" >/dev/null 2>&1; then + echo "ℹ️ Release $TAG_NAME already exists. Skipping creation." +else + gh release create "$TAG_NAME" \ + --title "$TITLE" \ + --notes "$BODY" \ + ${PRERELEASE_FLAG} ${LATEST_FLAG} + +fi + +echo "βœ… Framework released successfully" +echo "success=true" >> "$GITHUB_OUTPUT" diff --git a/.github/workflows/scripts/release-single-plugin.sh b/.github/workflows/scripts/release-single-plugin.sh new file mode 100755 index 000000000..e893da12b --- /dev/null +++ b/.github/workflows/scripts/release-single-plugin.sh @@ -0,0 +1,188 @@ +#!/usr/bin/env bash +set -euo pipefail + +# Release a single plugin +# Usage: ./release-single-plugin.sh [core-version] [framework-version] + +# Source Go utilities for exponential backoff +source "$(dirname "$0")/go-utils.sh" +if [[ $# -lt 1 ]]; then + echo "Usage: $0 [core-version] [framework-version]" + exit 1 +fi + +PLUGIN_NAME="$1" + +# Get core version from parameter or latest tag +if [ -n "${2:-}" ]; then + CORE_VERSION="$2" +else + # Get latest core version from git tags + LATEST_CORE_TAG=$(git tag -l "core/v*" | sort -V | tail -1) + if [ -z "$LATEST_CORE_TAG" ]; then + echo "❌ No core tags found, using version from file" + CORE_VERSION="v$(tr -d '\n\r' < core/version)" + else + CORE_VERSION=${LATEST_CORE_TAG#core/} + fi +fi + +# Get framework version from parameter or latest tag +if [ -n "${3:-}" ]; then + FRAMEWORK_VERSION="$3" +else + # Get latest framework version from git tags + LATEST_FRAMEWORK_TAG=$(git tag -l "framework/v*" | sort -V | tail -1) + if [ -z "$LATEST_FRAMEWORK_TAG" ]; then + echo "❌ No framework tags found, using version from file" + FRAMEWORK_VERSION="v$(tr -d '\n\r' < framework/version)" + else + FRAMEWORK_VERSION=${LATEST_FRAMEWORK_TAG#framework/} + fi +fi + +# Ensure we have the latest version +git pull origin + +echo "πŸ”Œ Releasing plugin: $PLUGIN_NAME" +echo "πŸ”§ Core version: $CORE_VERSION" +echo "πŸ”§ Framework version: $FRAMEWORK_VERSION" + +PLUGIN_DIR="plugins/$PLUGIN_NAME" +VERSION_FILE="$PLUGIN_DIR/version" + +if [ ! -f "$VERSION_FILE" ]; then + echo "❌ Version file not found: $VERSION_FILE" + exit 1 +fi + +PLUGIN_VERSION=$(tr -d '\n\r' < "$VERSION_FILE") +TAG_NAME="plugins/${PLUGIN_NAME}/v${PLUGIN_VERSION}" + +echo "πŸ“¦ Plugin version: $PLUGIN_VERSION" +echo "🏷️ Tag name: $TAG_NAME" + + +# Update plugin dependencies +echo "πŸ”§ Updating plugin dependencies..." +cd "$PLUGIN_DIR" + +# Update core dependency +if [ -f "go.mod" ]; then + go_get_with_backoff "github.com/maximhq/bifrost/core@${CORE_VERSION}" + go_get_with_backoff "github.com/maximhq/bifrost/framework@${FRAMEWORK_VERSION}" + go mod tidy + git add go.mod go.sum || true + + # Validate build + echo "πŸ”¨ Validating plugin build..." + go build ./... + + # Run tests if any exist + if go list ./... | grep -q .; then + echo "πŸ§ͺ Running plugin tests..." + go test ./... + fi + + echo "βœ… Plugin $PLUGIN_NAME build validation successful" +else + echo "ℹ️ No go.mod found, skipping Go dependency update" +fi + +cd ../.. + +# Commit and push changes if any +if ! git diff --cached --quiet; then + git config user.name "github-actions[bot]" + git config user.email "41898282+github-actions[bot]@users.noreply.github.com" + echo "πŸ”§ Committing and pushing changes..." + git commit -m "plugins/${PLUGIN_NAME}: bump core to $CORE_VERSION and framework to $FRAMEWORK_VERSION --skip-pipeline" + git push -u origin HEAD +else + echo "ℹ️ No staged changes to commit" +fi + +# Capturing changelog +CHANGELOG_BODY=$(cat $PLUGIN_DIR/changelog.md) +# Skip comments from changelog +CHANGELOG_BODY=$(echo "$CHANGELOG_BODY" | grep -v '^') +# If changelog is empty, return error +if [ -z "$CHANGELOG_BODY" ]; then + echo "❌ Changelog is empty" + exit 1 +fi +echo "πŸ“ New changelog: $CHANGELOG_BODY" + +# Finding previous tag +echo "πŸ” Finding previous tag..." +PREV_TAG=$(git tag -l "plugins/${PLUGIN_NAME}/v*" | sort -V | tail -1) +if [[ "$PREV_TAG" == "$TAG_NAME" ]]; then + PREV_TAG=$(git tag -l "plugins/${PLUGIN_NAME}/v*" | sort -V | tail -2 | head -1) +fi +echo "πŸ” Previous tag: $PREV_TAG" + +# Get message of the tag +echo "πŸ” Getting previous tag message..." +PREV_CHANGELOG=$(git tag -l --format='%(contents)' "$PREV_TAG") +echo "πŸ“ Previous changelog body: $PREV_CHANGELOG" + +# Checking if tag message is the same as the changelog +if [[ "$PREV_CHANGELOG" == "$CHANGELOG_BODY" ]]; then + echo "❌ Changelog is the same as the previous changelog" + exit 1 +fi + + +# Create and push tag +echo "🏷️ Creating tag: $TAG_NAME" + +if git rev-parse "$TAG_NAME" >/dev/null 2>&1; then + echo "ℹ️ Tag already exists: $TAG_NAME (skipping creation)" +else + git tag "$TAG_NAME" -m "Release plugin $PLUGIN_NAME v$PLUGIN_VERSION" -m "$CHANGELOG_BODY" + git push origin "$TAG_NAME" +fi + +# Create GitHub release +TITLE="Plugin $PLUGIN_NAME v$PLUGIN_VERSION" + +# Mark prereleases when version contains a hyphen +PRERELEASE_FLAG="" +if [[ "$PLUGIN_VERSION" == *-* ]]; then + PRERELEASE_FLAG="--prerelease" +fi + +# Mark as latest if not a prerelease +LATEST_FLAG="" +if [[ "$PLUGIN_VERSION" != *-* ]]; then + LATEST_FLAG="--latest" +fi + + +BODY="## Plugin Release: $PLUGIN_NAME v$PLUGIN_VERSION + +$CHANGELOG_BODY + +### Installation + +\`\`\`bash +# Update your go.mod to use the new plugin version +go get github.com/maximhq/bifrost/plugins/$PLUGIN_NAME@v$PLUGIN_VERSION +\`\`\` + +--- +_This release was automatically created from version file: \`plugins/$PLUGIN_NAME/version\`_" + +echo "πŸŽ‰ Creating GitHub release for $TITLE..." + +if gh release view "$TAG_NAME" >/dev/null 2>&1; then + echo "ℹ️ Release $TAG_NAME already exists. Skipping creation." +else + gh release create "$TAG_NAME" \ + --title "$TITLE" \ + --notes "$BODY" \ + ${PRERELEASE_FLAG} ${LATEST_FLAG} +fi + +echo "βœ… Plugin $PLUGIN_NAME released successfully" +echo "success=true" >> "${GITHUB_OUTPUT:-/dev/null}" diff --git a/.github/workflows/scripts/revert-latest.sh b/.github/workflows/scripts/revert-latest.sh new file mode 100755 index 000000000..0f6d3c33b --- /dev/null +++ b/.github/workflows/scripts/revert-latest.sh @@ -0,0 +1,77 @@ +#!/usr/bin/env bash +set -euo pipefail + +# Overwrite latest with a specific version from R2 +# Usage: ./revert-latest.sh + +if [[ $# -ne 1 ]]; then + echo "Usage: $0 (e.g., v1.2.3)" + exit 1 +fi + +VERSION="$1" +# Ensure version starts with 'v' +if [[ ! "$VERSION" =~ ^v ]]; then + VERSION="v${VERSION}" +fi + +# Validate required environment variables +: "${R2_ENDPOINT:?R2_ENDPOINT env var is required}" +: "${R2_BUCKET:?R2_BUCKET env var is required}" + +# Clean endpoint URL +R2_ENDPOINT="$(echo "$R2_ENDPOINT" | tr -d '[:space:]')" + +echo "πŸ”„ Reverting latest to version: $VERSION" + +# Function to sync with retry logic +sync_with_retry() { + local source_path="$1" + local dest_path="$2" + local max_retries=3 + + for attempt in $(seq 1 $max_retries); do + echo "πŸ”„ Attempt $attempt/$max_retries: Syncing $source_path to $dest_path" + + if aws s3 sync "$source_path" "$dest_path" \ + --endpoint-url "$R2_ENDPOINT" \ + --profile "${R2_AWS_PROFILE:-R2}" \ + --no-progress \ + --delete; then + echo "βœ… Sync successful from $source_path to $dest_path" + return 0 + else + echo "⚠️ Attempt $attempt failed" + if [ $attempt -lt $max_retries ]; then + delay=$((2 ** attempt)) + echo "πŸ• Waiting ${delay}s before retry..." + sleep $delay + fi + fi + done + + echo "❌ All $max_retries attempts failed for syncing to $dest_path" + return 1 +} + +# Check if the version exists in R2 +echo "πŸ” Checking if version $VERSION exists..." +if ! aws s3 ls "s3://$R2_BUCKET/bifrost/$VERSION/" \ + --endpoint-url "$R2_ENDPOINT" \ + --profile "${R2_AWS_PROFILE:-R2}" >/dev/null 2>&1; then + echo "❌ Version $VERSION not found in R2 bucket" + echo "Available versions:" + aws s3 ls "s3://$R2_BUCKET/bifrost/" \ + --endpoint-url "$R2_ENDPOINT" \ + --profile "${R2_AWS_PROFILE:-R2}" | grep "PRE v" | awk '{print $2}' | sed 's/\///g' || true + exit 1 +fi + +echo "βœ… Version $VERSION found in R2" + +# Sync the specific version to latest +if ! sync_with_retry "s3://$R2_BUCKET/bifrost/$VERSION/" "s3://$R2_BUCKET/bifrost/latest/"; then + exit 1 +fi + +echo "πŸŽ‰ Successfully reverted latest to version $VERSION" diff --git a/.github/workflows/scripts/upload-to-r2.sh b/.github/workflows/scripts/upload-to-r2.sh new file mode 100755 index 000000000..b89fed3c8 --- /dev/null +++ b/.github/workflows/scripts/upload-to-r2.sh @@ -0,0 +1,78 @@ +#!/usr/bin/env bash +set -euo pipefail + +# Upload builds to R2 with retry logic +# Usage: ./upload-to-r2.sh + +if [[ $# -ne 1 ]]; then + echo "Usage: $0 (e.g., transports/v1.2.3)" + exit 1 +fi +TRANSPORT_VERSION="$1" +if [[ ! -d "./dist" ]]; then + echo "❌ ./dist not found. Build artifacts must be present before upload." + exit 1 +fi +: "${R2_ENDPOINT:?R2_ENDPOINT env var is required}" +: "${R2_BUCKET:?R2_BUCKET env var is required}" + +# Strip 'transports/' prefix from version +VERSION_ONLY=${TRANSPORT_VERSION#transports/v} +CLI_VERSION="v${VERSION_ONLY}" +R2_ENDPOINT="$(echo "$R2_ENDPOINT" | tr -d '[:space:]')" + +echo "πŸ“€ Uploading binaries for version: $CLI_VERSION" + +# Function to upload with retry +upload_with_retry() { + local source_path="$1" + local dest_path="$2" + local max_retries=3 + + for attempt in $(seq 1 $max_retries); do + echo "πŸ”„ Attempt $attempt/$max_retries: Uploading to $dest_path" + + if aws s3 sync "$source_path" "$dest_path" \ + --endpoint-url "$R2_ENDPOINT" \ + --profile "${R2_AWS_PROFILE:-R2}" \ + --no-progress \ + --delete; then + echo "βœ… Upload successful to $dest_path" + return 0 + else + echo "⚠️ Attempt $attempt failed" + if [ $attempt -lt $max_retries ]; then + delay=$((2 ** attempt)) + echo "πŸ• Waiting ${delay}s before retry..." + sleep $delay + fi + fi + done + + echo "❌ All $max_retries attempts failed for $dest_path" + return 1 +} + +# Upload to versioned path +if ! upload_with_retry "./dist/" "s3://$R2_BUCKET/bifrost/$CLI_VERSION/"; then + exit 1 +fi + +# Check if this is a prerelease version (semver: presence of a hyphen denotes pre-release) +if [[ "$CLI_VERSION" == *-* ]]; then + echo "πŸ” Detected prerelease version: $CLI_VERSION" + echo "⏭️ Skipping upload to latest/ for prerelease" +else + echo "πŸ” Detected stable release: $CLI_VERSION" + + # Small delay between uploads (configurable; default 2s) + sleep "${INTER_UPLOAD_SLEEP_SECONDS:-2}" + + # Upload to latest path + echo "πŸ“€ Uploading to latest/" + if ! upload_with_retry "./dist/" "s3://$R2_BUCKET/bifrost/latest/"; then + exit 1 + fi +fi + +echo "πŸŽ‰ All binaries uploaded successfully to R2" diff --git a/.github/workflows/snyk.yml b/.github/workflows/snyk.yml new file mode 100644 index 000000000..db16d0aa9 --- /dev/null +++ b/.github/workflows/snyk.yml @@ -0,0 +1,103 @@ +name: Snyk checks + +on: + push: + branches: [main, master, '**/*'] + pull_request: + branches: ['**/*'] + workflow_dispatch: + +permissions: + contents: read + security-events: write + +jobs: + snyk-open-source: + name: Snyk Open Source (deps) + runs-on: ubuntu-latest + steps: + - name: Checkout + uses: actions/checkout@v4 + + - name: Setup Node (for UI) + uses: actions/setup-node@v4 + with: + node-version: '20' + + - name: Setup Python (for tests tooling) + uses: actions/setup-python@v5 + with: + python-version: '3.11' + + - name: Setup Go + uses: actions/setup-go@v5 + with: + go-version: '1.24' + + - name: Install Snyk CLI + uses: snyk/actions/setup@master + + - name: Snyk test (all projects) + env: + SNYK_TOKEN: ${{ secrets.SNYK_TOKEN }} + run: snyk test --all-projects --detection-depth=4 --sarif-file-output=snyk.sarif || true + + - name: Upload SARIF + if: always() + uses: github/codeql-action/upload-sarif@v3 + with: + sarif_file: snyk.sarif + + snyk-code: + name: Snyk Code (SAST) + runs-on: ubuntu-latest + steps: + - name: Checkout + uses: actions/checkout@v4 + + - name: Setup Node (for UI) + uses: actions/setup-node@v4 + with: + node-version: '20' + + - name: Setup Python (for tests tooling) + uses: actions/setup-python@v5 + with: + python-version: '3.11' + + - name: Setup Python (for tests tooling) + uses: actions/setup-python@v5 + with: + python-version: '3.11' + cache: 'pip' + cache-dependency-path: | + tests/integrations/requirements.txt + tests/governance/requirements.txt + + - name: Install Python dependencies (tests tooling) + run: | + python -m pip install --disable-pip-version-check \ + -r tests/integrations/requirements.txt \ + -r tests/governance/requirements.txt + + - name: Setup Go + uses: actions/setup-go@v5 + with: + go-version: '1.24' + + - name: Build + run: make build + + - name: Install Snyk CLI + uses: snyk/actions/setup@master + + - name: Snyk Code test + env: + SNYK_TOKEN: ${{ secrets.SNYK_TOKEN }} + run: snyk code test --sarif-file-output=snyk-code.sarif || true + + - name: Upload SARIF + if: always() + uses: github/codeql-action/upload-sarif@v3 + with: + sarif_file: snyk-code.sarif diff --git a/.gitignore b/.gitignore index 48303bc0a..3fa5840bc 100644 --- a/.gitignore +++ b/.gitignore @@ -1,3 +1,31 @@ .env .vscode .DS_Store +*_creds* +**/venv/ +**/__pycache__/** +private.* +.venv +bifrost-data + +# Temporary directories +**/temp/ +/transports/ui +/transports/bifrost-http/lib/ui +/transports/bifrost-http/ui/ +transports/bifrost-http/logs/ +transports/bifrost-http/tmp/ +node_modules +/dist +**/tmp/ +temp*/ +tmp/ + +# Go workspaces (local only) +go.work +go.work.sum + +# Sqlite DBs +*.db +*.db-shm +*.db-wal \ No newline at end of file diff --git a/.prettierrc b/.prettierrc new file mode 100644 index 000000000..4da40ee34 --- /dev/null +++ b/.prettierrc @@ -0,0 +1,25 @@ +{ + "root": true, + "printWidth": 140, + "singleQuote": false, + "bracketSpacing": true, + "semi": true, + "bracketSameLine": false, + "useTabs": true, + "tabWidth": 2, + "trailingComma": "all", + "plugins": [ + "prettier-plugin-tailwindcss" + ], + "pluginSearchDirs": [ + "./ui" + ], + "tailwindAttributes": [ + "buttonClassname" + ], + "tailwindFunctions": [ + "cn", + "classNames" + ], + "endOfLine": "lf" +} \ No newline at end of file diff --git a/.snyk b/.snyk new file mode 100644 index 000000000..96a414bcc --- /dev/null +++ b/.snyk @@ -0,0 +1,5 @@ +# Snyk (https://snyk.io) policy file +# Manages vulnerability ignores and patches for this repository. +version: v1.25.0 +ignore: {} +patch: {} \ No newline at end of file diff --git a/CODE_OF_CONDUCT.md b/CODE_OF_CONDUCT.md new file mode 100644 index 000000000..182c6513e --- /dev/null +++ b/CODE_OF_CONDUCT.md @@ -0,0 +1,128 @@ +# Contributor Covenant Code of Conduct + +## Our Pledge + +We as members, contributors, and leaders pledge to make participation in our +community a harassment-free experience for everyone, regardless of age, body +size, visible or invisible disability, ethnicity, sex characteristics, gender +identity and expression, level of experience, education, socio-economic status, +nationality, personal appearance, race, religion, or sexual identity +and orientation. + +We pledge to act and interact in ways that contribute to an open, welcoming, +diverse, inclusive, and healthy community. + +## Our Standards + +Examples of behavior that contributes to a positive environment for our +community include: + +* Demonstrating empathy and kindness toward other people +* Being respectful of differing opinions, viewpoints, and experiences +* Giving and gracefully accepting constructive feedback +* Accepting responsibility and apologizing to those affected by our mistakes, + and learning from the experience +* Focusing on what is best not just for us as individuals, but for the + overall community + +Examples of unacceptable behavior include: + +* The use of sexualized language or imagery, and sexual attention or + advances of any kind +* Trolling, insulting or derogatory comments, and personal or political attacks +* Public or private harassment +* Publishing others' private information, such as a physical or email + address, without their explicit permission +* Other conduct which could reasonably be considered inappropriate in a + professional setting + +## Enforcement Responsibilities + +Community leaders are responsible for clarifying and enforcing our standards of +acceptable behavior and will take appropriate and fair corrective action in +response to any behavior that they deem inappropriate, threatening, offensive, +or harmful. + +Community leaders have the right and responsibility to remove, edit, or reject +comments, commits, code, wiki edits, issues, and other contributions that are +not aligned to this Code of Conduct, and will communicate reasons for moderation +decisions when appropriate. + +## Scope + +This Code of Conduct applies within all community spaces, and also applies when +an individual is officially representing the community in public spaces. +Examples of representing our community include using an official e-mail address, +posting via an official social media account, or acting as an appointed +representative at an online or offline event. + +## Enforcement + +Instances of abusive, harassing, or otherwise unacceptable behavior may be +reported to the community leaders responsible for enforcement at +akshay@getmaxim.ai. +All complaints will be reviewed and investigated promptly and fairly. + +All community leaders are obligated to respect the privacy and security of the +reporter of any incident. + +## Enforcement Guidelines + +Community leaders will follow these Community Impact Guidelines in determining +the consequences for any action they deem in violation of this Code of Conduct: + +### 1. Correction + +**Community Impact**: Use of inappropriate language or other behavior deemed +unprofessional or unwelcome in the community. + +**Consequence**: A private, written warning from community leaders, providing +clarity around the nature of the violation and an explanation of why the +behavior was inappropriate. A public apology may be requested. + +### 2. Warning + +**Community Impact**: A violation through a single incident or series +of actions. + +**Consequence**: A warning with consequences for continued behavior. No +interaction with the people involved, including unsolicited interaction with +those enforcing the Code of Conduct, for a specified period of time. This +includes avoiding interactions in community spaces as well as external channels +like social media. Violating these terms may lead to a temporary or +permanent ban. + +### 3. Temporary Ban + +**Community Impact**: A serious violation of community standards, including +sustained inappropriate behavior. + +**Consequence**: A temporary ban from any sort of interaction or public +communication with the community for a specified period of time. No public or +private interaction with the people involved, including unsolicited interaction +with those enforcing the Code of Conduct, is allowed during this period. +Violating these terms may lead to a permanent ban. + +### 4. Permanent Ban + +**Community Impact**: Demonstrating a pattern of violation of community +standards, including sustained inappropriate behavior, harassment of an +individual, or aggression toward or disparagement of classes of individuals. + +**Consequence**: A permanent ban from any sort of public interaction within +the community. + +## Attribution + +This Code of Conduct is adapted from the [Contributor Covenant][homepage], +version 2.0, available at +https://www.contributor-covenant.org/version/2/0/code_of_conduct.html. + +Community Impact Guidelines were inspired by [Mozilla's code of conduct +enforcement ladder](https://github.com/mozilla/diversity). + +[homepage]: https://www.contributor-covenant.org + +For answers to common questions about this code of conduct, see the FAQ at +https://www.contributor-covenant.org/faq. Translations are available at +https://www.contributor-covenant.org/translations. diff --git a/Makefile b/Makefile new file mode 100644 index 000000000..79ce1b7e8 --- /dev/null +++ b/Makefile @@ -0,0 +1,174 @@ +# Makefile for Bifrost + +# Variables +HOST ?= localhost +PORT ?= 8080 +APP_DIR ?= +PROMETHEUS_LABELS ?= +LOG_STYLE ?= json +LOG_LEVEL ?= info + +# Colors for output +RED=\033[0;31m +GREEN=\033[0;32m +YELLOW=\033[1;33m +BLUE=\033[0;34m +CYAN=\033[0;36m +NC=\033[0m # No Color + +.PHONY: all help dev build-ui build run install-air clean test install-ui setup-workspace work-init work-clean docs docker-build + +all: help + +# Default target +help: ## Show this help message + @echo "$(BLUE)Bifrost Development - Available Commands:$(NC)" + @echo "" + @grep -E '^[a-zA-Z_-]+:.*?## .*$$' $(MAKEFILE_LIST) | sort | awk 'BEGIN {FS = ":.*?## "}; {printf " $(GREEN)%-15s$(NC) %s\n", $$1, $$2}' + @echo "" + @echo "$(YELLOW)Environment Variables:$(NC)" + @echo " HOST Server host (default: localhost)" + @echo " PORT Server port (default: 8080)" + @echo " PROMETHEUS_LABELS Labels for Prometheus metrics" + @echo " LOG_STYLE Logger output format: json|pretty (default: json)" + @echo " LOG_LEVEL Logger level: debug|info|warn|error (default: info)" + @echo " APP_DIR App data directory inside container (default: /app/data)" + +install-ui: + @which node > /dev/null || (echo "$(RED)Error: Node.js is not installed. Please install Node.js first.$(NC)" && exit 1) + @which npm > /dev/null || (echo "$(RED)Error: npm is not installed. Please install npm first.$(NC)" && exit 1) + @echo "$(GREEN)Node.js and npm are installed$(NC)" + @cd ui && npm install + @which next > /dev/null || (echo "$(YELLOW)Installing nextjs...$(NC)" && npm install -g next) + @echo "$(GREEN)UI deps are in sync$(NC)" + +install-air: ## Install air for hot reloading (if not already installed) + @which air > /dev/null || (echo "$(YELLOW)Installing air for hot reloading...$(NC)" && go install github.com/air-verse/air@latest) + @echo "$(GREEN)Air is ready$(NC)" + +dev: install-ui install-air setup-workspace ## Start complete development environment (UI + API with proxy) + @echo "$(GREEN)Starting Bifrost complete development environment...$(NC)" + @echo "$(YELLOW)This will start:$(NC)" + @echo " 1. UI development server (localhost:3000)" + @echo " 2. API server with UI proxy (localhost:$(PORT))" + @echo "$(CYAN)Access everything at: http://localhost:$(PORT)$(NC)" + @echo "" + @echo "$(YELLOW)Starting UI development server...$(NC)" + @cd ui && npm run dev & + @sleep 3 + @echo "$(YELLOW)Starting API server with UI proxy...$(NC)" + @$(MAKE) setup-workspace >/dev/null + @cd transports/bifrost-http && BIFROST_UI_DEV=true air -c .air.toml -- \ + -host "$(HOST)" \ + -port "$(PORT)" \ + -log-style "$(LOG_STYLE)" \ + -log-level "$(LOG_LEVEL)" \ + $(if $(PROMETHEUS_LABELS),-prometheus-labels "$(PROMETHEUS_LABELS)") \ + $(if $(APP_DIR),-app-dir "$(APP_DIR)") + +build-ui: install-ui ## Build ui + @echo "$(GREEN)Building ui...$(NC)" + @rm -rf ui/.next + @cd ui && npm run build && npm run copy-build + +build: build-ui ## Build bifrost-http binary + @echo "$(GREEN)Building bifrost-http...$(NC)" + @cd transports/bifrost-http && GOWORK=off go build -o ../../tmp/bifrost-http . + @echo "$(GREEN)Built: tmp/bifrost-http$(NC)" + +docker-build: build-ui ## Build Docker image + @echo "$(GREEN)Building Docker image...$(NC)" + @docker build -f transports/Dockerfile -t bifrost . + @echo "$(GREEN)Docker image built: bifrost$(NC)" + +docker-run: ## Run Docker container + @echo "$(GREEN)Running Docker container...$(NC)" + @docker run -e APP_PORT=$(PORT) -e APP_HOST=0.0.0.0 -p $(PORT):$(PORT) -e LOG_LEVEL=$(LOG_LEVEL) -e LOG_STYLE=$(LOG_STYLE) -v $(shell pwd):/app/data bifrost + +docs: ## Prepare local docs + @echo "$(GREEN)Preparing local docs...$(NC)" + @cd docs && npx --yes mintlify@latest dev + +run: build ## Build and run bifrost-http (no hot reload) + @echo "$(GREEN)Running bifrost-http...$(NC)" + @./tmp/bifrost-http \ + -host "$(HOST)" \ + -port "$(PORT)" \ + -log-style "$(LOG_STYLE)" \ + -log-level "$(LOG_LEVEL)" \ + $(if $(PROMETHEUS_LABELS),-prometheus-labels "$(PROMETHEUS_LABELS)") + $(if $(APP_DIR),-app-dir "$(APP_DIR)") + +clean: ## Clean build artifacts and temporary files + @echo "$(YELLOW)Cleaning build artifacts...$(NC)" + @rm -rf tmp/ + @rm -f transports/bifrost-http/build-errors.log + @rm -rf transports/bifrost-http/tmp/ + @echo "$(GREEN)Clean complete$(NC)" + +test: ## Run tests for bifrost-http + @echo "$(GREEN)Running bifrost-http tests...$(NC)" + @cd transports/bifrost-http && GOWORK=off go test -v ./... + +test-core: ## Run core tests + @echo "$(GREEN)Running core tests...$(NC)" + @cd core && go test -v ./... + +test-plugins: ## Run plugin tests + @echo "$(GREEN)Running plugin tests...$(NC)" + @cd plugins && find . -name "*.go" -path "*/tests/*" -o -name "*_test.go" | head -1 > /dev/null && \ + for dir in $$(find . -name "*_test.go" -exec dirname {} \; | sort -u); do \ + echo "Testing $$dir..."; \ + cd $$dir && go test -v ./... && cd - > /dev/null; \ + done || echo "No plugin tests found" + +test-all: test-core test-plugins test ## Run all tests + +# Quick start with example config +quick-start: ## Quick start with example config and maxim plugin + @echo "$(GREEN)Quick starting Bifrost with example configuration...$(NC)" + @$(MAKE) dev + +# Linting and formatting +lint: ## Run linter for Go code + @echo "$(GREEN)Running golangci-lint...$(NC)" + @golangci-lint run ./... + +fmt: ## Format Go code + @echo "$(GREEN)Formatting Go code...$(NC)" + @gofmt -s -w . + @goimports -w . + +# Workspace helpers +setup-workspace: ## Set up Go workspace with all local modules for development + @echo "$(GREEN)Setting up Go workspace for local development...$(NC)" + @echo "$(YELLOW)Cleaning existing workspace...$(NC)" + @rm -f go.work go.work.sum || true + @echo "$(YELLOW)Initializing new workspace...$(NC)" + @go work init ./core ./framework ./transports + @echo "$(YELLOW)Adding plugin modules...$(NC)" + @for plugin_dir in ./plugins/*/; do \ + if [ -d "$$plugin_dir" ] && [ -f "$$plugin_dir/go.mod" ]; then \ + echo " Adding plugin: $$(basename $$plugin_dir)"; \ + go work use "$$plugin_dir"; \ + fi; \ + done + @echo "$(YELLOW)Syncing workspace...$(NC)" + @go work sync + @echo "$(GREEN)βœ“ Go workspace ready with all local modules$(NC)" + @echo "" + @echo "$(CYAN)Local modules in workspace:$(NC)" + @go list -m all | grep "github.com/maximhq/bifrost" | grep -v " v" | sed 's/^/ βœ“ /' + @echo "" + @echo "$(CYAN)Remote modules (no local version):$(NC)" + @go list -m all | grep "github.com/maximhq/bifrost" | grep " v" | sed 's/^/ β†’ /' + @echo "" + @echo "$(YELLOW)Note: go.work files are not committed to version control$(NC)" + +work-init: ## Create local go.work to use local modules for development (legacy) + @echo "$(YELLOW)⚠️ work-init is deprecated, use 'make setup-workspace' instead$(NC)" + @$(MAKE) setup-workspace + +work-clean: ## Remove local go.work + @rm -f go.work go.work.sum || true + @echo "$(GREEN)Removed local go.work files$(NC)" diff --git a/README.md b/README.md index e227739c0..4cd888b0d 100644 --- a/README.md +++ b/README.md @@ -1,399 +1,264 @@ # Bifrost -Bifrost is an open-source middleware that serves as a unified gateway to various AI model providers, enabling seamless integration and fallback mechanisms for your AI-powered applications. - -## πŸ“‘ Table of Contents - -- [Bifrost](#bifrost) - - [πŸ“‘ Table of Contents](#-table-of-contents) - - [πŸ” Overview](#-overview) - - [✨ Features](#-features) - - [πŸ—οΈ Repository Structure](#️-repository-structure) - - [πŸ“Š Benchmarks](#-benchmarks) - - [Test Environment](#test-environment) - - [t3.medium Instance](#t3medium-instance) - - [t3.xlarge Instance](#t3xlarge-instance) - - [Performance Metrics](#performance-metrics) - - [Key Performance Highlights](#key-performance-highlights) - - [πŸš€ Getting Started](#-getting-started) - - [Package Structure](#package-structure) - - [Prerequisites](#prerequisites) - - [Setting up Bifrost](#setting-up-bifrost) - - [Additional Configurations](#additional-configurations) - - [🀝 Contributing](#-contributing) - - [πŸ“„ License](#-license) +[![Go Report Card](https://goreportcard.com/badge/github.com/maximhq/bifrost/core)](https://goreportcard.com/report/github.com/maximhq/bifrost/core) +[![Discord badge](https://dcbadge.limes.pink/api/server/https://discord.gg/exN5KAydbU?style=flat)](https://discord.gg/exN5KAydbU) +[![Known Vulnerabilities](https://snyk.io/test/github/maximhq/bifrost/badge.svg)](https://snyk.io/test/github/maximhq/bifrost) +[![codecov](https://codecov.io/gh/maximhq/bifrost/branch/main/graph/badge.svg)](https://codecov.io/gh/maximhq/bifrost) +![Docker Pulls](https://img.shields.io/docker/pulls/maximhq/bifrost) +[Run In Postman](https://app.getpostman.com/run-collection/31642484-2ba0e658-4dcd-49f4-845a-0c7ed745b916?action=collection%2Ffork&source=rip_markdown&collection-url=entityId%3D31642484-2ba0e658-4dcd-49f4-845a-0c7ed745b916%26entityType%3Dcollection%26workspaceId%3D63e853c8-9aec-477f-909c-7f02f543150e) +[![License](https://img.shields.io/github/license/maximhq/bifrost)](LICENSE) ---- +## The fastest way to build AI applications that never go down -## πŸ” Overview +Bifrost is a high-performance AI gateway that unifies access to 12+ providers (OpenAI, Anthropic, AWS Bedrock, Google Vertex, and more) through a single OpenAI-compatible API. Deploy in seconds with zero configuration and get automatic failover, load balancing, semantic caching, and enterprise-grade features. -Bifrost acts as a bridge between your applications and multiple AI providers (OpenAI, Anthropic, Amazon Bedrock, etc.). It provides a consistent API interface while handling: +## Quick Start -- Authentication and key management -- Request routing and load balancing -- Fallback mechanisms for reliability -- Unified request and response formatting -- Connection pooling and concurrency control +![Get started](./docs/media/getting-started.png) -With Bifrost, you can focus on building your AI-powered applications without worrying about the underlying provider-specific implementations. It handles all the complexities of key and provider management, providing a fixed input and output format so you don't need to modify your codebase for different providers. +**Go from zero to production-ready AI gateway in under a minute.** ---- +**Step 1:** Start Bifrost Gateway -## ✨ Features +```bash +# Install and run locally +npx -y @maximhq/bifrost -- **Multi-Provider Support**: Integrate with OpenAI, Anthropic, Amazon Bedrock, and more through a single API -- **Fallback Mechanisms**: Automatically retry failed requests with alternative models or providers -- **Dynamic Key Management**: Rotate and manage API keys efficiently -- **Connection Pooling**: Optimize network resources for better performance -- **Concurrency Control**: Manage rate limits and parallel requests effectively -- **HTTP Transport**: RESTful API interface for easy integration -- **Custom Configuration**: Flexible JSON-based configuration +# Or use Docker +docker run -p 8080:8080 maximhq/bifrost +``` ---- +**Step 2:** Configure via Web UI -## πŸ—οΈ Repository Structure +```bash +# Open the built-in web interface +open http://localhost:8080 +``` -Bifrost is built with a modular architecture: +**Step 3:** Make your first API call -``` -bifrost/ -β”œβ”€β”€ core/ # Core functionality and shared components -β”‚ β”œβ”€β”€ providers/ # Provider-specific implementations -β”‚ β”œβ”€β”€ schemas/ # Interfaces and structs used in bifrost -β”‚ β”œβ”€β”€ tests/ # Tests to make sure everything is in place -β”‚ β”œβ”€β”€ bifrost.go # Main Bifrost implementation -β”‚ -β”œβ”€β”€ transports/ # Interface layers (HTTP, gRPC, etc.) -β”‚ β”œβ”€β”€ http/ # HTTP transport implementation -β”‚ └── ... -β”‚ -└── plugins/ # Plugin Implementations - β”œβ”€β”€ maxim-logger.go - └── ... +```bash +curl -X POST http://localhost:8080/v1/chat/completions \ + -H "Content-Type: application/json" \ + -d '{ + "model": "openai/gpt-4o-mini", + "messages": [{"role": "user", "content": "Hello, Bifrost!"}] + }' ``` -The system uses a provider-agnostic approach with well-defined interfaces to easily extend to new AI providers. All interfaces are defined in `core/schemas/` and can be used as a reference for adding new plugins. +**That's it!** Your AI gateway is running with a web interface for visual configuration, real-time monitoring, and analytics. + +**Complete Setup Guides:** + +- [Gateway Setup](https://docs.getbifrost.ai/quickstart/gateway/setting-up) - HTTP API deployment +- [Go SDK Setup](https://docs.getbifrost.ai/quickstart/go-sdk/setting-up) - Direct integration --- -## πŸ“Š Benchmarks - -Bifrost has been tested under high load conditions to ensure optimal performance. The following results were obtained from benchmark tests running at 5000 requests per second (RPS) on different AWS EC2 instances, with Bifrost running inside Docker containers. - -### Test Environment - -#### t3.medium Instance -- **Instance**: AWS EC2 t3.medium -- **vCPUs**: 2 -- **Memory**: 4GB RAM -- **Container**: Docker container with resource limits matching instance specs -- **Bifrost Configurations**: - - Buffer Size: 15,000 - - Initial Pool Size: 10,000 - -#### t3.xlarge Instance -- **Instance**: AWS EC2 t3.xlarge -- **vCPUs**: 4 -- **Memory**: 16GB RAM -- **Container**: Docker container with resource limits matching instance specs -- **Bifrost Configurations**: - - Buffer Size: 20,000 - - Initial Pool Size: 15,000 - -### Performance Metrics - -| Metric | t3.medium | t3.xlarge | -|--------|-----------|-----------| -| Success Rate | 100.00% | 100.00% | -| Average Request Size | 0.13 KB | 0.13 KB | -| **Average Response Size** | **`1.37 KB`** | **`10.32 KB`** | -| Average Latency | 2.12s | 1.61s | -| Peak Memory Usage | 1312.79 MB | 3340.44 MB | -| Queue Wait Time | 47.13 Β΅s | 1.67 Β΅s | -| Key Selection Time | 16 ns | 10 ns | -| Message Formatting | 2.19 Β΅s | 2.11 Β΅s | -| Params Preparation | 436 ns | 417 ns | -| Request Body Preparation | 2.65 Β΅s | 2.36 Β΅s | -| JSON Marshaling | 63.47 Β΅s | 26.80 Β΅s | -| Request Setup | 6.59 Β΅s | 7.17 Β΅s | -| HTTP Request | 1.56s | 1.50s | -| Error Handling | 189 ns | 162 ns | -| Response Parsing | 11.30 ms | 2.11 ms | - -### Key Performance Highlights - -- **Perfect Success Rate**: 100% request success rate under high load on both instances -- **Efficient Queue Management**: Minimal queue wait time (1.67 Β΅s on t3.xlarge) -- **Fast Key Selection**: Near-instantaneous key selection (10 ns on t3.xlarge) -- **Optimized Memory Usage**: - - t3.medium: ~1.3GB at 5000 RPS - - t3.xlarge: ~3.3GB at 5000 RPS -- **Efficient Request Processing**: Most operations complete in microseconds -- **Network Efficiency**: - - Consistent small request sizes (0.13 KB) across instances - - Larger response sizes on t3.xlarge (10.32 KB vs 1.37 KB) due to more detailed responses -- **Improved Performance on t3.xlarge**: - - 24% faster average latency - - 81% faster response parsing - - 58% faster JSON marshaling - - Significantly reduced queue wait times - - Higher buffer and pool sizes enabled by increased resources - -These benchmarks demonstrate Bifrost's ability to handle high-throughput scenarios while maintaining reliability and performance, even when containerized. The t3.xlarge instance shows improved performance across most metrics, particularly in processing times and latency, while maintaining the same high reliability and success rate. The larger response sizes on t3.xlarge indicate its ability to handle more detailed responses without compromising performance. - -One of Bifrost's key strengths is its flexibility in configuration. You can freely decide the tradeoff between memory usage and processing speed by adjusting Bifrost's configurations: - -- **Memory vs Speed Tradeoff**: - - Higher buffer and pool sizes (like in t3.xlarge) improve speed but use more memory - - Lower configurations (like in t3.medium) use less memory but may have slightly higher latencies - - You can fine-tune these parameters based on your specific needs and available resources - -- **Customizable Parameters**: - - Buffer Size: Controls the maximum number of concurrent requests - - Initial Pool Size: Determines the initial allocation of resources - - Concurrency Settings: Adjustable per provider - - Retry and Timeout Configurations: Customizable based on your requirements - -This flexibility allows you to optimize Bifrost for your specific use case, whether you prioritize speed, memory efficiency, or a balance between the two. +## Key Features + +### Core Infrastructure + +- **[Unified Interface](https://docs.getbifrost.ai/features/unified-interface)** - Single OpenAI-compatible API for all providers +- **[Multi-Provider Support](https://docs.getbifrost.ai/quickstart/gateway/provider-configuration)** - OpenAI, Anthropic, AWS Bedrock, Google Vertex, Azure, Cohere, Mistral, Ollama, Groq, and more +- **[Automatic Fallbacks](https://docs.getbifrost.ai/features/fallbacks)** - Seamless failover between providers and models with zero downtime +- **[Load Balancing](https://docs.getbifrost.ai/features/fallbacks)** - Intelligent request distribution across multiple API keys and providers + +### Advanced Features + +- **[Model Context Protocol (MCP)](https://docs.getbifrost.ai/features/mcp)** - Enable AI models to use external tools (filesystem, web search, databases) +- **[Semantic Caching](https://docs.getbifrost.ai/features/semantic-caching)** - Intelligent response caching based on semantic similarity to reduce costs and latency +- **[Multimodal Support](https://docs.getbifrost.ai/quickstart/gateway/streaming)** - Support for text,images, audio, and streaming, all behind a common interface. +- **[Custom Plugins](https://docs.getbifrost.ai/enterprise/custom-plugins)** - Extensible middleware architecture for analytics, monitoring, and custom logic +- **[Governance](https://docs.getbifrost.ai/features/governance)** - Usage tracking, rate limiting, and fine-grained access control + +### Enterprise & Security + +- **[Budget Management](https://docs.getbifrost.ai/features/governance)** - Hierarchical cost control with virtual keys, teams, and customer budgets +- **[SSO Integration](https://docs.getbifrost.ai/features/sso-with-google-github)** - Google and GitHub authentication support +- **[Observability](https://docs.getbifrost.ai/features/observability)** - Native Prometheus metrics, distributed tracing, and comprehensive logging +- **[Vault Support](https://docs.getbifrost.ai/enterprise/vault-support)** - Secure API key management with HashiCorp Vault integration + +### Developer Experience + +- **[Zero-Config Startup](https://docs.getbifrost.ai/quickstart/gateway/setting-up)** - Start immediately with dynamic provider configuration +- **[Drop-in Replacement](https://docs.getbifrost.ai/features/drop-in-replacement)** - Replace OpenAI/Anthropic/GenAI APIs with one line of code +- **[SDK Integrations](https://docs.getbifrost.ai/integrations/what-is-an-integration)** - Native support for popular AI SDKs with zero code changes +- **[Configuration Flexibility](https://docs.getbifrost.ai/quickstart/gateway/provider-configuration)** - Web UI, API-driven, or file-based configuration options --- -## πŸš€ Getting Started - -If you want to **set up the Bifrost API quickly**, [check the transports documentation](https://github.com/maximhq/bifrost/tree/main/transports/README.md). - -### Package Structure - -Bifrost is divided into three Go packages: core, plugins, and transports. - -1. **core**: This package contains the core implementation of Bifrost as a Go package. - -2. **plugins**: This package serves as an extension to core. You can download this package using `go get github.com/maximhq/bifrost/plugins` and pass the plugins while initializing Bifrost. - - ```golang - plugin, err := plugins.NewMaximLoggerPlugin(os.Getenv("MAXIM_API_KEY"), os.Getenv("MAXIM_LOGGER_ID")) - if err != nil { - return nil, err - } - - // Initialize Bifrost - client, err := bifrost.Init(schemas.BifrostConfig{ - Account: &account, - Plugins: []schemas.Plugin{plugin}, - }) - ``` - -3. **transports**: This package contains transport clients like HTTP to expose your Bifrost client. You can either `go get` this package or directly use the independent Dockerfile to quickly spin up your Bifrost API interface ([Click here](https://github.com/maximhq/bifrost/tree/main/transports/README.md) to read more on this). - -### Prerequisites - -- Go 1.23 or higher -- Access to at least one AI model provider (OpenAI, Anthropic, etc.) -- API keys for the providers you wish to use - -### Setting up Bifrost - -1. Setting up your account: You first need to create your account which follows [Bifrost's account interface](https://github.com/maximhq/bifrost/blob/main/core/schemas/account.go). - -Example: - ```golang - type BaseAccount struct{} - - func (baseAccount *BaseAccount) GetConfiguredProviders() ([]schemas.ModelProvider, error) { - return []schemas.ModelProvider{schemas.OpenAI}, nil - } - - func (baseAccount *BaseAccount) GetKeysForProvider(providerKey schemas.ModelProvider) ([]schemas.Key, error) { - switch providerKey { - case schemas.OpenAI: - return []schemas.Key{ - { - Value: os.Getenv("OPENAI_API_KEY"), - Models: []string{"gpt-4o-mini"}, - }, - }, nil - default: - return nil, fmt.Errorf("unsupported provider: %s", providerKey) - } - } - - func (baseAccount *BaseAccount) GetConfigForProvider(providerKey schemas.ModelProvider) (*schemas.ProviderConfig, error) { - switch providerKey { - case schemas.OpenAI: - return &schemas.ProviderConfig{ - ConcurrencyAndBufferSize: schemas.ConcurrencyAndBufferSize{ - Concurrency: 3, - BufferSize: 10, - }, - }, nil - default: - return nil, fmt.Errorf("unsupported provider: %s", providerKey) - } - } - ``` - -Bifrost uses these methods to get all the keys and configurations it needs to call the providers. You can check the [Additional Configurations](#additional-configurations) section for further customizations. - -2. Get bifrost core package: Simply run `go get github.com/maximhq/bifrost/core` to download bifrost/core package. - -3. Initialising Bifrost: Initialise bifrost by providing your account implementation - -```golang -client, err := bifrost.Init(schemas.BifrostConfig{ - Account: &yourAccount, -}) -``` +## Repository Structure -4. Make your First LLM Call! - -```golang - msg = "What is a LLM gateway?" - messages := []schemas.Message{ - { Role: schemas.RoleUser, Content: &msg }, - } - - bifrostResult, bifrostErr := bifrost.ChatCompletionRequest( - schemas.OpenAI, &schemas.BifrostRequest{ - Model: "gpt-4o", // make sure you have configured gpt-4o in your account interface - Input: schemas.RequestInput{ - ChatCompletionInput: &messages, - }, - }, context.Background() - ) +Bifrost uses a modular architecture for maximum flexibility: + +```text +bifrost/ +β”œβ”€β”€ npx/ # NPX script for easy installation +β”œβ”€β”€ core/ # Core functionality and shared components +β”‚ β”œβ”€β”€ providers/ # Provider-specific implementations (OpenAI, Anthropic, etc.) +β”‚ β”œβ”€β”€ schemas/ # Interfaces and structs used throughout Bifrost +β”‚ └── bifrost.go # Main Bifrost implementation +β”œβ”€β”€ framework/ # Framework components for data persistence +β”‚ β”œβ”€β”€ configstore/ # Configuration storages +β”‚ β”œβ”€β”€ logstore/ # Request logging storages +β”‚ └── vectorstore/ # Vector storages +β”œβ”€β”€ transports/ # HTTP gateway and other interface layers +β”‚ └── bifrost-http/ # HTTP transport implementation +β”œβ”€β”€ ui/ # Web interface for HTTP gateway +β”œβ”€β”€ plugins/ # Extensible plugin system +β”‚ β”œβ”€β”€ governance/ # Budget management and access control +β”‚ β”œβ”€β”€ jsonparser/ # JSON parsing and manipulation utilities +β”‚ β”œβ”€β”€ logging/ # Request logging and analytics +β”‚ β”œβ”€β”€ maxim/ # Maxim's observability integration +β”‚ β”œβ”€β”€ mocker/ # Mock responses for testing and development +β”‚ β”œβ”€β”€ semanticcache/ # Intelligent response caching +β”‚ └── telemetry/ # Monitoring and observability +β”œβ”€β”€ docs/ # Documentation and guides +└── tests/ # Comprehensive test suites ``` -you can add model parameters by passing them in `Params:&schemas.ModelParameters{...yourParams}` ChatCompletionRequest. +--- + +## Getting Started Options + +Choose the deployment method that fits your needs: -### Additional Configurations +### 1. Gateway (HTTP API) -1. InitalPoolSize and DropExcessRequests: You can customise the initial pool size of the structs and channels bifrost creates on `bifrost.Init()`. A higher value would mean lesser run time allocations and lower latency but at the cost of more memory usage. Takes the defined default value if not provided. +**Best for:** Language-agnostic integration, microservices, and production deployments -```golang - client, err := bifrost.Init(schemas.BifrostConfig{ - Account: &yourAccount, - InitialPoolSize: 500, - DropExcessRequests: true, - }) +```bash +# NPX - Get started in 30 seconds +npx -y @maximhq/bifrost + +# Docker - Production ready +docker run -p 8080:8080 -v $(pwd)/data:/app/data maximhq/bifrost ``` -When `DropExcessRequests` is set to true, in cases where the queue is full, requests will not wait for the queue to be empty and will be dropped instead. By default it is set to false. +**Features:** Web UI, real-time monitoring, multi-provider management, zero-config startup -2. Logger: Like account interface, bifrost also allows you to pass your custom logger if it follows [bifrost's logger interface](https://github.com/maximhq/bifrost/blob/main/core/schemas/logger.go). Takes in the [default logger](https://github.com/maximhq/bifrost/blob/main/core/logger.go) if not provided. +**Learn More:** [Gateway Setup Guide](https://docs.getbifrost.ai/quickstart/gateway/setting-up) -```golang - client, err := bifrost.Init(schemas.BifrostConfig{ - Account: &yourAccount, - Logger: &yourLogger, - }) -``` +### 2. Go SDK -The default logger is set to level info by default. If you wish to use it but with a different log level, you can do it like this - +**Best for:** Direct Go integration with maximum performance and control -```golang - client, err := bifrost.Init(schemas.BifrostConfig{ - Account: &yourAccount, - Logger: bifrost.NewDefaultLogger(schemas.LogLevelDebug), - }) +```bash +go get github.com/maximhq/bifrost/core ``` -3. Plugins: You can create and pass your custom pre-hook and post-hook plugins to bifrost as long as they follow [bifrost's plugin interface](https://github.com/maximhq/bifrost/blob/main/core/schemas/plugin.go). +**Features:** Native Go APIs, embedded deployment, custom middleware integration -```golang - client, err := bifrost.Init(schemas.BifrostConfig{ - Account: &yourAccount, - Plugins: []schemas.Plugin{yourPlugin1, yourPlugin2, ...}, - }) -``` +**Learn More:** [Go SDK Guide](https://docs.getbifrost.ai/quickstart/go-sdk/setting-up) -4. Customise your provider settings: You can customise proxy config, timeouts, retry settings, concurrency buffer sizes for each of your provider in your account interface's GetConfigForProvider() method. - -exmaple: -```golang - schemas.ProviderConfig{ - NetworkConfig: schemas.NetworkConfig{ - DefaultRequestTimeoutInSeconds: 30, - MaxRetries: 2, - RetryBackoffInitial: 100 * time.Millisecond, - RetryBackoffMax: 2 * time.Second, - }, - MetaConfig: &meta.BedrockMetaConfig{ - SecretAccessKey: os.Getenv("BEDROCK_ACCESS_KEY"), - Region: StrPtr("us-east-1"), - }, - ConcurrencyAndBufferSize: schemas.ConcurrencyAndBufferSize{ - Concurrency: 3, - BufferSize: 10, - }, - ProxyConfig: &schemas.ProxyConfig{ - Type: schemas.HttpProxy, - URL: yourProxyURL, - }, - } -``` +### 3. Drop-in Replacement -You can manage buffer size (maximum number of requests you want to hold in the system) concurrency (maximum number of requests you want to be made concurrently) for each provider. You can manage user usage and provider limits by providing these custom provider settings Default values are taken for network config, concurrecy and buffer sizes if not provided. - -Bifrost also supports multiple API keys per provider, enabling both load balancing and redundancy. You can assign weights to each key to control how frequently they are selected for requests. By default, all keys are treated with equal weight unless specified otherwise. - -```golang - []schemas.Key{ - { - Value: os.Getenv("OPEN_AI_API_KEY1"), - Models: []string{"gpt-4o-mini", "gpt-4-turbo"}, - Weight: 0.6, - }, - { - Value: os.Getenv("OPEN_AI_API_KEY2"), - Models: []string{"gpt-4-turbo"}, - Weight: 0.3, - }, - { - Value: os.Getenv("OPEN_AI_API_KEY3"), - Models: []string{"gpt-4o-mini"}, - Weight: 0.1, - }, - } -``` +**Best for:** Migrating existing applications with zero code changes -You can check [this](https://github.com/maximhq/bifrost/blob/main/core/tests/account.go) file to refer all the customisation settings. - -5. Fallbacks: You can define fallback providers for each request, which will be used if all retry attempts with your primary provider fail. These fallback providers are attempted in the order you specify, provided they are configured in your account at runtime. Once a fallback is triggered, its own retry settings will apply, rather than those of the original provider. - -```golang - result, err := bifrost.ChatCompletionRequest( - schemas.OpenAI, &schemas.BifrostRequest{ - Model: "gpt-4o", - Input: schemas.RequestInput{ - ChatCompletionInput: &messages, - }, - Fallbacks: []schemas.Fallback{ - { - Provider: schemas.Anthropic, - Model: "claude-3-5-sonnet-20240620", // make sure you have configured this - }, - }, - }, context.Background() - ) +```diff +# OpenAI SDK +- base_url = "https://api.openai.com" ++ base_url = "http://localhost:8080/openai" + +# Anthropic SDK +- base_url = "https://api.anthropic.com" ++ base_url = "http://localhost:8080/anthropic" + +# Google GenAI SDK +- api_endpoint = "https://generativelanguage.googleapis.com" ++ api_endpoint = "http://localhost:8080/genai" ``` +**Learn More:** [Integration Guides](https://docs.getbifrost.ai/integrations/what-is-an-integration) + +--- + +## Performance + +Bifrost adds virtually zero overhead to your AI requests. In sustained 5,000 RPS benchmarks, the gateway added only **11 Β΅s** of overhead per request. + +| Metric | t3.medium | t3.xlarge | Improvement | +|--------|-----------|-----------|-------------| +| Added latency (Bifrost overhead) | 59 Β΅s | **11 Β΅s** | **-81%** | +| Success rate @ 5k RPS | 100% | 100% | No failed requests | +| Avg. queue wait time | 47 Β΅s | **1.67 Β΅s** | **-96%** | +| Avg. request latency (incl. provider) | 2.12 s | **1.61 s** | **-24%** | + +**Key Performance Highlights:** + +- **Perfect Success Rate** - 100% request success rate even at 5k RPS +- **Minimal Overhead** - Less than 15 Β΅s additional latency per request +- **Efficient Queuing** - Sub-microsecond average wait times +- **Fast Key Selection** - ~10 ns to pick weighted API keys + +**Complete Benchmarks:** [Performance Analysis](https://docs.getbifrost.ai/benchmarking/getting-started) + +--- + +## Documentation + +**Complete Documentation:** [https://docs.getbifrost.ai](https://docs.getbifrost.ai) + +### Quick Start + +- [Gateway Setup](https://docs.getbifrost.ai/quickstart/gateway/setting-up) - HTTP API deployment in 30 seconds +- [Go SDK Setup](https://docs.getbifrost.ai/quickstart/go-sdk/setting-up) - Direct Go integration +- [Provider Configuration](https://docs.getbifrost.ai/quickstart/gateway/provider-configuration) - Multi-provider setup + +### Features + +- [Multi-Provider Support](https://docs.getbifrost.ai/features/unified-interface) - Single API for all providers +- [MCP Integration](https://docs.getbifrost.ai/features/mcp) - External tool calling +- [Semantic Caching](https://docs.getbifrost.ai/features/semantic-caching) - Intelligent response caching +- [Fallbacks & Load Balancing](https://docs.getbifrost.ai/features/fallbacks) - Reliability features +- [Budget Management](https://docs.getbifrost.ai/features/governance) - Cost control and governance + +### Integrations + +- [OpenAI SDK](https://docs.getbifrost.ai/integrations/openai-sdk) - Drop-in OpenAI replacement +- [Anthropic SDK](https://docs.getbifrost.ai/integrations/anthropic-sdk) - Drop-in Anthropic replacement +- [Google GenAI SDK](https://docs.getbifrost.ai/integrations/genai-sdk) - Drop-in GenAI replacement +- [LiteLLM SDK](https://docs.getbifrost.ai/integrations/litellm-sdk) - LiteLLM integration +- [Langchain SDK](https://docs.getbifrost.ai/integrations/langchain-sdk) - Langchain integration + +### Enterprise + +- [Custom Plugins](https://docs.getbifrost.ai/enterprise/custom-plugins) - Extend functionality +- [Clustering](https://docs.getbifrost.ai/enterprise/clustering) - Multi-node deployment +- [Vault Support](https://docs.getbifrost.ai/enterprise/vault-support) - Secure key management +- [Production Deployment](https://docs.getbifrost.ai/deployment/docker-setup) - Scaling and monitoring + --- -## 🀝 Contributing +## Need Help? -Contributions are welcome! We welcome all kinds of contributions β€” bug fixes, features, docs, and ideas. Please feel free to submit a Pull Request. +**[Join our Discord](https://discord.gg/exN5KAydbU)** for community support and discussions. -1. Fork the repository -2. Create your feature branch (`git checkout -b feature/amazing-feature`) -3. Commit your changes (`git commit -m 'Add some amazing feature'`) -4. Push to the branch (`git push origin feature/amazing-feature`) -5. Open a Pull Request and describe your changes +Get help with: + +- Quick setup assistance and troubleshooting +- Best practices and configuration tips +- Community discussions and support +- Real-time help with integrations --- -## πŸ“„ License +## Contributing + +We welcome contributions of all kinds! See our [Contributing Guide](https://docs.getbifrost.ai/contributing/setting-up-repo) for: -This project is licensed under the MIT License - see the [LICENSE](LICENSE) file for details. +- Setting up the development environment +- Code conventions and best practices +- How to submit pull requests +- Building and testing locally + +For development requirements and build instructions, see our [Development Setup Guide](https://docs.getbifrost.ai/contributing/building-a-plugins). --- +## License + +This project is licensed under the Apache 2.0 License - see the [LICENSE](LICENSE) file for details. + Built with ❀️ by [Maxim](https://github.com/maximhq) diff --git a/core/bifrost.go b/core/bifrost.go index a4ead1c27..e72d9c972 100644 --- a/core/bifrost.go +++ b/core/bifrost.go @@ -7,120 +7,86 @@ import ( "context" "fmt" "math/rand" - "os" - "os/signal" "slices" + "strings" "sync" - "syscall" + "sync/atomic" "time" "github.com/maximhq/bifrost/core/providers" schemas "github.com/maximhq/bifrost/core/schemas" ) -// RequestType represents the type of request being made to a provider. -type RequestType string - -const ( - TextCompletionRequest RequestType = "text_completion" - ChatCompletionRequest RequestType = "chat_completion" -) - // ChannelMessage represents a message passed through the request channel. // It contains the request, response and error channels, and the request type. type ChannelMessage struct { schemas.BifrostRequest - Response chan *schemas.BifrostResponse - Err chan schemas.BifrostError - Type RequestType + Context context.Context + Response chan *schemas.BifrostResponse + ResponseStream chan chan *schemas.BifrostStream + Err chan schemas.BifrostError + Type schemas.RequestType } -// Bifrost manages providers and maintains sepcified open channels for concurrent processing. +// Bifrost manages providers and maintains specified open channels for concurrent processing. // It handles request routing, provider management, and response processing. type Bifrost struct { - account schemas.Account // account interface - providers []schemas.Provider // list of processed providers - plugins []schemas.Plugin // list of plugins - requestQueues map[schemas.ModelProvider]chan ChannelMessage // provider request queues - waitGroups map[schemas.ModelProvider]*sync.WaitGroup // wait groups for each provider - channelMessagePool sync.Pool // Pool for ChannelMessage objects, initial pool size is set in Init - responseChannelPool sync.Pool // Pool for response channels, initial pool size is set in Init - errorChannelPool sync.Pool // Pool for error channels, initial pool size is set in Init - logger schemas.Logger // logger instance, default logger is used if not provided - dropExcessRequests bool // If true, in cases where the queue is full, requests will not wait for the queue to be empty and will be dropped instead. - backgroundCtx context.Context // Shared background context for nil context handling -} - -// createProviderFromProviderKey creates a new provider instance based on the provider key. -// It returns an error if the provider is not supported. -func (bifrost *Bifrost) createProviderFromProviderKey(providerKey schemas.ModelProvider, config *schemas.ProviderConfig) (schemas.Provider, error) { - switch providerKey { - case schemas.OpenAI: - return providers.NewOpenAIProvider(config, bifrost.logger), nil - case schemas.Anthropic: - return providers.NewAnthropicProvider(config, bifrost.logger), nil - case schemas.Bedrock: - return providers.NewBedrockProvider(config, bifrost.logger), nil - case schemas.Cohere: - return providers.NewCohereProvider(config, bifrost.logger), nil - case schemas.Azure: - return providers.NewAzureProvider(config, bifrost.logger), nil - default: - return nil, fmt.Errorf("unsupported provider: %s", providerKey) - } + ctx context.Context + account schemas.Account // account interface + plugins []schemas.Plugin // list of plugins + requestQueues sync.Map // provider request queues (thread-safe) + waitGroups sync.Map // wait groups for each provider (thread-safe) + providerMutexes sync.Map // mutexes for each provider to prevent concurrent updates (thread-safe) + channelMessagePool sync.Pool // Pool for ChannelMessage objects, initial pool size is set in Init + responseChannelPool sync.Pool // Pool for response channels, initial pool size is set in Init + errorChannelPool sync.Pool // Pool for error channels, initial pool size is set in Init + responseStreamPool sync.Pool // Pool for response stream channels, initial pool size is set in Init + pluginPipelinePool sync.Pool // Pool for PluginPipeline objects + logger schemas.Logger // logger instance, default logger is used if not provided + mcpManager *MCPManager // MCP integration manager (nil if MCP not configured) + dropExcessRequests atomic.Bool // If true, in cases where the queue is full, requests will not wait for the queue to be empty and will be dropped instead. } -// prepareProvider sets up a provider with its configuration, keys, and worker channels. -// It initializes the request queue and starts worker goroutines for processing requests. -func (bifrost *Bifrost) prepareProvider(providerKey schemas.ModelProvider, config *schemas.ProviderConfig) error { - providerConfig, err := bifrost.account.GetConfigForProvider(providerKey) - if err != nil { - return fmt.Errorf("failed to get config for provider: %v", err) - } - - // Check if the provider has any keys - keys, err := bifrost.account.GetKeysForProvider(providerKey) - if err != nil || len(keys) == 0 { - return fmt.Errorf("failed to get keys for provider: %v", err) - } - - queue := make(chan ChannelMessage, providerConfig.ConcurrencyAndBufferSize.BufferSize) // Buffered channel per provider - - bifrost.requestQueues[providerKey] = queue - - // Start specified number of workers - bifrost.waitGroups[providerKey] = &sync.WaitGroup{} - - provider, err := bifrost.createProviderFromProviderKey(providerKey, config) - if err != nil { - return fmt.Errorf("failed to get provider for the given key: %v", err) - } +// PluginPipeline encapsulates the execution of plugin PreHooks and PostHooks, tracks how many plugins ran, and manages short-circuiting and error aggregation. +type PluginPipeline struct { + plugins []schemas.Plugin + logger schemas.Logger - for range providerConfig.ConcurrencyAndBufferSize.Concurrency { - bifrost.waitGroups[providerKey].Add(1) - go bifrost.requestWorker(provider, queue) - } + // Number of PreHooks that were executed (used to determine which PostHooks to run in reverse order) + executedPreHooks int + // Errors from PreHooks and PostHooks + preHookErrors []error + postHookErrors []error +} - return nil +// Define a set of retryable status codes +var retryableStatusCodes = map[int]bool{ + 500: true, // Internal Server Error + 502: true, // Bad Gateway + 503: true, // Service Unavailable + 504: true, // Gateway Timeout + 429: true, // Too Many Requests } +// INITIALIZATION + // Init initializes a new Bifrost instance with the given configuration. // It sets up the account, plugins, object pools, and initializes providers. // Returns an error if initialization fails. // Initial Memory Allocations happens here as per the initial pool size. -func Init(config schemas.BifrostConfig) (*Bifrost, error) { +func Init(ctx context.Context, config schemas.BifrostConfig) (*Bifrost, error) { if config.Account == nil { return nil, fmt.Errorf("account is required to initialize Bifrost") } bifrost := &Bifrost{ - account: config.Account, - plugins: config.Plugins, - waitGroups: make(map[schemas.ModelProvider]*sync.WaitGroup), - requestQueues: make(map[schemas.ModelProvider]chan ChannelMessage), - dropExcessRequests: config.DropExcessRequests, - backgroundCtx: context.Background(), + ctx: ctx, + account: config.Account, + plugins: config.Plugins, + requestQueues: sync.Map{}, + waitGroups: sync.Map{}, } + bifrost.dropExcessRequests.Store(config.DropExcessRequests) // Initialize object pools bifrost.channelMessagePool = sync.Pool{ @@ -138,6 +104,19 @@ func Init(config schemas.BifrostConfig) (*Bifrost, error) { return make(chan schemas.BifrostError, 1) }, } + bifrost.responseStreamPool = sync.Pool{ + New: func() interface{} { + return make(chan chan *schemas.BifrostStream, 1) + }, + } + bifrost.pluginPipelinePool = sync.Pool{ + New: func() interface{} { + return &PluginPipeline{ + preHookErrors: make([]error, 0), + postHookErrors: make([]error, 0), + } + }, + } // Prewarm pools with multiple objects for range config.InitialPoolSize { @@ -145,6 +124,11 @@ func Init(config schemas.BifrostConfig) (*Bifrost, error) { bifrost.channelMessagePool.Put(&ChannelMessage{}) bifrost.responseChannelPool.Put(make(chan *schemas.BifrostResponse, 1)) bifrost.errorChannelPool.Put(make(chan schemas.BifrostError, 1)) + bifrost.responseStreamPool.Put(make(chan chan *schemas.BifrostStream, 1)) + bifrost.pluginPipelinePool.Put(&PluginPipeline{ + preHookErrors: make([]error, 0), + postHookErrors: make([]error, 0), + }) } providerKeys, err := bifrost.account.GetConfiguredProviders() @@ -157,618 +141,1437 @@ func Init(config schemas.BifrostConfig) (*Bifrost, error) { } bifrost.logger = config.Logger + // Initialize MCP manager if configured + if config.MCPConfig != nil { + mcpManager, err := newMCPManager(ctx, *config.MCPConfig, bifrost.logger) + if err != nil { + bifrost.logger.Warn(fmt.Sprintf("failed to initialize MCP manager: %v", err)) + } else { + bifrost.mcpManager = mcpManager + bifrost.logger.Info("MCP integration initialized successfully") + } + } + // Create buffered channels for each provider and start workers for _, providerKey := range providerKeys { + if strings.TrimSpace(string(providerKey)) == "" { + bifrost.logger.Warn("provider key is empty, skipping init") + continue + } + config, err := bifrost.account.GetConfigForProvider(providerKey) if err != nil { bifrost.logger.Warn(fmt.Sprintf("failed to get config for provider, skipping init: %v", err)) continue } - if err := bifrost.prepareProvider(providerKey, config); err != nil { - bifrost.logger.Warn(fmt.Sprintf("failed to prepare provider: %v", err)) + // Lock the provider mutex during initialization + providerMutex := bifrost.getProviderMutex(providerKey) + providerMutex.Lock() + err = bifrost.prepareProvider(providerKey, config) + providerMutex.Unlock() + + if err != nil { + bifrost.logger.Warn(fmt.Sprintf("failed to prepare provider %s: %v", providerKey, err)) } } return bifrost, nil } -// getChannelMessage gets a ChannelMessage from the pool and configures it with the request. -// It also gets response and error channels from their respective pools. -func (bifrost *Bifrost) getChannelMessage(req schemas.BifrostRequest, reqType RequestType) *ChannelMessage { - // Get channels from pool - responseChan := bifrost.responseChannelPool.Get().(chan *schemas.BifrostResponse) - errorChan := bifrost.errorChannelPool.Get().(chan schemas.BifrostError) +// PUBLIC API METHODS - // Clear any previous values to avoid leaking between requests - select { - case <-responseChan: - default: - } - select { - case <-errorChan: - default: +// TextCompletionRequest sends a text completion request to the specified provider. +func (bifrost *Bifrost) TextCompletionRequest(ctx context.Context, req *schemas.BifrostRequest) (*schemas.BifrostResponse, *schemas.BifrostError) { + if req.Input.TextCompletionInput == nil { + return nil, &schemas.BifrostError{ + IsBifrostError: false, + Error: schemas.ErrorField{ + Message: "text not provided for text completion request", + }, + } } - // Get message from pool and configure it - msg := bifrost.channelMessagePool.Get().(*ChannelMessage) - msg.BifrostRequest = req - msg.Response = responseChan - msg.Err = errorChan - msg.Type = reqType - - return msg + return bifrost.handleRequest(ctx, req, schemas.TextCompletionRequest) } -// releaseChannelMessage returns a ChannelMessage and its channels to their respective pools. -func (bifrost *Bifrost) releaseChannelMessage(msg *ChannelMessage) { - // Put channels back in pools - bifrost.responseChannelPool.Put(msg.Response) - bifrost.errorChannelPool.Put(msg.Err) +// ChatCompletionRequest sends a chat completion request to the specified provider. +func (bifrost *Bifrost) ChatCompletionRequest(ctx context.Context, req *schemas.BifrostRequest) (*schemas.BifrostResponse, *schemas.BifrostError) { + if req.Input.ChatCompletionInput == nil { + return nil, &schemas.BifrostError{ + IsBifrostError: false, + Error: schemas.ErrorField{ + Message: "chats not provided for chat completion request", + }, + } + } - // Clear references and return to pool - msg.Response = nil - msg.Err = nil - bifrost.channelMessagePool.Put(msg) + return bifrost.handleRequest(ctx, req, schemas.ChatCompletionRequest) } -// SelectKeyFromProviderForModel selects an appropriate API key for a given provider and model. -// It uses weighted random selection if multiple keys are available. -func (bifrost *Bifrost) SelectKeyFromProviderForModel(providerKey schemas.ModelProvider, model string) (string, error) { - keys, err := bifrost.account.GetKeysForProvider(providerKey) - if err != nil { - return "", err +// ChatCompletionStreamRequest sends a chat completion stream request to the specified provider. +func (bifrost *Bifrost) ChatCompletionStreamRequest(ctx context.Context, req *schemas.BifrostRequest) (chan *schemas.BifrostStream, *schemas.BifrostError) { + if req.Input.ChatCompletionInput == nil { + return nil, &schemas.BifrostError{ + IsBifrostError: false, + Error: schemas.ErrorField{ + Message: "chats not provided for chat completion request", + }, + } } - if len(keys) == 0 { - return "", fmt.Errorf("no keys found for provider: %v", providerKey) - } + return bifrost.handleStreamRequest(ctx, req, schemas.ChatCompletionStreamRequest) +} - // filter out keys which dont support the model - var supportedKeys []schemas.Key - for _, key := range keys { - if slices.Contains(key.Models, model) { - supportedKeys = append(supportedKeys, key) +// EmbeddingRequest sends an embedding request to the specified provider. +func (bifrost *Bifrost) EmbeddingRequest(ctx context.Context, req *schemas.BifrostRequest) (*schemas.BifrostResponse, *schemas.BifrostError) { + if req.Input.EmbeddingInput == nil { + return nil, &schemas.BifrostError{ + IsBifrostError: false, + Error: schemas.ErrorField{ + Message: "embedding input not provided for embedding request", + }, } } - if len(supportedKeys) == 0 { - return "", fmt.Errorf("no keys found that support model: %s", model) - } + return bifrost.handleRequest(ctx, req, schemas.EmbeddingRequest) +} - if len(supportedKeys) == 1 { - return supportedKeys[0].Value, nil +// SpeechRequest sends a speech request to the specified provider. +func (bifrost *Bifrost) SpeechRequest(ctx context.Context, req *schemas.BifrostRequest) (*schemas.BifrostResponse, *schemas.BifrostError) { + if req.Input.SpeechInput == nil { + return nil, &schemas.BifrostError{ + IsBifrostError: false, + Error: schemas.ErrorField{ + Message: "speech input not provided for speech request", + }, + } } - // Use a weighted random selection based on key weights - totalWeight := 0 - for _, key := range supportedKeys { - totalWeight += int(key.Weight * 100) // Convert float to int for better performance + return bifrost.handleRequest(ctx, req, schemas.SpeechRequest) +} + +// SpeechStreamRequest sends a speech stream request to the specified provider. +func (bifrost *Bifrost) SpeechStreamRequest(ctx context.Context, req *schemas.BifrostRequest) (chan *schemas.BifrostStream, *schemas.BifrostError) { + if req.Input.SpeechInput == nil { + return nil, &schemas.BifrostError{ + IsBifrostError: false, + Error: schemas.ErrorField{ + Message: "speech input not provided for speech stream request", + }, + } } - // Use a fast random number generator - randomSource := rand.New(rand.NewSource(time.Now().UnixNano())) - randomValue := randomSource.Intn(totalWeight) + return bifrost.handleStreamRequest(ctx, req, schemas.SpeechStreamRequest) +} - // Select key based on weight - currentWeight := 0 - for _, key := range supportedKeys { - currentWeight += int(key.Weight * 100) - if randomValue < currentWeight { - return key.Value, nil +// TranscriptionRequest sends a transcription request to the specified provider. +func (bifrost *Bifrost) TranscriptionRequest(ctx context.Context, req *schemas.BifrostRequest) (*schemas.BifrostResponse, *schemas.BifrostError) { + if req.Input.TranscriptionInput == nil { + return nil, &schemas.BifrostError{ + IsBifrostError: false, + Error: schemas.ErrorField{ + Message: "transcription input not provided for transcription request", + }, } } - // Fallback to first key if something goes wrong - return supportedKeys[0].Value, nil + return bifrost.handleRequest(ctx, req, schemas.TranscriptionRequest) } -// calculateBackoff implements exponential backoff with jitter for retry attempts. -func (bifrost *Bifrost) calculateBackoff(attempt int, config *schemas.ProviderConfig) time.Duration { - // Calculate an exponential backoff: initial * 2^attempt - backoff := config.NetworkConfig.RetryBackoffInitial * time.Duration(1< config.NetworkConfig.RetryBackoffMax { - backoff = config.NetworkConfig.RetryBackoffMax +// TranscriptionStreamRequest sends a transcription stream request to the specified provider. +func (bifrost *Bifrost) TranscriptionStreamRequest(ctx context.Context, req *schemas.BifrostRequest) (chan *schemas.BifrostStream, *schemas.BifrostError) { + if req.Input.TranscriptionInput == nil { + return nil, &schemas.BifrostError{ + IsBifrostError: false, + Error: schemas.ErrorField{ + Message: "transcription input not provided for transcription stream request", + }, + } } - // Add jitter (Β±20%) - jitter := float64(backoff) * (0.8 + 0.4*rand.Float64()) - - return time.Duration(jitter) + return bifrost.handleStreamRequest(ctx, req, schemas.TranscriptionStreamRequest) } -// requestWorker handles incoming requests from the queue for a specific provider. -// It manages retries, error handling, and response processing. -func (bifrost *Bifrost) requestWorker(provider schemas.Provider, queue chan ChannelMessage) { - defer bifrost.waitGroups[provider.GetProviderKey()].Done() - - for req := range queue { - var result *schemas.BifrostResponse - var bifrostError *schemas.BifrostError +// UpdateProviderConcurrency dynamically updates the queue size and concurrency for an existing provider. +// This method gracefully stops existing workers, creates a new queue with updated settings, +// and starts new workers with the updated concurrency configuration. +// +// Parameters: +// - providerKey: The provider to update +// +// Returns: +// - error: Any error that occurred during the update process +// +// Note: This operation will temporarily pause request processing for the specified provider +// while the transition occurs. In-flight requests will complete before workers are stopped. +// Buffered requests in the old queue will be transferred to the new queue to prevent loss. +func (bifrost *Bifrost) UpdateProviderConcurrency(providerKey schemas.ModelProvider) error { + bifrost.logger.Info(fmt.Sprintf("Updating concurrency configuration for provider %s", providerKey)) + + // Get the updated configuration from the account + providerConfig, err := bifrost.account.GetConfigForProvider(providerKey) + if err != nil { + return fmt.Errorf("failed to get updated config for provider %s: %v", providerKey, err) + } - key, err := bifrost.SelectKeyFromProviderForModel(provider.GetProviderKey(), req.Model) - if err != nil { - bifrost.logger.Warn(fmt.Sprintf("Error selecting key for model %s: %v", req.Model, err)) - req.Err <- schemas.BifrostError{ - IsBifrostError: false, - Error: schemas.ErrorField{ - Message: err.Error(), - Error: err, - }, - } - continue - } + // Lock the provider to prevent concurrent access during update + providerMutex := bifrost.getProviderMutex(providerKey) + providerMutex.Lock() + defer providerMutex.Unlock() + + // Check if provider currently exists + oldQueueValue, exists := bifrost.requestQueues.Load(providerKey) + if !exists { + bifrost.logger.Debug("provider %s not currently active, initializing with new configuration", providerKey) + // If provider doesn't exist, just prepare it with new configuration + return bifrost.prepareProvider(providerKey, providerConfig) + } - config, err := bifrost.account.GetConfigForProvider(provider.GetProviderKey()) - if err != nil { - bifrost.logger.Warn(fmt.Sprintf("Error getting config for provider %s: %v", provider.GetProviderKey(), err)) - req.Err <- schemas.BifrostError{ - IsBifrostError: false, - Error: schemas.ErrorField{ - Message: err.Error(), - Error: err, - }, - } - continue - } + oldQueue := oldQueueValue.(chan ChannelMessage) - // Track attempts - var attempts int + bifrost.logger.Debug("gracefully stopping existing workers for provider %s", providerKey) - // Execute request with retries - for attempts = 0; attempts <= config.NetworkConfig.MaxRetries; attempts++ { - if attempts > 0 { - // Log retry attempt - bifrost.logger.Info(fmt.Sprintf( - "Retrying request (attempt %d/%d) for model %s: %s", - attempts, config.NetworkConfig.MaxRetries, req.Model, - bifrostError.Error.Message, - )) + // Step 1: Create new queue with updated buffer size + newQueue := make(chan ChannelMessage, providerConfig.ConcurrencyAndBufferSize.BufferSize) - // Calculate and apply backoff - backoff := bifrost.calculateBackoff(attempts-1, config) - time.Sleep(backoff) + // Step 2: Transfer any buffered requests from old queue to new queue + // This prevents request loss during the transition + transferredCount := 0 + var transferWaitGroup sync.WaitGroup + for { + select { + case msg := <-oldQueue: + select { + case newQueue <- msg: + transferredCount++ + default: + // New queue is full, handle this request in a goroutine + // This is unlikely with proper buffer sizing but provides safety + transferWaitGroup.Add(1) + go func(m ChannelMessage) { + defer transferWaitGroup.Done() + select { + case newQueue <- m: + // Message successfully transferred + case <-time.After(5 * time.Second): + bifrost.logger.Warn("Failed to transfer buffered request to new queue within timeout") + // Send error response to avoid hanging the client + select { + case m.Err <- schemas.BifrostError{ + IsBifrostError: false, + Error: schemas.ErrorField{ + Message: "request failed during provider concurrency update", + }, + }: + case <-time.After(1 * time.Second): + // If we can't send the error either, just log and continue + bifrost.logger.Warn("Failed to send error response during transfer timeout") + } + } + }(msg) + goto transferComplete } + default: + // No more buffered messages + goto transferComplete + } + } - bifrost.logger.Debug(fmt.Sprintf("Attempting request for provider %s", provider.GetProviderKey())) +transferComplete: + // Wait for all transfer goroutines to complete + transferWaitGroup.Wait() + if transferredCount > 0 { + bifrost.logger.Info("transferred %d buffered requests to new queue for provider %s", transferredCount, providerKey) + } - // Attempt the request - if req.Type == TextCompletionRequest { - if req.Input.TextCompletionInput == nil { - bifrostError = &schemas.BifrostError{ - IsBifrostError: false, - Error: schemas.ErrorField{ - Message: "text not provided for text completion request", - }, - } - break // Don't retry client errors - } else { - result, bifrostError = provider.TextCompletion(req.Model, key, *req.Input.TextCompletionInput, req.Params) - } - } else if req.Type == ChatCompletionRequest { - if req.Input.ChatCompletionInput == nil { - bifrostError = &schemas.BifrostError{ - IsBifrostError: false, - Error: schemas.ErrorField{ - Message: "chats not provided for chat completion request", - }, - } - break // Don't retry client errors - } else { - result, bifrostError = provider.ChatCompletion(req.Model, key, *req.Input.ChatCompletionInput, req.Params) - } - } + // Step 3: Close the old queue to signal workers to stop + close(oldQueue) - bifrost.logger.Debug(fmt.Sprintf("Request for provider %s completed", provider.GetProviderKey())) + // Step 4: Atomically replace the queue + bifrost.requestQueues.Store(providerKey, newQueue) - // Check if successful or if we should retry - //TODO should have a better way to check for only network errors - if bifrostError == nil || bifrostError.IsBifrostError { // Only retry non-bifrost errors - break - } - } + // Step 5: Wait for all existing workers to finish processing in-flight requests + waitGroup, exists := bifrost.waitGroups.Load(providerKey) + if exists { + waitGroup.(*sync.WaitGroup).Wait() + bifrost.logger.Debug("all workers for provider %s have stopped", providerKey) + } - if bifrostError != nil { - // Add retry information to error - if attempts > 0 { - bifrost.logger.Warn(fmt.Sprintf("Request failed after %d %s", - attempts, - map[bool]string{true: "retries", false: "retry"}[attempts > 1])) - } - req.Err <- *bifrostError - } else { - req.Response <- result - } + // Step 6: Create new wait group for the updated workers + bifrost.waitGroups.Store(providerKey, &sync.WaitGroup{}) + + // Step 7: Create provider instance + provider, err := bifrost.createBaseProvider(providerKey, providerConfig) + if err != nil { + return fmt.Errorf("failed to create provider instance for %s: %v", providerKey, err) } - bifrost.logger.Debug(fmt.Sprintf("Worker for provider %s exiting...", provider.GetProviderKey())) -} + // Step 8: Start new workers with updated concurrency + bifrost.logger.Debug("starting %d new workers for provider %s with buffer size %d", + providerConfig.ConcurrencyAndBufferSize.Concurrency, + providerKey, + providerConfig.ConcurrencyAndBufferSize.BufferSize) -// GetConfiguredProviderFromProviderKey returns the provider instance for a given provider key. -// Uses the GetProviderKey method of the provider interface to find the provider. -func (bifrost *Bifrost) GetConfiguredProviderFromProviderKey(key schemas.ModelProvider) (schemas.Provider, error) { - for _, provider := range bifrost.providers { - if provider.GetProviderKey() == key { - return provider, nil - } + for range providerConfig.ConcurrencyAndBufferSize.Concurrency { + waitGroupValue, _ := bifrost.waitGroups.Load(providerKey) + waitGroup := waitGroupValue.(*sync.WaitGroup) + waitGroup.Add(1) + go bifrost.requestWorker(provider, providerConfig, newQueue) } - return nil, fmt.Errorf("no provider found for key: %s", key) + bifrost.logger.Info("successfully updated concurrency configuration for provider %s", providerKey) + return nil } -// GetProviderQueue returns the request queue for a given provider key. -// If the queue doesn't exist, it creates one at runtime and initializes the provider, -// given the provider config is provided in the account interface implementation. -func (bifrost *Bifrost) GetProviderQueue(providerKey schemas.ModelProvider) (chan ChannelMessage, error) { - var queue chan ChannelMessage - var exists bool - - if queue, exists = bifrost.requestQueues[providerKey]; !exists { - bifrost.logger.Debug(fmt.Sprintf("Creating new request queue for provider %s at runtime", providerKey)) +// GetDropExcessRequests returns the current value of DropExcessRequests +func (bifrost *Bifrost) GetDropExcessRequests() bool { + return bifrost.dropExcessRequests.Load() +} - config, err := bifrost.account.GetConfigForProvider(providerKey) - if err != nil { - return nil, fmt.Errorf("failed to get config for provider: %v", err) - } +// UpdateDropExcessRequests updates the DropExcessRequests setting at runtime. +// This allows for hot-reloading of this configuration value. +func (bifrost *Bifrost) UpdateDropExcessRequests(value bool) { + bifrost.dropExcessRequests.Store(value) + bifrost.logger.Info("drop_excess_requests updated to: %v", value) +} - if err := bifrost.prepareProvider(providerKey, config); err != nil { - return nil, err - } +// getProviderMutex gets or creates a mutex for the given provider +func (bifrost *Bifrost) getProviderMutex(providerKey schemas.ModelProvider) *sync.RWMutex { + mutexValue, _ := bifrost.providerMutexes.LoadOrStore(providerKey, &sync.RWMutex{}) + return mutexValue.(*sync.RWMutex) +} - queue = bifrost.requestQueues[providerKey] +// MCP PUBLIC API + +// RegisterMCPTool registers a typed tool handler with the MCP integration. +// This allows developers to easily add custom tools that will be available +// to all LLM requests processed by this Bifrost instance. +// +// Parameters: +// - name: Unique tool name +// - description: Human-readable tool description +// - handler: Function that handles tool execution +// - toolSchema: Bifrost tool schema for function calling +// +// Returns: +// - error: Any registration error +// +// Example: +// +// type EchoArgs struct { +// Message string `json:"message"` +// } +// +// err := bifrost.RegisterMCPTool("echo", "Echo a message", +// func(args EchoArgs) (string, error) { +// return args.Message, nil +// }, toolSchema) +func (bifrost *Bifrost) RegisterMCPTool(name, description string, handler func(args any) (string, error), toolSchema schemas.Tool) error { + if bifrost.mcpManager == nil { + return fmt.Errorf("MCP is not configured in this Bifrost instance") } - return queue, nil + return bifrost.mcpManager.registerTool(name, description, handler, toolSchema) } -// TextCompletionRequest sends a text completion request to the specified provider. -// It handles plugin hooks, request validation, response processing, and fallback providers. -// If the primary provider fails, it will try each fallback provider in order until one succeeds. -func (bifrost *Bifrost) TextCompletionRequest(providerKey schemas.ModelProvider, req *schemas.BifrostRequest, ctx context.Context) (*schemas.BifrostResponse, *schemas.BifrostError) { - if req == nil { +// ExecuteMCPTool executes an MCP tool call and returns the result as a tool message. +// This is the main public API for manual MCP tool execution. +// +// Parameters: +// - ctx: Execution context +// - toolCall: The tool call to execute (from assistant message) +// +// Returns: +// - schemas.BifrostMessage: Tool message with execution result +// - schemas.BifrostError: Any execution error +func (bifrost *Bifrost) ExecuteMCPTool(ctx context.Context, toolCall schemas.ToolCall) (*schemas.BifrostMessage, *schemas.BifrostError) { + if bifrost.mcpManager == nil { return nil, &schemas.BifrostError{ IsBifrostError: false, Error: schemas.ErrorField{ - Message: "bifrost request cannot be nil", + Message: "MCP is not configured in this Bifrost instance", }, } } - if req.Model == "" { + result, err := bifrost.mcpManager.executeTool(ctx, toolCall) + if err != nil { return nil, &schemas.BifrostError{ IsBifrostError: false, Error: schemas.ErrorField{ - Message: "model is required", + Message: err.Error(), + Error: err, }, } } - // Try the primary provider first - primaryResult, primaryErr := bifrost.tryTextCompletion(providerKey, req, ctx) - if primaryErr == nil { - return primaryResult, nil + return result, nil +} + +// IMPORTANT: Running the MCP client management operations (GetMCPClients, AddMCPClient, RemoveMCPClient, EditMCPClientTools) +// may temporarily increase latency for incoming requests while the operations are being processed. +// These operations involve network I/O and connection management that require mutex locks +// which can block briefly during execution. + +// GetMCPClients returns all MCP clients managed by the Bifrost instance. +// +// Returns: +// - []schemas.MCPClient: List of all MCP clients +// - error: Any retrieval error +func (bifrost *Bifrost) GetMCPClients() ([]schemas.MCPClient, error) { + if bifrost.mcpManager == nil { + return nil, fmt.Errorf("MCP is not configured in this Bifrost instance") } - // If primary provider failed and we have fallbacks, try them in order - if len(req.Fallbacks) > 0 { - for _, fallback := range req.Fallbacks { - // Check if we have config for this fallback provider - _, err := bifrost.account.GetConfigForProvider(fallback.Provider) - if err != nil { - bifrost.logger.Warn(fmt.Sprintf("Config not found for provider %s, skipping fallback: %v", fallback.Provider, err)) - continue - } + clients, err := bifrost.mcpManager.GetClients() + if err != nil { + return nil, err + } - // Create a new request with the fallback model - fallbackReq := *req - fallbackReq.Model = fallback.Model + clientsInConfig := make([]schemas.MCPClient, 0, len(clients)) + for _, client := range clients { + tools := make([]string, 0, len(client.ToolMap)) + for toolName := range client.ToolMap { + tools = append(tools, toolName) + } - // Try the fallback provider - result, fallbackErr := bifrost.tryTextCompletion(fallback.Provider, &fallbackReq, ctx) - if fallbackErr == nil { - bifrost.logger.Info(fmt.Sprintf("Successfully used fallback provider %s with model %s", fallback.Provider, fallback.Model)) - return result, nil - } - bifrost.logger.Warn(fmt.Sprintf("Fallback provider %s failed: %s", fallback.Provider, fallbackErr.Error.Message)) + state := schemas.MCPConnectionStateConnected + if client.Conn == nil { + state = schemas.MCPConnectionStateDisconnected } + + clientsInConfig = append(clientsInConfig, schemas.MCPClient{ + Name: client.Name, + Config: client.ExecutionConfig, + Tools: tools, + State: state, + }) } - // All providers failed, return the original error - return nil, primaryErr + return clientsInConfig, nil } -// tryTextCompletion attempts a text completion request with a single provider. -// This is a helper function used by TextCompletionRequest to handle individual provider attempts. -func (bifrost *Bifrost) tryTextCompletion(providerKey schemas.ModelProvider, req *schemas.BifrostRequest, ctx context.Context) (*schemas.BifrostResponse, *schemas.BifrostError) { - queue, err := bifrost.GetProviderQueue(providerKey) - if err != nil { - return nil, &schemas.BifrostError{ - IsBifrostError: false, - Error: schemas.ErrorField{ - Message: err.Error(), - }, +// AddMCPClient adds a new MCP client to the Bifrost instance. +// This allows for dynamic MCP client management at runtime. +// +// Parameters: +// - config: MCP client configuration +// +// Returns: +// - error: Any registration error +// +// Example: +// +// err := bifrost.AddMCPClient(schemas.MCPClientConfig{ +// Name: "my-mcp-client", +// ConnectionType: schemas.MCPConnectionTypeHTTP, +// ConnectionString: &url, +// }) +func (bifrost *Bifrost) AddMCPClient(config schemas.MCPClientConfig) error { + if bifrost.mcpManager == nil { + manager := &MCPManager{ + ctx: bifrost.ctx, + clientMap: make(map[string]*MCPClient), + logger: bifrost.logger, } + + bifrost.mcpManager = manager } - for _, plugin := range bifrost.plugins { - req, err = plugin.PreHook(&ctx, req) - if err != nil { - return nil, &schemas.BifrostError{ - IsBifrostError: false, - Error: schemas.ErrorField{ - Message: err.Error(), - }, - } - } + return bifrost.mcpManager.AddClient(config) +} + +// RemoveMCPClient removes an MCP client from the Bifrost instance. +// This allows for dynamic MCP client management at runtime. +// +// Parameters: +// - name: Name of the client to remove +// +// Returns: +// - error: Any removal error +// +// Example: +// +// err := bifrost.RemoveMCPClient("my-mcp-client") +// if err != nil { +// log.Fatalf("Failed to remove MCP client: %v", err) +// } +func (bifrost *Bifrost) RemoveMCPClient(name string) error { + if bifrost.mcpManager == nil { + return fmt.Errorf("MCP is not configured in this Bifrost instance") } - if req == nil { - return nil, &schemas.BifrostError{ - IsBifrostError: false, - Error: schemas.ErrorField{ - Message: "bifrost request after plugin hooks cannot be nil", - }, + return bifrost.mcpManager.RemoveClient(name) +} + +// EditMCPClientTools edits the tools of an MCP client. +// This allows for dynamic MCP client tool management at runtime. +// +// Parameters: +// - name: Name of the client to edit +// - toolsToAdd: Tools to add to the client +// - toolsToRemove: Tools to remove from the client +// +// Returns: +// - error: Any edit error +// +// Example: +// +// err := bifrost.EditMCPClientTools("my-mcp-client", []string{"tool1", "tool2"}, []string{"tool3"}) +// if err != nil { +// log.Fatalf("Failed to edit MCP client tools: %v", err) +// } +func (bifrost *Bifrost) EditMCPClientTools(name string, toolsToAdd []string, toolsToRemove []string) error { + if bifrost.mcpManager == nil { + return fmt.Errorf("MCP is not configured in this Bifrost instance") + } + + return bifrost.mcpManager.EditClientTools(name, toolsToAdd, toolsToRemove) +} + +// ReconnectMCPClient attempts to reconnect an MCP client if it is disconnected. +// +// Parameters: +// - name: Name of the client to reconnect +// +// Returns: +// - error: Any reconnection error +func (bifrost *Bifrost) ReconnectMCPClient(name string) error { + if bifrost.mcpManager == nil { + return fmt.Errorf("MCP is not configured in this Bifrost instance") + } + + return bifrost.mcpManager.ReconnectClient(name) +} + +// PROVIDER MANAGEMENT + +// createBaseProvider creates a provider based on the base provider type +func (bifrost *Bifrost) createBaseProvider(providerKey schemas.ModelProvider, config *schemas.ProviderConfig) (schemas.Provider, error) { + // Determine which provider type to create + targetProviderKey := providerKey + + if config.CustomProviderConfig != nil { + // Validate custom provider config + if config.CustomProviderConfig.BaseProviderType == "" { + return nil, fmt.Errorf("custom provider config missing base provider type") } + + // Validate that base provider type is supported + if !IsSupportedBaseProvider(config.CustomProviderConfig.BaseProviderType) { + return nil, fmt.Errorf("unsupported base provider type: %s", config.CustomProviderConfig.BaseProviderType) + } + + // Automatically set the custom provider key to the provider name + config.CustomProviderConfig.CustomProviderKey = string(providerKey) + + targetProviderKey = config.CustomProviderConfig.BaseProviderType + } + + switch targetProviderKey { + case schemas.OpenAI: + return providers.NewOpenAIProvider(config, bifrost.logger), nil + case schemas.Anthropic: + return providers.NewAnthropicProvider(config, bifrost.logger), nil + case schemas.Bedrock: + return providers.NewBedrockProvider(config, bifrost.logger) + case schemas.Cohere: + return providers.NewCohereProvider(config, bifrost.logger), nil + case schemas.Azure: + return providers.NewAzureProvider(config, bifrost.logger) + case schemas.Vertex: + return providers.NewVertexProvider(config, bifrost.logger) + case schemas.Mistral: + return providers.NewMistralProvider(config, bifrost.logger), nil + case schemas.Ollama: + return providers.NewOllamaProvider(config, bifrost.logger) + case schemas.Groq: + return providers.NewGroqProvider(config, bifrost.logger) + case schemas.SGL: + return providers.NewSGLProvider(config, bifrost.logger) + case schemas.Parasail: + return providers.NewParasailProvider(config, bifrost.logger) + case schemas.Cerebras: + return providers.NewCerebrasProvider(config, bifrost.logger) + case schemas.Gemini: + return providers.NewGeminiProvider(config, bifrost.logger), nil + case schemas.OpenRouter: + return providers.NewOpenRouterProvider(config, bifrost.logger), nil + default: + return nil, fmt.Errorf("unsupported provider: %s", targetProviderKey) + } +} + +// prepareProvider sets up a provider with its configuration, keys, and worker channels. +// It initializes the request queue and starts worker goroutines for processing requests. +// Note: This function assumes the caller has already acquired the appropriate mutex for the provider. +func (bifrost *Bifrost) prepareProvider(providerKey schemas.ModelProvider, config *schemas.ProviderConfig) error { + providerConfig, err := bifrost.account.GetConfigForProvider(providerKey) + if err != nil { + return fmt.Errorf("failed to get config for provider: %v", err) + } + + queue := make(chan ChannelMessage, providerConfig.ConcurrencyAndBufferSize.BufferSize) // Buffered channel per provider + + bifrost.requestQueues.Store(providerKey, queue) + + // Start specified number of workers + bifrost.waitGroups.Store(providerKey, &sync.WaitGroup{}) + + provider, err := bifrost.createBaseProvider(providerKey, config) + if err != nil { + return fmt.Errorf("failed to create provider for the given key: %v", err) + } + + for range providerConfig.ConcurrencyAndBufferSize.Concurrency { + waitGroupValue, _ := bifrost.waitGroups.Load(providerKey) + waitGroup := waitGroupValue.(*sync.WaitGroup) + waitGroup.Add(1) + go bifrost.requestWorker(provider, providerConfig, queue) + } + + return nil +} + +// getProviderQueue returns the request queue for a given provider key. +// If the queue doesn't exist, it creates one at runtime and initializes the provider, +// given the provider config is provided in the account interface implementation. +// This function uses read locks to prevent race conditions during provider updates. +func (bifrost *Bifrost) getProviderQueue(providerKey schemas.ModelProvider) (chan ChannelMessage, error) { + // Use read lock to allow concurrent reads but prevent concurrent updates + providerMutex := bifrost.getProviderMutex(providerKey) + providerMutex.RLock() + + if queueValue, exists := bifrost.requestQueues.Load(providerKey); exists { + queue := queueValue.(chan ChannelMessage) + providerMutex.RUnlock() + return queue, nil + } + + // Provider doesn't exist, need to create it + // Upgrade to write lock for creation + providerMutex.RUnlock() + providerMutex.Lock() + defer providerMutex.Unlock() + + // Double-check after acquiring write lock (another goroutine might have created it) + if queueValue, exists := bifrost.requestQueues.Load(providerKey); exists { + queue := queueValue.(chan ChannelMessage) + return queue, nil + } + + bifrost.logger.Debug(fmt.Sprintf("Creating new request queue for provider %s at runtime", providerKey)) + + config, err := bifrost.account.GetConfigForProvider(providerKey) + if err != nil { + return nil, fmt.Errorf("failed to get config for provider: %v", err) + } + + if err := bifrost.prepareProvider(providerKey, config); err != nil { + return nil, err + } + + queueValue, _ := bifrost.requestQueues.Load(providerKey) + queue := queueValue.(chan ChannelMessage) + + return queue, nil +} + +// CORE INTERNAL LOGIC + +// shouldTryFallbacks handles the primary error and returns true if we should proceed with fallbacks, false if we should return immediately +func (bifrost *Bifrost) shouldTryFallbacks(req *schemas.BifrostRequest, primaryErr *schemas.BifrostError) bool { + // If no primary error, we succeeded + if primaryErr == nil { + return false + } + + // Handle request cancellation + if primaryErr.Error.Type != nil && *primaryErr.Error.Type == schemas.RequestCancelled { + primaryErr.Provider = req.Provider + return false + } + + // Check if this is a short-circuit error that doesn't allow fallbacks + // Note: AllowFallbacks = nil is treated as true (allow fallbacks by default) + if primaryErr.AllowFallbacks != nil && !*primaryErr.AllowFallbacks { + primaryErr.Provider = req.Provider + return false + } + + // If no fallbacks configured, return primary error + if len(req.Fallbacks) == 0 { + primaryErr.Provider = req.Provider + return false + } + + // Should proceed with fallbacks + return true +} + +// prepareFallbackRequest creates a fallback request and validates the provider config +// Returns the fallback request or nil if this fallback should be skipped +func (bifrost *Bifrost) prepareFallbackRequest(req *schemas.BifrostRequest, fallback schemas.Fallback) *schemas.BifrostRequest { + // Check if we have config for this fallback provider + _, err := bifrost.account.GetConfigForProvider(fallback.Provider) + if err != nil { + bifrost.logger.Warn(fmt.Sprintf("Config not found for provider %s, skipping fallback: %v", fallback.Provider, err)) + return nil + } + + // Create a new request with the fallback provider and model + fallbackReq := *req + fallbackReq.Provider = fallback.Provider + fallbackReq.Model = fallback.Model + return &fallbackReq +} + +// shouldContinueWithFallbacks processes errors from fallback attempts +// Returns true if we should continue with more fallbacks, false if we should stop +func (bifrost *Bifrost) shouldContinueWithFallbacks(fallback schemas.Fallback, fallbackErr *schemas.BifrostError) bool { + if fallbackErr.Error.Type != nil && *fallbackErr.Error.Type == schemas.RequestCancelled { + fallbackErr.Provider = fallback.Provider + return false } - // Get a ChannelMessage from the pool - msg := bifrost.getChannelMessage(*req, TextCompletionRequest) + // Check if it was a short-circuit error that doesn't allow fallbacks + if fallbackErr.AllowFallbacks != nil && !*fallbackErr.AllowFallbacks { + fallbackErr.Provider = fallback.Provider + return false + } + + bifrost.logger.Warn(fmt.Sprintf("Fallback provider %s failed: %s", fallback.Provider, fallbackErr.Error.Message)) + return true +} + +// handleRequest handles the request to the provider based on the request type +// It handles plugin hooks, request validation, response processing, and fallback providers. +// If the primary provider fails, it will try each fallback provider in order until one succeeds. +// It is the wrapper for all non-streaming public API methods. +func (bifrost *Bifrost) handleRequest(ctx context.Context, req *schemas.BifrostRequest, requestType schemas.RequestType) (*schemas.BifrostResponse, *schemas.BifrostError) { + if err := validateRequest(req); err != nil { + err.Provider = req.Provider + return nil, err + } + + // Handle nil context early to prevent blocking + if ctx == nil { + ctx = bifrost.ctx + } + + // Try the primary provider first + primaryResult, primaryErr := bifrost.tryRequest(req, ctx, requestType) + + // Check if we should proceed with fallbacks + shouldTryFallbacks := bifrost.shouldTryFallbacks(req, primaryErr) + if !shouldTryFallbacks { + return primaryResult, primaryErr + } + + // Try fallbacks in order + for _, fallback := range req.Fallbacks { + fallbackReq := bifrost.prepareFallbackRequest(req, fallback) + if fallbackReq == nil { + continue + } + + // Try the fallback provider + result, fallbackErr := bifrost.tryRequest(fallbackReq, ctx, requestType) + if fallbackErr == nil { + bifrost.logger.Info(fmt.Sprintf("Successfully used fallback provider %s with model %s", fallback.Provider, fallback.Model)) + return result, nil + } + + // Check if we should continue with more fallbacks + if !bifrost.shouldContinueWithFallbacks(fallback, fallbackErr) { + return nil, fallbackErr + } + } + + primaryErr.Provider = req.Provider + // All providers failed, return the original error + return nil, primaryErr +} + +// handleStreamRequest handles the stream request to the provider based on the request type +// It handles plugin hooks, request validation, response processing, and fallback providers. +// If the primary provider fails, it will try each fallback provider in order until one succeeds. +// It is the wrapper for all streaming public API methods. +func (bifrost *Bifrost) handleStreamRequest(ctx context.Context, req *schemas.BifrostRequest, requestType schemas.RequestType) (chan *schemas.BifrostStream, *schemas.BifrostError) { + if err := validateRequest(req); err != nil { + err.Provider = req.Provider + return nil, err + } + + // Handle nil context early to prevent blocking + if ctx == nil { + ctx = bifrost.ctx + } + + // Try the primary provider first + primaryResult, primaryErr := bifrost.tryStreamRequest(req, ctx, requestType) + + // Check if we should proceed with fallbacks + shouldTryFallbacks := bifrost.shouldTryFallbacks(req, primaryErr) + if !shouldTryFallbacks { + return primaryResult, primaryErr + } + + // Try fallbacks in order + for _, fallback := range req.Fallbacks { + fallbackReq := bifrost.prepareFallbackRequest(req, fallback) + if fallbackReq == nil { + continue + } + + // Try the fallback provider + result, fallbackErr := bifrost.tryStreamRequest(fallbackReq, ctx, requestType) + if fallbackErr == nil { + bifrost.logger.Info(fmt.Sprintf("Successfully used fallback provider %s with model %s", fallback.Provider, fallback.Model)) + return result, nil + } + + // Check if we should continue with more fallbacks + if !bifrost.shouldContinueWithFallbacks(fallback, fallbackErr) { + return nil, fallbackErr + } + } + + primaryErr.Provider = req.Provider + // All providers failed, return the original error + return nil, primaryErr +} + +// tryRequest is a generic function that handles common request processing logic +// It consolidates queue setup, plugin pipeline execution, enqueue logic, and response handling +func (bifrost *Bifrost) tryRequest(req *schemas.BifrostRequest, ctx context.Context, requestType schemas.RequestType) (*schemas.BifrostResponse, *schemas.BifrostError) { + queue, err := bifrost.getProviderQueue(req.Provider) + if err != nil { + return nil, newBifrostError(err) + } + + // Attach context keys to the context + ctx = attachContextKeys(ctx, req, requestType) + + // Add MCP tools to request if MCP is configured and requested + if requestType != schemas.EmbeddingRequest && + requestType != schemas.SpeechRequest && + requestType != schemas.TranscriptionRequest && + bifrost.mcpManager != nil { + req = bifrost.mcpManager.addMCPToolsToBifrostRequest(ctx, req) + } + + pipeline := bifrost.getPluginPipeline() + defer bifrost.releasePluginPipeline(pipeline) + + preReq, shortCircuit, preCount := pipeline.RunPreHooks(&ctx, req) + if shortCircuit != nil { + // Handle short-circuit with response (success case) + if shortCircuit.Response != nil { + resp, bifrostErr := pipeline.RunPostHooks(&ctx, shortCircuit.Response, nil, preCount) + if bifrostErr != nil { + return nil, bifrostErr + } + return resp, nil + } + // Handle short-circuit with error + if shortCircuit.Error != nil { + resp, bifrostErr := pipeline.RunPostHooks(&ctx, nil, shortCircuit.Error, preCount) + if bifrostErr != nil { + return nil, bifrostErr + } + return resp, nil + } + } + if preReq == nil { + return nil, newBifrostErrorFromMsg("bifrost request after plugin hooks cannot be nil") + } + + msg := bifrost.getChannelMessage(*preReq, requestType) + msg.Context = ctx - // Handle queue send with context and proper cleanup select { case queue <- *msg: // Message was sent successfully case <-ctx.Done(): - // Request was cancelled by caller bifrost.releaseChannelMessage(msg) - return nil, &schemas.BifrostError{ - IsBifrostError: false, - Error: schemas.ErrorField{ - Message: "request cancelled while waiting for queue space", - }, - } + return nil, newBifrostErrorFromMsg("request cancelled while waiting for queue space") default: - if bifrost.dropExcessRequests { - // Drop request immediately if configured to do so + if bifrost.dropExcessRequests.Load() { bifrost.releaseChannelMessage(msg) bifrost.logger.Warn("Request dropped: queue is full, please increase the queue size or set dropExcessRequests to false") - return nil, &schemas.BifrostError{ - IsBifrostError: false, - Error: schemas.ErrorField{ - Message: "request dropped: queue is full", - }, - } - } - // If not dropping excess requests, wait with context - if ctx == nil { - ctx = bifrost.backgroundCtx + return nil, newBifrostErrorFromMsg("request dropped: queue is full") } select { case queue <- *msg: // Message was sent successfully case <-ctx.Done(): bifrost.releaseChannelMessage(msg) - return nil, &schemas.BifrostError{ - IsBifrostError: false, - Error: schemas.ErrorField{ - Message: "request cancelled while waiting for queue space", - }, - } + return nil, newBifrostErrorFromMsg("request cancelled while waiting for queue space") } } - // Handle response var result *schemas.BifrostResponse + var resp *schemas.BifrostResponse select { case result = <-msg.Response: - // Run plugins in reverse order - for i := len(bifrost.plugins) - 1; i >= 0; i-- { - result, err = bifrost.plugins[i].PostHook(&ctx, result) + resp, bifrostErr := pipeline.RunPostHooks(&ctx, result, nil, len(bifrost.plugins)) + if bifrostErr != nil { + bifrost.releaseChannelMessage(msg) + return nil, bifrostErr + } + bifrost.releaseChannelMessage(msg) + return resp, nil + case bifrostErrVal := <-msg.Err: + bifrostErrPtr := &bifrostErrVal + resp, bifrostErrPtr = pipeline.RunPostHooks(&ctx, nil, bifrostErrPtr, len(bifrost.plugins)) + bifrost.releaseChannelMessage(msg) + if bifrostErrPtr != nil { + return nil, bifrostErrPtr + } + return resp, nil + } +} + +// tryStreamRequest is a generic function that handles common request processing logic +// It consolidates queue setup, plugin pipeline execution, enqueue logic, and response handling +func (bifrost *Bifrost) tryStreamRequest(req *schemas.BifrostRequest, ctx context.Context, requestType schemas.RequestType) (chan *schemas.BifrostStream, *schemas.BifrostError) { + queue, err := bifrost.getProviderQueue(req.Provider) + if err != nil { + return nil, newBifrostError(err) + } + + // Attach context keys to the context + ctx = attachContextKeys(ctx, req, requestType) + + // Add MCP tools to request if MCP is configured and requested + if requestType != schemas.SpeechStreamRequest && requestType != schemas.TranscriptionStreamRequest && bifrost.mcpManager != nil { + req = bifrost.mcpManager.addMCPToolsToBifrostRequest(ctx, req) + } + + pipeline := bifrost.getPluginPipeline() + defer bifrost.releasePluginPipeline(pipeline) + + preReq, shortCircuit, preCount := pipeline.RunPreHooks(&ctx, req) + if shortCircuit != nil { + // Handle short-circuit with response (success case) + if shortCircuit.Response != nil { + resp, bifrostErr := pipeline.RunPostHooks(&ctx, shortCircuit.Response, nil, preCount) + if bifrostErr != nil { + return nil, bifrostErr + } + return newBifrostMessageChan(resp), nil + } + // Handle short-circuit with stream + if shortCircuit.Stream != nil { + outputStream := make(chan *schemas.BifrostStream) + + // Create a post hook runner cause pipeline object is put back in the pool on defer + pipelinePostHookRunner := func(ctx *context.Context, result *schemas.BifrostResponse, err *schemas.BifrostError) (*schemas.BifrostResponse, *schemas.BifrostError) { + return pipeline.RunPostHooks(ctx, result, err, preCount) + } + + go func() { + defer close(outputStream) + + for streamMsg := range shortCircuit.Stream { + if streamMsg == nil { + continue + } + + // Run post hooks on the stream message + processedResp, processedErr := pipelinePostHookRunner(&ctx, streamMsg.BifrostResponse, streamMsg.BifrostError) + + // Send the processed message to the output stream + outputStream <- &schemas.BifrostStream{ + BifrostResponse: processedResp, + BifrostError: processedErr, + } + } + }() + + return outputStream, nil + } + // Handle short-circuit with error + if shortCircuit.Error != nil { + resp, bifrostErr := pipeline.RunPostHooks(&ctx, nil, shortCircuit.Error, preCount) + if bifrostErr != nil { + return nil, bifrostErr + } + return newBifrostMessageChan(resp), nil + } + } + if preReq == nil { + return nil, newBifrostErrorFromMsg("bifrost request after plugin hooks cannot be nil") + } + + msg := bifrost.getChannelMessage(*preReq, requestType) + msg.Context = ctx + + select { + case queue <- *msg: + // Message was sent successfully + case <-ctx.Done(): + bifrost.releaseChannelMessage(msg) + return nil, newBifrostErrorFromMsg("request cancelled while waiting for queue space") + default: + if bifrost.dropExcessRequests.Load() { + bifrost.releaseChannelMessage(msg) + bifrost.logger.Warn("Request dropped: queue is full, please increase the queue size or set dropExcessRequests to false") + return nil, newBifrostErrorFromMsg("request dropped: queue is full") + } + select { + case queue <- *msg: + // Message was sent successfully + case <-ctx.Done(): + bifrost.releaseChannelMessage(msg) + return nil, newBifrostErrorFromMsg("request cancelled while waiting for queue space") + } + } + + select { + case stream := <-msg.ResponseStream: + bifrost.releaseChannelMessage(msg) + return stream, nil + case bifrostErrVal := <-msg.Err: + bifrost.releaseChannelMessage(msg) + return nil, &bifrostErrVal + } +} + +// requestWorker handles incoming requests from the queue for a specific provider. +// It manages retries, error handling, and response processing. +func (bifrost *Bifrost) requestWorker(provider schemas.Provider, config *schemas.ProviderConfig, queue chan ChannelMessage) { + defer func() { + if waitGroupValue, ok := bifrost.waitGroups.Load(provider.GetProviderKey()); ok { + waitGroup := waitGroupValue.(*sync.WaitGroup) + waitGroup.Done() + } + }() + + for req := range queue { + var result *schemas.BifrostResponse + var stream chan *schemas.BifrostStream + var bifrostError *schemas.BifrostError + var err error + + // Determine the base provider type for key requirement checks + baseProvider := provider.GetProviderKey() + if cfg := config.CustomProviderConfig; cfg != nil && cfg.BaseProviderType != "" { + baseProvider = cfg.BaseProviderType + } + + key := schemas.Key{} + if providerRequiresKey(baseProvider) { + // Use the custom provider name for actual key selection, but pass base provider type for key validation + key, err = bifrost.selectKeyFromProviderForModel(&req.Context, provider.GetProviderKey(), req.Model, baseProvider) if err != nil { - bifrost.releaseChannelMessage(msg) - return nil, &schemas.BifrostError{ + bifrost.logger.Warn("error selecting key for model %s: %v", req.Model, err) + req.Err <- schemas.BifrostError{ IsBifrostError: false, Error: schemas.ErrorField{ Message: err.Error(), + Error: err, }, } + continue + } + } + + // Track attempts + var attempts int + + // Create plugin pipeline for streaming requests outside retry loop to prevent leaks + var postHookRunner schemas.PostHookRunner + if IsStreamRequestType(req.Type) { + pipeline := bifrost.getPluginPipeline() + defer bifrost.releasePluginPipeline(pipeline) + + postHookRunner = func(ctx *context.Context, result *schemas.BifrostResponse, err *schemas.BifrostError) (*schemas.BifrostResponse, *schemas.BifrostError) { + resp, bifrostErr := pipeline.RunPostHooks(ctx, result, err, len(bifrost.plugins)) + if bifrostErr != nil { + return nil, bifrostErr + } + return resp, nil + } + } + + // Execute request with retries + for attempts = 0; attempts <= config.NetworkConfig.MaxRetries; attempts++ { + if attempts > 0 { + // Log retry attempt + bifrost.logger.Info("retrying request (attempt %d/%d) for model %s: %s", attempts, config.NetworkConfig.MaxRetries, req.Model, bifrostError.Error.Message) + + // Calculate and apply backoff + backoff := calculateBackoff(attempts-1, config) + time.Sleep(backoff) + } + + bifrost.logger.Debug("attempting request for provider %s", provider.GetProviderKey()) + + // Attempt the request + if IsStreamRequestType(req.Type) { + stream, bifrostError = handleProviderStreamRequest(provider, &req, key, postHookRunner, req.Type) + if bifrostError != nil && !bifrostError.IsBifrostError { + break // Don't retry client errors + } + } else { + result, bifrostError = handleProviderRequest(provider, &req, key, req.Type) + if bifrostError != nil { + break // Don't retry client errors + } + } + + bifrost.logger.Debug("request for provider %s completed", provider.GetProviderKey()) + + // Check if successful or if we should retry + if bifrostError == nil || + bifrostError.IsBifrostError || + (bifrostError.StatusCode != nil && !retryableStatusCodes[*bifrostError.StatusCode]) || + (bifrostError.Error.Type != nil && *bifrostError.Error.Type == schemas.RequestCancelled) { + break + } + } + + if bifrostError != nil { + // Add retry information to error + if attempts > 0 { + bifrost.logger.Warn("request failed after %d %s", attempts, map[bool]string{true: "retries", false: "retry"}[attempts > 1]) + } + // Send error with context awareness to prevent deadlock + select { + case req.Err <- *bifrostError: + // Error sent successfully + case <-req.Context.Done(): + // Client no longer listening, log and continue + bifrost.logger.Debug("Client context cancelled while sending error response") + case <-time.After(5 * time.Second): + // Timeout to prevent indefinite blocking + bifrost.logger.Warn("Timeout while sending error response, client may have disconnected") + } + } else { + if IsStreamRequestType(req.Type) { + // Send stream with context awareness to prevent deadlock + select { + case req.ResponseStream <- stream: + // Stream sent successfully + case <-req.Context.Done(): + // Client no longer listening, log and continue + bifrost.logger.Debug("Client context cancelled while sending stream response") + case <-time.After(5 * time.Second): + // Timeout to prevent indefinite blocking + bifrost.logger.Warn("Timeout while sending stream response, client may have disconnected") + } + } else { + // Send response with context awareness to prevent deadlock + select { + case req.Response <- result: + // Response sent successfully + case <-req.Context.Done(): + // Client no longer listening, log and continue + bifrost.logger.Debug("Client context cancelled while sending response") + case <-time.After(5 * time.Second): + // Timeout to prevent indefinite blocking + bifrost.logger.Warn("Timeout while sending response, client may have disconnected") + } } } - case err := <-msg.Err: - bifrost.releaseChannelMessage(msg) - return nil, &err } - // Return message to pool - bifrost.releaseChannelMessage(msg) - return result, nil + bifrost.logger.Debug("worker for provider %s exiting...", provider.GetProviderKey()) } -// ChatCompletionRequest sends a chat completion request to the specified provider. -// It handles plugin hooks, request validation, response processing, and fallback providers. -// If the primary provider fails, it will try each fallback provider in order until one succeeds. -func (bifrost *Bifrost) ChatCompletionRequest(providerKey schemas.ModelProvider, req *schemas.BifrostRequest, ctx context.Context) (*schemas.BifrostResponse, *schemas.BifrostError) { - if req == nil { +// handleProviderRequest handles the request to the provider based on the request type +func handleProviderRequest(provider schemas.Provider, req *ChannelMessage, key schemas.Key, reqType schemas.RequestType) (*schemas.BifrostResponse, *schemas.BifrostError) { + switch reqType { + case schemas.TextCompletionRequest: + return provider.TextCompletion(req.Context, req.Model, key, *req.Input.TextCompletionInput, req.Params) + case schemas.ChatCompletionRequest: + return provider.ChatCompletion(req.Context, req.Model, key, *req.Input.ChatCompletionInput, req.Params) + case schemas.EmbeddingRequest: + return provider.Embedding(req.Context, req.Model, key, req.Input.EmbeddingInput, req.Params) + case schemas.SpeechRequest: + return provider.Speech(req.Context, req.Model, key, req.Input.SpeechInput, req.Params) + case schemas.TranscriptionRequest: + return provider.Transcription(req.Context, req.Model, key, req.Input.TranscriptionInput, req.Params) + default: return nil, &schemas.BifrostError{ IsBifrostError: false, Error: schemas.ErrorField{ - Message: "bifrost request cannot be nil", + Message: fmt.Sprintf("unsupported request type: %s", reqType), }, } } +} - if req.Model == "" { +// handleProviderStreamRequest handles the stream request to the provider based on the request type +func handleProviderStreamRequest(provider schemas.Provider, req *ChannelMessage, key schemas.Key, postHookRunner schemas.PostHookRunner, reqType schemas.RequestType) (chan *schemas.BifrostStream, *schemas.BifrostError) { + switch reqType { + case schemas.ChatCompletionStreamRequest: + return provider.ChatCompletionStream(req.Context, postHookRunner, req.Model, key, *req.Input.ChatCompletionInput, req.Params) + case schemas.SpeechStreamRequest: + return provider.SpeechStream(req.Context, postHookRunner, req.Model, key, req.Input.SpeechInput, req.Params) + case schemas.TranscriptionStreamRequest: + return provider.TranscriptionStream(req.Context, postHookRunner, req.Model, key, req.Input.TranscriptionInput, req.Params) + default: return nil, &schemas.BifrostError{ IsBifrostError: false, Error: schemas.ErrorField{ - Message: "model is required", + Message: fmt.Sprintf("unsupported request type: %s", reqType), }, } } +} - // Try the primary provider first - primaryResult, primaryErr := bifrost.tryChatCompletion(providerKey, req, ctx) - if primaryErr == nil { - return primaryResult, nil +// PLUGIN MANAGEMENT + +// RunPreHooks executes PreHooks in order, tracks how many ran, and returns the final request, any short-circuit decision, and the count. +func (p *PluginPipeline) RunPreHooks(ctx *context.Context, req *schemas.BifrostRequest) (*schemas.BifrostRequest, *schemas.PluginShortCircuit, int) { + var shortCircuit *schemas.PluginShortCircuit + var err error + for i, plugin := range p.plugins { + req, shortCircuit, err = plugin.PreHook(ctx, req) + if err != nil { + p.preHookErrors = append(p.preHookErrors, err) + p.logger.Warn("error in PreHook for plugin %s: %v", plugin.GetName(), err) + } + p.executedPreHooks = i + 1 + if shortCircuit != nil { + return req, shortCircuit, p.executedPreHooks // short-circuit: only plugins up to and including i ran + } } + return req, nil, p.executedPreHooks +} - // If primary provider failed and we have fallbacks, try them in order - if len(req.Fallbacks) > 0 { - for _, fallback := range req.Fallbacks { - // Check if we have config for this fallback provider - _, err := bifrost.account.GetConfigForProvider(fallback.Provider) - if err != nil { - bifrost.logger.Warn(fmt.Sprintf("Skipping fallback provider %s: %v", fallback.Provider, err)) - continue - } +// RunPostHooks executes PostHooks in reverse order for the plugins whose PreHook ran. +// Accepts the response and error, and allows plugins to transform either (e.g., recover from error, or invalidate a response). +// Returns the final response and error after all hooks. If both are set, error takes precedence unless error is nil. +func (p *PluginPipeline) RunPostHooks(ctx *context.Context, resp *schemas.BifrostResponse, bifrostErr *schemas.BifrostError, count int) (*schemas.BifrostResponse, *schemas.BifrostError) { + // Defensive: ensure count is within valid bounds + if count < 0 { + count = 0 + } + if count > len(p.plugins) { + count = len(p.plugins) + } + var err error + for i := count - 1; i >= 0; i-- { + plugin := p.plugins[i] + resp, bifrostErr, err = plugin.PostHook(ctx, resp, bifrostErr) + if err != nil { + p.postHookErrors = append(p.postHookErrors, err) + p.logger.Warn("error in PostHook for plugin %s: %v", plugin.GetName(), err) + } + // If a plugin recovers from an error (sets bifrostErr to nil and sets resp), allow that + // If a plugin invalidates a response (sets resp to nil and sets bifrostErr), allow that + } + // Final logic: if both are set, error takes precedence, unless error is nil + if bifrostErr != nil { + if resp != nil && bifrostErr.StatusCode == nil && bifrostErr.Error.Type == nil && + bifrostErr.Error.Message == "" && bifrostErr.Error.Error == nil { + // Defensive: treat as recovery if error is empty + return resp, nil + } + return resp, bifrostErr + } + return resp, nil +} - // Create a new request with the fallback model - fallbackReq := *req - fallbackReq.Model = fallback.Model +// resetPluginPipeline resets a PluginPipeline instance for reuse +func (p *PluginPipeline) resetPluginPipeline() { + p.executedPreHooks = 0 + p.preHookErrors = p.preHookErrors[:0] + p.postHookErrors = p.postHookErrors[:0] +} - // Try the fallback provider - result, fallbackErr := bifrost.tryChatCompletion(fallback.Provider, &fallbackReq, ctx) - if fallbackErr == nil { - bifrost.logger.Info(fmt.Sprintf("Successfully used fallback provider %s with model %s", fallback.Provider, fallback.Model)) - return result, nil - } - bifrost.logger.Warn(fmt.Sprintf("Fallback provider %s failed: %v", fallback.Provider, fallbackErr.Error.Message)) +// getPluginPipeline gets a PluginPipeline from the pool and configures it +func (bifrost *Bifrost) getPluginPipeline() *PluginPipeline { + pipeline := bifrost.pluginPipelinePool.Get().(*PluginPipeline) + pipeline.plugins = bifrost.plugins + pipeline.logger = bifrost.logger + pipeline.resetPluginPipeline() + return pipeline +} + +// releasePluginPipeline returns a PluginPipeline to the pool +func (bifrost *Bifrost) releasePluginPipeline(pipeline *PluginPipeline) { + pipeline.resetPluginPipeline() + bifrost.pluginPipelinePool.Put(pipeline) +} + +// POOL & RESOURCE MANAGEMENT + +// getChannelMessage gets a ChannelMessage from the pool and configures it with the request. +// It also gets response and error channels from their respective pools. +func (bifrost *Bifrost) getChannelMessage(req schemas.BifrostRequest, reqType schemas.RequestType) *ChannelMessage { + // Get channels from pool + responseChan := bifrost.responseChannelPool.Get().(chan *schemas.BifrostResponse) + errorChan := bifrost.errorChannelPool.Get().(chan schemas.BifrostError) + + // Clear any previous values to avoid leaking between requests + select { + case <-responseChan: + default: + } + select { + case <-errorChan: + default: + } + + // Get message from pool and configure it + msg := bifrost.channelMessagePool.Get().(*ChannelMessage) + msg.BifrostRequest = req + msg.Response = responseChan + msg.Err = errorChan + msg.Type = reqType + + // Conditionally allocate ResponseStream for streaming requests only + if IsStreamRequestType(reqType) { + responseStreamChan := bifrost.responseStreamPool.Get().(chan chan *schemas.BifrostStream) + // Clear any previous values to avoid leaking between requests + select { + case <-responseStreamChan: + default: } + msg.ResponseStream = responseStreamChan } - // All providers failed, return the original error - return nil, primaryErr + return msg } -// tryChatCompletion attempts a chat completion request with a single provider. -// This is a helper function used by ChatCompletionRequest to handle individual provider attempts. -func (bifrost *Bifrost) tryChatCompletion(providerKey schemas.ModelProvider, req *schemas.BifrostRequest, ctx context.Context) (*schemas.BifrostResponse, *schemas.BifrostError) { - queue, err := bifrost.GetProviderQueue(providerKey) - if err != nil { - return nil, &schemas.BifrostError{ - IsBifrostError: false, - Error: schemas.ErrorField{ - Message: err.Error(), - }, +// releaseChannelMessage returns a ChannelMessage and its channels to their respective pools. +func (bifrost *Bifrost) releaseChannelMessage(msg *ChannelMessage) { + // Put channels back in pools + bifrost.responseChannelPool.Put(msg.Response) + bifrost.errorChannelPool.Put(msg.Err) + + // Return ResponseStream to pool if it was used + if msg.ResponseStream != nil { + // Drain any remaining channels to prevent memory leaks + select { + case <-msg.ResponseStream: + default: } + bifrost.responseStreamPool.Put(msg.ResponseStream) } - for _, plugin := range bifrost.plugins { - req, err = plugin.PreHook(&ctx, req) - if err != nil { - return nil, &schemas.BifrostError{ - IsBifrostError: false, - Error: schemas.ErrorField{ - Message: err.Error(), - }, - } + // Clear references and return to pool + msg.Response = nil + msg.ResponseStream = nil + msg.Err = nil + bifrost.channelMessagePool.Put(msg) +} + +// selectKeyFromProviderForModel selects an appropriate API key for a given provider and model. +// It uses weighted random selection if multiple keys are available. +func (bifrost *Bifrost) selectKeyFromProviderForModel(ctx *context.Context, providerKey schemas.ModelProvider, model string, baseProviderType schemas.ModelProvider) (schemas.Key, error) { + // Check if key has been set in the context explicitly + if ctx != nil { + key, ok := (*ctx).Value(schemas.BifrostContextKeyDirectKey).(schemas.Key) + if ok { + return key, nil } } - if req == nil { - return nil, &schemas.BifrostError{ - IsBifrostError: false, - Error: schemas.ErrorField{ - Message: "bifrost request after plugin hooks cannot be nil", - }, - } + keys, err := bifrost.account.GetKeysForProvider(ctx, providerKey) + if err != nil { + return schemas.Key{}, err } - // Get a ChannelMessage from the pool - msg := bifrost.getChannelMessage(*req, ChatCompletionRequest) + if len(keys) == 0 { + return schemas.Key{}, fmt.Errorf("no keys found for provider: %v", providerKey) + } - // Handle queue send with context and proper cleanup - select { - case queue <- *msg: - // Message was sent successfully - case <-ctx.Done(): - // Request was cancelled by caller - bifrost.releaseChannelMessage(msg) - return nil, &schemas.BifrostError{ - IsBifrostError: false, - Error: schemas.ErrorField{ - Message: "request cancelled while waiting for queue space", - }, - } - default: - if bifrost.dropExcessRequests { - // Drop request immediately if configured to do so - bifrost.releaseChannelMessage(msg) - bifrost.logger.Warn("Request dropped: queue is full, please increase the queue size or set dropExcessRequests to false") - return nil, &schemas.BifrostError{ - IsBifrostError: false, - Error: schemas.ErrorField{ - Message: "request dropped: queue is full", - }, + // filter out keys which dont support the model, if the key has no models, it is supported for all models + var supportedKeys []schemas.Key + for _, key := range keys { + modelSupported := (slices.Contains(key.Models, model) && (strings.TrimSpace(key.Value) != "" || canProviderKeyValueBeEmpty(baseProviderType))) || len(key.Models) == 0 + + // Additional deployment checks for Azure and Bedrock + deploymentSupported := true + if baseProviderType == schemas.Azure && key.AzureKeyConfig != nil { + // For Azure, check if deployment exists for this model + if len(key.AzureKeyConfig.Deployments) > 0 { + _, deploymentSupported = key.AzureKeyConfig.Deployments[model] + } + } else if baseProviderType == schemas.Bedrock && key.BedrockKeyConfig != nil { + // For Bedrock, check if deployment exists for this model + if len(key.BedrockKeyConfig.Deployments) > 0 { + _, deploymentSupported = key.BedrockKeyConfig.Deployments[model] } } - // If not dropping excess requests, wait with context - if ctx == nil { - ctx = bifrost.backgroundCtx + + if modelSupported && deploymentSupported { + supportedKeys = append(supportedKeys, key) } - select { - case queue <- *msg: - // Message was sent successfully - case <-ctx.Done(): - bifrost.releaseChannelMessage(msg) - return nil, &schemas.BifrostError{ - IsBifrostError: false, - Error: schemas.ErrorField{ - Message: "request cancelled while waiting for queue space", - }, - } + } + + if len(supportedKeys) == 0 { + if baseProviderType == schemas.Azure || baseProviderType == schemas.Bedrock { + return schemas.Key{}, fmt.Errorf("no keys found that support model/deployment: %s", model) } + return schemas.Key{}, fmt.Errorf("no keys found that support model: %s", model) } - // Handle response - var result *schemas.BifrostResponse - select { - case result = <-msg.Response: - // Run plugins in reverse order - for i := len(bifrost.plugins) - 1; i >= 0; i-- { - result, err = bifrost.plugins[i].PostHook(&ctx, result) - if err != nil { - bifrost.releaseChannelMessage(msg) - return nil, &schemas.BifrostError{ - IsBifrostError: false, - Error: schemas.ErrorField{ - Message: err.Error(), - }, - } - } + if len(supportedKeys) == 1 { + return supportedKeys[0], nil + } + + // Use a weighted random selection based on key weights + totalWeight := 0 + for _, key := range supportedKeys { + totalWeight += int(key.Weight * 100) // Convert float to int for better performance + } + + // Use a fast random number generator + randomSource := rand.New(rand.NewSource(time.Now().UnixNano())) + randomValue := randomSource.Intn(totalWeight) + + // Select key based on weight + currentWeight := 0 + for _, key := range supportedKeys { + currentWeight += int(key.Weight * 100) + if randomValue < currentWeight { + return key, nil } - case err := <-msg.Err: - bifrost.releaseChannelMessage(msg) - return nil, &err } - // Return message to pool - bifrost.releaseChannelMessage(msg) - return result, nil + // Fallback to first key if something goes wrong + return supportedKeys[0], nil } // Shutdown gracefully stops all workers when triggered. // It closes all request channels and waits for workers to exit. func (bifrost *Bifrost) Shutdown() { - bifrost.logger.Info("[BIFROST] Graceful Shutdown Initiated - Closing all request channels...") + bifrost.logger.Info("closing all request channels...") // Close all provider queues to signal workers to stop - for _, queue := range bifrost.requestQueues { - close(queue) - } + bifrost.requestQueues.Range(func(key, value interface{}) bool { + close(value.(chan ChannelMessage)) + return true + }) // Wait for all workers to exit - for _, waitGroup := range bifrost.waitGroups { + bifrost.waitGroups.Range(func(key, value interface{}) bool { + waitGroup := value.(*sync.WaitGroup) waitGroup.Wait() - } -} + return true + }) -// Cleanup handles SIGINT (Ctrl+C) to exit cleanly. -// It sets up signal handling and calls Shutdown when interrupted. -func (bifrost *Bifrost) Cleanup() { - signalChan := make(chan os.Signal, 1) - signal.Notify(signalChan, os.Interrupt, syscall.SIGTERM) + // Cleanup MCP manager + if bifrost.mcpManager != nil { + err := bifrost.mcpManager.cleanup() + if err != nil { + bifrost.logger.Warn(fmt.Sprintf("Error cleaning up MCP manager: %s", err.Error())) + } + } - <-signalChan // Wait for interrupt signal - bifrost.Shutdown() // Gracefully shut down + // Cleanup plugins + for _, plugin := range bifrost.plugins { + err := plugin.Cleanup() + if err != nil { + bifrost.logger.Warn(fmt.Sprintf("Error cleaning up plugin: %s", err.Error())) + } + } } diff --git a/core/changelog.md b/core/changelog.md new file mode 100644 index 000000000..a0e9b0d60 --- /dev/null +++ b/core/changelog.md @@ -0,0 +1,2 @@ + + \ No newline at end of file diff --git a/core/go.mod b/core/go.mod index af649c745..94b20c261 100644 --- a/core/go.mod +++ b/core/go.mod @@ -1,33 +1,57 @@ module github.com/maximhq/bifrost/core -go 1.24.1 +go 1.24 -require github.com/joho/godotenv v1.5.1 +toolchain go1.24.3 require ( - github.com/aws/aws-sdk-go-v2 v1.36.3 - github.com/aws/aws-sdk-go-v2/config v1.29.14 - github.com/maximhq/bifrost/plugins v1.0.0 - github.com/valyala/fasthttp v1.60.0 + github.com/aws/aws-sdk-go-v2 v1.38.0 + github.com/aws/aws-sdk-go-v2/config v1.31.0 + github.com/bytedance/sonic v1.14.0 + github.com/mark3labs/mcp-go v0.37.0 + github.com/rs/zerolog v1.34.0 + github.com/valyala/fasthttp v1.65.0 + golang.org/x/oauth2 v0.30.0 ) require ( - github.com/andybalholm/brotli v1.1.1 // indirect - github.com/aws/aws-sdk-go-v2/credentials v1.17.67 // indirect - github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.16.30 // indirect - github.com/aws/aws-sdk-go-v2/internal/configsources v1.3.34 // indirect - github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.6.34 // indirect + cloud.google.com/go/compute/metadata v0.8.0 // indirect + github.com/andybalholm/brotli v1.2.0 // indirect + github.com/aws/aws-sdk-go-v2/credentials v1.18.4 // indirect + github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.18.3 // indirect + github.com/aws/aws-sdk-go-v2/internal/configsources v1.4.3 // indirect + github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.7.3 // indirect github.com/aws/aws-sdk-go-v2/internal/ini v1.8.3 // indirect - github.com/aws/aws-sdk-go-v2/service/internal/accept-encoding v1.12.3 // indirect - github.com/aws/aws-sdk-go-v2/service/internal/presigned-url v1.12.15 // indirect - github.com/aws/aws-sdk-go-v2/service/sso v1.25.3 // indirect - github.com/aws/aws-sdk-go-v2/service/ssooidc v1.30.1 // indirect - github.com/aws/aws-sdk-go-v2/service/sts v1.33.19 // indirect - github.com/aws/smithy-go v1.22.3 // indirect - github.com/goccy/go-json v0.10.5 // indirect + github.com/aws/aws-sdk-go-v2/service/internal/accept-encoding v1.13.0 // indirect + github.com/aws/aws-sdk-go-v2/service/internal/presigned-url v1.13.3 // indirect + github.com/aws/aws-sdk-go-v2/service/sso v1.28.0 // indirect + github.com/aws/aws-sdk-go-v2/service/ssooidc v1.33.0 // indirect + github.com/aws/aws-sdk-go-v2/service/sts v1.37.0 // indirect + github.com/aws/smithy-go v1.22.5 // indirect + github.com/bahlo/generic-list-go v0.2.0 // indirect + github.com/buger/jsonparser v1.1.1 // indirect + github.com/bytedance/sonic/loader v0.3.0 // indirect + github.com/cloudwego/base64x v0.1.6 // indirect + github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc // indirect + github.com/google/uuid v1.6.0 // indirect + github.com/invopop/jsonschema v0.13.0 // indirect github.com/klauspost/compress v1.18.0 // indirect - github.com/maximhq/maxim-go v0.1.1 // indirect + github.com/klauspost/cpuid/v2 v2.3.0 // indirect + github.com/mailru/easyjson v0.9.0 // indirect + github.com/mattn/go-colorable v0.1.14 // indirect + github.com/mattn/go-isatty v0.0.20 // indirect + github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2 // indirect + github.com/rogpeppe/go-internal v1.13.1 // indirect + github.com/spf13/cast v1.9.2 // indirect + github.com/stretchr/testify v1.10.0 // indirect + github.com/twitchyliquid64/golang-asm v0.15.1 // indirect github.com/valyala/bytebufferpool v1.0.0 // indirect - golang.org/x/net v0.39.0 // indirect - golang.org/x/text v0.24.0 // indirect + github.com/wk8/go-ordered-map/v2 v2.1.8 // indirect + github.com/yosida95/uritemplate/v3 v3.0.2 // indirect + golang.org/x/arch v0.20.0 // indirect + golang.org/x/net v0.43.0 // indirect + golang.org/x/sys v0.35.0 // indirect + golang.org/x/text v0.28.0 // indirect + gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c // indirect + gopkg.in/yaml.v3 v3.0.1 // indirect ) diff --git a/core/go.sum b/core/go.sum index d0f8edd17..22df319aa 100644 --- a/core/go.sum +++ b/core/go.sum @@ -1,48 +1,126 @@ -github.com/andybalholm/brotli v1.1.1 h1:PR2pgnyFznKEugtsUo0xLdDop5SKXd5Qf5ysW+7XdTA= -github.com/andybalholm/brotli v1.1.1/go.mod h1:05ib4cKhjx3OQYUY22hTVd34Bc8upXjOLL2rKwwZBoA= -github.com/aws/aws-sdk-go-v2 v1.36.3 h1:mJoei2CxPutQVxaATCzDUjcZEjVRdpsiiXi2o38yqWM= -github.com/aws/aws-sdk-go-v2 v1.36.3/go.mod h1:LLXuLpgzEbD766Z5ECcRmi8AzSwfZItDtmABVkRLGzg= -github.com/aws/aws-sdk-go-v2/config v1.29.14 h1:f+eEi/2cKCg9pqKBoAIwRGzVb70MRKqWX4dg1BDcSJM= -github.com/aws/aws-sdk-go-v2/config v1.29.14/go.mod h1:wVPHWcIFv3WO89w0rE10gzf17ZYy+UVS1Geq8Iei34g= -github.com/aws/aws-sdk-go-v2/credentials v1.17.67 h1:9KxtdcIA/5xPNQyZRgUSpYOE6j9Bc4+D7nZua0KGYOM= -github.com/aws/aws-sdk-go-v2/credentials v1.17.67/go.mod h1:p3C44m+cfnbv763s52gCqrjaqyPikj9Sg47kUVaNZQQ= -github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.16.30 h1:x793wxmUWVDhshP8WW2mlnXuFrO4cOd3HLBroh1paFw= -github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.16.30/go.mod h1:Jpne2tDnYiFascUEs2AWHJL9Yp7A5ZVy3TNyxaAjD6M= -github.com/aws/aws-sdk-go-v2/internal/configsources v1.3.34 h1:ZK5jHhnrioRkUNOc+hOgQKlUL5JeC3S6JgLxtQ+Rm0Q= -github.com/aws/aws-sdk-go-v2/internal/configsources v1.3.34/go.mod h1:p4VfIceZokChbA9FzMbRGz5OV+lekcVtHlPKEO0gSZY= -github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.6.34 h1:SZwFm17ZUNNg5Np0ioo/gq8Mn6u9w19Mri8DnJ15Jf0= -github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.6.34/go.mod h1:dFZsC0BLo346mvKQLWmoJxT+Sjp+qcVR1tRVHQGOH9Q= +cloud.google.com/go/compute/metadata v0.8.0 h1:HxMRIbao8w17ZX6wBnjhcDkW6lTFpgcaobyVfZWqRLA= +cloud.google.com/go/compute/metadata v0.8.0/go.mod h1:sYOGTp851OV9bOFJ9CH7elVvyzopvWQFNNghtDQ/Biw= +github.com/andybalholm/brotli v1.2.0 h1:ukwgCxwYrmACq68yiUqwIWnGY0cTPox/M94sVwToPjQ= +github.com/andybalholm/brotli v1.2.0/go.mod h1:rzTDkvFWvIrjDXZHkuS16NPggd91W3kUSvPlQ1pLaKY= +github.com/aws/aws-sdk-go-v2 v1.38.0 h1:UCRQ5mlqcFk9HJDIqENSLR3wiG1VTWlyUfLDEvY7RxU= +github.com/aws/aws-sdk-go-v2 v1.38.0/go.mod h1:9Q0OoGQoboYIAJyslFyF1f5K1Ryddop8gqMhWx/n4Wg= +github.com/aws/aws-sdk-go-v2/config v1.31.0 h1:9yH0xiY5fUnVNLRWO0AtayqwU1ndriZdN78LlhruJR4= +github.com/aws/aws-sdk-go-v2/config v1.31.0/go.mod h1:VeV3K72nXnhbe4EuxxhzsDc/ByrCSlZwUnWH52Nde/I= +github.com/aws/aws-sdk-go-v2/credentials v1.18.4 h1:IPd0Algf1b+Qy9BcDp0sCUcIWdCQPSzDoMK3a8pcbUM= +github.com/aws/aws-sdk-go-v2/credentials v1.18.4/go.mod h1:nwg78FjH2qvsRM1EVZlX9WuGUJOL5od+0qvm0adEzHk= +github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.18.3 h1:GicIdnekoJsjq9wqnvyi2elW6CGMSYKhdozE7/Svh78= +github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.18.3/go.mod h1:R7BIi6WNC5mc1kfRM7XM/VHC3uRWkjc396sfabq4iOo= +github.com/aws/aws-sdk-go-v2/internal/configsources v1.4.3 h1:o9RnO+YZ4X+kt5Z7Nvcishlz0nksIt2PIzDglLMP0vA= +github.com/aws/aws-sdk-go-v2/internal/configsources v1.4.3/go.mod h1:+6aLJzOG1fvMOyzIySYjOFjcguGvVRL68R+uoRencN4= +github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.7.3 h1:joyyUFhiTQQmVK6ImzNU9TQSNRNeD9kOklqTzyk5v6s= +github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.7.3/go.mod h1:+vNIyZQP3b3B1tSLI0lxvrU9cfM7gpdRXMFfm67ZcPc= github.com/aws/aws-sdk-go-v2/internal/ini v1.8.3 h1:bIqFDwgGXXN1Kpp99pDOdKMTTb5d2KyU5X/BZxjOkRo= github.com/aws/aws-sdk-go-v2/internal/ini v1.8.3/go.mod h1:H5O/EsxDWyU+LP/V8i5sm8cxoZgc2fdNR9bxlOFrQTo= -github.com/aws/aws-sdk-go-v2/service/internal/accept-encoding v1.12.3 h1:eAh2A4b5IzM/lum78bZ590jy36+d/aFLgKF/4Vd1xPE= -github.com/aws/aws-sdk-go-v2/service/internal/accept-encoding v1.12.3/go.mod h1:0yKJC/kb8sAnmlYa6Zs3QVYqaC8ug2AbnNChv5Ox3uA= -github.com/aws/aws-sdk-go-v2/service/internal/presigned-url v1.12.15 h1:dM9/92u2F1JbDaGooxTq18wmmFzbJRfXfVfy96/1CXM= -github.com/aws/aws-sdk-go-v2/service/internal/presigned-url v1.12.15/go.mod h1:SwFBy2vjtA0vZbjjaFtfN045boopadnoVPhu4Fv66vY= -github.com/aws/aws-sdk-go-v2/service/sso v1.25.3 h1:1Gw+9ajCV1jogloEv1RRnvfRFia2cL6c9cuKV2Ps+G8= -github.com/aws/aws-sdk-go-v2/service/sso v1.25.3/go.mod h1:qs4a9T5EMLl/Cajiw2TcbNt2UNo/Hqlyp+GiuG4CFDI= -github.com/aws/aws-sdk-go-v2/service/ssooidc v1.30.1 h1:hXmVKytPfTy5axZ+fYbR5d0cFmC3JvwLm5kM83luako= -github.com/aws/aws-sdk-go-v2/service/ssooidc v1.30.1/go.mod h1:MlYRNmYu/fGPoxBQVvBYr9nyr948aY/WLUvwBMBJubs= -github.com/aws/aws-sdk-go-v2/service/sts v1.33.19 h1:1XuUZ8mYJw9B6lzAkXhqHlJd/XvaX32evhproijJEZY= -github.com/aws/aws-sdk-go-v2/service/sts v1.33.19/go.mod h1:cQnB8CUnxbMU82JvlqjKR2HBOm3fe9pWorWBza6MBJ4= -github.com/aws/smithy-go v1.22.3 h1:Z//5NuZCSW6R4PhQ93hShNbyBbn8BWCmCVCt+Q8Io5k= -github.com/aws/smithy-go v1.22.3/go.mod h1:t1ufH5HMublsJYulve2RKmHDC15xu1f26kHCp/HgceI= -github.com/goccy/go-json v0.10.5 h1:Fq85nIqj+gXn/S5ahsiTlK3TmC85qgirsdTP/+DeaC4= -github.com/goccy/go-json v0.10.5/go.mod h1:oq7eo15ShAhp70Anwd5lgX2pLfOS3QCiwU/PULtXL6M= -github.com/joho/godotenv v1.5.1 h1:7eLL/+HRGLY0ldzfGMeQkb7vMd0as4CfYvUVzLqw0N0= -github.com/joho/godotenv v1.5.1/go.mod h1:f4LDr5Voq0i2e/R5DDNOoa2zzDfwtkZa6DnEwAbqwq4= +github.com/aws/aws-sdk-go-v2/service/internal/accept-encoding v1.13.0 h1:6+lZi2JeGKtCraAj1rpoZfKqnQ9SptseRZioejfUOLM= +github.com/aws/aws-sdk-go-v2/service/internal/accept-encoding v1.13.0/go.mod h1:eb3gfbVIxIoGgJsi9pGne19dhCBpK6opTYpQqAmdy44= +github.com/aws/aws-sdk-go-v2/service/internal/presigned-url v1.13.3 h1:ieRzyHXypu5ByllM7Sp4hC5f/1Fy5wqxqY0yB85hC7s= +github.com/aws/aws-sdk-go-v2/service/internal/presigned-url v1.13.3/go.mod h1:O5ROz8jHiOAKAwx179v+7sHMhfobFVi6nZt8DEyiYoM= +github.com/aws/aws-sdk-go-v2/service/sso v1.28.0 h1:Mc/MKBf2m4VynyJkABoVEN+QzkfLqGj0aiJuEe7cMeM= +github.com/aws/aws-sdk-go-v2/service/sso v1.28.0/go.mod h1:iS5OmxEcN4QIPXARGhavH7S8kETNL11kym6jhoS7IUQ= +github.com/aws/aws-sdk-go-v2/service/ssooidc v1.33.0 h1:6csaS/aJmqZQbKhi1EyEMM7yBW653Wy/B9hnBofW+sw= +github.com/aws/aws-sdk-go-v2/service/ssooidc v1.33.0/go.mod h1:59qHWaY5B+Rs7HGTuVGaC32m0rdpQ68N8QCN3khYiqs= +github.com/aws/aws-sdk-go-v2/service/sts v1.37.0 h1:MG9VFW43M4A8BYeAfaJJZWrroinxeTi2r3+SnmLQfSA= +github.com/aws/aws-sdk-go-v2/service/sts v1.37.0/go.mod h1:JdeBDPgpJfuS6rU/hNglmOigKhyEZtBmbraLE4GK1J8= +github.com/aws/smithy-go v1.22.5 h1:P9ATCXPMb2mPjYBgueqJNCA5S9UfktsW0tTxi+a7eqw= +github.com/aws/smithy-go v1.22.5/go.mod h1:t1ufH5HMublsJYulve2RKmHDC15xu1f26kHCp/HgceI= +github.com/bahlo/generic-list-go v0.2.0 h1:5sz/EEAK+ls5wF+NeqDpk5+iNdMDXrh3z3nPnH1Wvgk= +github.com/bahlo/generic-list-go v0.2.0/go.mod h1:2KvAjgMlE5NNynlg/5iLrrCCZ2+5xWbdbCW3pNTGyYg= +github.com/buger/jsonparser v1.1.1 h1:2PnMjfWD7wBILjqQbt530v576A/cAbQvEW9gGIpYMUs= +github.com/buger/jsonparser v1.1.1/go.mod h1:6RYKKt7H4d4+iWqouImQ9R2FZql3VbhNgx27UK13J/0= +github.com/bytedance/sonic v1.14.0 h1:/OfKt8HFw0kh2rj8N0F6C/qPGRESq0BbaNZgcNXXzQQ= +github.com/bytedance/sonic v1.14.0/go.mod h1:WoEbx8WTcFJfzCe0hbmyTGrfjt8PzNEBdxlNUO24NhA= +github.com/bytedance/sonic/loader v0.3.0 h1:dskwH8edlzNMctoruo8FPTJDF3vLtDT0sXZwvZJyqeA= +github.com/bytedance/sonic/loader v0.3.0/go.mod h1:N8A3vUdtUebEY2/VQC0MyhYeKUFosQU6FxH2JmUe6VI= +github.com/cloudwego/base64x v0.1.6 h1:t11wG9AECkCDk5fMSoxmufanudBtJ+/HemLstXDLI2M= +github.com/cloudwego/base64x v0.1.6/go.mod h1:OFcloc187FXDaYHvrNIjxSe8ncn0OOM8gEHfghB2IPU= +github.com/coreos/go-systemd/v22 v22.5.0/go.mod h1:Y58oyj3AT4RCenI/lSvhwexgC+NSVTIJ3seZv2GcEnc= +github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc h1:U9qPSI2PIWSS1VwoXQT9A3Wy9MM3WgvqSxFWenqJduM= +github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/frankban/quicktest v1.14.6 h1:7Xjx+VpznH+oBnejlPUj8oUpdxnVs4f8XU8WnHkI4W8= +github.com/frankban/quicktest v1.14.6/go.mod h1:4ptaffx2x8+WTWXmUCuVU6aPUX1/Mz7zb5vbUoiM6w0= +github.com/godbus/dbus/v5 v5.0.4/go.mod h1:xhWf0FNVPg57R7Z0UbKHbJfkEywrmjJnf7w5xrFpKfA= +github.com/google/go-cmp v0.7.0 h1:wk8382ETsv4JYUZwIsn6YpYiWiBsYLSJiTsyBybVuN8= +github.com/google/go-cmp v0.7.0/go.mod h1:pXiqmnSA92OHEEa9HXL2W4E7lf9JzCmGVUdgjX3N/iU= +github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0= +github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= +github.com/invopop/jsonschema v0.13.0 h1:KvpoAJWEjR3uD9Kbm2HWJmqsEaHt8lBUpd0qHcIi21E= +github.com/invopop/jsonschema v0.13.0/go.mod h1:ffZ5Km5SWWRAIN6wbDXItl95euhFz2uON45H2qjYt+0= github.com/klauspost/compress v1.18.0 h1:c/Cqfb0r+Yi+JtIEq73FWXVkRonBlf0CRNYc8Zttxdo= github.com/klauspost/compress v1.18.0/go.mod h1:2Pp+KzxcywXVXMr50+X0Q/Lsb43OQHYWRCY2AiWywWQ= -github.com/maximhq/bifrost/plugins v1.0.0 h1:ul4tMMQHOdhyFQueyZwmQB3uX+s2buYSKzq1FW0m090= -github.com/maximhq/bifrost/plugins v1.0.0/go.mod h1:IUDZ2NMgCjIn1SVCvYbWZd/Lsk96MNytOvEKpinjvHo= -github.com/maximhq/maxim-go v0.1.1 h1:69uUQjjDPmUGcKg/M4/3AO0fbD+70Agt66pH/UCsI5M= -github.com/maximhq/maxim-go v0.1.1/go.mod h1:0+UTWM7UZwNNE5VnljLtr/vpRGtYP8r/2q9WDwlLWFw= +github.com/klauspost/cpuid/v2 v2.3.0 h1:S4CRMLnYUhGeDFDqkGriYKdfoFlDnMtqTiI/sFzhA9Y= +github.com/klauspost/cpuid/v2 v2.3.0/go.mod h1:hqwkgyIinND0mEev00jJYCxPNVRVXFQeu1XKlok6oO0= +github.com/kr/pretty v0.2.1/go.mod h1:ipq/a2n7PKx3OHsz4KJII5eveXtPO4qwEXGdVfWzfnI= +github.com/kr/pretty v0.3.1 h1:flRD4NNwYAUpkphVc1HcthR4KEIFJ65n8Mw5qdRn3LE= +github.com/kr/pretty v0.3.1/go.mod h1:hoEshYVHaxMs3cyo3Yncou5ZscifuDolrwPKZanG3xk= +github.com/kr/pty v1.1.1/go.mod h1:pFQYn66WHrOpPYNljwOMqo10TkYh1fy3cYio2l3bCsQ= +github.com/kr/text v0.1.0/go.mod h1:4Jbv+DJW3UT/LiOwJeYQe1efqtUx/iVham/4vfdArNI= +github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY= +github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE= +github.com/mailru/easyjson v0.9.0 h1:PrnmzHw7262yW8sTBwxi1PdJA3Iw/EKBa8psRf7d9a4= +github.com/mailru/easyjson v0.9.0/go.mod h1:1+xMtQp2MRNVL/V1bOzuP3aP8VNwRW55fQUto+XFtTU= +github.com/mark3labs/mcp-go v0.37.0 h1:BywvZLPRT6Zx6mMG/MJfxLSZQkTGIcJSEGKsvr4DsoQ= +github.com/mark3labs/mcp-go v0.37.0/go.mod h1:T7tUa2jO6MavG+3P25Oy/jR7iCeJPHImCZHRymCn39g= +github.com/mattn/go-colorable v0.1.13/go.mod h1:7S9/ev0klgBDR4GtXTXX8a3vIGJpMovkB8vQcUbaXHg= +github.com/mattn/go-colorable v0.1.14 h1:9A9LHSqF/7dyVVX6g0U9cwm9pG3kP9gSzcuIPHPsaIE= +github.com/mattn/go-colorable v0.1.14/go.mod h1:6LmQG8QLFO4G5z1gPvYEzlUgJ2wF+stgPZH1UqBm1s8= +github.com/mattn/go-isatty v0.0.16/go.mod h1:kYGgaQfpe5nmfYZH+SKPsOc2e4SrIfOl2e/yFXSvRLM= +github.com/mattn/go-isatty v0.0.19/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y= +github.com/mattn/go-isatty v0.0.20 h1:xfD0iDuEKnDkl03q4limB+vH+GxLEtL/jb4xVJSWWEY= +github.com/mattn/go-isatty v0.0.20/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y= +github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= +github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= +github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2 h1:Jamvg5psRIccs7FGNTlIRMkT8wgtp5eCXdBlqhYGL6U= +github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= +github.com/rogpeppe/go-internal v1.13.1 h1:KvO1DLK/DRN07sQ1LQKScxyZJuNnedQ5/wKSR38lUII= +github.com/rogpeppe/go-internal v1.13.1/go.mod h1:uMEvuHeurkdAXX61udpOXGD/AzZDWNMNyH2VO9fmH0o= +github.com/rs/xid v1.6.0/go.mod h1:7XoLgs4eV+QndskICGsho+ADou8ySMSjJKDIan90Nz0= +github.com/rs/zerolog v1.34.0 h1:k43nTLIwcTVQAncfCw4KZ2VY6ukYoZaBPNOE8txlOeY= +github.com/rs/zerolog v1.34.0/go.mod h1:bJsvje4Z08ROH4Nhs5iH600c3IkWhwp44iRc54W6wYQ= +github.com/spf13/cast v1.9.2 h1:SsGfm7M8QOFtEzumm7UZrZdLLquNdzFYfIbEXntcFbE= +github.com/spf13/cast v1.9.2/go.mod h1:jNfB8QC9IA6ZuY2ZjDp0KtFO2LZZlg4S/7bzP6qqeHo= +github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= +github.com/stretchr/objx v0.4.0/go.mod h1:YvHI0jy2hoMjB+UWwv71VJQ9isScKT/TqJzVSSt89Yw= +github.com/stretchr/objx v0.5.0/go.mod h1:Yh+to48EsGEfYuaHDzXPcE3xhTkx73EhmCGUpEOglKo= +github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= +github.com/stretchr/testify v1.8.0/go.mod h1:yNjHg4UonilssWZ8iaSj1OCr/vHnekPRkoO+kdMU+MU= +github.com/stretchr/testify v1.8.1/go.mod h1:w2LPCIKwWwSfY2zedu0+kehJoqGctiVI29o6fzry7u4= +github.com/stretchr/testify v1.10.0 h1:Xv5erBjTwe/5IxqUQTdXv5kgmIvbHo3QQyRwhJsOfJA= +github.com/stretchr/testify v1.10.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY= +github.com/twitchyliquid64/golang-asm v0.15.1 h1:SU5vSMR7hnwNxj24w34ZyCi/FmDZTkS4MhqMhdFk5YI= +github.com/twitchyliquid64/golang-asm v0.15.1/go.mod h1:a1lVb/DtPvCB8fslRZhAngC2+aY1QWCk3Cedj/Gdt08= github.com/valyala/bytebufferpool v1.0.0 h1:GqA5TC/0021Y/b9FG4Oi9Mr3q7XYx6KllzawFIhcdPw= github.com/valyala/bytebufferpool v1.0.0/go.mod h1:6bBcMArwyJ5K/AmCkWv1jt77kVWyCJ6HpOuEn7z0Csc= -github.com/valyala/fasthttp v1.60.0 h1:kBRYS0lOhVJ6V+bYN8PqAHELKHtXqwq9zNMLKx1MBsw= -github.com/valyala/fasthttp v1.60.0/go.mod h1:iY4kDgV3Gc6EqhRZ8icqcmlG6bqhcDXfuHgTO4FXCvc= +github.com/valyala/fasthttp v1.65.0 h1:j/u3uzFEGFfRxw79iYzJN+TteTJwbYkru9uDp3d0Yf8= +github.com/valyala/fasthttp v1.65.0/go.mod h1:P/93/YkKPMsKSnATEeELUCkG8a7Y+k99uxNHVbKINr4= +github.com/wk8/go-ordered-map/v2 v2.1.8 h1:5h/BUHu93oj4gIdvHHHGsScSTMijfx5PeYkE/fJgbpc= +github.com/wk8/go-ordered-map/v2 v2.1.8/go.mod h1:5nJHM5DyteebpVlHnWMV0rPz6Zp7+xBAnxjb1X5vnTw= github.com/xyproto/randomstring v1.0.5 h1:YtlWPoRdgMu3NZtP45drfy1GKoojuR7hmRcnhZqKjWU= github.com/xyproto/randomstring v1.0.5/go.mod h1:rgmS5DeNXLivK7YprL0pY+lTuhNQW3iGxZ18UQApw/E= -golang.org/x/net v0.39.0 h1:ZCu7HMWDxpXpaiKdhzIfaltL9Lp31x/3fCP11bc6/fY= -golang.org/x/net v0.39.0/go.mod h1:X7NRbYVEA+ewNkCNyJ513WmMdQ3BineSwVtN2zD/d+E= -golang.org/x/text v0.24.0 h1:dd5Bzh4yt5KYA8f9CJHCP4FB4D51c2c6JvN37xJJkJ0= -golang.org/x/text v0.24.0/go.mod h1:L8rBsPeo2pSS+xqN0d5u2ikmjtmoJbDBT1b7nHvFCdU= +github.com/yosida95/uritemplate/v3 v3.0.2 h1:Ed3Oyj9yrmi9087+NczuL5BwkIc4wvTb5zIM+UJPGz4= +github.com/yosida95/uritemplate/v3 v3.0.2/go.mod h1:ILOh0sOhIJR3+L/8afwt/kE++YT040gmv5BQTMR2HP4= +golang.org/x/arch v0.20.0 h1:dx1zTU0MAE98U+TQ8BLl7XsJbgze2WnNKF/8tGp/Q6c= +golang.org/x/arch v0.20.0/go.mod h1:bdwinDaKcfZUGpH09BB7ZmOfhalA8lQdzl62l8gGWsk= +golang.org/x/net v0.43.0 h1:lat02VYK2j4aLzMzecihNvTlJNQUq316m2Mr9rnM6YE= +golang.org/x/net v0.43.0/go.mod h1:vhO1fvI4dGsIjh73sWfUVjj3N7CA9WkKJNQm2svM6Jg= +golang.org/x/oauth2 v0.30.0 h1:dnDm7JmhM45NNpd8FDDeLhK6FwqbOf4MLCM9zb1BOHI= +golang.org/x/oauth2 v0.30.0/go.mod h1:B++QgG3ZKulg6sRPGD/mqlHQs5rB3Ml9erfeDY7xKlU= +golang.org/x/sys v0.0.0-20220811171246-fbc7d0a398ab/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.12.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.35.0 h1:vz1N37gP5bs89s7He8XuIYXpyY0+QlsKmzipCbUtyxI= +golang.org/x/sys v0.35.0/go.mod h1:BJP2sWEmIv4KK5OTEluFJCKSidICx8ciO85XgH3Ak8k= +golang.org/x/text v0.28.0 h1:rhazDwis8INMIwQ4tpjLDzUhx6RlXqZNPEM0huQojng= +golang.org/x/text v0.28.0/go.mod h1:U8nCwOR8jO/marOQ0QbDiOngZVEBB7MAiitBuMjXiNU= +gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= +gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c h1:Hei/4ADfdWqJk1ZMxUNpqntNwaWcugrBjAiHlqqRiVk= +gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c/go.mod h1:JHkPIbrfpd72SG/EVd6muEfDQjcINNoR0C8j2r3qZ4Q= +gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= +gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= +gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= diff --git a/core/logger.go b/core/logger.go index b5d1bfcfa..c4c554d3a 100644 --- a/core/logger.go +++ b/core/logger.go @@ -2,11 +2,12 @@ package bifrost import ( - "fmt" "os" "time" schemas "github.com/maximhq/bifrost/core/schemas" + "github.com/rs/zerolog" + "github.com/rs/zerolog/log" ) // DefaultLogger implements the Logger interface with stdout/stderr printing. @@ -14,60 +15,103 @@ import ( // and error streams with formatted timestamps and log levels. // It is used as the default logger if no logger is provided in the BifrostConfig. type DefaultLogger struct { - level schemas.LogLevel // Current logging level + stderrLogger zerolog.Logger + stdoutLogger zerolog.Logger +} + +// toZerologLevel converts a Bifrost log level to a Zerolog level. +func toZerologLevel(l schemas.LogLevel) zerolog.Level { + switch l { + case schemas.LogLevelDebug: + return zerolog.DebugLevel + case schemas.LogLevelInfo: + return zerolog.InfoLevel + case schemas.LogLevelWarn: + return zerolog.WarnLevel + case schemas.LogLevelError: + return zerolog.ErrorLevel + default: + return zerolog.InfoLevel + } } // NewDefaultLogger creates a new DefaultLogger instance with the specified log level. // The log level determines which messages will be output based on their severity. func NewDefaultLogger(level schemas.LogLevel) *DefaultLogger { + zerolog.SetGlobalLevel(toZerologLevel(level)) + zerolog.DisableSampling(true) + zerolog.TimeFieldFormat = time.RFC3339 + log.Logger = zerolog.New(os.Stdout).With().Timestamp().Logger() return &DefaultLogger{ - level: level, - } -} - -// formatMessage formats the log message with timestamp, level, and optional error information. -// It creates a consistent log format: [BIFROST-TIMESTAMP] LEVEL: message (error: err) -func (logger *DefaultLogger) formatMessage(level schemas.LogLevel, msg string, err error) string { - timestamp := time.Now().Format(time.RFC3339) - baseMsg := fmt.Sprintf("[BIFROST-%s] %s: %s", timestamp, level, msg) - if err != nil { - return fmt.Sprintf("%s (error: %v)", baseMsg, err) + stderrLogger: zerolog.New(os.Stderr).With().Timestamp().Logger(), + stdoutLogger: zerolog.New(os.Stdout).With().Timestamp().Logger(), } - return baseMsg } // Debug logs a debug level message to stdout. // Messages are only output if the logger's level is set to LogLevelDebug. -func (logger *DefaultLogger) Debug(msg string) { - if logger.level == schemas.LogLevelDebug { - fmt.Fprintln(os.Stdout, logger.formatMessage(schemas.LogLevelDebug, msg, nil)) - } +func (logger *DefaultLogger) Debug(msg string, args ...any) { + logger.stdoutLogger.Debug().Msgf(msg, args...) } // Info logs an info level message to stdout. // Messages are output if the logger's level is LogLevelDebug or LogLevelInfo. -func (logger *DefaultLogger) Info(msg string) { - if logger.level == schemas.LogLevelDebug || logger.level == schemas.LogLevelInfo { - fmt.Fprintln(os.Stdout, logger.formatMessage(schemas.LogLevelInfo, msg, nil)) - } +func (logger *DefaultLogger) Info(msg string, args ...any) { + logger.stdoutLogger.Info().Msgf(msg, args...) } // Warn logs a warning level message to stdout. // Messages are output if the logger's level is LogLevelDebug, LogLevelInfo, or LogLevelWarn. -func (logger *DefaultLogger) Warn(msg string) { - if logger.level == schemas.LogLevelDebug || logger.level == schemas.LogLevelInfo || logger.level == schemas.LogLevelWarn { - fmt.Fprintln(os.Stdout, logger.formatMessage(schemas.LogLevelWarn, msg, nil)) - } +func (logger *DefaultLogger) Warn(msg string, args ...any) { + logger.stdoutLogger.Warn().Msgf(msg, args...) } // Error logs an error level message to stderr. // Error messages are always output regardless of the logger's level. -func (logger *DefaultLogger) Error(err error) { - fmt.Fprintln(os.Stderr, logger.formatMessage(schemas.LogLevelError, "", err)) +func (logger *DefaultLogger) Error(msg string, args ...any) { + logger.stderrLogger.Error().Msgf(msg, args...) +} + +// Fatal logs a fatal-level message to stderr. +// Fatal messages are always output regardless of the logger's level. +func (logger *DefaultLogger) Fatal(msg string, args ...any) { + // Check if any of the args is an error and exit with non-zero code if found + var errToPass error + for i, arg := range args { + if err, ok := arg.(error); ok && err != nil { + errToPass = err + // remove from args + args = append(args[:i], args[i+1:]...) + } + } + if errToPass != nil { + logger.stderrLogger.Fatal().Msgf(msg, errToPass) + } else { + logger.stderrLogger.Fatal().Msgf(msg, args...) + } } // SetLevel sets the logging level for the logger. // This determines which messages will be output based on their severity. func (logger *DefaultLogger) SetLevel(level schemas.LogLevel) { - logger.level = level + zerolog.SetGlobalLevel(toZerologLevel(level)) +} + +// SetOutputType sets the output type for the logger. +// This determines the format of the log output. +// If the output type is unknown, it defaults to JSON +func (logger *DefaultLogger) SetOutputType(outputType schemas.LoggerOutputType) { + switch outputType { + case schemas.LoggerOutputTypePretty: + logger.stdoutLogger = zerolog.New(zerolog.ConsoleWriter{Out: os.Stdout}).With().Timestamp().Logger() + logger.stderrLogger = zerolog.New(zerolog.ConsoleWriter{Out: os.Stderr}).With().Timestamp().Logger() + case schemas.LoggerOutputTypeJSON: + logger.stdoutLogger = zerolog.New(os.Stdout).With().Timestamp().Logger() + logger.stderrLogger = zerolog.New(os.Stderr).With().Timestamp().Logger() + default: + logger.stderrLogger.Warn(). + Str("outputType", string(outputType)). + Msg("unknown logger output type; defaulting to JSON") + logger.stdoutLogger = zerolog.New(os.Stdout).With().Timestamp().Logger() + } } diff --git a/core/mcp.go b/core/mcp.go new file mode 100644 index 000000000..fc7b19d48 --- /dev/null +++ b/core/mcp.go @@ -0,0 +1,1126 @@ +package bifrost + +import ( + "context" + "encoding/json" + "fmt" + "maps" + "os" + "slices" + "strings" + "sync" + "time" + + "github.com/maximhq/bifrost/core/schemas" + + "github.com/mark3labs/mcp-go/client" + "github.com/mark3labs/mcp-go/client/transport" + "github.com/mark3labs/mcp-go/mcp" + "github.com/mark3labs/mcp-go/server" +) + +// ============================================================================ +// CONSTANTS +// ============================================================================ + +const ( + // MCP defaults and identifiers + BifrostMCPVersion = "1.0.0" // Version identifier for Bifrost + BifrostMCPClientName = "BifrostClient" // Name for internal Bifrost MCP client + BifrostMCPClientKey = "bifrost-internal" // Key for internal Bifrost client in clientMap + MCPLogPrefix = "[Bifrost MCP]" // Consistent logging prefix + MCPClientConnectionEstablishTimeout = 30 * time.Second // Timeout for MCP client connection establishment + + // Context keys for client filtering in requests + MCPContextKeyIncludeClients = "mcp-include-clients" // Context key for whitelist client filtering + MCPContextKeyExcludeClients = "mcp-exclude-clients" // Context key for blacklist client filtering + MCPContextKeyIncludeTools = "mcp-include-tools" // Context key for whitelist tool filtering + MCPContextKeyExcludeTools = "mcp-exclude-tools" // Context key for blacklist tool filtering +) + +// ============================================================================ +// TYPE DEFINITIONS +// ============================================================================ + +// MCPManager manages MCP integration for Bifrost core. +// It provides a bridge between Bifrost and various MCP servers, supporting +// both local tool hosting and external MCP server connections. +type MCPManager struct { + ctx context.Context + server *server.MCPServer // Local MCP server instance for hosting tools (STDIO-based) + clientMap map[string]*MCPClient // Map of MCP client names to their configurations + mu sync.RWMutex // Read-write mutex for thread-safe operations + serverRunning bool // Track whether local MCP server is running + logger schemas.Logger // Logger instance for structured logging +} + +// MCPClient represents a connected MCP client with its configuration and tools. +type MCPClient struct { + Name string // Unique name for this client + Conn *client.Client // Active MCP client connection + ExecutionConfig schemas.MCPClientConfig // Tool filtering settings + ToolMap map[string]schemas.Tool // Available tools mapped by name + ConnectionInfo MCPClientConnectionInfo `json:"connection_info"` // Connection metadata for management + cancelFunc context.CancelFunc `json:"-"` // Cancel function for SSE connections (not serialized) +} + +// MCPClientConnectionInfo stores metadata about how a client is connected. +type MCPClientConnectionInfo struct { + Type schemas.MCPConnectionType `json:"type"` // Connection type (HTTP, STDIO, SSE, or InProcess) + ConnectionURL *string `json:"connection_url,omitempty"` // HTTP/SSE endpoint URL (for HTTP/SSE connections) + StdioCommandString *string `json:"stdio_command_string,omitempty"` // Command string for display (for STDIO connections) +} + +// MCPToolHandler is a generic function type for handling tool calls with typed arguments. +// T represents the expected argument structure for the tool. +type MCPToolHandler[T any] func(args T) (string, error) + +// ============================================================================ +// CONSTRUCTOR AND INITIALIZATION +// ============================================================================ + +// newMCPManager creates and initializes a new MCP manager instance. +// +// Parameters: +// - config: MCP configuration including server port and client configs +// - logger: Logger instance for structured logging (uses default if nil) +// +// Returns: +// - *MCPManager: Initialized manager instance +// - error: Any initialization error +func newMCPManager(ctx context.Context, config schemas.MCPConfig, logger schemas.Logger) (*MCPManager, error) { + // Creating new instance + manager := &MCPManager{ + ctx: ctx, + clientMap: make(map[string]*MCPClient), + logger: logger, + } + // Process client configs: create client map entries and establish connections + for _, clientConfig := range config.ClientConfigs { + if err := manager.AddClient(clientConfig); err != nil { + manager.logger.Warn(fmt.Sprintf("%s Failed to add MCP client %s: %v", MCPLogPrefix, clientConfig.Name, err)) + } + } + manager.logger.Info(MCPLogPrefix + " MCP Manager initialized") + return manager, nil +} + +// GetClients returns all MCP clients managed by the manager. +// +// Returns: +// - []*MCPClient: List of all MCP clients +// - error: Any retrieval error +func (m *MCPManager) GetClients() ([]MCPClient, error) { + m.mu.RLock() + defer m.mu.RUnlock() + + clients := make([]MCPClient, 0, len(m.clientMap)) + for _, client := range m.clientMap { + clients = append(clients, *client) + } + + return clients, nil +} + +// ReconnectClient attempts to reconnect an MCP client if it is disconnected. +func (m *MCPManager) ReconnectClient(name string) error { + m.mu.Lock() + defer m.mu.Unlock() + + client, ok := m.clientMap[name] + if !ok { + return fmt.Errorf("client %s not found", name) + } + + if client.Conn != nil { + return fmt.Errorf("client %s is already connected", name) + } + + m.mu.Unlock() + + // connectToMCPClient handles locking internally + err := m.connectToMCPClient(client.ExecutionConfig) + if err != nil { + return fmt.Errorf("failed to connect to MCP client %s: %w", name, err) + } + + return nil +} + +// AddClient adds a new MCP client to the manager. +// It validates the client configuration and establishes a connection. +// +// Parameters: +// - config: MCP client configuration +// +// Returns: +func (m *MCPManager) AddClient(config schemas.MCPClientConfig) error { + if err := validateMCPClientConfig(&config); err != nil { + return fmt.Errorf("invalid MCP client configuration: %w", err) + } + + // Make a copy of the config to use after unlocking + configCopy := config + + m.mu.Lock() + + if _, ok := m.clientMap[config.Name]; ok { + m.mu.Unlock() + return fmt.Errorf("client %s already exists", config.Name) + } + + // Create placeholder entry + m.clientMap[config.Name] = &MCPClient{ + Name: config.Name, + ExecutionConfig: config, + ToolMap: make(map[string]schemas.Tool), + } + + // Temporarily unlock for the connection attempt + // This is to avoid deadlocks when the connection attempt is made + m.mu.Unlock() + + // Connect using the copied config + if err := m.connectToMCPClient(configCopy); err != nil { + // Re-lock to clean up the failed entry + m.mu.Lock() + delete(m.clientMap, config.Name) + m.mu.Unlock() + return fmt.Errorf("failed to connect to MCP client %s: %w", config.Name, err) + } + + return nil +} + +// RemoveClient removes an MCP client from the manager. +// It handles cleanup for all transport types (HTTP, STDIO, SSE). +// +// Parameters: +// - name: Name of the client to remove +func (m *MCPManager) RemoveClient(name string) error { + m.mu.Lock() + defer m.mu.Unlock() + + return m.removeClientUnsafe(name) +} + +func (m *MCPManager) removeClientUnsafe(name string) error { + client, ok := m.clientMap[name] + if !ok { + return fmt.Errorf("client %s not found", name) + } + + m.logger.Info(fmt.Sprintf("%s Disconnecting MCP client: %s", MCPLogPrefix, name)) + + // Cancel SSE context if present (required for proper SSE cleanup) + if client.cancelFunc != nil { + client.cancelFunc() + client.cancelFunc = nil + } + + // Close the client transport connection + // This handles cleanup for all transport types (HTTP, STDIO, SSE) + if client.Conn != nil { + if err := client.Conn.Close(); err != nil { + m.logger.Error("%s Failed to close MCP client %s: %v", MCPLogPrefix, name, err) + } + client.Conn = nil + } + + // Clear client tool map + client.ToolMap = make(map[string]schemas.Tool) + + delete(m.clientMap, name) + return nil +} + +func (m *MCPManager) EditClientTools(name string, toolsToAdd []string, toolsToRemove []string) error { + m.mu.Lock() + defer m.mu.Unlock() + + client, ok := m.clientMap[name] + if !ok { + return fmt.Errorf("client %s not found", name) + } + + if client.Conn == nil { + return fmt.Errorf("client %s has no active connection", name) + } + + // Update the client's execution config with new tool filters + config := client.ExecutionConfig + config.ToolsToExecute = toolsToAdd + config.ToolsToSkip = toolsToRemove + + // Store the updated config + client.ExecutionConfig = config + + // Clear current tool map + client.ToolMap = make(map[string]schemas.Tool) + + // Temporarily unlock for the network call + m.mu.Unlock() + + // Retrieve tools with updated configuration + tools, err := m.retrieveExternalTools(m.ctx, client.Conn, config) + + // Re-lock to update the tool map + m.mu.Lock() + + // Verify client still exists + if _, ok := m.clientMap[name]; !ok { + return fmt.Errorf("client %s was removed during tool update", name) + } + + if err != nil { + return fmt.Errorf("failed to retrieve external tools: %w", err) + } + + // Store discovered tools + maps.Copy(client.ToolMap, tools) + + return nil +} + +// ============================================================================ +// TOOL REGISTRATION AND DISCOVERY +// ============================================================================ + +// getAvailableTools returns all tools from connected MCP clients. +// Applies client filtering if specified in the context. +func (m *MCPManager) getAvailableTools(ctx context.Context) []schemas.Tool { + m.mu.RLock() + defer m.mu.RUnlock() + + var includeClients []string + var excludeClients []string + + // Extract client filtering from request context + if existingIncludeClients, ok := ctx.Value(MCPContextKeyIncludeClients).([]string); ok && existingIncludeClients != nil { + includeClients = existingIncludeClients + } + if existingExcludeClients, ok := ctx.Value(MCPContextKeyExcludeClients).([]string); ok && existingExcludeClients != nil { + excludeClients = existingExcludeClients + } + + tools := make([]schemas.Tool, 0) + for clientName, client := range m.clientMap { + // Apply client filtering logic + if !m.shouldIncludeClient(clientName, includeClients, excludeClients) { + continue + } + + // Add all tools from this client + for toolName, tool := range client.ToolMap { + if m.shouldSkipToolForRequest(toolName, ctx) { + continue + } + + tools = append(tools, tool) + } + } + return tools +} + +// registerTool registers a typed tool handler with the local MCP server. +// This is a convenience function that handles the conversion between typed Go +// handlers and the MCP protocol. +// +// Type Parameters: +// - T: The expected argument type for the tool (must be JSON-deserializable) +// +// Parameters: +// - name: Unique tool name +// - description: Human-readable tool description +// - handler: Typed function that handles tool execution +// - toolSchema: Bifrost tool schema for function calling +// +// Returns: +// - error: Any registration error +// +// Example: +// +// type EchoArgs struct { +// Message string `json:"message"` +// } +// +// err := bifrost.RegisterMCPTool("echo", "Echo a message", +// func(args EchoArgs) (string, error) { +// return args.Message, nil +// }, toolSchema) +func (m *MCPManager) registerTool(name, description string, handler MCPToolHandler[any], toolSchema schemas.Tool) error { + // Ensure local server is set up + if err := m.setupLocalHost(); err != nil { + return fmt.Errorf("failed to setup local host: %w", err) + } + + // Verify internal client exists + if _, ok := m.clientMap[BifrostMCPClientKey]; !ok { + return fmt.Errorf("bifrost client not found") + } + + m.mu.Lock() + defer m.mu.Unlock() + + // Check if tool name already exists to prevent silent overwrites + if _, exists := m.clientMap[BifrostMCPClientKey].ToolMap[name]; exists { + return fmt.Errorf("tool '%s' is already registered", name) + } + + m.logger.Info(fmt.Sprintf("%s Registering typed tool: %s", MCPLogPrefix, name)) + + // Create MCP handler wrapper that converts between typed and MCP interfaces + mcpHandler := func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { + // Extract arguments from the request using the request's methods + args := request.GetArguments() + result, err := handler(args) + if err != nil { + return mcp.NewToolResultError(fmt.Sprintf("Error: %s", err.Error())), nil + } + return mcp.NewToolResultText(result), nil + } + + // Register the tool with the local MCP server using AddTool + if m.server != nil { + tool := mcp.NewTool(name, mcp.WithDescription(description)) + m.server.AddTool(tool, mcpHandler) + } + + // Store tool definition for Bifrost integration + m.clientMap[BifrostMCPClientKey].ToolMap[name] = toolSchema + + return nil +} + +// setupLocalHost initializes the local MCP server and client if not already running. +// This creates a STDIO-based server for local tool hosting and a corresponding client. +// This is called automatically when tools are registered or when the server is needed. +// +// Returns: +// - error: Any setup error +func (m *MCPManager) setupLocalHost() error { + // Check if server is already running + if m.server != nil && m.serverRunning { + return nil + } + + // Create and configure local MCP server (STDIO-based) + server, err := m.createLocalMCPServer() + if err != nil { + return fmt.Errorf("failed to create local MCP server: %w", err) + } + m.server = server + + // Create and configure local MCP client (STDIO-based) + client, err := m.createLocalMCPClient() + if err != nil { + return fmt.Errorf("failed to create local MCP client: %w", err) + } + m.clientMap[BifrostMCPClientKey] = client + + // Start the server and initialize client connection + return m.startLocalMCPServer() +} + +// createLocalMCPServer creates a new local MCP server instance with STDIO transport. +// This server will host tools registered via RegisterTool function. +// +// Returns: +// - *server.MCPServer: Configured MCP server instance +// - error: Any creation error +func (m *MCPManager) createLocalMCPServer() (*server.MCPServer, error) { + // Create MCP server + mcpServer := server.NewMCPServer( + "Bifrost-MCP-Server", + "1.0.0", + server.WithToolCapabilities(true), + ) + + return mcpServer, nil +} + +// createLocalMCPClient creates a placeholder client entry for the local MCP server. +// The actual in-process client connection will be established in startLocalMCPServer. +// +// Returns: +// - *MCPClient: Placeholder client for local server +// - error: Any creation error +func (m *MCPManager) createLocalMCPClient() (*MCPClient, error) { + // Don't create the actual client connection here - it will be created + // after the server is ready using NewInProcessClient + return &MCPClient{ + Name: BifrostMCPClientName, + ExecutionConfig: schemas.MCPClientConfig{ + Name: BifrostMCPClientName, + }, + ToolMap: make(map[string]schemas.Tool), + ConnectionInfo: MCPClientConnectionInfo{ + Type: schemas.MCPConnectionTypeInProcess, // Accurate: in-process (in-memory) transport + }, + }, nil +} + +// startLocalMCPServer creates an in-process connection between the local server and client. +// +// Returns: +// - error: Any startup error +func (m *MCPManager) startLocalMCPServer() error { + m.mu.Lock() + defer m.mu.Unlock() + + // Check if server is already running + if m.server != nil && m.serverRunning { + return nil + } + + if m.server == nil { + return fmt.Errorf("server not initialized") + } + + // Create in-process client directly connected to the server + inProcessClient, err := client.NewInProcessClient(m.server) + if err != nil { + return fmt.Errorf("failed to create in-process MCP client: %w", err) + } + + // Update the client connection + clientEntry, ok := m.clientMap[BifrostMCPClientKey] + if !ok { + return fmt.Errorf("bifrost client not found") + } + clientEntry.Conn = inProcessClient + + // Initialize the in-process client + ctx, cancel := context.WithTimeout(m.ctx, MCPClientConnectionEstablishTimeout) + defer cancel() + + // Create proper initialize request with correct structure + initRequest := mcp.InitializeRequest{ + Params: mcp.InitializeParams{ + ProtocolVersion: mcp.LATEST_PROTOCOL_VERSION, + Capabilities: mcp.ClientCapabilities{}, + ClientInfo: mcp.Implementation{ + Name: BifrostMCPClientName, + Version: BifrostMCPVersion, + }, + }, + } + + _, err = inProcessClient.Initialize(ctx, initRequest) + if err != nil { + return fmt.Errorf("failed to initialize MCP client: %w", err) + } + + // Mark server as running + m.serverRunning = true + + return nil +} + +// executeTool executes a tool call and returns the result as a tool message. +// +// Parameters: +// - ctx: Execution context +// - toolCall: The tool call to execute (from assistant message) +// +// Returns: +// - schemas.BifrostMessage: Tool message with execution result +// - error: Any execution error +func (m *MCPManager) executeTool(ctx context.Context, toolCall schemas.ToolCall) (*schemas.BifrostMessage, error) { + if toolCall.Function.Name == nil { + return nil, fmt.Errorf("tool call missing function name") + } + toolName := *toolCall.Function.Name + + // Parse tool arguments + var arguments map[string]interface{} + if err := json.Unmarshal([]byte(toolCall.Function.Arguments), &arguments); err != nil { + return nil, fmt.Errorf("failed to parse tool arguments for '%s': %v", toolName, err) + } + + // Find which client has this tool + client := m.findMCPClientForTool(toolName) + if client == nil { + return nil, fmt.Errorf("tool '%s' not found in any connected MCP client", toolName) + } + + if client.Conn == nil { + return nil, fmt.Errorf("client '%s' has no active connection", client.Name) + } + + // Call the tool via MCP client -> MCP server + callRequest := mcp.CallToolRequest{ + Request: mcp.Request{ + Method: string(mcp.MethodToolsCall), + }, + Params: mcp.CallToolParams{ + Name: toolName, + Arguments: arguments, + }, + } + + m.logger.Debug(fmt.Sprintf("%s Starting tool execution: %s via client: %s", MCPLogPrefix, toolName, client.Name)) + + toolResponse, callErr := client.Conn.CallTool(ctx, callRequest) + if callErr != nil { + m.logger.Error("%s Tool execution failed for %s via client %s: %v", MCPLogPrefix, toolName, client.Name, callErr) + return nil, fmt.Errorf("MCP tool call failed: %v", callErr) + } + + m.logger.Debug(fmt.Sprintf("%s Tool execution completed: %s", MCPLogPrefix, toolName)) + + // Extract text from MCP response + responseText := m.extractTextFromMCPResponse(toolResponse, toolName) + + // Create tool response message + return m.createToolResponseMessage(toolCall, responseText), nil +} + +// ============================================================================ +// EXTERNAL MCP CONNECTION MANAGEMENT +// ============================================================================ + +// connectToMCPClient establishes a connection to an external MCP server and +// registers its available tools with the manager. +func (m *MCPManager) connectToMCPClient(config schemas.MCPClientConfig) error { + // First lock: Initialize or validate client entry + m.mu.Lock() + + // Initialize or validate client entry + if existingClient, exists := m.clientMap[config.Name]; exists { + // Client entry exists from config, check for existing connection + if existingClient.Conn != nil { + m.mu.Unlock() + return fmt.Errorf("client %s already has an active connection", config.Name) + } + // Update connection type for this connection attempt + existingClient.ConnectionInfo.Type = config.ConnectionType + } else { + // Create new client entry with configuration + m.clientMap[config.Name] = &MCPClient{ + Name: config.Name, + ExecutionConfig: config, + ToolMap: make(map[string]schemas.Tool), + ConnectionInfo: MCPClientConnectionInfo{ + Type: config.ConnectionType, + }, + } + } + m.mu.Unlock() + + // Heavy operations performed outside lock + var externalClient *client.Client + var connectionInfo MCPClientConnectionInfo + var err error + + // Create appropriate transport based on connection type + switch config.ConnectionType { + case schemas.MCPConnectionTypeHTTP: + externalClient, connectionInfo, err = m.createHTTPConnection(config) + case schemas.MCPConnectionTypeSTDIO: + externalClient, connectionInfo, err = m.createSTDIOConnection(config) + case schemas.MCPConnectionTypeSSE: + externalClient, connectionInfo, err = m.createSSEConnection(config) + case schemas.MCPConnectionTypeInProcess: + externalClient, connectionInfo, err = m.createInProcessConnection(config) + default: + return fmt.Errorf("unknown connection type: %s", config.ConnectionType) + } + + if err != nil { + return fmt.Errorf("failed to create connection: %w", err) + } + + // Initialize the external client with timeout + // For SSE connections, we need a long-lived context, for others we can use timeout + var ctx context.Context + var cancel context.CancelFunc + + if config.ConnectionType == schemas.MCPConnectionTypeSSE { + // SSE connections need a long-lived context for the persistent stream + ctx, cancel = context.WithCancel(m.ctx) + // Don't defer cancel here - SSE needs the context to remain active + } else { + // Other connection types can use timeout context + ctx, cancel = context.WithTimeout(m.ctx, MCPClientConnectionEstablishTimeout) + defer cancel() + } + + // Start the transport first (required for STDIO and SSE clients) + if err := externalClient.Start(ctx); err != nil { + if config.ConnectionType == schemas.MCPConnectionTypeSSE { + cancel() // Cancel SSE context only on error + } + return fmt.Errorf("failed to start MCP client transport %s: %v", config.Name, err) + } + + // Create proper initialize request for external client + extInitRequest := mcp.InitializeRequest{ + Params: mcp.InitializeParams{ + ProtocolVersion: mcp.LATEST_PROTOCOL_VERSION, + Capabilities: mcp.ClientCapabilities{}, + ClientInfo: mcp.Implementation{ + Name: fmt.Sprintf("Bifrost-%s", config.Name), + Version: "1.0.0", + }, + }, + } + + _, err = externalClient.Initialize(ctx, extInitRequest) + if err != nil { + if config.ConnectionType == schemas.MCPConnectionTypeSSE { + cancel() // Cancel SSE context only on error + } + return fmt.Errorf("failed to initialize MCP client %s: %v", config.Name, err) + } + + // Retrieve tools from the external server (this also requires network I/O) + tools, err := m.retrieveExternalTools(ctx, externalClient, config) + if err != nil { + m.logger.Warn(fmt.Sprintf("%s Failed to retrieve tools from %s: %v", MCPLogPrefix, config.Name, err)) + // Continue with connection even if tool retrieval fails + tools = make(map[string]schemas.Tool) + } + + // Second lock: Update client with final connection details and tools + m.mu.Lock() + defer m.mu.Unlock() + + // Verify client still exists (could have been cleaned up during heavy operations) + if client, exists := m.clientMap[config.Name]; exists { + // Store the external client connection and details + client.Conn = externalClient + client.ConnectionInfo = connectionInfo + + // Store cancel function for SSE connections to enable proper cleanup + if config.ConnectionType == schemas.MCPConnectionTypeSSE { + client.cancelFunc = cancel + } + + // Store discovered tools + for toolName, tool := range tools { + client.ToolMap[toolName] = tool + } + + m.logger.Info(fmt.Sprintf("%s Connected to MCP client: %s", MCPLogPrefix, config.Name)) + } else { + return fmt.Errorf("client %s was removed during connection setup", config.Name) + } + + return nil +} + +// retrieveExternalTools retrieves and filters tools from an external MCP server without holding locks. +func (m *MCPManager) retrieveExternalTools(ctx context.Context, client *client.Client, config schemas.MCPClientConfig) (map[string]schemas.Tool, error) { + // Get available tools from external server + listRequest := mcp.ListToolsRequest{ + PaginatedRequest: mcp.PaginatedRequest{ + Request: mcp.Request{ + Method: string(mcp.MethodToolsList), + }, + }, + } + + toolsResponse, err := client.ListTools(ctx, listRequest) + if err != nil { + return nil, fmt.Errorf("failed to list tools: %v", err) + } + + if toolsResponse == nil { + return make(map[string]schemas.Tool), nil // No tools available + } + + tools := make(map[string]schemas.Tool) + + // toolsResponse is already a ListToolsResult + for _, mcpTool := range toolsResponse.Tools { + // Check if tool should be skipped based on configuration + if m.shouldSkipToolForConfig(mcpTool.Name, config) { + continue + } + + // Convert MCP tool schema to Bifrost format + bifrostTool := m.convertMCPToolToBifrostSchema(&mcpTool) + tools[mcpTool.Name] = bifrostTool + } + + return tools, nil +} + +// shouldSkipToolForConfig checks if a tool should be skipped based on client configuration (without accessing clientMap). +func (m *MCPManager) shouldSkipToolForConfig(toolName string, config schemas.MCPClientConfig) bool { + // If ToolsToExecute is specified, only execute tools in that list + if len(config.ToolsToExecute) > 0 { + for _, allowedTool := range config.ToolsToExecute { + if allowedTool == toolName { + return false // Tool is allowed + } + } + return true // Tool not in allowed list + } + + // Check if tool is in skip list + for _, skipTool := range config.ToolsToSkip { + if skipTool == toolName { + return true // Tool should be skipped + } + } + + return false // Tool is allowed +} + +// shouldSkipToolForRequest checks if a tool should be skipped based on the request context. +func (m *MCPManager) shouldSkipToolForRequest(toolName string, ctx context.Context) bool { + includeTools := ctx.Value(MCPContextKeyIncludeTools) + excludeTools := ctx.Value(MCPContextKeyExcludeTools) + + if includeTools != nil { + if includeStr, ok := includeTools.(string); ok && includeStr != "" { + includeToolsList := strings.Split(includeStr, ",") + if slices.Contains(includeToolsList, toolName) { + return false // Tool is allowed + } + } + } + + if excludeTools != nil { + if excludeStr, ok := excludeTools.(string); ok && excludeStr != "" { + excludeToolsList := strings.Split(excludeStr, ",") + if slices.Contains(excludeToolsList, toolName) { + return true // Tool should be skipped + } + } + } + + return false // Tool is allowed +} + +// convertMCPToolToBifrostSchema converts an MCP tool definition to Bifrost format. +func (m *MCPManager) convertMCPToolToBifrostSchema(mcpTool *mcp.Tool) schemas.Tool { + return schemas.Tool{ + Type: "function", + Function: schemas.Function{ + Name: mcpTool.Name, + Description: mcpTool.Description, + Parameters: schemas.FunctionParameters{ + Type: mcpTool.InputSchema.Type, + Properties: mcpTool.InputSchema.Properties, + Required: mcpTool.InputSchema.Required, + }, + }, + } +} + +// extractTextFromMCPResponse extracts text content from an MCP tool response. +func (m *MCPManager) extractTextFromMCPResponse(toolResponse *mcp.CallToolResult, toolName string) string { + if toolResponse == nil { + return fmt.Sprintf("MCP tool '%s' executed successfully", toolName) + } + + var result strings.Builder + for _, contentBlock := range toolResponse.Content { + // Handle typed content + switch content := contentBlock.(type) { + case mcp.TextContent: + result.WriteString(content.Text) + case mcp.ImageContent: + result.WriteString(fmt.Sprintf("[Image Response: %s, MIME: %s]\n", content.Data, content.MIMEType)) + case mcp.AudioContent: + result.WriteString(fmt.Sprintf("[Audio Response: %s, MIME: %s]\n", content.Data, content.MIMEType)) + case mcp.EmbeddedResource: + result.WriteString(fmt.Sprintf("[Embedded Resource Response: %s]\n", content.Type)) + default: + // Fallback: try to extract from map structure + if jsonBytes, err := json.Marshal(contentBlock); err == nil { + var contentMap map[string]interface{} + if json.Unmarshal(jsonBytes, &contentMap) == nil { + if text, ok := contentMap["text"].(string); ok { + result.WriteString(fmt.Sprintf("[Text Response: %s]\n", text)) + continue + } + } + // Final fallback: serialize as JSON + result.WriteString(string(jsonBytes)) + } + } + } + + if result.Len() > 0 { + return strings.TrimSpace(result.String()) + } + return fmt.Sprintf("MCP tool '%s' executed successfully", toolName) +} + +// createToolResponseMessage creates a tool response message with the execution result. +func (m *MCPManager) createToolResponseMessage(toolCall schemas.ToolCall, responseText string) *schemas.BifrostMessage { + return &schemas.BifrostMessage{ + Role: schemas.ModelChatMessageRoleTool, + Content: schemas.MessageContent{ + ContentStr: &responseText, + }, + ToolMessage: &schemas.ToolMessage{ + ToolCallID: toolCall.ID, + }, + } +} + +func (m *MCPManager) addMCPToolsToBifrostRequest(ctx context.Context, req *schemas.BifrostRequest) *schemas.BifrostRequest { + mcpTools := m.getAvailableTools(ctx) + if len(mcpTools) > 0 { + // Initialize tools array if needed + if req.Params == nil { + req.Params = &schemas.ModelParameters{} + } + if req.Params.Tools == nil { + req.Params.Tools = &[]schemas.Tool{} + } + tools := *req.Params.Tools + + // Create a map of existing tool names for O(1) lookup + existingToolsMap := make(map[string]bool) + for _, tool := range tools { + existingToolsMap[tool.Function.Name] = true + } + + // Add MCP tools that are not already present + for _, mcpTool := range mcpTools { + if !existingToolsMap[mcpTool.Function.Name] { + tools = append(tools, mcpTool) + // Update the map to prevent duplicates within MCP tools as well + existingToolsMap[mcpTool.Function.Name] = true + } + } + req.Params.Tools = &tools + + } + return req +} + +func validateMCPClientConfig(config *schemas.MCPClientConfig) error { + if strings.TrimSpace(config.Name) == "" { + return fmt.Errorf("name is required for MCP client config") + } + + if config.ConnectionType == "" { + return fmt.Errorf("connection type is required for MCP client config") + } + + switch config.ConnectionType { + case schemas.MCPConnectionTypeHTTP: + if config.ConnectionString == nil { + return fmt.Errorf("ConnectionString is required for HTTP connection type in client '%s'", config.Name) + } + case schemas.MCPConnectionTypeSSE: + if config.ConnectionString == nil { + return fmt.Errorf("ConnectionString is required for SSE connection type in client '%s'", config.Name) + } + case schemas.MCPConnectionTypeSTDIO: + if config.StdioConfig == nil { + return fmt.Errorf("StdioConfig is required for STDIO connection type in client '%s'", config.Name) + } + case schemas.MCPConnectionTypeInProcess: + // InProcess requires a server instance to be provided programmatically + // This cannot be validated from JSON config - the server must be set when using the Go package + if config.InProcessServer == nil { + return fmt.Errorf("InProcessServer is required for InProcess connection type in client '%s' (Go package only)", config.Name) + } + default: + return fmt.Errorf("unknown connection type '%s' in client '%s'", config.ConnectionType, config.Name) + } + + // Check for overlapping tools between ToolsToSkip and ToolsToExecute + if len(config.ToolsToSkip) > 0 && len(config.ToolsToExecute) > 0 { + skipMap := make(map[string]bool) + for _, tool := range config.ToolsToSkip { + skipMap[tool] = true + } + + var overlapping []string + for _, tool := range config.ToolsToExecute { + if skipMap[tool] { + overlapping = append(overlapping, tool) + } + } + + if len(overlapping) > 0 { + return fmt.Errorf("tools cannot be both included and excluded in client '%s': %v", config.Name, overlapping) + } + } + + return nil +} + +// ============================================================================ +// HELPER METHODS +// ============================================================================ + +// findMCPClientForTool safely finds a client that has the specified tool. +func (m *MCPManager) findMCPClientForTool(toolName string) *MCPClient { + m.mu.RLock() + defer m.mu.RUnlock() + + for _, client := range m.clientMap { + if _, exists := client.ToolMap[toolName]; exists { + return client + } + } + return nil +} + +// shouldIncludeClient determines if a client should be included based on filtering rules. +func (m *MCPManager) shouldIncludeClient(clientName string, includeClients, excludeClients []string) bool { + // If includeClients is specified, only include those clients (whitelist mode) + if len(includeClients) > 0 { + return slices.Contains(includeClients, clientName) + } + + // If excludeClients is specified, exclude those clients (blacklist mode) + if len(excludeClients) > 0 { + return !slices.Contains(excludeClients, clientName) + } + + // Default: include all clients + return true +} + +// createHTTPConnection creates an HTTP-based MCP client connection without holding locks. +func (m *MCPManager) createHTTPConnection(config schemas.MCPClientConfig) (*client.Client, MCPClientConnectionInfo, error) { + if config.ConnectionString == nil { + return nil, MCPClientConnectionInfo{}, fmt.Errorf("HTTP connection string is required") + } + + // Prepare connection info + connectionInfo := MCPClientConnectionInfo{ + Type: config.ConnectionType, + ConnectionURL: config.ConnectionString, + } + + // Create StreamableHTTP transport + httpTransport, err := transport.NewStreamableHTTP(*config.ConnectionString) + if err != nil { + return nil, MCPClientConnectionInfo{}, fmt.Errorf("failed to create HTTP transport: %w", err) + } + + client := client.NewClient(httpTransport) + + return client, connectionInfo, nil +} + +// createSTDIOConnection creates a STDIO-based MCP client connection without holding locks. +func (m *MCPManager) createSTDIOConnection(config schemas.MCPClientConfig) (*client.Client, MCPClientConnectionInfo, error) { + if config.StdioConfig == nil { + return nil, MCPClientConnectionInfo{}, fmt.Errorf("stdio config is required") + } + + // Prepare STDIO command info for display + cmdString := fmt.Sprintf("%s %s", config.StdioConfig.Command, strings.Join(config.StdioConfig.Args, " ")) + + // Check if environment variables are set + for _, env := range config.StdioConfig.Envs { + if os.Getenv(env) == "" { + return nil, MCPClientConnectionInfo{}, fmt.Errorf("environment variable %s is not set for MCP client %s", env, config.Name) + } + } + + // Create STDIO transport + stdioTransport := transport.NewStdio( + config.StdioConfig.Command, + config.StdioConfig.Envs, + config.StdioConfig.Args..., + ) + + // Prepare connection info + connectionInfo := MCPClientConnectionInfo{ + Type: config.ConnectionType, + StdioCommandString: &cmdString, + } + + client := client.NewClient(stdioTransport) + + // Return nil for cmd since mark3labs/mcp-go manages the process internally + return client, connectionInfo, nil +} + +// createSSEConnection creates a SSE-based MCP client connection without holding locks. +func (m *MCPManager) createSSEConnection(config schemas.MCPClientConfig) (*client.Client, MCPClientConnectionInfo, error) { + if config.ConnectionString == nil { + return nil, MCPClientConnectionInfo{}, fmt.Errorf("SSE connection string is required") + } + + // Prepare connection info + connectionInfo := MCPClientConnectionInfo{ + Type: config.ConnectionType, + ConnectionURL: config.ConnectionString, // Reuse HTTPConnectionURL field for SSE URL display + } + + // Create SSE transport + sseTransport, err := transport.NewSSE(*config.ConnectionString) + if err != nil { + return nil, MCPClientConnectionInfo{}, fmt.Errorf("failed to create SSE transport: %w", err) + } + + client := client.NewClient(sseTransport) + + return client, connectionInfo, nil +} + +// createInProcessConnection creates an in-process MCP client connection without holding locks. +// This allows direct connection to an MCP server running in the same process, providing +// the lowest latency and highest performance for tool execution. +func (m *MCPManager) createInProcessConnection(config schemas.MCPClientConfig) (*client.Client, MCPClientConnectionInfo, error) { + if config.InProcessServer == nil { + return nil, MCPClientConnectionInfo{}, fmt.Errorf("InProcess connection requires a server instance") + } + + // Type assert to ensure we have a proper MCP server + mcpServer, ok := config.InProcessServer.(*server.MCPServer) + if !ok { + return nil, MCPClientConnectionInfo{}, fmt.Errorf("InProcessServer must be a *server.MCPServer instance") + } + + // Create in-process client directly connected to the provided server + inProcessClient, err := client.NewInProcessClient(mcpServer) + if err != nil { + return nil, MCPClientConnectionInfo{}, fmt.Errorf("failed to create in-process client: %w", err) + } + + // Prepare connection info + connectionInfo := MCPClientConnectionInfo{ + Type: config.ConnectionType, + } + + return inProcessClient, connectionInfo, nil +} + +// cleanup performs cleanup of all MCP resources including clients and local server. +// This function safely disconnects all MCP clients (HTTP, STDIO, and SSE) and +// cleans up the local MCP server. It handles proper cancellation of SSE contexts +// and closes all transport connections. +// +// Returns: +// - error: Always returns nil, but maintains error interface for consistency +func (m *MCPManager) cleanup() error { + m.mu.Lock() + defer m.mu.Unlock() + + // Disconnect all external MCP clients + for name := range m.clientMap { + if err := m.removeClientUnsafe(name); err != nil { + m.logger.Error("%s Failed to remove MCP client %s: %v", MCPLogPrefix, name, err) + } + } + + // Clear the client map + m.clientMap = make(map[string]*MCPClient) + + // Clear local server reference + // Note: mark3labs/mcp-go STDIO server cleanup is handled automatically + if m.server != nil { + m.logger.Info(MCPLogPrefix + " Clearing local MCP server reference") + m.server = nil + m.serverRunning = false + } + + m.logger.Info(MCPLogPrefix + " MCP cleanup completed") + return nil +} diff --git a/core/providers/anthropic.go b/core/providers/anthropic.go index 881c0aced..c4aa299b8 100644 --- a/core/providers/anthropic.go +++ b/core/providers/anthropic.go @@ -3,12 +3,17 @@ package providers import ( + "bufio" + "bytes" + "context" "fmt" + "io" + "net/http" + "strings" "sync" "time" - "github.com/goccy/go-json" - + "github.com/bytedance/sonic" schemas "github.com/maximhq/bifrost/core/schemas" "github.com/valyala/fasthttp" ) @@ -57,6 +62,80 @@ type AnthropicChatResponse struct { } `json:"usage"` // Token usage statistics } +// AnthropicStreamEvent represents a single event in the Anthropic streaming response. +// It corresponds to the various event types defined in Anthropic's Messages API streaming documentation. +type AnthropicStreamEvent struct { + Type string `json:"type"` + Message *AnthropicStreamMessage `json:"message,omitempty"` + Index *int `json:"index,omitempty"` + ContentBlock *AnthropicContentBlock `json:"content_block,omitempty"` + Delta *AnthropicDelta `json:"delta,omitempty"` + Usage *AnthropicUsage `json:"usage,omitempty"` + Error *AnthropicStreamError `json:"error,omitempty"` +} + +// AnthropicStreamMessage represents the message structure in streaming events. +// This appears in message_start events and contains the initial message structure. +type AnthropicStreamMessage struct { + ID string `json:"id"` + Type string `json:"type"` + Role string `json:"role"` + Content []AnthropicContentBlock `json:"content"` + Model string `json:"model"` + StopReason *string `json:"stop_reason"` + StopSequence *string `json:"stop_sequence"` + Usage *schemas.LLMUsage `json:"usage"` +} + +// AnthropicContentBlock represents a content block in Anthropic responses. +// This includes text, tool_use, thinking, and web_search_tool_result blocks. +type AnthropicContentBlock struct { + Type string `json:"type"` + Text string `json:"text,omitempty"` + ID string `json:"id,omitempty"` + Name string `json:"name,omitempty"` + Input map[string]interface{} `json:"input,omitempty"` + Thinking string `json:"thinking,omitempty"` + // Web search tool result specific fields + ToolUseID string `json:"tool_use_id,omitempty"` + Content []AnthropicToolContent `json:"content,omitempty"` +} + +// AnthropicToolContent represents content within tool result blocks +type AnthropicToolContent struct { + Type string `json:"type"` + Title string `json:"title,omitempty"` + URL string `json:"url,omitempty"` + EncryptedContent string `json:"encrypted_content,omitempty"` + PageAge *string `json:"page_age,omitempty"` +} + +// AnthropicDelta represents incremental updates to content blocks during streaming. +// This includes all delta types: text_delta, input_json_delta, thinking_delta, and signature_delta. +type AnthropicDelta struct { + Type string `json:"type"` + Text string `json:"text,omitempty"` + PartialJSON string `json:"partial_json,omitempty"` + Thinking string `json:"thinking,omitempty"` + Signature string `json:"signature,omitempty"` + StopReason *string `json:"stop_reason,omitempty"` + StopSequence *string `json:"stop_sequence,omitempty"` +} + +// AnthropicUsage represents the usage information for Anthropic's API. +type AnthropicUsage struct { + InputTokens int `json:"input_tokens"` + CacheCreationInputTokens int `json:"cache_creation_input_tokens"` + CacheReadInputTokens int `json:"cache_read_input_tokens"` + OutputTokens int `json:"output_tokens"` +} + +// AnthropicStreamError represents error events in the streaming response. +type AnthropicStreamError struct { + Type string `json:"type"` + Message string `json:"message"` +} + // AnthropicError represents the error response structure from Anthropic's API. // It includes error type and message information. type AnthropicError struct { @@ -67,10 +146,21 @@ type AnthropicError struct { } `json:"error"` // Error details } +type AnthropicImageContent struct { + Type ImageContentType `json:"type"` + URL string `json:"url"` + MediaType string `json:"media_type,omitempty"` +} + // AnthropicProvider implements the Provider interface for Anthropic's Claude API. type AnthropicProvider struct { - logger schemas.Logger // Logger for provider operations - client *fasthttp.Client // HTTP client for API requests + logger schemas.Logger // Logger for provider operations + client *fasthttp.Client // HTTP client for API requests + streamClient *http.Client // HTTP client for streaming requests + apiVersion string // API version for the provider + networkConfig schemas.NetworkConfig // Network configuration including extra headers + sendBackRawResponse bool // Whether to include raw response in BifrostResponse + customProviderConfig *schemas.CustomProviderConfig // Custom provider config } // anthropicChatResponsePool provides a pool for Anthropic chat response objects. @@ -115,69 +205,103 @@ func releaseAnthropicTextResponse(resp *AnthropicTextResponse) { } } +// Since Anthropic always needs to have a max_tokens parameter, we set a default value if not provided. +const ( + AnthropicDefaultMaxTokens = 4096 +) + +// mapAnthropicFinishReasonToOpenAI maps Anthropic finish reasons to OpenAI-compatible ones +func MapAnthropicFinishReason(anthropicReason string) string { + switch anthropicReason { + case "end_turn": + return "stop" + case "max_tokens": + return "length" + case "stop_sequence": + return "stop" + case "tool_use": + return "tool_calls" + default: + // Pass through Anthropic-specific reasons like "pause_turn", "refusal", etc. + return anthropicReason + } +} + // NewAnthropicProvider creates a new Anthropic provider instance. // It initializes the HTTP client with the provided configuration and sets up response pools. // The client is configured with timeouts, concurrency limits, and optional proxy settings. func NewAnthropicProvider(config *schemas.ProviderConfig, logger schemas.Logger) *AnthropicProvider { - setConfigDefaults(config) + config.CheckAndSetDefaults() client := &fasthttp.Client{ ReadTimeout: time.Second * time.Duration(config.NetworkConfig.DefaultRequestTimeoutInSeconds), WriteTimeout: time.Second * time.Duration(config.NetworkConfig.DefaultRequestTimeoutInSeconds), - MaxConnsPerHost: config.ConcurrencyAndBufferSize.BufferSize, + MaxConnsPerHost: config.ConcurrencyAndBufferSize.Concurrency, + } + + // Initialize streaming HTTP client + streamClient := &http.Client{ + Timeout: time.Second * time.Duration(config.NetworkConfig.DefaultRequestTimeoutInSeconds), } // Pre-warm response pools - for range config.ConcurrencyAndBufferSize.Concurrency { + for i := 0; i < config.ConcurrencyAndBufferSize.Concurrency; i++ { anthropicTextResponsePool.Put(&AnthropicTextResponse{}) anthropicChatResponsePool.Put(&AnthropicChatResponse{}) - bifrostResponsePool.Put(&schemas.BifrostResponse{}) } // Configure proxy if provided client = configureProxy(client, config.ProxyConfig, logger) + // Set default BaseURL if not provided + if config.NetworkConfig.BaseURL == "" { + config.NetworkConfig.BaseURL = "https://api.anthropic.com" + } + config.NetworkConfig.BaseURL = strings.TrimRight(config.NetworkConfig.BaseURL, "/") + return &AnthropicProvider{ - logger: logger, - client: client, + logger: logger, + client: client, + streamClient: streamClient, + apiVersion: "2023-06-01", + networkConfig: config.NetworkConfig, + sendBackRawResponse: config.SendBackRawResponse, + customProviderConfig: config.CustomProviderConfig, } } // GetProviderKey returns the provider identifier for Anthropic. func (provider *AnthropicProvider) GetProviderKey() schemas.ModelProvider { - return schemas.Anthropic + return getProviderName(schemas.Anthropic, provider.customProviderConfig) } // prepareTextCompletionParams prepares text completion parameters for Anthropic's API. // It handles parameter mapping and conversion to the format expected by Anthropic. // Returns the modified parameters map. func (provider *AnthropicProvider) prepareTextCompletionParams(params map[string]interface{}) map[string]interface{} { - // Check if there is a key entry for max_tokens - if maxTokens, exists := params["max_tokens"]; exists { - // Check if max_tokens_to_sample is already present - if _, exists := params["max_tokens_to_sample"]; !exists { - // If max_tokens_to_sample is not present, rename max_tokens to max_tokens_to_sample + maxTokens, maxTokensExists := params["max_tokens"] + if _, exists := params["max_tokens_to_sample"]; !exists { + // If max_tokens_to_sample is not present, rename max_tokens to max_tokens_to_sample + if maxTokensExists { params["max_tokens_to_sample"] = maxTokens + } else { + params["max_tokens_to_sample"] = AnthropicDefaultMaxTokens } - delete(params, "max_tokens") } + + delete(params, "max_tokens") + return params } // completeRequest sends a request to Anthropic's API and handles the response. // It constructs the API URL, sets up authentication, and processes the response. // Returns the response body or an error if the request fails. -func (provider *AnthropicProvider) completeRequest(requestBody map[string]interface{}, url string, key string) ([]byte, *schemas.BifrostError) { +func (provider *AnthropicProvider) completeRequest(ctx context.Context, requestBody map[string]interface{}, url string, key string) ([]byte, *schemas.BifrostError) { // Marshal the request body - jsonData, err := json.Marshal(requestBody) + jsonData, err := sonic.Marshal(requestBody) if err != nil { - return nil, &schemas.BifrostError{ - IsBifrostError: true, - Error: schemas.ErrorField{ - Message: schemas.ErrProviderJSONMarshaling, - Error: err, - }, - } + return nil, newBifrostOperationError(schemas.ErrProviderJSONMarshaling, err, provider.GetProviderKey()) } // Create the request with the JSON body @@ -186,26 +310,27 @@ func (provider *AnthropicProvider) completeRequest(requestBody map[string]interf defer fasthttp.ReleaseRequest(req) defer fasthttp.ReleaseResponse(resp) + // Set any extra headers from network config + setExtraHeaders(req, provider.networkConfig.ExtraHeaders, nil) + req.SetRequestURI(url) req.Header.SetMethod("POST") req.Header.SetContentType("application/json") req.Header.Set("x-api-key", key) - req.Header.Set("anthropic-version", "2023-06-01") + req.Header.Set("anthropic-version", provider.apiVersion) + req.SetBody(jsonData) // Send the request - if err := provider.client.Do(req, resp); err != nil { - return nil, &schemas.BifrostError{ - IsBifrostError: false, - Error: schemas.ErrorField{ - Message: schemas.ErrProviderRequest, - Error: err, - }, - } + bifrostErr := makeRequestWithContext(ctx, provider.client, req, resp) + if bifrostErr != nil { + return nil, bifrostErr } // Handle error response if resp.StatusCode() != fasthttp.StatusOK { + provider.logger.Debug(fmt.Sprintf("error from %s provider: %s", provider.GetProviderKey(), string(resp.Body()))) + var errorResp AnthropicError bifrostErr := handleProviderAPIError(resp, &errorResp) @@ -224,7 +349,11 @@ func (provider *AnthropicProvider) completeRequest(requestBody map[string]interf // TextCompletion performs a text completion request to Anthropic's API. // It formats the request, sends it to Anthropic, and processes the response. // Returns a BifrostResponse containing the completion results or an error if the request fails. -func (provider *AnthropicProvider) TextCompletion(model, key, text string, params *schemas.ModelParameters) (*schemas.BifrostResponse, *schemas.BifrostError) { +func (provider *AnthropicProvider) TextCompletion(ctx context.Context, model string, key schemas.Key, text string, params *schemas.ModelParameters) (*schemas.BifrostResponse, *schemas.BifrostError) { + if err := checkOperationAllowed(schemas.Anthropic, provider.customProviderConfig, schemas.OperationTextCompletion); err != nil { + return nil, err + } + preparedParams := provider.prepareTextCompletionParams(prepareParams(params)) // Merge additional parameters @@ -233,7 +362,7 @@ func (provider *AnthropicProvider) TextCompletion(model, key, text string, param "prompt": fmt.Sprintf("\n\nHuman: %s\n\nAssistant:", text), }, preparedParams) - responseBody, err := provider.completeRequest(requestBody, "https://api.anthropic.com/v1/complete", key) + responseBody, err := provider.completeRequest(ctx, requestBody, provider.networkConfig.BaseURL+"/v1/complete", key.Value) if err != nil { return nil, err } @@ -242,34 +371,45 @@ func (provider *AnthropicProvider) TextCompletion(model, key, text string, param response := acquireAnthropicTextResponse() defer releaseAnthropicTextResponse(response) - // Create Bifrost response from pool - bifrostResponse := acquireBifrostResponse() - defer releaseBifrostResponse(bifrostResponse) - - rawResponse, bifrostErr := handleProviderResponse(responseBody, response) + rawResponse, bifrostErr := handleProviderResponse(responseBody, response, provider.sendBackRawResponse) if bifrostErr != nil { return nil, bifrostErr } - bifrostResponse.ID = response.ID - bifrostResponse.Choices = []schemas.BifrostResponseChoice{ - { - Index: 0, - Message: schemas.BifrostResponseChoiceMessage{ - Role: schemas.RoleAssistant, - Content: &response.Completion, + // Create final response + bifrostResponse := &schemas.BifrostResponse{ + ID: response.ID, + Choices: []schemas.BifrostResponseChoice{ + { + Index: 0, + BifrostNonStreamResponseChoice: &schemas.BifrostNonStreamResponseChoice{ + Message: schemas.BifrostMessage{ + Role: schemas.ModelChatMessageRoleAssistant, + Content: schemas.MessageContent{ + ContentStr: &response.Completion, + }, + }, + }, }, }, + Usage: &schemas.LLMUsage{ + PromptTokens: response.Usage.InputTokens, + CompletionTokens: response.Usage.OutputTokens, + TotalTokens: response.Usage.InputTokens + response.Usage.OutputTokens, + }, + Model: response.Model, + ExtraFields: schemas.BifrostResponseExtraFields{ + Provider: provider.GetProviderKey(), + }, } - bifrostResponse.Usage = schemas.LLMUsage{ - PromptTokens: response.Usage.InputTokens, - CompletionTokens: response.Usage.OutputTokens, - TotalTokens: response.Usage.InputTokens + response.Usage.OutputTokens, + + // Set raw response if enabled + if provider.sendBackRawResponse { + bifrostResponse.ExtraFields.RawResponse = rawResponse } - bifrostResponse.Model = response.Model - bifrostResponse.ExtraFields = schemas.BifrostResponseExtraFields{ - Provider: schemas.Anthropic, - RawResponse: rawResponse, + + if params != nil { + bifrostResponse.ExtraFields.Params = *params } return bifrostResponse, nil @@ -278,52 +418,226 @@ func (provider *AnthropicProvider) TextCompletion(model, key, text string, param // ChatCompletion performs a chat completion request to Anthropic's API. // It formats the request, sends it to Anthropic, and processes the response. // Returns a BifrostResponse containing the completion results or an error if the request fails. -func (provider *AnthropicProvider) ChatCompletion(model, key string, messages []schemas.Message, params *schemas.ModelParameters) (*schemas.BifrostResponse, *schemas.BifrostError) { +func (provider *AnthropicProvider) ChatCompletion(ctx context.Context, model string, key schemas.Key, messages []schemas.BifrostMessage, params *schemas.ModelParameters) (*schemas.BifrostResponse, *schemas.BifrostError) { + if err := checkOperationAllowed(schemas.Anthropic, provider.customProviderConfig, schemas.OperationChatCompletion); err != nil { + return nil, err + } + + formattedMessages, preparedParams := prepareAnthropicChatRequest(messages, params) + + // Merge additional parameters + requestBody := mergeConfig(map[string]interface{}{ + "model": model, + "messages": formattedMessages, + }, preparedParams) + + responseBody, err := provider.completeRequest(ctx, requestBody, provider.networkConfig.BaseURL+"/v1/messages", key.Value) + if err != nil { + return nil, err + } + + // Create response object from pool + response := acquireAnthropicChatResponse() + defer releaseAnthropicChatResponse(response) + + rawResponse, bifrostErr := handleProviderResponse(responseBody, response, provider.sendBackRawResponse) + if bifrostErr != nil { + return nil, bifrostErr + } + + // Create final response + bifrostResponse := &schemas.BifrostResponse{} + bifrostResponse, err = parseAnthropicResponse(response, bifrostResponse) + if err != nil { + return nil, err + } + + bifrostResponse.ExtraFields = schemas.BifrostResponseExtraFields{ + Provider: provider.GetProviderKey(), + } + + // Set raw response if enabled + if provider.sendBackRawResponse { + bifrostResponse.ExtraFields.RawResponse = rawResponse + } + + if params != nil { + bifrostResponse.ExtraFields.Params = *params + } + + return bifrostResponse, nil +} + +// buildAnthropicImageSourceMap creates the "source" map for an Anthropic image content part. +func buildAnthropicImageSourceMap(imgContent *schemas.ImageURLStruct) map[string]interface{} { + if imgContent == nil { + return nil + } + + sanitizedURL, _ := SanitizeImageURL(imgContent.URL) + urlTypeInfo := ExtractURLTypeInfo(sanitizedURL) + + formattedImgContent := AnthropicImageContent{ + Type: urlTypeInfo.Type, + } + + if urlTypeInfo.MediaType != nil { + formattedImgContent.MediaType = *urlTypeInfo.MediaType + } + + if urlTypeInfo.DataURLWithoutPrefix != nil { + formattedImgContent.URL = *urlTypeInfo.DataURLWithoutPrefix + } else { + formattedImgContent.URL = sanitizedURL + } + + sourceMap := map[string]interface{}{ + "type": string(formattedImgContent.Type), // "base64" or "url" + } + + if formattedImgContent.Type == ImageContentTypeURL { + sourceMap["url"] = formattedImgContent.URL + } else { + if formattedImgContent.MediaType != "" { + sourceMap["media_type"] = formattedImgContent.MediaType + } + sourceMap["data"] = formattedImgContent.URL // URL field contains base64 data string + } + return sourceMap +} + +func prepareAnthropicChatRequest(messages []schemas.BifrostMessage, params *schemas.ModelParameters) ([]map[string]interface{}, map[string]interface{}) { + // Add system messages if present + var systemMessages []BedrockAnthropicSystemMessage + for _, msg := range messages { + if msg.Role == schemas.ModelChatMessageRoleSystem { + if msg.Content.ContentStr != nil { + systemMessages = append(systemMessages, BedrockAnthropicSystemMessage{ + Text: *msg.Content.ContentStr, + }) + } else if msg.Content.ContentBlocks != nil { + for _, block := range *msg.Content.ContentBlocks { + if block.Text != nil { + systemMessages = append(systemMessages, BedrockAnthropicSystemMessage{ + Text: *block.Text, + }) + } + } + } + } + } + // Format messages for Anthropic API var formattedMessages []map[string]interface{} for _, msg := range messages { - if msg.ImageContent != nil { - var content []map[string]interface{} + var content []interface{} - imageContent := map[string]interface{}{ - "type": "image", - "source": map[string]interface{}{ - "type": msg.ImageContent.Type, - }, - } + if msg.Role != schemas.ModelChatMessageRoleSystem { + if msg.Role == schemas.ModelChatMessageRoleTool && msg.ToolMessage != nil && msg.ToolMessage.ToolCallID != nil { + toolCallResult := map[string]interface{}{ + "type": "tool_result", + "tool_use_id": *msg.ToolMessage.ToolCallID, + } + + var toolCallResultContent []map[string]interface{} + + if msg.Content.ContentStr != nil { + toolCallResultContent = append(toolCallResultContent, map[string]interface{}{ + "type": "text", + "text": *msg.Content.ContentStr, + }) + } else if msg.Content.ContentBlocks != nil { + for _, block := range *msg.Content.ContentBlocks { + if block.Text != nil { + toolCallResultContent = append(toolCallResultContent, map[string]interface{}{ + "type": "text", + "text": *block.Text, + }) + } + } + } - // Handle different image source types - if *msg.ImageContent.Type == "url" { - imageContent["source"].(map[string]interface{})["url"] = msg.ImageContent.URL + toolCallResult["content"] = toolCallResultContent + content = append(content, toolCallResult) } else { - imageContent["source"].(map[string]interface{})["media_type"] = msg.ImageContent.MediaType - imageContent["source"].(map[string]interface{})["data"] = msg.ImageContent.URL - } + // Add text content if present + if msg.Content.ContentStr != nil && *msg.Content.ContentStr != "" { + content = append(content, map[string]interface{}{ + "type": "text", + "text": *msg.Content.ContentStr, + }) + } else if msg.Content.ContentBlocks != nil { + for _, block := range *msg.Content.ContentBlocks { + if block.Text != nil && *block.Text != "" { + content = append(content, map[string]interface{}{ + "type": "text", + "text": *block.Text, + }) + } + if block.ImageURL != nil { + imageSource := buildAnthropicImageSourceMap(block.ImageURL) + if imageSource != nil { + content = append(content, map[string]interface{}{ + "type": "image", + "source": imageSource, + }) + } + } + } + } - content = append(content, imageContent) + // Add thinking content if present in AssistantMessage + if msg.AssistantMessage != nil && msg.AssistantMessage.Thought != nil { + content = append(content, map[string]interface{}{ + "type": "thinking", + "thinking": *msg.AssistantMessage.Thought, + }) + } - // Add text content if present - if msg.Content != nil { - content = append(content, map[string]interface{}{ - "type": "text", - "text": msg.Content, - }) + // Add tool calls as content if present + if msg.AssistantMessage != nil && msg.AssistantMessage.ToolCalls != nil { + for _, toolCall := range *msg.AssistantMessage.ToolCalls { + if toolCall.Function.Name != nil { + var input map[string]interface{} + if toolCall.Function.Arguments != "" { + if err := sonic.Unmarshal([]byte(toolCall.Function.Arguments), &input); err != nil { + // If unmarshaling fails, use a simple string representation + input = map[string]interface{}{"arguments": toolCall.Function.Arguments} + } + } + + toolUseContent := map[string]interface{}{ + "type": "tool_use", + "name": *toolCall.Function.Name, + "input": input, + } + + if toolCall.ID != nil { + toolUseContent["id"] = *toolCall.ID + } + + content = append(content, toolUseContent) + } + } + } } - formattedMessages = append(formattedMessages, map[string]interface{}{ - "role": msg.Role, - "content": content, - }) - } else { - formattedMessages = append(formattedMessages, map[string]interface{}{ - "role": msg.Role, - "content": msg.Content, - }) + if len(content) > 0 { + formattedMessages = append(formattedMessages, map[string]interface{}{ + "role": msg.Role, + "content": content, + }) + } } } preparedParams := prepareParams(params) + // If max_tokens is not provided, set a default value + if _, exists := preparedParams["max_tokens"]; !exists { + preparedParams["max_tokens"] = AnthropicDefaultMaxTokens + } + // Transform tools if present if params != nil && params.Tools != nil && len(*params.Tools) > 0 { var tools []map[string]interface{} @@ -338,49 +652,124 @@ func (provider *AnthropicProvider) ChatCompletion(model, key string, messages [] preparedParams["tools"] = tools } - // Merge additional parameters - requestBody := mergeConfig(map[string]interface{}{ - "model": model, - "messages": formattedMessages, - }, preparedParams) + // Transform tool choice if present + if params != nil && params.ToolChoice != nil { + if params.ToolChoice.ToolChoiceStr != nil { + preparedParams["tool_choice"] = map[string]interface{}{ + "type": *params.ToolChoice.ToolChoiceStr, + } + } else if params.ToolChoice.ToolChoiceStruct != nil { + switch toolChoice := params.ToolChoice.ToolChoiceStruct.Type; toolChoice { + case schemas.ToolChoiceTypeFunction: + fallthrough + case "tool": + preparedParams["tool_choice"] = map[string]interface{}{ + "type": "tool", + "name": params.ToolChoice.ToolChoiceStruct.Function.Name, + } + default: + preparedParams["tool_choice"] = map[string]interface{}{ + "type": toolChoice, + } + } + } + } - responseBody, err := provider.completeRequest(requestBody, "https://api.anthropic.com/v1/messages", key) - if err != nil { - return nil, err + if len(systemMessages) > 0 { + var messages []string + for _, message := range systemMessages { + messages = append(messages, message.Text) + } + + preparedParams["system"] = strings.Join(messages, " ") } - // Create response object from pool - response := acquireAnthropicChatResponse() - defer releaseAnthropicChatResponse(response) + // Post-process formattedMessages for tool call results + processedFormattedMessages := []map[string]interface{}{} // Use a new slice + i := 0 + for i < len(formattedMessages) { + currentMsg := formattedMessages[i] + currentRole, roleOk := getRoleFromMessage(currentMsg) + + if !roleOk || currentRole == "" { + // If role is of an unexpected type, missing, or empty, treat as non-tool message + processedFormattedMessages = append(processedFormattedMessages, currentMsg) + i++ + continue + } - // Create Bifrost response from pool - bifrostResponse := acquireBifrostResponse() - defer releaseBifrostResponse(bifrostResponse) + if currentRole == schemas.ModelChatMessageRoleTool { + // Content of a tool message is the toolCallResult map + // Initialize accumulatedToolResults with the content of the current tool message. + var accumulatedToolResults []interface{} - rawResponse, bifrostErr := handleProviderResponse(responseBody, response) - if bifrostErr != nil { - return nil, bifrostErr + // Safely extract content from current message + if content, ok := currentMsg["content"].([]interface{}); ok { + accumulatedToolResults = content + } else { + // If content is not the expected type, skip this message + processedFormattedMessages = append(processedFormattedMessages, currentMsg) + i++ + continue + } + + // Look ahead for more sequential tool messages + j := i + 1 + for j < len(formattedMessages) { + nextMsg := formattedMessages[j] + nextRole, nextRoleOk := getRoleFromMessage(nextMsg) + + if !nextRoleOk || nextRole == "" || nextRole != schemas.ModelChatMessageRoleTool { + break // Not a sequential tool message or role is invalid/missing/empty + } + + // Safely extract content from next message + if nextContent, ok := nextMsg["content"].([]interface{}); ok { + accumulatedToolResults = append(accumulatedToolResults, nextContent...) + } + j++ + } + + // Create a new message with role User and accumulated content + mergedMsg := map[string]interface{}{ + "role": schemas.ModelChatMessageRoleUser, // Final role is User + "content": accumulatedToolResults, + } + processedFormattedMessages = append(processedFormattedMessages, mergedMsg) + i = j // Advance main loop index past all merged messages + } else { + // Not a tool message, add it as is + processedFormattedMessages = append(processedFormattedMessages, currentMsg) + i++ + } } + formattedMessages = processedFormattedMessages // Update with processed messages - // Process the response into our BifrostResponse format - var choices []schemas.BifrostResponseChoice + return formattedMessages, preparedParams +} - // Process content and tool calls - for i, c := range response.Content { - var content string - var toolCalls []schemas.ToolCall +func parseAnthropicResponse(response *AnthropicChatResponse, bifrostResponse *schemas.BifrostResponse) (*schemas.BifrostResponse, *schemas.BifrostError) { + // Collect all content and tool calls into a single message + var toolCalls []schemas.ToolCall + var thinking string + var contentBlocks []schemas.ContentBlock + // Process content and tool calls + for _, c := range response.Content { switch c.Type { case "thinking": - content = c.Thinking + thinking = c.Thinking case "text": - content = c.Text + contentBlocks = append(contentBlocks, schemas.ContentBlock{ + Type: "text", + Text: &c.Text, + }) case "tool_use": function := schemas.FunctionCall{ Name: &c.Name, } - args, err := json.Marshal(c.Input) + args, err := sonic.Marshal(c.Input) if err != nil { function.Arguments = fmt.Sprintf("%v", c.Input) } else { @@ -388,36 +777,501 @@ func (provider *AnthropicProvider) ChatCompletion(model, key string, messages [] } toolCalls = append(toolCalls, schemas.ToolCall{ - Type: StrPtr("function"), + Type: Ptr("function"), ID: &c.ID, Function: function, }) } + } - choices = append(choices, schemas.BifrostResponseChoice{ - Index: i, - Message: schemas.BifrostResponseChoiceMessage{ - Role: schemas.RoleAssistant, - Content: &content, - ToolCalls: &toolCalls, - }, - FinishReason: &response.StopReason, - StopString: response.StopSequence, - }) + // Create the assistant message + var assistantMessage *schemas.AssistantMessage + + // Create AssistantMessage if we have tool calls or thinking + if len(toolCalls) > 0 || thinking != "" { + assistantMessage = &schemas.AssistantMessage{} + if len(toolCalls) > 0 { + assistantMessage.ToolCalls = &toolCalls + } + if thinking != "" { + assistantMessage.Thought = &thinking + } } + // Create a single choice with the collected content bifrostResponse.ID = response.ID - bifrostResponse.Choices = choices - bifrostResponse.Usage = schemas.LLMUsage{ + bifrostResponse.Choices = []schemas.BifrostResponseChoice{ + { + Index: 0, + BifrostNonStreamResponseChoice: &schemas.BifrostNonStreamResponseChoice{ + Message: schemas.BifrostMessage{ + Role: schemas.ModelChatMessageRoleAssistant, + Content: schemas.MessageContent{ + ContentBlocks: &contentBlocks, + }, + AssistantMessage: assistantMessage, + }, + StopString: response.StopSequence, + }, + FinishReason: func() *string { + if response.StopReason != "" { + mapped := MapAnthropicFinishReason(response.StopReason) + return &mapped + } + return nil + }(), + }, + } + bifrostResponse.Usage = &schemas.LLMUsage{ PromptTokens: response.Usage.InputTokens, CompletionTokens: response.Usage.OutputTokens, TotalTokens: response.Usage.InputTokens + response.Usage.OutputTokens, } bifrostResponse.Model = response.Model - bifrostResponse.ExtraFields = schemas.BifrostResponseExtraFields{ - Provider: schemas.Anthropic, - RawResponse: rawResponse, - } return bifrostResponse, nil } + +// Embedding is not supported by the Anthropic provider. +func (provider *AnthropicProvider) Embedding(ctx context.Context, model string, key schemas.Key, input *schemas.EmbeddingInput, params *schemas.ModelParameters) (*schemas.BifrostResponse, *schemas.BifrostError) { + return nil, newUnsupportedOperationError("embedding", "anthropic") +} + +// ChatCompletionStream performs a streaming chat completion request to the Anthropic API. +// It supports real-time streaming of responses using Server-Sent Events (SSE). +// Returns a channel containing BifrostResponse objects representing the stream or an error if the request fails. +func (provider *AnthropicProvider) ChatCompletionStream(ctx context.Context, postHookRunner schemas.PostHookRunner, model string, key schemas.Key, messages []schemas.BifrostMessage, params *schemas.ModelParameters) (chan *schemas.BifrostStream, *schemas.BifrostError) { + if err := checkOperationAllowed(schemas.Anthropic, provider.customProviderConfig, schemas.OperationChatCompletionStream); err != nil { + return nil, err + } + + formattedMessages, preparedParams := prepareAnthropicChatRequest(messages, params) + + // Merge additional parameters and set stream to true + requestBody := mergeConfig(map[string]interface{}{ + "model": model, + "messages": formattedMessages, + "stream": true, + }, preparedParams) + + // Prepare Anthropic headers + headers := map[string]string{ + "Content-Type": "application/json", + "x-api-key": key.Value, + "anthropic-version": provider.apiVersion, + "Accept": "text/event-stream", + "Cache-Control": "no-cache", + } + + // Use shared Anthropic streaming logic + return handleAnthropicStreaming( + ctx, + provider.streamClient, + provider.networkConfig.BaseURL+"/v1/messages", + requestBody, + headers, + provider.networkConfig.ExtraHeaders, + provider.GetProviderKey(), + params, + postHookRunner, + provider.logger, + ) +} + +// handleAnthropicStreaming handles streaming for Anthropic-compatible APIs (Anthropic, Vertex Claude models). +// This shared function reduces code duplication between providers that use the same SSE event format. +func handleAnthropicStreaming( + ctx context.Context, + httpClient *http.Client, + url string, + requestBody map[string]interface{}, + headers map[string]string, + extraHeaders map[string]string, + providerType schemas.ModelProvider, + params *schemas.ModelParameters, + postHookRunner schemas.PostHookRunner, + logger schemas.Logger, +) (chan *schemas.BifrostStream, *schemas.BifrostError) { + + jsonBody, err := sonic.Marshal(requestBody) + if err != nil { + return nil, newBifrostOperationError(schemas.ErrProviderJSONMarshaling, err, providerType) + } + + // Create HTTP request for streaming + req, err := http.NewRequestWithContext(ctx, "POST", url, bytes.NewReader(jsonBody)) + if err != nil { + return nil, newBifrostOperationError(schemas.ErrProviderRequest, err, providerType) + } + + // Set headers + for key, value := range headers { + req.Header.Set(key, value) + } + + // Set any extra headers from network config + setExtraHeadersHTTP(req, extraHeaders, nil) + + // Make the request + resp, err := httpClient.Do(req) + if err != nil { + return nil, newBifrostOperationError(schemas.ErrProviderRequest, err, providerType) + } + + // Check for HTTP errors + if resp.StatusCode != http.StatusOK { + body, _ := io.ReadAll(resp.Body) + resp.Body.Close() + return nil, newProviderAPIError(fmt.Sprintf("HTTP error from %s: %d", providerType, resp.StatusCode), fmt.Errorf("%s", string(body)), resp.StatusCode, providerType, nil, nil) + } + + // Create response channel + responseChan := make(chan *schemas.BifrostStream, schemas.DefaultStreamBufferSize) + + // Start streaming in a goroutine + go func() { + defer close(responseChan) + defer resp.Body.Close() + + scanner := bufio.NewScanner(resp.Body) + chunkIndex := -1 + + // Track minimal state needed for response format + var messageID string + var modelName string + var usage *schemas.LLMUsage + var finishReason *string + + // Track SSE event parsing state + var eventType string + var eventData string + + for scanner.Scan() { + line := scanner.Text() + + // Skip empty lines and comments + if line == "" || strings.HasPrefix(line, ":") { + continue + } + + // Parse SSE event - track event type and data separately + if strings.HasPrefix(line, "event: ") { + eventType = strings.TrimPrefix(line, "event: ") + continue + } else if strings.HasPrefix(line, "data: ") { + eventData = strings.TrimPrefix(line, "data: ") + } else { + continue + } + + // Skip if we don't have both event type and data + if eventType == "" || eventData == "" { + continue + } + + var event AnthropicStreamEvent + if err := sonic.Unmarshal([]byte(eventData), &event); err != nil { + logger.Warn(fmt.Sprintf("Failed to parse message_start event: %v", err)) + continue + } + + if event.Usage != nil { + usage = &schemas.LLMUsage{ + PromptTokens: event.Usage.InputTokens, + CompletionTokens: event.Usage.OutputTokens, + TotalTokens: event.Usage.InputTokens + event.Usage.OutputTokens, + } + } + if event.Delta != nil && event.Delta.StopReason != nil { + mappedReason := MapAnthropicFinishReason(*event.Delta.StopReason) + finishReason = &mappedReason + } + + // Handle different event types + switch eventType { + case "message_start": + if event.Message != nil { + messageID = event.Message.ID + modelName = event.Message.Model + + // Send first chunk with role + if event.Message.Role != "" { + chunkIndex++ + role := event.Message.Role + + // Create streaming response for message start with role + streamResponse := &schemas.BifrostResponse{ + ID: messageID, + Object: "chat.completion.chunk", + Model: modelName, + Choices: []schemas.BifrostResponseChoice{ + { + Index: 0, + BifrostStreamResponseChoice: &schemas.BifrostStreamResponseChoice{ + Delta: schemas.BifrostStreamDelta{ + Role: &role, + }, + }, + }, + }, + ExtraFields: schemas.BifrostResponseExtraFields{ + Provider: providerType, + ChunkIndex: chunkIndex, + }, + } + + // Use utility function to process and send response + processAndSendResponse(ctx, postHookRunner, streamResponse, responseChan, logger) + } + } + + case "content_block_start": + if event.Index != nil && event.ContentBlock != nil { + chunkIndex++ + + // Handle different content block types + switch event.ContentBlock.Type { + case "tool_use": + // Tool use content block initialization + if event.ContentBlock.Name != "" && event.ContentBlock.ID != "" { + // Create streaming response for tool start + streamResponse := &schemas.BifrostResponse{ + ID: messageID, + Object: "chat.completion.chunk", + Model: modelName, + Choices: []schemas.BifrostResponseChoice{ + { + Index: *event.Index, + BifrostStreamResponseChoice: &schemas.BifrostStreamResponseChoice{ + Delta: schemas.BifrostStreamDelta{ + ToolCalls: []schemas.ToolCall{ + { + Type: func() *string { s := "function"; return &s }(), + ID: &event.ContentBlock.ID, + Function: schemas.FunctionCall{ + Name: &event.ContentBlock.Name, + }, + }, + }, + }, + }, + }, + }, + ExtraFields: schemas.BifrostResponseExtraFields{ + Provider: providerType, + ChunkIndex: chunkIndex, + }, + } + + // Use utility function to process and send response + processAndSendResponse(ctx, postHookRunner, streamResponse, responseChan, logger) + } + default: + thought := "" + if event.ContentBlock.Thinking != "" { + thought = event.ContentBlock.Thinking + } + content := "" + if event.ContentBlock.Text != "" { + content = event.ContentBlock.Text + } + + // Send empty message for other content block types + streamResponse := &schemas.BifrostResponse{ + ID: messageID, + Object: "chat.completion.chunk", + Model: modelName, + Choices: []schemas.BifrostResponseChoice{ + { + Index: *event.Index, + BifrostStreamResponseChoice: &schemas.BifrostStreamResponseChoice{ + Delta: schemas.BifrostStreamDelta{ + Thought: &thought, + Content: &content, + }, + }, + }, + }, + ExtraFields: schemas.BifrostResponseExtraFields{ + Provider: providerType, + ChunkIndex: chunkIndex, + }, + } + + // Use utility function to process and send response + processAndSendResponse(ctx, postHookRunner, streamResponse, responseChan, logger) + } + } + + case "content_block_delta": + if event.Index != nil && event.Delta != nil { + chunkIndex++ + + // Handle different delta types + switch event.Delta.Type { + case "text_delta": + if event.Delta.Text != "" { + // Create streaming response for this delta + streamResponse := &schemas.BifrostResponse{ + ID: messageID, + Object: "chat.completion.chunk", + Model: modelName, + Choices: []schemas.BifrostResponseChoice{ + { + Index: *event.Index, + BifrostStreamResponseChoice: &schemas.BifrostStreamResponseChoice{ + Delta: schemas.BifrostStreamDelta{ + Content: &event.Delta.Text, + }, + }, + }, + }, + ExtraFields: schemas.BifrostResponseExtraFields{ + Provider: providerType, + ChunkIndex: chunkIndex, + }, + } + + // Use utility function to process and send response + processAndSendResponse(ctx, postHookRunner, streamResponse, responseChan, logger) + } + + case "input_json_delta": + // Handle tool use streaming - accumulate partial JSON + if event.Delta.PartialJSON != "" { + // Create streaming response for tool input delta + streamResponse := &schemas.BifrostResponse{ + ID: messageID, + Object: "chat.completion.chunk", + Model: modelName, + Choices: []schemas.BifrostResponseChoice{ + { + Index: *event.Index, + BifrostStreamResponseChoice: &schemas.BifrostStreamResponseChoice{ + Delta: schemas.BifrostStreamDelta{ + ToolCalls: []schemas.ToolCall{ + { + Type: func() *string { s := "function"; return &s }(), + Function: schemas.FunctionCall{ + Arguments: event.Delta.PartialJSON, + }, + }, + }, + }, + }, + }, + }, + ExtraFields: schemas.BifrostResponseExtraFields{ + Provider: providerType, + ChunkIndex: chunkIndex, + }, + } + + // Use utility function to process and send response + processAndSendResponse(ctx, postHookRunner, streamResponse, responseChan, logger) + } + + case "thinking_delta": + // Handle thinking content streaming + if event.Delta.Thinking != "" { + // Create streaming response for thinking delta + streamResponse := &schemas.BifrostResponse{ + ID: messageID, + Object: "chat.completion.chunk", + Model: modelName, + Choices: []schemas.BifrostResponseChoice{ + { + Index: *event.Index, + BifrostStreamResponseChoice: &schemas.BifrostStreamResponseChoice{ + Delta: schemas.BifrostStreamDelta{ + Thought: &event.Delta.Thinking, + }, + }, + }, + }, + ExtraFields: schemas.BifrostResponseExtraFields{ + Provider: providerType, + ChunkIndex: chunkIndex, + }, + } + + // Use utility function to process and send response + processAndSendResponse(ctx, postHookRunner, streamResponse, responseChan, logger) + } + + case "signature_delta": + // Handle signature verification for thinking content + // This is used to verify the integrity of thinking content + + } + } + + case "content_block_stop": + // Content block is complete, no specific action needed for streaming + continue + + case "message_delta": + continue + + case "message_stop": + continue + + case "ping": + // Ping events are just keepalive, no action needed + continue + + case "error": + if event.Error != nil { + // Send error through channel before closing + bifrostErr := &schemas.BifrostError{ + IsBifrostError: false, + Error: schemas.ErrorField{ + Type: &event.Error.Type, + Message: event.Error.Message, + }, + } + + ctx = context.WithValue(ctx, schemas.BifrostContextKeyStreamEndIndicator, true) + processAndSendBifrostError(ctx, postHookRunner, bifrostErr, responseChan, logger) + } + return + + default: + // Unknown event type - handle gracefully as per Anthropic's versioning policy + // New event types may be added, so we should not error but log and continue + logger.Debug(fmt.Sprintf("Unknown %s stream event type: %s, data: %s", providerType, eventType, eventData)) + continue + } + + // Reset for next event + eventType = "" + eventData = "" + } + + if err := scanner.Err(); err != nil { + logger.Warn(fmt.Sprintf("Error reading %s stream: %v", providerType, err)) + processAndSendError(ctx, postHookRunner, err, responseChan, logger) + } else { + response := createBifrostChatCompletionChunkResponse(messageID, usage, finishReason, chunkIndex, params, providerType) + handleStreamEndWithSuccess(ctx, response, postHookRunner, responseChan, logger) + } + }() + + return responseChan, nil +} + +func (provider *AnthropicProvider) Speech(ctx context.Context, model string, key schemas.Key, input *schemas.SpeechInput, params *schemas.ModelParameters) (*schemas.BifrostResponse, *schemas.BifrostError) { + return nil, newUnsupportedOperationError("speech", "anthropic") +} + +func (provider *AnthropicProvider) SpeechStream(ctx context.Context, postHookRunner schemas.PostHookRunner, model string, key schemas.Key, input *schemas.SpeechInput, params *schemas.ModelParameters) (chan *schemas.BifrostStream, *schemas.BifrostError) { + return nil, newUnsupportedOperationError("speech stream", "anthropic") +} + +func (provider *AnthropicProvider) Transcription(ctx context.Context, model string, key schemas.Key, input *schemas.TranscriptionInput, params *schemas.ModelParameters) (*schemas.BifrostResponse, *schemas.BifrostError) { + return nil, newUnsupportedOperationError("transcription", "anthropic") +} + +func (provider *AnthropicProvider) TranscriptionStream(ctx context.Context, postHookRunner schemas.PostHookRunner, model string, key schemas.Key, input *schemas.TranscriptionInput, params *schemas.ModelParameters) (chan *schemas.BifrostStream, *schemas.BifrostError) { + return nil, newUnsupportedOperationError("transcription stream", "anthropic") +} diff --git a/core/providers/azure.go b/core/providers/azure.go index 13e8e2ee1..c542732c0 100644 --- a/core/providers/azure.go +++ b/core/providers/azure.go @@ -3,12 +3,13 @@ package providers import ( + "context" "fmt" + "net/http" "sync" "time" - "github.com/goccy/go-json" - + "github.com/bytedance/sonic" schemas "github.com/maximhq/bifrost/core/schemas" "github.com/valyala/fasthttp" ) @@ -30,18 +31,6 @@ type AzureTextResponse struct { Usage schemas.LLMUsage `json:"usage"` // Token usage statistics } -// AzureChatResponse represents the response structure from Azure's chat completion API. -// It includes completion choices, model information, and usage statistics. -type AzureChatResponse struct { - ID string `json:"id"` // Unique identifier for the completion - Object string `json:"object"` // Type of completion (always "chat.completion") - Choices []schemas.BifrostResponseChoice `json:"choices"` // Array of completion choices - Model string `json:"model"` // Model used for the completion - Created int `json:"created"` // Unix timestamp of completion creation - SystemFingerprint *string `json:"system_fingerprint"` // System fingerprint for the request - Usage schemas.LLMUsage `json:"usage"` // Token usage statistics -} - // AzureError represents the error response structure from Azure's API. // It includes error code and message information. type AzureError struct { @@ -51,6 +40,9 @@ type AzureError struct { } `json:"error"` } +// AzureAuthorizationTokenKey is the context key for the Azure authentication token. +const AzureAuthorizationTokenKey ContextKey = "azure-authorization-token" + // azureTextCompletionResponsePool provides a pool for Azure text completion response objects. var azureTextCompletionResponsePool = sync.Pool{ New: func() interface{} { @@ -58,26 +50,26 @@ var azureTextCompletionResponsePool = sync.Pool{ }, } -// azureChatResponsePool provides a pool for Azure chat response objects. -var azureChatResponsePool = sync.Pool{ - New: func() interface{} { - return &AzureChatResponse{} - }, -} - -// acquireAzureChatResponse gets an Azure chat response from the pool and resets it. -func acquireAzureChatResponse() *AzureChatResponse { - resp := azureChatResponsePool.Get().(*AzureChatResponse) - *resp = AzureChatResponse{} // Reset the struct - return resp -} - -// releaseAzureChatResponse returns an Azure chat response to the pool. -func releaseAzureChatResponse(resp *AzureChatResponse) { - if resp != nil { - azureChatResponsePool.Put(resp) - } -} +// // azureChatResponsePool provides a pool for Azure chat response objects. +// var azureChatResponsePool = sync.Pool{ +// New: func() interface{} { +// return &schemas.BifrostResponse{} +// }, +// } + +// // acquireAzureChatResponse gets an Azure chat response from the pool and resets it. +// func acquireAzureChatResponse() *schemas.BifrostResponse { +// resp := azureChatResponsePool.Get().(*schemas.BifrostResponse) +// *resp = schemas.BifrostResponse{} // Reset the struct +// return resp +// } + +// // releaseAzureChatResponse returns an Azure chat response to the pool. +// func releaseAzureChatResponse(resp *schemas.BifrostResponse) { +// if resp != nil { +// azureChatResponsePool.Put(resp) +// } +// } // acquireAzureTextResponse gets an Azure text completion response from the pool and resets it. func acquireAzureTextResponse() *AzureTextResponse { @@ -95,38 +87,47 @@ func releaseAzureTextResponse(resp *AzureTextResponse) { // AzureProvider implements the Provider interface for Azure's OpenAI API. type AzureProvider struct { - logger schemas.Logger // Logger for provider operations - client *fasthttp.Client // HTTP client for API requests - meta schemas.MetaConfig // Azure-specific configuration + logger schemas.Logger // Logger for provider operations + client *fasthttp.Client // HTTP client for API requests + streamClient *http.Client // HTTP client for streaming requests + networkConfig schemas.NetworkConfig // Network configuration including extra headers + sendBackRawResponse bool // Whether to include raw response in BifrostResponse } // NewAzureProvider creates a new Azure provider instance. // It initializes the HTTP client with the provided configuration and sets up response pools. // The client is configured with timeouts, concurrency limits, and optional proxy settings. -func NewAzureProvider(config *schemas.ProviderConfig, logger schemas.Logger) *AzureProvider { - setConfigDefaults(config) +func NewAzureProvider(config *schemas.ProviderConfig, logger schemas.Logger) (*AzureProvider, error) { + config.CheckAndSetDefaults() client := &fasthttp.Client{ ReadTimeout: time.Second * time.Duration(config.NetworkConfig.DefaultRequestTimeoutInSeconds), WriteTimeout: time.Second * time.Duration(config.NetworkConfig.DefaultRequestTimeoutInSeconds), - MaxConnsPerHost: config.ConcurrencyAndBufferSize.BufferSize, + MaxConnsPerHost: config.ConcurrencyAndBufferSize.Concurrency, + } + + // Initialize streaming HTTP client + streamClient := &http.Client{ + Timeout: time.Second * time.Duration(config.NetworkConfig.DefaultRequestTimeoutInSeconds), } // Pre-warm response pools for range config.ConcurrencyAndBufferSize.Concurrency { - azureChatResponsePool.Put(&AzureChatResponse{}) + // azureChatResponsePool.Put(&schemas.BifrostResponse{}) azureTextCompletionResponsePool.Put(&AzureTextResponse{}) - bifrostResponsePool.Put(&schemas.BifrostResponse{}) + } // Configure proxy if provided client = configureProxy(client, config.ProxyConfig, logger) return &AzureProvider{ - logger: logger, - client: client, - meta: config.MetaConfig, - } + logger: logger, + client: client, + streamClient: streamClient, + networkConfig: config.NetworkConfig, + sendBackRawResponse: config.SendBackRawResponse, + }, nil } // GetProviderKey returns the provider identifier for Azure. @@ -137,54 +138,37 @@ func (provider *AzureProvider) GetProviderKey() schemas.ModelProvider { // completeRequest sends a request to Azure's API and handles the response. // It constructs the API URL, sets up authentication, and processes the response. // Returns the response body or an error if the request fails. -func (provider *AzureProvider) completeRequest(requestBody map[string]interface{}, path string, key string, model string) ([]byte, *schemas.BifrostError) { +func (provider *AzureProvider) completeRequest(ctx context.Context, requestBody map[string]interface{}, path string, key schemas.Key, model string) ([]byte, *schemas.BifrostError) { + if key.AzureKeyConfig == nil { + return nil, newConfigurationError("azure key config not set", schemas.Azure) + } + // Marshal the request body - jsonData, err := json.Marshal(requestBody) + jsonData, err := sonic.Marshal(requestBody) if err != nil { - return nil, &schemas.BifrostError{ - IsBifrostError: true, - Error: schemas.ErrorField{ - Message: schemas.ErrProviderJSONMarshaling, - Error: err, - }, - } + return nil, newBifrostOperationError(schemas.ErrProviderJSONMarshaling, err, schemas.Azure) } - if provider.meta.GetEndpoint() == nil { - return nil, &schemas.BifrostError{ - IsBifrostError: false, - Error: schemas.ErrorField{ - Message: "endpoint not set", - }, - } + if key.AzureKeyConfig.Endpoint == "" { + return nil, newConfigurationError("endpoint not set", schemas.Azure) } - url := *provider.meta.GetEndpoint() + url := key.AzureKeyConfig.Endpoint - if provider.meta.GetDeployments() != nil { - deployment := provider.meta.GetDeployments()[model] + if key.AzureKeyConfig.Deployments != nil { + deployment := key.AzureKeyConfig.Deployments[model] if deployment == "" { - return nil, &schemas.BifrostError{ - IsBifrostError: false, - Error: schemas.ErrorField{ - Message: fmt.Sprintf("deployment if not found for model %s", model), - }, - } + return nil, newConfigurationError(fmt.Sprintf("deployment not found for model %s", model), schemas.Azure) } - apiVersion := provider.meta.GetAPIVersion() + apiVersion := key.AzureKeyConfig.APIVersion if apiVersion == nil { - apiVersion = StrPtr("2024-02-01") + apiVersion = Ptr("2024-02-01") } url = fmt.Sprintf("%s/openai/deployments/%s/%s?api-version=%s", url, deployment, path, *apiVersion) } else { - return nil, &schemas.BifrostError{ - IsBifrostError: false, - Error: schemas.ErrorField{ - Message: "deployments not set", - }, - } + return nil, newConfigurationError("deployments not set", schemas.Azure) } // Create the request with the JSON body @@ -193,25 +177,32 @@ func (provider *AzureProvider) completeRequest(requestBody map[string]interface{ defer fasthttp.ReleaseRequest(req) defer fasthttp.ReleaseResponse(resp) + // Set any extra headers from network config + setExtraHeaders(req, provider.networkConfig.ExtraHeaders, nil) + req.SetRequestURI(url) req.Header.SetMethod("POST") req.Header.SetContentType("application/json") - req.Header.Set("api-key", key) + if authToken, ok := ctx.Value(AzureAuthorizationTokenKey).(string); ok { + req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", authToken)) + // Ensure api-key is not accidentally present (from extra headers, etc.) + req.Header.Del("api-key") + } else { + req.Header.Set("api-key", key.Value) + } + req.SetBody(jsonData) // Send the request - if err := provider.client.Do(req, resp); err != nil { - return nil, &schemas.BifrostError{ - IsBifrostError: false, - Error: schemas.ErrorField{ - Message: schemas.ErrProviderRequest, - Error: err, - }, - } + bifrostErr := makeRequestWithContext(ctx, provider.client, req, resp) + if bifrostErr != nil { + return nil, bifrostErr } // Handle error response if resp.StatusCode() != fasthttp.StatusOK { + provider.logger.Debug(fmt.Sprintf("error from azure provider: %s", string(resp.Body()))) + var errorResp AzureError bifrostErr := handleProviderAPIError(resp, &errorResp) @@ -230,7 +221,7 @@ func (provider *AzureProvider) completeRequest(requestBody map[string]interface{ // TextCompletion performs a text completion request to Azure's API. // It formats the request, sends it to Azure, and processes the response. // Returns a BifrostResponse containing the completion results or an error if the request fails. -func (provider *AzureProvider) TextCompletion(model, key, text string, params *schemas.ModelParameters) (*schemas.BifrostResponse, *schemas.BifrostError) { +func (provider *AzureProvider) TextCompletion(ctx context.Context, model string, key schemas.Key, text string, params *schemas.ModelParameters) (*schemas.BifrostResponse, *schemas.BifrostError) { preparedParams := prepareParams(params) // Merge additional parameters @@ -239,7 +230,7 @@ func (provider *AzureProvider) TextCompletion(model, key, text string, params *s "prompt": text, }, preparedParams) - responseBody, err := provider.completeRequest(requestBody, "completions", key, model) + responseBody, err := provider.completeRequest(ctx, requestBody, "completions", key, model) if err != nil { return nil, err } @@ -248,11 +239,7 @@ func (provider *AzureProvider) TextCompletion(model, key, text string, params *s response := acquireAzureTextResponse() defer releaseAzureTextResponse(response) - // Create Bifrost response from pool - bifrostResponse := acquireBifrostResponse() - defer releaseBifrostResponse(bifrostResponse) - - rawResponse, bifrostErr := handleProviderResponse(responseBody, response) + rawResponse, bifrostErr := handleProviderResponse(responseBody, response, provider.sendBackRawResponse) if bifrostErr != nil { return nil, bifrostErr } @@ -263,26 +250,41 @@ func (provider *AzureProvider) TextCompletion(model, key, text string, params *s if len(response.Choices) > 0 { choices = append(choices, schemas.BifrostResponseChoice{ Index: 0, - Message: schemas.BifrostResponseChoiceMessage{ - Role: schemas.RoleAssistant, - Content: &response.Choices[0].Text, + BifrostNonStreamResponseChoice: &schemas.BifrostNonStreamResponseChoice{ + Message: schemas.BifrostMessage{ + Role: schemas.ModelChatMessageRoleAssistant, + Content: schemas.MessageContent{ + ContentStr: &response.Choices[0].Text, + }, + }, + LogProbs: &schemas.LogProbs{ + Text: response.Choices[0].LogProbs, + }, }, FinishReason: response.Choices[0].FinishReason, - LogProbs: &schemas.LogProbs{ - Text: response.Choices[0].LogProbs, - }, }) } - bifrostResponse.ID = response.ID - bifrostResponse.Choices = choices - bifrostResponse.Model = response.Model - bifrostResponse.Created = response.Created - bifrostResponse.SystemFingerprint = response.SystemFingerprint - bifrostResponse.Usage = response.Usage - bifrostResponse.ExtraFields = schemas.BifrostResponseExtraFields{ - Provider: schemas.Azure, - RawResponse: rawResponse, + // Create final response + bifrostResponse := &schemas.BifrostResponse{ + ID: response.ID, + Choices: choices, + Model: response.Model, + Created: response.Created, + SystemFingerprint: response.SystemFingerprint, + Usage: &response.Usage, + ExtraFields: schemas.BifrostResponseExtraFields{ + Provider: schemas.Azure, + }, + } + + // Set raw response if enabled + if provider.sendBackRawResponse { + bifrostResponse.ExtraFields.RawResponse = rawResponse + } + + if params != nil { + bifrostResponse.ExtraFields.Params = *params } return bifrostResponse, nil @@ -291,17 +293,8 @@ func (provider *AzureProvider) TextCompletion(model, key, text string, params *s // ChatCompletion performs a chat completion request to Azure's API. // It formats the request, sends it to Azure, and processes the response. // Returns a BifrostResponse containing the completion results or an error if the request fails. -func (provider *AzureProvider) ChatCompletion(model, key string, messages []schemas.Message, params *schemas.ModelParameters) (*schemas.BifrostResponse, *schemas.BifrostError) { - preparedParams := prepareParams(params) - - // Format messages for Azure API - var formattedMessages []map[string]interface{} - for _, msg := range messages { - formattedMessages = append(formattedMessages, map[string]interface{}{ - "role": msg.Role, - "content": msg.Content, - }) - } +func (provider *AzureProvider) ChatCompletion(ctx context.Context, model string, key schemas.Key, messages []schemas.BifrostMessage, params *schemas.ModelParameters) (*schemas.BifrostResponse, *schemas.BifrostError) { + formattedMessages, preparedParams := prepareOpenAIChatRequest(messages, params) // Merge additional parameters requestBody := mergeConfig(map[string]interface{}{ @@ -309,34 +302,155 @@ func (provider *AzureProvider) ChatCompletion(model, key string, messages []sche "messages": formattedMessages, }, preparedParams) - responseBody, err := provider.completeRequest(requestBody, "chat/completions", key, model) + responseBody, err := provider.completeRequest(ctx, requestBody, "chat/completions", key, model) if err != nil { return nil, err } // Create response object from pool - response := acquireAzureChatResponse() - defer releaseAzureChatResponse(response) + // response := acquireAzureChatResponse() + // defer releaseAzureChatResponse(response) - // Create Bifrost response from pool - bifrostResponse := acquireBifrostResponse() - defer releaseBifrostResponse(bifrostResponse) + response := &schemas.BifrostResponse{} - rawResponse, bifrostErr := handleProviderResponse(responseBody, response) + rawResponse, bifrostErr := handleProviderResponse(responseBody, response, provider.sendBackRawResponse) if bifrostErr != nil { return nil, bifrostErr } - bifrostResponse.ID = response.ID - bifrostResponse.Choices = response.Choices - bifrostResponse.Model = response.Model - bifrostResponse.Created = response.Created - bifrostResponse.SystemFingerprint = response.SystemFingerprint - bifrostResponse.Usage = response.Usage - bifrostResponse.ExtraFields = schemas.BifrostResponseExtraFields{ - Provider: schemas.Azure, - RawResponse: rawResponse, + response.ExtraFields.Provider = schemas.Azure + + // Set raw response if enabled + if provider.sendBackRawResponse { + response.ExtraFields.RawResponse = rawResponse + } + + if params != nil { + response.ExtraFields.Params = *params } - return bifrostResponse, nil + return response, nil +} + +// Embedding generates embeddings for the given input text(s) using Azure OpenAI. +// The input can be either a single string or a slice of strings for batch embedding. +// Returns a BifrostResponse containing the embedding(s) and any error that occurred. +func (provider *AzureProvider) Embedding(ctx context.Context, model string, key schemas.Key, input *schemas.EmbeddingInput, params *schemas.ModelParameters) (*schemas.BifrostResponse, *schemas.BifrostError) { + // Prepare request body - Azure uses deployment-scoped URLs, so model is not needed in body + requestBody := prepareOpenAIEmbeddingRequest(input, params) + + responseBody, err := provider.completeRequest(ctx, requestBody, "embeddings", key, model) + if err != nil { + return nil, err + } + + // Pre-allocate response structs from pools + // response := acquireAzureChatResponse() + // defer releaseAzureChatResponse(response) + + response := &schemas.BifrostResponse{} + + // Use enhanced response handler with pre-allocated response + rawResponse, bifrostErr := handleProviderResponse(responseBody, response, provider.sendBackRawResponse) + if bifrostErr != nil { + return nil, bifrostErr + } + + response.ExtraFields.Provider = schemas.Azure + + if params != nil { + response.ExtraFields.Params = *params + } + + if provider.sendBackRawResponse { + response.ExtraFields.RawResponse = rawResponse + } + + return response, nil +} + +// ChatCompletionStream performs a streaming chat completion request to Azure's OpenAI API. +// It supports real-time streaming of responses using Server-Sent Events (SSE). +// Uses Azure-specific URL construction with deployments and supports both api-key and Bearer token authentication. +// Returns a channel containing BifrostResponse objects representing the stream or an error if the request fails. +func (provider *AzureProvider) ChatCompletionStream(ctx context.Context, postHookRunner schemas.PostHookRunner, model string, key schemas.Key, messages []schemas.BifrostMessage, params *schemas.ModelParameters) (chan *schemas.BifrostStream, *schemas.BifrostError) { + formattedMessages, preparedParams := prepareOpenAIChatRequest(messages, params) + + if key.AzureKeyConfig == nil { + return nil, newConfigurationError("azure key config not set", schemas.Azure) + } + + // Merge additional parameters and set stream to true + requestBody := mergeConfig(map[string]interface{}{ + "model": model, + "messages": formattedMessages, + "stream": true, + }, preparedParams) + + // Construct Azure-specific URL with deployment + if key.AzureKeyConfig.Endpoint == "" { + return nil, newConfigurationError("endpoint not set", schemas.Azure) + } + + baseURL := key.AzureKeyConfig.Endpoint + var fullURL string + + if key.AzureKeyConfig.Deployments != nil { + deployment := key.AzureKeyConfig.Deployments[model] + if deployment == "" { + return nil, newConfigurationError(fmt.Sprintf("deployment not found for model %s", model), schemas.Azure) + } + + apiVersion := key.AzureKeyConfig.APIVersion + if apiVersion == nil { + apiVersion = Ptr("2024-02-01") + } + + fullURL = fmt.Sprintf("%s/openai/deployments/%s/chat/completions?api-version=%s", baseURL, deployment, *apiVersion) + } else { + return nil, newConfigurationError("deployments not set", schemas.Azure) + } + + // Prepare Azure-specific headers + headers := make(map[string]string) + headers["Content-Type"] = "application/json" + headers["Accept"] = "text/event-stream" + headers["Cache-Control"] = "no-cache" + + // Set Azure authentication - either Bearer token or api-key + if authToken, ok := ctx.Value(AzureAuthorizationTokenKey).(string); ok { + headers["Authorization"] = fmt.Sprintf("Bearer %s", authToken) + } else { + headers["api-key"] = key.Value + } + + // Use shared streaming logic from OpenAI + return handleOpenAIStreaming( + ctx, + provider.streamClient, + fullURL, + requestBody, + headers, + provider.networkConfig.ExtraHeaders, + schemas.Azure, // Provider type + params, + postHookRunner, + provider.logger, + ) +} + +func (provider *AzureProvider) Speech(ctx context.Context, model string, key schemas.Key, input *schemas.SpeechInput, params *schemas.ModelParameters) (*schemas.BifrostResponse, *schemas.BifrostError) { + return nil, newUnsupportedOperationError("speech", "azure") +} + +func (provider *AzureProvider) SpeechStream(ctx context.Context, postHookRunner schemas.PostHookRunner, model string, key schemas.Key, input *schemas.SpeechInput, params *schemas.ModelParameters) (chan *schemas.BifrostStream, *schemas.BifrostError) { + return nil, newUnsupportedOperationError("speech stream", "azure") +} + +func (provider *AzureProvider) Transcription(ctx context.Context, model string, key schemas.Key, input *schemas.TranscriptionInput, params *schemas.ModelParameters) (*schemas.BifrostResponse, *schemas.BifrostError) { + return nil, newUnsupportedOperationError("transcription", "azure") +} + +func (provider *AzureProvider) TranscriptionStream(ctx context.Context, postHookRunner schemas.PostHookRunner, model string, key schemas.Key, input *schemas.TranscriptionInput, params *schemas.ModelParameters) (chan *schemas.BifrostStream, *schemas.BifrostError) { + return nil, newUnsupportedOperationError("transcription stream", "azure") } diff --git a/core/providers/bedrock.go b/core/providers/bedrock.go index 3a9b35c39..f24ba8c35 100644 --- a/core/providers/bedrock.go +++ b/core/providers/bedrock.go @@ -7,19 +7,23 @@ import ( "context" "crypto/sha256" "encoding/hex" + "encoding/json" + "errors" "fmt" "io" + "maps" "net/http" "net/url" "strings" "sync" "time" - "github.com/goccy/go-json" + "bufio" "github.com/aws/aws-sdk-go-v2/aws" v4 "github.com/aws/aws-sdk-go-v2/aws/signer/v4" "github.com/aws/aws-sdk-go-v2/config" + "github.com/bytedance/sonic" schemas "github.com/maximhq/bifrost/core/schemas" ) @@ -49,7 +53,9 @@ type BedrockChatResponse struct { Output struct { Message struct { Content []struct { - Text string `json:"text"` // Message content + Text *string `json:"text"` // Message content + // Bedrock returns a union type where either Text or ToolUse is present (mutually exclusive) + BedrockAnthropicToolUseMessage } `json:"content"` // Array of message content Role string `json:"role"` // Role of the message sender } `json:"message"` // Message structure @@ -82,7 +88,7 @@ type BedrockMistralContent struct { type BedrockMistralChatMessage struct { Role schemas.ModelChatMessageRole `json:"role"` // Role of the message sender Content []BedrockMistralContent `json:"content"` // Array of message content - ToolCalls *[]BedrockMistralToolCall `json:"tool_calls,omitempty"` // Optional tool calls + ToolCalls *[]BedrockAnthropicToolCall `json:"tool_calls,omitempty"` // Optional tool calls ToolCallID *string `json:"tool_call_id,omitempty"` // Optional tool call ID } @@ -94,8 +100,8 @@ type BedrockAnthropicImageMessage struct { // BedrockAnthropicImage represents image data for Anthropic models. type BedrockAnthropicImage struct { - Format string `json:"string"` // Image format - Source BedrockAnthropicImageSource `json:"source"` // Image source + Format string `json:"format,omitempty"` // Image format + Source BedrockAnthropicImageSource `json:"source,omitempty"` // Image source } // BedrockAnthropicImageSource represents the source of an image for Anthropic models. @@ -103,10 +109,28 @@ type BedrockAnthropicImageSource struct { Bytes string `json:"bytes"` // Base64 encoded image data } -// BedrockMistralToolCall represents a tool call for Mistral models. -type BedrockMistralToolCall struct { - ID string `json:"id"` // Tool call ID - Function schemas.Function `json:"function"` // Function to call +// BedrockAnthropicToolUseMessage represents a tool use message for Anthropic models. +type BedrockAnthropicToolUseMessage struct { + ToolUse *BedrockAnthropicToolUse `json:"toolUse"` +} + +// BedrockToolChoice represents the tool choice configuration for Bedrock models. +type BedrockToolChoice struct { + Auto map[string]interface{} `json:"auto,omitempty"` + Any map[string]interface{} `json:"any,omitempty"` + Tool *BedrockSpecificTool `json:"tool,omitempty"` +} + +// BedrockSpecificTool represents a specific tool choice configuration. +type BedrockSpecificTool struct { + Type string `json:"type"` // "tool" always + Name string `json:"name"` +} + +type BedrockAnthropicToolUse struct { + ToolUseID string `json:"toolUseId"` + Name string `json:"name"` + Input map[string]interface{} `json:"input"` } // BedrockAnthropicToolCall represents a tool call for Anthropic models. @@ -128,11 +152,60 @@ type BedrockError struct { Message string `json:"message"` // Error message } +// BedrockStreamMessageStartEvent is emitted when the assistant message starts. +type BedrockStreamMessageStartEvent struct { + MessageStart struct { + Role string `json:"role"` // e.g. "assistant" + } `json:"messageStart"` +} + +// BedrockStreamContentBlockDeltaEvent is sent for each content delta chunk (text, reasoning, tool use). +type BedrockStreamContentBlockDeltaEvent struct { + ContentBlockDelta struct { + Delta struct { + Text string `json:"text,omitempty"` + ReasoningContent json.RawMessage `json:"reasoningContent,omitempty"` + ToolUse json.RawMessage `json:"toolUse,omitempty"` + } `json:"delta"` + ContentBlockIndex int `json:"contentBlockIndex"` + } `json:"contentBlockDelta"` +} + +// BedrockStreamContentBlockStopEvent indicates the end of a content block. +type BedrockStreamContentBlockStopEvent struct { + ContentBlockStop struct { + ContentBlockIndex int `json:"contentBlockIndex"` + } `json:"contentBlockStop"` +} + +// BedrockStreamMessageStopEvent marks the end of the assistant message. +type BedrockStreamMessageStopEvent struct { + MessageStop struct { + StopReason string `json:"stopReason"` // e.g. "stop", "max_tokens", "tool_use" + } `json:"messageStop"` +} + +// BedrockStreamMetadataEvent contains metadata after streaming ends. +type BedrockStreamMetadataEvent struct { + Metadata struct { + Usage struct { + InputTokens int `json:"inputTokens"` + OutputTokens int `json:"outputTokens"` + TotalTokens int `json:"totalTokens"` + } `json:"usage"` + Metrics struct { + LatencyMs float64 `json:"latencyMs"` + } `json:"metrics"` + } `json:"metadata"` +} + // BedrockProvider implements the Provider interface for AWS Bedrock. type BedrockProvider struct { - logger schemas.Logger // Logger for provider operations - client *http.Client // HTTP client for API requests - meta schemas.MetaConfig // AWS-specific configuration + logger schemas.Logger // Logger for provider operations + client *http.Client // HTTP client for API requests + networkConfig schemas.NetworkConfig // Network configuration including extra headers + customProviderConfig *schemas.CustomProviderConfig // Custom provider config + sendBackRawResponse bool // Whether to include raw response in BifrostResponse } // bedrockChatResponsePool provides a pool for Bedrock response objects. @@ -159,49 +232,53 @@ func releaseBedrockChatResponse(resp *BedrockChatResponse) { // NewBedrockProvider creates a new Bedrock provider instance. // It initializes the HTTP client with the provided configuration and sets up response pools. // The client is configured with timeouts and AWS-specific settings. -func NewBedrockProvider(config *schemas.ProviderConfig, logger schemas.Logger) *BedrockProvider { - setConfigDefaults(config) +func NewBedrockProvider(config *schemas.ProviderConfig, logger schemas.Logger) (*BedrockProvider, error) { + config.CheckAndSetDefaults() client := &http.Client{Timeout: time.Second * time.Duration(config.NetworkConfig.DefaultRequestTimeoutInSeconds)} // Pre-warm response pools for range config.ConcurrencyAndBufferSize.Concurrency { bedrockChatResponsePool.Put(&BedrockChatResponse{}) - bifrostResponsePool.Put(&schemas.BifrostResponse{}) } return &BedrockProvider{ - logger: logger, - client: client, - meta: config.MetaConfig, - } + logger: logger, + client: client, + networkConfig: config.NetworkConfig, + customProviderConfig: config.CustomProviderConfig, + sendBackRawResponse: config.SendBackRawResponse, + }, nil } // GetProviderKey returns the provider identifier for Bedrock. func (provider *BedrockProvider) GetProviderKey() schemas.ModelProvider { - return schemas.Bedrock + return getProviderName(schemas.Bedrock, provider.customProviderConfig) } // CompleteRequest sends a request to Bedrock's API and handles the response. // It constructs the API URL, sets up AWS authentication, and processes the response. // Returns the response body or an error if the request fails. -func (provider *BedrockProvider) completeRequest(requestBody map[string]interface{}, path string, accessKey string) ([]byte, *schemas.BifrostError) { - if provider.meta == nil { - return nil, &schemas.BifrostError{ - IsBifrostError: false, - Error: schemas.ErrorField{ - Message: "meta config for bedrock is not provided", - }, - } - } +func (provider *BedrockProvider) completeRequest(ctx context.Context, requestBody map[string]interface{}, path string, key schemas.Key) ([]byte, *schemas.BifrostError) { + config := key.BedrockKeyConfig region := "us-east-1" - if provider.meta.GetRegion() != nil { - region = *provider.meta.GetRegion() + if config.Region != nil { + region = *config.Region } - jsonBody, err := json.Marshal(requestBody) + jsonBody, err := sonic.Marshal(requestBody) if err != nil { + if errors.Is(err, context.Canceled) || errors.Is(err, context.DeadlineExceeded) { + return nil, &schemas.BifrostError{ + IsBifrostError: false, + Error: schemas.ErrorField{ + Type: Ptr(schemas.RequestCancelled), + Message: fmt.Sprintf("Request cancelled or timed out by context: %v", ctx.Err()), + Error: err, + }, + } + } return nil, &schemas.BifrostError{ IsBifrostError: true, Error: schemas.ErrorField{ @@ -212,7 +289,7 @@ func (provider *BedrockProvider) completeRequest(requestBody map[string]interfac } // Create the request with the JSON body - req, err := http.NewRequest("POST", fmt.Sprintf("https://bedrock-runtime.%s.amazonaws.com/model/%s", region, path), bytes.NewBuffer(jsonBody)) + req, err := http.NewRequestWithContext(ctx, "POST", fmt.Sprintf("https://bedrock-runtime.%s.amazonaws.com/model/%s", region, path), bytes.NewBuffer(jsonBody)) if err != nil { return nil, &schemas.BifrostError{ IsBifrostError: true, @@ -223,8 +300,17 @@ func (provider *BedrockProvider) completeRequest(requestBody map[string]interfac } } - if err := signAWSRequest(req, accessKey, *provider.meta.GetSecretAccessKey(), provider.meta.GetSessionToken(), region, "bedrock"); err != nil { - return nil, err + // Set any extra headers from network config + setExtraHeadersHTTP(req, provider.networkConfig.ExtraHeaders, nil) + + // If Value is set, use API Key authentication - else use IAM role authentication + if key.Value != "" { + req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", key.Value)) + } else { + // Sign the request using either explicit credentials or IAM role authentication + if err := signAWSRequest(ctx, req, config.AccessKey, config.SecretKey, config.SessionToken, region, "bedrock", provider.GetProviderKey()); err != nil { + return nil, err + } } // Execute the request @@ -255,9 +341,10 @@ func (provider *BedrockProvider) completeRequest(requestBody map[string]interfac if resp.StatusCode != http.StatusOK { var errorResp BedrockError - if err := json.Unmarshal(body, &errorResp); err != nil { + if err := sonic.Unmarshal(body, &errorResp); err != nil { return nil, &schemas.BifrostError{ IsBifrostError: true, + StatusCode: &resp.StatusCode, Error: schemas.ErrorField{ Message: schemas.ErrProviderResponseUnmarshal, Error: err, @@ -279,15 +366,11 @@ func (provider *BedrockProvider) completeRequest(requestBody map[string]interfac // GetTextCompletionResult processes the text completion response from Bedrock. // It handles different model types (Anthropic and Mistral) and formats the response. // Returns a BifrostResponse containing the completion results or an error if processing fails. -func (provider *BedrockProvider) getTextCompletionResult(result []byte, model string) (*schemas.BifrostResponse, *schemas.BifrostError) { - switch model { - case "anthropic.claude-instant-v1:2": - fallthrough - case "anthropic.claude-v2": - fallthrough - case "anthropic.claude-v2:1": +func (provider *BedrockProvider) getTextCompletionResult(result []byte, model string, providerName schemas.ModelProvider) (*schemas.BifrostResponse, *schemas.BifrostError) { + switch { + case strings.Contains(model, "anthropic."): var response BedrockAnthropicTextResponse - if err := json.Unmarshal(result, &response); err != nil { + if err := sonic.Unmarshal(result, &response); err != nil { return nil, &schemas.BifrostError{ IsBifrostError: true, Error: schemas.ErrorField{ @@ -301,31 +384,27 @@ func (provider *BedrockProvider) getTextCompletionResult(result []byte, model st Choices: []schemas.BifrostResponseChoice{ { Index: 0, - Message: schemas.BifrostResponseChoiceMessage{ - Role: schemas.RoleAssistant, - Content: &response.Completion, + BifrostNonStreamResponseChoice: &schemas.BifrostNonStreamResponseChoice{ + Message: schemas.BifrostMessage{ + Role: schemas.ModelChatMessageRoleAssistant, + Content: schemas.MessageContent{ + ContentStr: &response.Completion, + }, + }, + StopString: &response.Stop, }, FinishReason: &response.StopReason, - StopString: &response.Stop, }, }, Model: model, ExtraFields: schemas.BifrostResponseExtraFields{ - Provider: schemas.Bedrock, + Provider: providerName, }, }, nil - case "mistral.mixtral-8x7b-instruct-v0:1": - fallthrough - case "mistral.mistral-7b-instruct-v0:2": - fallthrough - case "mistral.mistral-large-2402-v1:0": - fallthrough - case "mistral.mistral-large-2407-v1:0": - fallthrough - case "mistral.mistral-small-2402-v1:0": + case strings.Contains(model, "mistral."): var response BedrockMistralTextResponse - if err := json.Unmarshal(result, &response); err != nil { + if err := sonic.Unmarshal(result, &response); err != nil { return nil, &schemas.BifrostError{ IsBifrostError: true, Error: schemas.ErrorField{ @@ -339,9 +418,13 @@ func (provider *BedrockProvider) getTextCompletionResult(result []byte, model st for i, output := range response.Outputs { choices = append(choices, schemas.BifrostResponseChoice{ Index: i, - Message: schemas.BifrostResponseChoiceMessage{ - Role: schemas.RoleAssistant, - Content: &output.Text, + BifrostNonStreamResponseChoice: &schemas.BifrostNonStreamResponseChoice{ + Message: schemas.BifrostMessage{ + Role: schemas.ModelChatMessageRoleAssistant, + Content: schemas.MessageContent{ + ContentStr: &output.Text, + }, + }, }, FinishReason: &output.StopReason, }) @@ -351,123 +434,326 @@ func (provider *BedrockProvider) getTextCompletionResult(result []byte, model st Choices: choices, Model: model, ExtraFields: schemas.BifrostResponseExtraFields{ - Provider: schemas.Bedrock, + Provider: providerName, }, }, nil } - return nil, &schemas.BifrostError{ - IsBifrostError: false, - Error: schemas.ErrorField{ - Message: fmt.Sprintf("invalid model choice: %s", model), - }, + return nil, newConfigurationError(fmt.Sprintf("invalid model choice: %s", model), providerName) +} + +// parseBedrockAnthropicMessageToolCallContent parses the content of a tool call message. +// It handles both text and JSON content. +// Returns a map containing the parsed content. +func parseBedrockAnthropicMessageToolCallContent(content string) map[string]interface{} { + toolResultContentBlock := map[string]interface{}{} + var parsedJSON interface{} + err := sonic.Unmarshal([]byte(content), &parsedJSON) + if err == nil { + if arr, ok := parsedJSON.([]interface{}); ok { + toolResultContentBlock["json"] = map[string]interface{}{"content": arr} + } else { + toolResultContentBlock["json"] = map[string]interface{}{"output": parsedJSON} + } + } else { + toolResultContentBlock["text"] = content } + + return toolResultContentBlock } // PrepareChatCompletionMessages formats chat messages for Bedrock's API. // It handles different model types (Anthropic and Mistral) and formats messages accordingly. // Returns a map containing the formatted messages and any system messages, or an error if formatting fails. -func (provider *BedrockProvider) prepareChatCompletionMessages(messages []schemas.Message, model string) (map[string]interface{}, *schemas.BifrostError) { - switch model { - case "anthropic.claude-instant-v1:2": - fallthrough - case "anthropic.claude-v2": - fallthrough - case "anthropic.claude-v2:1": - fallthrough - case "anthropic.claude-3-sonnet-20240229-v1:0": - fallthrough - case "anthropic.claude-3-5-sonnet-20240620-v1:0": - fallthrough - case "anthropic.claude-3-5-sonnet-20241022-v2:0": - fallthrough - case "anthropic.claude-3-5-haiku-20241022-v1:0": - fallthrough - case "anthropic.claude-3-opus-20240229-v1:0": - fallthrough - case "anthropic.claude-3-7-sonnet-20250219-v1:0": +func (provider *BedrockProvider) prepareChatCompletionMessages(messages []schemas.BifrostMessage, model string) (map[string]interface{}, *schemas.BifrostError) { + switch { + case strings.Contains(model, "anthropic."): // Add system messages if present var systemMessages []BedrockAnthropicSystemMessage for _, msg := range messages { - if msg.Role == schemas.RoleSystem { - //TODO handling image inputs here - systemMessages = append(systemMessages, BedrockAnthropicSystemMessage{ - Text: *msg.Content, - }) + if msg.Role == schemas.ModelChatMessageRoleSystem { + if msg.Content.ContentStr != nil { + systemMessages = append(systemMessages, BedrockAnthropicSystemMessage{ + Text: *msg.Content.ContentStr, + }) + } else if msg.Content.ContentBlocks != nil { + for _, block := range *msg.Content.ContentBlocks { + if block.Text != nil { + systemMessages = append(systemMessages, BedrockAnthropicSystemMessage{ + Text: *block.Text, + }) + } + } + } } } // Format messages for Bedrock API var bedrockMessages []map[string]interface{} for _, msg := range messages { - if msg.Role != schemas.RoleSystem { - var content any - if msg.Content != nil { - content = BedrockAnthropicTextMessage{ - Type: "text", - Text: *msg.Content, + var content []interface{} + if msg.Role != schemas.ModelChatMessageRoleSystem { + if msg.Role == schemas.ModelChatMessageRoleTool && msg.ToolCallID != nil { + toolCallResult := map[string]interface{}{ + "toolUseId": *msg.ToolCallID, } - } else if msg.ImageContent != nil { - content = BedrockAnthropicImageMessage{ - Type: "image", - Image: BedrockAnthropicImage{ - Format: *msg.ImageContent.Type, - Source: BedrockAnthropicImageSource{ - Bytes: msg.ImageContent.URL, - }, - }, + var toolResultContentBlocks []map[string]interface{} + if msg.Content.ContentStr != nil { + toolResultContentBlocks = append(toolResultContentBlocks, parseBedrockAnthropicMessageToolCallContent(*msg.Content.ContentStr)) + } else if msg.Content.ContentBlocks != nil { + for _, block := range *msg.Content.ContentBlocks { + if block.Text != nil { + toolResultContentBlocks = append(toolResultContentBlocks, parseBedrockAnthropicMessageToolCallContent(*block.Text)) + } + } + } + toolCallResult["content"] = toolResultContentBlocks + content = append(content, map[string]interface{}{ + "toolResult": toolCallResult, + }) + } else { + // Bedrock wants only toolUse block on content, text blocks are not allowed when tools are called. + if msg.AssistantMessage != nil && msg.AssistantMessage.ToolCalls != nil { + for _, toolCall := range *msg.AssistantMessage.ToolCalls { + var input map[string]interface{} + if toolCall.Function.Arguments != "" { + if err := sonic.Unmarshal([]byte(toolCall.Function.Arguments), &input); err != nil { + input = map[string]interface{}{"arguments": toolCall.Function.Arguments} + } + } + + content = append(content, BedrockAnthropicToolUseMessage{ + ToolUse: &BedrockAnthropicToolUse{ + ToolUseID: *toolCall.ID, + Name: *toolCall.Function.Name, + Input: input, + }, + }) + } + } else { + if msg.Content.ContentStr != nil { + content = append(content, BedrockAnthropicTextMessage{ + Type: "text", + Text: *msg.Content.ContentStr, + }) + } else if msg.Content.ContentBlocks != nil { + for _, block := range *msg.Content.ContentBlocks { + if block.Text != nil { + content = append(content, BedrockAnthropicTextMessage{ + Type: "text", + Text: *block.Text, + }) + } + if block.ImageURL != nil { + sanitizedURL, _ := SanitizeImageURL(block.ImageURL.URL) + urlTypeInfo := ExtractURLTypeInfo(sanitizedURL) + + formattedImgContent := AnthropicImageContent{ + Type: urlTypeInfo.Type, + } + + if urlTypeInfo.MediaType != nil { + formattedImgContent.MediaType = *urlTypeInfo.MediaType + } + + if urlTypeInfo.DataURLWithoutPrefix != nil { + formattedImgContent.URL = *urlTypeInfo.DataURLWithoutPrefix + } else { + formattedImgContent.URL = sanitizedURL + } + + content = append(content, BedrockAnthropicImageMessage{ + Type: "image", + Image: BedrockAnthropicImage{ + Format: func() string { + if formattedImgContent.MediaType != "" { + mediaType := formattedImgContent.MediaType + // Remove "image/" prefix if present, since normalizeMediaType ensures full format + mediaType = strings.TrimPrefix(mediaType, "image/") + return mediaType + } + return "" + }(), + Source: BedrockAnthropicImageSource{ + Bytes: formattedImgContent.URL, + }, + }, + }) + } + } + } } + } - bedrockMessages = append(bedrockMessages, map[string]interface{}{ - "role": msg.Role, - "content": []interface{}{content}, - }) + if len(content) > 0 { + bedrockMessages = append(bedrockMessages, map[string]interface{}{ + "role": msg.Role, + "content": content, + }) + } + } + } + + // Post-process bedrockMessages for tool call results + processedBedrockMessages := []map[string]interface{}{} + i := 0 + for i < len(bedrockMessages) { + currentMsg := bedrockMessages[i] + currentRole, roleOk := getRoleFromMessage(currentMsg) + + if !roleOk { + // If role is of an unexpected type or missing, treat as non-tool message + processedBedrockMessages = append(processedBedrockMessages, currentMsg) + i++ + continue + } + + if currentRole == schemas.ModelChatMessageRoleTool { + // Content of a tool message is the toolCallResult map + // Initialize accumulatedToolResults with the content of the current tool message. + var accumulatedToolResults []interface{} + + // Safely extract content from current message + if content, ok := currentMsg["content"].([]interface{}); ok { + accumulatedToolResults = content + } else { + // If content is not the expected type, skip this message + processedBedrockMessages = append(processedBedrockMessages, currentMsg) + i++ + continue + } + + // Look ahead for more sequential tool messages + j := i + 1 + for j < len(bedrockMessages) { + nextMsg := bedrockMessages[j] + nextRole, nextRoleOk := getRoleFromMessage(nextMsg) + + if !nextRoleOk || nextRole != schemas.ModelChatMessageRoleTool { + break // Not a sequential tool message or role is invalid/missing + } + + // Safely extract content from next message + if nextContent, ok := nextMsg["content"].([]interface{}); ok { + accumulatedToolResults = append(accumulatedToolResults, nextContent...) + } + j++ + } + + // Create a new message with role User and accumulated content + mergedMsg := map[string]interface{}{ + "role": schemas.ModelChatMessageRoleUser, // Final role is User + "content": accumulatedToolResults, + } + processedBedrockMessages = append(processedBedrockMessages, mergedMsg) + i = j // Advance main loop index past all merged messages + } else { + // Not a tool message, add it as is + processedBedrockMessages = append(processedBedrockMessages, currentMsg) + i++ } } + bedrockMessages = processedBedrockMessages // Update with processed messages body := map[string]interface{}{ "messages": bedrockMessages, } if len(systemMessages) > 0 { - var messages []string - for _, message := range systemMessages { - messages = append(messages, message.Text) - } - - body["system"] = strings.Join(messages, " ") + body["system"] = systemMessages } - return body, nil - case "mistral.mistral-large-2402-v1:0": - fallthrough - case "mistral.mistral-large-2407-v1:0": + case strings.Contains(model, "mistral."): var bedrockMessages []BedrockMistralChatMessage for _, msg := range messages { - var filteredToolCalls []BedrockMistralToolCall - if msg.ToolCalls != nil { - for _, toolCall := range *msg.ToolCalls { - filteredToolCalls = append(filteredToolCalls, BedrockMistralToolCall{ - ID: *toolCall.ID, - Function: toolCall.Function, - }) + // Check if this is a tool message before changing the role + isToolMessage := msg.Role == schemas.ModelChatMessageRoleTool + + // Convert tool messages to user messages (Mistral doesn't support tool role) + role := msg.Role + switch role { + case schemas.ModelChatMessageRoleTool, schemas.ModelChatMessageRoleSystem: + role = schemas.ModelChatMessageRoleUser + } + + // Only process user and assistant messages + if role != schemas.ModelChatMessageRoleUser && role != schemas.ModelChatMessageRoleAssistant { + continue + } + + var filteredToolCalls []BedrockAnthropicToolCall + if msg.AssistantMessage != nil && msg.AssistantMessage.ToolCalls != nil { + for _, toolCall := range *msg.AssistantMessage.ToolCalls { + if toolCall.ID != nil && toolCall.Function.Name != nil { + // Parse the arguments to get parameters + var params interface{} + if toolCall.Function.Arguments != "" { + if err := sonic.Unmarshal([]byte(toolCall.Function.Arguments), ¶ms); err != nil { + // If parsing fails, use empty object + params = map[string]interface{}{} + } + } + + filteredToolCalls = append(filteredToolCalls, BedrockAnthropicToolCall{ + ToolSpec: BedrockAnthropicToolSpec{ + Name: *toolCall.Function.Name, + Description: "Tool function", // Default description since FunctionCall doesn't have one + InputSchema: struct { + Json interface{} `json:"json"` + }{ + Json: params, + }, + }, + }) + } } } message := BedrockMistralChatMessage{ - Role: msg.Role, - Content: []BedrockMistralContent{ - {Text: *msg.Content}, - }, + Role: role, } - if len(filteredToolCalls) > 0 { - message.ToolCalls = &filteredToolCalls + // Ensure message has valid content + var hasValidContent bool + switch { + case msg.Content.ContentStr != nil && *msg.Content.ContentStr != "": + message.Content = []BedrockMistralContent{{Text: *msg.Content.ContentStr}} + hasValidContent = true + case msg.Content.ContentBlocks != nil && len(*msg.Content.ContentBlocks) > 0: + for _, b := range *msg.Content.ContentBlocks { + if b.Text != nil && *b.Text != "" { + message.Content = append(message.Content, BedrockMistralContent{Text: *b.Text}) + hasValidContent = true + } + } + } + + // For tool messages that were converted to user messages, ensure they have content + if isToolMessage && !hasValidContent { + // If tool message has no content, create a default content + defaultText := "Tool result received" + if msg.ToolCallID != nil { + defaultText = fmt.Sprintf("Tool result for call ID: %s", *msg.ToolCallID) + } + message.Content = []BedrockMistralContent{{Text: defaultText}} + hasValidContent = true + } + + // Final safety check: ensure message always has content + if !hasValidContent { + message.Content = []BedrockMistralContent{{Text: "Message content"}} + hasValidContent = true } - bedrockMessages = append(bedrockMessages, message) + // Only add messages that have valid content or tool calls + if hasValidContent || len(filteredToolCalls) > 0 { + if len(filteredToolCalls) > 0 { + message.ToolCalls = &filteredToolCalls + } + bedrockMessages = append(bedrockMessages, message) + } } body := map[string]interface{}{ @@ -477,38 +763,17 @@ func (provider *BedrockProvider) prepareChatCompletionMessages(messages []schema return body, nil } - return nil, &schemas.BifrostError{ - IsBifrostError: false, - Error: schemas.ErrorField{ - Message: fmt.Sprintf("invalid model choice: %s", model), - }, - } + return nil, newConfigurationError(fmt.Sprintf("invalid model choice: %s", model), provider.GetProviderKey()) } // GetChatCompletionTools prepares tool specifications for Bedrock's API. // It formats tool definitions for different model types (Anthropic and Mistral). -// Returns an array of tool specifications for the given model. -func (provider *BedrockProvider) getChatCompletionTools(params *schemas.ModelParameters, model string) []BedrockAnthropicToolCall { - var tools []BedrockAnthropicToolCall - - switch model { - case "anthropic.claude-instant-v1:2": - fallthrough - case "anthropic.claude-v2": - fallthrough - case "anthropic.claude-v2:1": - fallthrough - case "anthropic.claude-3-sonnet-20240229-v1:0": - fallthrough - case "anthropic.claude-3-5-sonnet-20240620-v1:0": - fallthrough - case "anthropic.claude-3-5-sonnet-20241022-v2:0": - fallthrough - case "anthropic.claude-3-5-haiku-20241022-v1:0": - fallthrough - case "anthropic.claude-3-opus-20240229-v1:0": - fallthrough - case "anthropic.claude-3-7-sonnet-20250219-v1:0": +// Returns tool specifications appropriate for the given model type. +func (provider *BedrockProvider) getChatCompletionTools(params *schemas.ModelParameters, model string) (interface{}, *schemas.BifrostError) { + switch { + case strings.Contains(model, "anthropic."), strings.Contains(model, "mistral."): + // Both Anthropic and Mistral models on Bedrock use toolConfig.tools with toolSpec structure + var tools []BedrockAnthropicToolCall for _, tool := range *params.Tools { tools = append(tools, BedrockAnthropicToolCall{ ToolSpec: BedrockAnthropicToolSpec{ @@ -522,30 +787,30 @@ func (provider *BedrockProvider) getChatCompletionTools(params *schemas.ModelPar }, }) } - } + return tools, nil - return tools + default: + return nil, newConfigurationError(fmt.Sprintf("unsupported model for tool calling: %s", model), provider.GetProviderKey()) + } } // prepareTextCompletionParams prepares text completion parameters for Bedrock's API. // It handles parameter mapping and conversion for different model types. // Returns the modified parameters map with model-specific adjustments. func (provider *BedrockProvider) prepareTextCompletionParams(params map[string]interface{}, model string) map[string]interface{} { - switch model { - case "anthropic.claude-instant-v1:2": - fallthrough - case "anthropic.claude-v2": - fallthrough - case "anthropic.claude-v2:1": - // Check if there is a key entry for max_tokens - if maxTokens, exists := params["max_tokens"]; exists { - // Check if max_tokens_to_sample is already present - if _, exists := params["max_tokens_to_sample"]; !exists { - // If max_tokens_to_sample is not present, rename max_tokens to max_tokens_to_sample + switch { + case strings.Contains(model, "anthropic."): + maxTokens, maxTokensExists := params["max_tokens"] + if _, exists := params["max_tokens_to_sample"]; !exists { + // If max_tokens_to_sample is not present, rename max_tokens to max_tokens_to_sample + if maxTokensExists { params["max_tokens_to_sample"] = maxTokens + } else { + params["max_tokens_to_sample"] = AnthropicDefaultMaxTokens } - delete(params, "max_tokens") } + + delete(params, "max_tokens") } return params } @@ -553,44 +818,159 @@ func (provider *BedrockProvider) prepareTextCompletionParams(params map[string]i // TextCompletion performs a text completion request to Bedrock's API. // It formats the request, sends it to Bedrock, and processes the response. // Returns a BifrostResponse containing the completion results or an error if the request fails. -func (provider *BedrockProvider) TextCompletion(model, key, text string, params *schemas.ModelParameters) (*schemas.BifrostResponse, *schemas.BifrostError) { +func (provider *BedrockProvider) TextCompletion(ctx context.Context, model string, key schemas.Key, text string, params *schemas.ModelParameters) (*schemas.BifrostResponse, *schemas.BifrostError) { + if err := checkOperationAllowed(schemas.Bedrock, provider.customProviderConfig, schemas.OperationTextCompletion); err != nil { + return nil, err + } + + providerName := provider.GetProviderKey() + + if key.BedrockKeyConfig == nil { + return nil, newConfigurationError("bedrock key config is not provided", providerName) + } + preparedParams := provider.prepareTextCompletionParams(prepareParams(params), model) requestBody := mergeConfig(map[string]interface{}{ "prompt": text, }, preparedParams) - body, err := provider.completeRequest(requestBody, fmt.Sprintf("%s/invoke", model), key) + path := provider.getModelPath("invoke", model, key) + body, err := provider.completeRequest(ctx, requestBody, path, key) if err != nil { return nil, err } - result, err := provider.getTextCompletionResult(body, model) + bifrostResponse, err := provider.getTextCompletionResult(body, model, providerName) if err != nil { return nil, err } - // Parse raw response - var rawResponse interface{} - if err := json.Unmarshal(body, &rawResponse); err != nil { - return nil, &schemas.BifrostError{ - IsBifrostError: true, - Error: schemas.ErrorField{ - Message: "error parsing raw response", - Error: err, - }, + // Parse raw response if enabled + if provider.sendBackRawResponse { + var rawResponse interface{} + if err := sonic.Unmarshal(body, &rawResponse); err != nil { + return nil, newBifrostOperationError("error parsing raw response", err, providerName) } + bifrostResponse.ExtraFields.RawResponse = rawResponse } - result.ExtraFields.RawResponse = rawResponse + if params != nil { + bifrostResponse.ExtraFields.Params = *params + } - return result, nil + return bifrostResponse, nil +} + +// extractToolsFromHistory extracts minimal tool definitions from conversation history. +// It analyzes the messages to find tool-related content and returns whether tool content +// was found and a list of unique minimal tool definitions extracted from the conversation. +// This is needed when Bedrock requires toolConfig but no tools are provided in the current request. +func (provider *BedrockProvider) extractToolsFromHistory(messages []schemas.BifrostMessage) (bool, []BedrockAnthropicToolCall) { + hasToolContent := false + var toolsFromHistory []BedrockAnthropicToolCall + seenTools := make(map[string]BedrockAnthropicToolCall) + + for _, msg := range messages { + // Check for tool result messages + if msg.Role == schemas.ModelChatMessageRoleTool { + hasToolContent = true + } + // Check for assistant messages with tool calls + if msg.Role == schemas.ModelChatMessageRoleAssistant && msg.AssistantMessage != nil && msg.AssistantMessage.ToolCalls != nil { + hasToolContent = true + // Extract tool definitions from tool calls for toolConfig + for _, toolCall := range *msg.AssistantMessage.ToolCalls { + if toolCall.Function.Name != nil { + toolName := *toolCall.Function.Name + if _, exists := seenTools[toolName]; !exists { + // Create a basic tool definition from the tool call + // Note: We can't fully reconstruct the original tool definition, + // but we can provide a minimal one that satisfies Bedrock's requirement + tool := BedrockAnthropicToolCall{ + ToolSpec: BedrockAnthropicToolSpec{ + Name: toolName, + Description: fmt.Sprintf("Tool: %s", toolName), + InputSchema: struct { + Json interface{} `json:"json"` + }{ + Json: map[string]interface{}{ + "type": "object", + "properties": map[string]interface{}{}, + }, + }, + }, + } + seenTools[toolName] = tool + toolsFromHistory = append(toolsFromHistory, tool) + } + } + } + } + } + + return hasToolContent, toolsFromHistory +} + +// prepareToolChoice prepares tool choice configuration for different model types. +// Both Anthropic and Mistral models on Bedrock support toolChoice in toolConfig. +func (provider *BedrockProvider) prepareToolChoice(params *schemas.ModelParameters, model string) interface{} { + if params == nil || params.ToolChoice == nil { + return nil + } + + switch { + case strings.Contains(model, "anthropic."), strings.Contains(model, "mistral."): + // Both Anthropic and Mistral models use toolChoice in toolConfig + // AWS Bedrock supports: "auto", "any", "tool" as union types + if params.ToolChoice.ToolChoiceStr != nil { + choice := *params.ToolChoice.ToolChoiceStr + switch choice { + case string(schemas.ToolChoiceTypeAuto), string(schemas.ToolChoiceTypeAny): + return nil + case string(schemas.ToolChoiceTypeFunction), "tool": + if params.ToolChoice.ToolChoiceStruct == nil { + return nil + } + return &BedrockToolChoice{ + Tool: &BedrockSpecificTool{ + Type: "tool", + Name: params.ToolChoice.ToolChoiceStruct.Function.Name, + }, + } + } + // Note: "none" is not supported by AWS Bedrock for these models + } else if params.ToolChoice.ToolChoiceStruct != nil { + if (params.ToolChoice.ToolChoiceStruct.Type == schemas.ToolChoiceTypeFunction || params.ToolChoice.ToolChoiceStruct.Type == "tool") && + params.ToolChoice.ToolChoiceStruct.Function.Name != "" { + + return &BedrockToolChoice{ + Tool: &BedrockSpecificTool{ + Type: "tool", + Name: params.ToolChoice.ToolChoiceStruct.Function.Name, + }, + } + } + } + } + + return nil } // ChatCompletion performs a chat completion request to Bedrock's API. // It formats the request, sends it to Bedrock, and processes the response. // Returns a BifrostResponse containing the completion results or an error if the request fails. -func (provider *BedrockProvider) ChatCompletion(model, key string, messages []schemas.Message, params *schemas.ModelParameters) (*schemas.BifrostResponse, *schemas.BifrostError) { +func (provider *BedrockProvider) ChatCompletion(ctx context.Context, model string, key schemas.Key, messages []schemas.BifrostMessage, params *schemas.ModelParameters) (*schemas.BifrostResponse, *schemas.BifrostError) { + if err := checkOperationAllowed(schemas.Bedrock, provider.customProviderConfig, schemas.OperationChatCompletion); err != nil { + return nil, err + } + + providerName := provider.GetProviderKey() + + if key.BedrockKeyConfig == nil { + return nil, newConfigurationError("bedrock key config is not provided", providerName) + } + messageBody, err := provider.prepareChatCompletionMessages(messages, model) if err != nil { return nil, err @@ -598,27 +978,52 @@ func (provider *BedrockProvider) ChatCompletion(model, key string, messages []sc preparedParams := prepareParams(params) + if strings.Contains(model, "anthropic.") { + if _, exists := preparedParams["max_tokens"]; !exists { + preparedParams["max_tokens"] = AnthropicDefaultMaxTokens + } + } + // Transform tools if present if params != nil && params.Tools != nil && len(*params.Tools) > 0 { - preparedParams["tools"] = provider.getChatCompletionTools(params, model) - } + tools, err := provider.getChatCompletionTools(params, model) + if err != nil { + return nil, err + } + toolConfig := map[string]interface{}{ + "tools": tools, + } - requestBody := mergeConfig(messageBody, preparedParams) + // Add tool choice if specified + if toolChoice := provider.prepareToolChoice(params, model); toolChoice != nil { + toolConfig["toolChoice"] = toolChoice + } - // Format the path with proper model identifier - path := fmt.Sprintf("%s/converse", model) + preparedParams["toolConfig"] = toolConfig - if provider.meta != nil && provider.meta.GetInferenceProfiles() != nil { - if inferenceProfileId, ok := provider.meta.GetInferenceProfiles()[model]; ok { - if provider.meta.GetARN() != nil { - encodedModelIdentifier := url.PathEscape(fmt.Sprintf("%s/%s", *provider.meta.GetARN(), inferenceProfileId)) - path = fmt.Sprintf("%s/converse", encodedModelIdentifier) + delete(preparedParams, "tools") + delete(preparedParams, "tool_choice") + } else { + // Check if conversation history contains tool use/result blocks + // Bedrock requires toolConfig when such blocks are present + hasToolContent, toolsFromHistory := provider.extractToolsFromHistory(messages) + + // If conversation contains tool content but no tools provided in current request, + // include the extracted tools to satisfy Bedrock's toolConfig requirement + if hasToolContent && len(toolsFromHistory) > 0 { + preparedParams["toolConfig"] = map[string]interface{}{ + "tools": toolsFromHistory, } } } + requestBody := mergeConfig(messageBody, preparedParams) + + // Format the path with proper model identifier + path := provider.getModelPath("converse", model, key) + // Create the signed request - responseBody, err := provider.completeRequest(requestBody, path, key) + responseBody, err := provider.completeRequest(ctx, requestBody, path, key) if err != nil { return nil, err } @@ -627,40 +1032,96 @@ func (provider *BedrockProvider) ChatCompletion(model, key string, messages []sc response := acquireBedrockChatResponse() defer releaseBedrockChatResponse(response) - // Create Bifrost response from pool - bifrostResponse := acquireBifrostResponse() - defer releaseBifrostResponse(bifrostResponse) - - rawResponse, bifrostErr := handleProviderResponse(responseBody, response) + rawResponse, bifrostErr := handleProviderResponse(responseBody, response, provider.sendBackRawResponse) if bifrostErr != nil { return nil, bifrostErr } - var choices []schemas.BifrostResponseChoice - for i, choice := range response.Output.Message.Content { - choices = append(choices, schemas.BifrostResponseChoice{ - Index: i, - Message: schemas.BifrostResponseChoiceMessage{ - Role: schemas.RoleAssistant, - Content: &choice.Text, + // Collect all content and tool calls into a single message (similar to Anthropic aggregation) + var toolCalls []schemas.ToolCall + + var contentBlocks []schemas.ContentBlock + // Process content and tool calls + for _, choice := range response.Output.Message.Content { + if choice.Text != nil && *choice.Text != "" { + contentBlocks = append(contentBlocks, schemas.ContentBlock{ + Type: "text", + Text: choice.Text, + }) + } + + if choice.ToolUse != nil { + input := choice.ToolUse.Input + if input == nil { + input = map[string]any{} + } + arguments, err := sonic.Marshal(input) + if err != nil { + arguments = []byte("{}") + } + + toolCalls = append(toolCalls, schemas.ToolCall{ + Type: Ptr("function"), + ID: &choice.ToolUse.ToolUseID, + Function: schemas.FunctionCall{ + Name: &choice.ToolUse.Name, + Arguments: string(arguments), + }, + }) + } + } + + // Create the assistant message + var assistantMessage *schemas.AssistantMessage + + // Create AssistantMessage if we have tool calls + if len(toolCalls) > 0 { + assistantMessage = &schemas.AssistantMessage{ + ToolCalls: &toolCalls, + } + } + + // Create a single choice with the aggregated content + choices := []schemas.BifrostResponseChoice{ + { + Index: 0, + BifrostNonStreamResponseChoice: &schemas.BifrostNonStreamResponseChoice{ + Message: schemas.BifrostMessage{ + Role: schemas.ModelChatMessageRoleAssistant, + Content: schemas.MessageContent{ + ContentBlocks: &contentBlocks, + }, + AssistantMessage: assistantMessage, + }, }, FinishReason: &response.StopReason, - }) + }, } latency := float64(response.Metrics.Latency) - bifrostResponse.Choices = choices - bifrostResponse.Usage = schemas.LLMUsage{ - PromptTokens: response.Usage.InputTokens, - CompletionTokens: response.Usage.OutputTokens, - TotalTokens: response.Usage.TotalTokens, + // Create final response + bifrostResponse := &schemas.BifrostResponse{ + Choices: choices, + Usage: &schemas.LLMUsage{ + PromptTokens: response.Usage.InputTokens, + CompletionTokens: response.Usage.OutputTokens, + TotalTokens: response.Usage.TotalTokens, + }, + Model: model, + ExtraFields: schemas.BifrostResponseExtraFields{ + Latency: &latency, + Provider: providerName, + }, } - bifrostResponse.Model = model - bifrostResponse.ExtraFields = schemas.BifrostResponseExtraFields{ - Latency: &latency, - Provider: schemas.Bedrock, - RawResponse: rawResponse, + + // Set raw response if enabled + if provider.sendBackRawResponse { + bifrostResponse.ExtraFields.RawResponse = rawResponse + } + + if params != nil { + bifrostResponse.ExtraFields.Params = *params } return bifrostResponse, nil @@ -671,7 +1132,7 @@ func (provider *BedrockProvider) ChatCompletion(model, key string, messages []sc // It sets required headers, calculates the request body hash, and signs the request // using the provided AWS credentials. // Returns a BifrostError if signing fails. -func signAWSRequest(req *http.Request, accessKey, secretKey string, sessionToken *string, region, service string) *schemas.BifrostError { +func signAWSRequest(ctx context.Context, req *http.Request, accessKey, secretKey string, sessionToken *string, region, service string, providerName schemas.ModelProvider) *schemas.BifrostError { // Set required headers before signing req.Header.Set("Content-Type", "application/json") req.Header.Set("Accept", "application/json") @@ -681,13 +1142,7 @@ func signAWSRequest(req *http.Request, accessKey, secretKey string, sessionToken if req.Body != nil { bodyBytes, err := io.ReadAll(req.Body) if err != nil { - return &schemas.BifrostError{ - IsBifrostError: true, - Error: schemas.ErrorField{ - Message: "error reading request body", - Error: err, - }, - } + return newBifrostOperationError("error reading request body", err, providerName) } // Restore the body for subsequent reads req.Body = io.NopCloser(bytes.NewBuffer(bodyBytes)) @@ -700,54 +1155,654 @@ func signAWSRequest(req *http.Request, accessKey, secretKey string, sessionToken bodyHash = hex.EncodeToString(hash[:]) } - cfg, err := config.LoadDefaultConfig(context.TODO(), - config.WithRegion(region), - config.WithCredentialsProvider(aws.CredentialsProviderFunc(func(ctx context.Context) (aws.Credentials, error) { - creds := aws.Credentials{ - AccessKeyID: accessKey, - SecretAccessKey: secretKey, - } - if sessionToken != nil { - creds.SessionToken = *sessionToken - } - return creds, nil - })), - ) + var cfg aws.Config + var err error + + // If both accessKey and secretKey are empty, use the default credential provider chain + // This will automatically use IAM roles, environment variables, shared credentials, etc. + if accessKey == "" && secretKey == "" { + cfg, err = config.LoadDefaultConfig(ctx, + config.WithRegion(region), + ) + } else { + // Use explicit credentials when provided + cfg, err = config.LoadDefaultConfig(ctx, + config.WithRegion(region), + config.WithCredentialsProvider(aws.CredentialsProviderFunc(func(ctx context.Context) (aws.Credentials, error) { + creds := aws.Credentials{ + AccessKeyID: accessKey, + SecretAccessKey: secretKey, + } + if sessionToken != nil && *sessionToken != "" { + creds.SessionToken = *sessionToken + } + return creds, nil + })), + ) + } if err != nil { - return &schemas.BifrostError{ - IsBifrostError: true, - Error: schemas.ErrorField{ - Message: "failed to load aws config", - Error: err, - }, - } + return newBifrostOperationError("failed to load aws config", err, providerName) } // Create the AWS signer signer := v4.NewSigner() // Get credentials - creds, err := cfg.Credentials.Retrieve(context.TODO()) + creds, err := cfg.Credentials.Retrieve(ctx) if err != nil { - return &schemas.BifrostError{ - IsBifrostError: true, - Error: schemas.ErrorField{ - Message: "failed to retrieve aws credentials", - Error: err, - }, - } + return newBifrostOperationError("failed to retrieve aws credentials", err, providerName) } // Sign the request with AWS Signature V4 - if err := signer.SignHTTP(context.TODO(), creds, req, bodyHash, service, region, time.Now()); err != nil { - return &schemas.BifrostError{ - IsBifrostError: true, - Error: schemas.ErrorField{ - Message: "failed to sign request", - Error: err, + if err := signer.SignHTTP(ctx, creds, req, bodyHash, service, region, time.Now()); err != nil { + return newBifrostOperationError("failed to sign request", err, providerName) + } + + return nil +} + +// Embedding generates embeddings for the given input text(s) using Amazon Bedrock. +// Supports Titan and Cohere embedding models. Returns a BifrostResponse containing the embedding(s) and any error that occurred. +func (provider *BedrockProvider) Embedding(ctx context.Context, model string, key schemas.Key, input *schemas.EmbeddingInput, params *schemas.ModelParameters) (*schemas.BifrostResponse, *schemas.BifrostError) { + if err := checkOperationAllowed(schemas.Bedrock, provider.customProviderConfig, schemas.OperationEmbedding); err != nil { + return nil, err + } + + providerName := provider.GetProviderKey() + + if key.BedrockKeyConfig == nil { + return nil, newConfigurationError("bedrock key config is not provided", providerName) + } + + switch { + case strings.Contains(model, "amazon.titan-embed-text"): + return provider.handleTitanEmbedding(ctx, model, key, input, params, providerName) + case strings.Contains(model, "cohere.embed"): + return provider.handleCohereEmbedding(ctx, model, key, input, params, providerName) + default: + return nil, newConfigurationError("embedding is not supported for this Bedrock model", providerName) + } +} + +// handleTitanEmbedding handles embedding requests for Amazon Titan models. +func (provider *BedrockProvider) handleTitanEmbedding(ctx context.Context, model string, key schemas.Key, input *schemas.EmbeddingInput, params *schemas.ModelParameters, providerName schemas.ModelProvider) (*schemas.BifrostResponse, *schemas.BifrostError) { + // Titan Text Embeddings V1/V2 - only supports single text input + if input.Text == nil && len(input.Texts) == 0 { + return nil, newConfigurationError("no input text provided for embedding", providerName) + } + + // Validate that only single text input is provided for Titan models + if input.Text == nil && len(input.Texts) > 1 { + return nil, newConfigurationError("Amazon Titan embedding models only support single text input, but multiple texts were provided", providerName) + } + + requestBody := map[string]interface{}{} + + if input.Text != nil { + requestBody["inputText"] = *input.Text + } else if len(input.Texts) == 1 { + requestBody["inputText"] = input.Texts[0] + } + + if params != nil { + // Titan models do not support the dimensions parameter - they have fixed dimensions + if params.Dimensions != nil { + return nil, newConfigurationError("Amazon Titan embedding models do not support custom dimensions parameter", providerName) + } + if params.ExtraParams != nil { + for k, v := range params.ExtraParams { + requestBody[k] = v + } + } + } + + // Properly escape model name for URL path to ensure AWS SIGv4 signing works correctly + path := provider.getModelPath("invoke", model, key) + rawResponse, err := provider.completeRequest(ctx, requestBody, path, key) + if err != nil { + return nil, err + } + + // Parse Titan response from raw message + var titanResp struct { + Embedding []float32 `json:"embedding"` + InputTextTokenCount int `json:"inputTextTokenCount"` + } + if err := sonic.Unmarshal(rawResponse, &titanResp); err != nil { + return nil, newBifrostOperationError("error parsing Titan embedding response", err, providerName) + } + + bifrostResponse := &schemas.BifrostResponse{ + Object: "list", + Data: []schemas.BifrostEmbedding{ + { + Index: 0, + Object: "embedding", + Embedding: schemas.BifrostEmbeddingResponse{ + Embedding2DArray: &[][]float32{titanResp.Embedding}, + }, }, + }, + Model: model, + Usage: &schemas.LLMUsage{ + PromptTokens: titanResp.InputTextTokenCount, + TotalTokens: titanResp.InputTextTokenCount, + }, + ExtraFields: schemas.BifrostResponseExtraFields{ + Provider: providerName, + }, + } + + if provider.sendBackRawResponse { + bifrostResponse.ExtraFields.RawResponse = rawResponse + } + + if params != nil { + bifrostResponse.ExtraFields.Params = *params + } + + return bifrostResponse, nil +} + +// handleCohereEmbedding handles embedding requests for Cohere models on Bedrock. +func (provider *BedrockProvider) handleCohereEmbedding(ctx context.Context, model string, key schemas.Key, input *schemas.EmbeddingInput, params *schemas.ModelParameters, providerName schemas.ModelProvider) (*schemas.BifrostResponse, *schemas.BifrostError) { + if input.Text == nil && len(input.Texts) == 0 { + return nil, newConfigurationError("no input text provided for embedding", providerName) + } + + requestBody := map[string]interface{}{ + "input_type": "search_document", + } + + if input.Text != nil { + requestBody["texts"] = []string{*input.Text} + } else { + requestBody["texts"] = input.Texts + } + + if params != nil && params.ExtraParams != nil { + maps.Copy(requestBody, params.ExtraParams) + } + + // Properly escape model name for URL path to ensure AWS SIGv4 signing works correctly + path := provider.getModelPath("invoke", model, key) + rawResponse, err := provider.completeRequest(ctx, requestBody, path, key) + if err != nil { + return nil, err + } + + var cohereResp CohereEmbeddingResponse + if err := sonic.Unmarshal(rawResponse, &cohereResp); err != nil { + return nil, newBifrostOperationError("error parsing embedding response", err, providerName) + } + + return handleCohereEmbeddingResponse(cohereResp, model, params, providerName, rawResponse, provider.sendBackRawResponse) +} + +// ChatCompletionStream performs a streaming chat completion request to Bedrock's API. +// It formats the request, sends it to Bedrock, and processes the streaming response. +// Returns a channel for streaming BifrostResponse objects or an error if the request fails. +func (provider *BedrockProvider) ChatCompletionStream(ctx context.Context, postHookRunner schemas.PostHookRunner, model string, key schemas.Key, messages []schemas.BifrostMessage, params *schemas.ModelParameters) (chan *schemas.BifrostStream, *schemas.BifrostError) { + if err := checkOperationAllowed(schemas.Bedrock, provider.customProviderConfig, schemas.OperationChatCompletionStream); err != nil { + return nil, err + } + + providerName := provider.GetProviderKey() + + if key.BedrockKeyConfig == nil { + return nil, newConfigurationError("bedrock key config is not provided", providerName) + } + + messageBody, err := provider.prepareChatCompletionMessages(messages, model) + if err != nil { + return nil, err + } + + preparedParams := prepareParams(params) + + if strings.Contains(model, "anthropic.") { + if _, exists := preparedParams["max_tokens"]; !exists { + preparedParams["max_tokens"] = AnthropicDefaultMaxTokens + } + } + + // Transform tools if present + if params != nil && params.Tools != nil && len(*params.Tools) > 0 { + tools, err := provider.getChatCompletionTools(params, model) + if err != nil { + return nil, err + } + + toolConfig := map[string]interface{}{ + "tools": tools, + } + + // Add tool choice if specified + if toolChoice := provider.prepareToolChoice(params, model); toolChoice != nil { + toolConfig["toolChoice"] = toolChoice + } + + preparedParams["toolConfig"] = toolConfig + } else { + // Check if conversation history contains tool use/result blocks + // Bedrock requires toolConfig when such blocks are present + hasToolContent, toolsFromHistory := provider.extractToolsFromHistory(messages) + + // If conversation contains tool content but no tools provided in current request, + // include the extracted tools to satisfy Bedrock's toolConfig requirement + if hasToolContent && len(toolsFromHistory) > 0 { + preparedParams["toolConfig"] = map[string]interface{}{ + "tools": toolsFromHistory, + } + } + } + + requestBody := mergeConfig(messageBody, preparedParams) + + // Format the path with proper model identifier for streaming + path := provider.getModelPath("converse-stream", model, key) + + region := "us-east-1" + if key.BedrockKeyConfig.Region != nil { + region = *key.BedrockKeyConfig.Region + } + + // Create the streaming request + jsonBody, jsonErr := sonic.Marshal(requestBody) + if jsonErr != nil { + return nil, newBifrostOperationError(schemas.ErrProviderJSONMarshaling, jsonErr, providerName) + } + + // Create HTTP request for streaming + req, reqErr := http.NewRequestWithContext(ctx, "POST", fmt.Sprintf("https://bedrock-runtime.%s.amazonaws.com/model/%s", region, path), bytes.NewReader(jsonBody)) + if reqErr != nil { + return nil, newBifrostOperationError("error creating request", reqErr, providerName) + } + + // Set any extra headers from network config + setExtraHeadersHTTP(req, provider.networkConfig.ExtraHeaders, nil) + + // If Value is set, use API Key authentication - else use IAM role authentication + if key.Value != "" { + req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", key.Value)) + } else { + // Sign the request using either explicit credentials or IAM role authentication + if err := signAWSRequest(ctx, req, key.BedrockKeyConfig.AccessKey, key.BedrockKeyConfig.SecretKey, key.BedrockKeyConfig.SessionToken, region, "bedrock", providerName); err != nil { + return nil, err + } + } + + // Make the request + resp, respErr := provider.client.Do(req) + if respErr != nil { + return nil, newBifrostOperationError(schemas.ErrProviderRequest, respErr, providerName) + } + + // Check for HTTP errors + if resp.StatusCode != http.StatusOK { + body, _ := io.ReadAll(resp.Body) + resp.Body.Close() + return nil, newProviderAPIError(fmt.Sprintf("HTTP error from %s: %d", providerName, resp.StatusCode), fmt.Errorf("%s", string(body)), resp.StatusCode, providerName, nil, nil) + } + + // Create response channel + responseChan := make(chan *schemas.BifrostStream, schemas.DefaultStreamBufferSize) + + // Start streaming in a goroutine + go func() { + defer close(responseChan) + defer resp.Body.Close() + + // Process AWS Event Stream format + var messageID string + var usage *schemas.LLMUsage + var finishReason *string + chunkIndex := -1 + + // Read the response body as a continuous stream + reader := bufio.NewReader(resp.Body) + buffer := make([]byte, 1024*1024) // 1MB buffer + var accumulator []byte // Accumulate data across reads + + for { + n, err := reader.Read(buffer) + if err != nil { + if err == io.EOF { + // Process any remaining data in the accumulator + if len(accumulator) > 0 { + _ = provider.processAWSEventStreamData(ctx, postHookRunner, accumulator, &messageID, &chunkIndex, &usage, &finishReason, model, providerName, responseChan) + } + break + } + provider.logger.Warn(fmt.Sprintf("Error reading %s stream: %v", providerName, err)) + processAndSendError(ctx, postHookRunner, err, responseChan, provider.logger) + return + } + + if n == 0 { + continue + } + + // Append new data to accumulator + accumulator = append(accumulator, buffer[:n]...) + + // Process the accumulated data and get the remaining unprocessed part + remaining := provider.processAWSEventStreamData(ctx, postHookRunner, accumulator, &messageID, &chunkIndex, &usage, &finishReason, model, providerName, responseChan) + + // Reset accumulator with remaining data + accumulator = remaining + } + + // Send final response + response := createBifrostChatCompletionChunkResponse(messageID, usage, finishReason, chunkIndex, params, providerName) + handleStreamEndWithSuccess(ctx, response, postHookRunner, responseChan, provider.logger) + }() + + return responseChan, nil +} + +// processAWSEventStreamData processes raw AWS Event Stream data and extracts JSON events. +// Returns any remaining unprocessed bytes that should be kept for the next read. +func (provider *BedrockProvider) processAWSEventStreamData( + ctx context.Context, + postHookRunner schemas.PostHookRunner, + data []byte, + messageID *string, + chunkIndex *int, + usage **schemas.LLMUsage, + finishReason **string, + model string, + providerName schemas.ModelProvider, + responseChan chan *schemas.BifrostStream, +) []byte { + lastProcessed := 0 + depth := 0 + inString := false + escaped := false + objStart := -1 + + for i := 0; i < len(data); i++ { + b := data[i] + if inString { + if escaped { + escaped = false + continue + } + switch b { + case '\\': + escaped = true + case '"': + inString = false + } + continue + } + + switch b { + case '"': + inString = true + case '{': + if depth == 0 { + objStart = i + } + depth++ + case '}': + if depth > 0 { + depth-- + if depth == 0 && objStart >= 0 { + jsonBytes := data[objStart : i+1] + // Quick filter to match original behavior - check for JSON content and relevant fields + hasQuotes := bytes.Contains(jsonBytes, []byte(`"`)) + hasRelevantContent := bytes.Contains(jsonBytes, []byte(`role`)) || + bytes.Contains(jsonBytes, []byte(`delta`)) || + bytes.Contains(jsonBytes, []byte(`usage`)) || + bytes.Contains(jsonBytes, []byte(`stopReason`)) || + bytes.Contains(jsonBytes, []byte(`contentBlockIndex`)) || + bytes.Contains(jsonBytes, []byte(`metadata`)) + + if hasQuotes && hasRelevantContent { + provider.processEventBuffer(ctx, postHookRunner, jsonBytes, messageID, chunkIndex, usage, finishReason, model, providerName, responseChan) + lastProcessed = i + 1 + } + objStart = -1 + } + } + default: + // skip } } + if lastProcessed < len(data) { + return data[lastProcessed:] + } return nil } + +// processEventBuffer processes AWS Event Stream JSON payloads and determines event type from content +func (provider *BedrockProvider) processEventBuffer(ctx context.Context, postHookRunner schemas.PostHookRunner, eventBuffer []byte, messageID *string, chunkIndex *int, usage **schemas.LLMUsage, finishReason **string, model string, providerName schemas.ModelProvider, responseChan chan *schemas.BifrostStream) { + // Parse the JSON event + var event map[string]interface{} + if err := sonic.Unmarshal(eventBuffer, &event); err != nil { + provider.logger.Debug(fmt.Sprintf("Failed to parse JSON from event buffer: %v, data: %s", err, string(eventBuffer))) + return + } + + // Determine event type based on JSON content structure + switch { + case event["role"] != nil: + // This is a messageStart event + *chunkIndex++ + if role, ok := event["role"].(string); ok { + *messageID = fmt.Sprintf("bedrock-%d", time.Now().UnixNano()) + + // Send empty response to signal start + streamResponse := &schemas.BifrostResponse{ + ID: *messageID, + Object: "chat.completion.chunk", + Model: model, + Choices: []schemas.BifrostResponseChoice{ + { + Index: 0, + BifrostStreamResponseChoice: &schemas.BifrostStreamResponseChoice{ + Delta: schemas.BifrostStreamDelta{ + Role: &role, + }, + }, + }, + }, + ExtraFields: schemas.BifrostResponseExtraFields{ + Provider: providerName, + ChunkIndex: *chunkIndex, + }, + } + + // Use utility function to process and send response + processAndSendResponse(ctx, postHookRunner, streamResponse, responseChan, provider.logger) + } + + case event["contentBlockIndex"] != nil && event["delta"] != nil: + // This is a contentBlockDelta event + *chunkIndex++ + contentBlockIndex := 0 + if idx, ok := event["contentBlockIndex"].(float64); ok { + contentBlockIndex = int(idx) + } + + if delta, ok := event["delta"].(map[string]interface{}); ok { + switch { + case delta["text"] != nil: + // Handle text delta + if text, ok := delta["text"].(string); ok && text != "" { + // Create streaming response for this delta + streamResponse := &schemas.BifrostResponse{ + ID: *messageID, + Object: "chat.completion.chunk", + Model: model, + Choices: []schemas.BifrostResponseChoice{ + { + Index: contentBlockIndex, + BifrostStreamResponseChoice: &schemas.BifrostStreamResponseChoice{ + Delta: schemas.BifrostStreamDelta{ + Content: &text, + }, + }, + }, + }, + ExtraFields: schemas.BifrostResponseExtraFields{ + Provider: providerName, + ChunkIndex: *chunkIndex, + }, + } + + // Use utility function to process and send response + processAndSendResponse(ctx, postHookRunner, streamResponse, responseChan, provider.logger) + } + + case delta["toolUse"] != nil: + // Handle tool use delta + if toolUse, ok := delta["toolUse"].(map[string]interface{}); ok { + // Parse the tool use structure properly + var toolCall schemas.ToolCall + toolCall.Type = func() *string { s := "function"; return &s }() + + // Extract toolUseId + if toolUseID, hasID := toolUse["toolUseId"].(string); hasID { + toolCall.ID = &toolUseID + } + + // Extract name + if name, hasName := toolUse["name"].(string); hasName { + toolCall.Function.Name = &name + } + + // Extract and marshal input as arguments + if input, hasInput := toolUse["input"].(map[string]interface{}); hasInput { + inputBytes, err := sonic.Marshal(input) + if err != nil { + toolCall.Function.Arguments = "{}" + } else { + toolCall.Function.Arguments = string(inputBytes) + } + } else { + toolCall.Function.Arguments = "{}" + } + + // Create streaming response for tool delta + streamResponse := &schemas.BifrostResponse{ + ID: *messageID, + Object: "chat.completion.chunk", + Model: model, + Choices: []schemas.BifrostResponseChoice{ + { + Index: contentBlockIndex, + BifrostStreamResponseChoice: &schemas.BifrostStreamResponseChoice{ + Delta: schemas.BifrostStreamDelta{ + ToolCalls: []schemas.ToolCall{toolCall}, + }, + }, + }, + }, + ExtraFields: schemas.BifrostResponseExtraFields{ + Provider: providerName, + ChunkIndex: *chunkIndex, + }, + } + + // Use utility function to process and send response + processAndSendResponse(ctx, postHookRunner, streamResponse, responseChan, provider.logger) + } + } + } + + case event["stopReason"] != nil: + // This is a messageStop event + if stopReason, ok := event["stopReason"].(string); ok { + *finishReason = &stopReason + } + + case event["usage"] != nil: + // This is a metadata event with usage information at top level + if usageData, ok := event["usage"].(map[string]interface{}); ok { + inputTokens := 0 + outputTokens := 0 + totalTokens := 0 + + if val, exists := usageData["inputTokens"].(float64); exists { + inputTokens = int(val) + } + if val, exists := usageData["outputTokens"].(float64); exists { + outputTokens = int(val) + } + if val, exists := usageData["totalTokens"].(float64); exists { + totalTokens = int(val) + } + + *usage = &schemas.LLMUsage{ + PromptTokens: inputTokens, + CompletionTokens: outputTokens, + TotalTokens: totalTokens, + } + } + + case event["metadata"] != nil: + // This is a metadata event - check if it contains nested usage information + if metadata, ok := event["metadata"].(map[string]interface{}); ok { + if usageData, ok := metadata["usage"].(map[string]interface{}); ok { + inputTokens := 0 + outputTokens := 0 + totalTokens := 0 + + if val, exists := usageData["inputTokens"].(float64); exists { + inputTokens = int(val) + } + if val, exists := usageData["outputTokens"].(float64); exists { + outputTokens = int(val) + } + if val, exists := usageData["totalTokens"].(float64); exists { + totalTokens = int(val) + } + + *usage = &schemas.LLMUsage{ + PromptTokens: inputTokens, + CompletionTokens: outputTokens, + TotalTokens: totalTokens, + } + } + } + + default: + // Log unknown event types for debugging + provider.logger.Debug(fmt.Sprintf("Unknown event type received: %v", event)) + } +} + +func (provider *BedrockProvider) Speech(ctx context.Context, model string, key schemas.Key, input *schemas.SpeechInput, params *schemas.ModelParameters) (*schemas.BifrostResponse, *schemas.BifrostError) { + return nil, newUnsupportedOperationError("speech", "bedrock") +} + +func (provider *BedrockProvider) SpeechStream(ctx context.Context, postHookRunner schemas.PostHookRunner, model string, key schemas.Key, input *schemas.SpeechInput, params *schemas.ModelParameters) (chan *schemas.BifrostStream, *schemas.BifrostError) { + return nil, newUnsupportedOperationError("speech stream", "bedrock") +} + +func (provider *BedrockProvider) Transcription(ctx context.Context, model string, key schemas.Key, input *schemas.TranscriptionInput, params *schemas.ModelParameters) (*schemas.BifrostResponse, *schemas.BifrostError) { + return nil, newUnsupportedOperationError("transcription", "bedrock") +} + +func (provider *BedrockProvider) TranscriptionStream(ctx context.Context, postHookRunner schemas.PostHookRunner, model string, key schemas.Key, input *schemas.TranscriptionInput, params *schemas.ModelParameters) (chan *schemas.BifrostStream, *schemas.BifrostError) { + return nil, newUnsupportedOperationError("transcription stream", "bedrock") +} + +func (provider *BedrockProvider) getModelPath(basePath string, model string, key schemas.Key) string { + // Format the path with proper model identifier for streaming + path := fmt.Sprintf("%s/%s", model, basePath) + + if key.BedrockKeyConfig.Deployments != nil { + if inferenceProfileId, ok := key.BedrockKeyConfig.Deployments[model]; ok { + if key.BedrockKeyConfig.ARN != nil { + encodedModelIdentifier := url.PathEscape(fmt.Sprintf("%s/%s", *key.BedrockKeyConfig.ARN, inferenceProfileId)) + path = fmt.Sprintf("%s/%s", encodedModelIdentifier, basePath) + } + } + } + + return path +} diff --git a/core/providers/cerebras.go b/core/providers/cerebras.go new file mode 100644 index 000000000..95f114912 --- /dev/null +++ b/core/providers/cerebras.go @@ -0,0 +1,356 @@ +// Package providers implements various LLM providers and their utility functions. +// This file contains the Cerebras provider implementation. +package providers + +import ( + "context" + "fmt" + "net/http" + "strings" + "sync" + "time" + + "github.com/bytedance/sonic" + schemas "github.com/maximhq/bifrost/core/schemas" + "github.com/valyala/fasthttp" +) + +// cerebrasTextResponsePool provides a pool for Cerebras text completion response objects. +var cerebrasTextResponsePool = sync.Pool{ + New: func() interface{} { + return &AzureTextResponse{} + }, +} + +// // cerebrasChatResponsePool provides a pool for Cerebras chat response objects. +// var cerebrasChatResponsePool = sync.Pool{ +// New: func() interface{} { +// return &schemas.BifrostResponse{} +// }, +// } + +// // acquireCerebrasChatResponse gets a Cerebras response from the pool and resets it. +// func acquireCerebrasChatResponse() *schemas.BifrostResponse { +// resp := cerebrasChatResponsePool.Get().(*schemas.BifrostResponse) +// *resp = schemas.BifrostResponse{} // Reset the struct +// return resp +// } + +// // releaseCerebrasChatResponse returns a Cerebras response to the pool. +// func releaseCerebrasChatResponse(resp *schemas.BifrostResponse) { +// if resp != nil { +// cerebrasChatResponsePool.Put(resp) +// } +// } + +// acquireCerebrasTextResponse gets a Cerebras text completion response from the pool and resets it. +func acquireCerebrasTextResponse() *AzureTextResponse { + resp := cerebrasTextResponsePool.Get().(*AzureTextResponse) + *resp = AzureTextResponse{} // Reset the struct + return resp +} + +// releaseCerebrasTextResponse returns a Cerebras text completion response to the pool. +func releaseCerebrasTextResponse(resp *AzureTextResponse) { + if resp != nil { + cerebrasTextResponsePool.Put(resp) + } +} + +// CerebrasProvider implements the Provider interface for Cerebras's API. +type CerebrasProvider struct { + logger schemas.Logger // Logger for provider operations + client *fasthttp.Client // HTTP client for API requests + streamClient *http.Client // HTTP client for streaming requests + networkConfig schemas.NetworkConfig // Network configuration including extra headers + sendBackRawResponse bool // Whether to include raw response in BifrostResponse +} + +// NewCerebrasProvider creates a new Cerebras provider instance. +// It initializes the HTTP client with the provided configuration and sets up response pools. +// The client is configured with timeouts, concurrency limits, and optional proxy settings. +func NewCerebrasProvider(config *schemas.ProviderConfig, logger schemas.Logger) (*CerebrasProvider, error) { + config.CheckAndSetDefaults() + + client := &fasthttp.Client{ + ReadTimeout: time.Second * time.Duration(config.NetworkConfig.DefaultRequestTimeoutInSeconds), + WriteTimeout: time.Second * time.Duration(config.NetworkConfig.DefaultRequestTimeoutInSeconds), + MaxConnsPerHost: config.ConcurrencyAndBufferSize.BufferSize, + } + + // Initialize streaming HTTP client + streamClient := &http.Client{ + Timeout: time.Second * time.Duration(config.NetworkConfig.DefaultRequestTimeoutInSeconds), + } + + // Pre-warm response pools + for range config.ConcurrencyAndBufferSize.Concurrency { + // cerebrasChatResponsePool.Put(&schemas.BifrostResponse{}) + cerebrasTextResponsePool.Put(&AzureTextResponse{}) + } + + // Configure proxy if provided + client = configureProxy(client, config.ProxyConfig, logger) + + // Set default BaseURL if not provided + if config.NetworkConfig.BaseURL == "" { + config.NetworkConfig.BaseURL = "https://api.cerebras.ai" + } + config.NetworkConfig.BaseURL = strings.TrimRight(config.NetworkConfig.BaseURL, "/") + + return &CerebrasProvider{ + logger: logger, + client: client, + streamClient: streamClient, + networkConfig: config.NetworkConfig, + sendBackRawResponse: config.SendBackRawResponse, + }, nil +} + +// GetProviderKey returns the provider identifier for Cerebras. +func (provider *CerebrasProvider) GetProviderKey() schemas.ModelProvider { + return schemas.Cerebras +} + +// TextCompletion performs a text completion request to Cerebras's API. +// It formats the request, sends it to Cerebras, and processes the response. +// Returns a BifrostResponse containing the completion results or an error if the request fails. +func (provider *CerebrasProvider) TextCompletion(ctx context.Context, model string, key schemas.Key, text string, params *schemas.ModelParameters) (*schemas.BifrostResponse, *schemas.BifrostError) { + preparedParams := prepareParams(params) + + // Merge additional parameters + requestBody := mergeConfig(map[string]interface{}{ + "model": model, + "prompt": text, + }, preparedParams) + + // Create request + req := fasthttp.AcquireRequest() + resp := fasthttp.AcquireResponse() + defer fasthttp.ReleaseRequest(req) + defer fasthttp.ReleaseResponse(resp) + + jsonBody, err := sonic.Marshal(requestBody) + if err != nil { + return nil, newBifrostOperationError(schemas.ErrProviderJSONMarshaling, err, schemas.Cerebras) + } + + // Set any extra headers from network config + setExtraHeaders(req, provider.networkConfig.ExtraHeaders, nil) + + req.SetRequestURI(provider.networkConfig.BaseURL + "/v1/completions") + req.Header.SetMethod("POST") + req.Header.SetContentType("application/json") + req.Header.Set("Authorization", "Bearer "+key.Value) + + req.SetBody(jsonBody) + + // Make request + bifrostErr := makeRequestWithContext(ctx, provider.client, req, resp) + if bifrostErr != nil { + return nil, bifrostErr + } + + // Handle error response + if resp.StatusCode() != fasthttp.StatusOK { + provider.logger.Debug(fmt.Sprintf("error from cerebras provider: %s", string(resp.Body()))) + + var errorResp map[string]interface{} + bifrostErr := handleProviderAPIError(resp, &errorResp) + bifrostErr.Error.Message = fmt.Sprintf("Cerebras error: %v", errorResp) + return nil, bifrostErr + } + + responseBody := resp.Body() + + // Pre-allocate response structs from pools + response := acquireCerebrasTextResponse() + defer releaseCerebrasTextResponse(response) + + rawResponse, bifrostErr := handleProviderResponse(responseBody, response, provider.sendBackRawResponse) + if bifrostErr != nil { + return nil, bifrostErr + } + + choices := []schemas.BifrostResponseChoice{} + + // Create the completion result + if len(response.Choices) > 0 { + // Copy text content to avoid pointer to pooled memory + textCopy := response.Choices[0].Text + + choices = append(choices, schemas.BifrostResponseChoice{ + Index: 0, + BifrostNonStreamResponseChoice: &schemas.BifrostNonStreamResponseChoice{ + Message: schemas.BifrostMessage{ + Role: schemas.ModelChatMessageRoleAssistant, + Content: schemas.MessageContent{ + ContentStr: &textCopy, + }, + }, + LogProbs: &schemas.LogProbs{ + Text: response.Choices[0].LogProbs, + }, + }, + FinishReason: response.Choices[0].FinishReason, + }) + } + + // Copy Usage struct to avoid pointer to pooled memory + usageCopy := response.Usage + + // Create final response + bifrostResponse := &schemas.BifrostResponse{ + ID: response.ID, + Choices: choices, + Model: response.Model, + Created: response.Created, + SystemFingerprint: response.SystemFingerprint, + Usage: &usageCopy, + ExtraFields: schemas.BifrostResponseExtraFields{ + Provider: schemas.Cerebras, + }, + } + + // Set raw response if enabled + if provider.sendBackRawResponse { + bifrostResponse.ExtraFields.RawResponse = rawResponse + } + + if params != nil { + bifrostResponse.ExtraFields.Params = *params + } + + return bifrostResponse, nil +} + +// ChatCompletion performs a chat completion request to the Cerebras API. +func (provider *CerebrasProvider) ChatCompletion(ctx context.Context, model string, key schemas.Key, messages []schemas.BifrostMessage, params *schemas.ModelParameters) (*schemas.BifrostResponse, *schemas.BifrostError) { + formattedMessages, preparedParams := prepareOpenAIChatRequest(messages, params) + + requestBody := mergeConfig(map[string]interface{}{ + "model": model, + "messages": formattedMessages, + }, preparedParams) + + jsonBody, err := sonic.Marshal(requestBody) + if err != nil { + return nil, newBifrostOperationError(schemas.ErrProviderJSONMarshaling, err, schemas.Cerebras) + } + + // Create request + req := fasthttp.AcquireRequest() + resp := fasthttp.AcquireResponse() + defer fasthttp.ReleaseRequest(req) + defer fasthttp.ReleaseResponse(resp) + + // Set any extra headers from network config + setExtraHeaders(req, provider.networkConfig.ExtraHeaders, nil) + + req.SetRequestURI(provider.networkConfig.BaseURL + "/v1/chat/completions") + req.Header.SetMethod("POST") + req.Header.SetContentType("application/json") + req.Header.Set("Authorization", "Bearer "+key.Value) + + req.SetBody(jsonBody) + + // Make request + bifrostErr := makeRequestWithContext(ctx, provider.client, req, resp) + if bifrostErr != nil { + return nil, bifrostErr + } + + // Handle error response + if resp.StatusCode() != fasthttp.StatusOK { + provider.logger.Debug(fmt.Sprintf("error from cerebras provider: %s", string(resp.Body()))) + + var errorResp map[string]interface{} + bifrostErr := handleProviderAPIError(resp, &errorResp) + bifrostErr.Error.Message = fmt.Sprintf("Cerebras error: %v", errorResp) + return nil, bifrostErr + } + + responseBody := resp.Body() + + // Pre-allocate response structs from pools + // response := acquireCerebrasChatResponse() + // defer releaseCerebrasChatResponse(response) + response := &schemas.BifrostResponse{} + + // Use enhanced response handler with pre-allocated response + rawResponse, bifrostErr := handleProviderResponse(responseBody, response, provider.sendBackRawResponse) + if bifrostErr != nil { + return nil, bifrostErr + } + + // Create final response + response.ExtraFields.Provider = schemas.Cerebras + + if provider.sendBackRawResponse { + response.ExtraFields.RawResponse = rawResponse + } + + if params != nil { + response.ExtraFields.Params = *params + } + + return response, nil +} + +// Embedding is not supported by the Cerebras provider. +func (provider *CerebrasProvider) Embedding(ctx context.Context, model string, key schemas.Key, input *schemas.EmbeddingInput, params *schemas.ModelParameters) (*schemas.BifrostResponse, *schemas.BifrostError) { + return nil, newUnsupportedOperationError("embedding", "cerebras") +} + +// ChatCompletionStream performs a streaming chat completion request to the Cerebras API. +// It supports real-time streaming of responses using Server-Sent Events (SSE). +// Uses Cerebras's OpenAI-compatible streaming format. +// Returns a channel containing BifrostResponse objects representing the stream or an error if the request fails. +func (provider *CerebrasProvider) ChatCompletionStream(ctx context.Context, postHookRunner schemas.PostHookRunner, model string, key schemas.Key, messages []schemas.BifrostMessage, params *schemas.ModelParameters) (chan *schemas.BifrostStream, *schemas.BifrostError) { + formattedMessages, preparedParams := prepareOpenAIChatRequest(messages, params) + + requestBody := mergeConfig(map[string]interface{}{ + "model": model, + "messages": formattedMessages, + "stream": true, + }, preparedParams) + + // Prepare Cerebras headers + headers := map[string]string{ + "Content-Type": "application/json", + "Accept": "text/event-stream", + "Cache-Control": "no-cache", + } + + headers["Authorization"] = "Bearer " + key.Value + + // Use shared OpenAI-compatible streaming logic + return handleOpenAIStreaming( + ctx, + provider.streamClient, + provider.networkConfig.BaseURL+"/v1/chat/completions", + requestBody, + headers, + provider.networkConfig.ExtraHeaders, + schemas.Cerebras, + params, + postHookRunner, + provider.logger, + ) +} + +func (provider *CerebrasProvider) Speech(ctx context.Context, model string, key schemas.Key, input *schemas.SpeechInput, params *schemas.ModelParameters) (*schemas.BifrostResponse, *schemas.BifrostError) { + return nil, newUnsupportedOperationError("speech", "cerebras") +} + +func (provider *CerebrasProvider) SpeechStream(ctx context.Context, postHookRunner schemas.PostHookRunner, model string, key schemas.Key, input *schemas.SpeechInput, params *schemas.ModelParameters) (chan *schemas.BifrostStream, *schemas.BifrostError) { + return nil, newUnsupportedOperationError("speech stream", "cerebras") +} + +func (provider *CerebrasProvider) Transcription(ctx context.Context, model string, key schemas.Key, input *schemas.TranscriptionInput, params *schemas.ModelParameters) (*schemas.BifrostResponse, *schemas.BifrostError) { + return nil, newUnsupportedOperationError("transcription", "cerebras") +} + +func (provider *CerebrasProvider) TranscriptionStream(ctx context.Context, postHookRunner schemas.PostHookRunner, model string, key schemas.Key, input *schemas.TranscriptionInput, params *schemas.ModelParameters) (chan *schemas.BifrostStream, *schemas.BifrostError) { + return nil, newUnsupportedOperationError("transcription stream", "cerebras") +} diff --git a/core/providers/cohere.go b/core/providers/cohere.go index 0240af086..fcf03888f 100644 --- a/core/providers/cohere.go +++ b/core/providers/cohere.go @@ -3,13 +3,19 @@ package providers import ( + "bufio" + "bytes" + "context" "fmt" + "io" "slices" + "strings" "sync" "time" - "github.com/goccy/go-json" + "net/http" + "github.com/bytedance/sonic" schemas "github.com/maximhq/bifrost/core/schemas" "github.com/valyala/fasthttp" ) @@ -61,9 +67,8 @@ type CohereToolCall struct { // CohereChatResponse represents the response from Cohere's chat API. // It includes the response ID, generated text, chat history, and usage statistics. type CohereChatResponse struct { - ResponseID string `json:"response_id"` // Unique identifier for the response - Text string `json:"text"` // Generated text response GenerationID string `json:"generation_id"` // ID of the generation + Text string `json:"text"` // Generated text response ChatHistory []struct { Role schemas.ModelChatMessageRole `json:"role"` // Role of the message sender Message string `json:"message"` // Content of the message @@ -75,8 +80,10 @@ type CohereChatResponse struct { Version string `json:"version"` // Version of the API used } `json:"api_version"` // API version information BilledUnits struct { - InputTokens float64 `json:"input_tokens"` // Number of input tokens billed - OutputTokens float64 `json:"output_tokens"` // Number of output tokens billed + InputTokens float64 `json:"input_tokens"` // Number of input tokens billed + OutputTokens float64 `json:"output_tokens"` // Number of output tokens billed + Classifications float64 `json:"classifications"` // Number of classifications billed + SearchUnits float64 `json:"search_units"` // Number of search units billed } `json:"billed_units"` // Token usage billing information Tokens struct { InputTokens float64 `json:"input_tokens"` // Number of input tokens used @@ -91,114 +98,144 @@ type CohereError struct { Message string `json:"message"` // Error message } +// CohereEmbeddingResponse represents the response from Cohere's embedding API. +type CohereEmbeddingResponse struct { + ID string `json:"id"` // Unique identifier for the embedding request + Embeddings struct { + Float [][]float32 `json:"float"` // Array of float embeddings, one for each input text + } `json:"embeddings"` // Embeddings in the response + Texts []string `json:"texts"` // Texts that were embedded + Meta struct { + APIVersion struct { + Version string `json:"version"` // Version of the API used + IsExperimental bool `json:"is_experimental"` // Whether the API is experimental + } `json:"api_version"` // API version information + BilledUnits struct { + InputTokens float64 `json:"input_tokens"` // Number of input tokens billed + OutputTokens float64 `json:"output_tokens"` // Number of output tokens billed + Classifications float64 `json:"classifications"` // Number of classifications billed + SearchUnits float64 `json:"search_units"` // Number of search units billed + } `json:"billed_units"` // Token usage billing information + Tokens struct { + InputTokens float64 `json:"input_tokens"` // Number of input tokens used + OutputTokens float64 `json:"output_tokens"` // Number of output tokens generated + } `json:"tokens"` // Token usage statistics + Warnings []string `json:"warnings"` // Warnings about the response + } `json:"meta"` // Metadata about the response +} + // CohereProvider implements the Provider interface for Cohere. type CohereProvider struct { - logger schemas.Logger // Logger for provider operations - client *fasthttp.Client // HTTP client for API requests + logger schemas.Logger // Logger for provider operations + client *fasthttp.Client // HTTP client for API requests + streamClient *http.Client // HTTP client for streaming requests + networkConfig schemas.NetworkConfig // Network configuration including extra headers + sendBackRawResponse bool // Whether to include raw response in BifrostResponse + customProviderConfig *schemas.CustomProviderConfig // Custom provider config +} + +// CohereStreamStartEvent represents the start of a stream event. +type CohereStreamStartEvent struct { + EventType string `json:"event_type"` // stream-start + GenerationID string `json:"generation_id"` // ID of the generation +} + +// CohereStreamTextEvent represents the text generation event. +type CohereStreamTextEvent struct { + EventType string `json:"event_type"` // text-generation + Text string `json:"text"` // Text content being generated +} + +// CohereStreamToolEvent represents the tool use event. +type CohereStreamToolCallEvent struct { + EventType string `json:"event_type"` // tool-use + ToolCall struct { + ID string `json:"id"` // ID of the tool call + Parameters string `json:"parameters"` // Parameters of the tool being called + } `json:"tool_call"` // Tool call information + Text *string `json:"text"` // Text content being generated +} + +// CohereStreamStopEvent represents the end of a stream event. +type CohereStreamStopEvent struct { + EventType string `json:"event_type"` // stream-end + Response CohereChatResponse `json:"response"` // Response information } // NewCohereProvider creates a new Cohere provider instance. // It initializes the HTTP client with the provided configuration and sets up response pools. // The client is configured with timeouts and connection limits. func NewCohereProvider(config *schemas.ProviderConfig, logger schemas.Logger) *CohereProvider { - setConfigDefaults(config) + config.CheckAndSetDefaults() client := &fasthttp.Client{ ReadTimeout: time.Second * time.Duration(config.NetworkConfig.DefaultRequestTimeoutInSeconds), WriteTimeout: time.Second * time.Duration(config.NetworkConfig.DefaultRequestTimeoutInSeconds), - MaxConnsPerHost: config.ConcurrencyAndBufferSize.BufferSize, + MaxConnsPerHost: config.ConcurrencyAndBufferSize.Concurrency, + } + + // Initialize streaming HTTP client + streamClient := &http.Client{ + Timeout: time.Second * time.Duration(config.NetworkConfig.DefaultRequestTimeoutInSeconds), } // Pre-warm response pools - for range config.ConcurrencyAndBufferSize.Concurrency { + for i := 0; i < config.ConcurrencyAndBufferSize.Concurrency; i++ { cohereResponsePool.Put(&CohereChatResponse{}) - bifrostResponsePool.Put(&schemas.BifrostResponse{}) } + // Set default BaseURL if not provided + if config.NetworkConfig.BaseURL == "" { + config.NetworkConfig.BaseURL = "https://api.cohere.ai" + } + config.NetworkConfig.BaseURL = strings.TrimRight(config.NetworkConfig.BaseURL, "/") + return &CohereProvider{ - logger: logger, - client: client, + logger: logger, + client: client, + streamClient: streamClient, + networkConfig: config.NetworkConfig, + customProviderConfig: config.CustomProviderConfig, + sendBackRawResponse: config.SendBackRawResponse, } } // GetProviderKey returns the provider identifier for Cohere. func (provider *CohereProvider) GetProviderKey() schemas.ModelProvider { - return schemas.Cohere + return getProviderName(schemas.Cohere, provider.customProviderConfig) } // TextCompletion is not supported by the Cohere provider. // Returns an error indicating that text completion is not supported. -func (provider *CohereProvider) TextCompletion(model, key, text string, params *schemas.ModelParameters) (*schemas.BifrostResponse, *schemas.BifrostError) { - return nil, &schemas.BifrostError{ - IsBifrostError: false, - Error: schemas.ErrorField{ - Message: "text completion is not supported by cohere provider", - }, - } +func (provider *CohereProvider) TextCompletion(ctx context.Context, model string, key schemas.Key, text string, params *schemas.ModelParameters) (*schemas.BifrostResponse, *schemas.BifrostError) { + return nil, newUnsupportedOperationError("text completion", "cohere") } // ChatCompletion performs a chat completion request to the Cohere API. // It formats the request, sends it to Cohere, and processes the response. // Returns a BifrostResponse containing the completion results or an error if the request fails. -func (provider *CohereProvider) ChatCompletion(model, key string, messages []schemas.Message, params *schemas.ModelParameters) (*schemas.BifrostResponse, *schemas.BifrostError) { - // Get the last message and chat history - lastMessage := messages[len(messages)-1] - chatHistory := messages[:len(messages)-1] - - // Transform chat history - var cohereHistory []map[string]interface{} - for _, msg := range chatHistory { - cohereHistory = append(cohereHistory, map[string]interface{}{ - "role": msg.Role, - "message": msg.Content, - }) +func (provider *CohereProvider) ChatCompletion(ctx context.Context, model string, key schemas.Key, messages []schemas.BifrostMessage, params *schemas.ModelParameters) (*schemas.BifrostResponse, *schemas.BifrostError) { + // Check if chat completion is allowed + if err := checkOperationAllowed(schemas.Cohere, provider.customProviderConfig, schemas.OperationChatCompletion); err != nil { + return nil, err } - preparedParams := prepareParams(params) - - // Prepare request body - requestBody := mergeConfig(map[string]interface{}{ - "message": lastMessage.Content, - "chat_history": cohereHistory, - "model": model, - }, preparedParams) - - // Add tools if present - if params != nil && params.Tools != nil && len(*params.Tools) > 0 { - var tools []CohereTool - for _, tool := range *params.Tools { - parameterDefinitions := make(map[string]CohereParameterDefinition) - params := tool.Function.Parameters - for name, prop := range tool.Function.Parameters.Properties { - propMap, ok := prop.(map[string]interface{}) - if ok { - paramDef := CohereParameterDefinition{ - Required: slices.Contains(params.Required, name), - } - - if typeStr, ok := propMap["type"].(string); ok { - paramDef.Type = typeStr - } - - if desc, ok := propMap["description"].(string); ok { - paramDef.Description = &desc - } - - parameterDefinitions[name] = paramDef - } - } + providerName := provider.GetProviderKey() - tools = append(tools, CohereTool{ - Name: tool.Function.Name, - Description: tool.Function.Description, - ParameterDefinitions: parameterDefinitions, - }) + // Prepare request body using shared function + requestBody, err := prepareCohereChatRequest(messages, params, model, false) + if err != nil { + return nil, &schemas.BifrostError{ + IsBifrostError: true, + Error: schemas.ErrorField{ + Message: fmt.Sprintf("failed to prepare %s chat request", providerName), + Error: err, + }, } - requestBody["tools"] = tools } // Marshal request body - jsonBody, err := json.Marshal(requestBody) + jsonBody, err := sonic.Marshal(requestBody) if err != nil { return nil, &schemas.BifrostError{ IsBifrostError: true, @@ -215,25 +252,26 @@ func (provider *CohereProvider) ChatCompletion(model, key string, messages []sch defer fasthttp.ReleaseRequest(req) defer fasthttp.ReleaseResponse(resp) - req.SetRequestURI("https://api.cohere.ai/v1/chat") + // Set any extra headers from network config + setExtraHeaders(req, provider.networkConfig.ExtraHeaders, nil) + + req.SetRequestURI(provider.networkConfig.BaseURL + "/v1/chat") req.Header.SetMethod("POST") req.Header.SetContentType("application/json") - req.Header.Set("Authorization", "Bearer "+key) + req.Header.Set("Authorization", "Bearer "+key.Value) + req.SetBody(jsonBody) // Make request - if err := provider.client.Do(req, resp); err != nil { - return nil, &schemas.BifrostError{ - IsBifrostError: false, - Error: schemas.ErrorField{ - Message: schemas.ErrProviderRequest, - Error: err, - }, - } + bifrostErr := makeRequestWithContext(ctx, provider.client, req, resp) + if bifrostErr != nil { + return nil, bifrostErr } // Handle error response if resp.StatusCode() != fasthttp.StatusOK { + provider.logger.Debug(fmt.Sprintf("error from %s provider: %s", providerName, string(resp.Body()))) + var errorResp CohereError bifrostErr := handleProviderAPIError(resp, &errorResp) @@ -249,11 +287,7 @@ func (provider *CohereProvider) ChatCompletion(model, key string, messages []sch response := acquireCohereResponse() defer releaseCohereResponse(response) - // Create Bifrost response from pool - bifrostResponse := acquireBifrostResponse() - defer releaseBifrostResponse(bifrostResponse) - - rawResponse, bifrostErr := handleProviderResponse(responseBody, response) + rawResponse, bifrostErr := handleProviderResponse(responseBody, response, provider.sendBackRawResponse) if bifrostErr != nil { return nil, bifrostErr } @@ -266,7 +300,7 @@ func (provider *CohereProvider) ChatCompletion(model, key string, messages []sch Name: &tool.Name, } - args, err := json.Marshal(tool.Parameters) + args, err := sonic.Marshal(tool.Parameters) if err != nil { function.Arguments = fmt.Sprintf("%v", tool.Parameters) } else { @@ -287,49 +321,275 @@ func (provider *CohereProvider) ChatCompletion(model, key string, messages []sch role = lastMsg.Role content = lastMsg.Message } else { - role = schemas.RoleChatbot + role = schemas.ModelChatMessageRoleChatbot content = response.Text } - bifrostResponse.ID = response.ResponseID - bifrostResponse.Choices = []schemas.BifrostResponseChoice{ - { - Index: 0, - Message: schemas.BifrostResponseChoiceMessage{ - Role: role, - Content: &content, - ToolCalls: &toolCalls, + // Create final response + bifrostResponse := &schemas.BifrostResponse{ + ID: response.GenerationID, + Choices: []schemas.BifrostResponseChoice{ + { + Index: 0, + BifrostNonStreamResponseChoice: &schemas.BifrostNonStreamResponseChoice{ + Message: schemas.BifrostMessage{ + Role: role, + Content: schemas.MessageContent{ + ContentStr: &content, + }, + AssistantMessage: &schemas.AssistantMessage{ + ToolCalls: &toolCalls, + }, + }, + }, + FinishReason: &response.FinishReason, }, - FinishReason: &response.FinishReason, }, - } - bifrostResponse.Usage = schemas.LLMUsage{ - PromptTokens: int(response.Meta.Tokens.InputTokens), - CompletionTokens: int(response.Meta.Tokens.OutputTokens), - TotalTokens: int(response.Meta.Tokens.InputTokens + response.Meta.Tokens.OutputTokens), - } - bifrostResponse.Model = model - bifrostResponse.ExtraFields = schemas.BifrostResponseExtraFields{ - Provider: schemas.Cohere, - BilledUsage: &schemas.BilledLLMUsage{ - PromptTokens: float64Ptr(response.Meta.BilledUnits.InputTokens), - CompletionTokens: float64Ptr(response.Meta.BilledUnits.OutputTokens), + Usage: &schemas.LLMUsage{ + PromptTokens: int(response.Meta.Tokens.InputTokens), + CompletionTokens: int(response.Meta.Tokens.OutputTokens), + TotalTokens: int(response.Meta.Tokens.InputTokens + response.Meta.Tokens.OutputTokens), }, - ChatHistory: convertChatHistory(response.ChatHistory), - RawResponse: rawResponse, + Model: model, + ExtraFields: schemas.BifrostResponseExtraFields{ + Provider: providerName, + BilledUsage: &schemas.BilledLLMUsage{ + PromptTokens: Ptr(response.Meta.BilledUnits.InputTokens), + CompletionTokens: Ptr(response.Meta.BilledUnits.OutputTokens), + Classifications: Ptr(response.Meta.BilledUnits.Classifications), + SearchUnits: Ptr(response.Meta.BilledUnits.SearchUnits), + }, + ChatHistory: convertChatHistory(response.ChatHistory), + }, + } + + if provider.sendBackRawResponse { + bifrostResponse.ExtraFields.RawResponse = rawResponse + } + + if params != nil { + bifrostResponse.ExtraFields.Params = *params } return bifrostResponse, nil } +// prepareCohereChatRequest prepares the request body for Cohere chat completion requests. +// It transforms the messages into Cohere format and handles tools, parameters, and content formatting. +func prepareCohereChatRequest(messages []schemas.BifrostMessage, params *schemas.ModelParameters, model string, stream bool) (map[string]interface{}, error) { + // Get the last message and chat history + lastMessage := messages[len(messages)-1] + chatHistory := messages[:len(messages)-1] + + // Transform chat history + var cohereHistory []map[string]interface{} + for _, msg := range chatHistory { + historyMsg := map[string]interface{}{ + "role": msg.Role, + } + + if msg.Role == schemas.ModelChatMessageRoleAssistant { + if msg.AssistantMessage != nil && msg.AssistantMessage.ToolCalls != nil { + var toolCalls []map[string]interface{} + for _, toolCall := range *msg.AssistantMessage.ToolCalls { + var arguments map[string]interface{} + var parsedJSON interface{} + err := sonic.Unmarshal([]byte(toolCall.Function.Arguments), &parsedJSON) + if err == nil { + if arr, ok := parsedJSON.(map[string]interface{}); ok { + arguments = arr + } else { + arguments = map[string]interface{}{"content": parsedJSON} + } + } else { + arguments = map[string]interface{}{"content": toolCall.Function.Arguments} + } + + toolCalls = append(toolCalls, map[string]interface{}{ + "name": toolCall.Function.Name, + "parameters": arguments, + }) + } + historyMsg["tool_calls"] = toolCalls + } + } else if msg.Role == schemas.ModelChatMessageRoleTool { + // Find the original tool call parameters from conversation history + var toolCallParameters map[string]interface{} + + // Look back through the chat history to find the assistant message with the matching tool call + for i := len(chatHistory) - 1; i >= 0; i-- { + prevMsg := chatHistory[i] + if prevMsg.Role == schemas.ModelChatMessageRoleAssistant && + prevMsg.AssistantMessage != nil && + prevMsg.AssistantMessage.ToolCalls != nil { + + // Search through tool calls in this assistant message + for _, toolCall := range *prevMsg.AssistantMessage.ToolCalls { + if toolCall.ID != nil && msg.ToolMessage != nil && msg.ToolMessage.ToolCallID != nil && + *toolCall.ID == *msg.ToolMessage.ToolCallID { + + // Found the matching tool call, extract its parameters + var parsedJSON interface{} + err := sonic.Unmarshal([]byte(toolCall.Function.Arguments), &parsedJSON) + if err == nil { + if arr, ok := parsedJSON.(map[string]interface{}); ok { + toolCallParameters = arr + } else { + toolCallParameters = map[string]interface{}{"content": parsedJSON} + } + } else { + toolCallParameters = map[string]interface{}{"content": toolCall.Function.Arguments} + } + break + } + } + + // If we found the parameters, stop searching + if toolCallParameters != nil { + break + } + } + } + + // If no parameters found, use empty map as fallback + if toolCallParameters == nil { + toolCallParameters = map[string]interface{}{} + } + + toolResults := []map[string]interface{}{ + { + "call": map[string]interface{}{ + "name": *msg.ToolMessage.ToolCallID, + "parameters": toolCallParameters, + }, + "outputs": *msg.Content.ContentStr, + }, + } + + historyMsg["tool_results"] = toolResults + } + + if msg.Content.ContentStr != nil { + historyMsg["message"] = *msg.Content.ContentStr + } else if msg.Content.ContentBlocks != nil { + // Create content array with text and image + contentArray := []map[string]interface{}{} + + // Iterate over ContentBlocks to build the content array + for _, block := range *msg.Content.ContentBlocks { + if block.Text != nil { + contentArray = append(contentArray, map[string]interface{}{ + "type": "text", + "text": *block.Text, + }) + } + // Add image content using our helper function + // NOTE: Cohere v1 does not support image content + // if processedImageContent := processImageContent(block.ImageContent); processedImageContent != nil { + // contentArray = append(contentArray, processedImageContent) + // } + } + + historyMsg["content"] = contentArray + } + + cohereHistory = append(cohereHistory, historyMsg) + } + + preparedParams := prepareParams(params) + + // Prepare request body + requestBody := mergeConfig(map[string]interface{}{ + "chat_history": cohereHistory, + "model": model, + }, preparedParams) + + // Add stream parameter if streaming + if stream { + requestBody["stream"] = true + } + + if lastMessage.Content.ContentStr != nil { + requestBody["message"] = *lastMessage.Content.ContentStr + } else if lastMessage.Content.ContentBlocks != nil { + message := "" + for _, block := range *lastMessage.Content.ContentBlocks { + if block.Text != nil { + message += *block.Text + "\n" + } + } + requestBody["message"] = strings.TrimSuffix(message, "\n") + } + + // Add tools if present + if params != nil && params.Tools != nil && len(*params.Tools) > 0 { + var tools []CohereTool + for _, tool := range *params.Tools { + parameterDefinitions := make(map[string]CohereParameterDefinition) + params := tool.Function.Parameters + for name, prop := range tool.Function.Parameters.Properties { + propMap, ok := prop.(map[string]interface{}) + if ok { + paramDef := CohereParameterDefinition{ + Required: slices.Contains(params.Required, name), + } + + if typeStr, ok := propMap["type"].(string); ok { + paramDef.Type = typeStr + } + + if desc, ok := propMap["description"].(string); ok { + paramDef.Description = &desc + } + + parameterDefinitions[name] = paramDef + } + } + + tools = append(tools, CohereTool{ + Name: tool.Function.Name, + Description: tool.Function.Description, + ParameterDefinitions: parameterDefinitions, + }) + } + requestBody["tools"] = tools + } + + // Add tool choice if present + if params != nil && params.ToolChoice != nil { + if params.ToolChoice.ToolChoiceStr != nil { + requestBody["tool_choice"] = *params.ToolChoice.ToolChoiceStr + } else if params.ToolChoice.ToolChoiceStruct != nil { + requestBody["tool_choice"] = map[string]interface{}{ + "type": strings.ToUpper(string(params.ToolChoice.ToolChoiceStruct.Type)), + } + } + } + + return requestBody, nil +} + +// processImageContent processes image content for Cohere API format. +// NOTE: Cohere v1 does not support image content, so this function is a placeholder. +// It returns nil since image processing is not available. +func processImageContent(imageContent *schemas.ImageURLStruct) map[string]interface{} { + if imageContent == nil { + return nil + } + + // Cohere v1 does not support image content + // Return nil to skip image processing + return nil +} + // convertChatHistory converts Cohere's chat history format to Bifrost's format for standardization. // It transforms the chat history messages and their tool calls. func convertChatHistory(history []struct { Role schemas.ModelChatMessageRole `json:"role"` Message string `json:"message"` ToolCalls []CohereToolCall `json:"tool_calls"` -}) *[]schemas.BifrostResponseChoiceMessage { - converted := make([]schemas.BifrostResponseChoiceMessage, len(history)) +}) *[]schemas.BifrostMessage { + converted := make([]schemas.BifrostMessage, len(history)) for i, msg := range history { var toolCalls []schemas.ToolCall if msg.ToolCalls != nil { @@ -338,7 +598,7 @@ func convertChatHistory(history []struct { Name: &tool.Name, } - args, err := json.Marshal(tool.Parameters) + args, err := sonic.Marshal(tool.Parameters) if err != nil { function.Arguments = fmt.Sprintf("%v", tool.Parameters) } else { @@ -350,11 +610,452 @@ func convertChatHistory(history []struct { }) } } - converted[i] = schemas.BifrostResponseChoiceMessage{ - Role: msg.Role, - Content: &msg.Message, - ToolCalls: &toolCalls, + + converted[i] = schemas.BifrostMessage{ + Role: msg.Role, + Content: schemas.MessageContent{ + ContentStr: &msg.Message, + }, + AssistantMessage: &schemas.AssistantMessage{ + ToolCalls: &toolCalls, + }, } } return &converted } + +// Embedding generates embeddings for the given input text(s) using the Cohere API. +// Supports Cohere's embedding models and returns a BifrostResponse containing the embedding(s). +func (provider *CohereProvider) Embedding(ctx context.Context, model string, key schemas.Key, input *schemas.EmbeddingInput, params *schemas.ModelParameters) (*schemas.BifrostResponse, *schemas.BifrostError) { + // Check if embedding is allowed + if err := checkOperationAllowed(schemas.Cohere, provider.customProviderConfig, schemas.OperationEmbedding); err != nil { + return nil, err + } + + providerName := provider.GetProviderKey() + + // Prepare request body with default values + requestBody := map[string]interface{}{ + "model": model, + "input_type": "search_document", // Default input type - can be overridden via ExtraParams + "embedding_types": []string{"float"}, // Default to float embeddings + } + + if input.Text != nil { + requestBody["texts"] = []string{*input.Text} + } else { + requestBody["texts"] = input.Texts + } + + // Apply additional parameters if provided + if params != nil { + // Validate encoding format - Cohere API supports float, int8, uint8, binary, ubinary, but our provider only implements float + if params.EncodingFormat != nil { + if *params.EncodingFormat != "float" { + return nil, newConfigurationError(fmt.Sprintf("provider currently only supports 'float' encoding format, received: %s", *params.EncodingFormat), providerName) + } + // Override default with the specified format + requestBody["embedding_types"] = []string{*params.EncodingFormat} + } + + // Merge extra parameters - this allows overriding input_type and other parameters + if params.ExtraParams != nil { + for k, v := range params.ExtraParams { + requestBody[k] = v + } + } + } + + // Marshal request body + jsonBody, err := sonic.Marshal(requestBody) + if err != nil { + return nil, newBifrostOperationError(schemas.ErrProviderJSONMarshaling, err, providerName) + } + + // Create request + req := fasthttp.AcquireRequest() + resp := fasthttp.AcquireResponse() + defer fasthttp.ReleaseRequest(req) + defer fasthttp.ReleaseResponse(resp) + + // Set any extra headers from network config + setExtraHeaders(req, provider.networkConfig.ExtraHeaders, nil) + + req.SetRequestURI(provider.networkConfig.BaseURL + "/v2/embed") + req.Header.SetMethod("POST") + req.Header.SetContentType("application/json") + req.Header.Set("Authorization", "Bearer "+key.Value) + + req.SetBody(jsonBody) + + // Make request + bifrostErr := makeRequestWithContext(ctx, provider.client, req, resp) + if bifrostErr != nil { + return nil, bifrostErr + } + + // Handle error response + if resp.StatusCode() != fasthttp.StatusOK { + provider.logger.Debug(fmt.Sprintf("error from %s provider: %s", providerName, string(resp.Body()))) + + var errorResp CohereError + bifrostErr := handleProviderAPIError(resp, &errorResp) + bifrostErr.Error.Message = errorResp.Message + + return nil, bifrostErr + } + + // Parse response + var cohereResp CohereEmbeddingResponse + if err := sonic.Unmarshal(resp.Body(), &cohereResp); err != nil { + return nil, newBifrostOperationError("error parsing embedding response", err, providerName) + } + + // Parse raw response for consistent format + var rawResponse interface{} + if err := sonic.Unmarshal(resp.Body(), &rawResponse); err != nil { + return nil, newBifrostOperationError("error parsing raw response for embedding", err, providerName) + } + + return handleCohereEmbeddingResponse(cohereResp, model, params, providerName, rawResponse, provider.sendBackRawResponse) +} + +func handleCohereEmbeddingResponse(cohereResp CohereEmbeddingResponse, model string, params *schemas.ModelParameters, providerName schemas.ModelProvider, rawResponse interface{}, sendBackRawResponse bool) (*schemas.BifrostResponse, *schemas.BifrostError) { + // Create BifrostResponse + bifrostResponse := &schemas.BifrostResponse{ + ID: cohereResp.ID, + Object: "list", + Data: []schemas.BifrostEmbedding{ + { + Index: 0, + Object: "embedding", + Embedding: schemas.BifrostEmbeddingResponse{ + Embedding2DArray: &cohereResp.Embeddings.Float, + }, + }, + }, + Model: model, + Usage: &schemas.LLMUsage{ + PromptTokens: int(cohereResp.Meta.Tokens.InputTokens), + CompletionTokens: int(cohereResp.Meta.Tokens.OutputTokens), + TotalTokens: int(cohereResp.Meta.Tokens.InputTokens + cohereResp.Meta.Tokens.OutputTokens), + }, + ExtraFields: schemas.BifrostResponseExtraFields{ + Provider: providerName, + BilledUsage: &schemas.BilledLLMUsage{ + PromptTokens: Ptr(cohereResp.Meta.BilledUnits.InputTokens), + CompletionTokens: Ptr(cohereResp.Meta.BilledUnits.OutputTokens), + Classifications: Ptr(cohereResp.Meta.BilledUnits.Classifications), + SearchUnits: Ptr(cohereResp.Meta.BilledUnits.SearchUnits), + }, + }, + } + + // Only include RawResponse if sendBackRawResponse is enabled + if sendBackRawResponse { + bifrostResponse.ExtraFields.RawResponse = rawResponse + } + + if params != nil { + bifrostResponse.ExtraFields.Params = *params + } + + return bifrostResponse, nil + +} + +// ChatCompletionStream performs a streaming chat completion request to the Cohere API. +// It supports real-time streaming of responses using Server-Sent Events (SSE). +// Returns a channel containing BifrostResponse objects representing the stream or an error if the request fails. +func (provider *CohereProvider) ChatCompletionStream(ctx context.Context, postHookRunner schemas.PostHookRunner, model string, key schemas.Key, messages []schemas.BifrostMessage, params *schemas.ModelParameters) (chan *schemas.BifrostStream, *schemas.BifrostError) { + // Check if chat completion stream is allowed + if err := checkOperationAllowed(schemas.Cohere, provider.customProviderConfig, schemas.OperationChatCompletionStream); err != nil { + return nil, err + } + + providerName := provider.GetProviderKey() + + // Prepare request body using shared function + requestBody, err := prepareCohereChatRequest(messages, params, model, true) + if err != nil { + return nil, newBifrostOperationError("failed to prepare chat request", err, providerName) + } + + jsonBody, err := sonic.Marshal(requestBody) + if err != nil { + return nil, newBifrostOperationError(schemas.ErrProviderJSONMarshaling, err, providerName) + } + + // Create HTTP request for streaming + req, err := http.NewRequestWithContext(ctx, "POST", provider.networkConfig.BaseURL+"/v1/chat", bytes.NewReader(jsonBody)) + if err != nil { + return nil, newBifrostOperationError(schemas.ErrProviderRequest, err, providerName) + } + + // Set headers + req.Header.Set("Content-Type", "application/json") + req.Header.Set("Authorization", "Bearer "+key.Value) + req.Header.Set("Accept", "text/event-stream") + req.Header.Set("Cache-Control", "no-cache") + + // Set any extra headers from network config + setExtraHeadersHTTP(req, provider.networkConfig.ExtraHeaders, nil) + + // Make the request + resp, err := provider.streamClient.Do(req) + if err != nil { + return nil, &schemas.BifrostError{ + IsBifrostError: false, + Error: schemas.ErrorField{ + Message: schemas.ErrProviderRequest, + Error: err, + }, + } + } + + // Check for HTTP errors + if resp.StatusCode != http.StatusOK { + body, _ := io.ReadAll(resp.Body) + resp.Body.Close() + return nil, newProviderAPIError(fmt.Sprintf("HTTP error from %s: %d", providerName, resp.StatusCode), fmt.Errorf("%s", string(body)), resp.StatusCode, providerName, nil, nil) + } + + // Create response channel + responseChan := make(chan *schemas.BifrostStream, schemas.DefaultStreamBufferSize) + + chunkIndex := -1 + + // Start streaming in a goroutine + go func() { + defer close(responseChan) + defer resp.Body.Close() + + scanner := bufio.NewScanner(resp.Body) + var responseID string + + for scanner.Scan() { + line := scanner.Text() + + // Skip empty lines and comments + if line == "" || strings.HasPrefix(line, ":") { + continue + } + + // Parse SSE data + if strings.HasPrefix(line, "data: ") { + jsonData := strings.TrimPrefix(line, "data: ") + + // Parse the streaming event + var streamEvent map[string]interface{} + if err := sonic.Unmarshal([]byte(jsonData), &streamEvent); err != nil { + provider.logger.Warn(fmt.Sprintf("Failed to parse stream event: %v", err)) + continue + } + + eventType, exists := streamEvent["event_type"].(string) + if !exists { + continue + } + + chunkIndex++ + + switch eventType { + case "stream-start": + var startEvent CohereStreamStartEvent + if err := sonic.Unmarshal([]byte(jsonData), &startEvent); err != nil { + provider.logger.Warn(fmt.Sprintf("Failed to parse stream-start event: %v", err)) + continue + } + + responseID = startEvent.GenerationID + + // Send empty message to signal stream start + streamResponse := &schemas.BifrostResponse{ + ID: responseID, + Object: "chat.completion.chunk", + Model: model, + Choices: []schemas.BifrostResponseChoice{ + { + Index: 0, + + BifrostStreamResponseChoice: &schemas.BifrostStreamResponseChoice{ + Delta: schemas.BifrostStreamDelta{ + Role: Ptr(string(schemas.ModelChatMessageRoleAssistant)), + }, + }, + }, + }, + ExtraFields: schemas.BifrostResponseExtraFields{ + Provider: providerName, + ChunkIndex: chunkIndex, + }, + } + + // Use utility function to process and send response + processAndSendResponse(ctx, postHookRunner, streamResponse, responseChan, provider.logger) + + case "text-generation": + var textEvent CohereStreamTextEvent + if err := sonic.Unmarshal([]byte(jsonData), &textEvent); err != nil { + provider.logger.Warn(fmt.Sprintf("Failed to parse text-generation event: %v", err)) + continue + } + + // Create response for this text chunk + response := &schemas.BifrostResponse{ + ID: responseID, + Object: "chat.completion.chunk", + Choices: []schemas.BifrostResponseChoice{ + { + Index: 0, + BifrostStreamResponseChoice: &schemas.BifrostStreamResponseChoice{ + Delta: schemas.BifrostStreamDelta{ + Content: &textEvent.Text, + }, + }, + FinishReason: nil, // Not finished yet + }, + }, + Model: model, + ExtraFields: schemas.BifrostResponseExtraFields{ + Provider: providerName, + ChunkIndex: chunkIndex, + }, + } + + // Use utility function to process and send response + processAndSendResponse(ctx, postHookRunner, response, responseChan, provider.logger) + + case "tool-calls-chunk": + var toolEvent CohereStreamToolCallEvent + if err := sonic.Unmarshal([]byte(jsonData), &toolEvent); err != nil { + provider.logger.Warn(fmt.Sprintf("Failed to parse tool-use event: %v", err)) + continue + } + + toolCall := schemas.ToolCall{ + ID: &toolEvent.ToolCall.ID, + Function: schemas.FunctionCall{ + Name: &toolEvent.ToolCall.ID, + Arguments: toolEvent.ToolCall.Parameters, + }, + } + + // Create response for tool calls + response := &schemas.BifrostResponse{ + ID: responseID, + Object: "chat.completion.chunk", + Choices: []schemas.BifrostResponseChoice{ + { + Index: 0, + BifrostStreamResponseChoice: &schemas.BifrostStreamResponseChoice{ + Delta: schemas.BifrostStreamDelta{ + ToolCalls: []schemas.ToolCall{toolCall}, + Content: toolEvent.Text, + }, + }, + FinishReason: nil, + }, + }, + Model: model, + ExtraFields: schemas.BifrostResponseExtraFields{ + Provider: providerName, + ChunkIndex: chunkIndex, + }, + } + + // Use utility function to process and send response + processAndSendResponse(ctx, postHookRunner, response, responseChan, provider.logger) + + case "stream-end": + var stopEvent CohereStreamStopEvent + if err := sonic.Unmarshal([]byte(jsonData), &stopEvent); err != nil { + provider.logger.Warn(fmt.Sprintf("Failed to parse stream-end event: %v", err)) + continue + } + + // Convert tool calls from the final response + var toolCalls []schemas.ToolCall + for _, toolCall := range stopEvent.Response.ToolCalls { + function := schemas.FunctionCall{ + Name: &toolCall.Name, + } + + args, err := sonic.Marshal(toolCall.Parameters) + if err != nil { + function.Arguments = fmt.Sprintf("%v", toolCall.Parameters) + } else { + function.Arguments = string(args) + } + + toolCalls = append(toolCalls, schemas.ToolCall{ + Function: function, + }) + } + + // Send final response with complete content from the stopEvent + response := &schemas.BifrostResponse{ + ID: responseID, + Object: "chat.completion.chunk", + Choices: []schemas.BifrostResponseChoice{ + { + Index: 0, + BifrostStreamResponseChoice: &schemas.BifrostStreamResponseChoice{ + Delta: schemas.BifrostStreamDelta{ + Role: Ptr(string(schemas.ModelChatMessageRoleAssistant)), + Content: &stopEvent.Response.Text, + ToolCalls: toolCalls, + }, + }, + FinishReason: &stopEvent.Response.FinishReason, + }, + }, + Model: model, + ExtraFields: schemas.BifrostResponseExtraFields{ + Provider: providerName, + ChunkIndex: chunkIndex, + }, + } + + if params != nil { + response.ExtraFields.Params = *params + } + + ctx = context.WithValue(ctx, schemas.BifrostContextKeyStreamEndIndicator, true) + + // Use utility function to process and send response + processAndSendResponse(ctx, postHookRunner, response, responseChan, provider.logger) + + return // End of stream + + default: + // Unknown event type, log and continue + provider.logger.Debug(fmt.Sprintf("Unknown stream event type: %s", eventType)) + } + } + } + + if err := scanner.Err(); err != nil { + provider.logger.Warn(fmt.Sprintf("Error reading stream: %v", err)) + processAndSendError(ctx, postHookRunner, err, responseChan, provider.logger) + } + }() + + return responseChan, nil +} + +func (provider *CohereProvider) Speech(ctx context.Context, model string, key schemas.Key, input *schemas.SpeechInput, params *schemas.ModelParameters) (*schemas.BifrostResponse, *schemas.BifrostError) { + return nil, newUnsupportedOperationError("speech", "cohere") +} + +func (provider *CohereProvider) SpeechStream(ctx context.Context, postHookRunner schemas.PostHookRunner, model string, key schemas.Key, input *schemas.SpeechInput, params *schemas.ModelParameters) (chan *schemas.BifrostStream, *schemas.BifrostError) { + return nil, newUnsupportedOperationError("speech stream", "cohere") +} + +func (provider *CohereProvider) Transcription(ctx context.Context, model string, key schemas.Key, input *schemas.TranscriptionInput, params *schemas.ModelParameters) (*schemas.BifrostResponse, *schemas.BifrostError) { + return nil, newUnsupportedOperationError("transcription", "cohere") +} + +func (provider *CohereProvider) TranscriptionStream(ctx context.Context, postHookRunner schemas.PostHookRunner, model string, key schemas.Key, input *schemas.TranscriptionInput, params *schemas.ModelParameters) (chan *schemas.BifrostStream, *schemas.BifrostError) { + return nil, newUnsupportedOperationError("transcription stream", "cohere") +} diff --git a/core/providers/gemini.go b/core/providers/gemini.go new file mode 100644 index 000000000..b0c9be162 --- /dev/null +++ b/core/providers/gemini.go @@ -0,0 +1,1168 @@ +// Package providers implements various LLM providers and their utility functions. +// This file contains the Gemini provider implementation. +package providers + +import ( + "bufio" + "bytes" + "context" + "fmt" + "io" + "net/http" + "strings" + "time" + + "github.com/bytedance/sonic" + schemas "github.com/maximhq/bifrost/core/schemas" + "github.com/valyala/fasthttp" +) + +// Response message for PredictionService.GenerateContent. +type GenerateContentResponse struct { + // Response variations returned by the model. + Candidates []*Candidate `json:"candidates,omitempty"` + // Usage metadata about the response(s). + UsageMetadata *GenerateContentResponseUsageMetadata `json:"usageMetadata,omitempty"` +} + +// A response candidate generated from the model. +type Candidate struct { + // Optional. Contains the multi-part content of the response. + Content *Content `json:"content,omitempty"` + // Optional. The reason why the model stopped generating tokens. + // If empty, the model has not stopped generating the tokens. + FinishReason string `json:"finishReason,omitempty"` + // Output only. Index of the candidate. + Index int32 `json:"index,omitempty"` +} + +// Contains the multi-part content of a message. +type Content struct { + // Optional. List of parts that constitute a single message. Each part may have + // a different IANA MIME type. + Parts []*Part `json:"parts,omitempty"` + // Optional. The producer of the content. Must be either 'user' or + // 'model'. Useful to set for multi-turn conversations, otherwise can be + // empty. If role is not specified, SDK will determine the role. + Role string `json:"role,omitempty"` +} + +// A datatype containing media content. +// Exactly one field within a Part should be set, representing the specific type +// of content being conveyed. Using multiple fields within the same `Part` +// instance is considered invalid. +type Part struct { + // Optional. Inlined bytes data. + InlineData *Blob `json:"inlineData,omitempty"` + // Optional. Text part (can be code). + Text string `json:"text,omitempty"` +} + +// Content blob. +type Blob struct { + // Required. Raw bytes. + Data []byte `json:"data,omitempty"` +} + +// Usage metadata about response(s). +type GenerateContentResponseUsageMetadata struct { + // Number of tokens in the response(s). This includes all the generated response candidates. + CandidatesTokenCount int32 `json:"candidatesTokenCount,omitempty"` + // Number of tokens in the prompt. When cached_content is set, this is still the total + // effective prompt size meaning this includes the number of tokens in the cached content. + PromptTokenCount int32 `json:"promptTokenCount,omitempty"` + // Total token count for prompt, response candidates, and tool-use prompts (if present). + TotalTokenCount int32 `json:"totalTokenCount,omitempty"` +} + +type GeminiProvider struct { + logger schemas.Logger // Logger for provider operations + client *fasthttp.Client // HTTP client for API requests + streamClient *http.Client // HTTP client for streaming requests + networkConfig schemas.NetworkConfig // Network configuration including extra headers + sendBackRawResponse bool // Whether to include raw response in BifrostResponse + customProviderConfig *schemas.CustomProviderConfig // Custom provider config +} + +// NewGeminiProvider creates a new Gemini provider instance. +// It initializes the HTTP client with the provided configuration. +// The client is configured with timeouts, concurrency limits, and optional proxy settings. +func NewGeminiProvider(config *schemas.ProviderConfig, logger schemas.Logger) *GeminiProvider { + config.CheckAndSetDefaults() + + client := &fasthttp.Client{ + ReadTimeout: time.Second * time.Duration(config.NetworkConfig.DefaultRequestTimeoutInSeconds), + WriteTimeout: time.Second * time.Duration(config.NetworkConfig.DefaultRequestTimeoutInSeconds), + MaxConnsPerHost: config.ConcurrencyAndBufferSize.Concurrency, + } + + // Initialize streaming HTTP client + streamClient := &http.Client{ + Timeout: time.Second * time.Duration(config.NetworkConfig.DefaultRequestTimeoutInSeconds), + } + + // Configure proxy if provided + client = configureProxy(client, config.ProxyConfig, logger) + + // Set default BaseURL if not provided + if config.NetworkConfig.BaseURL == "" { + config.NetworkConfig.BaseURL = "https://generativelanguage.googleapis.com/v1beta" + } + config.NetworkConfig.BaseURL = strings.TrimRight(config.NetworkConfig.BaseURL, "/") + + return &GeminiProvider{ + logger: logger, + client: client, + streamClient: streamClient, + networkConfig: config.NetworkConfig, + customProviderConfig: config.CustomProviderConfig, + sendBackRawResponse: config.SendBackRawResponse, + } +} + +// GetProviderKey returns the provider identifier for Gemini. +func (provider *GeminiProvider) GetProviderKey() schemas.ModelProvider { + return getProviderName(schemas.Gemini, provider.customProviderConfig) +} + +// TextCompletion is not supported by the Gemini provider. +func (provider *GeminiProvider) TextCompletion(ctx context.Context, model string, key schemas.Key, text string, params *schemas.ModelParameters) (*schemas.BifrostResponse, *schemas.BifrostError) { + return nil, newUnsupportedOperationError("text completion", string(provider.GetProviderKey())) +} + +// ChatCompletion performs a chat completion request to the Gemini API. +func (provider *GeminiProvider) ChatCompletion(ctx context.Context, model string, key schemas.Key, messages []schemas.BifrostMessage, params *schemas.ModelParameters) (*schemas.BifrostResponse, *schemas.BifrostError) { + // Check if chat completion is allowed for this provider + if err := checkOperationAllowed(schemas.Gemini, provider.customProviderConfig, schemas.OperationChatCompletion); err != nil { + return nil, err + } + + providerName := provider.GetProviderKey() + + formattedMessages, preparedParams := prepareOpenAIChatRequest(messages, params) + + requestBody := mergeConfig(map[string]interface{}{ + "model": model, + "messages": formattedMessages, + }, preparedParams) + + jsonBody, err := sonic.Marshal(requestBody) + if err != nil { + return nil, newBifrostOperationError(schemas.ErrProviderJSONMarshaling, err, providerName) + } + + // Create request + req := fasthttp.AcquireRequest() + resp := fasthttp.AcquireResponse() + defer fasthttp.ReleaseRequest(req) + defer fasthttp.ReleaseResponse(resp) + + // Set any extra headers from network config + setExtraHeaders(req, provider.networkConfig.ExtraHeaders, nil) + + req.SetRequestURI(provider.networkConfig.BaseURL + "/openai/chat/completions") + req.Header.SetMethod("POST") + req.Header.SetContentType("application/json") + req.Header.Set("Authorization", "Bearer "+key.Value) + + req.SetBody(jsonBody) + + // Make request + bifrostErr := makeRequestWithContext(ctx, provider.client, req, resp) + if bifrostErr != nil { + return nil, bifrostErr + } + + // Handle error response + if resp.StatusCode() != fasthttp.StatusOK { + var errorResp map[string]interface{} + bifrostErr := handleProviderAPIError(resp, &errorResp) + bifrostErr.Error.Message = fmt.Sprintf("%s error: %v", providerName, errorResp) + return nil, bifrostErr + } + + responseBody := resp.Body() + + // Pre-allocate response structs from pools + // response := acquireGeminiResponse() + // defer releaseGeminiResponse(response) + response := &schemas.BifrostResponse{} + + // Use enhanced response handler with pre-allocated response + rawResponse, bifrostErr := handleProviderResponse(responseBody, response, provider.sendBackRawResponse) + if bifrostErr != nil { + return nil, bifrostErr + } + + for _, choice := range response.Choices { + if choice.Message.AssistantMessage == nil || choice.Message.AssistantMessage.ToolCalls == nil { + continue + } + for i, toolCall := range *choice.Message.AssistantMessage.ToolCalls { + if (toolCall.ID == nil || *toolCall.ID == "") && toolCall.Function.Name != nil && *toolCall.Function.Name != "" { + id := *toolCall.Function.Name + (*choice.Message.AssistantMessage.ToolCalls)[i].ID = &id + } + } + } + + response.ExtraFields.Provider = providerName + + if provider.sendBackRawResponse { + response.ExtraFields.RawResponse = rawResponse + } + + if params != nil { + response.ExtraFields.Params = *params + } + + return response, nil +} + +// ChatCompletionStream performs a streaming chat completion request to the Gemini API. +// It supports real-time streaming of responses using Server-Sent Events (SSE). +// Uses Gemini's OpenAI-compatible streaming format. +// Returns a channel containing BifrostResponse objects representing the stream or an error if the request fails. +func (provider *GeminiProvider) ChatCompletionStream(ctx context.Context, postHookRunner schemas.PostHookRunner, model string, key schemas.Key, messages []schemas.BifrostMessage, params *schemas.ModelParameters) (chan *schemas.BifrostStream, *schemas.BifrostError) { + // Check if chat completion stream is allowed for this provider + if err := checkOperationAllowed(schemas.Gemini, provider.customProviderConfig, schemas.OperationChatCompletionStream); err != nil { + return nil, err + } + + providerName := provider.GetProviderKey() + + formattedMessages, preparedParams := prepareOpenAIChatRequest(messages, params) + + requestBody := mergeConfig(map[string]interface{}{ + "model": model, + "messages": formattedMessages, + "stream": true, + }, preparedParams) + + // Prepare Gemini headers + headers := map[string]string{ + "Content-Type": "application/json", + "Authorization": "Bearer " + key.Value, + "Accept": "text/event-stream", + "Cache-Control": "no-cache", + } + + // Use shared OpenAI-compatible streaming logic + return handleOpenAIStreaming( + ctx, + provider.streamClient, + provider.networkConfig.BaseURL+"/openai/chat/completions", + requestBody, + headers, + provider.networkConfig.ExtraHeaders, + providerName, + params, + postHookRunner, + provider.logger, + ) +} + +// Embedding performs an embedding request to the Gemini API. +func (provider *GeminiProvider) Embedding(ctx context.Context, model string, key schemas.Key, input *schemas.EmbeddingInput, params *schemas.ModelParameters) (*schemas.BifrostResponse, *schemas.BifrostError) { + // Check if embedding is allowed for this provider + if err := checkOperationAllowed(schemas.Gemini, provider.customProviderConfig, schemas.OperationEmbedding); err != nil { + return nil, err + } + + providerName := provider.GetProviderKey() + + if input.Text == nil && len(input.Texts) == 0 { + return nil, newBifrostOperationError("invalid embedding input: at least one text is required", nil, providerName) + } + + // Prepare request body with base parameters + requestBody := map[string]interface{}{ + "model": model, + } + + if input.Text != nil { + requestBody["input"] = []string{*input.Text} + } else { + requestBody["input"] = input.Texts + } + + // Merge any additional parameters + if params != nil { + // Map standard parameters + if params.EncodingFormat != nil { + requestBody["encoding_format"] = *params.EncodingFormat + } + if params.Dimensions != nil { + requestBody["dimensions"] = *params.Dimensions + } + if params.User != nil { + requestBody["user"] = *params.User + } + + // Merge any extra parameters + if params.ExtraParams != nil { + requestBody = mergeConfig(requestBody, params.ExtraParams) + } + } + + // Use the shared embedding request handler + return handleOpenAIEmbeddingRequest( + ctx, + provider.client, + provider.networkConfig.BaseURL+"/openai/embeddings", + requestBody, + key, + params, + provider.networkConfig.ExtraHeaders, + providerName, + provider.sendBackRawResponse, + provider.logger, + ) +} + +func (provider *GeminiProvider) Speech(ctx context.Context, model string, key schemas.Key, input *schemas.SpeechInput, params *schemas.ModelParameters) (*schemas.BifrostResponse, *schemas.BifrostError) { + // Check if speech is allowed for this provider + if err := checkOperationAllowed(schemas.Gemini, provider.customProviderConfig, schemas.OperationSpeech); err != nil { + return nil, err + } + + providerName := provider.GetProviderKey() + + // Validate input + if input == nil || input.Input == "" { + return nil, newBifrostOperationError("invalid speech input: no text provided", fmt.Errorf("empty text input"), providerName) + } + + // Prepare request body using shared function + requestBody := prepareGeminiGenerationRequest(input, params, []string{"AUDIO"}) + + // Use common request function + bifrostResponse, geminiResponse, bifrostErr := provider.completeRequest(ctx, model, key, requestBody, ":generateContent", params) + if bifrostErr != nil { + return nil, bifrostErr + } + + // Extract audio data from response + var audioData []byte + if len(geminiResponse.Candidates) > 0 && geminiResponse.Candidates[0].Content != nil { + for _, part := range geminiResponse.Candidates[0].Content.Parts { + if part.InlineData != nil && part.InlineData.Data != nil { + audioData = append(audioData, part.InlineData.Data...) + } + } + } + + if len(audioData) == 0 { + return nil, newBifrostOperationError("no audio data received from Gemini", fmt.Errorf("empty audio response"), providerName) + } + + // Extract usage metadata using shared function + inputTokens, outputTokens, totalTokens := extractGeminiUsageMetadata(geminiResponse) + + // Update the response with speech-specific data + bifrostResponse.Object = "audio.speech" + bifrostResponse.Speech = &schemas.BifrostSpeech{ + Audio: audioData, + Usage: &schemas.AudioLLMUsage{ + InputTokens: inputTokens, + OutputTokens: outputTokens, + TotalTokens: totalTokens, + }, + } + + if params != nil { + bifrostResponse.ExtraFields.Params = *params + } + + return bifrostResponse, nil +} + +func (provider *GeminiProvider) SpeechStream(ctx context.Context, postHookRunner schemas.PostHookRunner, model string, key schemas.Key, input *schemas.SpeechInput, params *schemas.ModelParameters) (chan *schemas.BifrostStream, *schemas.BifrostError) { + // Check if speech stream is allowed for this provider + if err := checkOperationAllowed(schemas.Gemini, provider.customProviderConfig, schemas.OperationSpeechStream); err != nil { + return nil, err + } + + providerName := provider.GetProviderKey() + + // Validate input + if input == nil || input.Input == "" { + return nil, newBifrostOperationError("invalid speech input: no text provided", fmt.Errorf("empty text input"), providerName) + } + + // Prepare request body using shared function + requestBody := prepareGeminiGenerationRequest(input, params, []string{"AUDIO"}) + + jsonBody, err := sonic.Marshal(requestBody) + if err != nil { + return nil, newBifrostOperationError(schemas.ErrProviderJSONMarshaling, err, providerName) + } + + // Create HTTP request for streaming + req, err := http.NewRequestWithContext(ctx, "POST", provider.networkConfig.BaseURL+"/models/"+model+":streamGenerateContent?alt=sse", bytes.NewReader(jsonBody)) + if err != nil { + return nil, newBifrostOperationError(schemas.ErrProviderRequest, err, providerName) + } + + // Set any extra headers from network config + setExtraHeadersHTTP(req, provider.networkConfig.ExtraHeaders, nil) + + // Set headers for streaming + req.Header.Set("Content-Type", "application/json") + req.Header.Set("x-goog-api-key", key.Value) + req.Header.Set("Accept", "text/event-stream") + req.Header.Set("Cache-Control", "no-cache") + + // Make the request + resp, err := provider.streamClient.Do(req) + if err != nil { + return nil, newBifrostOperationError(schemas.ErrProviderRequest, err, providerName) + } + + // Check for HTTP errors + if resp.StatusCode != http.StatusOK { + defer resp.Body.Close() + return nil, parseStreamGeminiError(providerName, resp) + } + + // Create response channel + responseChan := make(chan *schemas.BifrostStream, schemas.DefaultStreamBufferSize) + + // Start streaming in a goroutine + go func() { + defer close(responseChan) + defer resp.Body.Close() + + scanner := bufio.NewScanner(resp.Body) + // Increase buffer size to handle large chunks (especially for audio data) + buf := make([]byte, 0, 64*1024) // 64KB buffer + scanner.Buffer(buf, 1024*1024) // Allow up to 1MB tokens + chunkIndex := -1 + usage := &schemas.AudioLLMUsage{} + + for scanner.Scan() { + line := scanner.Text() + + // Skip empty lines + if line == "" { + continue + } + + var jsonData string + // Parse SSE data + if strings.HasPrefix(line, "data: ") { + jsonData = strings.TrimPrefix(line, "data: ") + } else { + // Handle raw JSON errors (without "data: " prefix) + jsonData = line + } + + // Skip empty data + if strings.TrimSpace(jsonData) == "" { + continue + } + + // Process chunk using shared function + geminiResponse, err := processGeminiStreamChunk(jsonData) + if err != nil { + if strings.Contains(err.Error(), "gemini api error") { + // Handle API error + bifrostErr := &schemas.BifrostError{ + Type: Ptr("gemini_api_error"), + IsBifrostError: false, + Error: schemas.ErrorField{ + Message: err.Error(), + Error: err, + }, + } + ctx = context.WithValue(ctx, schemas.BifrostContextKeyStreamEndIndicator, true) + processAndSendBifrostError(ctx, postHookRunner, bifrostErr, responseChan, provider.logger) + return + } + provider.logger.Warn(fmt.Sprintf("Failed to process chunk: %v", err)) + continue + } + + // Extract audio data from Gemini response for regular chunks + var audioChunk []byte + if len(geminiResponse.Candidates) > 0 { + candidate := geminiResponse.Candidates[0] + if candidate.Content != nil && len(candidate.Content.Parts) > 0 { + var buf []byte + for _, part := range candidate.Content.Parts { + if part.InlineData != nil && part.InlineData.Data != nil { + buf = append(buf, part.InlineData.Data...) + } + } + if len(buf) > 0 { + audioChunk = buf + } + } + } + + // Check if this is the final chunk (has finishReason) + if len(geminiResponse.Candidates) > 0 && (geminiResponse.Candidates[0].FinishReason != "" || geminiResponse.UsageMetadata != nil) { + // Extract usage metadata using shared function + inputTokens, outputTokens, totalTokens := extractGeminiUsageMetadata(geminiResponse) + usage.InputTokens = inputTokens + usage.OutputTokens = outputTokens + usage.TotalTokens = totalTokens + } + + // Only send response if we have actual audio content + if len(audioChunk) > 0 { + chunkIndex++ + + // Create Bifrost speech response for streaming + response := &schemas.BifrostResponse{ + Object: "audio.speech.chunk", + Model: model, + Speech: &schemas.BifrostSpeech{ + Audio: audioChunk, + BifrostSpeechStreamResponse: &schemas.BifrostSpeechStreamResponse{ + Type: "audio.speech.chunk", + }, + }, + ExtraFields: schemas.BifrostResponseExtraFields{ + Provider: providerName, + ChunkIndex: chunkIndex, + }, + } + + // Process response through post-hooks and send to channel + processAndSendResponse(ctx, postHookRunner, response, responseChan, provider.logger) + } + } + + // Handle scanner errors + if err := scanner.Err(); err != nil { + provider.logger.Warn(fmt.Sprintf("Error reading stream: %v", err)) + processAndSendError(ctx, postHookRunner, err, responseChan, provider.logger) + } else { + response := &schemas.BifrostResponse{ + Object: "audio.speech.chunk", + Speech: &schemas.BifrostSpeech{ + Usage: usage, + }, + ExtraFields: schemas.BifrostResponseExtraFields{ + Provider: providerName, + ChunkIndex: chunkIndex + 1, + }, + } + + if params != nil { + response.ExtraFields.Params = *params + } + handleStreamEndWithSuccess(ctx, response, postHookRunner, responseChan, provider.logger) + } + }() + + return responseChan, nil +} + +func (provider *GeminiProvider) Transcription(ctx context.Context, model string, key schemas.Key, input *schemas.TranscriptionInput, params *schemas.ModelParameters) (*schemas.BifrostResponse, *schemas.BifrostError) { + // Check if transcription is allowed for this provider + if err := checkOperationAllowed(schemas.Gemini, provider.customProviderConfig, schemas.OperationTranscription); err != nil { + return nil, err + } + + providerName := provider.GetProviderKey() + + // Check file size limit (Gemini has a 20MB limit for inline data) + const maxFileSize = 20 * 1024 * 1024 // 20MB + if len(input.File) > maxFileSize { + return nil, newBifrostOperationError("audio file too large for inline transcription", fmt.Errorf("file size %d bytes exceeds 20MB limit", len(input.File)), providerName) + } + + if input.Prompt == nil { + input.Prompt = Ptr("Generate a transcript of the speech.") + } + + // Prepare request body using shared function + requestBody := prepareGeminiGenerationRequest(input, params, nil) + + // Use common request function + bifrostResponse, geminiResponse, bifrostErr := provider.completeRequest(ctx, model, key, requestBody, ":generateContent", params) + if bifrostErr != nil { + return nil, bifrostErr + } + + // Extract text from Gemini response + var transcriptText string + if len(geminiResponse.Candidates) > 0 && geminiResponse.Candidates[0].Content != nil { + for _, p := range geminiResponse.Candidates[0].Content.Parts { + if p.Text != "" { + transcriptText += p.Text + } + } + } + + // If no transcript text was extracted, return an error + if transcriptText == "" { + return nil, newBifrostOperationError("failed to extract transcript from Gemini response", fmt.Errorf("no transcript text found"), providerName) + } + + // Extract usage metadata using shared function + inputTokens, outputTokens, totalTokens := extractGeminiUsageMetadata(geminiResponse) + + // Update the response with transcription-specific data + bifrostResponse.Object = "audio.transcription" + bifrostResponse.Transcribe = &schemas.BifrostTranscribe{ + Text: transcriptText, + Usage: &schemas.TranscriptionUsage{ + Type: "tokens", + InputTokens: &inputTokens, + OutputTokens: &outputTokens, + TotalTokens: &totalTokens, + }, + BifrostTranscribeNonStreamResponse: &schemas.BifrostTranscribeNonStreamResponse{ + Task: Ptr("transcribe"), + Language: input.Language, + }, + } + + if params != nil { + bifrostResponse.ExtraFields.Params = *params + } + + return bifrostResponse, nil +} + +func (provider *GeminiProvider) TranscriptionStream(ctx context.Context, postHookRunner schemas.PostHookRunner, model string, key schemas.Key, input *schemas.TranscriptionInput, params *schemas.ModelParameters) (chan *schemas.BifrostStream, *schemas.BifrostError) { + // Check if transcription stream is allowed for this provider + if err := checkOperationAllowed(schemas.Gemini, provider.customProviderConfig, schemas.OperationTranscriptionStream); err != nil { + return nil, err + } + + providerName := provider.GetProviderKey() + + // Check file size limit (Gemini has a 20MB limit for inline data) + if input.File != nil { + const maxFileSize = 20 * 1024 * 1024 // 20MB + if len(input.File) > maxFileSize { + return nil, newBifrostOperationError("audio file too large for inline transcription", fmt.Errorf("file size %d bytes exceeds 20MB limit", len(input.File)), providerName) + } + } + + if input.Prompt == nil { + input.Prompt = Ptr("Generate a transcript of the speech.") + } + + // Prepare request body using shared function + requestBody := prepareGeminiGenerationRequest(input, params, nil) + + jsonBody, err := sonic.Marshal(requestBody) + if err != nil { + return nil, newBifrostOperationError(schemas.ErrProviderJSONMarshaling, err, providerName) + } + + // Create HTTP request for streaming + req, err := http.NewRequestWithContext(ctx, "POST", provider.networkConfig.BaseURL+"/models/"+model+":streamGenerateContent?alt=sse", bytes.NewReader(jsonBody)) + if err != nil { + return nil, newBifrostOperationError(schemas.ErrProviderRequest, err, providerName) + } + + // Set any extra headers from network config + setExtraHeadersHTTP(req, provider.networkConfig.ExtraHeaders, nil) + + // Set headers for streaming + req.Header.Set("Content-Type", "application/json") + req.Header.Set("x-goog-api-key", key.Value) + req.Header.Set("Accept", "text/event-stream") + req.Header.Set("Cache-Control", "no-cache") + + // Make the request + resp, err := provider.streamClient.Do(req) + if err != nil { + return nil, newBifrostOperationError(schemas.ErrProviderRequest, err, providerName) + } + + // Check for HTTP errors + if resp.StatusCode != http.StatusOK { + defer resp.Body.Close() + return nil, parseStreamGeminiError(providerName, resp) + } + + // Create response channel + responseChan := make(chan *schemas.BifrostStream, schemas.DefaultStreamBufferSize) + + // Start streaming in a goroutine + go func() { + defer close(responseChan) + defer resp.Body.Close() + + scanner := bufio.NewScanner(resp.Body) + chunkIndex := -1 + usage := &schemas.TranscriptionUsage{} + + var fullTranscriptionText string + + for scanner.Scan() { + line := scanner.Text() + + // Skip empty lines + if line == "" { + continue + } + var jsonData string + // Parse SSE data + if strings.HasPrefix(line, "data: ") { + jsonData = strings.TrimPrefix(line, "data: ") + } else { + // Handle raw JSON errors (without "data: " prefix) + jsonData = line + } + + // Skip empty data + if strings.TrimSpace(jsonData) == "" { + continue + } + + // First, check if this is an error response + var errorCheck map[string]interface{} + if err := sonic.Unmarshal([]byte(jsonData), &errorCheck); err != nil { + provider.logger.Warn(fmt.Sprintf("Failed to parse stream data as JSON: %v", err)) + continue + } + + // Handle error responses + if _, hasError := errorCheck["error"]; hasError { + bifrostErr := &schemas.BifrostError{ + Type: Ptr("gemini_api_error"), + IsBifrostError: false, + Error: schemas.ErrorField{ + Message: fmt.Sprintf("Gemini API error: %v", errorCheck["error"]), + Error: fmt.Errorf("stream error: %v", errorCheck["error"]), + }, + } + ctx = context.WithValue(ctx, schemas.BifrostContextKeyStreamEndIndicator, true) + processAndSendBifrostError(ctx, postHookRunner, bifrostErr, responseChan, provider.logger) + return + } + + // Parse Gemini streaming response + var geminiResponse GenerateContentResponse + if err := sonic.Unmarshal([]byte(jsonData), &geminiResponse); err != nil { + provider.logger.Warn(fmt.Sprintf("Failed to parse Gemini stream response: %v", err)) + continue + } + + // Extract text from Gemini response for regular chunks + var deltaText string + if len(geminiResponse.Candidates) > 0 && geminiResponse.Candidates[0].Content != nil { + if len(geminiResponse.Candidates[0].Content.Parts) > 0 { + var sb strings.Builder + for _, p := range geminiResponse.Candidates[0].Content.Parts { + if p.Text != "" { + sb.WriteString(p.Text) + } + } + if sb.Len() > 0 { + deltaText = sb.String() + fullTranscriptionText += deltaText + } + } + } + + // Check if this is the final chunk (has finishReason) + if len(geminiResponse.Candidates) > 0 && (geminiResponse.Candidates[0].FinishReason != "" || geminiResponse.UsageMetadata != nil) { + // Extract usage metadata from Gemini response + inputTokens, outputTokens, totalTokens := extractGeminiUsageMetadata(&geminiResponse) + usage.InputTokens = Ptr(inputTokens) + usage.OutputTokens = Ptr(outputTokens) + usage.TotalTokens = Ptr(totalTokens) + } + + // Only send response if we have actual text content + if deltaText != "" { + chunkIndex++ + + // Create Bifrost transcription response for streaming + response := &schemas.BifrostResponse{ + Object: "audio.transcription.chunk", + Transcribe: &schemas.BifrostTranscribe{ + BifrostTranscribeStreamResponse: &schemas.BifrostTranscribeStreamResponse{ + Type: Ptr("transcript.text.delta"), + Delta: &deltaText, // Delta text for this chunk + }, + }, + Model: model, + ExtraFields: schemas.BifrostResponseExtraFields{ + Provider: providerName, + ChunkIndex: chunkIndex, + }, + } + + // Process response through post-hooks and send to channel + processAndSendResponse(ctx, postHookRunner, response, responseChan, provider.logger) + } + } + + // Handle scanner errors + if err := scanner.Err(); err != nil { + provider.logger.Warn(fmt.Sprintf("Error reading stream: %v", err)) + processAndSendError(ctx, postHookRunner, err, responseChan, provider.logger) + } else { + response := &schemas.BifrostResponse{ + Object: "audio.transcription.chunk", + Transcribe: &schemas.BifrostTranscribe{ + Text: fullTranscriptionText, + Usage: &schemas.TranscriptionUsage{ + Type: "tokens", + InputTokens: usage.InputTokens, + OutputTokens: usage.OutputTokens, + TotalTokens: usage.TotalTokens, + }, + }, + ExtraFields: schemas.BifrostResponseExtraFields{ + Provider: providerName, + ChunkIndex: chunkIndex + 1, + }, + } + + if params != nil { + response.ExtraFields.Params = *params + } + handleStreamEndWithSuccess(ctx, response, postHookRunner, responseChan, provider.logger) + } + }() + + return responseChan, nil +} + +// prepareGeminiGenerationRequest prepares the common request structure for Gemini API calls +func prepareGeminiGenerationRequest(input interface{}, params *schemas.ModelParameters, responseModalities []string) map[string]interface{} { + requestBody := map[string]interface{}{ + "generationConfig": map[string]interface{}{}, + } + + // Add response modalities if specified + if len(responseModalities) > 0 { + requestBody["generationConfig"].(map[string]interface{})["responseModalities"] = responseModalities + } + + // Map Bifrost parameters to Gemini generationConfig + if params != nil { + generationConfig := requestBody["generationConfig"].(map[string]interface{}) + + // Map standard parameters to Gemini generationConfig + if params.StopSequences != nil { + generationConfig["stopSequences"] = *params.StopSequences + } + if params.MaxTokens != nil { + generationConfig["maxOutputTokens"] = *params.MaxTokens + } + if params.Temperature != nil { + generationConfig["temperature"] = *params.Temperature + } + if params.TopP != nil { + generationConfig["topP"] = *params.TopP + } + if params.TopK != nil { + generationConfig["topK"] = *params.TopK + } + if params.PresencePenalty != nil { + generationConfig["presencePenalty"] = *params.PresencePenalty + } + if params.FrequencyPenalty != nil { + generationConfig["frequencyPenalty"] = *params.FrequencyPenalty + } + + // Handle tool-related parameters + if params.Tools != nil && len(*params.Tools) > 0 { + // Transform Bifrost tools to Gemini format + var geminiTools []map[string]interface{} + for _, tool := range *params.Tools { + if tool.Type == "function" { + geminiTool := map[string]interface{}{ + "functionDeclarations": []map[string]interface{}{ + { + "name": tool.Function.Name, + "description": tool.Function.Description, + "parameters": tool.Function.Parameters, + }, + }, + } + geminiTools = append(geminiTools, geminiTool) + } + } + + if len(geminiTools) > 0 { + requestBody["tools"] = geminiTools + + // Add toolConfig for Gemini + toolConfig := map[string]interface{}{} + + // Handle tool choice + if params.ToolChoice != nil { + functionCallingConfig := map[string]interface{}{} + + if params.ToolChoice.ToolChoiceStr != nil { + // Map string values to Gemini's enum values + switch *params.ToolChoice.ToolChoiceStr { + case "none": + functionCallingConfig["mode"] = "NONE" + case "auto": + functionCallingConfig["mode"] = "AUTO" + case "any": + functionCallingConfig["mode"] = "ANY" + case "required": + functionCallingConfig["mode"] = "ANY" + default: + functionCallingConfig["mode"] = "AUTO" + } + } else if params.ToolChoice.ToolChoiceStruct != nil { + switch params.ToolChoice.ToolChoiceStruct.Type { + case schemas.ToolChoiceTypeNone: + functionCallingConfig["mode"] = "NONE" + case schemas.ToolChoiceTypeAuto: + functionCallingConfig["mode"] = "AUTO" + case schemas.ToolChoiceTypeRequired: + functionCallingConfig["mode"] = "ANY" + case schemas.ToolChoiceTypeFunction: + functionCallingConfig["mode"] = "ANY" + default: + functionCallingConfig["mode"] = "AUTO" + } + + // Handle specific function selection if provided + if params.ToolChoice.ToolChoiceStruct.Function.Name != "" { + functionCallingConfig["allowedFunctionNames"] = []string{params.ToolChoice.ToolChoiceStruct.Function.Name} + } + } + + // Only add functionCallingConfig if it has content + if len(functionCallingConfig) > 0 { + toolConfig["functionCallingConfig"] = functionCallingConfig + } + } + + // Only add toolConfig if it has content + if len(toolConfig) > 0 { + requestBody["toolConfig"] = toolConfig + } + } + } + + // Add any extra parameters that might be Gemini-specific + if params.ExtraParams != nil { + requestBody = mergeConfig(requestBody, params.ExtraParams) + } + } + + // Add contents based on input type + switch v := input.(type) { + case *schemas.SpeechInput: + // Speech/TTS request + requestBody["contents"] = []map[string]interface{}{ + { + "parts": []map[string]interface{}{ + {"text": v.Input}, + }, + }, + } + addSpeechConfig(requestBody, v.VoiceConfig) + case *schemas.TranscriptionInput: + // Transcription request + parts := []map[string]interface{}{ + {"text": v.Prompt}, + } + + if len(v.File) > 0 { + if v.Format == nil { + v.Format = Ptr(detectAudioMimeType(v.File)) + } + parts = append(parts, map[string]interface{}{ + "inlineData": map[string]interface{}{ + "mimeType": *v.Format, + "data": v.File, + }, + }) + } + + requestBody["contents"] = []map[string]interface{}{ + {"parts": parts}, + } + case []schemas.BifrostMessage: + // Chat completion request + formattedMessages, _ := prepareOpenAIChatRequest(v, params) + requestBody["contents"] = formattedMessages + } + + return requestBody +} + +// addSpeechConfig adds speech configuration to the request body +func addSpeechConfig(requestBody map[string]interface{}, voiceConfig schemas.SpeechVoiceInput) { + speechConfig := map[string]interface{}{} + + // Handle single voice configuration + if voiceConfig.Voice != nil { + speechConfig["voiceConfig"] = map[string]interface{}{ + "prebuiltVoiceConfig": map[string]interface{}{ + "voiceName": *voiceConfig.Voice, + }, + } + } + + // Handle multi-speaker voice configuration + if len(voiceConfig.MultiVoiceConfig) > 0 { + var speakerVoiceConfigs []map[string]interface{} + for _, vc := range voiceConfig.MultiVoiceConfig { + speakerVoiceConfigs = append(speakerVoiceConfigs, map[string]interface{}{ + "speaker": vc.Speaker, + "voiceConfig": map[string]interface{}{ + "prebuiltVoiceConfig": map[string]interface{}{ + "voiceName": vc.Voice, + }, + }, + }) + } + + speechConfig["multiSpeakerVoiceConfig"] = map[string]interface{}{ + "speakerVoiceConfigs": speakerVoiceConfigs, + } + } + + // Add speech config to generation config if not empty + if len(speechConfig) > 0 { + requestBody["generationConfig"].(map[string]interface{})["speechConfig"] = speechConfig + } +} + +// processGeminiStreamChunk processes a single chunk from Gemini streaming response +func processGeminiStreamChunk(jsonData string) (*GenerateContentResponse, error) { + // First, check if this is an error response + var errorCheck map[string]interface{} + if err := sonic.Unmarshal([]byte(jsonData), &errorCheck); err != nil { + return nil, fmt.Errorf("failed to parse stream data as JSON: %v", err) + } + + // Handle error responses + if _, hasError := errorCheck["error"]; hasError { + return nil, fmt.Errorf("gemini api error: %v", errorCheck["error"]) + } + + // Parse Gemini streaming response + var geminiResponse GenerateContentResponse + if err := sonic.Unmarshal([]byte(jsonData), &geminiResponse); err != nil { + return nil, fmt.Errorf("failed to parse Gemini stream response: %v", err) + } + + return &geminiResponse, nil +} + +// extractGeminiUsageMetadata extracts usage metadata (as ints) from Gemini response +func extractGeminiUsageMetadata(geminiResponse *GenerateContentResponse) (int, int, int) { + var inputTokens, outputTokens, totalTokens int + if geminiResponse.UsageMetadata != nil { + usageMetadata := geminiResponse.UsageMetadata + inputTokens = int(usageMetadata.PromptTokenCount) + outputTokens = int(usageMetadata.CandidatesTokenCount) + totalTokens = int(usageMetadata.TotalTokenCount) + } + return inputTokens, outputTokens, totalTokens +} + +// completeRequest handles the common HTTP request pattern for Gemini API calls +func (provider *GeminiProvider) completeRequest(ctx context.Context, model string, key schemas.Key, requestBody map[string]interface{}, endpoint string, params *schemas.ModelParameters) (*schemas.BifrostResponse, *GenerateContentResponse, *schemas.BifrostError) { + providerName := provider.GetProviderKey() + + jsonBody, err := sonic.Marshal(requestBody) + if err != nil { + return nil, nil, newBifrostOperationError(schemas.ErrProviderJSONMarshaling, err, providerName) + } + + // Create request + req := fasthttp.AcquireRequest() + resp := fasthttp.AcquireResponse() + defer fasthttp.ReleaseRequest(req) + defer fasthttp.ReleaseResponse(resp) + + // Set any extra headers from network config + setExtraHeaders(req, provider.networkConfig.ExtraHeaders, nil) + + // Use Gemini's generateContent endpoint + req.SetRequestURI(provider.networkConfig.BaseURL + "/models/" + model + endpoint) + req.Header.SetMethod("POST") + req.Header.SetContentType("application/json") + req.Header.Set("x-goog-api-key", key.Value) + + req.SetBody(jsonBody) + + // Make request + bifrostErr := makeRequestWithContext(ctx, provider.client, req, resp) + if bifrostErr != nil { + return nil, nil, bifrostErr + } + + // Handle error response + if resp.StatusCode() != fasthttp.StatusOK { + return nil, nil, parseGeminiError(providerName, resp) + } + + responseBody := resp.Body() + + // Parse Gemini's response + var geminiResponse GenerateContentResponse + if err := sonic.Unmarshal(responseBody, &geminiResponse); err != nil { + return nil, nil, newBifrostOperationError(schemas.ErrProviderResponseUnmarshal, err, providerName) + } + + // Create base response + bifrostResponse := &schemas.BifrostResponse{ + Model: model, + ExtraFields: schemas.BifrostResponseExtraFields{ + Provider: providerName, + }, + } + + if params != nil { + bifrostResponse.ExtraFields.Params = *params + } + + // Set raw response if enabled + if provider.sendBackRawResponse { + var rawResponse interface{} + if err := sonic.Unmarshal(responseBody, &rawResponse); err == nil { + bifrostResponse.ExtraFields.RawResponse = rawResponse + } + } + + return bifrostResponse, &geminiResponse, nil +} + +// parseStreamGeminiError parses Gemini streaming error responses +func parseStreamGeminiError(providerName schemas.ModelProvider, resp *http.Response) *schemas.BifrostError { + body, err := io.ReadAll(resp.Body) + if err != nil { + return newBifrostOperationError("failed to read error response body", err, providerName) + } + + // Try to parse as JSON first + var errorResp map[string]interface{} + if err := sonic.Unmarshal(body, &errorResp); err == nil { + // Successfully parsed as JSON + return newBifrostOperationError(fmt.Sprintf("Gemini streaming error: %v", errorResp), fmt.Errorf("HTTP %d", resp.StatusCode), providerName) + } + + // If JSON parsing fails, treat as plain text + bodyStr := string(body) + if bodyStr == "" { + bodyStr = "empty response body" + } + + return newBifrostOperationError(fmt.Sprintf("Gemini streaming error (HTTP %d): %s", resp.StatusCode, bodyStr), fmt.Errorf("HTTP %d", resp.StatusCode), providerName) +} + +// parseGeminiError parses Gemini error responses +func parseGeminiError(providerName schemas.ModelProvider, resp *fasthttp.Response) *schemas.BifrostError { + var errorResp map[string]interface{} + body := resp.Body() + + if err := sonic.Unmarshal(body, &errorResp); err != nil { + return newBifrostOperationError("failed to parse error response", err, providerName) + } + + return newBifrostOperationError(fmt.Sprintf("Gemini error: %v", errorResp), fmt.Errorf("HTTP %d", resp.StatusCode()), providerName) +} diff --git a/core/providers/groq.go b/core/providers/groq.go new file mode 100644 index 000000000..93d34e11b --- /dev/null +++ b/core/providers/groq.go @@ -0,0 +1,226 @@ +// Package providers implements various LLM providers and their utility functions. +// This file contains the Groq provider implementation. +package providers + +import ( + "context" + "fmt" + "net/http" + "strings" + "time" + + "github.com/bytedance/sonic" + schemas "github.com/maximhq/bifrost/core/schemas" + "github.com/valyala/fasthttp" +) + +// // groqResponsePool provides a pool for Groq response objects. +// var groqResponsePool = sync.Pool{ +// New: func() interface{} { +// return &schemas.BifrostResponse{} +// }, +// } + +// // acquireGroqResponse gets a Groq response from the pool and resets it. +// func acquireGroqResponse() *schemas.BifrostResponse { +// resp := groqResponsePool.Get().(*schemas.BifrostResponse) +// *resp = schemas.BifrostResponse{} // Reset the struct +// return resp +// } + +// // releaseGroqResponse returns a Groq response to the pool. +// func releaseGroqResponse(resp *schemas.BifrostResponse) { +// if resp != nil { +// groqResponsePool.Put(resp) +// } +// } + +// GroqProvider implements the Provider interface for Groq's API. +type GroqProvider struct { + logger schemas.Logger // Logger for provider operations + client *fasthttp.Client // HTTP client for API requests + streamClient *http.Client // HTTP client for streaming requests + networkConfig schemas.NetworkConfig // Network configuration including extra headers + sendBackRawResponse bool // Whether to include raw response in BifrostResponse +} + +// NewGroqProvider creates a new Groq provider instance. +// It initializes the HTTP client with the provided configuration and sets up response pools. +// The client is configured with timeouts, concurrency limits, and optional proxy settings. +func NewGroqProvider(config *schemas.ProviderConfig, logger schemas.Logger) (*GroqProvider, error) { + config.CheckAndSetDefaults() + + client := &fasthttp.Client{ + ReadTimeout: time.Second * time.Duration(config.NetworkConfig.DefaultRequestTimeoutInSeconds), + WriteTimeout: time.Second * time.Duration(config.NetworkConfig.DefaultRequestTimeoutInSeconds), + MaxConnsPerHost: config.ConcurrencyAndBufferSize.BufferSize, + } + + // Initialize streaming HTTP client + streamClient := &http.Client{ + Timeout: time.Second * time.Duration(config.NetworkConfig.DefaultRequestTimeoutInSeconds), + } + + // // Pre-warm response pools + // for range config.ConcurrencyAndBufferSize.Concurrency { + // groqResponsePool.Put(&schemas.BifrostResponse{}) + // } + + // Configure proxy if provided + client = configureProxy(client, config.ProxyConfig, logger) + + // Set default BaseURL if not provided + if config.NetworkConfig.BaseURL == "" { + config.NetworkConfig.BaseURL = "https://api.groq.com/openai" + } + config.NetworkConfig.BaseURL = strings.TrimRight(config.NetworkConfig.BaseURL, "/") + + return &GroqProvider{ + logger: logger, + client: client, + streamClient: streamClient, + networkConfig: config.NetworkConfig, + sendBackRawResponse: config.SendBackRawResponse, + }, nil +} + +// GetProviderKey returns the provider identifier for Groq. +func (provider *GroqProvider) GetProviderKey() schemas.ModelProvider { + return schemas.Groq +} + +// TextCompletion is not supported by the Groq provider. +func (provider *GroqProvider) TextCompletion(ctx context.Context, model string, key schemas.Key, text string, params *schemas.ModelParameters) (*schemas.BifrostResponse, *schemas.BifrostError) { + return nil, newUnsupportedOperationError("text completion", "groq") +} + +// ChatCompletion performs a chat completion request to the Groq API. +func (provider *GroqProvider) ChatCompletion(ctx context.Context, model string, key schemas.Key, messages []schemas.BifrostMessage, params *schemas.ModelParameters) (*schemas.BifrostResponse, *schemas.BifrostError) { + formattedMessages, preparedParams := prepareOpenAIChatRequest(messages, params) + + requestBody := mergeConfig(map[string]interface{}{ + "model": model, + "messages": formattedMessages, + }, preparedParams) + + jsonBody, err := sonic.Marshal(requestBody) + if err != nil { + return nil, newBifrostOperationError(schemas.ErrProviderJSONMarshaling, err, schemas.Groq) + } + + // Create request + req := fasthttp.AcquireRequest() + resp := fasthttp.AcquireResponse() + defer fasthttp.ReleaseRequest(req) + defer fasthttp.ReleaseResponse(resp) + + // Set any extra headers from network config + setExtraHeaders(req, provider.networkConfig.ExtraHeaders, nil) + + req.SetRequestURI(provider.networkConfig.BaseURL + "/v1/chat/completions") + req.Header.SetMethod("POST") + req.Header.SetContentType("application/json") + req.Header.Set("Authorization", "Bearer "+key.Value) + + req.SetBody(jsonBody) + + // Make request + bifrostErr := makeRequestWithContext(ctx, provider.client, req, resp) + if bifrostErr != nil { + return nil, bifrostErr + } + + // Handle error response + if resp.StatusCode() != fasthttp.StatusOK { + provider.logger.Debug(fmt.Sprintf("error from groq provider: %s", string(resp.Body()))) + + var errorResp map[string]interface{} + bifrostErr := handleProviderAPIError(resp, &errorResp) + bifrostErr.Error.Message = fmt.Sprintf("Groq error: %v", errorResp) + return nil, bifrostErr + } + + responseBody := resp.Body() + + // Pre-allocate response structs from pools + // response := acquireGroqResponse() + // defer releaseGroqResponse(response) + response := &schemas.BifrostResponse{} + + // Use enhanced response handler with pre-allocated response + rawResponse, bifrostErr := handleProviderResponse(responseBody, response, provider.sendBackRawResponse) + if bifrostErr != nil { + return nil, bifrostErr + } + + // Create final response + response.ExtraFields.Provider = schemas.Groq + + if provider.sendBackRawResponse { + response.ExtraFields.RawResponse = rawResponse + } + + if params != nil { + response.ExtraFields.Params = *params + } + + return response, nil +} + +// Embedding is not supported by the Groq provider. +func (provider *GroqProvider) Embedding(ctx context.Context, model string, key schemas.Key, input *schemas.EmbeddingInput, params *schemas.ModelParameters) (*schemas.BifrostResponse, *schemas.BifrostError) { + return nil, newUnsupportedOperationError("embedding", "groq") +} + +// ChatCompletionStream performs a streaming chat completion request to the Groq API. +// It supports real-time streaming of responses using Server-Sent Events (SSE). +// Uses Groq's OpenAI-compatible streaming format. +// Returns a channel containing BifrostResponse objects representing the stream or an error if the request fails. +func (provider *GroqProvider) ChatCompletionStream(ctx context.Context, postHookRunner schemas.PostHookRunner, model string, key schemas.Key, messages []schemas.BifrostMessage, params *schemas.ModelParameters) (chan *schemas.BifrostStream, *schemas.BifrostError) { + formattedMessages, preparedParams := prepareOpenAIChatRequest(messages, params) + + requestBody := mergeConfig(map[string]interface{}{ + "model": model, + "messages": formattedMessages, + "stream": true, + }, preparedParams) + + // Prepare Groq headers + headers := map[string]string{ + "Content-Type": "application/json", + "Accept": "text/event-stream", + "Cache-Control": "no-cache", + } + + headers["Authorization"] = "Bearer " + key.Value + + // Use shared OpenAI-compatible streaming logic + return handleOpenAIStreaming( + ctx, + provider.streamClient, + provider.networkConfig.BaseURL+"/v1/chat/completions", + requestBody, + headers, + provider.networkConfig.ExtraHeaders, + schemas.Groq, + params, + postHookRunner, + provider.logger, + ) +} + +func (provider *GroqProvider) Speech(ctx context.Context, model string, key schemas.Key, input *schemas.SpeechInput, params *schemas.ModelParameters) (*schemas.BifrostResponse, *schemas.BifrostError) { + return nil, newUnsupportedOperationError("speech", "groq") +} + +func (provider *GroqProvider) SpeechStream(ctx context.Context, postHookRunner schemas.PostHookRunner, model string, key schemas.Key, input *schemas.SpeechInput, params *schemas.ModelParameters) (chan *schemas.BifrostStream, *schemas.BifrostError) { + return nil, newUnsupportedOperationError("speech stream", "groq") +} + +func (provider *GroqProvider) Transcription(ctx context.Context, model string, key schemas.Key, input *schemas.TranscriptionInput, params *schemas.ModelParameters) (*schemas.BifrostResponse, *schemas.BifrostError) { + return nil, newUnsupportedOperationError("transcription", "groq") +} + +func (provider *GroqProvider) TranscriptionStream(ctx context.Context, postHookRunner schemas.PostHookRunner, model string, key schemas.Key, input *schemas.TranscriptionInput, params *schemas.ModelParameters) (chan *schemas.BifrostStream, *schemas.BifrostError) { + return nil, newUnsupportedOperationError("transcription stream", "groq") +} diff --git a/core/providers/mistral.go b/core/providers/mistral.go new file mode 100644 index 000000000..cdca1a863 --- /dev/null +++ b/core/providers/mistral.go @@ -0,0 +1,311 @@ +// Package providers implements various LLM providers and their utility functions. +// This file contains the Mistral provider implementation. +package providers + +import ( + "context" + "fmt" + "net/http" + "strings" + "time" + + "github.com/bytedance/sonic" + schemas "github.com/maximhq/bifrost/core/schemas" + "github.com/valyala/fasthttp" +) + +// // mistralResponsePool provides a pool for Mistral response objects. +// var mistralResponsePool = sync.Pool{ +// New: func() interface{} { +// return &schemas.BifrostResponse{} +// }, +// } + +// // acquireMistralResponse gets a Mistral response from the pool and resets it. +// func acquireMistralResponse() *schemas.BifrostResponse { +// resp := mistralResponsePool.Get().(*schemas.BifrostResponse) +// *resp = schemas.BifrostResponse{} // Reset the struct +// return resp +// } + +// // releaseMistralResponse returns a Mistral response to the pool. +// func releaseMistralResponse(resp *schemas.BifrostResponse) { +// if resp != nil { +// mistralResponsePool.Put(resp) +// } +// } + +// MistralProvider implements the Provider interface for Mistral's API. +type MistralProvider struct { + logger schemas.Logger // Logger for provider operations + client *fasthttp.Client // HTTP client for API requests + streamClient *http.Client // HTTP client for streaming requests + networkConfig schemas.NetworkConfig // Network configuration including extra headers + sendBackRawResponse bool // Whether to include raw response in BifrostResponse +} + +// NewMistralProvider creates a new Mistral provider instance. +// It initializes the HTTP client with the provided configuration and sets up response pools. +// The client is configured with timeouts, concurrency limits, and optional proxy settings. +func NewMistralProvider(config *schemas.ProviderConfig, logger schemas.Logger) *MistralProvider { + config.CheckAndSetDefaults() + + client := &fasthttp.Client{ + ReadTimeout: time.Second * time.Duration(config.NetworkConfig.DefaultRequestTimeoutInSeconds), + WriteTimeout: time.Second * time.Duration(config.NetworkConfig.DefaultRequestTimeoutInSeconds), + MaxConnsPerHost: config.ConcurrencyAndBufferSize.Concurrency, + } + + // Initialize streaming HTTP client + streamClient := &http.Client{ + Timeout: time.Second * time.Duration(config.NetworkConfig.DefaultRequestTimeoutInSeconds), + } + + // Pre-warm response pools + // for range config.ConcurrencyAndBufferSize.Concurrency { + // mistralResponsePool.Put(&schemas.BifrostResponse{}) + // } + + // Configure proxy if provided + client = configureProxy(client, config.ProxyConfig, logger) + + // Set default BaseURL if not provided + if config.NetworkConfig.BaseURL == "" { + config.NetworkConfig.BaseURL = "https://api.mistral.ai" + } + config.NetworkConfig.BaseURL = strings.TrimRight(config.NetworkConfig.BaseURL, "/") + + return &MistralProvider{ + logger: logger, + client: client, + streamClient: streamClient, + networkConfig: config.NetworkConfig, + sendBackRawResponse: config.SendBackRawResponse, + } +} + +// GetProviderKey returns the provider identifier for Mistral. +func (provider *MistralProvider) GetProviderKey() schemas.ModelProvider { + return schemas.Mistral +} + +// TextCompletion is not supported by the Mistral provider. +func (provider *MistralProvider) TextCompletion(ctx context.Context, model string, key schemas.Key, text string, params *schemas.ModelParameters) (*schemas.BifrostResponse, *schemas.BifrostError) { + return nil, newUnsupportedOperationError("text completion", "mistral") +} + +// ChatCompletion performs a chat completion request to the Mistral API. +func (provider *MistralProvider) ChatCompletion(ctx context.Context, model string, key schemas.Key, messages []schemas.BifrostMessage, params *schemas.ModelParameters) (*schemas.BifrostResponse, *schemas.BifrostError) { + formattedMessages, preparedParams := prepareOpenAIChatRequest(messages, params) + + requestBody := mergeConfig(map[string]interface{}{ + "model": model, + "messages": formattedMessages, + }, preparedParams) + + jsonBody, err := sonic.Marshal(requestBody) + if err != nil { + return nil, newBifrostOperationError(schemas.ErrProviderJSONMarshaling, err, schemas.Mistral) + } + + // Create request + req := fasthttp.AcquireRequest() + resp := fasthttp.AcquireResponse() + defer fasthttp.ReleaseRequest(req) + defer fasthttp.ReleaseResponse(resp) + + // Set any extra headers from network config + setExtraHeaders(req, provider.networkConfig.ExtraHeaders, nil) + + req.SetRequestURI(provider.networkConfig.BaseURL + "/v1/chat/completions") + req.Header.SetMethod("POST") + req.Header.SetContentType("application/json") + req.Header.Set("Authorization", "Bearer "+key.Value) + + req.SetBody(jsonBody) + + // Make request + bifrostErr := makeRequestWithContext(ctx, provider.client, req, resp) + if bifrostErr != nil { + return nil, bifrostErr + } + + // Handle error response + if resp.StatusCode() != fasthttp.StatusOK { + provider.logger.Debug(fmt.Sprintf("error from mistral provider: %s", string(resp.Body()))) + + var errorResp map[string]interface{} + bifrostErr := handleProviderAPIError(resp, &errorResp) + bifrostErr.Error.Message = fmt.Sprintf("Mistral error: %v", errorResp) + return nil, bifrostErr + } + + responseBody := resp.Body() + + // Pre-allocate response structs from pools + // response := acquireMistralResponse() + // defer releaseMistralResponse(response) + response := &schemas.BifrostResponse{} + + // Use enhanced response handler with pre-allocated response + rawResponse, bifrostErr := handleProviderResponse(responseBody, response, provider.sendBackRawResponse) + if bifrostErr != nil { + return nil, bifrostErr + } + + response.ExtraFields.Provider = schemas.Mistral + + if provider.sendBackRawResponse { + response.ExtraFields.RawResponse = rawResponse + } + + if params != nil { + response.ExtraFields.Params = *params + } + + return response, nil +} + +// Embedding generates embeddings for the given input text(s) using the Mistral API. +// Supports Mistral's embedding models and returns a BifrostResponse containing the embedding(s). +func (provider *MistralProvider) Embedding(ctx context.Context, model string, key schemas.Key, input *schemas.EmbeddingInput, params *schemas.ModelParameters) (*schemas.BifrostResponse, *schemas.BifrostError) { + // Prepare request body with base parameters + requestBody := map[string]interface{}{ + "model": model, + "input": input, + } + + // Merge any additional parameters + if params != nil { + // Validate encoding format - Mistral API supports multiple formats, but our provider only implements float + if params.EncodingFormat != nil { + if *params.EncodingFormat != "float" { + return nil, newConfigurationError(fmt.Sprintf("Mistral provider currently only supports 'float' encoding format, received: %s", *params.EncodingFormat), schemas.Mistral) + } + // Map to Mistral's parameter name + requestBody["output_dtype"] = *params.EncodingFormat + } + + // Map dimensions to Mistral's parameter name + if params.Dimensions != nil { + requestBody["output_dimension"] = *params.Dimensions + } + + // Merge any extra parameters + requestBody = mergeConfig(requestBody, params.ExtraParams) + } + + jsonBody, err := sonic.Marshal(requestBody) + if err != nil { + return nil, newBifrostOperationError(schemas.ErrProviderJSONMarshaling, err, schemas.Mistral) + } + + // Create request + req := fasthttp.AcquireRequest() + resp := fasthttp.AcquireResponse() + defer fasthttp.ReleaseRequest(req) + defer fasthttp.ReleaseResponse(resp) + + // Set any extra headers from network config + setExtraHeaders(req, provider.networkConfig.ExtraHeaders, nil) + + req.SetRequestURI(provider.networkConfig.BaseURL + "/v1/embeddings") + req.Header.SetMethod("POST") + req.Header.SetContentType("application/json") + req.Header.Set("Authorization", "Bearer "+key.Value) + + req.SetBody(jsonBody) + + // Make request + bifrostErr := makeRequestWithContext(ctx, provider.client, req, resp) + if bifrostErr != nil { + return nil, bifrostErr + } + + // Handle error response + if resp.StatusCode() != fasthttp.StatusOK { + provider.logger.Debug(fmt.Sprintf("error from mistral embedding provider: %s", string(resp.Body()))) + + var errorResp map[string]interface{} + bifrostErr := handleProviderAPIError(resp, &errorResp) + bifrostErr.Error.Message = fmt.Sprintf("Mistral embedding error: %v", errorResp) + return nil, bifrostErr + } + + responseBody := resp.Body() + + // Pre-allocate response structs from pools + // response := acquireMistralResponse() + response := &schemas.BifrostResponse{} + // defer releaseMistralResponse(response) + + // Use enhanced response handler with pre-allocated response + rawResponse, bifrostErr := handleProviderResponse(responseBody, response, provider.sendBackRawResponse) + if bifrostErr != nil { + return nil, bifrostErr + } + + response.ExtraFields.Provider = schemas.Mistral + + if params != nil { + response.ExtraFields.Params = *params + } + + if provider.sendBackRawResponse { + response.ExtraFields.RawResponse = rawResponse + } + + return response, nil +} + +// ChatCompletionStream performs a streaming chat completion request to the Mistral API. +// It supports real-time streaming of responses using Server-Sent Events (SSE). +// Uses Mistral's OpenAI-compatible streaming format. +// Returns a channel containing BifrostResponse objects representing the stream or an error if the request fails. +func (provider *MistralProvider) ChatCompletionStream(ctx context.Context, postHookRunner schemas.PostHookRunner, model string, key schemas.Key, messages []schemas.BifrostMessage, params *schemas.ModelParameters) (chan *schemas.BifrostStream, *schemas.BifrostError) { + formattedMessages, preparedParams := prepareOpenAIChatRequest(messages, params) + + requestBody := mergeConfig(map[string]interface{}{ + "model": model, + "messages": formattedMessages, + "stream": true, + }, preparedParams) + + // Prepare Mistral headers + headers := map[string]string{ + "Content-Type": "application/json", + "Authorization": "Bearer " + key.Value, + "Accept": "text/event-stream", + "Cache-Control": "no-cache", + } + + // Use shared OpenAI-compatible streaming logic + return handleOpenAIStreaming( + ctx, + provider.streamClient, + provider.networkConfig.BaseURL+"/v1/chat/completions", + requestBody, + headers, + provider.networkConfig.ExtraHeaders, + schemas.Mistral, + params, + postHookRunner, + provider.logger, + ) +} + +func (provider *MistralProvider) Speech(ctx context.Context, model string, key schemas.Key, input *schemas.SpeechInput, params *schemas.ModelParameters) (*schemas.BifrostResponse, *schemas.BifrostError) { + return nil, newUnsupportedOperationError("speech", "mistral") +} + +func (provider *MistralProvider) SpeechStream(ctx context.Context, postHookRunner schemas.PostHookRunner, model string, key schemas.Key, input *schemas.SpeechInput, params *schemas.ModelParameters) (chan *schemas.BifrostStream, *schemas.BifrostError) { + return nil, newUnsupportedOperationError("speech stream", "mistral") +} + +func (provider *MistralProvider) Transcription(ctx context.Context, model string, key schemas.Key, input *schemas.TranscriptionInput, params *schemas.ModelParameters) (*schemas.BifrostResponse, *schemas.BifrostError) { + return nil, newUnsupportedOperationError("transcription", "mistral") +} + +func (provider *MistralProvider) TranscriptionStream(ctx context.Context, postHookRunner schemas.PostHookRunner, model string, key schemas.Key, input *schemas.TranscriptionInput, params *schemas.ModelParameters) (chan *schemas.BifrostStream, *schemas.BifrostError) { + return nil, newUnsupportedOperationError("transcription stream", "mistral") +} diff --git a/core/providers/ollama.go b/core/providers/ollama.go new file mode 100644 index 000000000..ff0a5b374 --- /dev/null +++ b/core/providers/ollama.go @@ -0,0 +1,231 @@ +// Package providers implements various LLM providers and their utility functions. +// This file contains the Ollama provider implementation. +package providers + +import ( + "context" + "fmt" + "net/http" + "strings" + "time" + + "github.com/bytedance/sonic" + schemas "github.com/maximhq/bifrost/core/schemas" + "github.com/valyala/fasthttp" +) + +// // ollamaResponsePool provides a pool for Ollama response objects. +// var ollamaResponsePool = sync.Pool{ +// New: func() interface{} { +// return &schemas.BifrostResponse{} +// }, +// } + +// // acquireOllamaResponse gets a Ollama response from the pool and resets it. +// func acquireOllamaResponse() *schemas.BifrostResponse { +// resp := ollamaResponsePool.Get().(*schemas.BifrostResponse) +// *resp = schemas.BifrostResponse{} // Reset the struct +// return resp +// } + +// // releaseOllamaResponse returns a Ollama response to the pool. +// func releaseOllamaResponse(resp *schemas.BifrostResponse) { +// if resp != nil { +// ollamaResponsePool.Put(resp) +// } +// } + +// OllamaProvider implements the Provider interface for Ollama's API. +type OllamaProvider struct { + logger schemas.Logger // Logger for provider operations + client *fasthttp.Client // HTTP client for API requests + streamClient *http.Client // HTTP client for streaming requests + networkConfig schemas.NetworkConfig // Network configuration including extra headers + sendBackRawResponse bool // Whether to include raw response in BifrostResponse +} + +// NewOllamaProvider creates a new Ollama provider instance. +// It initializes the HTTP client with the provided configuration and sets up response pools. +// The client is configured with timeouts, concurrency limits, and optional proxy settings. +func NewOllamaProvider(config *schemas.ProviderConfig, logger schemas.Logger) (*OllamaProvider, error) { + config.CheckAndSetDefaults() + + client := &fasthttp.Client{ + ReadTimeout: time.Second * time.Duration(config.NetworkConfig.DefaultRequestTimeoutInSeconds), + WriteTimeout: time.Second * time.Duration(config.NetworkConfig.DefaultRequestTimeoutInSeconds), + MaxConnsPerHost: config.ConcurrencyAndBufferSize.BufferSize, + } + + // Initialize streaming HTTP client + streamClient := &http.Client{ + Timeout: time.Second * time.Duration(config.NetworkConfig.DefaultRequestTimeoutInSeconds), + } + + // // Pre-warm response pools + // for range config.ConcurrencyAndBufferSize.Concurrency { + // ollamaResponsePool.Put(&schemas.BifrostResponse{}) + // } + + // Configure proxy if provided + client = configureProxy(client, config.ProxyConfig, logger) + + config.NetworkConfig.BaseURL = strings.TrimRight(config.NetworkConfig.BaseURL, "/") + + // BaseURL is required for Ollama + if config.NetworkConfig.BaseURL == "" { + return nil, fmt.Errorf("base_url is required for ollama provider") + } + + return &OllamaProvider{ + logger: logger, + client: client, + streamClient: streamClient, + networkConfig: config.NetworkConfig, + sendBackRawResponse: config.SendBackRawResponse, + }, nil +} + +// GetProviderKey returns the provider identifier for Ollama. +func (provider *OllamaProvider) GetProviderKey() schemas.ModelProvider { + return schemas.Ollama +} + +// TextCompletion is not supported by the Ollama provider. +func (provider *OllamaProvider) TextCompletion(ctx context.Context, model string, key schemas.Key, text string, params *schemas.ModelParameters) (*schemas.BifrostResponse, *schemas.BifrostError) { + return nil, newUnsupportedOperationError("text completion", "ollama") +} + +// ChatCompletion performs a chat completion request to the Ollama API. +func (provider *OllamaProvider) ChatCompletion(ctx context.Context, model string, key schemas.Key, messages []schemas.BifrostMessage, params *schemas.ModelParameters) (*schemas.BifrostResponse, *schemas.BifrostError) { + formattedMessages, preparedParams := prepareOpenAIChatRequest(messages, params) + + requestBody := mergeConfig(map[string]interface{}{ + "model": model, + "messages": formattedMessages, + }, preparedParams) + + jsonBody, err := sonic.Marshal(requestBody) + if err != nil { + return nil, newBifrostOperationError(schemas.ErrProviderJSONMarshaling, err, schemas.Ollama) + } + + // Create request + req := fasthttp.AcquireRequest() + resp := fasthttp.AcquireResponse() + defer fasthttp.ReleaseRequest(req) + defer fasthttp.ReleaseResponse(resp) + + // Set any extra headers from network config + setExtraHeaders(req, provider.networkConfig.ExtraHeaders, nil) + + req.SetRequestURI(provider.networkConfig.BaseURL + "/v1/chat/completions") + req.Header.SetMethod("POST") + req.Header.SetContentType("application/json") + if key.Value != "" { + req.Header.Set("Authorization", "Bearer "+key.Value) + } + + req.SetBody(jsonBody) + + // Make request + bifrostErr := makeRequestWithContext(ctx, provider.client, req, resp) + if bifrostErr != nil { + return nil, bifrostErr + } + + // Handle error response + if resp.StatusCode() != fasthttp.StatusOK { + provider.logger.Debug(fmt.Sprintf("error from ollama provider: %s", string(resp.Body()))) + + var errorResp map[string]interface{} + bifrostErr := handleProviderAPIError(resp, &errorResp) + bifrostErr.Error.Message = fmt.Sprintf("Ollama error: %v", errorResp) + return nil, bifrostErr + } + + responseBody := resp.Body() + + // Pre-allocate response structs from pools + // response := acquireOllamaResponse() + // defer releaseOllamaResponse(response) + response := &schemas.BifrostResponse{} + + // Use enhanced response handler with pre-allocated response + rawResponse, bifrostErr := handleProviderResponse(responseBody, response, provider.sendBackRawResponse) + if bifrostErr != nil { + return nil, bifrostErr + } + + response.ExtraFields.Provider = schemas.Ollama + + if provider.sendBackRawResponse { + response.ExtraFields.RawResponse = rawResponse + } + + if params != nil { + response.ExtraFields.Params = *params + } + + return response, nil +} + +// Embedding is not supported by the Ollama provider. +func (provider *OllamaProvider) Embedding(ctx context.Context, model string, key schemas.Key, input *schemas.EmbeddingInput, params *schemas.ModelParameters) (*schemas.BifrostResponse, *schemas.BifrostError) { + return nil, newUnsupportedOperationError("embedding", "ollama") +} + +// ChatCompletionStream performs a streaming chat completion request to the Ollama API. +// It supports real-time streaming of responses using Server-Sent Events (SSE). +// Uses Ollama's OpenAI-compatible streaming format. +// Returns a channel containing BifrostResponse objects representing the stream or an error if the request fails. +func (provider *OllamaProvider) ChatCompletionStream(ctx context.Context, postHookRunner schemas.PostHookRunner, model string, key schemas.Key, messages []schemas.BifrostMessage, params *schemas.ModelParameters) (chan *schemas.BifrostStream, *schemas.BifrostError) { + formattedMessages, preparedParams := prepareOpenAIChatRequest(messages, params) + + requestBody := mergeConfig(map[string]interface{}{ + "model": model, + "messages": formattedMessages, + "stream": true, + }, preparedParams) + + // Prepare Ollama headers (Ollama typically doesn't require authorization, but we include it if provided) + headers := map[string]string{ + "Content-Type": "application/json", + "Accept": "text/event-stream", + "Cache-Control": "no-cache", + } + + // Only add Authorization header if key is provided (Ollama can run without auth) + if key.Value != "" { + headers["Authorization"] = "Bearer " + key.Value + } + + // Use shared OpenAI-compatible streaming logic + return handleOpenAIStreaming( + ctx, + provider.streamClient, + provider.networkConfig.BaseURL+"/v1/chat/completions", + requestBody, + headers, + provider.networkConfig.ExtraHeaders, + schemas.Ollama, + params, + postHookRunner, + provider.logger, + ) +} + +func (provider *OllamaProvider) Speech(ctx context.Context, model string, key schemas.Key, input *schemas.SpeechInput, params *schemas.ModelParameters) (*schemas.BifrostResponse, *schemas.BifrostError) { + return nil, newUnsupportedOperationError("speech", "ollama") +} + +func (provider *OllamaProvider) SpeechStream(ctx context.Context, postHookRunner schemas.PostHookRunner, model string, key schemas.Key, input *schemas.SpeechInput, params *schemas.ModelParameters) (chan *schemas.BifrostStream, *schemas.BifrostError) { + return nil, newUnsupportedOperationError("speech stream", "ollama") +} + +func (provider *OllamaProvider) Transcription(ctx context.Context, model string, key schemas.Key, input *schemas.TranscriptionInput, params *schemas.ModelParameters) (*schemas.BifrostResponse, *schemas.BifrostError) { + return nil, newUnsupportedOperationError("transcription", "ollama") +} + +func (provider *OllamaProvider) TranscriptionStream(ctx context.Context, postHookRunner schemas.PostHookRunner, model string, key schemas.Key, input *schemas.TranscriptionInput, params *schemas.ModelParameters) (chan *schemas.BifrostStream, *schemas.BifrostError) { + return nil, newUnsupportedOperationError("transcription stream", "ollama") +} diff --git a/core/providers/openai.go b/core/providers/openai.go index ff96a69b7..9bae45802 100644 --- a/core/providers/openai.go +++ b/core/providers/openai.go @@ -3,171 +3,579 @@ package providers import ( - "sync" + "bufio" + "bytes" + "context" + "fmt" + "io" + "mime/multipart" + "net/http" + "strings" "time" - "github.com/goccy/go-json" - + "github.com/bytedance/sonic" schemas "github.com/maximhq/bifrost/core/schemas" "github.com/valyala/fasthttp" ) -// OpenAIResponse represents the response structure from the OpenAI API. -// It includes completion choices, model information, and usage statistics. -type OpenAIResponse struct { - ID string `json:"id"` // Unique identifier for the completion - Object string `json:"object"` // Type of completion (text.completion or chat.completion) - Choices []schemas.BifrostResponseChoice `json:"choices"` // Array of completion choices - Model string `json:"model"` // Model used for the completion - Created int `json:"created"` // Unix timestamp of completion creation - ServiceTier *string `json:"service_tier"` // Service tier used for the request - SystemFingerprint *string `json:"system_fingerprint"` // System fingerprint for the request - Usage schemas.LLMUsage `json:"usage"` // Token usage statistics -} - -// OpenAIError represents the error response structure from the OpenAI API. -// It includes detailed error information and event tracking. -type OpenAIError struct { - EventID string `json:"event_id"` // Unique identifier for the error event - Type string `json:"type"` // Type of error - Error struct { - Type string `json:"type"` // Error type - Code string `json:"code"` // Error code - Message string `json:"message"` // Error message - Param interface{} `json:"param"` // Parameter that caused the error - EventID string `json:"event_id"` // Event ID for tracking - } `json:"error"` -} +// // openAIResponsePool provides a pool for OpenAI response objects. +// var openAIResponsePool = sync.Pool{ +// New: func() interface{} { +// return &schemas.BifrostResponse{} +// }, +// } -// openAIResponsePool provides a pool for OpenAI response objects. -var openAIResponsePool = sync.Pool{ - New: func() interface{} { - return &OpenAIResponse{} - }, -} - -// acquireOpenAIResponse gets an OpenAI response from the pool and resets it. -func acquireOpenAIResponse() *OpenAIResponse { - resp := openAIResponsePool.Get().(*OpenAIResponse) - *resp = OpenAIResponse{} // Reset the struct - return resp -} +// // acquireOpenAIResponse gets an OpenAI response from the pool and resets it. +// func acquireOpenAIResponse() *schemas.BifrostResponse { +// resp := openAIResponsePool.Get().(*schemas.BifrostResponse) +// *resp = schemas.BifrostResponse{} // Reset the struct +// return resp +// } -// releaseOpenAIResponse returns an OpenAI response to the pool. -func releaseOpenAIResponse(resp *OpenAIResponse) { - if resp != nil { - openAIResponsePool.Put(resp) - } -} +// // releaseOpenAIResponse returns an OpenAI response to the pool. +// func releaseOpenAIResponse(resp *schemas.BifrostResponse) { +// if resp != nil { +// openAIResponsePool.Put(resp) +// } +// } -// OpenAIProvider implements the Provider interface for OpenAI's API. +// OpenAIProvider implements the Provider interface for OpenAI's GPT API. type OpenAIProvider struct { - logger schemas.Logger // Logger for provider operations - client *fasthttp.Client // HTTP client for API requests + logger schemas.Logger // Logger for provider operations + client *fasthttp.Client // HTTP client for API requests + streamClient *http.Client // HTTP client for streaming requests + networkConfig schemas.NetworkConfig // Network configuration including extra headers + sendBackRawResponse bool // Whether to include raw response in BifrostResponse + customProviderConfig *schemas.CustomProviderConfig // Custom provider config } // NewOpenAIProvider creates a new OpenAI provider instance. // It initializes the HTTP client with the provided configuration and sets up response pools. // The client is configured with timeouts, concurrency limits, and optional proxy settings. func NewOpenAIProvider(config *schemas.ProviderConfig, logger schemas.Logger) *OpenAIProvider { - setConfigDefaults(config) + config.CheckAndSetDefaults() client := &fasthttp.Client{ ReadTimeout: time.Second * time.Duration(config.NetworkConfig.DefaultRequestTimeoutInSeconds), WriteTimeout: time.Second * time.Duration(config.NetworkConfig.DefaultRequestTimeoutInSeconds), - MaxConnsPerHost: config.ConcurrencyAndBufferSize.BufferSize, + MaxConnsPerHost: config.ConcurrencyAndBufferSize.Concurrency, } - // Pre-warm response pools - for range config.ConcurrencyAndBufferSize.Concurrency { - openAIResponsePool.Put(&OpenAIResponse{}) - bifrostResponsePool.Put(&schemas.BifrostResponse{}) + // Initialize streaming HTTP client + streamClient := &http.Client{ + Timeout: time.Second * time.Duration(config.NetworkConfig.DefaultRequestTimeoutInSeconds), } + // // Pre-warm response pools + // for range config.ConcurrencyAndBufferSize.Concurrency { + // openAIResponsePool.Put(&schemas.BifrostResponse{}) + // } + // Configure proxy if provided client = configureProxy(client, config.ProxyConfig, logger) + // Set default BaseURL if not provided + if config.NetworkConfig.BaseURL == "" { + config.NetworkConfig.BaseURL = "https://api.openai.com" + } + config.NetworkConfig.BaseURL = strings.TrimRight(config.NetworkConfig.BaseURL, "/") + return &OpenAIProvider{ - logger: logger, - client: client, + logger: logger, + client: client, + streamClient: streamClient, + networkConfig: config.NetworkConfig, + sendBackRawResponse: config.SendBackRawResponse, + customProviderConfig: config.CustomProviderConfig, } } // GetProviderKey returns the provider identifier for OpenAI. func (provider *OpenAIProvider) GetProviderKey() schemas.ModelProvider { - return schemas.OpenAI + return getProviderName(schemas.OpenAI, provider.customProviderConfig) } // TextCompletion is not supported by the OpenAI provider. // Returns an error indicating that text completion is not available. -func (provider *OpenAIProvider) TextCompletion(model, key, text string, params *schemas.ModelParameters) (*schemas.BifrostResponse, *schemas.BifrostError) { - return nil, &schemas.BifrostError{ - IsBifrostError: false, - Error: schemas.ErrorField{ - Message: "text completion is not supported by openai provider", - }, - } +func (provider *OpenAIProvider) TextCompletion(ctx context.Context, model string, key schemas.Key, text string, params *schemas.ModelParameters) (*schemas.BifrostResponse, *schemas.BifrostError) { + return nil, newUnsupportedOperationError("text completion", "openai") } // ChatCompletion performs a chat completion request to the OpenAI API. // It supports both text and image content in messages. // Returns a BifrostResponse containing the completion results or an error if the request fails. -func (provider *OpenAIProvider) ChatCompletion(model, key string, messages []schemas.Message, params *schemas.ModelParameters) (*schemas.BifrostResponse, *schemas.BifrostError) { +func (provider *OpenAIProvider) ChatCompletion(ctx context.Context, model string, key schemas.Key, messages []schemas.BifrostMessage, params *schemas.ModelParameters) (*schemas.BifrostResponse, *schemas.BifrostError) { + // Check if chat completion is allowed for this provider + if err := checkOperationAllowed(schemas.OpenAI, provider.customProviderConfig, schemas.OperationChatCompletion); err != nil { + return nil, err + } + + providerName := provider.GetProviderKey() + + formattedMessages, preparedParams := prepareOpenAIChatRequest(messages, params) + + requestBody := mergeConfig(map[string]interface{}{ + "model": model, + "messages": formattedMessages, + }, preparedParams) + + jsonBody, err := sonic.Marshal(requestBody) + if err != nil { + return nil, newBifrostOperationError(schemas.ErrProviderJSONMarshaling, err, providerName) + } + + // Create request + req := fasthttp.AcquireRequest() + resp := fasthttp.AcquireResponse() + defer fasthttp.ReleaseRequest(req) + defer fasthttp.ReleaseResponse(resp) + + // Set any extra headers from network config + setExtraHeaders(req, provider.networkConfig.ExtraHeaders, nil) + + req.SetRequestURI(provider.networkConfig.BaseURL + "/v1/chat/completions") + req.Header.SetMethod("POST") + req.Header.SetContentType("application/json") + req.Header.Set("Authorization", "Bearer "+key.Value) + + req.SetBody(jsonBody) + + // Make request + bifrostErr := makeRequestWithContext(ctx, provider.client, req, resp) + if bifrostErr != nil { + return nil, bifrostErr + } + + // Handle error response + if resp.StatusCode() != fasthttp.StatusOK { + provider.logger.Debug(fmt.Sprintf("error from %s provider: %s", providerName, string(resp.Body()))) + return nil, parseOpenAIError(resp) + } + + responseBody := resp.Body() + + // Pre-allocate response structs from pools + // response := acquireOpenAIResponse() + // defer releaseOpenAIResponse(response) + response := &schemas.BifrostResponse{} + + // Use enhanced response handler with pre-allocated response + rawResponse, bifrostErr := handleProviderResponse(responseBody, response, provider.sendBackRawResponse) + if bifrostErr != nil { + return nil, bifrostErr + } + + // Set raw response if enabled + if provider.sendBackRawResponse { + response.ExtraFields.RawResponse = rawResponse + } + + if params != nil { + response.ExtraFields.Params = *params + } + + response.ExtraFields.Provider = providerName + + return response, nil +} + +// prepareOpenAIChatRequest formats messages for the OpenAI API. +// It handles both text and image content in messages. +// Returns a slice of formatted messages and any additional parameters. +func prepareOpenAIChatRequest(messages []schemas.BifrostMessage, params *schemas.ModelParameters) ([]map[string]interface{}, map[string]interface{}) { // Format messages for OpenAI API var formattedMessages []map[string]interface{} for _, msg := range messages { - if msg.ImageContent != nil { - var content []map[string]interface{} - - // Add text content if present - if msg.Content != nil { - content = append(content, map[string]interface{}{ - "type": "text", - "text": msg.Content, - }) + if msg.Role == schemas.ModelChatMessageRoleAssistant { + assistantMessage := map[string]interface{}{ + "role": msg.Role, + "content": msg.Content, } - - imageContent := map[string]interface{}{ - "type": "image_url", - "image_url": map[string]interface{}{ - "url": msg.ImageContent.URL, - }, + if msg.AssistantMessage != nil && msg.AssistantMessage.ToolCalls != nil { + assistantMessage["tool_calls"] = *msg.AssistantMessage.ToolCalls + } + formattedMessages = append(formattedMessages, assistantMessage) + } else { + message := map[string]interface{}{ + "role": msg.Role, } - if msg.ImageContent.Detail != nil { - imageContent["image_url"].(map[string]interface{})["detail"] = msg.ImageContent.Detail + if msg.Content.ContentStr != nil { + message["content"] = *msg.Content.ContentStr + } else if msg.Content.ContentBlocks != nil { + contentBlocks := *msg.Content.ContentBlocks + for i := range contentBlocks { + if contentBlocks[i].Type == schemas.ContentBlockTypeImage && contentBlocks[i].ImageURL != nil { + sanitizedURL, _ := SanitizeImageURL(contentBlocks[i].ImageURL.URL) + contentBlocks[i].ImageURL.URL = sanitizedURL + } + } + + message["content"] = contentBlocks } - content = append(content, imageContent) + if msg.ToolMessage != nil && msg.ToolMessage.ToolCallID != nil { + message["tool_call_id"] = *msg.ToolMessage.ToolCallID + } - formattedMessages = append(formattedMessages, map[string]interface{}{ - "role": msg.Role, - "content": content, - }) - } else { - formattedMessages = append(formattedMessages, map[string]interface{}{ - "role": msg.Role, - "content": msg.Content, - }) + formattedMessages = append(formattedMessages, message) } } preparedParams := prepareParams(params) + return formattedMessages, preparedParams +} + +// Embedding generates embeddings for the given input text(s). +// The input can be either a single string or a slice of strings for batch embedding. +// Returns a BifrostResponse containing the embedding(s) and any error that occurred. +func (provider *OpenAIProvider) Embedding(ctx context.Context, model string, key schemas.Key, input *schemas.EmbeddingInput, params *schemas.ModelParameters) (*schemas.BifrostResponse, *schemas.BifrostError) { + // Check if embedding is allowed for this provider + if err := checkOperationAllowed(schemas.OpenAI, provider.customProviderConfig, schemas.OperationEmbedding); err != nil { + return nil, err + } + + providerName := provider.GetProviderKey() + + requestBody := prepareOpenAIEmbeddingRequest(input, params) + requestBody["model"] = model + + // Use the shared embedding request handler + return handleOpenAIEmbeddingRequest( + ctx, + provider.client, + provider.networkConfig.BaseURL+"/v1/embeddings", + requestBody, + key, + params, + provider.networkConfig.ExtraHeaders, + providerName, + provider.sendBackRawResponse, + provider.logger, + ) +} + +func prepareOpenAIEmbeddingRequest(input *schemas.EmbeddingInput, params *schemas.ModelParameters) map[string]interface{} { + requestBody := map[string]interface{}{ + "input": input, + } + + // Merge any additional parameters + if params != nil { + // Map standard parameters + if params.EncodingFormat != nil { + requestBody["encoding_format"] = *params.EncodingFormat + } + if params.Dimensions != nil { + requestBody["dimensions"] = *params.Dimensions + } + if params.User != nil { + requestBody["user"] = *params.User + } + + // Merge any extra parameters + requestBody = mergeConfig(requestBody, params.ExtraParams) + } + + return requestBody +} + +func handleOpenAIEmbeddingRequest(ctx context.Context, client *fasthttp.Client, url string, requestBody map[string]interface{}, key schemas.Key, params *schemas.ModelParameters, extraHeaders map[string]string, providerName schemas.ModelProvider, sendBackRawResponse bool, logger schemas.Logger) (*schemas.BifrostResponse, *schemas.BifrostError) { + jsonBody, err := sonic.Marshal(requestBody) + if err != nil { + return nil, newBifrostOperationError(schemas.ErrProviderJSONMarshaling, err, providerName) + } + + // Create request + req := fasthttp.AcquireRequest() + resp := fasthttp.AcquireResponse() + defer fasthttp.ReleaseRequest(req) + defer fasthttp.ReleaseResponse(resp) + + // Set any extra headers from network config + setExtraHeaders(req, extraHeaders, nil) + + req.SetRequestURI(url) + req.Header.SetMethod("POST") + req.Header.SetContentType("application/json") + req.Header.Set("Authorization", "Bearer "+key.Value) + + req.SetBody(jsonBody) + + // Make request + bifrostErr := makeRequestWithContext(ctx, client, req, resp) + if bifrostErr != nil { + return nil, bifrostErr + } + + // Handle error response + if resp.StatusCode() != fasthttp.StatusOK { + logger.Debug(fmt.Sprintf("error from %s provider: %s", providerName, string(resp.Body()))) + return nil, parseOpenAIError(resp) + } + + responseBody := resp.Body() + + // Pre-allocate response structs + response := &schemas.BifrostResponse{} + + // Use enhanced response handler with pre-allocated response + rawResponse, bifrostErr := handleProviderResponse(responseBody, response, sendBackRawResponse) + if bifrostErr != nil { + return nil, bifrostErr + } + + response.ExtraFields.Provider = providerName + + if params != nil { + response.ExtraFields.Params = *params + } + + if sendBackRawResponse { + response.ExtraFields.RawResponse = rawResponse + } + + return response, nil +} + +// ChatCompletionStream handles streaming for OpenAI chat completions. +// It formats messages, prepares request body, and uses shared streaming logic. +// Returns a channel for streaming responses and any error that occurred. +func (provider *OpenAIProvider) ChatCompletionStream(ctx context.Context, postHookRunner schemas.PostHookRunner, model string, key schemas.Key, messages []schemas.BifrostMessage, params *schemas.ModelParameters) (chan *schemas.BifrostStream, *schemas.BifrostError) { + // Check if chat completion stream is allowed for this provider + if err := checkOperationAllowed(schemas.OpenAI, provider.customProviderConfig, schemas.OperationChatCompletionStream); err != nil { + return nil, err + } + + formattedMessages, preparedParams := prepareOpenAIChatRequest(messages, params) + requestBody := mergeConfig(map[string]interface{}{ "model": model, "messages": formattedMessages, + "stream": true, + "stream_options": map[string]interface{}{ + "include_usage": true, + }, }, preparedParams) - jsonBody, err := json.Marshal(requestBody) + // Prepare OpenAI headers + headers := map[string]string{ + "Content-Type": "application/json", + "Authorization": "Bearer " + key.Value, + "Accept": "text/event-stream", + "Cache-Control": "no-cache", + } + + providerName := provider.GetProviderKey() + + // Use shared streaming logic + return handleOpenAIStreaming( + ctx, + provider.streamClient, + provider.networkConfig.BaseURL+"/v1/chat/completions", + requestBody, + headers, + provider.networkConfig.ExtraHeaders, + providerName, + params, + postHookRunner, + provider.logger, + ) +} + +// performOpenAICompatibleStreaming handles streaming for OpenAI-compatible APIs (OpenAI, Azure). +// This shared function reduces code duplication between providers that use the same SSE format. +func handleOpenAIStreaming( + ctx context.Context, + httpClient *http.Client, + url string, + requestBody map[string]interface{}, + headers map[string]string, + extraHeaders map[string]string, + providerName schemas.ModelProvider, + params *schemas.ModelParameters, + postHookRunner schemas.PostHookRunner, + logger schemas.Logger, +) (chan *schemas.BifrostStream, *schemas.BifrostError) { + + jsonBody, err := sonic.Marshal(requestBody) if err != nil { - return nil, &schemas.BifrostError{ - IsBifrostError: true, - Error: schemas.ErrorField{ - Message: schemas.ErrProviderJSONMarshaling, - Error: err, - }, + return nil, newBifrostOperationError(schemas.ErrProviderJSONMarshaling, err, providerName) + } + + // Create HTTP request for streaming + req, err := http.NewRequestWithContext(ctx, "POST", url, bytes.NewReader(jsonBody)) + if err != nil { + return nil, newBifrostOperationError(schemas.ErrProviderRequest, err, providerName) + } + + // Set any extra headers from network config + setExtraHeadersHTTP(req, extraHeaders, nil) + + // Set headers + for key, value := range headers { + req.Header.Set(key, value) + } + + // Make the request + resp, err := httpClient.Do(req) + if err != nil { + return nil, newBifrostOperationError(schemas.ErrProviderRequest, err, providerName) + } + + // Check for HTTP errors + if resp.StatusCode != http.StatusOK { + return nil, parseStreamOpenAIError(resp) + } + + // Create response channel + responseChan := make(chan *schemas.BifrostStream, schemas.DefaultStreamBufferSize) + + // Start streaming in a goroutine + go func() { + defer close(responseChan) + defer resp.Body.Close() + + scanner := bufio.NewScanner(resp.Body) + chunkIndex := -1 + usage := &schemas.LLMUsage{} + + var finishReason *string + var id string + + for scanner.Scan() { + line := scanner.Text() + + // Skip empty lines and comments + if line == "" || strings.HasPrefix(line, ":") { + continue + } + + // Check for end of stream + if line == "data: [DONE]" { + break + } + + var jsonData string + + // Parse SSE data + if strings.HasPrefix(line, "data: ") { + jsonData = strings.TrimPrefix(line, "data: ") + } else { + // Handle raw JSON errors (without "data: " prefix) + jsonData = line + } + + // Skip empty data + if strings.TrimSpace(jsonData) == "" { + continue + } + + // First, check if this is an error response + var errorCheck map[string]interface{} + if err := sonic.Unmarshal([]byte(jsonData), &errorCheck); err != nil { + logger.Warn(fmt.Sprintf("Failed to parse stream data as JSON: %v", err)) + continue + } + + // Handle error responses + if _, hasError := errorCheck["error"]; hasError { + bifrostErr, err := parseOpenAIErrorForStreamDataLine(jsonData) + if err != nil { + logger.Warn(fmt.Sprintf("Failed to parse error response: %v", err)) + continue + } + ctx = context.WithValue(ctx, schemas.BifrostContextKeyStreamEndIndicator, true) + processAndSendBifrostError(ctx, postHookRunner, bifrostErr, responseChan, logger) + return + } + + // Parse into bifrost response + var response schemas.BifrostResponse + if err := sonic.Unmarshal([]byte(jsonData), &response); err != nil { + logger.Warn(fmt.Sprintf("Failed to parse stream response: %v", err)) + continue + } + + // Handle usage-only chunks (when stream_options include_usage is true) + if len(response.Choices) == 0 && response.Usage != nil { + // Collect usage information and send at the end of the stream + usage = response.Usage + response.Usage = nil + } + + // Skip empty responses or responses without choices + if len(response.Choices) == 0 { + continue + } + + // Handle finish reason, usually in the final chunk + choice := response.Choices[0] + if choice.FinishReason != nil && *choice.FinishReason != "" { + // Collect finish reason and send at the end of the stream + finishReason = choice.FinishReason + response.Choices[0].FinishReason = nil + } + + if response.ID != "" && id == "" { + id = response.ID + } + + // Handle regular content chunks + if choice.BifrostStreamResponseChoice != nil && (choice.BifrostStreamResponseChoice.Delta.Content != nil || len(choice.BifrostStreamResponseChoice.Delta.ToolCalls) > 0) { + chunkIndex++ + + response.ExtraFields.Provider = providerName + response.ExtraFields.ChunkIndex = chunkIndex + + processAndSendResponse(ctx, postHookRunner, &response, responseChan, logger) + } } + + // Handle scanner errors first + if err := scanner.Err(); err != nil { + logger.Warn(fmt.Sprintf("Error reading stream: %v", err)) + processAndSendError(ctx, postHookRunner, err, responseChan, logger) + } else { + response := createBifrostChatCompletionChunkResponse(id, usage, finishReason, chunkIndex, params, providerName) + handleStreamEndWithSuccess(ctx, response, postHookRunner, responseChan, logger) + } + }() + + return responseChan, nil +} + +// Speech handles non-streaming speech synthesis requests. +// It formats the request body, makes the API call, and returns the response. +// Returns the response and any error that occurred. +func (provider *OpenAIProvider) Speech(ctx context.Context, model string, key schemas.Key, input *schemas.SpeechInput, params *schemas.ModelParameters) (*schemas.BifrostResponse, *schemas.BifrostError) { + if err := checkOperationAllowed(schemas.OpenAI, provider.customProviderConfig, schemas.OperationSpeech); err != nil { + return nil, err + } + + providerName := provider.GetProviderKey() + + responseFormat := input.ResponseFormat + if responseFormat == "" { + responseFormat = "mp3" + } + + requestBody := map[string]interface{}{ + "input": input.Input, + "model": model, + "voice": input.VoiceConfig.Voice, + "instructions": input.Instructions, + "response_format": responseFormat, + } + + if params != nil { + requestBody = mergeConfig(requestBody, params.ExtraParams) + } + + jsonBody, err := sonic.Marshal(requestBody) + if err != nil { + return nil, newBifrostOperationError(schemas.ErrProviderJSONMarshaling, err, providerName) } // Create request @@ -176,67 +584,607 @@ func (provider *OpenAIProvider) ChatCompletion(model, key string, messages []sch defer fasthttp.ReleaseRequest(req) defer fasthttp.ReleaseResponse(resp) - req.SetRequestURI("https://api.openai.com/v1/chat/completions") + // Set any extra headers from network config + setExtraHeaders(req, provider.networkConfig.ExtraHeaders, nil) + + req.SetRequestURI(provider.networkConfig.BaseURL + "/v1/audio/speech") req.Header.SetMethod("POST") req.Header.SetContentType("application/json") - req.Header.Set("Authorization", "Bearer "+key) + req.Header.Set("Authorization", "Bearer "+key.Value) + req.SetBody(jsonBody) // Make request - if err := provider.client.Do(req, resp); err != nil { - return nil, &schemas.BifrostError{ - IsBifrostError: false, - Error: schemas.ErrorField{ - Message: schemas.ErrProviderRequest, - Error: err, - }, - } + bifrostErr := makeRequestWithContext(ctx, provider.client, req, resp) + if bifrostErr != nil { + return nil, bifrostErr } // Handle error response if resp.StatusCode() != fasthttp.StatusOK { - var errorResp OpenAIError + provider.logger.Debug(fmt.Sprintf("error from %s provider: %s", providerName, string(resp.Body()))) + return nil, parseOpenAIError(resp) + } + + // Get the binary audio data from the response body + audioData := resp.Body() + + // Create final response with the audio data + // Note: For speech synthesis, we return the binary audio data in the raw response + // The audio data is typically in MP3, WAV, or other audio formats as specified by response_format + bifrostResponse := &schemas.BifrostResponse{ + Object: "audio.speech", + Model: model, + Speech: &schemas.BifrostSpeech{ + Audio: audioData, + }, + ExtraFields: schemas.BifrostResponseExtraFields{ + Provider: providerName, + }, + } + + if params != nil { + bifrostResponse.ExtraFields.Params = *params + } + + return bifrostResponse, nil +} + +// SpeechStream handles streaming for speech synthesis. +// It formats the request body, creates HTTP request, and uses shared streaming logic. +// Returns a channel for streaming responses and any error that occurred. +func (provider *OpenAIProvider) SpeechStream(ctx context.Context, postHookRunner schemas.PostHookRunner, model string, key schemas.Key, input *schemas.SpeechInput, params *schemas.ModelParameters) (chan *schemas.BifrostStream, *schemas.BifrostError) { + if err := checkOperationAllowed(schemas.OpenAI, provider.customProviderConfig, schemas.OperationSpeechStream); err != nil { + return nil, err + } + + providerName := provider.GetProviderKey() + + responseFormat := input.ResponseFormat + if responseFormat == "" { + responseFormat = "mp3" + } + + requestBody := map[string]interface{}{ + "input": input.Input, + "model": model, + "voice": input.VoiceConfig.Voice, + "instructions": input.Instructions, + "response_format": responseFormat, + "stream_format": "sse", + } + + if params != nil { + requestBody = mergeConfig(requestBody, params.ExtraParams) + } + + jsonBody, err := sonic.Marshal(requestBody) + if err != nil { + return nil, newBifrostOperationError(schemas.ErrProviderJSONMarshaling, err, providerName) + } + + // Prepare OpenAI headers + headers := map[string]string{ + "Content-Type": "application/json", + "Authorization": "Bearer " + key.Value, + "Accept": "text/event-stream", + "Cache-Control": "no-cache", + } + + // Create HTTP request for streaming + req, err := http.NewRequestWithContext(ctx, "POST", provider.networkConfig.BaseURL+"/v1/audio/speech", bytes.NewReader(jsonBody)) + if err != nil { + return nil, newBifrostOperationError(schemas.ErrProviderRequest, err, providerName) + } + + // Set any extra headers from network config + setExtraHeadersHTTP(req, provider.networkConfig.ExtraHeaders, nil) + + // Set headers + for key, value := range headers { + req.Header.Set(key, value) + } + + // Make the request + resp, err := provider.streamClient.Do(req) + if err != nil { + return nil, newBifrostOperationError(schemas.ErrProviderRequest, err, providerName) + } - bifrostErr := handleProviderAPIError(resp, &errorResp) + // Check for HTTP errors + if resp.StatusCode != http.StatusOK { + return nil, parseStreamOpenAIError(resp) + } + + // Create response channel + responseChan := make(chan *schemas.BifrostStream, schemas.DefaultStreamBufferSize) + + // Start streaming in a goroutine + go func() { + defer close(responseChan) + defer resp.Body.Close() + + scanner := bufio.NewScanner(resp.Body) + chunkIndex := -1 + + for scanner.Scan() { + line := scanner.Text() + + // Skip empty lines and comments + if line == "" || strings.HasPrefix(line, ":") { + continue + } - bifrostErr.EventID = &errorResp.EventID - bifrostErr.Error.Type = &errorResp.Error.Type - bifrostErr.Error.Code = &errorResp.Error.Code - bifrostErr.Error.Message = errorResp.Error.Message - bifrostErr.Error.Param = errorResp.Error.Param - bifrostErr.Error.EventID = &errorResp.Error.EventID + // Check for end of stream + if line == "data: [DONE]" { + break + } + + var jsonData string + + // Parse SSE data + if strings.HasPrefix(line, "data: ") { + jsonData = strings.TrimPrefix(line, "data: ") + } else { + // Handle raw JSON errors (without "data: " prefix) + jsonData = line + } + + // Skip empty data + if strings.TrimSpace(jsonData) == "" { + continue + } + + // First, check if this is an error response + var errorCheck map[string]interface{} + if err := sonic.Unmarshal([]byte(jsonData), &errorCheck); err != nil { + provider.logger.Warn(fmt.Sprintf("Failed to parse stream data as JSON: %v", err)) + continue + } + + // Handle error responses + if _, hasError := errorCheck["error"]; hasError { + bifrostErr, err := parseOpenAIErrorForStreamDataLine(jsonData) + if err != nil { + provider.logger.Warn(fmt.Sprintf("Failed to parse error response: %v", err)) + continue + } + ctx = context.WithValue(ctx, schemas.BifrostContextKeyStreamEndIndicator, true) + processAndSendBifrostError(ctx, postHookRunner, bifrostErr, responseChan, provider.logger) + return + } + + // Parse into bifrost response + var response schemas.BifrostResponse + + var speechResponse schemas.BifrostSpeech + if err := sonic.Unmarshal([]byte(jsonData), &speechResponse); err != nil { + provider.logger.Warn(fmt.Sprintf("Failed to parse stream response: %v", err)) + continue + } + chunkIndex++ + + response.Speech = &speechResponse + response.Object = "audio.speech.chunk" + response.Model = model + response.ExtraFields = schemas.BifrostResponseExtraFields{ + Provider: providerName, + } + + response.ExtraFields.ChunkIndex = chunkIndex + + if speechResponse.Usage != nil { + if params != nil { + response.ExtraFields.Params = *params + } + + ctx = context.WithValue(ctx, schemas.BifrostContextKeyStreamEndIndicator, true) + processAndSendResponse(ctx, postHookRunner, &response, responseChan, provider.logger) + return + } + + processAndSendResponse(ctx, postHookRunner, &response, responseChan, provider.logger) + } + + // Handle scanner errors + if err := scanner.Err(); err != nil { + provider.logger.Warn(fmt.Sprintf("Error reading stream: %v", err)) + processAndSendError(ctx, postHookRunner, err, responseChan, provider.logger) + } + }() + + return responseChan, nil +} + +// Transcription handles non-streaming transcription requests. +// It creates a multipart form, adds fields, makes the API call, and returns the response. +// Returns the response and any error that occurred. +func (provider *OpenAIProvider) Transcription(ctx context.Context, model string, key schemas.Key, input *schemas.TranscriptionInput, params *schemas.ModelParameters) (*schemas.BifrostResponse, *schemas.BifrostError) { + if err := checkOperationAllowed(schemas.OpenAI, provider.customProviderConfig, schemas.OperationTranscription); err != nil { + return nil, err + } + + providerName := provider.GetProviderKey() + + // Create multipart form + var body bytes.Buffer + writer := multipart.NewWriter(&body) + + if bifrostErr := parseTranscriptionFormDataBody(writer, input, model, params, providerName); bifrostErr != nil { return nil, bifrostErr } - responseBody := resp.Body() + // Create request + req := fasthttp.AcquireRequest() + resp := fasthttp.AcquireResponse() + defer fasthttp.ReleaseRequest(req) + defer fasthttp.ReleaseResponse(resp) - // Pre-allocate response structs from pools - response := acquireOpenAIResponse() - defer releaseOpenAIResponse(response) + // Set any extra headers from network config + setExtraHeaders(req, provider.networkConfig.ExtraHeaders, nil) - result := acquireBifrostResponse() - defer releaseBifrostResponse(result) + req.SetRequestURI(provider.networkConfig.BaseURL + "/v1/audio/transcriptions") + req.Header.SetMethod("POST") + req.Header.SetContentType(writer.FormDataContentType()) // This sets multipart/form-data with boundary + req.Header.Set("Authorization", "Bearer "+key.Value) - // Use enhanced response handler with pre-allocated response - rawResponse, bifrostErr := handleProviderResponse(responseBody, response) + req.SetBody(body.Bytes()) + + // Make request + bifrostErr := makeRequestWithContext(ctx, provider.client, req, resp) if bifrostErr != nil { return nil, bifrostErr } - // Populate result from response - result.ID = response.ID - result.Choices = response.Choices - result.Object = response.Object - result.Usage = response.Usage - result.ServiceTier = response.ServiceTier - result.SystemFingerprint = response.SystemFingerprint - result.Model = response.Model - result.Created = response.Created - result.ExtraFields = schemas.BifrostResponseExtraFields{ - Provider: schemas.OpenAI, - RawResponse: rawResponse, + // Handle error response + if resp.StatusCode() != fasthttp.StatusOK { + provider.logger.Debug(fmt.Sprintf("error from %s provider: %s", providerName, string(resp.Body()))) + return nil, parseOpenAIError(resp) + } + + responseBody := resp.Body() + + // Parse OpenAI's transcription response directly into BifrostTranscribe + transcribeResponse := &schemas.BifrostTranscribe{ + BifrostTranscribeNonStreamResponse: &schemas.BifrostTranscribeNonStreamResponse{}, + } + + if err := sonic.Unmarshal(responseBody, transcribeResponse); err != nil { + return nil, newBifrostOperationError(schemas.ErrProviderResponseUnmarshal, err, providerName) + } + + // Parse raw response for RawResponse field + var rawResponse interface{} + if err := sonic.Unmarshal(responseBody, &rawResponse); err != nil { + return nil, newBifrostOperationError(schemas.ErrProviderDecodeRaw, err, providerName) + } + + // Create final response + bifrostResponse := &schemas.BifrostResponse{ + Object: "audio.transcription", + Model: model, + Transcribe: transcribeResponse, + ExtraFields: schemas.BifrostResponseExtraFields{ + Provider: providerName, + }, + } + + if provider.sendBackRawResponse { + bifrostResponse.ExtraFields.RawResponse = rawResponse + } + + if params != nil { + bifrostResponse.ExtraFields.Params = *params + } + + return bifrostResponse, nil + +} + +func (provider *OpenAIProvider) TranscriptionStream(ctx context.Context, postHookRunner schemas.PostHookRunner, model string, key schemas.Key, input *schemas.TranscriptionInput, params *schemas.ModelParameters) (chan *schemas.BifrostStream, *schemas.BifrostError) { + if err := checkOperationAllowed(schemas.OpenAI, provider.customProviderConfig, schemas.OperationTranscriptionStream); err != nil { + return nil, err + } + + providerName := provider.GetProviderKey() + + // Create multipart form + var body bytes.Buffer + writer := multipart.NewWriter(&body) + + if err := writer.WriteField("stream", "true"); err != nil { + return nil, newBifrostOperationError("failed to write stream field", err, providerName) + } + + if bifrostErr := parseTranscriptionFormDataBody(writer, input, model, params, providerName); bifrostErr != nil { + return nil, bifrostErr + } + + // Prepare OpenAI headers + headers := map[string]string{ + "Content-Type": writer.FormDataContentType(), + "Authorization": "Bearer " + key.Value, + "Accept": "text/event-stream", + "Cache-Control": "no-cache", + } + + // Create HTTP request for streaming + req, err := http.NewRequestWithContext(ctx, "POST", provider.networkConfig.BaseURL+"/v1/audio/transcriptions", &body) + if err != nil { + return nil, newBifrostOperationError(schemas.ErrProviderRequest, err, providerName) + } + + // Set any extra headers from network config + setExtraHeadersHTTP(req, provider.networkConfig.ExtraHeaders, nil) + + // Set headers + for key, value := range headers { + req.Header.Set(key, value) + } + + // Make the request + resp, err := provider.streamClient.Do(req) + if err != nil { + return nil, newBifrostOperationError(schemas.ErrProviderRequest, err, providerName) + } + + // Check for HTTP errors + if resp.StatusCode != http.StatusOK { + return nil, parseStreamOpenAIError(resp) + } + + // Create response channel + responseChan := make(chan *schemas.BifrostStream, schemas.DefaultStreamBufferSize) + + // Start streaming in a goroutine + go func() { + defer close(responseChan) + defer resp.Body.Close() + + scanner := bufio.NewScanner(resp.Body) + chunkIndex := -1 + + for scanner.Scan() { + line := scanner.Text() + + // Skip empty lines and comments + if line == "" { + continue + } + + // Check for end of stream + if line == "data: [DONE]" { + break + } + + var jsonData string + // Parse SSE data + if strings.HasPrefix(line, "data: ") { + jsonData = strings.TrimPrefix(line, "data: ") + } else { + // Handle raw JSON errors (without "data: " prefix) + jsonData = line + } + + // Skip empty data + if strings.TrimSpace(jsonData) == "" { + continue + } + + // First, check if this is an error response + var errorCheck map[string]interface{} + if err := sonic.Unmarshal([]byte(jsonData), &errorCheck); err != nil { + provider.logger.Warn(fmt.Sprintf("Failed to parse stream data as JSON: %v", err)) + continue + } + + // Handle error responses + if _, hasError := errorCheck["error"]; hasError { + bifrostErr, err := parseOpenAIErrorForStreamDataLine(jsonData) + if err != nil { + provider.logger.Warn(fmt.Sprintf("Failed to parse error response: %v", err)) + continue + } + ctx = context.WithValue(ctx, schemas.BifrostContextKeyStreamEndIndicator, true) + processAndSendBifrostError(ctx, postHookRunner, bifrostErr, responseChan, provider.logger) + return + } + + var response schemas.BifrostResponse + + var transcriptionResponse schemas.BifrostTranscribe + if err := sonic.Unmarshal([]byte(jsonData), &transcriptionResponse); err != nil { + provider.logger.Warn(fmt.Sprintf("Failed to parse stream response: %v", err)) + continue + } + + chunkIndex++ + + response.Transcribe = &transcriptionResponse + response.Object = "audio.transcription.chunk" + response.Model = model + response.ExtraFields = schemas.BifrostResponseExtraFields{ + Provider: providerName, + } + + response.ExtraFields.ChunkIndex = chunkIndex + + if transcriptionResponse.Usage != nil { + if params != nil { + response.ExtraFields.Params = *params + } + + ctx = context.WithValue(ctx, schemas.BifrostContextKeyStreamEndIndicator, true) + processAndSendResponse(ctx, postHookRunner, &response, responseChan, provider.logger) + return + } + + processAndSendResponse(ctx, postHookRunner, &response, responseChan, provider.logger) + } + + // Handle scanner errors + if err := scanner.Err(); err != nil { + provider.logger.Warn(fmt.Sprintf("Error reading stream: %v", err)) + processAndSendError(ctx, postHookRunner, err, responseChan, provider.logger) + } + }() + + return responseChan, nil +} + +func parseTranscriptionFormDataBody(writer *multipart.Writer, input *schemas.TranscriptionInput, model string, params *schemas.ModelParameters, providerName schemas.ModelProvider) *schemas.BifrostError { + // Add file field + fileWriter, err := writer.CreateFormFile("file", "audio.mp3") // OpenAI requires a filename + if err != nil { + return newBifrostOperationError("failed to create form file", err, providerName) + } + if _, err := fileWriter.Write(input.File); err != nil { + return newBifrostOperationError("failed to write file data", err, providerName) + } + + // Add model field + if err := writer.WriteField("model", model); err != nil { + return newBifrostOperationError("failed to write model field", err, providerName) + } + + // Add optional fields + if input.Language != nil { + if err := writer.WriteField("language", *input.Language); err != nil { + return newBifrostOperationError("failed to write language field", err, providerName) + } + } + + if input.Prompt != nil { + if err := writer.WriteField("prompt", *input.Prompt); err != nil { + return newBifrostOperationError("failed to write prompt field", err, providerName) + } + } + + if input.ResponseFormat != nil { + if err := writer.WriteField("response_format", *input.ResponseFormat); err != nil { + return newBifrostOperationError("failed to write response_format field", err, providerName) + } + } + + // Note: Temperature and TimestampGranularities can be added via params.ExtraParams if needed + + // Add extra params if provided + if params != nil && params.ExtraParams != nil { + for key, value := range params.ExtraParams { + // Handle array parameters specially for OpenAI's form data format + switch v := value.(type) { + case []string: + // For arrays like timestamp_granularities[] or include[] + for _, item := range v { + if err := writer.WriteField(key+"[]", item); err != nil { + return newBifrostOperationError(fmt.Sprintf("failed to write array param %s", key), err, providerName) + } + } + case []interface{}: + // Handle generic interface arrays + for _, item := range v { + if err := writer.WriteField(key+"[]", fmt.Sprintf("%v", item)); err != nil { + return newBifrostOperationError(fmt.Sprintf("failed to write array param %s", key), err, providerName) + } + } + default: + // Handle non-array parameters normally + if err := writer.WriteField(key, fmt.Sprintf("%v", value)); err != nil { + return newBifrostOperationError(fmt.Sprintf("failed to write extra param %s", key), err, providerName) + } + } + } + } + + // Close the multipart writer + if err := writer.Close(); err != nil { + return newBifrostOperationError("failed to close multipart writer", err, providerName) + } + + return nil +} + +func parseOpenAIError(resp *fasthttp.Response) *schemas.BifrostError { + var errorResp schemas.BifrostError + + bifrostErr := handleProviderAPIError(resp, &errorResp) + + if errorResp.EventID != nil { + bifrostErr.EventID = errorResp.EventID + } + bifrostErr.Error.Type = errorResp.Error.Type + bifrostErr.Error.Code = errorResp.Error.Code + bifrostErr.Error.Message = errorResp.Error.Message + bifrostErr.Error.Param = errorResp.Error.Param + if errorResp.Error.EventID != nil { + bifrostErr.Error.EventID = errorResp.Error.EventID + } + + return bifrostErr +} + +func parseStreamOpenAIError(resp *http.Response) *schemas.BifrostError { + var errorResp schemas.BifrostError + + statusCode := resp.StatusCode + body, _ := io.ReadAll(resp.Body) + resp.Body.Close() + + if err := sonic.Unmarshal(body, &errorResp); err != nil { + return &schemas.BifrostError{ + IsBifrostError: true, + StatusCode: &statusCode, + Error: schemas.ErrorField{ + Message: schemas.ErrProviderResponseUnmarshal, + Error: err, + }, + } + } + + bifrostErr := &schemas.BifrostError{ + IsBifrostError: false, + StatusCode: &statusCode, + Error: schemas.ErrorField{}, + } + + if errorResp.EventID != nil { + bifrostErr.EventID = errorResp.EventID + } + bifrostErr.Error.Type = errorResp.Error.Type + bifrostErr.Error.Code = errorResp.Error.Code + bifrostErr.Error.Message = errorResp.Error.Message + bifrostErr.Error.Param = errorResp.Error.Param + if errorResp.Error.EventID != nil { + bifrostErr.Error.EventID = errorResp.Error.EventID + } + + return bifrostErr +} + +func parseOpenAIErrorForStreamDataLine(jsonData string) (*schemas.BifrostError, error) { + var openAIError schemas.BifrostError + if err := sonic.Unmarshal([]byte(jsonData), &openAIError); err != nil { + return nil, err + } + + // Send error through channel + bifrostErr := &schemas.BifrostError{ + IsBifrostError: false, + Error: schemas.ErrorField{ + Type: openAIError.Error.Type, + Code: openAIError.Error.Code, + Message: openAIError.Error.Message, + Param: openAIError.Error.Param, + }, + } + + if openAIError.EventID != nil { + bifrostErr.EventID = openAIError.EventID + } + if openAIError.Error.EventID != nil { + bifrostErr.Error.EventID = openAIError.Error.EventID } - return result, nil + return bifrostErr, nil } diff --git a/core/providers/openrouter.go b/core/providers/openrouter.go new file mode 100644 index 000000000..0496e9b94 --- /dev/null +++ b/core/providers/openrouter.go @@ -0,0 +1,334 @@ +// Package providers implements various LLM providers and their utility functions. +// This file contains the OpenRouter provider implementation. +package providers + +import ( + "context" + "fmt" + "net/http" + "strings" + "sync" + "time" + + "github.com/bytedance/sonic" + schemas "github.com/maximhq/bifrost/core/schemas" + "github.com/valyala/fasthttp" +) + +// OpenRouter response structures + +// OpenRouterTextResponse represents the response from OpenRouter text completion API +type OpenRouterTextResponse struct { + ID string `json:"id"` + Model string `json:"model"` + Created int `json:"created"` + SystemFingerprint *string `json:"system_fingerprint"` + Choices []OpenRouterTextChoice `json:"choices"` + Usage *schemas.LLMUsage `json:"usage"` +} + +// OpenRouterTextChoice represents a choice in the OpenRouter text completion response +type OpenRouterTextChoice struct { + Text string `json:"text"` + Index int `json:"index"` + FinishReason string `json:"finish_reason"` +} + +// OpenRouterProvider implements the Provider interface for OpenRouter's API. +type OpenRouterProvider struct { + logger schemas.Logger // Logger for provider operations + client *fasthttp.Client // HTTP client for API requests + streamClient *http.Client // HTTP client for streaming requests + networkConfig schemas.NetworkConfig // Network configuration including extra headers + sendBackRawResponse bool // Whether to include raw response in BifrostResponse +} + +// openRouterTextCompletionResponsePool provides a pool for OpenRouter text completion response objects. +var openRouterTextCompletionResponsePool = sync.Pool{ + New: func() interface{} { + return &OpenRouterTextResponse{} + }, +} + +// acquireOpenRouterTextResponse gets an OpenRouter text completion response from the pool and resets it. +func acquireOpenRouterTextResponse() *OpenRouterTextResponse { + resp := openRouterTextCompletionResponsePool.Get().(*OpenRouterTextResponse) + *resp = OpenRouterTextResponse{} // Reset the struct + return resp +} + +// releaseOpenRouterTextResponse returns an OpenRouter text completion response to the pool. +func releaseOpenRouterTextResponse(resp *OpenRouterTextResponse) { + if resp != nil { + openRouterTextCompletionResponsePool.Put(resp) + } +} + +// NewOpenRouterProvider creates a new OpenRouter provider instance. +// It initializes the HTTP client with the provided configuration and sets up response pools. +// The client is configured with timeouts, concurrency limits, and optional proxy settings. +func NewOpenRouterProvider(config *schemas.ProviderConfig, logger schemas.Logger) *OpenRouterProvider { + config.CheckAndSetDefaults() + + client := &fasthttp.Client{ + ReadTimeout: time.Second * time.Duration(config.NetworkConfig.DefaultRequestTimeoutInSeconds), + WriteTimeout: time.Second * time.Duration(config.NetworkConfig.DefaultRequestTimeoutInSeconds), + MaxConnsPerHost: config.ConcurrencyAndBufferSize.Concurrency, + } + + // Initialize streaming HTTP client + streamClient := &http.Client{ + Timeout: time.Second * time.Duration(config.NetworkConfig.DefaultRequestTimeoutInSeconds), + } + + // Pre-warm response pools + for i := 0; i < config.ConcurrencyAndBufferSize.Concurrency; i++ { + openRouterTextCompletionResponsePool.Put(&OpenRouterTextResponse{}) + } + + // Configure proxy if provided + client = configureProxy(client, config.ProxyConfig, logger) + + // Set default BaseURL if not provided + if config.NetworkConfig.BaseURL == "" { + config.NetworkConfig.BaseURL = "https://openrouter.ai/api" + } + config.NetworkConfig.BaseURL = strings.TrimRight(config.NetworkConfig.BaseURL, "/") + + return &OpenRouterProvider{ + logger: logger, + client: client, + streamClient: streamClient, + networkConfig: config.NetworkConfig, + sendBackRawResponse: config.SendBackRawResponse, + } +} + +// GetProviderKey returns the provider identifier for OpenRouter. +func (provider *OpenRouterProvider) GetProviderKey() schemas.ModelProvider { + return schemas.OpenRouter +} + +// TextCompletion performs a text completion request to the OpenRouter API. +func (provider *OpenRouterProvider) TextCompletion(ctx context.Context, model string, key schemas.Key, text string, params *schemas.ModelParameters) (*schemas.BifrostResponse, *schemas.BifrostError) { + preparedParams := prepareParams(params) + + // Merge additional parameters + requestBody := mergeConfig(map[string]interface{}{ + "model": model, + "prompt": text, + }, preparedParams) + + jsonBody, err := sonic.Marshal(requestBody) + if err != nil { + return nil, newBifrostOperationError(schemas.ErrProviderJSONMarshaling, err, schemas.OpenRouter) + } + + // Create request + req := fasthttp.AcquireRequest() + resp := fasthttp.AcquireResponse() + defer fasthttp.ReleaseRequest(req) + defer fasthttp.ReleaseResponse(resp) + + // Set any extra headers from network config + setExtraHeaders(req, provider.networkConfig.ExtraHeaders, nil) + + req.SetRequestURI(provider.networkConfig.BaseURL + "/v1/completions") + req.Header.SetMethod("POST") + req.Header.SetContentType("application/json") + req.Header.Set("Authorization", "Bearer "+key.Value) + + req.SetBody(jsonBody) + + // Make request + bifrostErr := makeRequestWithContext(ctx, provider.client, req, resp) + if bifrostErr != nil { + return nil, bifrostErr + } + + // Handle error response + if resp.StatusCode() != fasthttp.StatusOK { + provider.logger.Debug(fmt.Sprintf("error from openrouter provider: %s", string(resp.Body()))) + + var errorResp map[string]interface{} + bifrostErr := handleProviderAPIError(resp, &errorResp) + bifrostErr.Error.Message = fmt.Sprintf("OpenRouter error: %v", errorResp) + return nil, bifrostErr + } + + responseBody := resp.Body() + + // Create response object from pool + response := acquireOpenRouterTextResponse() + defer releaseOpenRouterTextResponse(response) + + rawResponse, bifrostErr := handleProviderResponse(responseBody, response, provider.sendBackRawResponse) + if bifrostErr != nil { + return nil, bifrostErr + } + + choices := make([]schemas.BifrostResponseChoice, 0, len(response.Choices)) + for i, ch := range response.Choices { + txt := ch.Text // local copy; safe after pool release + fr := ch.FinishReason // local copy; safe after pool release + choices = append(choices, schemas.BifrostResponseChoice{ + Index: i, + BifrostNonStreamResponseChoice: &schemas.BifrostNonStreamResponseChoice{ + Message: schemas.BifrostMessage{ + Role: schemas.ModelChatMessageRoleAssistant, + Content: schemas.MessageContent{ContentStr: &txt}, + }, + }, + FinishReason: &fr, + }) + } + + // Create final response + bifrostResponse := &schemas.BifrostResponse{ + ID: response.ID, + Choices: choices, + Model: response.Model, + Created: response.Created, + SystemFingerprint: response.SystemFingerprint, + Usage: response.Usage, + ExtraFields: schemas.BifrostResponseExtraFields{ + Provider: schemas.OpenRouter, + }, + } + + // Set raw response if enabled + if provider.sendBackRawResponse { + bifrostResponse.ExtraFields.RawResponse = rawResponse + } + + if params != nil { + bifrostResponse.ExtraFields.Params = *params + } + + return bifrostResponse, nil +} + +// ChatCompletion performs a chat completion request to the OpenRouter API. +func (provider *OpenRouterProvider) ChatCompletion(ctx context.Context, model string, key schemas.Key, messages []schemas.BifrostMessage, params *schemas.ModelParameters) (*schemas.BifrostResponse, *schemas.BifrostError) { + formattedMessages, preparedParams := prepareOpenAIChatRequest(messages, params) + + requestBody := mergeConfig(map[string]interface{}{ + "model": model, + "messages": formattedMessages, + }, preparedParams) + + jsonBody, err := sonic.Marshal(requestBody) + if err != nil { + return nil, newBifrostOperationError(schemas.ErrProviderJSONMarshaling, err, schemas.OpenRouter) + } + + // Create request + req := fasthttp.AcquireRequest() + resp := fasthttp.AcquireResponse() + defer fasthttp.ReleaseRequest(req) + defer fasthttp.ReleaseResponse(resp) + + // Set any extra headers from network config + setExtraHeaders(req, provider.networkConfig.ExtraHeaders, nil) + + req.SetRequestURI(provider.networkConfig.BaseURL + "/v1/chat/completions") + req.Header.SetMethod("POST") + req.Header.SetContentType("application/json") + req.Header.Set("Authorization", "Bearer "+key.Value) + + req.SetBody(jsonBody) + + // Make request + bifrostErr := makeRequestWithContext(ctx, provider.client, req, resp) + if bifrostErr != nil { + return nil, bifrostErr + } + + // Handle error response + if resp.StatusCode() != fasthttp.StatusOK { + provider.logger.Debug(fmt.Sprintf("error from openrouter provider: %s", string(resp.Body()))) + + var errorResp map[string]interface{} + bifrostErr := handleProviderAPIError(resp, &errorResp) + bifrostErr.Error.Message = fmt.Sprintf("OpenRouter error: %v", errorResp) + return nil, bifrostErr + } + + responseBody := resp.Body() + + response := &schemas.BifrostResponse{} + + // Use enhanced response handler with pre-allocated response + rawResponse, bifrostErr := handleProviderResponse(responseBody, response, provider.sendBackRawResponse) + if bifrostErr != nil { + return nil, bifrostErr + } + + response.ExtraFields.Provider = schemas.OpenRouter + + if provider.sendBackRawResponse { + response.ExtraFields.RawResponse = rawResponse + } + + if params != nil { + response.ExtraFields.Params = *params + } + + return response, nil +} + +// ChatCompletionStream performs a streaming chat completion request to the OpenRouter API. +// It supports real-time streaming of responses using Server-Sent Events (SSE). +// Uses OpenRouter's OpenAI-compatible streaming format. +// Returns a channel containing BifrostResponse objects representing the stream or an error if the request fails. +func (provider *OpenRouterProvider) ChatCompletionStream(ctx context.Context, postHookRunner schemas.PostHookRunner, model string, key schemas.Key, messages []schemas.BifrostMessage, params *schemas.ModelParameters) (chan *schemas.BifrostStream, *schemas.BifrostError) { + formattedMessages, preparedParams := prepareOpenAIChatRequest(messages, params) + + requestBody := mergeConfig(map[string]interface{}{ + "model": model, + "messages": formattedMessages, + "stream": true, + }, preparedParams) + + // Prepare OpenRouter headers + headers := map[string]string{ + "Content-Type": "application/json", + "Authorization": "Bearer " + key.Value, + "Accept": "text/event-stream", + "Cache-Control": "no-cache", + } + + // Use shared OpenAI-compatible streaming logic + return handleOpenAIStreaming( + ctx, + provider.streamClient, + provider.networkConfig.BaseURL+"/v1/chat/completions", + requestBody, + headers, + provider.networkConfig.ExtraHeaders, + schemas.OpenRouter, + params, + postHookRunner, + provider.logger, + ) +} + +func (provider *OpenRouterProvider) Embedding(ctx context.Context, model string, key schemas.Key, input *schemas.EmbeddingInput, params *schemas.ModelParameters) (*schemas.BifrostResponse, *schemas.BifrostError) { + return nil, newUnsupportedOperationError("embedding", "openrouter") +} + +func (provider *OpenRouterProvider) Speech(ctx context.Context, model string, key schemas.Key, input *schemas.SpeechInput, params *schemas.ModelParameters) (*schemas.BifrostResponse, *schemas.BifrostError) { + return nil, newUnsupportedOperationError("speech", "openrouter") +} + +func (provider *OpenRouterProvider) SpeechStream(ctx context.Context, postHookRunner schemas.PostHookRunner, model string, key schemas.Key, input *schemas.SpeechInput, params *schemas.ModelParameters) (chan *schemas.BifrostStream, *schemas.BifrostError) { + return nil, newUnsupportedOperationError("speech stream", "openrouter") +} + +func (provider *OpenRouterProvider) Transcription(ctx context.Context, model string, key schemas.Key, input *schemas.TranscriptionInput, params *schemas.ModelParameters) (*schemas.BifrostResponse, *schemas.BifrostError) { + return nil, newUnsupportedOperationError("transcription", "openrouter") +} + +func (provider *OpenRouterProvider) TranscriptionStream(ctx context.Context, postHookRunner schemas.PostHookRunner, model string, key schemas.Key, input *schemas.TranscriptionInput, params *schemas.ModelParameters) (chan *schemas.BifrostStream, *schemas.BifrostError) { + return nil, newUnsupportedOperationError("transcription stream", "openrouter") +} diff --git a/core/providers/parasail.go b/core/providers/parasail.go new file mode 100644 index 000000000..97c70d714 --- /dev/null +++ b/core/providers/parasail.go @@ -0,0 +1,230 @@ +// Package providers implements various LLM providers and their utility functions. +// This file contains the Parasail provider implementation. +package providers + +import ( + "context" + "fmt" + "net/http" + "strings" + "time" + + "github.com/bytedance/sonic" + schemas "github.com/maximhq/bifrost/core/schemas" + "github.com/valyala/fasthttp" +) + +// // parasailResponsePool provides a pool for Parasail response objects. +// var parasailResponsePool = sync.Pool{ +// New: func() interface{} { +// return &schemas.BifrostResponse{} +// }, +// } + +// // acquireParasailResponse gets a Parasail response from the pool and resets it. +// func acquireParasailResponse() *schemas.BifrostResponse { +// resp := parasailResponsePool.Get().(*schemas.BifrostResponse) +// *resp = schemas.BifrostResponse{} // Reset the struct +// return resp +// } + +// // releaseParasailResponse returns a Parasail response to the pool. +// func releaseParasailResponse(resp *schemas.BifrostResponse) { +// if resp != nil { +// parasailResponsePool.Put(resp) +// } +// } + +// ParasailProvider implements the Provider interface for Parasail's API. +type ParasailProvider struct { + logger schemas.Logger // Logger for provider operations + client *fasthttp.Client // HTTP client for API requests + streamClient *http.Client // HTTP client for streaming requests + networkConfig schemas.NetworkConfig // Network configuration including extra headers + sendBackRawResponse bool // Whether to include raw response in BifrostResponse +} + +// NewParasailProvider creates a new Parasail provider instance. +// It initializes the HTTP client with the provided configuration and sets up response pools. +// The client is configured with timeouts, concurrency limits, and optional proxy settings. +func NewParasailProvider(config *schemas.ProviderConfig, logger schemas.Logger) (*ParasailProvider, error) { + config.CheckAndSetDefaults() + + client := &fasthttp.Client{ + ReadTimeout: time.Second * time.Duration(config.NetworkConfig.DefaultRequestTimeoutInSeconds), + WriteTimeout: time.Second * time.Duration(config.NetworkConfig.DefaultRequestTimeoutInSeconds), + MaxConnsPerHost: config.ConcurrencyAndBufferSize.BufferSize, + } + + // Initialize streaming HTTP client + streamClient := &http.Client{ + Timeout: time.Second * time.Duration(config.NetworkConfig.DefaultRequestTimeoutInSeconds), + } + + // Pre-warm response pools + // for range config.ConcurrencyAndBufferSize.Concurrency { + // parasailResponsePool.Put(&schemas.BifrostResponse{}) + // } + + // Configure proxy if provided + client = configureProxy(client, config.ProxyConfig, logger) + + // Set default BaseURL if not provided + if config.NetworkConfig.BaseURL == "" { + config.NetworkConfig.BaseURL = "https://api.parasail.io" + } + config.NetworkConfig.BaseURL = strings.TrimRight(config.NetworkConfig.BaseURL, "/") + + return &ParasailProvider{ + logger: logger, + client: client, + streamClient: streamClient, + networkConfig: config.NetworkConfig, + sendBackRawResponse: config.SendBackRawResponse, + }, nil +} + +// GetProviderKey returns the provider identifier for Parasail. +func (provider *ParasailProvider) GetProviderKey() schemas.ModelProvider { + return schemas.Parasail +} + +// TextCompletion is not supported by the Parasail provider. +func (provider *ParasailProvider) TextCompletion(ctx context.Context, model string, key schemas.Key, text string, params *schemas.ModelParameters) (*schemas.BifrostResponse, *schemas.BifrostError) { + return nil, newUnsupportedOperationError("text completion", "parasail") +} + +// ChatCompletion performs a chat completion request to the Parasail API. +func (provider *ParasailProvider) ChatCompletion(ctx context.Context, model string, key schemas.Key, messages []schemas.BifrostMessage, params *schemas.ModelParameters) (*schemas.BifrostResponse, *schemas.BifrostError) { + formattedMessages, preparedParams := prepareOpenAIChatRequest(messages, params) + + requestBody := mergeConfig(map[string]interface{}{ + "model": model, + "messages": formattedMessages, + }, preparedParams) + + jsonBody, err := sonic.Marshal(requestBody) + if err != nil { + return nil, newBifrostOperationError(schemas.ErrProviderJSONMarshaling, err, schemas.Parasail) + } + + // Create request + req := fasthttp.AcquireRequest() + resp := fasthttp.AcquireResponse() + defer fasthttp.ReleaseRequest(req) + defer fasthttp.ReleaseResponse(resp) + + // Set any extra headers from network config + setExtraHeaders(req, provider.networkConfig.ExtraHeaders, nil) + + req.SetRequestURI(provider.networkConfig.BaseURL + "/v1/chat/completions") + req.Header.SetMethod("POST") + req.Header.SetContentType("application/json") + req.Header.Set("Authorization", "Bearer "+key.Value) + + req.SetBody(jsonBody) + + // Make request + bifrostErr := makeRequestWithContext(ctx, provider.client, req, resp) + if bifrostErr != nil { + return nil, bifrostErr + } + + // Handle error response + if resp.StatusCode() != fasthttp.StatusOK { + provider.logger.Debug(fmt.Sprintf("error from parasail provider: %s", string(resp.Body()))) + + var errorResp map[string]interface{} + bifrostErr := handleProviderAPIError(resp, &errorResp) + bifrostErr.Error.Message = fmt.Sprintf("Parasail error: %v", errorResp) + return nil, bifrostErr + } + + responseBody := resp.Body() + + // Pre-allocate response structs from pools + // response := acquireParasailResponse() + // defer releaseParasailResponse(response) + response := &schemas.BifrostResponse{} + + // Use enhanced response handler with pre-allocated response + rawResponse, bifrostErr := handleProviderResponse(responseBody, response, provider.sendBackRawResponse) + if bifrostErr != nil { + return nil, bifrostErr + } + + // Create final response + response.ExtraFields.Provider = schemas.Parasail + + if provider.sendBackRawResponse { + response.ExtraFields.RawResponse = rawResponse + } + + if params != nil { + response.ExtraFields.Params = *params + } + + return response, nil +} + +// Embedding is not supported by the Parasail provider. +func (provider *ParasailProvider) Embedding(ctx context.Context, model string, key schemas.Key, input *schemas.EmbeddingInput, params *schemas.ModelParameters) (*schemas.BifrostResponse, *schemas.BifrostError) { + return nil, newUnsupportedOperationError("embedding", "parasail") +} + +// ChatCompletionStream performs a streaming chat completion request to the Parasail API. +// It supports real-time streaming of responses using Server-Sent Events (SSE). +// Uses Parasail's OpenAI-compatible streaming format. +// Returns a channel containing BifrostResponse objects representing the stream or an error if the request fails. +func (provider *ParasailProvider) ChatCompletionStream(ctx context.Context, postHookRunner schemas.PostHookRunner, model string, key schemas.Key, messages []schemas.BifrostMessage, params *schemas.ModelParameters) (chan *schemas.BifrostStream, *schemas.BifrostError) { + formattedMessages, preparedParams := prepareOpenAIChatRequest(messages, params) + + requestBody := mergeConfig(map[string]interface{}{ + "model": model, + "messages": formattedMessages, + "stream": true, + }, preparedParams) + + // Prepare Parasail headers + headers := map[string]string{ + "Content-Type": "application/json", + "Accept": "text/event-stream", + "Cache-Control": "no-cache", + } + + headers["Authorization"] = "Bearer " + key.Value + + // Use shared OpenAI-compatible streaming logic + return handleOpenAIStreaming( + ctx, + provider.streamClient, + provider.networkConfig.BaseURL+"/v1/chat/completions", + requestBody, + headers, + provider.networkConfig.ExtraHeaders, + schemas.Parasail, + params, + postHookRunner, + provider.logger, + ) +} + +// Speech is not supported by the Parasail provider. +func (provider *ParasailProvider) Speech(ctx context.Context, model string, key schemas.Key, input *schemas.SpeechInput, params *schemas.ModelParameters) (*schemas.BifrostResponse, *schemas.BifrostError) { + return nil, newUnsupportedOperationError("speech", "parasail") +} + +// SpeechStream is not supported by the Parasail provider. +func (provider *ParasailProvider) SpeechStream(ctx context.Context, postHookRunner schemas.PostHookRunner, model string, key schemas.Key, input *schemas.SpeechInput, params *schemas.ModelParameters) (chan *schemas.BifrostStream, *schemas.BifrostError) { + return nil, newUnsupportedOperationError("speech stream", "parasail") +} + +// Transcription is not supported by the Parasail provider. +func (provider *ParasailProvider) Transcription(ctx context.Context, model string, key schemas.Key, input *schemas.TranscriptionInput, params *schemas.ModelParameters) (*schemas.BifrostResponse, *schemas.BifrostError) { + return nil, newUnsupportedOperationError("transcription", "parasail") +} + +// TranscriptionStream is not supported by the Parasail provider. +func (provider *ParasailProvider) TranscriptionStream(ctx context.Context, postHookRunner schemas.PostHookRunner, model string, key schemas.Key, input *schemas.TranscriptionInput, params *schemas.ModelParameters) (chan *schemas.BifrostStream, *schemas.BifrostError) { + return nil, newUnsupportedOperationError("transcription stream", "parasail") +} diff --git a/core/providers/sgl.go b/core/providers/sgl.go new file mode 100644 index 000000000..2a5f406ff --- /dev/null +++ b/core/providers/sgl.go @@ -0,0 +1,237 @@ +// Package providers implements various LLM providers and their utility functions. +// This file contains the SGL provider implementation. +package providers + +import ( + "context" + "fmt" + "net/http" + "strings" + "time" + + "github.com/bytedance/sonic" + schemas "github.com/maximhq/bifrost/core/schemas" + "github.com/valyala/fasthttp" +) + +// // sglResponsePool provides a pool for SGL response objects. +// var sglResponsePool = sync.Pool{ +// New: func() interface{} { +// return &schemas.BifrostResponse{} +// }, +// } + +// // acquireSGLResponse gets a SGL response from the pool and resets it. +// func acquireSGLResponse() *schemas.BifrostResponse { +// resp := sglResponsePool.Get().(*schemas.BifrostResponse) +// *resp = schemas.BifrostResponse{} // Reset the struct +// return resp +// } + +// // releaseSGLResponse returns a SGL response to the pool. +// func releaseSGLResponse(resp *schemas.BifrostResponse) { +// if resp != nil { +// sglResponsePool.Put(resp) +// } +// } + +// SGLProvider implements the Provider interface for SGL's API. +type SGLProvider struct { + logger schemas.Logger // Logger for provider operations + client *fasthttp.Client // HTTP client for API requests + streamClient *http.Client // HTTP client for streaming requests + networkConfig schemas.NetworkConfig // Network configuration including extra headers + sendBackRawResponse bool // Whether to include raw response in BifrostResponse +} + +// NewSGLProvider creates a new SGL provider instance. +// It initializes the HTTP client with the provided configuration and sets up response pools. +// The client is configured with timeouts, concurrency limits, and optional proxy settings. +func NewSGLProvider(config *schemas.ProviderConfig, logger schemas.Logger) (*SGLProvider, error) { + config.CheckAndSetDefaults() + + client := &fasthttp.Client{ + ReadTimeout: time.Second * time.Duration(config.NetworkConfig.DefaultRequestTimeoutInSeconds), + WriteTimeout: time.Second * time.Duration(config.NetworkConfig.DefaultRequestTimeoutInSeconds), + MaxConnsPerHost: config.ConcurrencyAndBufferSize.BufferSize, + } + + // Initialize streaming HTTP client + streamClient := &http.Client{ + Timeout: time.Second * time.Duration(config.NetworkConfig.DefaultRequestTimeoutInSeconds), + } + + // Pre-warm response pools + // for range config.ConcurrencyAndBufferSize.Concurrency { + // sglResponsePool.Put(&schemas.BifrostResponse{}) + // } + + // Configure proxy if provided + client = configureProxy(client, config.ProxyConfig, logger) + + config.NetworkConfig.BaseURL = strings.TrimRight(config.NetworkConfig.BaseURL, "/") + + // BaseURL is required for SGLang + if config.NetworkConfig.BaseURL == "" { + return nil, fmt.Errorf("base_url is required for sgl provider") + } + + return &SGLProvider{ + logger: logger, + client: client, + streamClient: streamClient, + networkConfig: config.NetworkConfig, + sendBackRawResponse: config.SendBackRawResponse, + }, nil +} + +// GetProviderKey returns the provider identifier for SGL. +func (provider *SGLProvider) GetProviderKey() schemas.ModelProvider { + return schemas.SGL +} + +// TextCompletion is not supported by the SGL provider. +func (provider *SGLProvider) TextCompletion(ctx context.Context, model string, key schemas.Key, text string, params *schemas.ModelParameters) (*schemas.BifrostResponse, *schemas.BifrostError) { + return nil, newUnsupportedOperationError("text completion", "sgl") +} + +// ChatCompletion performs a chat completion request to the SGL API. +func (provider *SGLProvider) ChatCompletion(ctx context.Context, model string, key schemas.Key, messages []schemas.BifrostMessage, params *schemas.ModelParameters) (*schemas.BifrostResponse, *schemas.BifrostError) { + formattedMessages, preparedParams := prepareOpenAIChatRequest(messages, params) + + requestBody := mergeConfig(map[string]interface{}{ + "model": model, + "messages": formattedMessages, + }, preparedParams) + + jsonBody, err := sonic.Marshal(requestBody) + if err != nil { + return nil, &schemas.BifrostError{ + IsBifrostError: true, + Error: schemas.ErrorField{ + Message: schemas.ErrProviderJSONMarshaling, + Error: err, + }, + } + } + + // Create request + req := fasthttp.AcquireRequest() + resp := fasthttp.AcquireResponse() + defer fasthttp.ReleaseRequest(req) + defer fasthttp.ReleaseResponse(resp) + + // Set any extra headers from network config + setExtraHeaders(req, provider.networkConfig.ExtraHeaders, nil) + + req.SetRequestURI(provider.networkConfig.BaseURL + "/v1/chat/completions") + req.Header.SetMethod("POST") + req.Header.SetContentType("application/json") + if key.Value != "" { + req.Header.Set("Authorization", "Bearer "+key.Value) + } + + req.SetBody(jsonBody) + + // Make request + bifrostErr := makeRequestWithContext(ctx, provider.client, req, resp) + if bifrostErr != nil { + return nil, bifrostErr + } + + // Handle error response + if resp.StatusCode() != fasthttp.StatusOK { + provider.logger.Debug(fmt.Sprintf("error from sgl provider: %s", string(resp.Body()))) + + var errorResp map[string]interface{} + bifrostErr := handleProviderAPIError(resp, &errorResp) + bifrostErr.Error.Message = fmt.Sprintf("SGL error: %v", errorResp) + return nil, bifrostErr + } + + responseBody := resp.Body() + + // Pre-allocate response structs from pools + // response := acquireSGLResponse() + response := &schemas.BifrostResponse{} + // defer releaseSGLResponse(response) + + // Use enhanced response handler with pre-allocated response + rawResponse, bifrostErr := handleProviderResponse(responseBody, response, provider.sendBackRawResponse) + if bifrostErr != nil { + return nil, bifrostErr + } + + response.ExtraFields.Provider = schemas.SGL + + if provider.sendBackRawResponse { + response.ExtraFields.RawResponse = rawResponse + } + + if params != nil { + response.ExtraFields.Params = *params + } + + return response, nil +} + +// Embedding is not supported by the SGL provider. +func (provider *SGLProvider) Embedding(ctx context.Context, model string, key schemas.Key, input *schemas.EmbeddingInput, params *schemas.ModelParameters) (*schemas.BifrostResponse, *schemas.BifrostError) { + return nil, newUnsupportedOperationError("embedding", "sgl") +} + +// ChatCompletionStream performs a streaming chat completion request to the SGL API. +// It supports real-time streaming of responses using Server-Sent Events (SSE). +// Uses SGL's OpenAI-compatible streaming format. +// Returns a channel containing BifrostResponse objects representing the stream or an error if the request fails. +func (provider *SGLProvider) ChatCompletionStream(ctx context.Context, postHookRunner schemas.PostHookRunner, model string, key schemas.Key, messages []schemas.BifrostMessage, params *schemas.ModelParameters) (chan *schemas.BifrostStream, *schemas.BifrostError) { + formattedMessages, preparedParams := prepareOpenAIChatRequest(messages, params) + + requestBody := mergeConfig(map[string]interface{}{ + "model": model, + "messages": formattedMessages, + "stream": true, + }, preparedParams) + + // Prepare SGL headers (SGL typically doesn't require authorization, but we include it if provided) + headers := map[string]string{ + "Content-Type": "application/json", + "Accept": "text/event-stream", + "Cache-Control": "no-cache", + } + + // Only add Authorization header if key is provided (SGL can run without auth) + if key.Value != "" { + headers["Authorization"] = "Bearer " + key.Value + } + + // Use shared OpenAI-compatible streaming logic + return handleOpenAIStreaming( + ctx, + provider.streamClient, + provider.networkConfig.BaseURL+"/v1/chat/completions", + requestBody, + headers, + provider.networkConfig.ExtraHeaders, + schemas.SGL, + params, + postHookRunner, + provider.logger, + ) +} + +func (provider *SGLProvider) Speech(ctx context.Context, model string, key schemas.Key, input *schemas.SpeechInput, params *schemas.ModelParameters) (*schemas.BifrostResponse, *schemas.BifrostError) { + return nil, newUnsupportedOperationError("speech", "sgl") +} + +func (provider *SGLProvider) SpeechStream(ctx context.Context, postHookRunner schemas.PostHookRunner, model string, key schemas.Key, input *schemas.SpeechInput, params *schemas.ModelParameters) (chan *schemas.BifrostStream, *schemas.BifrostError) { + return nil, newUnsupportedOperationError("speech stream", "sgl") +} + +func (provider *SGLProvider) Transcription(ctx context.Context, model string, key schemas.Key, input *schemas.TranscriptionInput, params *schemas.ModelParameters) (*schemas.BifrostResponse, *schemas.BifrostError) { + return nil, newUnsupportedOperationError("transcription", "sgl") +} + +func (provider *SGLProvider) TranscriptionStream(ctx context.Context, postHookRunner schemas.PostHookRunner, model string, key schemas.Key, input *schemas.TranscriptionInput, params *schemas.ModelParameters) (chan *schemas.BifrostStream, *schemas.BifrostError) { + return nil, newUnsupportedOperationError("transcription stream", "sgl") +} diff --git a/core/providers/utils.go b/core/providers/utils.go index 2988fac35..fe9062d84 100644 --- a/core/providers/utils.go +++ b/core/providers/utils.go @@ -3,13 +3,19 @@ package providers import ( + "bytes" + "context" "fmt" + "net/http" + "net/textproto" "net/url" "reflect" + "regexp" + "slices" "strings" "sync" - "github.com/goccy/go-json" + "github.com/bytedance/sonic" schemas "github.com/maximhq/bifrost/core/schemas" "github.com/valyala/fasthttp" "github.com/valyala/fasthttp/fasthttpproxy" @@ -17,27 +23,46 @@ import ( "maps" ) -// bifrostResponsePool provides a pool for Bifrost response objects. -var bifrostResponsePool = sync.Pool{ - New: func() interface{} { - return &schemas.BifrostResponse{} - }, +// dataURIRegex is a precompiled regex for matching data URI format patterns. +// It matches patterns like: ... +var dataURIRegex = regexp.MustCompile(`^data:([^;]+)(;base64)?,(.+)$`) + +// base64Regex is a precompiled regex for matching base64 strings. +// It matches strings containing only valid base64 characters with optional padding. +var base64Regex = regexp.MustCompile(`^[A-Za-z0-9+/]*={0,2}$`) + +// fileExtensionToMediaType maps common image file extensions to their corresponding media types. +// This map is used to infer media types from file extensions in URLs. +var fileExtensionToMediaType = map[string]string{ + ".jpg": "image/jpeg", + ".jpeg": "image/jpeg", + ".png": "image/png", + ".gif": "image/gif", + ".webp": "image/webp", + ".svg": "image/svg+xml", + ".bmp": "image/bmp", } -// acquireBifrostResponse gets a Bifrost response from the pool and resets it. -func acquireBifrostResponse() *schemas.BifrostResponse { - resp := bifrostResponsePool.Get().(*schemas.BifrostResponse) - *resp = schemas.BifrostResponse{} // Reset the struct - return resp -} +// ImageContentType represents the type of image content +type ImageContentType string -// releaseBifrostResponse returns a Bifrost response to the pool. -func releaseBifrostResponse(resp *schemas.BifrostResponse) { - if resp != nil { - bifrostResponsePool.Put(resp) - } +const ( + ImageContentTypeBase64 ImageContentType = "base64" + ImageContentTypeURL ImageContentType = "url" +) + +// URLTypeInfo contains extracted information about a URL +type URLTypeInfo struct { + Type ImageContentType + MediaType *string + DataURLWithoutPrefix *string // URL without the prefix (eg ...) } +// ContextKey is a custom type for context keys to prevent key collisions in the context. +// It provides type safety for context values and ensures that context keys are unique +// across different packages. +type ContextKey string + // mergeConfig merges a default configuration map with custom parameters. // It creates a new map containing all default values, then overrides them with any custom values. // Returns a new map containing the merged configuration. @@ -104,6 +129,49 @@ func prepareParams(params *schemas.ModelParameters) map[string]interface{} { return flatParams } +// IMPORTANT: This function does NOT truly cancel the underlying fasthttp network request if the +// context is done. The fasthttp client call will continue in its goroutine until it completes +// or times out based on its own settings. This function merely stops *waiting* for the +// fasthttp call and returns an error related to the context. +func makeRequestWithContext(ctx context.Context, client *fasthttp.Client, req *fasthttp.Request, resp *fasthttp.Response) *schemas.BifrostError { + errChan := make(chan error, 1) + + go func() { + // client.Do is a blocking call. + // It will send an error (or nil for success) to errChan when it completes. + errChan <- client.Do(req, resp) + }() + + select { + case <-ctx.Done(): + // Context was cancelled (e.g., deadline exceeded or manual cancellation). + // Return a BifrostError indicating this. + return &schemas.BifrostError{ + IsBifrostError: true, + Error: schemas.ErrorField{ + Type: Ptr(schemas.RequestCancelled), + Message: fmt.Sprintf("Request cancelled or timed out by context: %v", ctx.Err()), + Error: ctx.Err(), + }, + } + case err := <-errChan: + // The fasthttp.Do call completed. + if err != nil { + // The HTTP request itself failed (e.g., connection error, fasthttp timeout). + return &schemas.BifrostError{ + IsBifrostError: false, + Error: schemas.ErrorField{ + Message: schemas.ErrProviderRequest, + Error: err, + }, + } + } + // HTTP request was successful from fasthttp's perspective (err is nil). + // The caller should check resp.StatusCode() for HTTP-level errors (4xx, 5xx). + return nil + } +} + // configureProxy sets up a proxy for the fasthttp client based on the provided configuration. // It supports HTTP, SOCKS5, and environment-based proxy configurations. // Returns the configured client or the original client if proxy configuration is invalid. @@ -157,13 +225,73 @@ func configureProxy(client *fasthttp.Client, proxyConfig *schemas.ProxyConfig, l return client } +// setExtraHeaders sets additional headers from NetworkConfig to the fasthttp request. +// This allows users to configure custom headers for their provider requests. +// Header keys are canonicalized using textproto.CanonicalMIMEHeaderKey to avoid duplicates. +// The Authorization header is excluded for security reasons. +// It accepts a list of headers (all canonicalized) to skip for security reasons. +// Headers are only set if they don't already exist on the request to avoid overwriting important headers. +func setExtraHeaders(req *fasthttp.Request, extraHeaders map[string]string, skipHeaders *[]string) { + if extraHeaders == nil { + return + } + + for key, value := range extraHeaders { + canonicalKey := textproto.CanonicalMIMEHeaderKey(key) + // Skip Authorization header for security reasons + if key == "Authorization" { + continue + } + if skipHeaders != nil { + if slices.Contains(*skipHeaders, key) { + continue + } + } + // Only set the header if it doesn't already exist to avoid overwriting important headers + if len(req.Header.Peek(canonicalKey)) == 0 { + req.Header.Set(canonicalKey, value) + } + } +} + +// setExtraHeadersHTTP sets additional headers from NetworkConfig to the standard HTTP request. +// This allows users to configure custom headers for their provider requests. +// Header keys are canonicalized using textproto.CanonicalMIMEHeaderKey to avoid duplicates. +// It accepts a list of headers (all canonicalized) to skip for security reasons. +// Headers are only set if they don't already exist on the request to avoid overwriting important headers. +func setExtraHeadersHTTP(req *http.Request, extraHeaders map[string]string, skipHeaders *[]string) { + if extraHeaders == nil { + return + } + + for key, value := range extraHeaders { + canonicalKey := textproto.CanonicalMIMEHeaderKey(key) + // Skip Authorization header for security reasons + if key == "Authorization" { + continue + } + if skipHeaders != nil { + if slices.Contains(*skipHeaders, key) { + continue + } + } + // Only set the header if it doesn't already exist to avoid overwriting important headers + if req.Header.Get(canonicalKey) == "" { + req.Header.Set(canonicalKey, value) + } + } +} + // handleProviderAPIError processes error responses from provider APIs. // It attempts to unmarshal the error response and returns a BifrostError // with the appropriate status code and error information. func handleProviderAPIError(resp *fasthttp.Response, errorResp any) *schemas.BifrostError { - if err := json.Unmarshal(resp.Body(), &errorResp); err != nil { + statusCode := resp.StatusCode() + + if err := sonic.Unmarshal(resp.Body(), &errorResp); err != nil { return &schemas.BifrostError{ IsBifrostError: true, + StatusCode: &statusCode, Error: schemas.ErrorField{ Message: schemas.ErrProviderResponseUnmarshal, Error: err, @@ -171,8 +299,6 @@ func handleProviderAPIError(resp *fasthttp.Response, errorResp any) *schemas.Bif } } - statusCode := resp.StatusCode() - return &schemas.BifrostError{ IsBifrostError: false, StatusCode: &statusCode, @@ -183,7 +309,8 @@ func handleProviderAPIError(resp *fasthttp.Response, errorResp any) *schemas.Bif // handleProviderResponse handles common response parsing logic for provider responses. // It attempts to parse the response body into the provided response type // and returns either the parsed response or a BifrostError if parsing fails. -func handleProviderResponse[T any](responseBody []byte, response *T) (interface{}, *schemas.BifrostError) { +// If sendBackRawResponse is true, it returns the raw response interface, otherwise nil. +func handleProviderResponse[T any](responseBody []byte, response *T, sendBackRawResponse bool) (interface{}, *schemas.BifrostError) { var rawResponse interface{} var wg sync.WaitGroup @@ -192,11 +319,13 @@ func handleProviderResponse[T any](responseBody []byte, response *T) (interface{ wg.Add(2) go func() { defer wg.Done() - structuredErr = json.Unmarshal(responseBody, response) + structuredErr = sonic.Unmarshal(responseBody, response) }() go func() { defer wg.Done() - rawErr = json.Unmarshal(responseBody, &rawResponse) + if sendBackRawResponse { + rawErr = sonic.Unmarshal(responseBody, &rawResponse) + } }() wg.Wait() @@ -210,51 +339,510 @@ func handleProviderResponse[T any](responseBody []byte, response *T) (interface{ } } - if rawErr != nil { - return nil, &schemas.BifrostError{ - IsBifrostError: true, - Error: schemas.ErrorField{ - Message: schemas.ErrProviderDecodeRaw, - Error: rawErr, - }, + if sendBackRawResponse { + if rawErr != nil { + return nil, &schemas.BifrostError{ + IsBifrostError: true, + Error: schemas.ErrorField{ + Message: schemas.ErrProviderDecodeRaw, + Error: rawErr, + }, + } } + + return rawResponse, nil + } + + return nil, nil +} + +// getRoleFromMessage extracts and validates the role from a message map. +func getRoleFromMessage(msg map[string]interface{}) (schemas.ModelChatMessageRole, bool) { + roleVal, exists := msg["role"] + if !exists { + return "", false // Role key doesn't exist + } + + // Try direct assertion to ModelChatMessageRole + roleAsModelType, ok := roleVal.(schemas.ModelChatMessageRole) + if ok { + return roleAsModelType, true + } + + // Try assertion to string and then convert + roleAsString, okStr := roleVal.(string) + if okStr { + return schemas.ModelChatMessageRole(roleAsString), true } - return rawResponse, nil + return "", false // Role is of an unexpected or invalid type } -// float64Ptr creates a pointer to a float64 value. -// This is a helper function for creating pointers to float64 values. -func float64Ptr(f float64) *float64 { - return &f +// Ptr creates a pointer to any value. +// This is a helper function for creating pointers to values. +func Ptr[T any](v T) *T { + return &v } -func setConfigDefaults(config *schemas.ProviderConfig) { - if config.ConcurrencyAndBufferSize.Concurrency == 0 { - config.ConcurrencyAndBufferSize.Concurrency = schemas.DefaultConcurrency +//* IMAGE UTILS *// + +// SanitizeImageURL sanitizes and validates an image URL. +// It handles both data URLs and regular HTTP/HTTPS URLs. +// It also detects raw base64 image data and adds proper data URL headers. +func SanitizeImageURL(rawURL string) (string, error) { + if rawURL == "" { + return rawURL, fmt.Errorf("URL cannot be empty") + } + + // Trim whitespace + rawURL = strings.TrimSpace(rawURL) + + // Check if it's already a proper data URL + if strings.HasPrefix(rawURL, "data:") { + // Validate data URL format + if !dataURIRegex.MatchString(rawURL) { + return rawURL, fmt.Errorf("invalid data URL format") + } + return rawURL, nil + } + + // Check if it looks like raw base64 image data + if isLikelyBase64(rawURL) { + // Detect the image type from the base64 data + mediaType := detectImageTypeFromBase64(rawURL) + + // Remove any whitespace/newlines from base64 data + cleanBase64 := strings.ReplaceAll(strings.ReplaceAll(rawURL, "\n", ""), " ", "") + + // Create proper data URL + return fmt.Sprintf("data:%s;base64,%s", mediaType, cleanBase64), nil + } + + // Parse as regular URL + parsedURL, err := url.Parse(rawURL) + if err != nil { + return rawURL, fmt.Errorf("invalid URL format: %w", err) } - if config.ConcurrencyAndBufferSize.BufferSize == 0 { - config.ConcurrencyAndBufferSize.BufferSize = schemas.DefaultBufferSize + // Validate scheme + if parsedURL.Scheme != "http" && parsedURL.Scheme != "https" { + return rawURL, fmt.Errorf("URL must use http or https scheme") } - if config.NetworkConfig.DefaultRequestTimeoutInSeconds == 0 { - config.NetworkConfig.DefaultRequestTimeoutInSeconds = schemas.DefaultRequestTimeoutInSeconds + // Validate host + if parsedURL.Host == "" { + return rawURL, fmt.Errorf("URL must have a valid host") + } + + return parsedURL.String(), nil +} + +// ExtractURLTypeInfo extracts type and media type information from a sanitized URL. +// For data URLs, it parses the media type and encoding. +// For regular URLs, it attempts to infer the media type from the file extension. +func ExtractURLTypeInfo(sanitizedURL string) URLTypeInfo { + if strings.HasPrefix(sanitizedURL, "data:") { + return extractDataURLInfo(sanitizedURL) + } + return extractRegularURLInfo(sanitizedURL) +} + +// extractDataURLInfo extracts information from a data URL +func extractDataURLInfo(dataURL string) URLTypeInfo { + // Parse data URL: data:[][;base64], + matches := dataURIRegex.FindStringSubmatch(dataURL) + + if len(matches) != 4 { + return URLTypeInfo{Type: ImageContentTypeBase64} } - if config.NetworkConfig.MaxRetries == 0 { - config.NetworkConfig.MaxRetries = schemas.DefaultMaxRetries + mediaType := matches[1] + isBase64 := matches[2] == ";base64" + + dataURLWithoutPrefix := dataURL + if isBase64 { + dataURLWithoutPrefix = dataURL[len("data:")+len(mediaType)+len(";base64,"):] } - if config.NetworkConfig.RetryBackoffInitial == 0 { - config.NetworkConfig.RetryBackoffInitial = schemas.DefaultRetryBackoffInitial + info := URLTypeInfo{ + MediaType: &mediaType, + DataURLWithoutPrefix: &dataURLWithoutPrefix, } - if config.NetworkConfig.RetryBackoffMax == 0 { - config.NetworkConfig.RetryBackoffMax = schemas.DefaultRetryBackoffMax + if isBase64 { + info.Type = ImageContentTypeBase64 + } else { + info.Type = ImageContentTypeURL // Non-base64 data URL } + + return info } -func StrPtr(s string) *string { - return &s +// extractRegularURLInfo extracts information from a regular HTTP/HTTPS URL +func extractRegularURLInfo(regularURL string) URLTypeInfo { + info := URLTypeInfo{ + Type: ImageContentTypeURL, + } + + // Try to infer media type from file extension + parsedURL, err := url.Parse(regularURL) + if err != nil { + return info + } + + path := strings.ToLower(parsedURL.Path) + + // Check for known file extensions using the map + for ext, mediaType := range fileExtensionToMediaType { + if strings.HasSuffix(path, ext) { + info.MediaType = &mediaType + break + } + } + // For URLs without recognizable extensions, MediaType remains nil + + return info +} + +// detectImageTypeFromBase64 detects the image type from base64 data by examining the header bytes +func detectImageTypeFromBase64(base64Data string) string { + // Remove any whitespace or newlines + cleanData := strings.ReplaceAll(strings.ReplaceAll(base64Data, "\n", ""), " ", "") + + // Check common image format signatures in base64 + switch { + case strings.HasPrefix(cleanData, "/9j/") || strings.HasPrefix(cleanData, "/9k/"): + // JPEG images typically start with /9j/ or /9k/ in base64 (FFD8 in hex) + return "image/jpeg" + case strings.HasPrefix(cleanData, "iVBORw0KGgo"): + // PNG images start with iVBORw0KGgo in base64 (89504E470D0A1A0A in hex) + return "image/png" + case strings.HasPrefix(cleanData, "R0lGOD"): + // GIF images start with R0lGOD in base64 (474946 in hex) + return "image/gif" + case strings.HasPrefix(cleanData, "Qk"): + // BMP images start with Qk in base64 (424D in hex) + return "image/bmp" + case strings.HasPrefix(cleanData, "UklGR") && len(cleanData) >= 16 && cleanData[12:16] == "V0VC": + // WebP images start with RIFF header (UklGR in base64) and have WEBP signature at offset 8-11 (V0VC in base64) + return "image/webp" + case strings.HasPrefix(cleanData, "PHN2Zy") || strings.HasPrefix(cleanData, "PD94bW"): + // SVG images often start with = 12 && + bytes.Equal(audioData[:4], riff) && + bytes.Equal(audioData[8:12], wave) { + return "audio/wav" + } + // MP3: ID3v2 tag (keep this check for MP3) + if len(audioData) >= 3 && bytes.Equal(audioData[:3], id3) { + return "audio/mp3" + } + // AAC: ADIF or ADTS (0xFFF sync) - check before MP3 frame sync to avoid misclassification + if bytes.HasPrefix(audioData, adif) { + return "audio/aac" + } + if len(audioData) >= 2 && audioData[0] == 0xFF && (audioData[1]&0xF6) == 0xF0 { + return "audio/aac" + } + // AIFF / AIFC (map both to audio/aiff) + if len(audioData) >= 12 && bytes.Equal(audioData[:4], form) && + (bytes.Equal(audioData[8:12], aiff) || bytes.Equal(audioData[8:12], aifc)) { + return "audio/aiff" + } + // FLAC + if bytes.HasPrefix(audioData, flac) { + return "audio/flac" + } + // OGG container + if bytes.HasPrefix(audioData, oggs) { + return "audio/ogg" + } + // MP3: MPEG frame sync (cover common variants) - check after AAC to avoid misclassification + if len(audioData) >= 2 && audioData[0] == 0xFF && + (audioData[1] == 0xFB || audioData[1] == 0xF3 || audioData[1] == 0xF2 || audioData[1] == 0xFA) { + return "audio/mp3" + } + // Fallback within supported set + return "audio/mp3" +} + +// newUnsupportedOperationError creates a standardized error for unsupported operations. +// This helper reduces code duplication across providers that don't support certain operations. +func newUnsupportedOperationError(operation string, providerName string) *schemas.BifrostError { + return &schemas.BifrostError{ + IsBifrostError: false, + Provider: schemas.ModelProvider(providerName), + Error: schemas.ErrorField{ + Message: fmt.Sprintf("%s is not supported by %s provider", operation, providerName), + }, + } +} + +// checkOperationAllowed enforces per-op gating using schemas.Operation. +// Behavior: +// - If no gating is configured (config == nil or AllowedRequests == nil), the operation is allowed. +// - If gating is configured, returns an error when the operation is not explicitly allowed. +func checkOperationAllowed(defaultProvider schemas.ModelProvider, config *schemas.CustomProviderConfig, operation schemas.Operation) *schemas.BifrostError { + // No gating configured => allowed + if config == nil || config.AllowedRequests == nil { + return nil + } + // Explicitly allowed? + if config.IsOperationAllowed(operation) { + return nil + } + // Gated and not allowed + resolved := getProviderName(defaultProvider, config) + return newUnsupportedOperationError(string(operation), string(resolved)) +} + +// newConfigurationError creates a standardized error for configuration errors. +// This helper reduces code duplication across providers that have configuration errors. +func newConfigurationError(message string, providerType schemas.ModelProvider) *schemas.BifrostError { + return &schemas.BifrostError{ + IsBifrostError: false, + Provider: providerType, + Error: schemas.ErrorField{ + Message: message, + }, + } +} + +// newBifrostOperationError creates a standardized error for bifrost operation errors. +// This helper reduces code duplication across providers that have bifrost operation errors. +func newBifrostOperationError(message string, err error, providerType schemas.ModelProvider) *schemas.BifrostError { + return &schemas.BifrostError{ + IsBifrostError: true, + Provider: providerType, + Error: schemas.ErrorField{ + Message: message, + Error: err, + }, + } +} + +// newProviderAPIError creates a standardized error for provider API errors. +// This helper reduces code duplication across providers that have provider API errors. +func newProviderAPIError(message string, err error, statusCode int, providerType schemas.ModelProvider, errorType *string, eventID *string) *schemas.BifrostError { + return &schemas.BifrostError{ + IsBifrostError: false, + Provider: providerType, + StatusCode: &statusCode, + Type: errorType, + EventID: eventID, + Error: schemas.ErrorField{ + Message: message, + Error: err, + Type: errorType, + }, + } +} + +// processAndSendResponse handles post-hook processing and sends the response to the channel. +// This utility reduces code duplication across streaming implementations by encapsulating +// the common pattern of running post hooks, handling errors, and sending responses with +// proper context cancellation handling. +func processAndSendResponse( + ctx context.Context, + postHookRunner schemas.PostHookRunner, + response *schemas.BifrostResponse, + responseChan chan *schemas.BifrostStream, + logger schemas.Logger, +) { + // Run post hooks on the response + processedResponse, bifrostErr := postHookRunner(&ctx, response, nil) + if bifrostErr != nil { + // check if it is a stream error + if handleStreamControlSkip(logger, bifrostErr) { + return + } + + // Send error response and close channel + errorResponse := &schemas.BifrostStream{ + BifrostError: bifrostErr, + } + + // Try to send error response before closing + select { + case responseChan <- errorResponse: + case <-ctx.Done(): + } + return + } + + // Send the response + select { + case responseChan <- &schemas.BifrostStream{ + BifrostResponse: processedResponse, + BifrostError: bifrostErr, + }: + case <-ctx.Done(): + return + } +} + +// processAndSendBifrostError handles post-hook processing and sends the bifrost error to the channel. +// This utility reduces code duplication across streaming implementations by encapsulating +// the common pattern of running post hooks, handling errors, and sending responses with +// proper context cancellation handling. +func processAndSendBifrostError( + ctx context.Context, + postHookRunner schemas.PostHookRunner, + bifrostErr *schemas.BifrostError, + responseChan chan *schemas.BifrostStream, + logger schemas.Logger, +) { + // Send scanner error through channel + processedResponse, processedError := postHookRunner(&ctx, nil, bifrostErr) + + if handleStreamControlSkip(logger, processedError) { + return + } + + errorResponse := &schemas.BifrostStream{ + BifrostResponse: processedResponse, + BifrostError: processedError, + } + select { + case responseChan <- errorResponse: + case <-ctx.Done(): + } +} + +// processAndSendError handles post-hook processing and sends the error to the channel. +// This utility reduces code duplication across streaming implementations by encapsulating +// the common pattern of running post hooks, handling errors, and sending responses with +// proper context cancellation handling. +func processAndSendError( + ctx context.Context, + postHookRunner schemas.PostHookRunner, + err error, + responseChan chan *schemas.BifrostStream, + logger schemas.Logger, +) { + // Send scanner error through channel + bifrostError := + &schemas.BifrostError{ + IsBifrostError: true, + Error: schemas.ErrorField{ + Message: fmt.Sprintf("Error reading stream: %v", err), + Error: err, + }, + } + processedResponse, processedError := postHookRunner(&ctx, nil, bifrostError) + + if handleStreamControlSkip(logger, processedError) { + return + } + + errorResponse := &schemas.BifrostStream{ + BifrostResponse: processedResponse, + BifrostError: processedError, + } + select { + case responseChan <- errorResponse: + case <-ctx.Done(): + } +} + +func createBifrostChatCompletionChunkResponse( + id string, + usage *schemas.LLMUsage, + finishReason *string, + currentChunkIndex int, + params *schemas.ModelParameters, + providerName schemas.ModelProvider, +) *schemas.BifrostResponse { + response := &schemas.BifrostResponse{ + ID: id, + Object: "chat.completion.chunk", + Usage: usage, + Choices: []schemas.BifrostResponseChoice{ + { + FinishReason: finishReason, + BifrostStreamResponseChoice: &schemas.BifrostStreamResponseChoice{ + Delta: schemas.BifrostStreamDelta{}, // empty delta + }, + }, + }, + ExtraFields: schemas.BifrostResponseExtraFields{ + Provider: providerName, + ChunkIndex: currentChunkIndex + 1, + }, + } + if params != nil { + response.ExtraFields.Params = *params + } + return response +} + +func handleStreamEndWithSuccess( + ctx context.Context, + response *schemas.BifrostResponse, + postHookRunner schemas.PostHookRunner, + responseChan chan *schemas.BifrostStream, + logger schemas.Logger, +) { + ctx = context.WithValue(ctx, schemas.BifrostContextKeyStreamEndIndicator, true) + processAndSendResponse(ctx, postHookRunner, response, responseChan, logger) +} + +func handleStreamControlSkip(logger schemas.Logger, bifrostErr *schemas.BifrostError) bool { + if bifrostErr == nil || bifrostErr.StreamControl == nil { + return false + } + if bifrostErr.StreamControl.SkipStream != nil && *bifrostErr.StreamControl.SkipStream { + if bifrostErr.StreamControl.LogError != nil && *bifrostErr.StreamControl.LogError { + logger.Warn("Error in stream: " + bifrostErr.Error.Message) + } + return true + } + return false +} + +// getProviderName extracts the provider name from custom provider configuration. +// If a custom provider key is specified, it returns that; otherwise, it returns the default provider. +// Note: CustomProviderKey is internally set by Bifrost and should always match the provider name. +func getProviderName(defaultProvider schemas.ModelProvider, customConfig *schemas.CustomProviderConfig) schemas.ModelProvider { + if customConfig != nil { + if key := strings.TrimSpace(customConfig.CustomProviderKey); key != "" { + return schemas.ModelProvider(key) + } + } + return defaultProvider } diff --git a/core/providers/vertex.go b/core/providers/vertex.go new file mode 100644 index 000000000..808bb0540 --- /dev/null +++ b/core/providers/vertex.go @@ -0,0 +1,722 @@ +// Package providers implements various LLM providers and their utility functions. +// This file contains the Vertex provider implementation. +package providers + +import ( + "bytes" + "context" + "crypto/sha256" + "encoding/hex" + "errors" + "fmt" + "io" + "net/http" + "strings" + "sync" + + "golang.org/x/oauth2/google" + + "github.com/bytedance/sonic" + schemas "github.com/maximhq/bifrost/core/schemas" +) + +type VertexError struct { + Error struct { + Code int `json:"code"` + Message string `json:"message"` + Status string `json:"status"` + } `json:"error"` +} + +// vertexClientPool provides a pool/cache for authenticated Vertex HTTP clients. +// This avoids creating and authenticating clients for every request. +// Uses sync.Map for atomic operations without explicit locking. +var vertexClientPool sync.Map + +// getClientKey generates a unique key for caching authenticated clients. +// It uses a hash of the auth credentials for security. +func getClientKey(authCredentials string) string { + hash := sha256.Sum256([]byte(authCredentials)) + return hex.EncodeToString(hash[:]) +} + +// removeVertexClient removes a specific client from the pool. +// This should be called when: +// - API returns authentication/authorization errors (401, 403) +// - Auth client creation fails +// - Network errors that might indicate credential issues +// This ensures we don't keep using potentially invalid clients. +func removeVertexClient(authCredentials string) { + clientKey := getClientKey(authCredentials) + vertexClientPool.Delete(clientKey) +} + +// VertexProvider implements the Provider interface for Google's Vertex AI API. +type VertexProvider struct { + logger schemas.Logger // Logger for provider operations + networkConfig schemas.NetworkConfig // Network configuration including extra headers + sendBackRawResponse bool // Whether to include raw response in BifrostResponse +} + +// NewVertexProvider creates a new Vertex provider instance. +// It initializes the HTTP client with the provided configuration and sets up response pools. +// The client is configured with timeouts, concurrency limits, and optional proxy settings. +func NewVertexProvider(config *schemas.ProviderConfig, logger schemas.Logger) (*VertexProvider, error) { + config.CheckAndSetDefaults() + + // Pre-warm response pools + for range config.ConcurrencyAndBufferSize.Concurrency { + // openAIResponsePool.Put(&schemas.BifrostResponse{}) + anthropicChatResponsePool.Put(&AnthropicChatResponse{}) + + } + + return &VertexProvider{ + logger: logger, + networkConfig: config.NetworkConfig, + sendBackRawResponse: config.SendBackRawResponse, + }, nil +} + +const cloudPlatformScope = "https://www.googleapis.com/auth/cloud-platform" + +// getAuthClient returns an authenticated HTTP client for Vertex AI API requests. +// This function implements client pooling to avoid creating and authenticating +// clients for every request, which significantly improves performance by: +// - Avoiding repeated JWT config creation +// - Reusing OAuth2 token refresh logic +// - Reducing authentication overhead +func getAuthClient(key schemas.Key) (*http.Client, error) { + if key.VertexKeyConfig == nil { + return nil, fmt.Errorf("vertex key config is not set") + } + + authCredentials := key.VertexKeyConfig.AuthCredentials + var client *http.Client + // Generate cache key from credentials + clientKey := getClientKey(authCredentials) + + // Try to get existing client from pool + if value, exists := vertexClientPool.Load(clientKey); exists { + return value.(*http.Client), nil + } + + if authCredentials == "" { + // When auth credentials are not explicitly set, use default credentials + // This will automatically detect credentials from the environment/server + var err error + client, err = google.DefaultClient(context.Background(), cloudPlatformScope) + if err != nil { + return nil, fmt.Errorf("failed to create default client: %w", err) + } + } else { + conf, err := google.JWTConfigFromJSON([]byte(authCredentials), cloudPlatformScope) + if err != nil { + return nil, fmt.Errorf("failed to create JWT config: %w", err) + } + client = conf.Client(context.Background()) + } + + // Store the client using LoadOrStore to handle race conditions + // If another goroutine stored a client while we were creating ours, use theirs + actual, _ := vertexClientPool.LoadOrStore(clientKey, client) + return actual.(*http.Client), nil +} + +// GetProviderKey returns the provider identifier for Vertex. +func (provider *VertexProvider) GetProviderKey() schemas.ModelProvider { + return schemas.Vertex +} + +// TextCompletion is not supported by the Vertex provider. +// Returns an error indicating that text completion is not available. +func (provider *VertexProvider) TextCompletion(ctx context.Context, model string, key schemas.Key, text string, params *schemas.ModelParameters) (*schemas.BifrostResponse, *schemas.BifrostError) { + return nil, newUnsupportedOperationError("text completion", "vertex") +} + +// ChatCompletion performs a chat completion request to the Vertex API. +// It supports both text and image content in messages. +// Returns a BifrostResponse containing the completion results or an error if the request fails. +func (provider *VertexProvider) ChatCompletion(ctx context.Context, model string, key schemas.Key, messages []schemas.BifrostMessage, params *schemas.ModelParameters) (*schemas.BifrostResponse, *schemas.BifrostError) { + if key.VertexKeyConfig == nil { + return nil, newConfigurationError("vertex key config is not set", schemas.Vertex) + } + + // Format messages for Vertex API + var formattedMessages []map[string]interface{} + var preparedParams map[string]interface{} + + if strings.Contains(model, "claude") { + formattedMessages, preparedParams = prepareAnthropicChatRequest(messages, params) + } else { + formattedMessages, preparedParams = prepareOpenAIChatRequest(messages, params) + } + + requestBody := mergeConfig(map[string]interface{}{ + "model": model, + "messages": formattedMessages, + }, preparedParams) + + if strings.Contains(model, "claude") { + if _, exists := requestBody["anthropic_version"]; !exists { + requestBody["anthropic_version"] = "vertex-2023-10-16" + } + + delete(requestBody, "model") + } + + delete(requestBody, "region") + + jsonBody, err := sonic.Marshal(requestBody) + if err != nil { + return nil, newBifrostOperationError(schemas.ErrProviderJSONMarshaling, err, schemas.Vertex) + } + + projectID := key.VertexKeyConfig.ProjectID + if projectID == "" { + return nil, newConfigurationError("project ID is not set", schemas.Vertex) + } + + region := key.VertexKeyConfig.Region + if region == "" { + return nil, newConfigurationError("region is not set in key config", schemas.Vertex) + } + + url := fmt.Sprintf("https://%s-aiplatform.googleapis.com/v1beta1/projects/%s/locations/%s/endpoints/openapi/chat/completions", region, projectID, region) + + if strings.Contains(model, "claude") { + url = fmt.Sprintf("https://%s-aiplatform.googleapis.com/v1/projects/%s/locations/%s/publishers/anthropic/models/%s:streamRawPredict", region, projectID, region, model) + } + + // Create request + req, err := http.NewRequestWithContext(ctx, "POST", url, bytes.NewReader(jsonBody)) + if err != nil { + return nil, &schemas.BifrostError{ + IsBifrostError: false, + Error: schemas.ErrorField{ + Message: schemas.ErrProviderRequest, + Error: err, + }, + } + } + + // Set any extra headers from network config + setExtraHeadersHTTP(req, provider.networkConfig.ExtraHeaders, nil) + + req.Header.Set("Content-Type", "application/json") + + client, err := getAuthClient(key) + if err != nil { + // Remove client from pool if auth client creation fails + removeVertexClient(key.VertexKeyConfig.AuthCredentials) + return nil, newBifrostOperationError("error creating auth client", err, schemas.Vertex) + } + + // Make request + resp, err := client.Do(req) + if err != nil { + if errors.Is(err, context.Canceled) || errors.Is(err, context.DeadlineExceeded) { + return nil, &schemas.BifrostError{ + IsBifrostError: false, + Error: schemas.ErrorField{ + Type: Ptr(schemas.RequestCancelled), + Message: fmt.Sprintf("Request cancelled or timed out by context: %v", ctx.Err()), + Error: err, + }, + } + } + // Remove client from pool for non-context errors (could be auth/network issues) + removeVertexClient(key.VertexKeyConfig.AuthCredentials) + return nil, newBifrostOperationError(schemas.ErrProviderRequest, err, schemas.Vertex) + } + defer resp.Body.Close() + + // Handle error response + // Read response body + body, err := io.ReadAll(resp.Body) + if err != nil { + return nil, newBifrostOperationError("error reading response", err, schemas.Vertex) + } + + if resp.StatusCode != http.StatusOK { + // Remove client from pool for authentication/authorization errors + if resp.StatusCode == http.StatusUnauthorized || resp.StatusCode == http.StatusForbidden { + removeVertexClient(key.VertexKeyConfig.AuthCredentials) + } + + var openAIErr schemas.BifrostError + var vertexErr []VertexError + + if err := sonic.Unmarshal(body, &openAIErr); err != nil { + // Try Vertex error format if OpenAI format fails + if err := sonic.Unmarshal(body, &vertexErr); err != nil { + return nil, newBifrostOperationError(schemas.ErrProviderResponseUnmarshal, err, schemas.Vertex) + } + + if len(vertexErr) > 0 { + return nil, newProviderAPIError(vertexErr[0].Error.Message, nil, resp.StatusCode, schemas.Vertex, nil, nil) + } + } + + return nil, newProviderAPIError(openAIErr.Error.Message, nil, resp.StatusCode, schemas.Vertex, nil, nil) + } + + if strings.Contains(model, "claude") { + // Create response object from pool + response := acquireAnthropicChatResponse() + defer releaseAnthropicChatResponse(response) + + rawResponse, bifrostErr := handleProviderResponse(body, response, provider.sendBackRawResponse) + if bifrostErr != nil { + return nil, bifrostErr + } + + // Create final response + bifrostResponse := &schemas.BifrostResponse{} + var err *schemas.BifrostError + bifrostResponse, err = parseAnthropicResponse(response, bifrostResponse) + if err != nil { + return nil, err + } + + bifrostResponse.ExtraFields = schemas.BifrostResponseExtraFields{ + Provider: schemas.Vertex, + } + + if provider.sendBackRawResponse { + bifrostResponse.ExtraFields.RawResponse = rawResponse + } + + if params != nil { + bifrostResponse.ExtraFields.Params = *params + } + + return bifrostResponse, nil + } else { + // Pre-allocate response structs from pools + // response := acquireOpenAIResponse() + response := &schemas.BifrostResponse{} + // defer releaseOpenAIResponse(response) + + // Use enhanced response handler with pre-allocated response + rawResponse, bifrostErr := handleProviderResponse(body, response, provider.sendBackRawResponse) + if bifrostErr != nil { + return nil, bifrostErr + } + + response.ExtraFields.Provider = schemas.Vertex + + if provider.sendBackRawResponse { + response.ExtraFields.RawResponse = rawResponse + } + + if params != nil { + response.ExtraFields.Params = *params + } + + return response, nil + } +} + +// Embedding generates embeddings for the given input text(s) using Vertex AI. +// All Vertex AI embedding models use the same response format regardless of the model type. +// Returns a BifrostResponse containing the embedding(s) and any error that occurred. +func (provider *VertexProvider) Embedding(ctx context.Context, model string, key schemas.Key, input *schemas.EmbeddingInput, params *schemas.ModelParameters) (*schemas.BifrostResponse, *schemas.BifrostError) { + if key.VertexKeyConfig == nil { + return nil, newConfigurationError("vertex key config is not set", schemas.Vertex) + } + + projectID := key.VertexKeyConfig.ProjectID + if projectID == "" { + return nil, newConfigurationError("project ID is not set", schemas.Vertex) + } + + region := key.VertexKeyConfig.Region + if region == "" { + return nil, newConfigurationError("region is not set in key config", schemas.Vertex) + } + + // Validate input + if input.Text == nil && len(input.Texts) == 0 { + return nil, newConfigurationError("embedding input texts are empty", schemas.Vertex) + } + + // All Vertex AI embedding models use the same native Vertex embedding API + return provider.handleVertexEmbedding(ctx, model, key, input, params) +} + +// handleVertexEmbedding handles embedding requests using Vertex's native embedding API +// This is used for all Vertex AI embedding models as they all use the same response format +func (provider *VertexProvider) handleVertexEmbedding(ctx context.Context, model string, key schemas.Key, input *schemas.EmbeddingInput, params *schemas.ModelParameters) (*schemas.BifrostResponse, *schemas.BifrostError) { + // Prepare request body for Vertex's native embedding API + texts := input.Texts + + if len(texts) == 0 && input.Text != nil { + texts = []string{*input.Text} + } + + instances := make([]map[string]interface{}, 0, len(texts)) + requestBody := map[string]interface{}{ + "instances": instances, + } + + // Add text instances + for _, text := range texts { + instance := map[string]interface{}{ + "content": text, + } + + // Add optional task_type if specified in params + if params != nil && params.ExtraParams != nil { + if taskType, exists := params.ExtraParams["task_type"]; exists { + instance["task_type"] = taskType + } + if title, exists := params.ExtraParams["title"]; exists { + instance["title"] = title + } + } + + requestBody["instances"] = append(requestBody["instances"].([]map[string]interface{}), instance) + } + + // Add parameters + parameters := make(map[string]interface{}) + + // Set autoTruncate (defaults to true) + autoTruncate := true + if params != nil && params.ExtraParams != nil { + if autoTruncateVal, exists := params.ExtraParams["autoTruncate"]; exists { + if autoTruncateBool, ok := autoTruncateVal.(bool); ok { + autoTruncate = autoTruncateBool + } + } + } + parameters["autoTruncate"] = autoTruncate + + // Add outputDimensionality if specified + if params != nil && params.Dimensions != nil { + parameters["outputDimensionality"] = *params.Dimensions + } + + // Add any other extra parameters + if params != nil && params.ExtraParams != nil { + for k, v := range params.ExtraParams { + // Skip parameters we've already handled + if k != "task_type" && k != "title" && k != "autoTruncate" && k != "outputDimensionality" { + parameters[k] = v + } + } + } + + requestBody["parameters"] = parameters + + jsonBody, err := sonic.Marshal(requestBody) + if err != nil { + return nil, newBifrostOperationError(schemas.ErrProviderJSONMarshaling, err, schemas.Vertex) + } + + // Build the native Vertex embedding API endpoint + url := fmt.Sprintf("https://%s-aiplatform.googleapis.com/v1/projects/%s/locations/%s/publishers/google/models/%s:predict", + key.VertexKeyConfig.Region, key.VertexKeyConfig.ProjectID, key.VertexKeyConfig.Region, model) + + // Create request + req, err := http.NewRequestWithContext(ctx, "POST", url, bytes.NewReader(jsonBody)) + if err != nil { + return nil, newBifrostOperationError(schemas.ErrProviderRequest, err, schemas.Vertex) + } + + // Set any extra headers from network config + setExtraHeadersHTTP(req, provider.networkConfig.ExtraHeaders, nil) + + req.Header.Set("Content-Type", "application/json") + + client, err := getAuthClient(key) + if err != nil { + // Remove client from pool if auth client creation fails + removeVertexClient(key.VertexKeyConfig.AuthCredentials) + return nil, newBifrostOperationError("error creating auth client", err, schemas.Vertex) + } + + // Make request + resp, err := client.Do(req) + if err != nil { + if errors.Is(err, context.Canceled) || errors.Is(err, context.DeadlineExceeded) { + return nil, &schemas.BifrostError{ + IsBifrostError: false, + Error: schemas.ErrorField{ + Type: Ptr(schemas.RequestCancelled), + Message: fmt.Sprintf("Request cancelled or timed out by context: %v", ctx.Err()), + Error: err, + }, + } + } + // Remove client from pool for non-context errors (could be auth/network issues) + removeVertexClient(key.VertexKeyConfig.AuthCredentials) + return nil, newBifrostOperationError(schemas.ErrProviderRequest, err, schemas.Vertex) + } + defer resp.Body.Close() + + // Handle error response + body, err := io.ReadAll(resp.Body) + if err != nil { + return nil, newBifrostOperationError("error reading response", err, schemas.Vertex) + } + + if resp.StatusCode != http.StatusOK { + // Remove client from pool for authentication/authorization errors + if resp.StatusCode == http.StatusUnauthorized || resp.StatusCode == http.StatusForbidden { + removeVertexClient(key.VertexKeyConfig.AuthCredentials) + } + + // Try to parse Vertex's error format + var vertexError map[string]interface{} + if err := sonic.Unmarshal(body, &vertexError); err != nil { + return nil, newBifrostOperationError(schemas.ErrProviderResponseUnmarshal, err, schemas.Vertex) + } + + // Extract error message from Vertex's error format + errorMessage := "Unknown error" + if errorObj, exists := vertexError["error"]; exists { + if errorMap, ok := errorObj.(map[string]interface{}); ok { + if message, exists := errorMap["message"]; exists { + if msgStr, ok := message.(string); ok { + errorMessage = msgStr + } + } + } + } + + return nil, newProviderAPIError(errorMessage, nil, resp.StatusCode, schemas.Vertex, nil, nil) + } + + // Parse Vertex's native embedding response + var vertexResponse map[string]interface{} + if err := sonic.Unmarshal(body, &vertexResponse); err != nil { + return nil, newBifrostOperationError(schemas.ErrProviderResponseUnmarshal, err, schemas.Vertex) + } + + // Convert Vertex's response format to Bifrost format + bifrostResponse, bifrostErr := provider.convertVertexEmbeddingResponse(vertexResponse, model, params) + if bifrostErr != nil { + return nil, bifrostErr + } + + // Set provider and raw response + bifrostResponse.ExtraFields.Provider = schemas.Vertex + if provider.sendBackRawResponse { + bifrostResponse.ExtraFields.RawResponse = vertexResponse + } + + return bifrostResponse, nil +} + +// convertVertexEmbeddingResponse converts Vertex's native embedding response to Bifrost format +func (provider *VertexProvider) convertVertexEmbeddingResponse(vertexResponse map[string]interface{}, model string, params *schemas.ModelParameters) (*schemas.BifrostResponse, *schemas.BifrostError) { + // Extract predictions from Vertex's response + predictions, exists := vertexResponse["predictions"] + if !exists { + return nil, newBifrostOperationError("missing predictions in response", fmt.Errorf("no predictions field"), schemas.Vertex) + } + + predictionsArray, ok := predictions.([]interface{}) + if !ok { + return nil, newBifrostOperationError("invalid predictions format", fmt.Errorf("predictions is not an array"), schemas.Vertex) + } + + // Convert to Bifrost embedding format + var embeddings []schemas.BifrostEmbedding + var usage *schemas.LLMUsage + + for i, pred := range predictionsArray { + predMap, ok := pred.(map[string]interface{}) + if !ok { + continue + } + + embeddingsObj, exists := predMap["embeddings"] + if !exists { + continue + } + + embMap, ok := embeddingsObj.(map[string]interface{}) + if !ok { + continue + } + + // Extract values + values, exists := embMap["values"] + if !exists { + continue + } + + valuesArray, ok := values.([]interface{}) + if !ok { + continue + } + + // Convert to float32 in a single pass + embeddingFloat32 := make([]float32, 0, len(valuesArray)) + for _, v := range valuesArray { + if f64, ok := v.(float64); ok { + embeddingFloat32 = append(embeddingFloat32, float32(f64)) + } + } + + // Create embedding object + embedding := schemas.BifrostEmbedding{ + Object: "embedding", + Embedding: schemas.BifrostEmbeddingResponse{ + EmbeddingArray: &embeddingFloat32, + }, + Index: i, + } + + // Extract statistics if available + if stats, exists := embMap["statistics"]; exists { + if statsMap, ok := stats.(map[string]interface{}); ok { + if tokenCount, exists := statsMap["token_count"]; exists { + if count, ok := tokenCount.(float64); ok { + if usage == nil { + usage = &schemas.LLMUsage{} + } + usage.TotalTokens += int(count) + usage.PromptTokens += int(count) + } + } + } + } + + embeddings = append(embeddings, embedding) + } + + // Create final response + response := &schemas.BifrostResponse{ + Object: "list", + Model: model, + Data: embeddings, + Usage: usage, + ExtraFields: schemas.BifrostResponseExtraFields{}, + } + + if params != nil { + response.ExtraFields.Params = *params + } + + return response, nil +} + +// ChatCompletionStream performs a streaming chat completion request to the Vertex API. +// It supports both OpenAI-style streaming (for non-Claude models) and Anthropic-style streaming (for Claude models). +// Returns a channel of BifrostResponse objects for streaming results or an error if the request fails. +func (provider *VertexProvider) ChatCompletionStream(ctx context.Context, postHookRunner schemas.PostHookRunner, model string, key schemas.Key, messages []schemas.BifrostMessage, params *schemas.ModelParameters) (chan *schemas.BifrostStream, *schemas.BifrostError) { + if key.VertexKeyConfig == nil { + return nil, newConfigurationError("vertex key config is not set", schemas.Vertex) + } + + projectID := key.VertexKeyConfig.ProjectID + if projectID == "" { + return nil, newConfigurationError("project ID is not set", schemas.Vertex) + } + + region := key.VertexKeyConfig.Region + if region == "" { + return nil, newConfigurationError("region is not set in key config", schemas.Vertex) + } + + client, err := getAuthClient(key) + if err != nil { + // Remove client from pool if auth client creation fails + removeVertexClient(key.VertexKeyConfig.AuthCredentials) + return nil, newBifrostOperationError("error creating auth client", err, schemas.Vertex) + } + + if strings.Contains(model, "claude") { + // Use Anthropic-style streaming for Claude models + formattedMessages, preparedParams := prepareAnthropicChatRequest(messages, params) + + requestBody := mergeConfig(map[string]interface{}{ + "messages": formattedMessages, + "stream": true, + }, preparedParams) + + if _, exists := requestBody["anthropic_version"]; !exists { + requestBody["anthropic_version"] = "vertex-2023-10-16" + } + + delete(requestBody, "model") + delete(requestBody, "region") + + url := fmt.Sprintf("https://%s-aiplatform.googleapis.com/v1/projects/%s/locations/%s/publishers/anthropic/models/%s:streamRawPredict", region, projectID, region, model) + + // Prepare headers for Vertex Anthropic + headers := map[string]string{ + "Content-Type": "application/json", + "Accept": "text/event-stream", + "Cache-Control": "no-cache", + } + + // Use shared Anthropic streaming logic + return handleAnthropicStreaming( + ctx, + client, + url, + requestBody, + headers, + provider.networkConfig.ExtraHeaders, + schemas.Vertex, + params, + postHookRunner, + provider.logger, + ) + } else { + // Use OpenAI-style streaming for non-Claude models + formattedMessages, preparedParams := prepareOpenAIChatRequest(messages, params) + + requestBody := mergeConfig(map[string]interface{}{ + "model": model, + "messages": formattedMessages, + "stream": true, + }, preparedParams) + + delete(requestBody, "region") + + url := fmt.Sprintf("https://%s-aiplatform.googleapis.com/v1beta1/projects/%s/locations/%s/endpoints/openapi/chat/completions", region, projectID, region) + + // Prepare headers for Vertex OpenAI-compatible + headers := map[string]string{ + "Content-Type": "application/json", + "Accept": "text/event-stream", + "Cache-Control": "no-cache", + } + + // Use shared OpenAI streaming logic + return handleOpenAIStreaming( + ctx, + client, + url, + requestBody, + headers, + provider.networkConfig.ExtraHeaders, + schemas.Vertex, + params, + postHookRunner, + provider.logger, + ) + } +} + +func (provider *VertexProvider) Speech(ctx context.Context, model string, key schemas.Key, input *schemas.SpeechInput, params *schemas.ModelParameters) (*schemas.BifrostResponse, *schemas.BifrostError) { + return nil, newUnsupportedOperationError("speech", "vertex") +} + +func (provider *VertexProvider) SpeechStream(ctx context.Context, postHookRunner schemas.PostHookRunner, model string, key schemas.Key, input *schemas.SpeechInput, params *schemas.ModelParameters) (chan *schemas.BifrostStream, *schemas.BifrostError) { + return nil, newUnsupportedOperationError("speech stream", "vertex") +} + +func (provider *VertexProvider) Transcription(ctx context.Context, model string, key schemas.Key, input *schemas.TranscriptionInput, params *schemas.ModelParameters) (*schemas.BifrostResponse, *schemas.BifrostError) { + return nil, newUnsupportedOperationError("transcription", "vertex") +} + +func (provider *VertexProvider) TranscriptionStream(ctx context.Context, postHookRunner schemas.PostHookRunner, model string, key schemas.Key, input *schemas.TranscriptionInput, params *schemas.ModelParameters) (chan *schemas.BifrostStream, *schemas.BifrostError) { + return nil, newUnsupportedOperationError("transcription stream", "vertex") +} diff --git a/core/schemas/account.go b/core/schemas/account.go index 7800c2dd3..44563ca7b 100644 --- a/core/schemas/account.go +++ b/core/schemas/account.go @@ -1,14 +1,52 @@ // Package schemas defines the core schemas and types used by the Bifrost system. package schemas +import "context" + // Key represents an API key and its associated configuration for a provider. // It contains the key value, supported models, and a weight for load balancing. type Key struct { - Value string `json:"value"` // The actual API key value - Models []string `json:"models"` // List of models this key can access - Weight float64 `json:"weight"` // Weight for load balancing between multiple keys + ID string `json:"id"` // The unique identifier for the key (not used by bifrost, but can be used by users to identify the key) + Value string `json:"value"` // The actual API key value + Models []string `json:"models"` // List of models this key can access + Weight float64 `json:"weight"` // Weight for load balancing between multiple keys + AzureKeyConfig *AzureKeyConfig `json:"azure_key_config,omitempty"` // Azure-specific key configuration + VertexKeyConfig *VertexKeyConfig `json:"vertex_key_config,omitempty"` // Vertex-specific key configuration + BedrockKeyConfig *BedrockKeyConfig `json:"bedrock_key_config,omitempty"` // AWS Bedrock-specific key configuration +} + +// AzureKeyConfig represents the Azure-specific configuration. +// It contains Azure-specific settings required for service access and deployment management. +type AzureKeyConfig struct { + Endpoint string `json:"endpoint"` // Azure service endpoint URL + Deployments map[string]string `json:"deployments,omitempty"` // Mapping of model names to deployment names + APIVersion *string `json:"api_version,omitempty"` // Azure API version to use; defaults to "2024-08-01-preview" +} + +// VertexKeyConfig represents the Vertex-specific configuration. +// It contains Vertex-specific settings required for authentication and service access. +type VertexKeyConfig struct { + ProjectID string `json:"project_id,omitempty"` + Region string `json:"region,omitempty"` + AuthCredentials string `json:"auth_credentials,omitempty"` } +// NOTE: To use Vertex IAM role authentication, set AuthCredentials to empty string. + +// BedrockKeyConfig represents the AWS Bedrock-specific configuration. +// It contains AWS-specific settings required for authentication and service access. +type BedrockKeyConfig struct { + AccessKey string `json:"access_key,omitempty"` // AWS access key for authentication + SecretKey string `json:"secret_key,omitempty"` // AWS secret access key for authentication + SessionToken *string `json:"session_token,omitempty"` // AWS session token for temporary credentials + Region *string `json:"region,omitempty"` // AWS region for service access + ARN *string `json:"arn,omitempty"` // Amazon Resource Name for resource identification + Deployments map[string]string `json:"deployments,omitempty"` // Mapping of model identifiers to inference profiles +} + +// NOTE: To use Bedrock IAM role authentication, set both AccessKey and SecretKey to empty strings. +// To use Bedrock API Key authentication, set Value in Key struct instead. + // Account defines the interface for managing provider accounts and their configurations. // It provides methods to access provider-specific settings, API keys, and configurations. type Account interface { @@ -18,7 +56,10 @@ type Account interface { // GetKeysForProvider returns the API keys configured for a specific provider. // The keys include their values, supported models, and weights for load balancing. - GetKeysForProvider(providerKey ModelProvider) ([]Key, error) + // The context can carry data from any source that sets values before the Bifrost request, + // including but not limited to plugin pre-hooks, application logic, or any in app middleware sharing the context. + // This enables dynamic key selection based on any context values present during the request. + GetKeysForProvider(ctx *context.Context, providerKey ModelProvider) ([]Key, error) // GetConfigForProvider returns the configuration for a specific provider. // This includes network settings, authentication details, and other provider-specific diff --git a/core/schemas/bifrost.go b/core/schemas/bifrost.go index 4e3f06041..667d5e837 100644 --- a/core/schemas/bifrost.go +++ b/core/schemas/bifrost.go @@ -1,8 +1,14 @@ // Package schemas defines the core schemas and types used by the Bifrost system. package schemas +import ( + "fmt" + + "github.com/bytedance/sonic" +) + const ( - DefaultInitialPoolSize = 100 + DefaultInitialPoolSize = 5000 ) // BifrostConfig represents the configuration for initializing a Bifrost instance. @@ -12,76 +18,294 @@ type BifrostConfig struct { Account Account Plugins []Plugin Logger Logger - InitialPoolSize int // Initial pool size for sync pools in Bifrost. Higher values will reduce memory allocations but will increase memory usage. - DropExcessRequests bool // If true, in cases where the queue is full, requests will not wait for the queue to be empty and will be dropped instead. + InitialPoolSize int // Initial pool size for sync pools in Bifrost. Higher values will reduce memory allocations but will increase memory usage. + DropExcessRequests bool // If true, in cases where the queue is full, requests will not wait for the queue to be empty and will be dropped instead. + MCPConfig *MCPConfig // MCP (Model Context Protocol) configuration for tool integration } // ModelChatMessageRole represents the role of a chat message type ModelChatMessageRole string const ( - RoleAssistant ModelChatMessageRole = "assistant" - RoleUser ModelChatMessageRole = "user" - RoleSystem ModelChatMessageRole = "system" - RoleChatbot ModelChatMessageRole = "chatbot" - RoleTool ModelChatMessageRole = "tool" + ModelChatMessageRoleAssistant ModelChatMessageRole = "assistant" + ModelChatMessageRoleUser ModelChatMessageRole = "user" + ModelChatMessageRoleSystem ModelChatMessageRole = "system" + ModelChatMessageRoleChatbot ModelChatMessageRole = "chatbot" + ModelChatMessageRoleTool ModelChatMessageRole = "tool" ) // ModelProvider represents the different AI model providers supported by Bifrost. type ModelProvider string const ( - OpenAI ModelProvider = "openai" - Azure ModelProvider = "azure" - Anthropic ModelProvider = "anthropic" - Bedrock ModelProvider = "bedrock" - Cohere ModelProvider = "cohere" + OpenAI ModelProvider = "openai" + Azure ModelProvider = "azure" + Anthropic ModelProvider = "anthropic" + Bedrock ModelProvider = "bedrock" + Cohere ModelProvider = "cohere" + Vertex ModelProvider = "vertex" + Mistral ModelProvider = "mistral" + Ollama ModelProvider = "ollama" + Groq ModelProvider = "groq" + SGL ModelProvider = "sgl" + Parasail ModelProvider = "parasail" + Cerebras ModelProvider = "cerebras" + Gemini ModelProvider = "gemini" + OpenRouter ModelProvider = "openrouter" +) + +// SupportedBaseProviders is the list of base providers allowed for custom providers. +var SupportedBaseProviders = []ModelProvider{ + Anthropic, + Bedrock, + Cohere, + Gemini, + OpenAI, +} + +// StandardProviders is the list of all built-in (non-custom) providers. +var StandardProviders = []ModelProvider{ + Anthropic, + Azure, + Bedrock, + Cerebras, + Cohere, + Gemini, + Groq, + Mistral, + Ollama, + OpenAI, + Parasail, + SGL, + Vertex, + OpenRouter, +} + +// RequestType represents the type of request being made to a provider. +type RequestType string + +const ( + TextCompletionRequest RequestType = "text_completion" + ChatCompletionRequest RequestType = "chat_completion" + ChatCompletionStreamRequest RequestType = "chat_completion_stream" + EmbeddingRequest RequestType = "embedding" + SpeechRequest RequestType = "speech" + SpeechStreamRequest RequestType = "speech_stream" + TranscriptionRequest RequestType = "transcription" + TranscriptionStreamRequest RequestType = "transcription_stream" +) + +// BifrostContextKey is a type for context keys used in Bifrost. +type BifrostContextKey string + +// BifrostContextKeyRequestType is a context key for the request type. +const ( + BifrostContextKeyDirectKey BifrostContextKey = "bifrost-direct-key" + BifrostContextKeyStreamEndIndicator BifrostContextKey = "bifrost-stream-end-indicator" + BifrostContextKeyRequestType BifrostContextKey = "bifrost-request-type" + BifrostContextKeyRequestProvider BifrostContextKey = "bifrost-request-provider" + BifrostContextKeyRequestModel BifrostContextKey = "bifrost-request-model" ) +// NOTE: for custom plugin implementation dealing with streaming short circuit, +// make sure to mark BifrostContextKeyStreamEndIndicator as true at the end of the stream. + //* Request Structs // RequestInput represents the input for a model request, which can be either -// a text completion or a chat completion, but either one must be provided. +// a text completion, a chat completion, an embedding request, a speech request, or a transcription request. type RequestInput struct { - TextCompletionInput *string - ChatCompletionInput *[]Message + TextCompletionInput *string `json:"text_completion_input,omitempty"` + ChatCompletionInput *[]BifrostMessage `json:"chat_completion_input,omitempty"` + EmbeddingInput *EmbeddingInput `json:"embedding_input,omitempty"` + SpeechInput *SpeechInput `json:"speech_input,omitempty"` + TranscriptionInput *TranscriptionInput `json:"transcription_input,omitempty"` +} + +// EmbeddingInput represents the input for an embedding request. +type EmbeddingInput struct { + Text *string + Texts []string + Embedding []int + Embeddings [][]int +} + +func (e *EmbeddingInput) MarshalJSON() ([]byte, error) { + // enforce one-of + set := 0 + if e.Text != nil { + set++ + } + if e.Texts != nil { + set++ + } + if e.Embedding != nil { + set++ + } + if e.Embeddings != nil { + set++ + } + if set == 0 { + return nil, fmt.Errorf("embedding input is empty") + } + if set > 1 { + return nil, fmt.Errorf("embedding input must set exactly one of: text, texts, embedding, embeddings") + } + + if e.Text != nil { + return sonic.Marshal(*e.Text) + } + if e.Texts != nil { + return sonic.Marshal(e.Texts) + } + if e.Embedding != nil { + return sonic.Marshal(e.Embedding) + } + if e.Embeddings != nil { + return sonic.Marshal(e.Embeddings) + } + + return nil, fmt.Errorf("invalid embedding input") +} + +func (e *EmbeddingInput) UnmarshalJSON(data []byte) error { + // Try string + var s string + if err := sonic.Unmarshal(data, &s); err == nil { + e.Text = &s + return nil + } + // Try []string + var ss []string + if err := sonic.Unmarshal(data, &ss); err == nil { + e.Texts = ss + return nil + } + // Try []int + var i []int + if err := sonic.Unmarshal(data, &i); err == nil { + e.Embedding = i + return nil + } + // Try [][]int + var i2 [][]int + if err := sonic.Unmarshal(data, &i2); err == nil { + e.Embeddings = i2 + return nil + } + + return fmt.Errorf("unsupported embedding input shape") +} + +// SpeechInput represents the input for a speech request. +type SpeechInput struct { + Input string `json:"input"` + VoiceConfig SpeechVoiceInput `json:"voice"` + Instructions string `json:"instructions,omitempty"` + ResponseFormat string `json:"response_format,omitempty"` // Default is "mp3" +} + +type SpeechVoiceInput struct { + Voice *string + MultiVoiceConfig []VoiceConfig +} + +type VoiceConfig struct { + Speaker string `json:"speaker"` + Voice string `json:"voice"` +} + +// MarshalJSON implements custom JSON marshalling for SpeechVoiceInput. +// It marshals either Voice or MultiVoiceConfig directly without wrapping. +func (tc SpeechVoiceInput) MarshalJSON() ([]byte, error) { + // Validation: ensure only one field is set at a time + if tc.Voice != nil && len(tc.MultiVoiceConfig) > 0 { + return nil, fmt.Errorf("both Voice and MultiVoiceConfig are set; only one should be non-nil") + } + + if tc.Voice != nil { + return sonic.Marshal(*tc.Voice) + } + if len(tc.MultiVoiceConfig) > 0 { + return sonic.Marshal(tc.MultiVoiceConfig) + } + // If both are nil, return null + return sonic.Marshal(nil) +} + +// UnmarshalJSON implements custom JSON unmarshalling for SpeechVoiceInput. +// It determines whether "voice" is a string or a VoiceConfig object/array and assigns to the appropriate field. +// It also handles direct string/array content without a wrapper object. +func (tc *SpeechVoiceInput) UnmarshalJSON(data []byte) error { + // First, try to unmarshal as a direct string + var stringContent string + if err := sonic.Unmarshal(data, &stringContent); err == nil { + tc.Voice = &stringContent + return nil + } + + // Try to unmarshal as an array of VoiceConfig objects + var voiceConfigs []VoiceConfig + if err := sonic.Unmarshal(data, &voiceConfigs); err == nil { + // Validate each VoiceConfig and append to MultiVoiceConfig + for _, config := range voiceConfigs { + if config.Voice == "" { + return fmt.Errorf("voice config has empty voice field") + } + tc.MultiVoiceConfig = append(tc.MultiVoiceConfig, config) + } + return nil + } + + return fmt.Errorf("voice field is neither a string, nor an array of VoiceConfig objects") +} + +type TranscriptionInput struct { + File []byte `json:"file"` + Language *string `json:"language,omitempty"` + Prompt *string `json:"prompt,omitempty"` + ResponseFormat *string `json:"response_format,omitempty"` // Default is "json" + Format *string `json:"file_format,omitempty"` // Type of file, not required in openai, but required in gemini } // BifrostRequest represents a request to be processed by Bifrost. -// It must be provided when calling the Bifrost for text completion or chat completion. +// It must be provided when calling the Bifrost for text completion, chat completion, or embedding. // It contains the model identifier, input data, and parameters for the request. type BifrostRequest struct { - Model string - Input RequestInput - Params *ModelParameters + Provider ModelProvider `json:"provider"` + Model string `json:"model"` + Input RequestInput `json:"input"` + Params *ModelParameters `json:"params,omitempty"` // Fallbacks are tried in order, the first one to succeed is returned // Provider config must be available for each fallback's provider in account's GetConfigForProvider, // else it will be skipped. - Fallbacks []Fallback + Fallbacks []Fallback `json:"fallbacks,omitempty"` } // Fallback represents a fallback model to be used if the primary model is not available. type Fallback struct { - Provider ModelProvider - Model string + Provider ModelProvider `json:"provider"` + Model string `json:"model"` } // ModelParameters represents the parameters that can be used to configure // your request to the model. Bifrost follows a standard set of parameters which // mapped to the provider's parameters. type ModelParameters struct { - ToolChoice *ToolChoice `json:"tool_choice,omitempty"` - Tools *[]Tool `json:"tools,omitempty"` - Temperature *float64 `json:"temperature,omitempty"` // Controls randomness in the output - TopP *float64 `json:"top_p,omitempty"` // Controls diversity via nucleus sampling - TopK *int `json:"top_k,omitempty"` // Controls diversity via top-k sampling - MaxTokens *int `json:"max_tokens,omitempty"` // Maximum number of tokens to generate - StopSequences *[]string `json:"stop_sequences,omitempty"` // Sequences that stop generation - PresencePenalty *float64 `json:"presence_penalty,omitempty"` // Penalizes repeated tokens - FrequencyPenalty *float64 `json:"frequency_penalty,omitempty"` // Penalizes frequent tokens - ParallelToolCalls *bool `json:"parallel_tool_calls"` // Enables parallel tool calls - + ToolChoice *ToolChoice `json:"tool_choice,omitempty"` // Whether to call a tool + Tools *[]Tool `json:"tools,omitempty"` // Tools to use + Temperature *float64 `json:"temperature,omitempty"` // Controls randomness in the output + TopP *float64 `json:"top_p,omitempty"` // Controls diversity via nucleus sampling + TopK *int `json:"top_k,omitempty"` // Controls diversity via top-k sampling + MaxTokens *int `json:"max_tokens,omitempty"` // Maximum number of tokens to generate + StopSequences *[]string `json:"stop_sequences,omitempty"` // Sequences that stop generation + PresencePenalty *float64 `json:"presence_penalty,omitempty"` // Penalizes repeated tokens + FrequencyPenalty *float64 `json:"frequency_penalty,omitempty"` // Penalizes frequent tokens + ParallelToolCalls *bool `json:"parallel_tool_calls,omitempty"` // Enables parallel tool calls + EncodingFormat *string `json:"encoding_format,omitempty"` // Format for embedding output (e.g., "float", "base64") + Dimensions *int `json:"dimensions,omitempty"` // Number of dimensions for embedding output + User *string `json:"user,omitempty"` // User identifier for tracking // Dynamic parameters that can be provider-specific, they are directly // added to the request as is. ExtraParams map[string]interface{} `json:"-"` @@ -89,10 +313,11 @@ type ModelParameters struct { // FunctionParameters represents the parameters for a function definition. type FunctionParameters struct { - Type string `json:"type,"` // Type of the parameters + Type string `json:"type"` // Type of the parameters Description *string `json:"description,omitempty"` // Description of the parameters - Required []string `json:"required"` // Required parameter names - Properties map[string]interface{} `json:"properties"` // Parameter properties + Required []string `json:"required,omitempty"` // Required parameter names + Properties map[string]interface{} `json:"properties,omitempty"` // Parameter properties + Enum *[]string `json:"enum,omitempty"` // Enum values for the parameters } // Function represents a function that can be called by the model. @@ -114,16 +339,16 @@ type Tool struct { type ToolChoiceType string const ( - // ToolChoiceNone means no tool will be called - ToolChoiceNone ToolChoiceType = "none" - // ToolChoiceAuto means the model can choose whether to call a tool - ToolChoiceAuto ToolChoiceType = "auto" - // ToolChoiceAny means any tool can be called - ToolChoiceAny ToolChoiceType = "any" - // ToolChoiceTool means a specific tool must be called - ToolChoiceTool ToolChoiceType = "tool" - // ToolChoiceRequired means a tool must be called - ToolChoiceRequired ToolChoiceType = "required" + // ToolChoiceTypeNone means no tool will be called + ToolChoiceTypeNone ToolChoiceType = "none" + // ToolChoiceTypeAuto means the model can choose whether to call a tool + ToolChoiceTypeAuto ToolChoiceType = "auto" + // ToolChoiceTypeAny means any tool can be called + ToolChoiceTypeAny ToolChoiceType = "any" + // ToolChoiceTypeFunction means a specific tool must be called (converted to "tool" for Anthropic) + ToolChoiceTypeFunction ToolChoiceType = "function" + // ToolChoiceTypeRequired means a tool must be called + ToolChoiceTypeRequired ToolChoiceType = "required" ) // ToolChoiceFunction represents a specific function to be called. @@ -131,26 +356,157 @@ type ToolChoiceFunction struct { Name string `json:"name"` // Name of the function to call } -// ToolChoice represents how a tool should be chosen for a request. +// ToolChoiceStruct represents a specific tool choice. +type ToolChoiceStruct struct { + Type ToolChoiceType `json:"type"` // Type of tool choice + Function ToolChoiceFunction `json:"function,omitempty"` // Function to call if type is ToolChoiceTypeFunction +} + +// ToolChoice represents how a tool should be chosen for a request. (either a string or a struct) type ToolChoice struct { - Type ToolChoiceType `json:"type"` // Type of tool choice - Function ToolChoiceFunction `json:"function"` // Function to call if type is ToolChoiceTool + ToolChoiceStr *string + ToolChoiceStruct *ToolChoiceStruct +} + +// MarshalJSON implements custom JSON marshalling for ToolChoice. +// It marshals either ToolChoiceStr or ToolChoiceStruct directly without wrapping. +func (tc ToolChoice) MarshalJSON() ([]byte, error) { + // Validation: ensure only one field is set at a time + if tc.ToolChoiceStr != nil && tc.ToolChoiceStruct != nil { + return nil, fmt.Errorf("both ToolChoiceStr and ToolChoiceStruct are set; only one should be non-nil") + } + + if tc.ToolChoiceStr != nil { + return sonic.Marshal(*tc.ToolChoiceStr) + } + if tc.ToolChoiceStruct != nil { + return sonic.Marshal(*tc.ToolChoiceStruct) + } + // If both are nil, return null + return sonic.Marshal(nil) +} + +// UnmarshalJSON implements custom JSON unmarshalling for ToolChoice. +// It determines whether "tool_choice" is a string or struct and assigns to the appropriate field. +// It also handles direct string/array content without a wrapper object. +func (tc *ToolChoice) UnmarshalJSON(data []byte) error { + // First, try to unmarshal as a direct string + var stringContent string + if err := sonic.Unmarshal(data, &stringContent); err == nil { + tc.ToolChoiceStr = &stringContent + return nil + } + + // Try to unmarshal as a direct struct of ToolChoiceStruct + var toolChoiceStruct ToolChoiceStruct + if err := sonic.Unmarshal(data, &toolChoiceStruct); err == nil { + // Validate the Type field is not empty and is a valid value + if toolChoiceStruct.Type == "" { + return fmt.Errorf("tool_choice struct has empty type field") + } + + tc.ToolChoiceStruct = &toolChoiceStruct + return nil + } + + return fmt.Errorf("tool_choice field is neither a string nor a struct") +} + +// BifrostMessage represents a message in a chat conversation. +type BifrostMessage struct { + Role ModelChatMessageRole `json:"role"` + Content MessageContent `json:"content"` + + // Embedded pointer structs - when non-nil, their exported fields are flattened into the top-level JSON object + // IMPORTANT: Only one of the following can be non-nil at a time, otherwise the JSON marshalling will override the common fields + *ToolMessage + *AssistantMessage +} + +type MessageContent struct { + ContentStr *string + ContentBlocks *[]ContentBlock +} + +// MarshalJSON implements custom JSON marshalling for MessageContent. +// It marshals either ContentStr or ContentBlocks directly without wrapping. +func (mc MessageContent) MarshalJSON() ([]byte, error) { + // Validation: ensure only one field is set at a time + if mc.ContentStr != nil && mc.ContentBlocks != nil { + return nil, fmt.Errorf("both ContentStr and ContentBlocks are set; only one should be non-nil") + } + + if mc.ContentStr != nil { + return sonic.Marshal(*mc.ContentStr) + } + if mc.ContentBlocks != nil { + return sonic.Marshal(*mc.ContentBlocks) + } + // If both are nil, return null + return sonic.Marshal(nil) +} + +// UnmarshalJSON implements custom JSON unmarshalling for MessageContent. +// It determines whether "content" is a string or array and assigns to the appropriate field. +// It also handles direct string/array content without a wrapper object. +func (mc *MessageContent) UnmarshalJSON(data []byte) error { + // First, try to unmarshal as a direct string + var stringContent string + if err := sonic.Unmarshal(data, &stringContent); err == nil { + mc.ContentStr = &stringContent + return nil + } + + // Try to unmarshal as a direct array of ContentBlock + var arrayContent []ContentBlock + if err := sonic.Unmarshal(data, &arrayContent); err == nil { + mc.ContentBlocks = &arrayContent + return nil + } + + return fmt.Errorf("content field is neither a string nor an array of ContentBlock") +} + +type ContentBlockType string + +const ( + ContentBlockTypeText ContentBlockType = "text" + ContentBlockTypeImage ContentBlockType = "image_url" + ContentBlockTypeInputAudio ContentBlockType = "input_audio" +) + +type ContentBlock struct { + Type ContentBlockType `json:"type"` + Text *string `json:"text,omitempty"` + ImageURL *ImageURLStruct `json:"image_url,omitempty"` + InputAudio *InputAudioStruct `json:"input_audio,omitempty"` +} + +// ToolMessage represents a message from a tool +type ToolMessage struct { + ToolCallID *string `json:"tool_call_id,omitempty"` } -// Message represents a single message in a chat conversation. -type Message struct { - Role ModelChatMessageRole `json:"role"` - Content *string `json:"content,omitempty"` - ImageContent *ImageContent `json:"image_content,omitempty"` - ToolCalls *[]Tool `json:"tool_calls,omitempty"` +// AssistantMessage represents a message from an assistant +type AssistantMessage struct { + Refusal *string `json:"refusal,omitempty"` + Annotations []Annotation `json:"annotations,omitempty"` + ToolCalls *[]ToolCall `json:"tool_calls,omitempty"` + Thought *string `json:"thought,omitempty"` } // ImageContent represents image data in a message. -type ImageContent struct { - Type *string `json:"type"` - URL string `json:"url"` - MediaType *string `json:"media_type"` - Detail *string `json:"detail"` +type ImageURLStruct struct { + URL string `json:"url"` + Detail *string `json:"detail,omitempty"` +} + +// InputAudioStruct represents audio data in a message. +// Data carries the audio payload as a string (e.g., data URL or provider-accepted encoded content). +// Format is optional (e.g., "wav", "mp3"); when nil, providers may attempt auto-detection. +type InputAudioStruct struct { + Data string `json:"data"` + Format *string `json:"format,omitempty"` } //* Response Structs @@ -158,13 +514,16 @@ type ImageContent struct { // BifrostResponse represents the complete result from any bifrost request. type BifrostResponse struct { ID string `json:"id,omitempty"` - Object string `json:"object,omitempty"` // text.completion or chat.completion + Object string `json:"object,omitempty"` // text.completion, chat.completion, embedding, speech, transcribe Choices []BifrostResponseChoice `json:"choices,omitempty"` + Data []BifrostEmbedding `json:"data,omitempty"` // Maps to "data" field in provider responses (e.g., OpenAI embedding format) + Speech *BifrostSpeech `json:"speech,omitempty"` // Maps to "speech" field in provider responses (e.g., OpenAI speech format) + Transcribe *BifrostTranscribe `json:"transcribe,omitempty"` // Maps to "transcribe" field in provider responses (e.g., OpenAI transcription format) Model string `json:"model,omitempty"` Created int `json:"created,omitempty"` // The Unix timestamp (in seconds). ServiceTier *string `json:"service_tier,omitempty"` SystemFingerprint *string `json:"system_fingerprint,omitempty"` - Usage LLMUsage `json:"usage,omitempty"` + Usage *LLMUsage `json:"usage,omitempty"` ExtraFields BifrostResponseExtraFields `json:"extra_fields"` } @@ -177,6 +536,18 @@ type LLMUsage struct { CompletionTokensDetails *CompletionTokensDetails `json:"completion_tokens_details,omitempty"` } +type AudioLLMUsage struct { + InputTokens int `json:"input_tokens"` + InputTokensDetails *AudioTokenDetails `json:"input_tokens_details,omitempty"` + OutputTokens int `json:"output_tokens"` + TotalTokens int `json:"total_tokens"` +} + +type AudioTokenDetails struct { + TextTokens int `json:"text_tokens"` + AudioTokens int `json:"audio_tokens"` +} + // TokenDetails provides detailed information about token usage. // It is not provided by all model providers. type TokenDetails struct { @@ -260,41 +631,224 @@ type Annotation struct { Citation Citation `json:"url_citation"` } -// BifrostResponseChoiceMessage represents a choice in the completion response -type BifrostResponseChoiceMessage struct { - Role ModelChatMessageRole `json:"role"` - Content *string `json:"content,omitempty"` - Refusal *string `json:"refusal,omitempty"` - Annotations []Annotation `json:"annotations,omitempty"` - ToolCalls *[]ToolCall `json:"tool_calls,omitempty"` +type BifrostEmbedding struct { + Index int `json:"index"` + Object string `json:"object"` // embedding + Embedding BifrostEmbeddingResponse `json:"embedding"` // can be []float32 or string +} + +type BifrostEmbeddingResponse struct { + EmbeddingStr *string + EmbeddingArray *[]float32 + Embedding2DArray *[][]float32 +} + +func (be BifrostEmbeddingResponse) MarshalJSON() ([]byte, error) { + if be.EmbeddingStr != nil { + return sonic.Marshal(be.EmbeddingStr) + } + if be.EmbeddingArray != nil { + return sonic.Marshal(be.EmbeddingArray) + } + if be.Embedding2DArray != nil { + return sonic.Marshal(be.Embedding2DArray) + } + return nil, fmt.Errorf("no embedding found") +} + +func (be *BifrostEmbeddingResponse) UnmarshalJSON(data []byte) error { + // First, try to unmarshal as a direct string + var stringContent string + if err := sonic.Unmarshal(data, &stringContent); err == nil { + be.EmbeddingStr = &stringContent + return nil + } + + // Try to unmarshal as a direct array of float32 + var arrayContent []float32 + if err := sonic.Unmarshal(data, &arrayContent); err == nil { + be.EmbeddingArray = &arrayContent + return nil + } + + // Try to unmarshal as a direct 2D array of float32 + var arrayContent2D [][]float32 + if err := sonic.Unmarshal(data, &arrayContent2D); err == nil { + be.Embedding2DArray = &arrayContent2D + return nil + } + + return fmt.Errorf("embedding field is neither a string nor an array of float32 nor a 2D array of float32") } -// BifrostResponseChoice represents a choice in the completion result +// BifrostResponseChoice represents a choice in the completion result. +// This struct can represent either a streaming or non-streaming response choice. +// IMPORTANT: Only one of BifrostNonStreamResponseChoice or BifrostStreamResponseChoice +// should be non-nil at a time. type BifrostResponseChoice struct { - Index int `json:"index"` - Message BifrostResponseChoiceMessage `json:"message"` - FinishReason *string `json:"finish_reason,omitempty"` - StopString *string `json:"stop,omitempty"` - LogProbs *LogProbs `json:"log_probs,omitempty"` + Index int `json:"index"` + FinishReason *string `json:"finish_reason,omitempty"` + + *BifrostNonStreamResponseChoice + *BifrostStreamResponseChoice +} + +// BifrostNonStreamResponseChoice represents a choice in the non-stream response +type BifrostNonStreamResponseChoice struct { + Message BifrostMessage `json:"message"` + StopString *string `json:"stop,omitempty"` + LogProbs *LogProbs `json:"log_probs,omitempty"` +} + +// BifrostStreamResponseChoice represents a choice in the stream response +type BifrostStreamResponseChoice struct { + Delta BifrostStreamDelta `json:"delta"` // Partial message info +} + +// BifrostStreamDelta represents a delta in the stream response +type BifrostStreamDelta struct { + Role *string `json:"role,omitempty"` // Only in the first chunk + Content *string `json:"content,omitempty"` // May be empty string or null + Thought *string `json:"thought,omitempty"` // May be empty string or null + Refusal *string `json:"refusal,omitempty"` // Refusal content if any + ToolCalls []ToolCall `json:"tool_calls,omitempty"` // If tool calls used (supports incremental updates) +} + +type BifrostSpeech struct { + Usage *AudioLLMUsage `json:"usage,omitempty"` + Audio []byte `json:"audio"` + + *BifrostSpeechStreamResponse +} +type BifrostSpeechStreamResponse struct { + Type string `json:"type"` +} + +// BifrostTranscribe represents transcription response data +type BifrostTranscribe struct { + // Common fields for both streaming and non-streaming + Text string `json:"text"` + LogProbs []TranscriptionLogProb `json:"logprobs,omitempty"` + Usage *TranscriptionUsage `json:"usage,omitempty"` + + // Embedded structs for specific fields only + *BifrostTranscribeNonStreamResponse + *BifrostTranscribeStreamResponse +} + +// BifrostTranscribeNonStreamResponse represents non-streaming specific fields only +type BifrostTranscribeNonStreamResponse struct { + Task *string `json:"task,omitempty"` // e.g., "transcribe" + Language *string `json:"language,omitempty"` // e.g., "english" + Duration *float64 `json:"duration,omitempty"` // Duration in seconds + Words []TranscriptionWord `json:"words,omitempty"` + Segments []TranscriptionSegment `json:"segments,omitempty"` +} + +// BifrostTranscribeStreamResponse represents streaming specific fields only +type BifrostTranscribeStreamResponse struct { + Type *string `json:"type,omitempty"` // "transcript.text.delta" or "transcript.text.done" + Delta *string `json:"delta,omitempty"` // For delta events +} + +// TranscriptionLogProb represents log probability information for transcription +type TranscriptionLogProb struct { + Token string `json:"token"` + LogProb float64 `json:"logprob"` + Bytes []int `json:"bytes"` +} + +// TranscriptionWord represents word-level timing information +type TranscriptionWord struct { + Word string `json:"word"` + Start float64 `json:"start"` + End float64 `json:"end"` +} + +// TranscriptionSegment represents segment-level transcription information +type TranscriptionSegment struct { + ID int `json:"id"` + Seek int `json:"seek"` + Start float64 `json:"start"` + End float64 `json:"end"` + Text string `json:"text"` + Tokens []int `json:"tokens"` + Temperature float64 `json:"temperature"` + AvgLogProb float64 `json:"avg_logprob"` + CompressionRatio float64 `json:"compression_ratio"` + NoSpeechProb float64 `json:"no_speech_prob"` +} + +// TranscriptionUsage represents usage information for transcription +type TranscriptionUsage struct { + Type string `json:"type"` // "tokens" or "duration" + InputTokens *int `json:"input_tokens,omitempty"` + InputTokenDetails *AudioTokenDetails `json:"input_token_details,omitempty"` + OutputTokens *int `json:"output_tokens,omitempty"` + TotalTokens *int `json:"total_tokens,omitempty"` + Seconds *int `json:"seconds,omitempty"` // For duration-based usage } // BifrostResponseExtraFields contains additional fields in a response. type BifrostResponseExtraFields struct { - Provider ModelProvider `json:"provider"` - Params ModelParameters `json:"model_params"` - Latency *float64 `json:"latency,omitempty"` - ChatHistory *[]BifrostResponseChoiceMessage `json:"chat_history,omitempty"` - BilledUsage *BilledLLMUsage `json:"billed_usage,omitempty"` - RawResponse interface{} `json:"raw_response"` + Provider ModelProvider `json:"provider"` + Params ModelParameters `json:"model_params"` + Latency *float64 `json:"latency,omitempty"` + ChatHistory *[]BifrostMessage `json:"chat_history,omitempty"` + BilledUsage *BilledLLMUsage `json:"billed_usage,omitempty"` + ChunkIndex int `json:"chunk_index"` // used for streaming responses to identify the chunk index, will be 0 for non-streaming responses + RawResponse interface{} `json:"raw_response,omitempty"` + CacheDebug *BifrostCacheDebug `json:"cache_debug,omitempty"` +} + +// BifrostCacheDebug represents debug information about the cache. +type BifrostCacheDebug struct { + CacheHit bool `json:"cache_hit"` + + CacheID *string `json:"cache_id,omitempty"` + HitType *string `json:"hit_type,omitempty"` + + // Semantic cache only (provider, model, and input tokens will be present for semantic cache, even if cache is not hit) + ProviderUsed *string `json:"provider_used,omitempty"` + ModelUsed *string `json:"model_used,omitempty"` + InputTokens *int `json:"input_tokens,omitempty"` + + // Semantic cache only (only when cache is hit) + Threshold *float64 `json:"threshold,omitempty"` + Similarity *float64 `json:"similarity,omitempty"` +} + +const ( + RequestCancelled = "request_cancelled" +) + +// BifrostStream represents a stream of responses from the Bifrost system. +// Either BifrostResponse or BifrostError will be non-nil. +type BifrostStream struct { + *BifrostResponse + *BifrostError } // BifrostError represents an error from the Bifrost system. +// +// PLUGIN DEVELOPERS: When creating BifrostError in PreHook or PostHook, you can set AllowFallbacks: +// - AllowFallbacks = &true: Bifrost will try fallback providers if available +// - AllowFallbacks = &false: Bifrost will return this error immediately, no fallbacks +// - AllowFallbacks = nil: Treated as true by default (fallbacks allowed for resilience) type BifrostError struct { - EventID *string `json:"event_id,omitempty"` - Type *string `json:"type,omitempty"` - IsBifrostError bool `json:"is_bifrost_error"` - StatusCode *int `json:"status_code,omitempty"` - Error ErrorField `json:"error"` + Provider ModelProvider `json:"-"` + EventID *string `json:"event_id,omitempty"` + Type *string `json:"type,omitempty"` + IsBifrostError bool `json:"is_bifrost_error"` + StatusCode *int `json:"status_code,omitempty"` + Error ErrorField `json:"error"` + AllowFallbacks *bool `json:"-"` // Optional: Controls fallback behavior (nil = true by default) + StreamControl *StreamControl `json:"-"` // Optional: Controls stream behavior +} + +type StreamControl struct { + LogError *bool `json:"log_error,omitempty"` // Optional: Controls logging of error + SkipStream *bool `json:"skip_stream,omitempty"` // Optional: Controls skipping of stream chunk } // ErrorField represents detailed error information. diff --git a/core/schemas/logger.go b/core/schemas/logger.go index 9e636579f..268244d79 100644 --- a/core/schemas/logger.go +++ b/core/schemas/logger.go @@ -2,9 +2,10 @@ package schemas // LogLevel represents the severity level of a log message. -// It is used to categorize and filter log messages based on their importance. +// Internally it maps to zerolog.Level for interoperability. type LogLevel string +// LogLevel constants for different severity levels. const ( LogLevelDebug LogLevel = "debug" LogLevelInfo LogLevel = "info" @@ -12,6 +13,15 @@ const ( LogLevelError LogLevel = "error" ) +// LoggerOutputType represents the output type of a logger. +type LoggerOutputType string + +// LoggerOutputType constants for different output types. +const ( + LoggerOutputTypeJSON LoggerOutputType = "json" + LoggerOutputTypePretty LoggerOutputType = "pretty" +) + // Logger defines the interface for logging operations in the Bifrost system. // Implementations of this interface should provide methods for logging messages // at different severity levels. @@ -19,17 +29,27 @@ type Logger interface { // Debug logs a debug-level message. // This is used for detailed debugging information that is typically only needed // during development or troubleshooting. - Debug(msg string) + Debug(msg string, args ...any) // Info logs an info-level message. // This is used for general informational messages about normal operation. - Info(msg string) + Info(msg string, args ...any) // Warn logs a warning-level message. // This is used for potentially harmful situations that don't prevent normal operation. - Warn(msg string) + Warn(msg string, args ...any) // Error logs an error-level message. // This is used for serious problems that need attention and may prevent normal operation. - Error(err error) + Error(msg string, args ...any) + + // Fatal logs a fatal-level message. + // This is used for critical situations that require immediate attention and will terminate the program. + Fatal(msg string, args ...any) + + // SetLevel sets the log level for the logger. + SetLevel(level LogLevel) + + // SetOutputType sets the output type for the logger. + SetOutputType(outputType LoggerOutputType) } diff --git a/core/schemas/mcp.go b/core/schemas/mcp.go new file mode 100644 index 000000000..2f2ac61b6 --- /dev/null +++ b/core/schemas/mcp.go @@ -0,0 +1,59 @@ +// Package schemas defines the core schemas and types used by the Bifrost system. +package schemas + +// MCPServerInstance represents an MCP server instance for InProcess connections. +// This should be a *github.com/mark3labs/mcp-go/server.MCPServer instance. +// We use interface{} to avoid creating a dependency on the mcp-go package in schemas. +type MCPServerInstance interface{} + +// MCPConfig represents the configuration for MCP integration in Bifrost. +// It enables tool auto-discovery and execution from local and external MCP servers. +type MCPConfig struct { + ClientConfigs []MCPClientConfig `json:"client_configs,omitempty"` // Per-client execution configurations +} + +// MCPClientConfig defines tool filtering for an MCP client. +type MCPClientConfig struct { + Name string `json:"name"` // Client name + ConnectionType MCPConnectionType `json:"connection_type"` // How to connect (HTTP, STDIO, SSE, or InProcess) + ConnectionString *string `json:"connection_string,omitempty"` // HTTP or SSE URL (required for HTTP or SSE connections) + StdioConfig *MCPStdioConfig `json:"stdio_config,omitempty"` // STDIO configuration (required for STDIO connections) + InProcessServer MCPServerInstance `json:"-"` // MCP server instance for in-process connections (Go package only) + ToolsToSkip []string `json:"tools_to_skip,omitempty"` // Tools to exclude from this client + ToolsToExecute []string `json:"tools_to_execute,omitempty"` // Tools to include from this client (if specified, only these are used) +} + +// MCPConnectionType defines the communication protocol for MCP connections +type MCPConnectionType string + +const ( + MCPConnectionTypeHTTP MCPConnectionType = "http" // HTTP-based connection + MCPConnectionTypeSTDIO MCPConnectionType = "stdio" // STDIO-based connection + MCPConnectionTypeSSE MCPConnectionType = "sse" // Server-Sent Events connection + MCPConnectionTypeInProcess MCPConnectionType = "inprocess" // In-process (in-memory) connection +) + +// MCPStdioConfig defines how to launch a STDIO-based MCP server. +type MCPStdioConfig struct { + Command string `json:"command"` // Executable command to run + Args []string `json:"args"` // Command line arguments + Envs []string `json:"envs"` // Environment variables required +} + +type MCPConnectionState string + +const ( + MCPConnectionStateConnected MCPConnectionState = "connected" // Client is connected and ready to use + MCPConnectionStateDisconnected MCPConnectionState = "disconnected" // Client is not connected + MCPConnectionStateError MCPConnectionState = "error" // Client is in an error state, and cannot be used +) + +// MCPClient represents a connected MCP client with its configuration and tools, +// and connection information, after it has been initialized. +// It is returned by GetMCPClients() method. +type MCPClient struct { + Name string `json:"name"` // Unique name for this client + Config MCPClientConfig `json:"config"` // Tool filtering settings + Tools []string `json:"tools"` // Available tools mapped by name + State MCPConnectionState `json:"state"` // Connection state +} diff --git a/core/schemas/meta/azure.go b/core/schemas/meta/azure.go deleted file mode 100644 index df5fd163b..000000000 --- a/core/schemas/meta/azure.go +++ /dev/null @@ -1,56 +0,0 @@ -// Package meta provides provider-specific configuration structures and schemas. -// This file contains the Azure-specific configuration implementation. - -package meta - -// AzureMetaConfig represents the Azure-specific configuration. -// It contains Azure-specific settings required for service access and deployment management. -type AzureMetaConfig struct { - Endpoint string `json:"endpoint"` // Azure service endpoint URL - Deployments map[string]string `json:"deployments,omitempty"` // Mapping of model names to deployment names - APIVersion *string `json:"api_version,omitempty"` // Azure API version to use; defaults to "2024-02-01" -} - -// This is not used for Azure. -func (c *AzureMetaConfig) GetSecretAccessKey() *string { - return nil -} - -// This is not used for Azure. -func (c *AzureMetaConfig) GetRegion() *string { - return nil -} - -// This is not used for Azure. -func (c *AzureMetaConfig) GetSessionToken() *string { - return nil -} - -// This is not used for Azure. -func (c *AzureMetaConfig) GetARN() *string { - return nil -} - -// This is not used for Azure. -func (c *AzureMetaConfig) GetInferenceProfiles() map[string]string { - return nil -} - -// GetEndpoint returns the Azure service endpoint. -// This specifies the base URL for Azure API requests. -func (c *AzureMetaConfig) GetEndpoint() *string { - return &c.Endpoint -} - -// GetDeployments returns the deployment configurations. -// This maps model names to their corresponding Azure deployment names. -// Eg. "gpt-4o": "your-deployment-name-for-gpt-4o" -func (c *AzureMetaConfig) GetDeployments() map[string]string { - return c.Deployments -} - -// GetAPIVersion returns the Azure API version. -// This specifies which version of the Azure API to use. -func (c *AzureMetaConfig) GetAPIVersion() *string { - return c.APIVersion -} diff --git a/core/schemas/meta/bedrock.go b/core/schemas/meta/bedrock.go deleted file mode 100644 index 1a875d3f6..000000000 --- a/core/schemas/meta/bedrock.go +++ /dev/null @@ -1,59 +0,0 @@ -// Package meta provides provider-specific configuration structures and schemas. -// This file contains the AWS Bedrock-specific configuration implementation. - -package meta - -// BedrockMetaConfig represents the AWS Bedrock-specific configuration. -// It contains AWS-specific settings required for authentication and service access. -type BedrockMetaConfig struct { - SecretAccessKey string `json:"secret_access_key,omitempty"` // AWS secret access key for authentication - Region *string `json:"region,omitempty"` // AWS region for service access - SessionToken *string `json:"session_token,omitempty"` // AWS session token for temporary credentials - ARN *string `json:"arn,omitempty"` // Amazon Resource Name for resource identification - InferenceProfiles map[string]string `json:"inference_profiles,omitempty"` // Mapping of model identifiers to inference profiles -} - -// GetSecretAccessKey returns the AWS secret access key. -// This is used for AWS API authentication. -func (c *BedrockMetaConfig) GetSecretAccessKey() *string { - return &c.SecretAccessKey -} - -// GetRegion returns the AWS region. -// This specifies which AWS region the service should be accessed from. -func (c *BedrockMetaConfig) GetRegion() *string { - return c.Region -} - -// GetSessionToken returns the AWS session token. -// This is used for temporary credentials in AWS authentication. -func (c *BedrockMetaConfig) GetSessionToken() *string { - return c.SessionToken -} - -// GetARN returns the Amazon Resource Name. -// This uniquely identifies AWS resources. -func (c *BedrockMetaConfig) GetARN() *string { - return c.ARN -} - -// GetInferenceProfiles returns the inference profiles mapping. -// This maps model identifiers to their corresponding inference profiles. -func (c *BedrockMetaConfig) GetInferenceProfiles() map[string]string { - return c.InferenceProfiles -} - -// This is not used for Bedrock. -func (c *BedrockMetaConfig) GetEndpoint() *string { - return nil -} - -// This is not used for Bedrock. -func (c *BedrockMetaConfig) GetDeployments() map[string]string { - return nil -} - -// This is not used for Bedrock. -func (c *BedrockMetaConfig) GetAPIVersion() *string { - return nil -} diff --git a/core/schemas/plugin.go b/core/schemas/plugin.go index c10adebf3..93f31917b 100644 --- a/core/schemas/plugin.go +++ b/core/schemas/plugin.go @@ -3,28 +3,69 @@ package schemas import "context" +// PluginShortCircuit represents a plugin's decision to short-circuit the normal flow. +// It can contain either a response (success short-circuit), a stream (streaming short-circuit), or an error (error short-circuit). +type PluginShortCircuit struct { + Response *BifrostResponse // If set, short-circuit with this response (skips provider call) + Stream chan *BifrostStream // If set, short-circuit with this stream (skips provider call) + Error *BifrostError // If set, short-circuit with this error (can set AllowFallbacks field) +} + // Plugin defines the interface for Bifrost plugins. // Plugins can intercept and modify requests and responses at different stages // of the processing pipeline. // User can provide multiple plugins in the BifrostConfig. // PreHooks are executed in the order they are registered. // PostHooks are executed in the reverse order of PreHooks. - +// // PreHooks and PostHooks can be used to implement custom logic, such as: // - Rate limiting // - Caching // - Logging // - Monitoring +// +// Plugin error handling: +// - No Plugin errors are returned to the caller; they are logged as warnings by the Bifrost instance. +// - PreHook and PostHook can both modify the request/response and the error. Plugins can recover from errors (set error to nil and provide a response), or invalidate a response (set response to nil and provide an error). +// - PostHook is always called with both the current response and error, and should handle either being nil. +// - Only truly empty errors (no message, no error, no status code, no type) are treated as recoveries by the pipeline. +// - If a PreHook returns a PluginShortCircuit, the provider call may be skipped and only the PostHook methods of plugins that had their PreHook executed are called in reverse order. +// - The plugin pipeline ensures symmetry: for every PreHook executed, the corresponding PostHook will be called in reverse order. +// +// IMPORTANT: When returning BifrostError from PreHook or PostHook: +// - You can set the AllowFallbacks field to control fallback behavior +// - AllowFallbacks = &true: Allow Bifrost to try fallback providers +// - AllowFallbacks = &false: Do not try fallbacks, return error immediately +// - AllowFallbacks = nil: Treated as true by default (allow fallbacks for resilience) +// +// Plugin authors should ensure their hooks are robust to both response and error being nil, and should not assume either is always present. type Plugin interface { + // GetName returns the name of the plugin. + GetName() string + // PreHook is called before a request is processed by a provider. // It allows plugins to modify the request before it is sent to the provider. // The context parameter can be used to maintain state across plugin calls. - // Returns the modified request and any error that occurred during processing. - PreHook(ctx *context.Context, req *BifrostRequest) (*BifrostRequest, error) + // Returns the modified request, an optional short-circuit decision, and any error that occurred during processing. + PreHook(ctx *context.Context, req *BifrostRequest) (*BifrostRequest, *PluginShortCircuit, error) + + // PostHook is called after a response is received from a provider or a PreHook short-circuit. + // It allows plugins to modify the response and/or error before it is returned to the caller. + // Plugins can recover from errors (set error to nil and provide a response), or invalidate a response (set response to nil and provide an error). + // Returns the modified response, bifrost error, and any error that occurred during processing. + PostHook(ctx *context.Context, result *BifrostResponse, err *BifrostError) (*BifrostResponse, *BifrostError, error) + + // Cleanup is called on bifrost shutdown. + // It allows plugins to clean up any resources they have allocated. + // Returns any error that occurred during cleanup, which will be logged as a warning by the Bifrost instance. + Cleanup() error +} - // PostHook is called after a response is received from a provider. - // It allows plugins to modify the response before it is returned to the caller. - // Returns the modified response and any error that occurred during processing. - PostHook(ctx *context.Context, result *BifrostResponse) (*BifrostResponse, error) +// PluginConfig is the configuration for a plugin. +// It contains the name of the plugin, whether it is enabled, and the configuration for the plugin. +type PluginConfig struct { + Enabled bool `json:"enabled"` + Name string `json:"name"` + Config any `json:"config,omitempty"` } diff --git a/core/schemas/provider.go b/core/schemas/provider.go index 56376b730..76d88ad85 100644 --- a/core/schemas/provider.go +++ b/core/schemas/provider.go @@ -1,15 +1,20 @@ // Package schemas defines the core schemas and types used by the Bifrost system. package schemas -import "time" +import ( + "context" + "maps" + "time" +) const ( DefaultMaxRetries = 0 DefaultRetryBackoffInitial = 500 * time.Millisecond DefaultRetryBackoffMax = 5 * time.Second DefaultRequestTimeoutInSeconds = 30 - DefaultBufferSize = 100 - DefaultConcurrency = 10 + DefaultBufferSize = 5000 + DefaultConcurrency = 1000 + DefaultStreamBufferSize = 5000 ) // Pre-defined errors for provider operations @@ -23,32 +28,23 @@ const ( ) // NetworkConfig represents the network configuration for provider connections. +// ExtraHeaders is automatically copied during provider initialization to prevent data races. type NetworkConfig struct { - DefaultRequestTimeoutInSeconds int `json:"default_request_timeout_in_seconds"` // Default timeout for requests - MaxRetries int `json:"max_retries"` // Maximum number of retries - RetryBackoffInitial time.Duration `json:"retry_backoff_initial"` // Initial backoff duration - RetryBackoffMax time.Duration `json:"retry_backoff_max"` // Maximum backoff duration + // BaseURL is supported for OpenAI, Anthropic, Cohere, Mistral, and Ollama providers (required for Ollama) + BaseURL string `json:"base_url,omitempty"` // Base URL for the provider (optional) + ExtraHeaders map[string]string `json:"extra_headers,omitempty"` // Additional headers to include in requests (optional) + DefaultRequestTimeoutInSeconds int `json:"default_request_timeout_in_seconds"` // Default timeout for requests + MaxRetries int `json:"max_retries"` // Maximum number of retries + RetryBackoffInitial time.Duration `json:"retry_backoff_initial"` // Initial backoff duration + RetryBackoffMax time.Duration `json:"retry_backoff_max"` // Maximum backoff duration } -// MetaConfig defines the interface for provider-specific configuration. -// Check /meta folder for implemented provider-specific meta configurations. -type MetaConfig interface { - // GetSecretAccessKey returns the secret access key for authentication - GetSecretAccessKey() *string - // GetRegion returns the region for the provider - GetRegion() *string - // GetSessionToken returns the session token for authentication - GetSessionToken() *string - // GetARN returns the Amazon Resource Name (ARN) - GetARN() *string - // GetInferenceProfiles returns the inference profiles - GetInferenceProfiles() map[string]string - // GetEndpoint returns the provider endpoint - GetEndpoint() *string - // GetDeployments returns the deployment configurations - GetDeployments() map[string]string - // GetAPIVersion returns the API version - GetAPIVersion() *string +// DefaultNetworkConfig is the default network configuration for provider connections. +var DefaultNetworkConfig = NetworkConfig{ + DefaultRequestTimeoutInSeconds: DefaultRequestTimeoutInSeconds, + MaxRetries: DefaultMaxRetries, + RetryBackoffInitial: DefaultRetryBackoffInitial, + RetryBackoffMax: DefaultRetryBackoffMax, } // ConcurrencyAndBufferSize represents configuration for concurrent operations and buffer sizes. @@ -57,6 +53,12 @@ type ConcurrencyAndBufferSize struct { BufferSize int `json:"buffer_size"` // Size of the buffer } +// DefaultConcurrencyAndBufferSize is the default concurrency and buffer size for provider operations. +var DefaultConcurrencyAndBufferSize = ConcurrencyAndBufferSize{ + Concurrency: DefaultConcurrency, + BufferSize: DefaultBufferSize, +} + // ProxyType defines the type of proxy to use for connections. type ProxyType string @@ -79,24 +81,141 @@ type ProxyConfig struct { Password string `json:"password"` // Password for proxy authentication } +// AllowedRequests controls which operations are permitted. +// A nil *AllowedRequests means "all operations allowed." +// A non-nil value only allows fields set to true; omitted or false fields are disallowed. +type AllowedRequests struct { + TextCompletion bool `json:"text_completion"` + ChatCompletion bool `json:"chat_completion"` + ChatCompletionStream bool `json:"chat_completion_stream"` + Embedding bool `json:"embedding"` + Speech bool `json:"speech"` + SpeechStream bool `json:"speech_stream"` + Transcription bool `json:"transcription"` + TranscriptionStream bool `json:"transcription_stream"` +} + +// IsOperationAllowed checks if a specific operation is allowed +func (ar *AllowedRequests) IsOperationAllowed(operation Operation) bool { + if ar == nil { + return true // Default to allowed if no restrictions + } + + switch operation { + case OperationTextCompletion: + return ar.TextCompletion + case OperationChatCompletion: + return ar.ChatCompletion + case OperationChatCompletionStream: + return ar.ChatCompletionStream + case OperationEmbedding: + return ar.Embedding + case OperationSpeech: + return ar.Speech + case OperationSpeechStream: + return ar.SpeechStream + case OperationTranscription: + return ar.Transcription + case OperationTranscriptionStream: + return ar.TranscriptionStream + default: + return false // Default to not allowed for unknown operations + } +} + +type CustomProviderConfig struct { + CustomProviderKey string `json:"-"` // Custom provider key, internally set by Bifrost + BaseProviderType ModelProvider `json:"base_provider_type"` // Base provider type + AllowedRequests *AllowedRequests `json:"allowed_requests,omitempty"` +} + +// IsOperationAllowed checks if a specific operation is allowed for this custom provider +func (cpc *CustomProviderConfig) IsOperationAllowed(operation Operation) bool { + if cpc == nil || cpc.AllowedRequests == nil { + return true // Default to allowed if no restrictions + } + return cpc.AllowedRequests.IsOperationAllowed(operation) +} + // ProviderConfig represents the complete configuration for a provider. -// An array of ProviderConfig needs to provided in GetConfigForProvider +// An array of ProviderConfig needs to be provided in GetConfigForProvider // in your account interface implementation. type ProviderConfig struct { NetworkConfig NetworkConfig `json:"network_config"` // Network configuration - MetaConfig MetaConfig `json:"meta_config,omitempty"` // Provider-specific configuration ConcurrencyAndBufferSize ConcurrencyAndBufferSize `json:"concurrency_and_buffer_size"` // Concurrency settings // Logger instance, can be provided by the user or bifrost default logger is used if not provided - Logger Logger `json:"logger"` - ProxyConfig *ProxyConfig `json:"proxy_config,omitempty"` // Proxy configuration + Logger Logger `json:"-"` + ProxyConfig *ProxyConfig `json:"proxy_config,omitempty"` // Proxy configuration + SendBackRawResponse bool `json:"send_back_raw_response"` // Send raw response back in the bifrost response (default: false) + CustomProviderConfig *CustomProviderConfig `json:"custom_provider_config,omitempty"` } +type Operation string + +const ( + OperationTextCompletion Operation = "text_completion" + OperationChatCompletion Operation = "chat_completion" + OperationChatCompletionStream Operation = "chat_completion_stream" + OperationEmbedding Operation = "embedding" + OperationSpeech Operation = "speech" + OperationSpeechStream Operation = "speech_stream" + OperationTranscription Operation = "transcription" + OperationTranscriptionStream Operation = "transcription_stream" +) + +func (config *ProviderConfig) CheckAndSetDefaults() { + if config.ConcurrencyAndBufferSize.Concurrency == 0 { + config.ConcurrencyAndBufferSize.Concurrency = DefaultConcurrency + } + + if config.ConcurrencyAndBufferSize.BufferSize == 0 { + config.ConcurrencyAndBufferSize.BufferSize = DefaultBufferSize + } + + if config.NetworkConfig.DefaultRequestTimeoutInSeconds == 0 { + config.NetworkConfig.DefaultRequestTimeoutInSeconds = DefaultRequestTimeoutInSeconds + } + + if config.NetworkConfig.MaxRetries == 0 { + config.NetworkConfig.MaxRetries = DefaultMaxRetries + } + + if config.NetworkConfig.RetryBackoffInitial == 0 { + config.NetworkConfig.RetryBackoffInitial = DefaultRetryBackoffInitial + } + + if config.NetworkConfig.RetryBackoffMax == 0 { + config.NetworkConfig.RetryBackoffMax = DefaultRetryBackoffMax + } + + // Create a defensive copy of ExtraHeaders to prevent data races + if config.NetworkConfig.ExtraHeaders != nil { + headersCopy := make(map[string]string, len(config.NetworkConfig.ExtraHeaders)) + maps.Copy(headersCopy, config.NetworkConfig.ExtraHeaders) + config.NetworkConfig.ExtraHeaders = headersCopy + } +} + +type PostHookRunner func(ctx *context.Context, result *BifrostResponse, err *BifrostError) (*BifrostResponse, *BifrostError) + // Provider defines the interface for AI model providers. type Provider interface { // GetProviderKey returns the provider's identifier GetProviderKey() ModelProvider // TextCompletion performs a text completion request - TextCompletion(model, key, text string, params *ModelParameters) (*BifrostResponse, *BifrostError) + TextCompletion(ctx context.Context, model string, key Key, text string, params *ModelParameters) (*BifrostResponse, *BifrostError) // ChatCompletion performs a chat completion request - ChatCompletion(model, key string, messages []Message, params *ModelParameters) (*BifrostResponse, *BifrostError) + ChatCompletion(ctx context.Context, model string, key Key, messages []BifrostMessage, params *ModelParameters) (*BifrostResponse, *BifrostError) + // ChatCompletionStream performs a chat completion stream request + ChatCompletionStream(ctx context.Context, postHookRunner PostHookRunner, model string, key Key, messages []BifrostMessage, params *ModelParameters) (chan *BifrostStream, *BifrostError) + // Embedding performs an embedding request + Embedding(ctx context.Context, model string, key Key, input *EmbeddingInput, params *ModelParameters) (*BifrostResponse, *BifrostError) + // Speech performs a text to speech request + Speech(ctx context.Context, model string, key Key, input *SpeechInput, params *ModelParameters) (*BifrostResponse, *BifrostError) + // SpeechStream performs a text to speech stream request + SpeechStream(ctx context.Context, postHookRunner PostHookRunner, model string, key Key, input *SpeechInput, params *ModelParameters) (chan *BifrostStream, *BifrostError) + // Transcription performs a transcription request + Transcription(ctx context.Context, model string, key Key, input *TranscriptionInput, params *ModelParameters) (*BifrostResponse, *BifrostError) + // TranscriptionStream performs a transcription stream request + TranscriptionStream(ctx context.Context, postHookRunner PostHookRunner, model string, key Key, input *TranscriptionInput, params *ModelParameters) (chan *BifrostStream, *BifrostError) } diff --git a/core/tests/account.go b/core/tests/account.go deleted file mode 100644 index 53278b7ce..000000000 --- a/core/tests/account.go +++ /dev/null @@ -1,202 +0,0 @@ -// Package tests provides test utilities and configurations for the Bifrost system. -// It includes test implementations of schemas, mock objects, and helper functions -// for testing the Bifrost functionality with various AI providers. -package tests - -import ( - "fmt" - "os" - "time" - - schemas "github.com/maximhq/bifrost/core/schemas" - "github.com/maximhq/bifrost/core/schemas/meta" -) - -// BaseAccount provides a test implementation of the Account interface. -// It implements basic account functionality for testing purposes, supporting -// multiple AI providers including OpenAI, Anthropic, Bedrock, Cohere, and Azure. -// The implementation uses environment variables from the .env file for API keys and provides -// default configurations suitable for testing. -type BaseAccount struct{} - -// GetConfiguredProviders returns the list of initially supported providers. -// This implementation returns OpenAI, Anthropic, and Bedrock as the default providers. -// -// Returns: -// - []schemas.SupportedModelProvider: A slice containing the supported provider identifiers -// - error: Always returns nil as this implementation doesn't produce errors -func (baseAccount *BaseAccount) GetConfiguredProviders() ([]schemas.ModelProvider, error) { - return []schemas.ModelProvider{schemas.OpenAI, schemas.Anthropic, schemas.Bedrock, schemas.Cohere, schemas.Azure}, nil -} - -// GetKeysForProvider returns the API keys and associated models for a given provider. -// It retrieves API keys from environment variables and maps them to their supported models. -// Each key includes a weight value for load balancing purposes. -// -// Parameters: -// - providerKey: The identifier of the provider to get keys for -// -// Returns: -// - []schemas.Key: A slice of Key objects containing API keys and their configurations -// - error: An error if the provider is not supported -// -// Environment Variables Used: -// - OPENAI_API_KEY: API key for OpenAI -// - ANTHROPIC_API_KEY: API key for Anthropic -// - BEDROCK_API_KEY: API key for AWS Bedrock -// - COHERE_API_KEY: API key for Cohere -// - AZURE_API_KEY: API key for Azure OpenAI -func (baseAccount *BaseAccount) GetKeysForProvider(providerKey schemas.ModelProvider) ([]schemas.Key, error) { - switch providerKey { - case schemas.OpenAI: - return []schemas.Key{ - { - Value: os.Getenv("OPENAI_API_KEY"), - Models: []string{"gpt-4o-mini", "gpt-4-turbo"}, - Weight: 1.0, - }, - }, nil - case schemas.Anthropic: - return []schemas.Key{ - { - Value: os.Getenv("ANTHROPIC_API_KEY"), - Models: []string{"claude-3-7-sonnet-20250219", "claude-3-5-sonnet-20240620", "claude-2.1"}, - Weight: 1.0, - }, - }, nil - case schemas.Bedrock: - return []schemas.Key{ - { - Value: os.Getenv("BEDROCK_API_KEY"), - Models: []string{"anthropic.claude-v2:1", "mistral.mixtral-8x7b-instruct-v0:1", "mistral.mistral-large-2402-v1:0", "anthropic.claude-3-sonnet-20240229-v1:0"}, - Weight: 1.0, - }, - }, nil - case schemas.Cohere: - return []schemas.Key{ - { - Value: os.Getenv("COHERE_API_KEY"), - Models: []string{"command-a-03-2025"}, - Weight: 1.0, - }, - }, nil - case schemas.Azure: - return []schemas.Key{ - { - Value: os.Getenv("AZURE_API_KEY"), - Models: []string{"gpt-4o"}, - Weight: 1.0, - }, - }, nil - default: - return nil, fmt.Errorf("unsupported provider: %s", providerKey) - } -} - -// GetConfigForProvider returns the configuration settings for a given provider. -// It provides standardized configuration settings for network operations, -// concurrency, and provider-specific metadata. -// -// Parameters: -// - providerKey: The identifier of the provider to get configuration for -// -// Returns: -// - *schemas.ProviderConfig: Configuration settings for the provider, including: -// - Network settings (timeouts, retries, backoff) -// - Concurrency and buffer size settings -// - Provider-specific metadata (for Bedrock and Azure) -// - error: An error if the provider is not supported -// -// Environment Variables Used: -// - BEDROCK_ACCESS_KEY: AWS access key for Bedrock configuration -// - AZURE_ENDPOINT: Azure endpoint for Azure OpenAI configuration -// -// Default Settings: -// - Request Timeout: 30 seconds -// - Max Retries: 1 -// - Initial Backoff: 100ms -// - Max Backoff: 2s -// - Concurrency: 3 -// - Buffer Size: 10 -func (baseAccount *BaseAccount) GetConfigForProvider(providerKey schemas.ModelProvider) (*schemas.ProviderConfig, error) { - switch providerKey { - case schemas.OpenAI: - return &schemas.ProviderConfig{ - NetworkConfig: schemas.NetworkConfig{ - DefaultRequestTimeoutInSeconds: 30, - MaxRetries: 1, - RetryBackoffInitial: 100 * time.Millisecond, - RetryBackoffMax: 2 * time.Second, - }, - ConcurrencyAndBufferSize: schemas.ConcurrencyAndBufferSize{ - Concurrency: 3, - BufferSize: 10, - }, - }, nil - case schemas.Anthropic: - return &schemas.ProviderConfig{ - NetworkConfig: schemas.NetworkConfig{ - DefaultRequestTimeoutInSeconds: 30, - MaxRetries: 1, - RetryBackoffInitial: 100 * time.Millisecond, - RetryBackoffMax: 2 * time.Second, - }, - ConcurrencyAndBufferSize: schemas.ConcurrencyAndBufferSize{ - Concurrency: 3, - BufferSize: 10, - }, - }, nil - case schemas.Bedrock: - return &schemas.ProviderConfig{ - NetworkConfig: schemas.NetworkConfig{ - DefaultRequestTimeoutInSeconds: 30, - MaxRetries: 1, - RetryBackoffInitial: 100 * time.Millisecond, - RetryBackoffMax: 2 * time.Second, - }, - MetaConfig: &meta.BedrockMetaConfig{ - SecretAccessKey: os.Getenv("BEDROCK_ACCESS_KEY"), - Region: StrPtr("us-east-1"), - }, - ConcurrencyAndBufferSize: schemas.ConcurrencyAndBufferSize{ - Concurrency: 3, - BufferSize: 10, - }, - }, nil - case schemas.Cohere: - return &schemas.ProviderConfig{ - NetworkConfig: schemas.NetworkConfig{ - DefaultRequestTimeoutInSeconds: 30, - MaxRetries: 1, - RetryBackoffInitial: 100 * time.Millisecond, - RetryBackoffMax: 2 * time.Second, - }, - ConcurrencyAndBufferSize: schemas.ConcurrencyAndBufferSize{ - Concurrency: 3, - BufferSize: 10, - }, - }, nil - case schemas.Azure: - return &schemas.ProviderConfig{ - NetworkConfig: schemas.NetworkConfig{ - DefaultRequestTimeoutInSeconds: 30, - MaxRetries: 1, - RetryBackoffInitial: 100 * time.Millisecond, - RetryBackoffMax: 2 * time.Second, - }, - MetaConfig: &meta.AzureMetaConfig{ - Endpoint: os.Getenv("AZURE_ENDPOINT"), - Deployments: map[string]string{ - "gpt-4o": "gpt-4o-aug", - }, - APIVersion: StrPtr("2024-08-01-preview"), - }, - ConcurrencyAndBufferSize: schemas.ConcurrencyAndBufferSize{ - Concurrency: 3, - BufferSize: 10, - }, - }, nil - default: - return nil, fmt.Errorf("unsupported provider: %s", providerKey) - } -} diff --git a/core/tests/anthropic_test.go b/core/tests/anthropic_test.go deleted file mode 100644 index 5df5170b4..000000000 --- a/core/tests/anthropic_test.go +++ /dev/null @@ -1,37 +0,0 @@ -// Package tests provides test utilities and configurations for the Bifrost system. -// It includes test implementations of schemas, mock objects, and helper functions -// for testing the Bifrost functionality with various AI providers. -package tests - -import ( - "testing" - - schemas "github.com/maximhq/bifrost/core/schemas" -) - -func TestAnthropic(t *testing.T) { - bifrost, err := getBifrost() - if err != nil { - t.Fatalf("Error initializing bifrost: %v", err) - return - } - - maxTokens := 4096 - - config := TestConfig{ - Provider: schemas.Anthropic, - TextModel: "claude-2.1", - ChatModel: "claude-3-5-sonnet-20240620", - SetupText: true, - SetupToolCalls: false, // available in 3.7 sonnet - SetupImage: true, - SetupBaseImage: true, - CustomParams: &schemas.ModelParameters{ - MaxTokens: &maxTokens, - }, - } - - SetupAllRequests(bifrost, config) - - bifrost.Cleanup() -} diff --git a/core/tests/azure_test.go b/core/tests/azure_test.go deleted file mode 100644 index 81f37b630..000000000 --- a/core/tests/azure_test.go +++ /dev/null @@ -1,30 +0,0 @@ -// Package tests provides test utilities and configurations for the Bifrost system. -// It includes test implementations of schemas, mock objects, and helper functions -// for testing the Bifrost functionality with various AI providers. -package tests - -import ( - "testing" - - schemas "github.com/maximhq/bifrost/core/schemas" -) - -func TestAzure(t *testing.T) { - bifrost, err := getBifrost() - if err != nil { - t.Fatalf("Error initializing bifrost: %v", err) - return - } - - config := TestConfig{ - Provider: schemas.Azure, - ChatModel: "gpt-4o", - SetupText: false, // gpt-4o does not support text completion - SetupToolCalls: true, - SetupImage: true, - SetupBaseImage: false, - } - - SetupAllRequests(bifrost, config) - bifrost.Cleanup() -} diff --git a/core/tests/bedrock_test.go b/core/tests/bedrock_test.go deleted file mode 100644 index ed227629c..000000000 --- a/core/tests/bedrock_test.go +++ /dev/null @@ -1,38 +0,0 @@ -// Package tests provides test utilities and configurations for the Bifrost system. -// It includes test implementations of schemas, mock objects, and helper functions -// for testing the Bifrost functionality with various AI providers. -package tests - -import ( - "testing" - - schemas "github.com/maximhq/bifrost/core/schemas" -) - -func TestBedrock(t *testing.T) { - bifrost, err := getBifrost() - if err != nil { - t.Fatalf("Error initializing bifrost: %v", err) - return - } - - maxTokens := 4096 - textCompletion := "\n\nHuman:\n\nAssistant:" - - config := TestConfig{ - Provider: schemas.Bedrock, - TextModel: "anthropic.claude-v2:1", - ChatModel: "anthropic.claude-3-sonnet-20240229-v1:0", - SetupText: true, - SetupToolCalls: true, - SetupImage: true, - SetupBaseImage: false, - CustomParams: &schemas.ModelParameters{ - MaxTokens: &maxTokens, - }, - CustomTextCompletion: &textCompletion, - } - - SetupAllRequests(bifrost, config) - bifrost.Cleanup() -} diff --git a/core/tests/cohere_test.go b/core/tests/cohere_test.go deleted file mode 100644 index 37a7bfb37..000000000 --- a/core/tests/cohere_test.go +++ /dev/null @@ -1,31 +0,0 @@ -// Package tests provides test utilities and configurations for the Bifrost system. -// It includes test implementations of schemas, mock objects, and helper functions -// for testing the Bifrost functionality with various AI providers. -package tests - -import ( - "testing" - - schemas "github.com/maximhq/bifrost/core/schemas" -) - -func TestCohere(t *testing.T) { - bifrost, err := getBifrost() - if err != nil { - t.Fatalf("Error initializing bifrost: %v", err) - return - } - - config := TestConfig{ - Provider: schemas.Cohere, - ChatModel: "command-a-03-2025", - SetupText: false, // Cohere does not support text completion - SetupToolCalls: true, - SetupImage: false, - SetupBaseImage: false, - } - - SetupAllRequests(bifrost, config) - - bifrost.Cleanup() -} diff --git a/core/tests/openai_test.go b/core/tests/openai_test.go deleted file mode 100644 index cae22a1b7..000000000 --- a/core/tests/openai_test.go +++ /dev/null @@ -1,37 +0,0 @@ -// Package tests provides test utilities and configurations for the Bifrost system. -// It includes test implementations of schemas, mock objects, and helper functions -// for testing the Bifrost functionality with various AI providers. -package tests - -import ( - "testing" - - schemas "github.com/maximhq/bifrost/core/schemas" -) - -func TestOpenAI(t *testing.T) { - bifrost, err := getBifrost() - if err != nil { - t.Fatalf("Error initializing bifrost: %v", err) - return - } - - config := TestConfig{ - Provider: schemas.OpenAI, - TextModel: "gpt-4o-mini", - ChatModel: "gpt-4o-mini", - SetupText: true, // OpenAI does not support text completion - SetupToolCalls: false, - SetupImage: false, - SetupBaseImage: false, - Fallbacks: []schemas.Fallback{ - { - Provider: schemas.Anthropic, - Model: "claude-3-5-sonnet-20240620", - }, - }, - } - - SetupAllRequests(bifrost, config) - bifrost.Cleanup() -} diff --git a/core/tests/setup.go b/core/tests/setup.go deleted file mode 100644 index 9af47664f..000000000 --- a/core/tests/setup.go +++ /dev/null @@ -1,106 +0,0 @@ -// Package tests provides test utilities and configurations for the Bifrost system. -// It includes test implementations of schemas, mock objects, and helper functions -// for testing the Bifrost functionality with various AI providers. -package tests - -import ( - "fmt" - "log" - "os" - - bifrost "github.com/maximhq/bifrost/core" - schemas "github.com/maximhq/bifrost/core/schemas" - "github.com/maximhq/bifrost/plugins" - - "github.com/joho/godotenv" -) - -// loadEnv loads environment variables from a .env file into the process environment. -// It uses the godotenv package to load variables and fails if the .env file cannot be loaded. -// -// Environment Variables: -// - .env file: Contains configuration values for the test environment -// -// Returns: -// - None, but will log.Fatal if the .env file cannot be loaded -func loadEnv() { - err := godotenv.Load() - if err != nil { - log.Fatal("Error loading .env file:", err) - } -} - -// getPlugin initializes and returns a Plugin instance for testing purposes. -// It sets up the Maxim logger with configuration from environment variables. -// -// Environment Variables: -// - MAXIM_API_KEY: API key for Maxim SDK authentication -// - MAXIM_LOGGER_ID: ID for the Maxim logger instance -// -// Returns: -// - schemas.Plugin: A configured plugin instance for request/response tracing -// - error: Any error that occurred during plugin initialization -func getPlugin() (schemas.Plugin, error) { - loadEnv() - - // check if Maxim Logger variables are set - if os.Getenv("MAXIM_API_KEY") == "" { - return nil, fmt.Errorf("MAXIM_API_KEY is not set, please set it in your .env file or pass nil in the Plugins field when initializing Bifrost") - } - - if os.Getenv("MAXIM_LOGGER_ID") == "" { - return nil, fmt.Errorf("MAXIM_LOGGER_ID is not set, please set it in your .env file or pass nil in the Plugins field when initializing Bifrost") - } - - plugin, err := plugins.NewMaximLoggerPlugin(os.Getenv("MAXIM_API_KEY"), os.Getenv("MAXIM_LOGGER_ID")) - if err != nil { - return nil, err - } - - return plugin, nil -} - -// getBifrost initializes and returns a Bifrost instance for testing. -// It sets up the test account, plugin, and logger configuration. -// -// Environment Variables: -// - Uses environment variables loaded by loadEnv() -// -// Returns: -// - *bifrost.Bifrost: A configured Bifrost instance ready for testing -// - error: Any error that occurred during Bifrost initialization -// -// The function: -// 1. Loads environment variables -// 2. Creates a test account instance -// 3. Initializes a plugin for request tracing -// 4. Configures Bifrost with the account, plugin, and default logger -func getBifrost() (*bifrost.Bifrost, error) { - loadEnv() - - account := BaseAccount{} - - // You can pass nil in the Plugins field if you don't want to use the implemented example plugin. - plugin, err := getPlugin() - if err != nil { - fmt.Println("Error setting up the plugin:", err) - return nil, err - } - - // Initialize Bifrost - b, err := bifrost.Init(schemas.BifrostConfig{ - Account: &account, - // Plugins: nil, - Plugins: []schemas.Plugin{plugin}, - Logger: bifrost.NewDefaultLogger(schemas.LogLevelDebug), - }) - if err != nil { - return nil, err - } - - return b, nil -} - -func StrPtr(s string) *string { - return &s -} diff --git a/core/tests/tests.go b/core/tests/tests.go deleted file mode 100644 index f8a70b8d5..000000000 --- a/core/tests/tests.go +++ /dev/null @@ -1,314 +0,0 @@ -// Package tests provides test utilities and configurations for the Bifrost system. -// It includes test implementations of schemas, mock objects, and helper functions -// for testing the Bifrost functionality with various AI providers. -package tests - -import ( - "context" - "fmt" - "time" - - bifrost "github.com/maximhq/bifrost/core" - schemas "github.com/maximhq/bifrost/core/schemas" -) - -// TestConfig holds configuration for test requests across different AI providers. -// It provides a flexible way to configure test scenarios for various provider capabilities. -// -// Fields: -// - Provider: The AI provider to test (e.g., OpenAI, Anthropic, etc.) -// - ChatModel: The model to use for chat completion tests -// - TextModel: The model to use for text completion tests -// - Messages: Custom messages to use in chat tests (optional) -// - SetupText: Whether to run text completion tests -// - SetupToolCalls: Whether to run function calling tests -// - SetupImage: Whether to run image input tests -// - SetupBaseImage: Whether to run base64 image tests -// - CustomTextCompletion: Custom text for completion tests (optional) -// - CustomParams: Custom model parameters for requests (optional) -// - Fallbacks: List of fallback providers and models to try if primary provider fails -type TestConfig struct { - Provider schemas.ModelProvider - ChatModel string - TextModel string - Messages []string - SetupText bool - SetupToolCalls bool - SetupImage bool - SetupBaseImage bool - CustomTextCompletion *string - CustomParams *schemas.ModelParameters - Fallbacks []schemas.Fallback -} - -// CommonTestMessages contains default messages used across providers for testing. -// These messages are used when no custom messages are provided in the test configuration. -var CommonTestMessages = []string{ - "Hello! How are you today?", - "Tell me a joke!", - "What's your favorite programming language?", -} - -// WeatherToolParams defines the parameters for a weather function tool. -// This is used to test function calling capabilities of AI providers. -var WeatherToolParams = schemas.ModelParameters{ - Tools: &[]schemas.Tool{{ - Type: "function", - Function: schemas.Function{ - Name: "get_weather", - Description: "Get the current weather in a given location", - Parameters: schemas.FunctionParameters{ - Type: "object", - Properties: map[string]interface{}{ - "location": map[string]interface{}{ - "type": "string", - "description": "The city and state, e.g. San Francisco, CA", - }, - "unit": map[string]interface{}{ - "type": "string", - "enum": []string{"celsius", "fahrenheit"}, - }, - }, - Required: []string{"location"}, - }, - }, - }}, -} - -// setupTextCompletionRequest sets up and executes a text completion test request. -// It runs asynchronously and prints the result or error to stdout. -// -// Parameters: -// - bifrost: The Bifrost instance to use for the request -// - config: Test configuration containing model and parameters -// - ctx: Context for the request -func setupTextCompletionRequest(bifrost *bifrost.Bifrost, config TestConfig, ctx context.Context) { - text := "Hello world!" - if config.CustomTextCompletion != nil { - text = *config.CustomTextCompletion - } - - params := schemas.ModelParameters{} - if config.CustomParams != nil { - params = *config.CustomParams - } - - go func() { - result, err := bifrost.TextCompletionRequest(config.Provider, &schemas.BifrostRequest{ - Model: config.TextModel, - Input: schemas.RequestInput{ - TextCompletionInput: &text, - }, - Params: ¶ms, - Fallbacks: config.Fallbacks, - }, ctx) - if err != nil { - fmt.Printf("\nError in %s text completion: %v\n", config.Provider, err.Error.Message) - } else { - fmt.Printf("\nπŸ’ %s Text Completion Result: %s\n", config.Provider, *result.Choices[0].Message.Content) - } - }() -} - -// setupChatCompletionRequests sets up and executes multiple chat completion test requests. -// It runs requests asynchronously with staggered delays and prints results to stdout. -// -// Parameters: -// - bifrost: The Bifrost instance to use for the requests -// - config: Test configuration containing model and parameters -// - ctx: Context for the requests -func setupChatCompletionRequests(bifrost *bifrost.Bifrost, config TestConfig, ctx context.Context) { - messages := config.Messages - if len(messages) == 0 { - messages = CommonTestMessages - } - - params := schemas.ModelParameters{} - if config.CustomParams != nil { - params = *config.CustomParams - } - - for i, message := range messages { - delay := time.Duration(100*(i+1)) * time.Millisecond - go func(msg string, delay time.Duration, index int) { - time.Sleep(delay) - messages := []schemas.Message{ - { - Role: schemas.RoleUser, - Content: &msg, - }, - } - result, err := bifrost.ChatCompletionRequest(config.Provider, &schemas.BifrostRequest{ - Model: config.ChatModel, - Input: schemas.RequestInput{ - ChatCompletionInput: &messages, - }, - Params: ¶ms, - Fallbacks: config.Fallbacks, - }, ctx) - if err != nil { - fmt.Printf("\nError in %s request %d: %v\n", config.Provider, index+1, err.Error.Message) - } else { - fmt.Printf("\nπŸ’ %s Chat Completion Result %d: %s\n", config.Provider, index+1, *result.Choices[0].Message.Content) - } - }(message, delay, i) - } -} - -// setupImageTests sets up and executes image input test requests. -// It tests both URL and base64 image inputs (if enabled) and prints results to stdout. -// -// Parameters: -// - bifrost: The Bifrost instance to use for the requests -// - config: Test configuration containing model and parameters -// - ctx: Context for the requests -func setupImageTests(bifrost *bifrost.Bifrost, config TestConfig, ctx context.Context) { - params := schemas.ModelParameters{} - if config.CustomParams != nil { - params = *config.CustomParams - } - - // URL image test - urlImageMessages := []schemas.Message{ - { - Role: schemas.RoleUser, - Content: StrPtr("What is Happening in this picture?"), - ImageContent: &schemas.ImageContent{ - Type: StrPtr("url"), - URL: "https://upload.wikimedia.org/wikipedia/commons/a/a7/Camponotus_flavomarginatus_ant.jpg", - }, - }, - } - - if config.Provider == schemas.Anthropic { - urlImageMessages[0].ImageContent.Type = StrPtr("url") - } - - go func() { - result, err := bifrost.ChatCompletionRequest(config.Provider, &schemas.BifrostRequest{ - Model: config.ChatModel, - Input: schemas.RequestInput{ - ChatCompletionInput: &urlImageMessages, - }, - Params: ¶ms, - Fallbacks: config.Fallbacks, - }, ctx) - if err != nil { - fmt.Printf("\nError in %s URL image request: %v\n", config.Provider, err.Error.Message) - } else { - fmt.Printf("\nπŸ’ %s URL Image Result: %s\n", config.Provider, *result.Choices[0].Message.Content) - } - }() - - // Base64 image test (only for providers that support it) - if config.SetupBaseImage { - base64ImageMessages := []schemas.Message{ - { - Role: schemas.RoleUser, - Content: StrPtr("What is this image about?"), - ImageContent: &schemas.ImageContent{ - Type: StrPtr("base64"), - URL: "/9j/4AAQSkZJRgABAQEAYABgAAD/2wBDAAgGBgcGBQgHBwcJCQgKDBQNDAsLDBkSEw8UHRofHh0aHBwgJC4nICIsIxwcKDcpLDAxNDQ0Hyc5PTgyPC4zNDL/2wBDAQkJCQwLDBgNDRgyIRwhMjIyMjIyMjIyMjIyMjIyMjIyMjIyMjIyMjIyMjIyMjIyMjIyMjIyMjIyMjIyMjIyMjL/wAARCAAIAAoDASIAAhEBAxEB/8QAFQABAQAAAAAAAAAAAAAAAAAAAAb/xAAUEAEAAAAAAAAAAAAAAAAAAAAA/8QAFQEBAQAAAAAAAAAAAAAAAAAAAAX/xAAUEQEAAAAAAAAAAAAAAAAAAAAA/9oADAMBAAIRAxEAPwCdABmX/9k=", - MediaType: StrPtr("image/jpeg"), - }, - }, - } - - go func() { - result, err := bifrost.ChatCompletionRequest(config.Provider, &schemas.BifrostRequest{ - Model: config.ChatModel, - Input: schemas.RequestInput{ - ChatCompletionInput: &base64ImageMessages, - }, - Params: ¶ms, - Fallbacks: config.Fallbacks, - }, ctx) - if err != nil { - fmt.Printf("\nError in %s base64 image request: %v\n", config.Provider, err.Error.Message) - } else { - fmt.Printf("\nπŸ’ %s Base64 Image Result: %s\n", config.Provider, *result.Choices[0].Message.Content) - } - }() - } -} - -// setupToolCalls sets up and executes function calling test requests. -// It tests the provider's ability to handle tool/function calls and prints results to stdout. -// -// Parameters: -// - bifrost: The Bifrost instance to use for the requests -// - config: Test configuration containing model and parameters -// - ctx: Context for the requests -func setupToolCalls(bifrost *bifrost.Bifrost, config TestConfig, ctx context.Context) { - messages := []string{"What's the weather like in Mumbai?"} - - params := WeatherToolParams - if config.CustomParams != nil { - customParams := *config.CustomParams - if customParams.Tools != nil { - params.Tools = customParams.Tools - } - if customParams.MaxTokens != nil { - params.MaxTokens = customParams.MaxTokens - } - } - - for i, message := range messages { - delay := time.Duration(100*(i+1)) * time.Millisecond - go func(msg string, delay time.Duration, index int) { - time.Sleep(delay) - messages := []schemas.Message{ - { - Role: schemas.RoleUser, - Content: &msg, - }, - } - result, err := bifrost.ChatCompletionRequest(config.Provider, &schemas.BifrostRequest{ - Model: config.ChatModel, - Input: schemas.RequestInput{ - ChatCompletionInput: &messages, - }, - Params: ¶ms, - Fallbacks: config.Fallbacks, - }, ctx) - if err != nil { - fmt.Printf("\nError in %s tool call request %d: %v\n", config.Provider, index+1, err.Error.Message) - } else { - if result.Choices[0].Message.ToolCalls != nil && len(*result.Choices[0].Message.ToolCalls) > 0 { - toolCall := *result.Choices[0].Message.ToolCalls - fmt.Printf("\nπŸ’ %s Tool Call Result %d: %s\n", config.Provider, index+1, toolCall[0].Function.Arguments) - } else { - fmt.Printf("\nπŸ’ %s No tool calls in response %d\n", config.Provider, index+1) - if result.ExtraFields.RawResponse != nil { - fmt.Println("\nRaw JSON Response", result.ExtraFields.RawResponse) - } - } - } - }(message, delay, i) - } -} - -// SetupAllRequests sets up and executes all configured test requests for a provider. -// It coordinates the execution of text completion, chat completion, image, and tool call tests -// based on the provided configuration. -// -// Parameters: -// - bifrost: The Bifrost instance to use for the requests -// - config: Test configuration specifying which tests to run -func SetupAllRequests(bifrost *bifrost.Bifrost, config TestConfig) { - ctx := context.Background() - - if config.SetupText { - setupTextCompletionRequest(bifrost, config, ctx) - } - - setupChatCompletionRequests(bifrost, config, ctx) - - if config.SetupImage { - setupImageTests(bifrost, config, ctx) - } - - if config.SetupToolCalls { - setupToolCalls(bifrost, config, ctx) - } -} diff --git a/core/utils.go b/core/utils.go new file mode 100644 index 000000000..e0bddf0aa --- /dev/null +++ b/core/utils.go @@ -0,0 +1,150 @@ +package bifrost + +import ( + "context" + "math/rand" + "time" + + schemas "github.com/maximhq/bifrost/core/schemas" +) + +// Ptr returns a pointer to the given value. +func Ptr[T any](v T) *T { + return &v +} + +func attachContextKeys(ctx context.Context, req *schemas.BifrostRequest, requestType schemas.RequestType) context.Context { + ctx = context.WithValue(ctx, schemas.BifrostContextKeyRequestType, requestType) + ctx = context.WithValue(ctx, schemas.BifrostContextKeyRequestProvider, req.Provider) + ctx = context.WithValue(ctx, schemas.BifrostContextKeyRequestModel, req.Model) + + return ctx +} + +// providerRequiresKey returns true if the given provider requires an API key for authentication. +// Some providers like Ollama and SGL are keyless and don't require API keys. +func providerRequiresKey(providerKey schemas.ModelProvider) bool { + return providerKey != schemas.Ollama && providerKey != schemas.SGL +} + +// canProviderKeyValueBeEmpty returns true if the given provider allows the API key to be empty. +// Some providers like Vertex and Bedrock have their credentials in additional key configs.. +func canProviderKeyValueBeEmpty(providerKey schemas.ModelProvider) bool { + return providerKey == schemas.Vertex || providerKey == schemas.Bedrock +} + +// calculateBackoff implements exponential backoff with jitter for retry attempts. +func calculateBackoff(attempt int, config *schemas.ProviderConfig) time.Duration { + // Calculate an exponential backoff: initial * 2^attempt + backoff := min(config.NetworkConfig.RetryBackoffInitial*time.Duration(1<\",\"type\":\"audio.chunk\"}\n\ndata: [DONE]\n\n" + } + } + } + } + }, + "400": { + "$ref": "#/components/responses/BadRequest" + }, + "401": { + "$ref": "#/components/responses/Unauthorized" + }, + "429": { + "$ref": "#/components/responses/RateLimited" + }, + "500": { + "$ref": "#/components/responses/InternalServerError" + } + } + } + }, + "/v1/audio/transcriptions": { + "post": { + "summary": "Create Transcription (Speech-to-Text)", + "description": "Transcribes audio files to text using AI speech recognition. Supports multiple audio formats, languages, and detailed timing information.", + "operationId": "createTranscription", + "tags": ["Audio"], + "requestBody": { + "required": true, + "content": { + "multipart/form-data": { + "schema": { + "$ref": "#/components/schemas/TranscriptionRequest" + }, + "examples": { + "basic_transcription": { + "summary": "Basic audio transcription", + "value": { + "model": "openai/whisper-1", + "file": "", + "language": "en", + "response_format": "json" + } + }, + "detailed_transcription": { + "summary": "Detailed transcription with timing", + "value": { + "model": "openai/whisper-1", + "file": "", + "language": "en", + "response_format": "verbose_json", + "prompt": "This is a recording of a technical presentation about AI.", + "temperature": 0.0 + } + }, + "streaming_transcription": { + "summary": "Streaming transcription", + "value": { + "model": "openai/whisper-1", + "file": "", + "stream": "true", + "response_format": "json" + } + } + } + } + } + }, + "responses": { + "200": { + "description": "Transcription result or streaming response", + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/BifrostResponse" + }, + "examples": { + "transcription_response": { + "summary": "Transcription response", + "value": { + "object": "audio.transcription", + "text": "Hello, this is a test of audio transcription.", + "language": "english", + "duration": 3.45, + "segments": [ + { + "id": 0, + "start": 0.0, + "end": 3.45, + "text": "Hello, this is a test of audio transcription.", + "temperature": 0.0, + "avg_logprob": -0.23, + "compression_ratio": 1.2, + "no_speech_prob": 0.01 + } + ], + "usage": { + "total_duration": 3.45 + }, + "extra_fields": { + "provider": "openai" + } + } + } + } + }, + "text/event-stream": { + "schema": { + "type": "string" + }, + "examples": { + "streaming_transcription": { + "summary": "Streaming transcription chunks", + "value": "data: {\"object\":\"audio.transcription.chunk\",\"text\":\"Hello\",\"type\":\"transcript.text.delta\"}\n\ndata: {\"object\":\"audio.transcription.chunk\",\"text\":\" world\",\"type\":\"transcript.text.delta\"}\n\ndata: [DONE]\n\n" + } + } + } + } + }, + "400": { + "$ref": "#/components/responses/BadRequest" + }, + "401": { + "$ref": "#/components/responses/Unauthorized" + }, + "429": { + "$ref": "#/components/responses/RateLimited" + }, + "500": { + "$ref": "#/components/responses/InternalServerError" + } + } + } + }, + "/openai/v1/audio/speech": { + "post": { + "summary": "OpenAI Compatible - Create Speech", + "description": "OpenAI-compatible speech synthesis endpoint. Drop-in replacement for OpenAI's audio/speech API with identical request/response format.", + "operationId": "createOpenAISpeech", + "tags": ["OpenAI Integration", "Audio"], + "requestBody": { + "required": true, + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/OpenAISpeechRequest" + }, + "examples": { + "openai_speech": { + "summary": "OpenAI speech synthesis", + "value": { + "model": "tts-1", + "input": "The quick brown fox jumps over the lazy dog.", + "voice": "alloy", + "response_format": "mp3", + "speed": 1.0 + } + } + } + } + } + }, + "responses": { + "200": { + "description": "Audio data", + "content": { + "audio/mpeg": { + "schema": { + "type": "string", + "format": "binary" + } + }, + "audio/wav": { + "schema": { + "type": "string", + "format": "binary" + } + } + } + }, + "400": { + "$ref": "#/components/responses/BadRequest" + }, + "401": { + "$ref": "#/components/responses/Unauthorized" + }, + "429": { + "$ref": "#/components/responses/RateLimited" + }, + "500": { + "$ref": "#/components/responses/InternalServerError" + } + }, + "security": [ + { + "BearerAuth": [] + } + ] + } + }, + "/openai/audio/speech": { + "post": { + "summary": "OpenAI Compatible - Create Speech (Alternative)", + "description": "Alternative OpenAI-compatible speech synthesis endpoint without version prefix.", + "operationId": "createOpenAISpeechAlt", + "tags": ["OpenAI Integration", "Audio"], + "requestBody": { + "$ref": "#/paths/~1openai~1v1~1audio~1speech/post/requestBody" + }, + "responses": { + "200": { + "description": "Audio data", + "content": { + "audio/mpeg": { + "schema": { + "type": "string", + "format": "binary" + } + }, + "audio/wav": { + "schema": { + "type": "string", + "format": "binary" + } + } + } + }, + "400": { + "$ref": "#/components/responses/BadRequest" + }, + "401": { + "$ref": "#/components/responses/Unauthorized" + }, + "429": { + "$ref": "#/components/responses/RateLimited" + }, + "500": { + "$ref": "#/components/responses/InternalServerError" + } + }, + "security": [ + { + "BearerAuth": [] + } + ] + } + }, + "/openai/v1/audio/transcriptions": { + "post": { + "summary": "OpenAI Compatible - Create Transcription", + "description": "OpenAI-compatible audio transcription endpoint. Drop-in replacement for OpenAI's audio/transcriptions API with identical request/response format.", + "operationId": "createOpenAITranscription", + "tags": ["OpenAI Integration", "Audio"], + "requestBody": { + "required": true, + "content": { + "multipart/form-data": { + "schema": { + "$ref": "#/components/schemas/OpenAITranscriptionRequest" + }, + "examples": { + "openai_transcription": { + "summary": "OpenAI transcription", + "value": { + "model": "whisper-1", + "file": "", + "language": "en", + "prompt": "Technical presentation about AI", + "response_format": "verbose_json", + "temperature": 0.0 + } + } + } + } + } + }, + "responses": { + "200": { + "description": "Transcription result", + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/OpenAITranscriptionResponse" + }, + "examples": { + "openai_transcription_response": { + "summary": "OpenAI transcription response", + "value": { + "text": "Hello, this is a test of audio transcription.", + "task": "transcribe", + "language": "english", + "duration": 3.45, + "segments": [ + { + "id": 0, + "seek": 0, + "start": 0.0, + "end": 3.45, + "text": "Hello, this is a test of audio transcription.", + "tokens": [15496, 11, 341, 307, 257, 1500, 295, 264, 6278, 35288, 4122, 13], + "temperature": 0.0, + "avg_logprob": -0.23, + "compression_ratio": 1.2, + "no_speech_prob": 0.01 + } + ], + "words": [ + {"word": "Hello", "start": 0.0, "end": 0.5}, + {"word": "this", "start": 0.6, "end": 0.8}, + {"word": "is", "start": 0.9, "end": 1.0} + ] + } + } + } + } + } + }, + "400": { + "$ref": "#/components/responses/BadRequest" + }, + "401": { + "$ref": "#/components/responses/Unauthorized" + }, + "429": { + "$ref": "#/components/responses/RateLimited" + }, + "500": { + "$ref": "#/components/responses/InternalServerError" + } + }, + "security": [ + { + "BearerAuth": [] + } + ] + } + }, + "/openai/audio/transcriptions": { + "post": { + "summary": "OpenAI Compatible - Create Transcription (Alternative)", + "description": "Alternative OpenAI-compatible transcription endpoint without version prefix.", + "operationId": "createOpenAITranscriptionAlt", + "tags": ["OpenAI Integration", "Audio"], + "requestBody": { + "$ref": "#/paths/~1openai~1v1~1audio~1transcriptions/post/requestBody" + }, + "responses": { + "200": { + "description": "Transcription result", + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/OpenAITranscriptionResponse" + }, + "examples": { + "openai_transcription_response": { + "summary": "OpenAI transcription response", + "value": { + "text": "Hello, this is a test of audio transcription.", + "task": "transcribe", + "language": "english", + "duration": 3.45, + "segments": [ + { + "id": 0, + "seek": 0, + "start": 0.0, + "end": 3.45, + "text": "Hello, this is a test of audio transcription.", + "tokens": [15496, 11, 341, 307, 257, 1500, 295, 264, 6278, 35288, 4122, 13], + "temperature": 0.0, + "avg_logprob": -0.23, + "compression_ratio": 1.2, + "no_speech_prob": 0.01 + } + ], + "words": [ + {"word": "Hello", "start": 0.0, "end": 0.5}, + {"word": "this", "start": 0.6, "end": 0.8}, + {"word": "is", "start": 0.9, "end": 1.0} + ] + } + } + } + } + } + }, + "400": { + "$ref": "#/components/responses/BadRequest" + }, + "401": { + "$ref": "#/components/responses/Unauthorized" + }, + "429": { + "$ref": "#/components/responses/RateLimited" + }, + "500": { + "$ref": "#/components/responses/InternalServerError" + } + }, + "security": [ + { + "BearerAuth": [] + } + ] + } + }, + "/api/mcp/clients": { + "get": { + "summary": "List MCP Clients", + "description": "Get information about all configured MCP (Model Context Protocol) clients including their connection state, available tools, and configuration.", + "operationId": "listMCPClients", + "tags": ["MCP Management"], + "responses": { + "200": { + "description": "List of MCP clients", + "content": { + "application/json": { + "schema": { + "type": "array", + "items": { + "$ref": "#/components/schemas/MCPClient" + } + }, + "examples": { + "mcp_clients": { + "summary": "MCP clients list", + "value": [ + { + "name": "filesystem", + "config": { + "name": "filesystem", + "connection_type": "stdio", + "stdio_config": { + "command": "npx", + "args": ["-y", "@modelcontextprotocol/server-filesystem"] + }, + "tools_to_execute": ["read_file", "list_directory"] + }, + "tools": ["read_file", "list_directory", "write_file"], + "state": "connected" + } + ] + } + } + } + } + }, + "500": { + "$ref": "#/components/responses/InternalServerError" + } + } + } + }, + "/api/mcp/client": { + "post": { + "summary": "Add MCP Client", + "description": "Add a new MCP client configuration and establish connection. Supports STDIO, HTTP, and SSE connection types.", + "operationId": "addMCPClient", + "tags": ["MCP Management"], + "requestBody": { + "required": true, + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/MCPClientConfig" + }, + "examples": { + "stdio_client": { + "summary": "STDIO MCP Client", + "value": { + "name": "filesystem", + "connection_type": "stdio", + "stdio_config": { + "command": "npx", + "args": ["-y", "@modelcontextprotocol/server-filesystem"], + "envs": ["HOME"] + }, + "tools_to_execute": ["read_file", "list_directory"] + } + }, + "http_client": { + "summary": "HTTP MCP Client", + "value": { + "name": "remote-api", + "connection_type": "http", + "connection_string": "https://api.example.com/mcp" + } + } + } + } + } + }, + "responses": { + "200": { + "description": "MCP client added successfully", + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/SuccessResponse" + } + } + } + }, + "400": { + "$ref": "#/components/responses/BadRequest" + }, + "500": { + "$ref": "#/components/responses/InternalServerError" + } + } + } + }, + "/api/mcp/client/{name}": { + "put": { + "summary": "Edit MCP Client Tools", + "description": "Modify which tools are available from an MCP client by updating tool filters.", + "operationId": "editMCPClientTools", + "tags": ["MCP Management"], + "parameters": [ + { + "name": "name", + "in": "path", + "required": true, + "schema": { + "type": "string" + }, + "description": "Name of the MCP client to edit" + } + ], + "requestBody": { + "required": true, + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/MCPClientToolsEdit" + }, + "examples": { + "whitelist_tools": { + "summary": "Allow only specific tools", + "value": { + "tools_to_execute": ["read_file", "list_directory"] + } + }, + "blacklist_tools": { + "summary": "Block dangerous tools", + "value": { + "tools_to_skip": ["delete_file", "write_file"] + } + } + } + } + } + }, + "responses": { + "200": { + "description": "MCP client tools updated successfully", + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/SuccessResponse" + } + } + } + }, + "400": { + "$ref": "#/components/responses/BadRequest" + }, + "500": { + "$ref": "#/components/responses/InternalServerError" + } + } + }, + "delete": { + "summary": "Remove MCP Client", + "description": "Remove an MCP client configuration and disconnect it from the system.", + "operationId": "removeMCPClient", + "tags": ["MCP Management"], + "parameters": [ + { + "name": "name", + "in": "path", + "required": true, + "schema": { + "type": "string" + }, + "description": "Name of the MCP client to remove" + } + ], + "responses": { + "200": { + "description": "MCP client removed successfully", + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/SuccessResponse" + } + } + } + }, + "400": { + "$ref": "#/components/responses/BadRequest" + }, + "500": { + "$ref": "#/components/responses/InternalServerError" + } + } + } + }, + "/api/mcp/client/{name}/reconnect": { + "post": { + "summary": "Reconnect MCP Client", + "description": "Reconnect a disconnected or errored MCP client.", + "operationId": "reconnectMCPClient", + "tags": ["MCP Management"], + "parameters": [ + { + "name": "name", + "in": "path", + "required": true, + "schema": { + "type": "string" + }, + "description": "Name of the MCP client to reconnect" + } + ], + "responses": { + "200": { + "description": "MCP client reconnected successfully", + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/SuccessResponse" + } + } + } + }, + "400": { + "$ref": "#/components/responses/BadRequest" + }, + "500": { + "$ref": "#/components/responses/InternalServerError" + } + } + } + }, + "/api/providers": { + "get": { + "summary": "List Providers", + "description": "Get a list of all configured AI providers with their settings and capabilities.", + "operationId": "listProviders", + "tags": ["Provider Management"], + "responses": { + "200": { + "description": "List of providers", + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/ListProvidersResponse" + } + } + } + }, + "500": { + "$ref": "#/components/responses/InternalServerError" + } + } + }, + "post": { + "summary": "Add Provider", + "description": "Add a new AI provider configuration with API keys, network settings, and concurrency options.", + "operationId": "addProvider", + "tags": ["Provider Management"], + "requestBody": { + "required": true, + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/AddProviderRequest" + }, + "examples": { + "openai_provider": { + "summary": "Add OpenAI Provider", + "value": { + "provider": "openai", + "keys": [ + { + "value": "env.OPENAI_API_KEY", + "weight": 1.0, + "models": ["gpt-4o", "gpt-4o-mini"] + } + ], + "concurrency_and_buffer_size": { + "concurrency": 10, + "buffer_size": 100 + } + } + } + } + } + } + }, + "responses": { + "200": { + "description": "Provider added successfully", + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/SuccessResponse" + } + } + } + }, + "400": { + "$ref": "#/components/responses/BadRequest" + }, + "500": { + "$ref": "#/components/responses/InternalServerError" + } + } + } + }, + "/api/providers/{provider}": { + "get": { + "summary": "Get Provider", + "description": "Get detailed configuration for a specific AI provider.", + "operationId": "getProvider", + "tags": ["Provider Management"], + "parameters": [ + { + "name": "provider", + "in": "path", + "required": true, + "schema": { + "$ref": "#/components/schemas/ModelProvider" + }, + "description": "Provider name" + } + ], + "responses": { + "200": { + "description": "Provider configuration", + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/ProviderResponse" + } + } + } + }, + "404": { + "description": "Provider not found", + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/BifrostError" + } + } + } + }, + "500": { + "$ref": "#/components/responses/InternalServerError" + } + } + }, + "put": { + "summary": "Update Provider", + "description": "Update an existing AI provider's configuration including API keys, network settings, and concurrency options.", + "operationId": "updateProvider", + "tags": ["Provider Management"], + "parameters": [ + { + "name": "provider", + "in": "path", + "required": true, + "schema": { + "$ref": "#/components/schemas/ModelProvider" + }, + "description": "Provider name" + } + ], + "requestBody": { + "required": true, + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/UpdateProviderRequest" + } + } + } + }, + "responses": { + "200": { + "description": "Provider updated successfully", + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/SuccessResponse" + } + } + } + }, + "400": { + "$ref": "#/components/responses/BadRequest" + }, + "404": { + "description": "Provider not found", + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/BifrostError" + } + } + } + }, + "500": { + "$ref": "#/components/responses/InternalServerError" + } + } + }, + "delete": { + "summary": "Delete Provider", + "description": "Remove an AI provider configuration from the system.", + "operationId": "deleteProvider", + "tags": ["Provider Management"], + "parameters": [ + { + "name": "provider", + "in": "path", + "required": true, + "schema": { + "$ref": "#/components/schemas/ModelProvider" + }, + "description": "Provider name" + } + ], + "responses": { + "200": { + "description": "Provider deleted successfully", + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/SuccessResponse" + } + } + } + }, + "404": { + "description": "Provider not found", + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/BifrostError" + } + } + } + }, + "500": { + "$ref": "#/components/responses/InternalServerError" + } + } + } + }, + "/api/config": { + "get": { + "summary": "Get Configuration", + "description": "Get the current system configuration including logging, pool size, and other runtime settings.", + "operationId": "getConfig", + "tags": ["Configuration"], + "responses": { + "200": { + "description": "Current configuration", + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/ClientConfig" + } + } + } + }, + "500": { + "$ref": "#/components/responses/InternalServerError" + } + } + }, + "put": { + "summary": "Update Configuration", + "description": "Update system configuration settings. Supports hot-reloading of certain settings like drop_excess_requests.", + "operationId": "updateConfig", + "tags": ["Configuration"], + "requestBody": { + "required": true, + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/ClientConfig" + } + } + } + }, + "responses": { + "200": { + "description": "Configuration updated successfully", + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/SuccessResponse" + } + } + } + }, + "400": { + "$ref": "#/components/responses/BadRequest" + }, + "500": { + "$ref": "#/components/responses/InternalServerError" + } + } + } + }, + "/api/logs": { + "get": { + "summary": "Get Application Logs", + "description": "Retrieve application logs with filtering, search, and pagination. Supports filtering by provider, model, status, time range, latency, tokens, and content search.", + "operationId": "getLogs", + "tags": ["Logging"], + "parameters": [ + { + "name": "providers", + "in": "query", + "schema": { + "type": "string" + }, + "description": "Comma-separated list of providers to filter by" + }, + { + "name": "models", + "in": "query", + "schema": { + "type": "string" + }, + "description": "Comma-separated list of models to filter by" + }, + { + "name": "status", + "in": "query", + "schema": { + "type": "string" + }, + "description": "Comma-separated list of statuses to filter by (success, error)" + }, + { + "name": "objects", + "in": "query", + "schema": { + "type": "string" + }, + "description": "Comma-separated list of object types to filter by (chat.completion, text.completion)" + }, + { + "name": "start_time", + "in": "query", + "schema": { + "type": "string", + "format": "date-time" + }, + "description": "Start time for filtering (RFC3339 format)" + }, + { + "name": "end_time", + "in": "query", + "schema": { + "type": "string", + "format": "date-time" + }, + "description": "End time for filtering (RFC3339 format)" + }, + { + "name": "min_latency", + "in": "query", + "schema": { + "type": "number" + }, + "description": "Minimum latency in seconds" + }, + { + "name": "max_latency", + "in": "query", + "schema": { + "type": "number" + }, + "description": "Maximum latency in seconds" + }, + { + "name": "min_tokens", + "in": "query", + "schema": { + "type": "integer" + }, + "description": "Minimum token count" + }, + { + "name": "max_tokens", + "in": "query", + "schema": { + "type": "integer" + }, + "description": "Maximum token count" + }, + { + "name": "content_search", + "in": "query", + "schema": { + "type": "string" + }, + "description": "Search term for message content" + }, + { + "name": "limit", + "in": "query", + "schema": { + "type": "integer", + "minimum": 1, + "maximum": 1000, + "default": 50 + }, + "description": "Number of logs to return (1-1000)" + }, + { + "name": "offset", + "in": "query", + "schema": { + "type": "integer", + "minimum": 0, + "default": 0 + }, + "description": "Number of logs to skip" + } + ], + "responses": { + "200": { + "description": "Application logs", + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/LogSearchResponse" + } + } + } + }, + "400": { + "$ref": "#/components/responses/BadRequest" + }, + "500": { + "$ref": "#/components/responses/InternalServerError" + } + } + } + }, + "/api/logs/dropped": { + "get": { + "summary": "Get Dropped Requests", + "description": "Get information about requests that were dropped due to queue overflow or other capacity limits.", + "operationId": "getDroppedRequests", + "tags": ["Logging"], + "responses": { + "200": { + "description": "Dropped requests information", + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/DroppedRequestsResponse" + } + } + } + }, + "500": { + "$ref": "#/components/responses/InternalServerError" + } + } + } + }, + "/ws/logs": { + "get": { + "summary": "WebSocket Log Stream", + "description": "Establish a WebSocket connection for real-time log streaming. Supports filtering similar to the /api/logs endpoint via query parameters.", + "operationId": "logWebSocket", + "tags": ["WebSocket"], + "parameters": [ + { + "name": "Upgrade", + "in": "header", + "required": true, + "schema": { + "type": "string", + "enum": ["websocket"] + } + }, + { + "name": "Connection", + "in": "header", + "required": true, + "schema": { + "type": "string", + "enum": ["Upgrade"] + } + } + ], + "responses": { + "101": { + "description": "WebSocket connection established", + "headers": { + "Upgrade": { + "schema": { + "type": "string", + "enum": ["websocket"] + } + }, + "Connection": { + "schema": { + "type": "string", + "enum": ["Upgrade"] + } + } + } + }, + "400": { + "$ref": "#/components/responses/BadRequest" + }, + "500": { + "$ref": "#/components/responses/InternalServerError" + } + } + } + }, + "/openai/chat/completions": { + "post": { + "summary": "OpenAI Compatible Chat Completions", + "description": "OpenAI-compatible chat completions endpoint that converts requests to Bifrost format and returns OpenAI-compatible responses.", + "operationId": "openaiChatCompletions", + "tags": ["Integration - OpenAI"], + "requestBody": { + "required": true, + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/ChatCompletionRequest" + } + } + } + }, + "responses": { + "200": { + "description": "OpenAI-compatible chat completion response", + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/BifrostResponse" + } + } + } + }, + "400": { + "$ref": "#/components/responses/BadRequest" + }, + "500": { + "$ref": "#/components/responses/InternalServerError" + } + } + } + }, + "/anthropic/v1/messages": { + "post": { + "summary": "Anthropic Compatible Messages", + "description": "Anthropic-compatible messages endpoint that converts requests to Bifrost format and returns Anthropic-compatible responses.", + "operationId": "anthropicMessages", + "tags": ["Integration - Anthropic"], + "requestBody": { + "required": true, + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/ChatCompletionRequest" + } + } + } + }, + "responses": { + "200": { + "description": "Anthropic-compatible message response", + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/BifrostResponse" + } + } + } + }, + "400": { + "$ref": "#/components/responses/BadRequest" + }, + "500": { + "$ref": "#/components/responses/InternalServerError" + } + } + } + }, + "/genai/v1beta/models/{model}": { + "post": { + "summary": "Google Gemini Compatible Completions", + "description": "Google Gemini-compatible completions endpoint that converts requests to Bifrost format and returns Gemini-compatible responses.", + "operationId": "geminiCompletions", + "tags": ["Integration - Gemini"], + "parameters": [ + { + "name": "model", + "in": "path", + "required": true, + "schema": { + "type": "string" + }, + "description": "Model name" + } + ], + "requestBody": { + "required": true, + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/ChatCompletionRequest" + } + } + } + }, + "responses": { + "200": { + "description": "Gemini-compatible completion response", + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/BifrostResponse" + } + } + } + }, + "400": { + "$ref": "#/components/responses/BadRequest" + }, + "500": { + "$ref": "#/components/responses/InternalServerError" + } + } + } + }, + "/litellm/chat/completions": { + "post": { + "summary": "LiteLLM Compatible Chat Completions", + "description": "LiteLLM-compatible chat completions endpoint that automatically detects the provider from the model name and converts requests accordingly.", + "operationId": "litellmChatCompletions", + "tags": ["Integration - LiteLLM"], + "requestBody": { + "required": true, + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/ChatCompletionRequest" + } + } + } + }, + "responses": { + "200": { + "description": "LiteLLM-compatible chat completion response", + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/BifrostResponse" + } + } + } + }, + "400": { + "$ref": "#/components/responses/BadRequest" + }, + "500": { + "$ref": "#/components/responses/InternalServerError" + } + } + } + }, + "/metrics": { + "get": { + "summary": "Get Prometheus Metrics", + "description": "Returns Prometheus-compatible metrics for monitoring request counts, latency, token usage, and error rates.", + "operationId": "getMetrics", + "tags": ["Monitoring"], + "responses": { + "200": { + "description": "Prometheus metrics in text format", + "content": { + "text/plain": { + "schema": { + "type": "string" + }, + "example": "# HELP http_requests_total Total number of HTTP requests\n# TYPE http_requests_total counter\nhttp_requests_total{method=\"POST\",handler=\"/v1/chat/completions\",code=\"200\"} 42\n" + } + } + } + } + } + } + }, + "components": { + "schemas": { + "ChatCompletionRequest": { + "type": "object", + "required": ["model", "messages"], + "properties": { + "model": { + "type": "string", + "description": "Model identifier in 'provider/model' format (e.g., 'openai/gpt-4o-mini', 'anthropic/claude-3-sonnet-20240229')", + "example": "openai/gpt-4o-mini" + }, + "messages": { + "type": "array", + "items": { + "$ref": "#/components/schemas/BifrostMessage" + }, + "description": "Array of chat messages", + "minItems": 1 + }, + "max_tokens": { + "type": "integer", + "minimum": 1, + "description": "Maximum number of tokens to generate", + "example": 1000 + }, + "fallbacks": { + "type": "array", + "items": { + "type": "string" + }, + "description": "Fallback model names in 'provider/model' format", + "example": ["anthropic/claude-3-sonnet-20240229", "openai/gpt-4o"] + } + } + }, + "TextCompletionRequest": { + "type": "object", + "required": ["model", "text"], + "properties": { + "model": { + "type": "string", + "description": "Model identifier in 'provider/model' format (e.g., 'anthropic/claude-2.1')", + "example": "anthropic/claude-2.1" + }, + "text": { + "type": "string", + "description": "Text prompt for completion", + "example": "The benefits of artificial intelligence include" + }, + "max_tokens": { + "type": "integer", + "minimum": 1, + "description": "Maximum number of tokens to generate", + "example": 1000 + }, + "temperature": { + "type": "number", + "minimum": 0.0, + "maximum": 2.0, + "description": "Controls randomness in the output", + "example": 0.7 + }, + "stop_sequences": { + "type": "array", + "items": { + "type": "string" + }, + "description": "Sequences that stop generation" + }, + "fallbacks": { + "type": "array", + "items": { + "type": "string" + }, + "description": "Fallback model names in 'provider/model' format", + "example": [ + "anthropic/claude-3-haiku-20240307", + "openai/gpt-4o-mini" + ] + } + } + }, + "ModelProvider": { + "type": "string", + "enum": [ + "openai", + "anthropic", + "azure", + "bedrock", + "cohere", + "vertex", + "mistral", + "ollama", + "gemini", + "groq", + "openrouter", + "sgl", + "parasail", + "cerebras" + ], + "description": "AI model provider", + "example": "openai" + }, + "BifrostMessage": { + "type": "object", + "required": ["role"], + "properties": { + "role": { + "$ref": "#/components/schemas/MessageRole" + }, + "content": { + "oneOf": [ + { + "type": "string", + "description": "Simple text content", + "example": "Hello, how are you?" + }, + { + "type": "array", + "items": { + "$ref": "#/components/schemas/ContentBlock" + }, + "description": "Structured content with text and images" + } + ], + "description": "Message content - can be simple text or structured content with text and images" + }, + "tool_call_id": { + "type": "string", + "description": "ID of the tool call (for tool messages)" + }, + "tool_calls": { + "type": "array", + "items": { + "$ref": "#/components/schemas/ToolCall" + }, + "description": "Tool calls made by assistant" + }, + "refusal": { + "type": "string", + "description": "Refusal message from assistant" + }, + "annotations": { + "type": "array", + "items": { + "$ref": "#/components/schemas/Annotation" + }, + "description": "Message annotations" + }, + "thought": { + "type": "string", + "description": "Assistant's internal thought process" + } + } + }, + "MessageRole": { + "type": "string", + "enum": ["user", "assistant", "system", "tool"], + "description": "Role of the message sender", + "example": "user" + }, + "ContentBlock": { + "type": "object", + "required": ["type"], + "discriminator": { + "propertyName": "type" + }, + "oneOf": [ + { + "type": "object", + "required": ["type", "text"], + "properties": { + "type": { + "type": "string", + "enum": ["text"], + "description": "Content type for text blocks", + "example": "text" + }, + "text": { + "type": "string", + "description": "Text content", + "example": "What do you see in this image?" + } + }, + "additionalProperties": false + }, + { + "type": "object", + "required": ["type", "image_url"], + "properties": { + "type": { + "type": "string", + "enum": ["image_url"], + "description": "Content type for image blocks", + "example": "image_url" + }, + "image_url": { + "$ref": "#/components/schemas/ImageURLStruct", + "description": "Image data" + } + }, + "additionalProperties": false + }, + { + "type": "object", + "required": ["type", "input_audio"], + "properties": { + "type": { + "type": "string", + "enum": ["input_audio"], + "description": "Content type for audio blocks", + "example": "input_audio" + }, + "input_audio": { + "$ref": "#/components/schemas/InputAudioStruct", + "description": "Audio data" + } + }, + "additionalProperties": false + } + ] + }, + "ImageURLStruct": { + "type": "object", + "required": ["url"], + "properties": { + "url": { + "type": "string", + "description": "Image URL or data URI", + "example": "https://example.com/image.jpg" + }, + "detail": { + "type": "string", + "enum": ["low", "high", "auto"], + "description": "Image detail level", + "example": "auto" + } + } + }, + "InputAudioStruct": { + "type": "object", + "required": ["data"], + "properties": { + "data": { + "type": "string", + "description": "Audio payload (opaque string such as a data URL or provider-accepted encoded content)" + }, + "format": { + "type": "string", + "description": "Optional audio format (e.g., \"mp3\", \"wav\") or MIME type (e.g., \"audio/mp3\"); providers may auto-detect when omitted" + } + } + }, + "ModelParameters": { + "type": "object", + "properties": { + "temperature": { + "type": "number", + "minimum": 0.0, + "maximum": 2.0, + "description": "Controls randomness in the output", + "example": 0.7 + }, + "top_p": { + "type": "number", + "minimum": 0.0, + "maximum": 1.0, + "description": "Nucleus sampling parameter", + "example": 0.9 + }, + "top_k": { + "type": "integer", + "minimum": 1, + "description": "Top-k sampling parameter", + "example": 40 + }, + "max_tokens": { + "type": "integer", + "minimum": 1, + "description": "Maximum number of tokens to generate", + "example": 1000 + }, + "stop_sequences": { + "type": "array", + "items": { + "type": "string" + }, + "description": "Sequences that stop generation", + "example": ["\n\n", "END"] + }, + "presence_penalty": { + "type": "number", + "minimum": -2.0, + "maximum": 2.0, + "description": "Penalizes repeated tokens", + "example": 0.0 + }, + "frequency_penalty": { + "type": "number", + "minimum": -2.0, + "maximum": 2.0, + "description": "Penalizes frequent tokens", + "example": 0.0 + }, + "tools": { + "type": "array", + "items": { + "$ref": "#/components/schemas/Tool" + }, + "description": "Available tools for the model" + }, + "tool_choice": { + "$ref": "#/components/schemas/ToolChoice" + }, + "parallel_tool_calls": { + "type": "boolean", + "description": "Enable parallel tool execution", + "example": true + } + } + }, + "Tool": { + "type": "object", + "required": ["type", "function"], + "properties": { + "id": { + "type": "string", + "description": "Unique tool identifier" + }, + "type": { + "type": "string", + "enum": ["function"], + "description": "Tool type", + "example": "function" + }, + "function": { + "$ref": "#/components/schemas/Function" + } + } + }, + "Function": { + "type": "object", + "required": ["name", "description", "parameters"], + "properties": { + "name": { + "type": "string", + "description": "Function name", + "example": "get_weather" + }, + "description": { + "type": "string", + "description": "Function description", + "example": "Get current weather for a location" + }, + "parameters": { + "$ref": "#/components/schemas/FunctionParameters" + } + } + }, + "FunctionParameters": { + "type": "object", + "required": ["type"], + "properties": { + "type": { + "type": "string", + "description": "Parameter type", + "example": "object" + }, + "description": { + "type": "string", + "description": "Parameter description" + }, + "properties": { + "type": "object", + "additionalProperties": true, + "description": "Parameter properties (JSON Schema)" + }, + "required": { + "type": "array", + "items": { + "type": "string" + }, + "description": "Required parameter names" + }, + "enum": { + "type": "array", + "items": { + "type": "string" + }, + "description": "Enum values for parameters" + } + } + }, + "ToolChoice": { + "type": "object", + "required": ["type"], + "properties": { + "type": { + "type": "string", + "enum": ["none", "auto", "any", "function", "required"], + "description": "How tools should be chosen", + "example": "auto" + }, + "function": { + "$ref": "#/components/schemas/ToolChoiceFunction" + } + } + }, + "ToolChoiceFunction": { + "type": "object", + "required": ["name"], + "properties": { + "name": { + "type": "string", + "description": "Name of the function to call", + "example": "get_weather" + } + } + }, + "ToolCall": { + "type": "object", + "required": ["function"], + "properties": { + "id": { + "type": "string", + "description": "Unique tool call identifier", + "example": "tool_123" + }, + "type": { + "type": "string", + "enum": ["function"], + "description": "Tool call type", + "example": "function" + }, + "function": { + "$ref": "#/components/schemas/FunctionCall" + } + } + }, + "FunctionCall": { + "type": "object", + "required": ["name", "arguments"], + "properties": { + "name": { + "type": "string", + "description": "Function name", + "example": "get_weather" + }, + "arguments": { + "type": "string", + "description": "JSON string of function arguments", + "example": "{\"location\": \"San Francisco, CA\"}" + } + } + }, + "Annotation": { + "type": "object", + "required": ["type", "url_citation"], + "properties": { + "type": { + "type": "string", + "description": "Annotation type" + }, + "url_citation": { + "$ref": "#/components/schemas/Citation" + } + } + }, + "Citation": { + "type": "object", + "required": ["start_index", "end_index", "title"], + "properties": { + "start_index": { + "type": "integer", + "description": "Start index in the text" + }, + "end_index": { + "type": "integer", + "description": "End index in the text" + }, + "title": { + "type": "string", + "description": "Citation title" + }, + "url": { + "type": "string", + "description": "Citation URL" + }, + "sources": { + "description": "Citation sources" + }, + "type": { + "type": "string", + "description": "Citation type" + } + } + }, + "BifrostResponse": { + "type": "object", + "properties": { + "id": { + "type": "string", + "description": "Unique response identifier", + "example": "chatcmpl-123" + }, + "object": { + "type": "string", + "enum": ["chat.completion", "text.completion"], + "description": "Response type", + "example": "chat.completion" + }, + "choices": { + "type": "array", + "items": { + "$ref": "#/components/schemas/BifrostResponseChoice" + }, + "description": "Array of completion choices" + }, + "model": { + "type": "string", + "description": "Model used for generation", + "example": "gpt-4o" + }, + "created": { + "type": "integer", + "description": "Unix timestamp of creation", + "example": 1677652288 + }, + "service_tier": { + "type": "string", + "description": "Service tier used" + }, + "system_fingerprint": { + "type": "string", + "description": "System fingerprint" + }, + "usage": { + "$ref": "#/components/schemas/LLMUsage" + }, + "extra_fields": { + "$ref": "#/components/schemas/BifrostResponseExtraFields" + } + } + }, + "BifrostResponseChoice": { + "type": "object", + "required": ["index", "message"], + "properties": { + "index": { + "type": "integer", + "description": "Choice index", + "example": 0 + }, + "message": { + "$ref": "#/components/schemas/BifrostMessage" + }, + "finish_reason": { + "type": "string", + "enum": [ + "stop", + "length", + "tool_calls", + "content_filter", + "function_call" + ], + "description": "Reason completion stopped", + "example": "stop" + }, + "stop": { + "type": "string", + "description": "Stop sequence that ended generation" + }, + "log_probs": { + "$ref": "#/components/schemas/LogProbs" + } + } + }, + "LLMUsage": { + "type": "object", + "properties": { + "prompt_tokens": { + "type": "integer", + "description": "Tokens in the prompt", + "example": 56 + }, + "completion_tokens": { + "type": "integer", + "description": "Tokens in the completion", + "example": 31 + }, + "total_tokens": { + "type": "integer", + "description": "Total tokens used", + "example": 87 + }, + "completion_tokens_details": { + "$ref": "#/components/schemas/CompletionTokensDetails" + } + } + }, + "CompletionTokensDetails": { + "type": "object", + "properties": { + "reasoning_tokens": { + "type": "integer", + "description": "Tokens used for reasoning" + }, + "audio_tokens": { + "type": "integer", + "description": "Tokens used for audio" + }, + "accepted_prediction_tokens": { + "type": "integer", + "description": "Accepted prediction tokens" + }, + "rejected_prediction_tokens": { + "type": "integer", + "description": "Rejected prediction tokens" + } + } + }, + "BifrostResponseExtraFields": { + "type": "object", + "properties": { + "provider": { + "$ref": "#/components/schemas/ModelProvider" + }, + "model_params": { + "$ref": "#/components/schemas/ModelParameters" + }, + "latency": { + "type": "number", + "description": "Request latency in seconds", + "example": 1.234 + }, + "chat_history": { + "type": "array", + "items": { + "$ref": "#/components/schemas/BifrostMessage" + }, + "description": "Full conversation history" + }, + "billed_usage": { + "$ref": "#/components/schemas/BilledLLMUsage" + }, + "raw_response": { + "type": "object", + "description": "Raw provider response" + } + } + }, + "BilledLLMUsage": { + "type": "object", + "properties": { + "prompt_tokens": { + "type": "number", + "description": "Billed prompt tokens" + }, + "completion_tokens": { + "type": "number", + "description": "Billed completion tokens" + }, + "search_units": { + "type": "number", + "description": "Billed search units" + }, + "classifications": { + "type": "number", + "description": "Billed classifications" + } + } + }, + "LogProbs": { + "type": "object", + "properties": { + "content": { + "type": "array", + "items": { + "$ref": "#/components/schemas/ContentLogProb" + }, + "description": "Log probabilities for content" + }, + "refusal": { + "type": "array", + "items": { + "$ref": "#/components/schemas/LogProb" + }, + "description": "Log probabilities for refusal" + } + } + }, + "ContentLogProb": { + "type": "object", + "required": ["logprob", "token"], + "properties": { + "bytes": { + "type": "array", + "items": { + "type": "integer" + }, + "description": "Byte representation" + }, + "logprob": { + "type": "number", + "description": "Log probability", + "example": -0.123 + }, + "token": { + "type": "string", + "description": "Token", + "example": "hello" + }, + "top_logprobs": { + "type": "array", + "items": { + "$ref": "#/components/schemas/LogProb" + }, + "description": "Top log probabilities" + } + } + }, + "LogProb": { + "type": "object", + "required": ["logprob", "token"], + "properties": { + "bytes": { + "type": "array", + "items": { + "type": "integer" + }, + "description": "Byte representation" + }, + "logprob": { + "type": "number", + "description": "Log probability", + "example": -0.456 + }, + "token": { + "type": "string", + "description": "Token", + "example": "world" + } + } + }, + "BifrostError": { + "type": "object", + "required": ["is_bifrost_error", "error"], + "properties": { + "event_id": { + "type": "string", + "description": "Unique error event ID", + "example": "evt_123" + }, + "type": { + "type": "string", + "description": "Error type", + "example": "invalid_request_error" + }, + "is_bifrost_error": { + "type": "boolean", + "description": "Whether error originated from Bifrost", + "example": true + }, + "status_code": { + "type": "integer", + "description": "HTTP status code", + "example": 400 + }, + "error": { + "$ref": "#/components/schemas/ErrorField" + } + } + }, + "ErrorField": { + "type": "object", + "required": ["message"], + "properties": { + "type": { + "type": "string", + "description": "Error type", + "example": "invalid_request_error" + }, + "code": { + "type": "string", + "description": "Error code", + "example": "missing_required_parameter" + }, + "message": { + "type": "string", + "description": "Human-readable error message", + "example": "Provider is required" + }, + "param": { + "description": "Parameter that caused the error", + "example": "provider" + }, + "event_id": { + "type": "string", + "description": "Error event ID", + "example": "evt_123" + } + } + }, + "MCPClient": { + "type": "object", + "required": ["name", "config", "tools", "state"], + "properties": { + "name": { + "type": "string", + "description": "Unique name for this MCP client", + "example": "filesystem" + }, + "config": { + "$ref": "#/components/schemas/MCPClientConfig" + }, + "tools": { + "type": "array", + "items": { + "type": "string" + }, + "description": "Available tools from this client", + "example": ["read_file", "list_directory", "write_file"] + }, + "state": { + "$ref": "#/components/schemas/MCPConnectionState" + } + } + }, + "MCPClientConfig": { + "type": "object", + "required": ["name", "connection_type"], + "properties": { + "name": { + "type": "string", + "description": "Client name", + "example": "filesystem" + }, + "connection_type": { + "$ref": "#/components/schemas/MCPConnectionType" + }, + "connection_string": { + "type": "string", + "description": "HTTP or SSE URL (required for HTTP or SSE connections)", + "example": "https://api.example.com/mcp" + }, + "stdio_config": { + "$ref": "#/components/schemas/MCPStdioConfig" + }, + "tools_to_skip": { + "type": "array", + "items": { + "type": "string" + }, + "description": "Tools to exclude from this client", + "example": ["delete_file", "write_file"] + }, + "tools_to_execute": { + "type": "array", + "items": { + "type": "string" + }, + "description": "Tools to include from this client (if specified, only these are used)", + "example": ["read_file", "list_directory"] + } + } + }, + "MCPConnectionType": { + "type": "string", + "enum": ["http", "stdio", "sse"], + "description": "Communication protocol for MCP connections", + "example": "stdio" + }, + "MCPStdioConfig": { + "type": "object", + "required": ["command", "args"], + "properties": { + "command": { + "type": "string", + "description": "Executable command to run", + "example": "npx" + }, + "args": { + "type": "array", + "items": { + "type": "string" + }, + "description": "Command line arguments", + "example": ["-y", "@modelcontextprotocol/server-filesystem"] + }, + "envs": { + "type": "array", + "items": { + "type": "string" + }, + "description": "Environment variables required", + "example": ["HOME", "USER"] + } + } + }, + "MCPConnectionState": { + "type": "string", + "enum": ["connected", "disconnected", "error"], + "description": "Connection state of MCP client", + "example": "connected" + }, + "MCPClientToolsEdit": { + "type": "object", + "properties": { + "tools_to_execute": { + "type": "array", + "items": { + "type": "string" + }, + "description": "Tools to allow from this client (whitelist)", + "example": ["read_file", "list_directory"] + }, + "tools_to_skip": { + "type": "array", + "items": { + "type": "string" + }, + "description": "Tools to block from this client (blacklist)", + "example": ["delete_file", "write_file"] + } + } + }, + "SuccessResponse": { + "type": "object", + "required": ["status", "message"], + "properties": { + "status": { + "type": "string", + "enum": ["success"], + "description": "Operation status", + "example": "success" + }, + "message": { + "type": "string", + "description": "Success message", + "example": "Operation completed successfully" + } + } + }, + "AddProviderRequest": { + "type": "object", + "required": ["provider", "keys"], + "properties": { + "provider": { + "$ref": "#/components/schemas/ModelProvider" + }, + "keys": { + "type": "array", + "items": { + "$ref": "#/components/schemas/Key" + }, + "description": "API keys for the provider" + }, + "network_config": { + "$ref": "#/components/schemas/NetworkConfig" + }, + "concurrency_and_buffer_size": { + "$ref": "#/components/schemas/ConcurrencyAndBufferSize" + }, + "proxy_config": { + "$ref": "#/components/schemas/ProxyConfig" + } + } + }, + "UpdateProviderRequest": { + "type": "object", + "required": ["keys", "network_config", "concurrency_and_buffer_size"], + "properties": { + "keys": { + "type": "array", + "items": { + "$ref": "#/components/schemas/Key" + }, + "description": "API keys for the provider" + }, + "network_config": { + "$ref": "#/components/schemas/NetworkConfig" + }, + "concurrency_and_buffer_size": { + "$ref": "#/components/schemas/ConcurrencyAndBufferSize" + }, + "proxy_config": { + "$ref": "#/components/schemas/ProxyConfig" + } + } + }, + "ProviderResponse": { + "type": "object", + "required": ["name", "keys", "network_config", "concurrency_and_buffer_size"], + "properties": { + "name": { + "$ref": "#/components/schemas/ModelProvider" + }, + "keys": { + "type": "array", + "items": { + "$ref": "#/components/schemas/Key" + }, + "description": "API keys for the provider" + }, + "network_config": { + "$ref": "#/components/schemas/NetworkConfig" + }, + "concurrency_and_buffer_size": { + "$ref": "#/components/schemas/ConcurrencyAndBufferSize" + }, + "proxy_config": { + "$ref": "#/components/schemas/ProxyConfig" + } + } + }, + "ListProvidersResponse": { + "type": "object", + "required": ["providers", "total"], + "properties": { + "providers": { + "type": "array", + "items": { + "$ref": "#/components/schemas/ProviderResponse" + }, + "description": "List of configured providers" + }, + "total": { + "type": "integer", + "description": "Total number of providers", + "example": 3 + } + } + }, + "Key": { + "type": "object", + "required": ["value"], + "properties": { + "value": { + "type": "string", + "description": "API key value or environment variable reference", + "example": "env.OPENAI_API_KEY" + }, + "weight": { + "type": "number", + "description": "Weight for load balancing", + "example": 1.0 + }, + "models": { + "type": "array", + "items": { + "type": "string" + }, + "description": "Models this key can access", + "example": ["gpt-4o", "gpt-4o-mini"] + }, + "azure_key_config": { + "type": "object", + "properties": { + "endpoint": { + "type": "string", + "description": "Azure endpoint", + "example": "https://your-resource.openai.azure.com" + }, + "deployments": { + "type": "object", + "description": "Azure deployments", + "example": { + "gpt-4o": "gpt-4o-deployment" + } + }, + "api_version": { + "type": "string", + "description": "Azure API version", + "example": "2024-02-15-preview" + } + }, + "description": "Azure key configuration" + }, + "vertex_key_config": { + "type": "object", + "properties": { + "project_id": { + "type": "string", + "description": "Vertex project ID", + "example": "your-project-id" + }, + "region": { + "type": "string", + "description": "Vertex region", + "example": "us-central1" + }, + "auth_credentials": { + "type": "string", + "description": "Vertex auth credentials", + "example": "env.VERTEX_AUTH_CREDENTIALS" + } + }, + "description": "Vertex key configuration" + }, + "bedrock_key_config": { + "type": "object", + "properties": { + "access_key": { + "type": "string", + "description": "Bedrock access key", + "example": "env.AWS_ACCESS_KEY_ID" + }, + "secret_key": { + "type": "string", + "description": "Bedrock secret key", + "example": "env.AWS_SECRET_ACCESS_KEY" + }, + "session_token": { + "type": "string", + "description": "Bedrock session token", + "example": "env.AWS_SESSION_TOKEN" + }, + "region": { + "type": "string", + "description": "Bedrock region", + "example": "us-east-1" + }, + "arn": { + "type": "string", + "description": "Bedrock ARN", + "example": "arn:aws:iam::123456789012:role/BedrockRole" + }, + "deployments": { + "type": "object", + "description": "Bedrock deployments", + "example": { + "gpt-4o": "gpt-4o-deployment" + } + } + } + } + } + }, + "NetworkConfig": { + "type": "object", + "properties": { + "timeout": { + "type": "integer", + "description": "Request timeout in seconds", + "example": 30 + }, + "max_retries": { + "type": "integer", + "description": "Maximum number of retries", + "example": 3 + } + } + }, + "ConcurrencyAndBufferSize": { + "type": "object", + "properties": { + "concurrency": { + "type": "integer", + "description": "Maximum concurrent requests", + "example": 10 + }, + "buffer_size": { + "type": "integer", + "description": "Request buffer size", + "example": 100 + } + } + }, + "ProxyConfig": { + "type": "object", + "properties": { + "url": { + "type": "string", + "description": "Proxy URL", + "example": "http://proxy.example.com:8080" + }, + "username": { + "type": "string", + "description": "Proxy username" + }, + "password": { + "type": "string", + "description": "Proxy password" + } + } + }, + "ClientConfig": { + "type": "object", + "properties": { + "initial_pool_size": { + "type": "integer", + "description": "Initial pool size for sync pools", + "example": 100 + }, + "drop_excess_requests": { + "type": "boolean", + "description": "Whether to drop requests when queue is full", + "example": false + }, + "enable_logging": { + "type": "boolean", + "description": "Whether logging is enabled", + "example": true + }, + "prometheus_labels": { + "type": "array", + "items": { + "type": "string" + }, + "description": "Prometheus metric labels", + "example": ["environment", "service"] + } + } + }, + "LogSearchResponse": { + "type": "object", + "required": ["logs", "total", "limit", "offset"], + "properties": { + "logs": { + "type": "array", + "items": { + "$ref": "#/components/schemas/LogEntry" + }, + "description": "Array of log entries" + }, + "total": { + "type": "integer", + "description": "Total number of matching logs", + "example": 156 + }, + "limit": { + "type": "integer", + "description": "Number of logs per page", + "example": 50 + }, + "offset": { + "type": "integer", + "description": "Number of logs skipped", + "example": 0 + } + } + }, + "LogEntry": { + "type": "object", + "required": ["id", "timestamp", "level", "message"], + "properties": { + "id": { + "type": "string", + "description": "Unique log entry ID", + "example": "log_123" + }, + "timestamp": { + "type": "string", + "format": "date-time", + "description": "Log entry timestamp", + "example": "2023-12-01T10:30:00Z" + }, + "level": { + "type": "string", + "enum": ["debug", "info", "warn", "error"], + "description": "Log level", + "example": "info" + }, + "message": { + "type": "string", + "description": "Log message", + "example": "Request completed successfully" + }, + "provider": { + "$ref": "#/components/schemas/ModelProvider" + }, + "model": { + "type": "string", + "description": "Model name used", + "example": "gpt-4o" + }, + "status": { + "type": "string", + "enum": ["success", "error"], + "description": "Request status", + "example": "success" + }, + "latency": { + "type": "number", + "description": "Request latency in seconds", + "example": 1.234 + }, + "tokens": { + "type": "integer", + "description": "Total tokens used", + "example": 87 + }, + "object": { + "type": "string", + "enum": ["chat.completion", "text.completion"], + "description": "Request object type", + "example": "chat.completion" + } + } + }, + "DroppedRequestsResponse": { + "type": "object", + "required": ["total_dropped", "recent_drops"], + "properties": { + "total_dropped": { + "type": "integer", + "description": "Total number of dropped requests", + "example": 5 + }, + "recent_drops": { + "type": "array", + "items": { + "$ref": "#/components/schemas/DroppedRequest" + }, + "description": "Recent dropped requests" + } + } + }, + "DroppedRequest": { + "type": "object", + "required": ["timestamp", "reason"], + "properties": { + "timestamp": { + "type": "string", + "format": "date-time", + "description": "When the request was dropped", + "example": "2023-12-01T10:30:00Z" + }, + "reason": { + "type": "string", + "description": "Reason for dropping the request", + "example": "Queue overflow" + }, + "provider": { + "$ref": "#/components/schemas/ModelProvider" + }, + "model": { + "type": "string", + "description": "Model name requested", + "example": "gpt-4o" + } + } + }, + "SpeechRequest": { + "type": "object", + "required": ["model", "input", "voice"], + "properties": { + "model": { + "type": "string", + "description": "Model to use for speech synthesis in 'provider/model' format", + "example": "openai/tts-1" + }, + "input": { + "type": "string", + "description": "Text to convert to speech (max 4096 characters)", + "example": "Hello! This is a test of speech synthesis.", + "maxLength": 4096 + }, + "voice": { + "type": "string", + "description": "Voice to use for speech synthesis", + "enum": ["alloy", "echo", "fable", "onyx", "nova", "shimmer"], + "example": "alloy" + }, + "response_format": { + "type": "string", + "description": "Audio format for the response", + "enum": ["mp3", "opus", "aac", "flac", "wav", "pcm"], + "default": "mp3", + "example": "mp3" + }, + "instructions": { + "type": "string", + "description": "Additional instructions for voice synthesis", + "example": "Speak slowly and clearly" + }, + "stream_format": { + "type": "string", + "description": "Enable streaming with Server-Sent Events", + "enum": ["sse"], + "example": "sse" + } + } + }, + "TranscriptionRequest": { + "type": "object", + "required": ["model", "file"], + "properties": { + "model": { + "type": "string", + "description": "Model to use for transcription in 'provider/model' format", + "example": "openai/whisper-1" + }, + "file": { + "type": "string", + "format": "binary", + "description": "Audio file to transcribe (mp3, mp4, mpeg, mpga, m4a, wav, webm, max 25MB)" + }, + "language": { + "type": "string", + "description": "Language of the input audio (ISO-639-1 format)", + "example": "en" + }, + "prompt": { + "type": "string", + "description": "Optional text to guide the model's style or continue a previous audio segment", + "example": "This is a recording of a technical presentation about AI." + }, + "response_format": { + "type": "string", + "description": "Format of the transcript output", + "enum": ["json", "text", "srt", "verbose_json", "vtt"], + "default": "json", + "example": "verbose_json" + }, + "temperature": { + "type": "number", + "description": "Sampling temperature (0 to 1)", + "minimum": 0, + "maximum": 1, + "example": 0.0 + }, + "stream": { + "type": "string", + "description": "Enable streaming transcription", + "enum": ["true", "false"], + "example": "true" + } + } + }, + "OpenAISpeechRequest": { + "type": "object", + "required": ["model", "input", "voice"], + "properties": { + "model": { + "type": "string", + "description": "TTS model to use", + "enum": ["tts-1", "tts-1-hd"], + "example": "tts-1" + }, + "input": { + "type": "string", + "description": "Text to generate audio for (max 4096 characters)", + "example": "The quick brown fox jumped over the lazy dog.", + "maxLength": 4096 + }, + "voice": { + "type": "string", + "description": "Voice to use when generating the audio", + "enum": ["alloy", "echo", "fable", "onyx", "nova", "shimmer"], + "example": "alloy" + }, + "response_format": { + "type": "string", + "description": "Format to audio in", + "enum": ["mp3", "opus", "aac", "flac", "wav", "pcm"], + "default": "mp3", + "example": "mp3" + }, + "speed": { + "type": "number", + "description": "Speed of the generated audio (0.25 to 4.0)", + "minimum": 0.25, + "maximum": 4.0, + "default": 1.0, + "example": 1.0 + }, + "instructions": { + "type": "string", + "description": "Additional instructions for voice synthesis", + "example": "Speak slowly and clearly" + }, + "stream_format": { + "type": "string", + "description": "Enable streaming with Server-Sent Events", + "enum": ["sse"], + "example": "sse" + } + } + }, + "OpenAITranscriptionRequest": { + "type": "object", + "required": ["model", "file"], + "properties": { + "model": { + "type": "string", + "description": "ID of the model to use", + "enum": ["whisper-1"], + "example": "whisper-1" + }, + "file": { + "type": "string", + "format": "binary", + "description": "Audio file object to transcribe (mp3, mp4, mpeg, mpga, m4a, wav, webm, max 25MB)" + }, + "language": { + "type": "string", + "description": "Language of the input audio (ISO-639-1 format)", + "example": "en" + }, + "prompt": { + "type": "string", + "description": "Optional text to guide the model's style or continue a previous audio segment", + "example": "ZyntriQix, Digique Plus, CynapseFive, VortiQore V8, EchoNix Pro, CyberLeap" + }, + "response_format": { + "type": "string", + "description": "Format of the transcript output", + "enum": ["json", "text", "srt", "verbose_json", "vtt"], + "default": "json", + "example": "json" + }, + "temperature": { + "type": "number", + "description": "Sampling temperature (0 to 1)", + "minimum": 0, + "maximum": 1, + "default": 0, + "example": 0 + }, + "include": { + "type": "array", + "items": { + "type": "string", + "enum": ["segments"] + }, + "description": "Additional data to include in the response" + }, + "timestamp_granularities": { + "type": "array", + "items": { + "type": "string", + "enum": ["word", "segment"] + }, + "description": "Timestamp granularities to populate for this transcription", + "default": ["segment"] + }, + "stream": { + "type": "boolean", + "description": "Enable streaming transcription", + "default": false, + "example": false + } + } + }, + "OpenAITranscriptionResponse": { + "type": "object", + "required": ["text"], + "properties": { + "text": { + "type": "string", + "description": "The transcribed text", + "example": "Imagine the wildest idea that you've ever had, and you're curious about how it might scale to something that's a 100, a 1,000 times bigger." + }, + "task": { + "type": "string", + "description": "Task that was performed", + "example": "transcribe" + }, + "language": { + "type": "string", + "description": "Detected language of the input audio", + "example": "english" + }, + "duration": { + "type": "number", + "description": "Duration of the input audio in seconds", + "example": 8.470000267028809 + }, + "segments": { + "type": "array", + "description": "Segments of the transcribed text and their corresponding details", + "items": { + "type": "object", + "properties": { + "id": { + "type": "integer", + "description": "Unique identifier of the segment" + }, + "seek": { + "type": "integer", + "description": "Seek offset of the segment" + }, + "start": { + "type": "number", + "description": "Start time of the segment in seconds" + }, + "end": { + "type": "number", + "description": "End time of the segment in seconds" + }, + "text": { + "type": "string", + "description": "Text content of the segment" + }, + "tokens": { + "type": "array", + "items": { + "type": "integer" + }, + "description": "Array of token IDs for the text content" + }, + "temperature": { + "type": "number", + "description": "Temperature parameter used for generating the segment" + }, + "avg_logprob": { + "type": "number", + "description": "Average logprob of the segment" + }, + "compression_ratio": { + "type": "number", + "description": "Compression ratio of the segment" + }, + "no_speech_prob": { + "type": "number", + "description": "Probability of no speech in the segment" + } + } + } + }, + "words": { + "type": "array", + "description": "Individual words and their corresponding timestamps", + "items": { + "type": "object", + "properties": { + "word": { + "type": "string", + "description": "The text content of the word" + }, + "start": { + "type": "number", + "description": "Start time of the word in seconds" + }, + "end": { + "type": "number", + "description": "End time of the word in seconds" + } + } + } + } + } + } + }, + "responses": { + "BadRequest": { + "description": "Bad Request - Invalid request format or missing required fields", + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/BifrostError" + }, + "example": { + "is_bifrost_error": true, + "status_code": 400, + "error": { + "type": "invalid_request_error", + "message": "Invalid request format" + } + } + } + } + }, + "Unauthorized": { + "description": "Unauthorized - Invalid or missing API key", + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/BifrostError" + }, + "example": { + "is_bifrost_error": true, + "status_code": 401, + "error": { + "type": "authentication_error", + "message": "Invalid API key provided" + } + } + } + } + }, + "RateLimited": { + "description": "Rate Limited - Too many requests", + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/BifrostError" + }, + "example": { + "is_bifrost_error": true, + "status_code": 429, + "error": { + "type": "rate_limit_error", + "message": "Rate limit exceeded" + } + } + } + } + }, + "InternalServerError": { + "description": "Internal Server Error - An unexpected error occurred", + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/BifrostError" + }, + "example": { + "is_bifrost_error": true, + "status_code": 500, + "error": { + "type": "api_error", + "message": "Internal server error occurred" + } + } + } + } + } + } + }, + "tags": [ + { + "name": "Chat Completions", + "description": "Create chat completions using conversational messages" + }, + { + "name": "Text Completions", + "description": "Create text completions from prompts" + }, + { + "name": "Audio", + "description": "Speech synthesis and audio transcription" + }, + { + "name": "MCP Tools", + "description": "Execute MCP tools" + }, + { + "name": "MCP Management", + "description": "Manage MCP client configurations and connections" + }, + { + "name": "Provider Management", + "description": "Manage AI provider configurations" + }, + { + "name": "Configuration", + "description": "System configuration management" + }, + { + "name": "Logging", + "description": "Application logs and dropped requests" + }, + { + "name": "WebSocket", + "description": "Real-time WebSocket connections" + }, + { + "name": "Integration - OpenAI", + "description": "OpenAI-compatible endpoints" + }, + { + "name": "Integration - Anthropic", + "description": "Anthropic-compatible endpoints" + }, + { + "name": "Integration - Gemini", + "description": "Google Gemini-compatible endpoints" + }, + { + "name": "Integration - LiteLLM", + "description": "LiteLLM-compatible endpoints" + }, + { + "name": "Monitoring", + "description": "Monitoring and observability endpoint" + } + ] +} diff --git a/docs/architecture/README.mdx b/docs/architecture/README.mdx new file mode 100644 index 000000000..e69de29bb diff --git a/docs/architecture/core/concurrency.mdx b/docs/architecture/core/concurrency.mdx new file mode 100644 index 000000000..660390628 --- /dev/null +++ b/docs/architecture/core/concurrency.mdx @@ -0,0 +1,761 @@ +--- +title: "Concurrency" +description: "Deep dive into Bifrost's advanced concurrency architecture - worker pools, goroutine management, channel-based communication, and resource isolation patterns." +icon: "traffic-light" +--- + +## Concurrency Philosophy + +### **Core Principles** + +| Principle | Implementation | Benefit | +| ---------------------------------- | -------------------------------------- | -------------------------------------- | +| **Provider Isolation** | Independent worker pools per provider | Fault tolerance, no cascade failures | +| **Channel-Based Communication** | Go channels for all async operations | Type-safe, deadlock-free communication | +| **Resource Pooling** | Object pools with lifecycle management | Predictable memory usage, minimal GC | +| **Non-Blocking Operations** | Async processing throughout pipeline | Maximum concurrency, no blocking waits | +| **Backpressure Handling** | Configurable buffers and flow control | Graceful degradation under load | + +### **Threading Architecture Overview** + +```mermaid +graph TB + subgraph "Main Thread" + Main[Main Process
HTTP Server] + Router[Request Router
Goroutine] + PluginMgr[Plugin Manager
Goroutine] + end + + subgraph "Provider Worker Pools" + subgraph "OpenAI Pool" + OAI1[Worker 1
Goroutine] + OAI2[Worker 2
Goroutine] + OAIN[Worker N
Goroutine] + end + subgraph "Anthropic Pool" + ANT1[Worker 1
Goroutine] + ANT2[Worker 2
Goroutine] + ANTN[Worker N
Goroutine] + end + subgraph "Bedrock Pool" + BED1[Worker 1
Goroutine] + BED2[Worker 2
Goroutine] + BEDN[Worker N
Goroutine] + end + end + + subgraph "Memory Pools" + ChannelPool[Channel Pool
sync.Pool] + MessagePool[Message Pool
sync.Pool] + ResponsePool[Response Pool
sync.Pool] + end + + Main --> Router + Router --> PluginMgr + PluginMgr --> OAI1 + PluginMgr --> ANT1 + PluginMgr --> BED1 + + OAI1 --> ChannelPool + ANT1 --> MessagePool + BED1 --> ResponsePool +``` + +--- + +## Worker Pool Architecture + +### **Provider-Isolated Worker Pools** + +```mermaid +stateDiagram-v2 + [*] --> PoolInit: Worker Pool Creation + PoolInit --> WorkerSpawn: Spawn Worker Goroutines + WorkerSpawn --> Listening: Workers Listen on Channels + + Listening --> Processing: Job Received + Processing --> API_Call: Provider API Request + API_Call --> Response: Process Response + Response --> Listening: Job Complete + + Listening --> Shutdown: Graceful Shutdown + Processing --> Shutdown: Complete Current Job + Shutdown --> [*]: Pool Destroyed +``` + +**Worker Pool Architecture:** + +The worker pool system maintains a sophisticated balance between resource efficiency and performance isolation: + +**Key Components:** + +- **Worker Pool Management** - Pre-spawned workers reduce startup latency +- **Job Queue System** - Buffered channels provide smooth load balancing +- **Resource Pools** - HTTP clients and API keys are pooled for efficiency +- **Health Monitoring** - Circuit breakers detect and isolate failing providers +- **Graceful Shutdown** - Workers complete current jobs before terminating + +**Startup Process:** + +1. **Worker Pre-spawning** - Workers are created during pool initialization +2. **Channel Setup** - Job queues and worker channels are established +3. **Resource Allocation** - HTTP clients and API keys are distributed +4. **Health Checks** - Initial connectivity tests verify provider availability +5. **Ready State** - Pool becomes available for request processing + +**Job Dispatch Logic:** + +- **Round-Robin Assignment** - Jobs are distributed evenly across available workers +- **Load Balancing** - Worker availability determines job assignment +- **Overflow Handling** - Excess jobs are queued or dropped based on configuration + +### **Worker Lifecycle Management** + +```mermaid +sequenceDiagram + participant Pool + participant Worker + participant HTTPClient + participant Provider + participant Metrics + + Pool->>Worker: Start() + Worker->>Worker: Initialize HTTP Client + Worker->>Pool: Ready Signal + + loop Job Processing + Pool->>Worker: Job Assignment + Worker->>HTTPClient: Prepare Request + HTTPClient->>Provider: API Call + Provider-->>HTTPClient: Response + HTTPClient-->>Worker: Parsed Response + Worker->>Metrics: Record Performance + Worker->>Pool: Job Complete + end + + Pool->>Worker: Shutdown Signal + Worker->>Worker: Complete Current Job + Worker-->>Pool: Shutdown Confirmed +```` + +--- + +## Channel-Based Communication + +### **Channel Architecture** + +```mermaid +graph TB + subgraph "Channel Types" + JobQueue[Job Queue
Buffered Channel] + WorkerPool[Worker Pool
Buffered Channel] + ResultChan[Result Channel
Buffered Channel] + QuitChan[Quit Channel
Unbuffered] + end + + subgraph "Flow Control" + BackPressure[Backpressure
Buffer Limits] + Timeout[Timeout
Context Cancellation] + Graceful[Graceful Shutdown
Channel Closing] + end + + JobQueue --> BackPressure + WorkerPool --> Timeout + ResultChan --> Graceful +``` + +**Channel Configuration Principles:** + +Bifrost's channel system balances throughput and memory usage through careful buffer sizing: + +**Job Queuing Configuration:** + +- **Job Queue Buffer** - Sized based on expected burst traffic (100-1000 jobs) +- **Worker Pool Size** - Matches provider concurrency limits (10-100 workers) +- **Result Buffer** - Accommodates response processing delays (50-500 responses) + +**Flow Control Parameters:** + +- **Queue Wait Limits** - Maximum time jobs wait before timeout (1-10 seconds) +- **Processing Timeouts** - Per-job execution limits (30-300 seconds) +- **Shutdown Timeouts** - Graceful termination periods (5-30 seconds) + +**Backpressure Policies:** + +- **Drop Policy** - Discard excess jobs when queues are full +- **Block Policy** - Wait for queue space with timeout +- **Error Policy** - Immediately return error for full queues + +**Channel Type Selection:** + +- **Buffered Channels** - Used for async job processing and result handling +- **Unbuffered Channels** - Used for synchronization signals (quit, done) +- **Context Cancellation** - Used for timeout and cancellation propagation + +### **Backpressure and Flow Control** + +```mermaid +flowchart TD + Request[Incoming Request] --> QueueCheck{Queue Full?} + QueueCheck -->|No| Queue[Add to Queue] + QueueCheck -->|Yes| Policy{Drop Policy?} + + Policy -->|Drop| Drop[Drop Request
Return Error] + Policy -->|Block| Block[Block Until Space
With Timeout] + Policy -->|Error| Error[Return Queue Full Error] + + Queue --> Worker[Assign to Worker] + Block --> TimeoutCheck{Timeout?} + TimeoutCheck -->|Yes| Error + TimeoutCheck -->|No| Queue + + Worker --> Processing[Process Request] + Processing --> Complete[Complete] + + Drop --> Client[Client Response] + Error --> Client + Complete --> Client +```` + +**Backpressure Implementation Strategy:** + +The backpressure system protects Bifrost from being overwhelmed while maintaining service availability: + +**Non-Blocking Job Submission:** + +- **Immediate Queue Check** - Jobs are submitted without blocking on queue space +- **Success Path** - Available queue space allows immediate job acceptance +- **Overflow Detection** - Full queues trigger backpressure policies +- **Metrics Collection** - All queue operations are tracked for monitoring + +**Backpressure Policy Execution:** + +- **Drop Policy** - Immediately rejects excess jobs with meaningful error messages +- **Block Policy** - Waits for queue space with configurable timeout limits +- **Error Policy** - Returns queue full errors for immediate client feedback +- **Metrics Tracking** - Dropped, blocked, and successful submissions are measured + +**Timeout Management:** + +- **Context-Based Timeouts** - All blocking operations respect timeout boundaries +- **Graceful Degradation** - Timeouts result in controlled error responses +- **Resource Protection** - Prevents goroutine leaks from infinite waits + +```go + case pool.jobQueue <- job: + pool.metrics.IncQueuedJobs() + return nil + case <-ctx.Done(): + pool.metrics.IncTimeoutJobs() + return errors.New("queue full, timeout waiting") + } + + case "error": + pool.metrics.IncRejectedJobs() + return errors.New("queue full, job rejected") + + default: + return errors.New("unknown queue policy") + } + } + } +``` + +--- + +## Memory Pool Concurrency + +### **Thread-Safe Object Pools** + +```mermaid +graph TB + subgraph "sync.Pool Architecture" + GetObject[Get Object
sync.Pool.Get()] + NewObject[New Object
Factory Function] + UseObject[Use Object
Application Logic] + ResetObject[Reset Object
Clear State] + ReturnObject[Return Object
sync.Pool.Put()] + end + + subgraph "GC Integration" + GCRun[GC Runs] + PoolCleanup[Pool Cleanup
Automatic] + Reallocation[Object Reallocation
as Needed] + end + + GetObject --> NewObject + NewObject --> UseObject + UseObject --> ResetObject + ResetObject --> ReturnObject + ReturnObject --> GetObject + + GCRun --> PoolCleanup + PoolCleanup --> Reallocation +``` + +**Thread-Safe Pool Architecture:** + +Bifrost's memory pool system ensures thread-safe object reuse across multiple goroutines: + +**Pool Structure Design:** + +- **Multiple Pool Types** - Separate pools for channels, messages, responses, and buffers +- **Factory Functions** - Dynamic object creation when pools are empty +- **Statistics Tracking** - Comprehensive metrics for pool performance monitoring +- **Thread Safety** - Synchronized access using Go's sync.Pool and read-write mutexes + +**Object Lifecycle Management:** + +- **Pool Initialization** - Factory functions define object creation patterns +- **Unique Identification** - Each pooled object gets a unique ID for tracking +- **Timestamp Tracking** - Creation, acquisition, and return times are recorded +- **Reusability Flags** - Objects can be marked as non-reusable for single-use scenarios + +**Acquisition Strategy:** + +- **Request Tracking** - All pool requests are counted for monitoring +- **Hit/Miss Tracking** - Pool effectiveness is measured through hit ratios +- **Fallback Creation** - New objects are created when pools are empty +- **Performance Metrics** - Acquisition times and patterns are monitored + +**Return and Reset Process:** + +- **State Validation** - Only reusable objects are returned to pools +- **Object Reset** - All object state is cleared before returning to pool +- **Return Tracking** - Return operations are counted and timed +- **Pool Replenishment** - Returned objects become available for reuse + +### **Pool Performance Monitoring** + +Comprehensive metrics provide insights into pool efficiency and system health: + +**Usage Statistics Collection:** +- **Request Counting** - Track total pool requests by object type +- **Creation Tracking** - Monitor new object allocations when pools are empty +- **Hit/Miss Ratios** - Measure pool effectiveness through reuse rates +- **Return Monitoring** - Track successful object returns to pools + +**Performance Metrics Analysis:** +- **Acquisition Times** - Measure how long it takes to get objects from pools +- **Reset Performance** - Track time spent cleaning objects for reuse +- **Hit Ratio Calculation** - Determine percentage of requests served from pools +- **Memory Efficiency** - Calculate memory savings from object reuse + +**Key Performance Indicators:** +- **Channel Pool Hit Ratio** - Typically 85-95% in steady state +- **Message Pool Efficiency** - Usually 80-90% reuse rate +- **Response Pool Utilization** - Often 70-85% hit ratio +- **Total Memory Savings** - Measured reduction in garbage collection pressure + +**Monitoring Integration:** +- **Thread-Safe Access** - All metrics collection is synchronized +- **Real-Time Updates** - Statistics are updated with each pool operation +- **Export Capability** - Metrics are available in JSON format for monitoring systems +- **Alerting Support** - Low hit ratios can trigger performance alerts + +--- + +## Goroutine Management + +### **Goroutine Lifecycle Patterns** + +```mermaid +stateDiagram-v2 + [*] --> Created: go routine() + Created --> Running: Execute Function + Running --> Waiting: Channel/Mutex Block + Waiting --> Running: Unblocked + Running --> Syscall: Network I/O + Syscall --> Running: I/O Complete + Running --> GCAssist: GC Triggered + GCAssist --> Running: GC Complete + Running --> Terminated: Function Exit + Terminated --> [*]: Cleanup +``` + +**Goroutine Pool Management Strategy:** + +Bifrost's goroutine management ensures optimal resource usage while preventing goroutine leaks: + +**Pool Configuration Management:** + +- **Goroutine Limits** - Maximum concurrent goroutines prevent resource exhaustion +- **Active Counting** - Atomic counters track currently running goroutines +- **Idle Timeouts** - Unused goroutines are cleaned up after configured periods +- **Resource Boundaries** - Hard limits prevent runaway goroutine creation + +**Lifecycle Orchestration:** + +- **Spawn Channels** - New goroutine creation is tracked through channels +- **Completion Monitoring** - Finished goroutines signal completion for cleanup +- **Shutdown Coordination** - Graceful shutdown ensures all goroutines complete properly +- **Health Monitoring** - Continuous monitoring tracks goroutine health and performance + +**Worker Creation Process:** + +- **Limit Enforcement** - Creation fails when maximum goroutine count is reached +- **Unique Identification** - Each goroutine gets a unique ID for tracking and debugging +- **Lifecycle Tracking** - Start times and names enable performance analysis +- **Atomic Operations** - Thread-safe counters prevent race conditions + +**Panic Recovery and Error Handling:** + +- **Panic Isolation** - Goroutine panics don't crash the entire system +- **Error Logging** - Panic details are logged with goroutine context +- **Metrics Updates** - Panic counts are tracked for monitoring and alerting +- **Resource Cleanup** - Failed goroutines are properly cleaned up and counted + +**Health Monitoring System:** + +- **Periodic Health Checks** - Regular intervals check goroutine pool health +- **Completion Tracking** - Finished goroutines are recorded for performance analysis +- **Shutdown Handling** - Clean shutdown process ensures no goroutine leaks + +### **Resource Leak Prevention** + +```mermaid +flowchart TD + GoroutineStart[Goroutine Start] --> ResourceCheck[Resource Allocation Check] + ResourceCheck --> Timeout[Set Timeout Context] + Timeout --> Work[Execute Work] + + Work --> Complete{Work Complete?} + Complete -->|Yes| Cleanup[Cleanup Resources] + Complete -->|No| TimeoutCheck{Timeout?} + + TimeoutCheck -->|Yes| ForceCleanup[Force Cleanup] + TimeoutCheck -->|No| Work + + Cleanup --> Return[Return Resources to Pool] + ForceCleanup --> Return + Return --> End[Goroutine End] +```` + +**Resource Leak Prevention:** + +```go +func (worker *Worker) ExecuteWithCleanup(job *Job) { + // Set timeout context + ctx, cancel := context.WithTimeout( + context.Background(), + worker.config.ProcessTimeout, + ) + defer cancel() + + // Acquire resources with timeout + resources, err := worker.acquireResources(ctx) + if err != nil { + job.resultChan <- &Result{Error: err} + return + } + + // Ensure cleanup happens + defer func() { + // Always return resources + worker.returnResources(resources) + + // Handle panics + if r := recover(); r != nil { + worker.metrics.IncPanics() + job.resultChan <- &Result{ + Error: fmt.Errorf("worker panic: %v", r), + } + } + }() + + // Execute job with context + result := worker.processJob(ctx, job, resources) + + // Return result + select { + case job.resultChan <- result: + // Success + case <-ctx.Done(): + // Timeout - result channel might be closed + worker.metrics.IncTimeouts() + } +} +``` + +--- + +## Concurrency Optimization Strategies + +### **Load-Based Worker Scaling** (Planned) + +```mermaid +graph TB + subgraph "Load Monitoring" + QueueDepth[Queue Depth
Monitoring] + ResponseTime[Response Time
Tracking] + WorkerUtil[Worker Utilization
Metrics] + end + + subgraph "Scaling Decisions" + ScaleUp{Scale Up?
Load > 80%} + ScaleDown{Scale Down?
Load < 30%} + Maintain[Maintain
Current Size] + end + + subgraph "Actions" + AddWorkers[Spawn Additional
Workers] + RemoveWorkers[Graceful Worker
Shutdown] + NoAction[No Action
Monitor Continue] + end + + QueueDepth --> ScaleUp + ResponseTime --> ScaleUp + WorkerUtil --> ScaleDown + + ScaleUp -->|Yes| AddWorkers + ScaleUp -->|No| ScaleDown + ScaleDown -->|Yes| RemoveWorkers + ScaleDown -->|No| Maintain + + Maintain --> NoAction +``` + +**Adaptive Scaling Implementation:** + +```go +type AdaptiveScaler struct { + pool *ProviderWorkerPool + config ScalingConfig + metrics *ScalingMetrics + lastScaleTime time.Time + scalingMutex sync.Mutex +} + +func (scaler *AdaptiveScaler) EvaluateScaling() { + scaler.scalingMutex.Lock() + defer scaler.scalingMutex.Unlock() + + // Prevent frequent scaling + if time.Since(scaler.lastScaleTime) < scaler.config.MinScaleInterval { + return + } + + current := scaler.getCurrentMetrics() + + // Scale up conditions + if current.QueueUtilization > scaler.config.ScaleUpThreshold || + current.AvgResponseTime > scaler.config.MaxResponseTime { + + scaler.scaleUp(current) + return + } + + // Scale down conditions + if current.QueueUtilization < scaler.config.ScaleDownThreshold && + current.AvgResponseTime < scaler.config.TargetResponseTime { + + scaler.scaleDown(current) + return + } +} + +func (scaler *AdaptiveScaler) scaleUp(metrics *CurrentMetrics) { + currentWorkers := scaler.pool.GetWorkerCount() + targetWorkers := int(float64(currentWorkers) * scaler.config.ScaleUpFactor) + + // Respect maximum limits + if targetWorkers > scaler.config.MaxWorkers { + targetWorkers = scaler.config.MaxWorkers + } + + additionalWorkers := targetWorkers - currentWorkers + if additionalWorkers > 0 { + scaler.pool.AddWorkers(additionalWorkers) + scaler.lastScaleTime = time.Now() + scaler.metrics.RecordScaleUp(additionalWorkers) + } +} +``` + +### **Provider-Specific Optimization** + +```go +type ProviderOptimization struct { + // Provider characteristics + ProviderName string `json:"provider_name"` + RateLimit int `json:"rate_limit"` // Requests per second + AvgLatency time.Duration `json:"avg_latency"` // Average response time + ErrorRate float64 `json:"error_rate"` // Historical error rate + + // Optimal configuration + OptimalWorkers int `json:"optimal_workers"` + OptimalBuffer int `json:"optimal_buffer"` + TimeoutConfig time.Duration `json:"timeout_config"` + RetryStrategy RetryConfig `json:"retry_strategy"` +} + +func CalculateOptimalConcurrency(provider ProviderOptimization) ConcurrencyConfig { + // Calculate based on rate limits and latency + optimalWorkers := provider.RateLimit * int(provider.AvgLatency.Seconds()) + + // Adjust for error rate (more workers for higher error rate) + errorAdjustment := 1.0 + provider.ErrorRate + optimalWorkers = int(float64(optimalWorkers) * errorAdjustment) + + // Buffer should be 2-3x worker count for smooth operation + optimalBuffer := optimalWorkers * 3 + + return ConcurrencyConfig{ + Concurrency: optimalWorkers, + BufferSize: optimalBuffer, + Timeout: provider.AvgLatency * 2, // 2x avg latency for timeout + } +} +``` + +--- + +## Concurrency Monitoring & Metrics + +### **Key Concurrency Metrics** + +```mermaid +graph TB + subgraph "Worker Metrics" + ActiveWorkers[Active Workers
Current Count] + IdleWorkers[Idle Workers
Available Count] + BusyWorkers[Busy Workers
Processing Count] + end + + subgraph "Queue Metrics" + QueueDepth[Queue Depth
Pending Jobs] + QueueThroughput[Queue Throughput
Jobs/Second] + QueueWaitTime[Queue Wait Time
Average Delay] + end + + subgraph "Performance Metrics" + GoroutineCount[Goroutine Count
Total Active] + MemoryUsage[Memory Usage
Pool Utilization] + GCPressure[GC Pressure
Collection Frequency] + end + + subgraph "Health Metrics" + ErrorRate[Error Rate
Failed Jobs %] + PanicCount[Panic Count
Crashed Goroutines] + DeadlockDetection[Deadlock Detection
Blocked Operations] + end +``` + +**Metrics Collection Strategy:** + +Comprehensive concurrency monitoring provides operational insights and performance optimization data: + +**Worker Pool Monitoring:** + +- **Total Worker Tracking** - Monitor configured vs actual worker counts +- **Active Worker Monitoring** - Track workers currently processing requests +- **Idle Worker Analysis** - Identify unused capacity and optimization opportunities +- **Queue Depth Monitoring** - Track pending job backlog and processing delays + +**Performance Data Collection:** + +- **Throughput Metrics** - Measure jobs processed per second across all pools +- **Wait Time Analysis** - Track how long jobs wait in queues before processing +- **Memory Pool Performance** - Monitor hit/miss ratios for memory pool effectiveness +- **Goroutine Count Tracking** - Ensure goroutine counts remain within healthy limits + +**Health and Reliability Metrics:** + +- **Panic Recovery Tracking** - Count and analyze worker panic occurrences +- **Timeout Monitoring** - Track jobs that exceed processing time limits +- **Circuit Breaker Events** - Monitor provider isolation events and recoveries +- **Error Rate Analysis** - Track failure patterns for capacity planning + +**Real-Time Updates:** + +- **Live Metric Updates** - Worker metrics are updated continuously during operation +- **Processing Event Recording** - Each job completion updates relevant metrics +- **Performance Correlation** - Queue times and processing times are correlated for analysis +- **Success/Failure Tracking** - All job outcomes are recorded for reliability analysis + +--- + +## Deadlock Prevention & Detection + +### **Deadlock Prevention Strategies** + +```mermaid +flowchart TD + Strategy1[Lock Ordering
Consistent Acquisition] + Strategy2[Timeout-Based Locks
Context Cancellation] + Strategy3[Channel Select
Non-blocking Operations] + Strategy4[Resource Hierarchy
Layered Locking] + + Prevention[Deadlock Prevention
Design Patterns] + + Prevention --> Strategy1 + Prevention --> Strategy2 + Prevention --> Strategy3 + Prevention --> Strategy4 + + Strategy1 --> Success[No Deadlocks
Guaranteed Order] + Strategy2 --> Success + Strategy3 --> Success + Strategy4 --> Success +```` + +**Deadlock Prevention Implementation Strategy:** + +Bifrost employs multiple complementary strategies to prevent deadlocks in concurrent operations: + +**Lock Ordering Management:** + +- **Consistent Acquisition Order** - All locks are acquired in a predetermined order +- **Global Lock Registry** - Centralized registry maintains lock ordering relationships +- **Order Enforcement** - Lock acquisition automatically sorts by predetermined order +- **Dependency Tracking** - Lock dependencies are mapped to prevent circular waits + +**Timeout-Based Protection:** + +- **Default Timeouts** - All lock acquisitions have reasonable timeout limits +- **Context Cancellation** - Operations respect context cancellation for cleanup +- **Maximum Timeout Limits** - Upper bounds prevent indefinite blocking +- **Graceful Timeout Handling** - Timeout errors provide meaningful context + +**Multi-Lock Acquisition Process:** + +- **Ordered Sorting** - Multiple locks are sorted before acquisition attempts +- **Progressive Acquisition** - Locks are acquired one by one in sorted order +- **Failure Recovery** - Failed acquisitions trigger automatic cleanup of held locks +- **Resource Tracking** - All acquired locks are tracked for proper release + +**Lock Acquisition Safety:** + +- **Non-Blocking Detection** - Channel-based lock attempts prevent indefinite blocking +- **Timeout Enforcement** - All lock attempts respect configured timeout limits +- **Error Propagation** - Lock failures are properly propagated with context +- **Cleanup Guarantees** - Failed operations always clean up partially acquired resources + +**Deadlock Detection and Recovery:** + +- **Active Monitoring** - Continuous monitoring for potential deadlock conditions +- **Automatic Recovery** - Detected deadlocks trigger automatic resolution procedures +- **Resource Release** - Deadlock resolution involves strategic resource release +- **Prevention Learning** - Deadlock patterns inform prevention strategy improvements + +--- + +## Related Architecture Documentation + +- **[Request Flow](./request-flow)** - How concurrency fits in request processing +- **[Benchmarks](../../benchmarking/getting-started)** - Concurrency performance characteristics +- **[Plugin System](./plugins)** - Plugin concurrency considerations +- **[MCP System](./mcp)** - MCP concurrency and worker integration + +## Usage Documentation + +- **[Provider Configuration](../../quickstart/gateway/provider-configuration)** - Configure concurrency settings per provider +- **[Performance Analysis](../../benchmarking/getting-started)** - Memory pool configuration and optimization +- **[Performance Monitoring](../../features/telemetry)** - Monitor concurrency metrics and health +- **[Go SDK Usage](../../quickstart/go-sdk/setting-up)** - Use Bifrost concurrency in Go applications +- **[Gateway Setup](../../quickstart/gateway/setting-up)** - Deploy Bifrost with optimal concurrency settings + +--- + +**🎯 Next Step:** Understand how plugins integrate with the concurrency model in **[Plugin System](./plugins)**. +``` diff --git a/docs/architecture/core/mcp.mdx b/docs/architecture/core/mcp.mdx new file mode 100644 index 000000000..f716e369a --- /dev/null +++ b/docs/architecture/core/mcp.mdx @@ -0,0 +1,564 @@ +--- +title: "Model Context Protocol (MCP)" +description: "Deep dive into Bifrost's Model Context Protocol (MCP) integration - how external tool discovery, execution, and integration work internally." +icon: "toolbox" +--- + +## MCP Architecture Overview + +### **What is MCP in Bifrost?** + +The Model Context Protocol (MCP) system in Bifrost enables AI models to seamlessly discover and execute external tools, transforming static chat models into dynamic, action-capable agents. This architecture bridges the gap between AI reasoning and real-world tool execution. + +**Core MCP Principles:** + +- **Dynamic Discovery** - Tools are discovered at runtime, not hardcoded +- **Client-Side Execution** - Bifrost controls all tool execution for security +- **Multi-Protocol Support** - STDIO, HTTP, and SSE connection types +- **Request-Level Filtering** - Granular control over tool availability +- **Async Execution** - Non-blocking tool invocation and response handling + +### **MCP System Components** + +```mermaid +graph TB + subgraph "MCP Management Layer" + MCPMgr[MCP Manager
Central Controller] + ClientRegistry[Client Registry
Connection Management] + ToolDiscovery[Tool Discovery
Runtime Registration] + end + + subgraph "MCP Execution Layer" + ToolFilter[Tool Filter
Access Control] + ToolExecutor[Tool Executor
Invocation Engine] + ResultProcessor[Result Processor
Response Handling] + end + + subgraph "Connection Types" + STDIOConn[STDIO Connections
Command-line Tools] + HTTPConn[HTTP Connections
Web Services] + SSEConn[SSE Connections
Real-time Streams] + end + + subgraph "External MCP Servers" + FileSystem[Filesystem Tools
File Operations] + WebSearch[Web Search
Information Retrieval] + Database[Database Tools
Data Access] + Custom[Custom Tools
Business Logic] + end + + MCPMgr --> ClientRegistry + ClientRegistry --> ToolDiscovery + ToolDiscovery --> ToolFilter + ToolFilter --> ToolExecutor + ToolExecutor --> ResultProcessor + + ClientRegistry --> STDIOConn + ClientRegistry --> HTTPConn + ClientRegistry --> SSEConn + + STDIOConn --> FileSystem + HTTPConn --> WebSearch + HTTPConn --> Database + STDIOConn --> Custom +``` + +--- + +## MCP Connection Architecture + +### **Multi-Protocol Connection System** + +Bifrost supports four MCP connection types, each optimized for different tool deployment patterns: + +```mermaid +graph TB + subgraph "InProcess Connections" + InProcess[In-Memory Tools
Same Process] + InProcessEx[Examples:
β€’ Embedded tools
β€’ High-perf operations
β€’ Testing tools] + end + + subgraph "STDIO Connections" + STDIO[Command Line Tools
Local Execution] + STDIOEx[Examples:
β€’ Filesystem tools
β€’ Local scripts
β€’ CLI utilities] + end + + subgraph "HTTP Connections" + HTTP[Web Service Tools
Remote APIs] + HTTPEx[Examples:
β€’ Web search APIs
β€’ Database services
β€’ External integrations] + end + + subgraph "SSE Connections" + SSE[Real-time Tools
Streaming Data] + SSEEx[Examples:
β€’ Live data feeds
β€’ Real-time monitoring
β€’ Event streams] + end + + subgraph "Connection Characteristics" + Latency[Latency:
InProcess < STDIO < HTTP < SSE] + Security[Security:
InProcess/Local > HTTP > SSE] + Scalability[Scalability:
HTTP > SSE > STDIO > InProcess] + Complexity[Complexity:
InProcess < STDIO < HTTP < SSE] + end + + InProcess --> Latency + STDIO --> Latency + HTTP --> Security + SSE --> Scalability + HTTP --> Complexity +``` + +### **Connection Type Details** + +**InProcess Connections (In-Memory Tools):** + +- **Use Case:** Embedded tools, high-performance operations, testing +- **Performance:** Lowest possible latency (~0.1ms) with no IPC overhead +- **Security:** Highest security as tools run in the same process +- **Limitations:** Go package only, cannot be configured via JSON + +**STDIO Connections (Local Tools):** + +- **Use Case:** Command-line tools, local scripts, filesystem operations +- **Performance:** Low latency (~1-10ms) due to local execution +- **Security:** High security with full local control +- **Limitations:** Single-server deployment, resource sharing + +**HTTP Connections (Remote Services):** + +- **Use Case:** Web APIs, microservices, cloud functions +- **Performance:** Network-dependent latency (~10-500ms) +- **Security:** Configurable with authentication and encryption +- **Advantages:** Scalable, multi-server deployment, service isolation + +**SSE Connections (Streaming Tools):** + +- **Use Case:** Real-time data feeds, live monitoring, event streams +- **Performance:** Variable latency depending on stream frequency +- **Security:** Similar to HTTP with streaming capabilities +- **Benefits:** Real-time updates, persistent connections, event-driven + +> **MCP Configuration:** [MCP Setup Guide β†’](../../features/mcp) + +--- + +## Tool Discovery & Registration + +### **Dynamic Tool Discovery Process** + +The MCP system discovers tools at runtime rather than requiring static configuration, enabling flexible and adaptive tool availability: + +```mermaid +sequenceDiagram + participant Bifrost + participant MCPManager + participant MCPServer + participant ToolRegistry + participant AIModel + + Note over Bifrost: System Startup + Bifrost->>MCPManager: Initialize MCP System + MCPManager->>MCPServer: Establish Connection + MCPServer-->>MCPManager: Connection Ready + + MCPManager->>MCPServer: List Available Tools + MCPServer-->>MCPManager: Tool Definitions + MCPManager->>ToolRegistry: Register Tools + + Note over Bifrost: Runtime Request Processing + AIModel->>MCPManager: Request Available Tools + MCPManager->>ToolRegistry: Query Tools + ToolRegistry-->>MCPManager: Filtered Tool List + MCPManager-->>AIModel: Available Tools + + AIModel->>MCPManager: Execute Tool Call + MCPManager->>MCPServer: Tool Invocation + MCPServer->>MCPServer: Execute Tool Logic + MCPServer-->>MCPManager: Tool Result + MCPManager-->>AIModel: Enhanced Response +``` + +### **Tool Registry Management** + +**Registration Process:** + +1. **Connection Establishment** - MCP client connects to configured servers +2. **Capability Exchange** - Server announces available tools and schemas +3. **Tool Validation** - Bifrost validates tool definitions and security +4. **Registry Update** - Tools are registered in the internal tool registry +5. **Availability Notification** - Tools become available for AI model use + +**Registry Features:** + +- **Dynamic Updates** - Tools can be added/removed during runtime +- **Version Management** - Support for tool versioning and compatibility +- **Access Control** - Request-level tool filtering and permissions +- **Health Monitoring** - Continuous tool availability checking + +**Tool Metadata Structure:** + +- **Name & Description** - Human-readable tool identification +- **Parameters Schema** - JSON schema for tool input validation +- **Return Schema** - Expected response format definition +- **Capabilities** - Tool feature flags and limitations +- **Authentication** - Required credentials and permissions + +--- + +## Tool Filtering & Access Control + +### **Multi-Level Filtering System** + +Bifrost provides granular control over tool availability through a sophisticated filtering system: + +```mermaid +flowchart TD + Request[Incoming Request] --> GlobalFilter{Global MCP Filter} + GlobalFilter -->|Enabled| ClientFilter[MCP Client Filtering] + GlobalFilter -->|Disabled| NoMCP[No MCP Tools] + + ClientFilter --> IncludeClients{Include Clients?} + IncludeClients -->|Yes| IncludeList[Include Specified
MCP Clients] + IncludeClients -->|No| AllClients[All MCP Clients] + + IncludeList --> ExcludeClients{Exclude Clients?} + AllClients --> ExcludeClients + ExcludeClients -->|Yes| RemoveClients[Remove Excluded
MCP Clients] + ExcludeClients -->|No| ClientsFiltered[Filtered Clients] + + RemoveClients --> ToolFilter[Tool-Level Filtering] + ClientsFiltered --> ToolFilter + + ToolFilter --> IncludeTools{Include Tools?} + IncludeTools -->|Yes| IncludeSpecific[Include Specified
Tools Only] + IncludeTools -->|No| AllTools[All Available Tools] + + IncludeSpecific --> ExcludeTools{Exclude Tools?} + AllTools --> ExcludeTools + ExcludeTools -->|Yes| RemoveTools[Remove Excluded
Tools] + ExcludeTools -->|No| FinalTools[Final Tool Set] + + RemoveTools --> FinalTools + FinalTools --> AIModel[Available to AI Model] + NoMCP --> AIModel +``` + +### **Filtering Configuration Levels** + +**Request-Level Filtering:** + +```bash +# Include only specific MCP clients +curl -X POST http://localhost:8080/v1/chat/completions \ + -H "mcp-include-clients: filesystem,websearch" \ + -d '{"model": "gpt-4o-mini", "messages": [...]}' + +# Exclude dangerous tools +curl -X POST http://localhost:8080/v1/chat/completions \ + -H "mcp-exclude-tools: delete_file,format_disk" \ + -d '{"model": "gpt-4o-mini", "messages": [...]}' +``` + +**Configuration-Level Filtering:** + +- **Client Selection** - Choose which MCP servers to connect to +- **Tool Blacklisting** - Permanently disable dangerous or unwanted tools +- **Permission Mapping** - Map user roles to available tool sets +- **Environment-Based** - Different tool sets for development vs production + +**Security Benefits:** + +- **Principle of Least Privilege** - Only necessary tools are exposed +- **Dynamic Access Control** - Per-request tool availability +- **Audit Trail** - Track which tools are used by which requests +- **Risk Mitigation** - Prevent access to dangerous operations + +> **πŸ“– Tool Filtering:** [MCP Tool Control β†’](../../features/mcp) + +--- + +## Tool Execution Engine + +### **Async Tool Execution Architecture** + +The MCP execution engine handles tool invocation asynchronously to maintain system responsiveness and enable complex multi-tool workflows: + +```mermaid +sequenceDiagram + participant AIModel + participant ExecutionEngine + participant ToolInvoker + participant MCPServer + participant ResultProcessor + + AIModel->>ExecutionEngine: Tool Call Request + ExecutionEngine->>ExecutionEngine: Validate Tool Call + ExecutionEngine->>ToolInvoker: Queue Tool Execution + + Note over ToolInvoker: Async Tool Execution + ToolInvoker->>MCPServer: Invoke Tool + MCPServer->>MCPServer: Execute Tool Logic + MCPServer-->>ToolInvoker: Raw Tool Result + + ToolInvoker->>ResultProcessor: Process Result + ResultProcessor->>ResultProcessor: Format & Validate + ResultProcessor-->>ExecutionEngine: Processed Result + + ExecutionEngine-->>AIModel: Tool Execution Complete + + Note over AIModel: Multi-turn Conversation + AIModel->>ExecutionEngine: Continue with Tool Results + ExecutionEngine->>ExecutionEngine: Merge Results into Context + ExecutionEngine-->>AIModel: Enhanced Response +``` + +### **Execution Flow Characteristics** + +**Validation Phase:** + +- **Parameter Validation** - Ensure tool arguments match expected schema +- **Permission Checking** - Verify tool access permissions for the request +- **Rate Limiting** - Apply per-tool and per-user rate limits +- **Security Scanning** - Check for potentially dangerous operations + +**Execution Phase:** + +- **Timeout Management** - Bounded execution time to prevent hanging +- **Error Handling** - Graceful handling of tool failures and timeouts +- **Result Streaming** - Support for tools that return streaming responses +- **Resource Monitoring** - Track tool resource usage and performance + +**Response Phase:** + +- **Result Formatting** - Convert tool outputs to consistent format +- **Error Enrichment** - Add context and suggestions for tool failures +- **Multi-Result Aggregation** - Combine multiple tool outputs coherently +- **Context Integration** - Merge tool results into conversation context + +### **Multi-Turn Conversation Support** + +The MCP system enables sophisticated multi-turn conversations where AI models can: + +1. **Initial Tool Discovery** - Request available tools for a given context +2. **Tool Execution** - Execute one or more tools based on user request +3. **Result Analysis** - Analyze tool outputs and determine next steps +4. **Follow-up Actions** - Execute additional tools based on previous results +5. **Response Synthesis** - Combine tool results into coherent user response + +**Example Multi-Turn Flow:** + +``` +User: "Find recent news about AI and save interesting articles" +AI: β†’ Execute web_search("AI news recent") +AI: β†’ Analyze search results +AI: β†’ Execute save_article() for each interesting result +AI: β†’ Respond with summary of saved articles +``` + +### **Complete User-Controlled Tool Execution Flow** + +The following diagram shows the end-to-end user experience with MCP tool execution, highlighting the critical user control points and decision-making process: + +```mermaid +flowchart TD + A["πŸ‘€ User Message
\"List files in current directory\""] --> B["πŸ€– Bifrost Core"] + + B --> C["πŸ”§ MCP Manager
Auto-discovers and adds
available tools to request"] + + C --> D["🌐 LLM Provider
(OpenAI, Anthropic, etc.)"] + + D --> E{"πŸ” Response contains
tool_calls?"} + + E -->|No| F["βœ… Final Response
Display to user"] + + E -->|Yes| G["πŸ“ Add assistant message
with tool_calls to history"] + + G --> H["πŸ›‘οΈ YOUR EXECUTION LOGIC
(Security, Approval, Logging)"] + + H --> I{"πŸ€” User Decision Point
Execute this tool?"} + + I -->|Deny| J["❌ Create denial result
Add to conversation history"] + + I -->|Approve| K["βš™οΈ client.ExecuteMCPTool()
Bifrost executes via MCP"] + + K --> L["πŸ“Š Tool Result
Add to conversation history"] + + J --> M["πŸ”„ Continue conversation loop
Send updated history back to LLM"] + L --> M + + M --> D + + style A fill:#e1f5fe + style F fill:#e8f5e8 + style H fill:#fff3e0 + style I fill:#fce4ec + style K fill:#f3e5f5 +``` + +**Key Flow Characteristics:** + +**User Control Points:** + +- **Security Layer** - Your application controls all tool execution decisions +- **Approval Gate** - Users can approve or deny each tool execution +- **Transparency** - Full visibility into what tools will be executed and why +- **Conversation Continuity** - Tool results seamlessly integrate into conversation flow + +**Security Benefits:** + +- **No Automatic Execution** - Tools never execute without explicit approval +- **Audit Trail** - Complete logging of all tool execution decisions +- **Contextual Security** - Approval decisions can consider full conversation context +- **Graceful Denials** - Denied tools result in informative responses, not errors + +**Implementation Patterns:** + +```go +// Example tool execution control in your application +func handleToolExecution(toolCall schemas.ToolCall, userContext UserContext) error { + // YOUR SECURITY AND APPROVAL LOGIC HERE + if !userContext.HasPermission(toolCall.Function.Name) { + return createDenialResponse("Tool not permitted for user role") + } + + if requiresApproval(toolCall) { + approved := promptUserForApproval(toolCall) + if !approved { + return createDenialResponse("User denied tool execution") + } + } + + // Execute the tool via Bifrost + result, err := client.ExecuteMCPTool(ctx, toolCall) + if err != nil { + return handleToolError(err) + } + + return addToolResultToHistory(result) +} +``` + +This flow ensures that while AI models can discover and request tool usage, all actual execution remains under user control, providing the perfect balance of AI capability and human oversight. + +--- + +## MCP Integration Patterns + +### **Common Integration Scenarios** + +**1. Filesystem Operations** + +- **Tools:** `list_files`, `read_file`, `write_file`, `create_directory` +- **Use Cases:** Code analysis, document processing, file management +- **Security:** Sandboxed file access, path validation, permission checks +- **Performance:** Local execution for fast file operations + +**2. Web Search & Information Retrieval** + +- **Tools:** `web_search`, `fetch_url`, `extract_content`, `summarize` +- **Use Cases:** Research assistance, fact-checking, content gathering +- **Integration:** External search APIs, content parsing services +- **Caching:** Response caching for repeated queries + +**3. Database Operations** + +- **Tools:** `query_database`, `insert_record`, `update_record`, `schema_info` +- **Use Cases:** Data analysis, report generation, database administration +- **Security:** Read-only access by default, query validation, injection prevention +- **Performance:** Connection pooling, query optimization + +**4. API Integrations** + +- **Tools:** Custom business logic tools, third-party service integration +- **Use Cases:** CRM operations, payment processing, notification sending +- **Authentication:** API key management, OAuth token handling +- **Error Handling:** Retry logic, fallback mechanisms + +### **MCP Server Development Patterns** + +**Simple STDIO Server:** + +- **Language:** Any language that can read/write JSON to stdin/stdout +- **Deployment:** Single executable, minimal dependencies +- **Use Case:** Local tools, development utilities, simple scripts + +**HTTP Service Server:** + +- **Architecture:** RESTful API with MCP protocol endpoints +- **Scalability:** Horizontal scaling, load balancing +- **Use Case:** Shared tools, enterprise integrations, cloud services + +**Hybrid Approach:** + +- **Local + Remote:** Combine STDIO tools for local operations with HTTP for remote services +- **Failover:** Use local fallbacks when remote services are unavailable +- **Optimization:** Route tool calls to most appropriate execution environment + +> **πŸ“– MCP Development:** [Tool Development Guide β†’](../../features/mcp) + +--- + +## Security & Safety Considerations + +### **MCP Security Architecture** + +```mermaid +graph TB + subgraph "Security Layers" + L1[Connection Security
Authentication & Encryption] + L2[Tool Validation
Schema & Permission Checks] + L3[Execution Security
Sandboxing & Limits] + L4[Result Security
Output Validation & Filtering] + end + + subgraph "Threat Mitigation" + T1[Malicious Tools
Code Injection Prevention] + T2[Resource Abuse
Rate Limiting & Quotas] + T3[Data Exposure
Output Sanitization] + T4[System Access
Privilege Isolation] + end + + L1 --> T1 + L2 --> T2 + L3 --> T4 + L4 --> T3 +``` + +**Security Measures:** + +**Connection Security:** + +- **Authentication** - API keys, certificates, or token-based auth for HTTP/SSE +- **Encryption** - TLS for HTTP connections, secure pipes for STDIO +- **Network Isolation** - Firewall rules and network segmentation + +**Execution Security:** + +- **Sandboxing** - Isolated execution environments for tools +- **Resource Limits** - CPU, memory, and time constraints +- **Permission Model** - Principle of least privilege for tool access + +**Data Security:** + +- **Input Validation** - Strict parameter validation before tool execution +- **Output Sanitization** - Remove sensitive data from tool responses +- **Audit Logging** - Complete audit trail of tool usage + +**Operational Security:** + +- **Regular Updates** - Keep MCP servers and tools updated +- **Monitoring** - Continuous security monitoring and alerting +- **Incident Response** - Procedures for security incidents involving tools + +> **πŸ“– MCP Security:** [Security Best Practices β†’](../../features/mcp) + +--- + +## Related Architecture Documentation + +- **[Request Flow](./request-flow)** - MCP integration in request processing +- **[Concurrency Model](./concurrency)** - MCP concurrency and worker integration +- **[Plugin System](./plugins)** - Integration between MCP and plugin systems +- **[Benchmarks](../../benchmarking/getting-started)** - MCP performance impact and optimization + + + diff --git a/docs/architecture/core/plugins.mdx b/docs/architecture/core/plugins.mdx new file mode 100644 index 000000000..f4d9483a7 --- /dev/null +++ b/docs/architecture/core/plugins.mdx @@ -0,0 +1,590 @@ +--- +title: "Plugins" +description: "Deep dive into Bifrost's extensible plugin architecture - how plugins work internally, lifecycle management, execution model, and integration patterns." +icon: "puzzle-piece" +--- + +## Plugin Architecture Philosophy + +### **Core Design Principles** + +Bifrost's plugin system is built around five key principles that ensure extensibility without compromising performance or reliability: + +| Principle | Implementation | Benefit | +| ----------------------------- | ------------------------------------------------ | ------------------------------------------------ | +| **Plugin-First Design** | Core logic designed around plugin hook points | Maximum extensibility without core modifications | +| **Zero-Copy Integration** | Direct memory access to request/response objects | Minimal performance overhead | +| **Lifecycle Management** | Complete plugin lifecycle with automatic cleanup | Resource safety and leak prevention | +| **Interface-Based Safety** | Well-defined interfaces for type safety | Compile-time validation and consistency | +| **Failure Isolation** | Plugin errors don't crash the core system | Fault tolerance and system stability | + +### **Plugin System Overview** + +```mermaid +graph TB + subgraph "Plugin Management Layer" + PluginMgr[Plugin Manager
Central Controller] + Registry[Plugin Registry
Discovery & Loading] + Lifecycle[Lifecycle Manager
State Management] + end + + subgraph "Plugin Execution Layer" + Pipeline[Plugin Pipeline
Execution Orchestrator] + PreHooks[Pre-Processing Hooks
Request Modification] + PostHooks[Post-Processing Hooks
Response Enhancement] + end + + subgraph "Plugin Categories" + Auth[Authentication
& Authorization] + RateLimit[Rate Limiting
& Throttling] + Transform[Data Transformation
& Validation] + Monitor[Monitoring
& Analytics] + Custom[Custom Business
Logic] + end + + PluginMgr --> Registry + Registry --> Lifecycle + Lifecycle --> Pipeline + + Pipeline --> PreHooks + Pipeline --> PostHooks + + PreHooks --> Auth + PreHooks --> RateLimit + PostHooks --> Transform + PostHooks --> Monitor + PostHooks --> Custom +``` + +--- + +## Plugin Lifecycle Management + +### **Complete Lifecycle States** + +Every plugin goes through a well-defined lifecycle that ensures proper resource management and error handling: + +```mermaid +stateDiagram-v2 + [*] --> PluginInit: Plugin Creation + PluginInit --> Registered: Add to BifrostConfig + Registered --> PreHookCall: Request Received + + PreHookCall --> ModifyRequest: Normal Flow + PreHookCall --> ShortCircuitResponse: Return Response + PreHookCall --> ShortCircuitError: Return Error + + ModifyRequest --> ProviderCall: Send to Provider + ProviderCall --> PostHookCall: Receive Response + + ShortCircuitResponse --> PostHookCall: Skip Provider + ShortCircuitError --> PostHookCall: Pipeline Symmetry + + PostHookCall --> ModifyResponse: Process Result + PostHookCall --> RecoverError: Error Recovery + PostHookCall --> FallbackCheck: Check AllowFallbacks + PostHookCall --> ResponseReady: Pass Through + + FallbackCheck --> TryFallback: AllowFallbacks=true/nil + FallbackCheck --> ResponseReady: AllowFallbacks=false + TryFallback --> PreHookCall: Next Provider + + ModifyResponse --> ResponseReady: Modified + RecoverError --> ResponseReady: Recovered + ResponseReady --> [*]: Return to Client + + Registered --> CleanupCall: Bifrost Shutdown + CleanupCall --> [*]: Plugin Destroyed +``` + +### **Lifecycle Phase Details** + +**Discovery Phase:** + +- **Purpose:** Find and catalog available plugins +- **Sources:** Command line, environment variables, JSON configuration, directory scanning +- **Validation:** Basic existence and format checks +- **Output:** Plugin descriptors with metadata + +**Loading Phase:** + +- **Purpose:** Load plugin binaries into memory +- **Security:** Digital signature verification and checksum validation +- **Compatibility:** Interface implementation validation +- **Resource:** Memory and capability assessment + +**Initialization Phase:** + +- **Purpose:** Configure plugin with runtime settings +- **Timeout:** Bounded initialization time to prevent hanging +- **Dependencies:** External service connectivity verification +- **State:** Internal state setup and resource allocation + +**Runtime Phase:** + +- **Purpose:** Active request processing +- **Monitoring:** Continuous health checking and performance tracking +- **Recovery:** Automatic error recovery and degraded mode handling +- **Metrics:** Real-time performance and health metrics collection + +> **Plugin Lifecycle:** [Plugin Management β†’](../../enterprise/custom-plugins) + +--- + +## Plugin Execution Pipeline + +### **Request Processing Flow** + +The plugin pipeline ensures consistent, predictable execution while maintaining high performance: + +#### **Normal Execution Flow (No Short-Circuit)** + +```mermaid +sequenceDiagram + participant Client + participant Bifrost + participant Plugin1 + participant Plugin2 + participant Provider + + Client->>Bifrost: Request + Bifrost->>Plugin1: PreHook(request) + Plugin1-->>Bifrost: modified request + Bifrost->>Plugin2: PreHook(request) + Plugin2-->>Bifrost: modified request + Bifrost->>Provider: API Call + Provider-->>Bifrost: response + Bifrost->>Plugin2: PostHook(response) + Plugin2-->>Bifrost: modified response + Bifrost->>Plugin1: PostHook(response) + Plugin1-->>Bifrost: modified response + Bifrost-->>Client: Final Response +``` + +**Execution Order:** + +1. **PreHooks:** Execute in registration order (1 β†’ 2 β†’ N) +2. **Provider Call:** If no short-circuit occurred +3. **PostHooks:** Execute in reverse order (N β†’ 2 β†’ 1) + +#### **Short-Circuit Response Flow (Cache Hit)** + +```mermaid +sequenceDiagram + participant Client + participant Bifrost + participant Cache + participant Auth + participant Provider + + Client->>Bifrost: Request + Bifrost->>Auth: PreHook(request) + Auth-->>Bifrost: modified request + Bifrost->>Cache: PreHook(request) + Cache-->>Bifrost: PluginShortCircuit{Response} + Note over Provider: Provider call skipped + Bifrost->>Cache: PostHook(response) + Cache-->>Bifrost: modified response + Bifrost->>Auth: PostHook(response) + Auth-->>Bifrost: modified response + Bifrost-->>Client: Cached Response +``` + +#### **Streaming Response Flow** + +For streaming responses, the plugin pipeline executes post-hooks for every delta/chunk received from the provider: + +```mermaid +sequenceDiagram + participant Client + participant Bifrost + participant Plugin1 + participant Plugin2 + participant Provider + + Client->>Bifrost: Stream Request + Bifrost->>Plugin1: PreHook(request) + Plugin1-->>Bifrost: modified request + Bifrost->>Plugin2: PreHook(request) + Plugin2-->>Bifrost: modified request + Bifrost->>Provider: Stream API Call + + loop For Each Delta + Provider-->>Bifrost: stream delta + Bifrost->>Plugin2: PostHook(delta) + Plugin2-->>Bifrost: modified delta + Bifrost->>Plugin1: PostHook(delta) + Plugin1-->>Bifrost: modified delta + Bifrost-->>Client: Send Delta + end + + Provider-->>Bifrost: final chunk (finish reason) + Bifrost->>Plugin2: PostHook(final) + Plugin2-->>Bifrost: modified final + Bifrost->>Plugin1: PostHook(final) + Plugin1-->>Bifrost: modified final + Bifrost-->>Client: Final Chunk +``` + +**Streaming Execution Characteristics:** + +1. **Delta Processing:** + - Each stream delta (chunk) goes through all post-hooks + - Plugins can modify/transform each delta before it reaches the client + - Deltas can contain: text content, tool calls, role changes, or usage info + +2. **Special Delta Types:** + - **Start Event:** Initial delta with role information + - **Content Delta:** Regular text or tool call content + - **Usage Update:** Token usage statistics (if enabled) + - **Final Chunk:** Contains finish reason and any final metadata + +3. **Plugin Considerations:** + - Plugins must handle streaming responses efficiently + - Each delta should be processed quickly to maintain stream responsiveness + - Plugins can track state across deltas using context + - Heavy processing should be done asynchronously + +4. **Error Handling:** + - If a post-hook returns an error, it's sent as an error stream chunk + - Stream is terminated after error chunks + - Plugins can recover from errors by providing valid responses + +5. **Performance Optimization:** + - Lightweight delta processing to minimize latency + - Object pooling for common data structures + - Non-blocking operations for logging and metrics + - Efficient memory management for stream processing + +> **Streaming Details:** [Streaming Guide β†’](../../quickstart/gateway/streaming) + +**Short-Circuit Rules:** + +- **Provider Skipped:** When plugin returns short-circuit response/error +- **PostHook Guarantee:** All executed PreHooks get corresponding PostHook calls +- **Reverse Order:** PostHooks execute in reverse order of PreHooks + +#### **Short-Circuit Error Flow (Allow Fallbacks)** + +```mermaid +sequenceDiagram + participant Client + participant Bifrost + participant Plugin1 + participant Provider1 + participant Provider2 + + Client->>Bifrost: Request (Provider1 + Fallback Provider2) + Bifrost->>Plugin1: PreHook(request) + Plugin1-->>Bifrost: PluginShortCircuit{Error, AllowFallbacks=true} + Note over Provider1: Provider1 call skipped + Bifrost->>Plugin1: PostHook(error) + Plugin1-->>Bifrost: error unchanged + + Note over Bifrost: Try fallback provider + Bifrost->>Plugin1: PreHook(request for Provider2) + Plugin1-->>Bifrost: modified request + Bifrost->>Provider2: API Call + Provider2-->>Bifrost: response + Bifrost->>Plugin1: PostHook(response) + Plugin1-->>Bifrost: modified response + Bifrost-->>Client: Final Response +``` + +#### **Error Recovery Flow** + +```mermaid +sequenceDiagram + participant Client + participant Bifrost + participant Plugin1 + participant Plugin2 + participant Provider + participant RecoveryPlugin + + Client->>Bifrost: Request + Bifrost->>Plugin1: PreHook(request) + Plugin1-->>Bifrost: modified request + Bifrost->>Plugin2: PreHook(request) + Plugin2-->>Bifrost: modified request + Bifrost->>RecoveryPlugin: PreHook(request) + RecoveryPlugin-->>Bifrost: modified request + Bifrost->>Provider: API Call + Provider-->>Bifrost: error + Bifrost->>RecoveryPlugin: PostHook(error) + RecoveryPlugin-->>Bifrost: recovered response + Bifrost->>Plugin2: PostHook(response) + Plugin2-->>Bifrost: modified response + Bifrost->>Plugin1: PostHook(response) + Plugin1-->>Bifrost: modified response + Bifrost-->>Client: Recovered Response +``` + +**Error Recovery Features:** + +- **Error Transformation:** Plugins can convert errors to successful responses +- **Graceful Degradation:** Provide fallback responses for service failures +- **Context Preservation:** Error context is maintained through recovery process + +### **Complex Plugin Decision Flow** + +Real-world plugin interactions involving authentication, rate limiting, and caching with different decision paths: + +```mermaid +graph TD + A["Client Request"] --> B["Bifrost"] + B --> C["Auth Plugin PreHook"] + C --> D{"Authenticated?"} + D -->|No| E["Return Auth Error
AllowFallbacks=false"] + D -->|Yes| F["RateLimit Plugin PreHook"] + F --> G{"Rate Limited?"} + G -->|Yes| H["Return Rate Error
AllowFallbacks=nil"] + G -->|No| I["Cache Plugin PreHook"] + I --> J{"Cache Hit?"} + J -->|Yes| K["Return Cached Response"] + J -->|No| L["Provider API Call"] + L --> M["Cache Plugin PostHook"] + M --> N["Store in Cache"] + N --> O["RateLimit Plugin PostHook"] + O --> P["Auth Plugin PostHook"] + P --> Q["Final Response"] + + E --> R["Skip Fallbacks"] + H --> S["Try Fallback Provider"] + K --> T["Skip Provider Call"] +``` + +### **Execution Characteristics** + +**Symmetric Execution Pattern:** + +- **Pre-processing:** Plugins execute in priority order (high to low) +- **Post-processing:** Plugins execute in reverse order (low to high) +- **Rationale:** Ensures proper cleanup and state management (last in, first out) + +**Performance Optimizations:** + +- **Timeout Boundaries:** Each plugin has configurable execution timeouts +- **Panic Recovery:** Plugin panics are caught and logged without crashing the system +- **Resource Limits:** Memory and CPU limits prevent runaway plugins +- **Circuit Breaking:** Repeated failures trigger plugin isolation + +**Error Handling Strategies:** + +- **Continue:** Use original request/response if plugin fails +- **Fail Fast:** Return error immediately if critical plugin fails +- **Retry:** Attempt plugin execution with exponential backoff +- **Fallback:** Use alternative plugin or default behavior + +> **Plugin Execution:** [Request Flow β†’](./request-flow#stage-3-plugin-pipeline-processing) + +--- + +## Plugin Discovery & Configuration + +### **Configuration Methods** + +**Current: Command-Line Plugin Loading** + +```bash +# Docker deployment +docker run -p 8080:8080 \ + -e APP_PLUGINS="maxim,custom-plugin" \ + maximhq/bifrost + +# NPM deployment +npx -y @maximhq/bifrost -plugins "maxim" +``` + +**Future: JSON Configuration System** + +```json +{ + "plugins": [ + { + "name": "maxim", + "source": "../../plugins/maxim", + "type": "local", + "config": { + "api_key": "env.MAXIM_API_KEY", + "log_repo_id": "env.MAXIM_LOG_REPO_ID" + } + } + ] +} +``` + +> **Plugin Configuration:** [Plugin Setup β†’](../../enterprise/custom-plugins) + +--- + +## Security & Validation + +### **Multi-Layer Security Model** + +Plugin security operates at multiple layers to ensure system integrity: + +```mermaid +graph TB + subgraph "Security Validation Layers" + L1[Layer 1: Binary Validation
Signature & Checksum] + L2[Layer 2: Interface Validation
Type Safety & Compatibility] + L3[Layer 3: Runtime Validation
Resource Limits & Timeouts] + L4[Layer 4: Execution Isolation
Panic Recovery & Error Handling] + end + + subgraph "Security Benefits" + Integrity[Code Integrity
Verified Authenticity] + Safety[Type Safety
Compile-time Checks] + Stability[System Stability
Isolated Failures] + Performance[Performance Protection
Resource Limits] + end + + L1 --> Integrity + L2 --> Safety + L3 --> Performance + L4 --> Stability +``` + +### **Validation Process** + +**Binary Security:** + +- **Digital Signatures:** Cryptographic verification of plugin authenticity +- **Checksum Validation:** File integrity verification +- **Source Verification:** Trusted source requirements + +**Interface Security:** + +- **Type Safety:** Interface implementation verification +- **Version Compatibility:** Plugin API version checking +- **Memory Safety:** Safe memory access patterns + +**Runtime Security:** + +- **Resource Quotas:** Memory and CPU usage limits +- **Execution Timeouts:** Bounded execution time +- **Sandbox Execution:** Isolated execution environment + +**Operational Security:** + +- **Health Monitoring:** Continuous plugin health assessment +- **Error Tracking:** Plugin error rate monitoring +- **Automatic Recovery:** Failed plugin restart and recovery + +--- + +## Plugin Performance & Monitoring + +### **Comprehensive Metrics System** + +Bifrost provides detailed metrics for plugin performance and health monitoring: + +```mermaid +graph TB + subgraph "Execution Metrics" + ExecTime[Execution Time
Latency per Plugin] + ExecCount[Execution Count
Request Volume] + SuccessRate[Success Rate
Error Percentage] + Throughput[Throughput
Requests/Second] + end + + subgraph "Resource Metrics" + MemoryUsage[Memory Usage
Per Plugin Instance] + CPUUsage[CPU Utilization
Processing Time] + IOMetrics[I/O Operations
Network/Disk Activity] + PoolUtilization[Pool Utilization
Resource Efficiency] + end + + subgraph "Health Metrics" + ErrorRate[Error Rate
Failed Executions] + PanicCount[Panic Recovery
Crash Events] + TimeoutCount[Timeout Events
Slow Executions] + RecoveryRate[Recovery Success
Failure Handling] + end + + subgraph "Business Metrics" + AddedLatency[Added Latency
Plugin Overhead] + SystemImpact[System Impact
Overall Performance] + FeatureUsage[Feature Usage
Plugin Utilization] + CostImpact[Cost Impact
Resource Consumption] + end +``` + +### **Performance Characteristics** + +**Plugin Execution Performance:** + +- **Typical Overhead:** 1-10ΞΌs per plugin for simple operations +- **Authentication Plugins:** 1-5ΞΌs for key validation +- **Rate Limiting Plugins:** 500ns for quota checks +- **Monitoring Plugins:** 200ns for metric collection +- **Transformation Plugins:** 2-10ΞΌs depending on complexity + +**Resource Usage Patterns:** + +- **Memory Efficiency:** Object pooling reduces allocations +- **CPU Optimization:** Minimal processing overhead +- **Network Impact:** Configurable external service calls +- **Storage Overhead:** Minimal for stateless plugins + +--- + +## Plugin Integration Patterns + +### **Common Integration Scenarios** + +**1. Authentication & Authorization** + +- **Pre-processing Hook:** Validate API keys or JWT tokens +- **Configuration:** External identity provider integration +- **Error Handling:** Return 401/403 responses for invalid credentials +- **Performance:** Sub-5ΞΌs validation with caching + +**2. Rate Limiting & Quotas** + +- **Pre-processing Hook:** Check request quotas and limits +- **Storage:** Redis or in-memory rate limit tracking +- **Algorithms:** Token bucket, sliding window, fixed window +- **Responses:** 429 Too Many Requests with retry headers + +**3. Request/Response Transformation** + +- **Dual Hooks:** Pre-processing for requests, post-processing for responses +- **Use Cases:** Data format conversion, field mapping, content filtering +- **Performance:** Streaming transformations for large payloads +- **Compatibility:** Provider-specific format adaptations + +**4. Monitoring & Analytics** + +- **Post-processing Hook:** Collect metrics and logs after request completion +- **Destinations:** Prometheus, DataDog, custom analytics systems +- **Data:** Request/response metadata, performance metrics, error tracking +- **Privacy:** Configurable data sanitization and filtering + +### **Plugin Communication Patterns** + +**Plugin-to-Plugin Communication:** + +- **Shared Context:** Plugins can store data in request context for downstream plugins +- **Event System:** Plugin can emit events for other plugins to consume +- **Data Passing:** Structured data exchange between related plugins + +**Plugin-to-External Service Communication:** + +- **HTTP Clients:** Built-in HTTP client pools for external API calls +- **Database Connections:** Connection pooling for database access +- **Message Queues:** Integration with message queue systems +- **Caching Systems:** Redis, Memcached integration for state storage + +> **πŸ“– Integration Examples:** [Plugin Development Guide β†’](../../enterprise/custom-plugins) + +--- + +## Related Architecture Documentation + +- **[Request Flow](./request-flow)** - Plugin execution in request processing pipeline +- **[Concurrency Model](./concurrency)** - Plugin concurrency and threading considerations +- **[Benchmarks](../../benchmarking/getting-started)** - Plugin performance characteristics and optimization +- **[MCP System](./mcp)** - Integration between plugins and MCP system + diff --git a/docs/architecture/core/providers.mdx b/docs/architecture/core/providers.mdx new file mode 100644 index 000000000..e69de29bb diff --git a/docs/architecture/core/request-flow.mdx b/docs/architecture/core/request-flow.mdx new file mode 100644 index 000000000..7cb43eb7e --- /dev/null +++ b/docs/architecture/core/request-flow.mdx @@ -0,0 +1,529 @@ +--- +title: "Request Flow" +description: "Deep dive into Bifrost's request processing pipeline - from transport layer ingestion through provider execution to response delivery." +icon: "route" +--- + +## Stage 1: Transport Layer Processing + +### **HTTP Transport Flow** + +```mermaid +sequenceDiagram + participant Client + participant HTTPTransport + participant Router + participant Validation + + Client->>HTTPTransport: POST /v1/chat/completions + HTTPTransport->>HTTPTransport: Parse Headers + HTTPTransport->>HTTPTransport: Extract Body + HTTPTransport->>Validation: Validate JSON Schema + Validation->>Router: BifrostRequest + Router-->>HTTPTransport: Processing Started + HTTPTransport-->>Client: HTTP 200 (async processing) +``` + +**Key Processing Steps:** + +1. **Request Reception** - FastHTTP server receives request +2. **Header Processing** - Extract authentication, content-type, custom headers +3. **Body Parsing** - JSON unmarshaling with schema validation +4. **Request Transformation** - Convert to internal `BifrostRequest` schema +5. **Context Creation** - Build request context with metadata + +**Performance Characteristics:** + +- **Parsing Time:** ~2.1ΞΌs for typical requests +- **Validation Overhead:** ~400ns for schema checks +- **Memory Allocation:** Zero-copy where possible + +### **Go SDK Flow** + +```mermaid +sequenceDiagram + participant Application + participant SDK + participant Core + participant Validation + + Application->>SDK: bifrost.ChatCompletion(req) + SDK->>SDK: Type Validation + SDK->>Core: Direct Function Call + Core->>Validation: Schema Validation + Validation-->>Core: Validated Request + Core-->>SDK: Processing Result + SDK-->>Application: Typed Response +``` + +**Advantages:** + +- **Zero Serialization** - Direct Go struct passing +- **Type Safety** - Compile-time validation +- **Lower Latency** - No HTTP/JSON overhead +- **Memory Efficiency** - No intermediate allocations + +--- + +## Stage 2: Request Routing & Load Balancing + +### **Provider Selection Logic** + +```mermaid +flowchart TD + Request[Incoming Request] --> ModelCheck{Model Available?} + ModelCheck -->|Yes| ProviderDirect[Use Specified Provider] + ModelCheck -->|No| ModelMapping[Model β†’ Provider Mapping] + + ProviderDirect --> KeyPool[API Key Pool] + ModelMapping --> KeyPool + + KeyPool --> WeightedSelect[Weighted Random Selection] + WeightedSelect --> HealthCheck{Provider Healthy?} + + HealthCheck -->|Yes| AssignWorker[Assign Worker] + HealthCheck -->|No| CircuitBreaker[Circuit Breaker] + + CircuitBreaker --> FallbackCheck{Fallback Available?} + FallbackCheck -->|Yes| FallbackProvider[Try Fallback] + FallbackCheck -->|No| ErrorResponse[Return Error] + + FallbackProvider --> KeyPool +``` + +**Key Selection Algorithm:** + +```go +// Weighted random key selection +type KeySelector struct { + keys []APIKey + weights []float64 + total float64 +} + +func (ks *KeySelector) SelectKey() *APIKey { + r := rand.Float64() * ks.total + cumulative := 0.0 + + for i, weight := range ks.weights { + cumulative += weight + if r <= cumulative { + return &ks.keys[i] + } + } + return &ks.keys[len(ks.keys)-1] +} +``` + +**Performance Metrics:** + +- **Key Selection Time:** ~10ns (constant time) +- **Health Check Overhead:** ~50ns (cached results) +- **Fallback Decision:** ~25ns (configuration lookup) + +--- + +## Stage 3: Plugin Pipeline Processing + +### **Pre-Processing Hooks** + +```mermaid +sequenceDiagram + participant Request + participant AuthPlugin + participant RateLimitPlugin + participant TransformPlugin + participant Core + + Request->>AuthPlugin: ProcessRequest() + AuthPlugin->>AuthPlugin: Validate API Key + AuthPlugin->>RateLimitPlugin: Authorized Request + + RateLimitPlugin->>RateLimitPlugin: Check Rate Limits + RateLimitPlugin->>TransformPlugin: Allowed Request + + TransformPlugin->>TransformPlugin: Modify Request + TransformPlugin->>Core: Final Request +``` + +**Plugin Execution Model:** + +```go +type PluginManager struct { + plugins []Plugin +} + +func (pm *PluginManager) ExecutePreHooks( + ctx BifrostContext, + req *BifrostRequest, +) (*BifrostRequest, *BifrostError) { + for _, plugin := range pm.plugins { + modifiedReq, err := plugin.ProcessRequest(ctx, req) + if err != nil { + return nil, err + } + req = modifiedReq + } + return req, nil +} +``` + +**Plugin Types & Performance:** + +| Plugin Type | Processing Time | Memory Impact | Failure Mode | +| --------------------- | --------------- | ------------- | ---------------------- | +| **Authentication** | ~1-5ΞΌs | Minimal | Reject request | +| **Rate Limiting** | ~500ns | Cache-based | Throttle/reject | +| **Request Transform** | ~2-10ΞΌs | Copy-on-write | Continue with original | +| **Monitoring** | ~200ns | Append-only | Continue silently | + +--- + +## Stage 4: MCP Tool Discovery & Integration + +### **Tool Discovery Process** + +```mermaid +flowchart TD + Request[Request with Model] --> MCPCheck{MCP Enabled?} + MCPCheck -->|No| SkipMCP[Skip MCP Processing] + MCPCheck -->|Yes| ClientLookup[MCP Client Lookup] + + ClientLookup --> ToolFilter[Tool Filtering] + ToolFilter --> ToolInject[Inject Tools into Request] + + ToolFilter --> IncludeCheck{Include Filter?} + ToolFilter --> ExcludeCheck{Exclude Filter?} + + IncludeCheck -->|Yes| IncludeTools[Include Specified Tools] + IncludeCheck -->|No| AllTools[Include All Tools] + + ExcludeCheck -->|Yes| RemoveTools[Remove Excluded Tools] + ExcludeCheck -->|No| KeepFiltered[Keep Filtered Tools] + + IncludeTools --> ToolInject + AllTools --> ToolInject + RemoveTools --> ToolInject + KeepFiltered --> ToolInject + + ToolInject --> EnhancedRequest[Request with Tools] + SkipMCP --> EnhancedRequest +``` + +**Tool Integration Algorithm:** + +```go +func (mcpm *MCPManager) EnhanceRequest( + ctx BifrostContext, + req *BifrostRequest, +) (*BifrostRequest, error) { + // Extract tool filtering from context + includeClients := ctx.GetStringSlice("mcp-include-clients") + excludeClients := ctx.GetStringSlice("mcp-exclude-clients") + includeTools := ctx.GetStringSlice("mcp-include-tools") + excludeTools := ctx.GetStringSlice("mcp-exclude-tools") + + // Get available tools + availableTools := mcpm.getAvailableTools(includeClients, excludeClients) + + // Filter tools + filteredTools := mcpm.filterTools(availableTools, includeTools, excludeTools) + + // Inject into request + if req.Params == nil { + req.Params = &ModelParameters{} + } + req.Params.Tools = append(req.Params.Tools, filteredTools...) + + return req, nil +} +``` + +**MCP Performance Impact:** + +- **Tool Discovery:** ~100-500ΞΌs (cached after first request) +- **Tool Filtering:** ~50-200ns per tool +- **Request Enhancement:** ~1-5ΞΌs depending on tool count + +--- + +## Stage 5: Memory Pool Management + +### **Object Pool Lifecycle** + +```mermaid +stateDiagram-v2 + [*] --> PoolInit: System Startup + PoolInit --> Available: Objects Pre-allocated + + Available --> Acquired: Request Processing + Acquired --> InUse: Object Populated + InUse --> Processing: Worker Processing + Processing --> Completed: Processing Done + Completed --> Reset: Object Cleanup + Reset --> Available: Return to Pool + + Available --> Expansion: Pool Exhaustion + Expansion --> Available: New Objects Created + + Reset --> GC: Pool Full + GC --> [*]: Garbage Collection +``` + +**Memory Pool Implementation:** + +```go +type MemoryPools struct { + channelPool sync.Pool + messagePool sync.Pool + responsePool sync.Pool + bufferPool sync.Pool +} + +func (mp *MemoryPools) GetChannel() *ProcessingChannel { + if ch := mp.channelPool.Get(); ch != nil { + return ch.(*ProcessingChannel) + } + return NewProcessingChannel() +} + +func (mp *MemoryPools) ReturnChannel(ch *ProcessingChannel) { + ch.Reset() // Clear previous data + mp.channelPool.Put(ch) +} +``` + +--- + +## Stage 6: Worker Pool Processing + +### **Worker Assignment & Execution** + +```mermaid +sequenceDiagram + participant Queue + participant WorkerPool + participant Worker + participant Provider + participant Circuit + + Queue->>WorkerPool: Enqueue Request + WorkerPool->>Worker: Assign Available Worker + Worker->>Circuit: Check Circuit Breaker + Circuit->>Provider: Forward Request + + Provider-->>Circuit: Response/Error + Circuit->>Circuit: Update Health Metrics + Circuit-->>Worker: Provider Response + Worker-->>WorkerPool: Release Worker + WorkerPool-->>Queue: Request Completed +``` + +**Worker Pool Architecture:** + +```go +type ProviderWorkerPool struct { + workers chan *Worker + queue chan *ProcessingJob + config WorkerPoolConfig + metrics *PoolMetrics +} + +func (pwp *ProviderWorkerPool) ProcessRequest(job *ProcessingJob) { + // Get worker from pool + worker := <-pwp.workers + + go func() { + defer func() { + // Return worker to pool + pwp.workers <- worker + }() + + // Process request + result := worker.Execute(job) + job.ResultChan <- result + }() +} +``` + +--- + +## Stage 7: Provider API Communication + +### **HTTP Request Execution** + +```mermaid +sequenceDiagram + participant Worker + participant HTTPClient + participant Provider + participant CircuitBreaker + participant Metrics + + Worker->>HTTPClient: PrepareRequest() + HTTPClient->>HTTPClient: Add Headers & Auth + HTTPClient->>CircuitBreaker: CheckHealth() + CircuitBreaker->>Provider: HTTP Request + + Provider-->>CircuitBreaker: HTTP Response + CircuitBreaker->>Metrics: Record Metrics + CircuitBreaker-->>HTTPClient: Response/Error + HTTPClient-->>Worker: Parsed Response +``` + +**Request Preparation Pipeline:** + +```go +func (w *ProviderWorker) ExecuteRequest(job *ProcessingJob) *ProviderResponse { + // Prepare HTTP request + httpReq := w.prepareHTTPRequest(job.Request) + + // Add authentication + w.addAuthentication(httpReq, job.APIKey) + + // Execute with timeout + ctx, cancel := context.WithTimeout(context.Background(), job.Timeout) + defer cancel() + + httpResp, err := w.httpClient.Do(httpReq.WithContext(ctx)) + if err != nil { + return w.handleError(err, job) + } + + // Parse response + return w.parseResponse(httpResp, job) +} +``` + +--- + +## Stage 8: Tool Execution & Response Processing + +### **MCP Tool Execution Flow** + +```mermaid +sequenceDiagram + participant Provider + participant MCPProcessor + participant MCPServer + participant ToolExecutor + participant ResponseBuilder + + Provider->>MCPProcessor: Response with Tool Calls + MCPProcessor->>MCPProcessor: Extract Tool Calls + + loop For each tool call + MCPProcessor->>MCPServer: Execute Tool + MCPServer->>ToolExecutor: Tool Invocation + ToolExecutor-->>MCPServer: Tool Result + MCPServer-->>MCPProcessor: Tool Response + end + + MCPProcessor->>ResponseBuilder: Combine Results + ResponseBuilder-->>Provider: Enhanced Response +``` + +**Tool Execution Pipeline:** + +```go +func (mcp *MCPProcessor) ProcessToolCalls( + response *ProviderResponse, +) (*ProviderResponse, error) { + toolCalls := mcp.extractToolCalls(response) + if len(toolCalls) == 0 { + return response, nil + } + + // Execute tools concurrently + results := make(chan ToolResult, len(toolCalls)) + for _, toolCall := range toolCalls { + go func(tc ToolCall) { + result := mcp.executeTool(tc) + results <- result + }(toolCall) + } + + // Collect results + toolResults := make([]ToolResult, 0, len(toolCalls)) + for i := 0; i < len(toolCalls); i++ { + toolResults = append(toolResults, <-results) + } + + // Enhance response + return mcp.enhanceResponse(response, toolResults), nil +} +``` + +--- + +## Stage 9: Post-Processing & Response Formation + +### **Plugin Post-Processing** + +```mermaid +sequenceDiagram + participant CoreResponse + participant LoggingPlugin + participant CachePlugin + participant MetricsPlugin + participant Transport + + CoreResponse->>LoggingPlugin: ProcessResponse() + LoggingPlugin->>LoggingPlugin: Log Request/Response + LoggingPlugin->>CachePlugin: Response + Logs + + CachePlugin->>CachePlugin: Cache Response + CachePlugin->>MetricsPlugin: Cached Response + + MetricsPlugin->>MetricsPlugin: Record Metrics + MetricsPlugin->>Transport: Final Response +``` + +**Response Enhancement Pipeline:** + +```go +func (pm *PluginManager) ExecutePostHooks( + ctx BifrostContext, + req *BifrostRequest, + resp *BifrostResponse, +) (*BifrostResponse, error) { + for _, plugin := range pm.plugins { + enhancedResp, err := plugin.ProcessResponse(ctx, req, resp) + if err != nil { + // Log error but continue processing + pm.logger.Warn("Plugin post-processing error", "plugin", plugin.Name(), "error", err) + continue + } + resp = enhancedResp + } + return resp, nil +} +``` + +### **Response Serialization** + +```mermaid +flowchart TD + Response[BifrostResponse] --> Format{Response Format} + Format -->|HTTP| JSONSerialize[JSON Serialization] + Format -->|SDK| DirectReturn[Direct Go Struct] + + JSONSerialize --> Compress[Compression] + DirectReturn --> TypeCheck[Type Validation] + + Compress --> Headers[Set Headers] + TypeCheck --> Return[Return Response] + + Headers --> HTTPResponse[HTTP Response] + HTTPResponse --> Client[Client Response] + Return --> Client +``` + +--- + +## Related Architecture Documentation + +- **[Concurrency Model](./concurrency)** - Worker pools and threading details +- **[Plugin System](./plugins)** - Plugin execution and lifecycle +- **[MCP System](./mcp)** - Tool discovery and execution internals +- **[Benchmarks](../../benchmarking/getting-started)** - Detailed performance analysis diff --git a/docs/architecture/framework/config-store.mdx b/docs/architecture/framework/config-store.mdx new file mode 100644 index 000000000..e69de29bb diff --git a/docs/architecture/framework/log-store.mdx b/docs/architecture/framework/log-store.mdx new file mode 100644 index 000000000..e69de29bb diff --git a/docs/architecture/framework/pricing.mdx b/docs/architecture/framework/pricing.mdx new file mode 100644 index 000000000..413466750 --- /dev/null +++ b/docs/architecture/framework/pricing.mdx @@ -0,0 +1,206 @@ +--- +title: "Pricing Module" +description: "Dynamic model pricing and cost calculation system for AI model usage tracking and billing." +icon: "dollar-sign" +--- + +The Pricing Module provides intelligent cost calculation and dynamic pricing management for AI model usage across all providers supported by Bifrost. It offers real-time cost tracking, multi-modal pricing support, and automatic pricing data synchronization. + +## Core Features + +### **Automatic Pricing Synchronization** +The pricing system manages pricing data through a two-phase approach: + +**Startup Behavior:** +- **With ConfigStore**: Downloads pricing sheet from Maxim's datasheet and persists it to the config store, then loads into memory for fast lookups +- **Without ConfigStore**: Downloads pricing sheet directly into memory on every startup + +**Ongoing Synchronization:** +- When ConfigStore is available, automatic sync occurs every 24 hours to keep pricing data current +- All pricing data is cached in memory for O(1) lookup performance during cost calculations + +This ensures cost calculations always use the latest pricing information from AI providers while maintaining optimal performance. + +### **Multi-Modal Cost Calculation** +Supports diverse pricing models across different AI operation types: +- **Text Operations**: Token-based and character-based pricing for chat completions, text completions, and embeddings +- **Audio Processing**: Token-based pricing for speech synthesis and transcription +- **Image Processing**: Per-image costs with tiered pricing for high-token contexts + +### **Intelligent Cache Cost Handling** +Integrates with semantic caching to provide accurate cost calculations: +- **Cache Hits**: Zero cost for direct cache hits, embedding cost only for semantic matches +- **Cache Misses**: Combined cost of base model usage plus embedding generation for cache storage + +### **Tiered Pricing Support** +Automatically applies different pricing rates for high-token contexts (above 128k tokens), reflecting real provider pricing models. + +## Architecture + +### PricingManager +The central component that handles all pricing operations: + +```go +type PricingManager struct { + configStore configstore.ConfigStore + logger schemas.Logger + + // In-memory cache for fast access + pricingData map[string]configstore.TableModelPricing + mu sync.RWMutex + + // Background sync worker + syncTicker *time.Ticker + done chan struct{} + wg sync.WaitGroup +} +``` + +### Pricing Data Structure +Each model's pricing information includes comprehensive cost metrics: + +```go +type PricingEntry struct { + // Basic pricing + InputCostPerToken float64 `json:"input_cost_per_token"` + OutputCostPerToken float64 `json:"output_cost_per_token"` + Provider string `json:"provider"` + Mode string `json:"mode"` + + // Media pricing + InputCostPerImage *float64 `json:"input_cost_per_image,omitempty"` + InputCostPerVideoPerSecond *float64 `json:"input_cost_per_video_per_second,omitempty"` + InputCostPerAudioPerSecond *float64 `json:"input_cost_per_audio_per_second,omitempty"` + + // Character-based pricing + InputCostPerCharacter *float64 `json:"input_cost_per_character,omitempty"` + OutputCostPerCharacter *float64 `json:"output_cost_per_character,omitempty"` + + // Tiered pricing (above 128k tokens) + InputCostPerTokenAbove128kTokens *float64 `json:"input_cost_per_token_above_128k_tokens,omitempty"` + OutputCostPerTokenAbove128kTokens *float64 `json:"output_cost_per_token_above_128k_tokens,omitempty"` + + // Special operation pricing + CacheReadInputTokenCost *float64 `json:"cache_read_input_token_cost,omitempty"` + InputCostPerTokenBatches *float64 `json:"input_cost_per_token_batches,omitempty"` + OutputCostPerTokenBatches *float64 `json:"output_cost_per_token_batches,omitempty"` +} +``` + +## Usage in Plugins + +### Initialization +In Bifrost's gateway, the `PricingManager` is initialized once at the start and shared across all plugins: + +```go +import "github.com/maximhq/bifrost/framework/pricing" + +// Initialize pricing manager with config store and logger +pricingManager, err := pricing.Init(configStore, logger) +if err != nil { + return fmt.Errorf("failed to initialize pricing manager: %w", err) +} +``` + +### Basic Cost Calculation +Calculate costs from a Bifrost response: + +```go +// Calculate cost for a completed request +cost := pricingManager.CalculateCost( + result, // *schemas.BifrostResponse + schemas.OpenAI, // provider + "gpt-4", // model + schemas.ChatCompletionRequest, // request type +) + +logger.Info("Request cost: $%.6f", cost) +``` + +### Advanced Cost Calculation with Usage Details +For more granular cost calculation with custom usage data: + +```go +// Custom usage calculation +usage := &schemas.LLMUsage{ + PromptTokens: 1500, + CompletionTokens: 800, + TotalTokens: 2300, +} + +cost := pricingManager.CalculateCostFromUsage( + "openai", // provider + "gpt-4", // model + usage, // usage data + schemas.ChatCompletionRequest, // request type + false, // is cache read + false, // is batch + nil, // audio seconds (for audio models) + nil, // audio token details +) +``` + +### Cache Aware Cost Calculation +For workflows that implement semantic caching, use cache-aware cost calculation: + +```go +// This automatically handles cache hits/misses and embedding costs +cost := pricingManager.CalculateCostWithCacheDebug( + result, // *schemas.BifrostResponse with cache debug info + schemas.Anthropic, // provider + "claude-3-sonnet", // model + schemas.ChatCompletionRequest, // request type +) + +// Cache hits return 0 for direct hits, embedding cost for semantic matches +// Cache misses return base model cost + embedding generation cost +``` + +## Error Handling and Fallbacks + +The pricing module handles missing pricing data gracefully: + +```go +// Pricing lookup with fallback behavior +func (pm *PricingManager) getPricing(model, provider string, requestType schemas.RequestType) (*configstore.TableModelPricing, bool) { + // Try direct lookup + pricing, ok := pm.pricingData[makeKey(model, provider, normalizeRequestType(requestType))] + if !ok { + pm.logger.Warn("pricing not found for model %s and provider %s of request type %s", + model, provider, requestType) + return nil, false + } + return &pricing, true +} + +// When pricing is not found, CalculateCost returns 0.0 and logs a warning +// This ensures operations continue smoothly without billing failures +``` + + +## Cleanup and Lifecycle Management + +Properly cleanup resources when shutting down: + +```go +// Cleanup pricing manager resources +defer func() { + if err := pricingManager.Cleanup(); err != nil { + logger.Error("Failed to cleanup pricing manager: %v", err) + } +}() +``` + +## Thread Safety + +All PricingManager operations are thread-safe, making it suitable for concurrent usage across multiple plugins and goroutines. The internal pricing data cache uses read-write mutexes for optimal performance during frequent lookups. + +## Best Practices + +1. **Shared Instance**: Use a single PricingManager instance across all plugins to avoid redundant data synchronization +2. **Error Handling**: Always handle the case where pricing returns 0.0 due to missing model data +3. **Logging**: Monitor pricing sync failures and missing model warnings in production +4. **Cache Awareness**: Use `CalculateCostWithCacheDebug` when implementing caching features +5. **Resource Cleanup**: Always call `Cleanup()` during application shutdown to prevent resource leaks + +The Pricing Module provides a robust, production-ready foundation for implementing billing, budgeting, and cost monitoring features in Bifrost plugins. diff --git a/docs/architecture/framework/vector-store.mdx b/docs/architecture/framework/vector-store.mdx new file mode 100644 index 000000000..1fee714a9 --- /dev/null +++ b/docs/architecture/framework/vector-store.mdx @@ -0,0 +1,506 @@ +--- +title: "Vector Store" +description: "Vector database implementations for semantic search, embeddings storage, and AI-powered features in Bifrost." +icon: "diagram-project" +--- + +## Overview + +The VectorStore is a core component of Bifrost's framework package that provides a unified interface for vector database operations. It enables plugins to store embeddings, perform similarity searches, and build AI-powered features like semantic caching, content recommendations, and knowledge retrieval. + +**Key Capabilities:** +- **Vector Similarity Search**: Find semantically similar content using embeddings +- **Namespace Management**: Organize data into separate collections with custom schemas +- **Flexible Filtering**: Query data with complex filters and pagination +- **Multiple Backends**: Support for Weaviate and Redis vector stores +- **High Performance**: Optimized for production workloads +- **Scalable Storage**: Handle millions of vectors with efficient indexing + +## Supported Vector Stores + +Bifrost currently supports two vector store implementations: + +- **[Weaviate](#weaviate)**: Production-ready vector database with gRPC support and advanced querying +- **[Redis](#redis)**: High-performance in-memory vector store using RediSearch + +## VectorStore Interface Usage + +### Creating Namespaces +Create collections (namespaces) with custom schemas: + +```go +// Define properties for your data +properties := map[string]vectorstore.VectorStoreProperties{ + "content": { + DataType: vectorstore.VectorStorePropertyTypeString, + Description: "The main content text", + }, + "category": { + DataType: vectorstore.VectorStorePropertyTypeString, + Description: "Content category", + }, + "tags": { + DataType: vectorstore.VectorStorePropertyTypeStringArray, + Description: "Content tags", + }, +} + +// Create namespace +err := store.CreateNamespace(ctx, "my_content", 1536, properties) +if err != nil { + log.Fatal("Failed to create namespace:", err) +} +``` + +### Storing Data with Embeddings +Add data with vector embeddings for similarity search: + +```go +// Your embedding data (typically from an embedding model) +embedding := []float32{0.1, 0.2, 0.3 } // example 3-dimensional vector + +// Metadata associated with this vector +metadata := map[string]interface{}{ + "content": "This is my content text", + "category": "documentation", + "tags": []string{"guide", "tutorial"}, +} + +// Store in vector database +err := store.Add(ctx, "my_content", "unique-id-123", embedding, metadata) +if err != nil { + log.Fatal("Failed to add data:", err) +} +``` + +### Similarity Search +Find similar content using vector similarity: + +```go +// Query embedding (from user query) +queryEmbedding := []float32{0.15, 0.25, 0.35, ...} + +// Optional filters +filters := []vectorstore.Query{ + { + Field: "category", + Operator: vectorstore.QueryOperatorEqual, + Value: "documentation", + }, +} + +// Perform similarity search +results, err := store.GetNearest( + ctx, + "my_content", // namespace + queryEmbedding, // query vector + filters, // optional filters + []string{"content", "category"}, // fields to return + 0.7, // similarity threshold (0-1) + 10, // limit +) + +for _, result := range results { + fmt.Printf("Score: %.3f, Content: %s\n", *result.Score, result.Properties["content"]) +} +``` + +### Data Retrieval and Management +Query and manage stored data: + +```go +// Get specific item by ID +item, err := store.GetChunk(ctx, "my_content", "unique-id-123") +if err != nil { + log.Fatal("Failed to get item:", err) +} + +// Get all items with filtering and pagination +allResults, cursor, err := store.GetAll( + ctx, + "my_content", + []vectorstore.Query{ + {Field: "category", Operator: vectorstore.QueryOperatorEqual, Value: "documentation"}, + }, + []string{"content", "tags"}, // select fields + nil, // cursor for pagination + 50, // limit +) + +// Delete items +err = store.Delete(ctx, "my_content", "unique-id-123") +``` + +## Weaviate + +Weaviate is a production-ready vector database solution that provides advanced querying capabilities, gRPC support for high performance, and flexible schema management for production deployments. + +### Key Features + +- **gRPC Support**: Enhanced performance with gRPC connections +- **Advanced Filtering**: Complex query operations with multiple conditions +- **Schema Management**: Flexible schema definition for different data types +- **Cloud & Self-Hosted**: Support for both Weaviate Cloud and self-hosted deployments +- **Scalable Storage**: Handle millions of vectors with efficient indexing + +### Setup & Installation + +**Weaviate Cloud:** +- Sign up at [cloud.weaviate.io](https://cloud.weaviate.io) +- Create a new cluster +- Get your API key and cluster URL + +**Local Weaviate:** +```bash +# Using Docker +docker run -d \ + --name weaviate \ + -p 8080:8080 \ + -e QUERY_DEFAULTS_LIMIT=25 \ + -e AUTHENTICATION_ANONYMOUS_ACCESS_ENABLED='true' \ + -e PERSISTENCE_DATA_PATH='/var/lib/weaviate' \ + semitechnologies/weaviate:latest +``` + +### Configuration Options + + + + + +```go +// Configure Weaviate vector store +vectorConfig := &vectorstore.Config{ + Enabled: true, + Type: vectorstore.VectorStoreTypeWeaviate, + Config: vectorstore.WeaviateConfig{ + Scheme: "http", // "http" for local, "https" for cloud + Host: "localhost:8080", // Your Weaviate host + ApiKey: "your-weaviate-api-key", // Required for Weaviate Cloud; optional for local/self-hosted + + // Enable gRPC for improved performance (optional) + GrpcConfig: &vectorstore.WeaviateGrpcConfig{ + Host: "localhost:50051", // gRPC port + Secured: false, // true for TLS + }, + }, +} + +// Create vector store +store, err := vectorstore.NewVectorStore(context.Background(), vectorConfig, logger) +if err != nil { + log.Fatal("Failed to create vector store:", err) +} +``` + + + + + +**Local Setup:** +```json +{ + "vector_store": { + "enabled": true, + "type": "weaviate", + "config": { + "scheme": "http", + "host": "localhost:8080" + } + } +} +``` + +**Cloud Setup with gRPC:** +```json +{ + "vector_store": { + "enabled": true, + "type": "weaviate", + "config": { + "scheme": "https", + "host": "your-weaviate-host", + "api_key": "your-weaviate-api-key", + "grpc_config": { + "host": "your-weaviate-grpc-host", + "secured": true + } + } + } +} +``` + + + + + + +gRPC host should include the port. If no port is specified, port 80 is used for insecured connections and port 443 for secured connections. + + +### Advanced Features + +**gRPC Performance Optimization:** +Enable gRPC for better performance in production: + +```go +vectorConfig := &vectorstore.Config{ + Type: vectorstore.VectorStoreTypeWeaviate, + Config: vectorstore.WeaviateConfig{ + Scheme: "https", + Host: "your-weaviate-host", + ApiKey: "your-api-key", + + // Enable gRPC for better performance + GrpcConfig: &vectorstore.WeaviateGrpcConfig{ + Host: "your-weaviate-grpc-host:443", + Secured: true, + }, + }, +} +``` + +### Production Considerations + + +**Performance**: For production environments, consider using gRPC configuration for better performance and enable appropriate authentication mechanisms for your Weaviate deployment. + + + +**Authentication**: Always use API keys for Weaviate Cloud deployments and configure proper authentication for self-hosted instances in production. + + +--- + +## Redis + +Redis provides high-performance in-memory vector storage using RediSearch, ideal for applications requiring sub-millisecond response times and fast semantic search capabilities. + +### Key Features + +- **High Performance**: Sub-millisecond cache retrieval with Redis's in-memory storage +- **Cost Effective**: Open-source solution with no licensing costs +- **HNSW Algorithm**: Fast vector similarity search with excellent recall rates +- **Connection Pooling**: Advanced connection management for high-throughput applications +- **TTL Support**: Automatic expiration of cached entries +- **Streaming Support**: Full streaming response caching with proper chunk ordering +- **Flexible Filtering**: Advanced metadata filtering with exact string matching + +### Setup & Installation + +**Redis Cloud:** +- Sign up at [cloud.redis.io](https://cloud.redis.io) +- Create a new database with RediSearch module enabled +- Get your connection details + +**Local Redis with RediSearch:** +```bash +# Using Docker with Redis Stack (includes RediSearch) +docker run -d --name redis-stack -p 6379:6379 redis/redis-stack:latest +``` + +### Configuration Options + + + + + +```go +// Configure Redis vector store +vectorConfig := &vectorstore.Config{ + Enabled: true, + Type: vectorstore.VectorStoreTypeRedis, + Config: vectorstore.RedisConfig{ + Addr: "localhost:6379", // Redis server address - REQUIRED + Username: "", // Optional: Redis username + Password: "", // Optional: Redis password + DB: 0, // Optional: Redis database number (default: 0) + + // Optional: Connection pool settings + PoolSize: 10, // Maximum socket connections + MaxActiveConns: 10, // Maximum active connections + MinIdleConns: 5, // Minimum idle connections + MaxIdleConns: 10, // Maximum idle connections + + // Optional: Timeout settings + DialTimeout: 5 * time.Second, // Connection timeout + ReadTimeout: 3 * time.Second, // Read timeout + WriteTimeout: 3 * time.Second, // Write timeout + ContextTimeout: 10 * time.Second, // Operation timeout + }, +} + +// Create vector store +store, err := vectorstore.NewVectorStore(context.Background(), vectorConfig, logger) +if err != nil { + log.Fatal("Failed to create vector store:", err) +} +``` + + + + + +```json +{ + "vector_store": { + "enabled": true, + "type": "redis", + "config": { + "addr": "localhost:6379", + "username": "", + "password": "", + "db": 0, + "pool_size": 10, + "max_active_conns": 10, + "min_idle_conns": 5, + "max_idle_conns": 10, + "dial_timeout": "5s", + "read_timeout": "3s", + "write_timeout": "3s", + "context_timeout": "10s" + } + } +} +``` + +**For Redis Cloud:** +```json +{ + "vector_store": { + "enabled": true, + "type": "redis", + "config": { + "addr": "your-redis-host:port", + "username": "your-username", + "password": "your-password", + "db": 0, + "context_timeout": "10s" + } + } +} +``` + + + + + +### Redis-Specific Features + +**Vector Search Algorithm:** +Redis uses the **HNSW (Hierarchical Navigable Small World)** algorithm for vector similarity search, which provides: + +- **Fast Search**: O(log N) search complexity +- **High Accuracy**: Excellent recall rates for similarity search +- **Memory Efficient**: Optimized for in-memory operations +- **Cosine Similarity**: Uses cosine distance metric for semantic similarity + +**Connection Pool Management:** +Redis provides extensive connection pool configuration: + +```go +config := vectorstore.RedisConfig{ + Addr: "localhost:6379", + PoolSize: 20, // Max socket connections + MaxActiveConns: 20, // Max active connections + MinIdleConns: 5, // Min idle connections + MaxIdleConns: 10, // Max idle connections + ConnMaxLifetime: 30 * time.Minute, // Connection lifetime + ConnMaxIdleTime: 5 * time.Minute, // Idle connection timeout + DialTimeout: 5 * time.Second, // Connection timeout + ReadTimeout: 3 * time.Second, // Read timeout + WriteTimeout: 3 * time.Second, // Write timeout + ContextTimeout: 10 * time.Second, // Operation timeout +} +``` + +### Performance Optimization + +**Connection Pool Tuning:** +For high-throughput applications, tune the connection pool settings: + +```json +{ + "vector_store": { + "config": { + "pool_size": 50, // Increase for high concurrency + "max_active_conns": 50, // Match pool_size + "min_idle_conns": 10, // Keep connections warm + "max_idle_conns": 20, // Allow some idle connections + "conn_max_lifetime": "1h", // Refresh connections periodically + "conn_max_idle_time": "10m" // Close idle connections + } + } +} +``` + +**Memory Optimization:** +- **TTL**: Use appropriate TTL values to prevent memory bloat +- **Namespace Cleanup**: Regularly clean up unused namespaces + +**Batch Operations:** +Redis supports efficient batch operations: + +```go +// Batch retrieval +results, err := store.GetChunks(ctx, namespace, []string{"id1", "id2", "id3"}) + +// Batch deletion +deleteResults, err := store.DeleteAll(ctx, namespace, queries) +``` + +### Production Considerations + + +**RediSearch Module Required**: Redis integration requires the RediSearch module to be enabled on your Redis instance. This module provides the vector search capabilities needed for semantic caching. + + + +**Production Considerations**: +- Use Redis AUTH for production deployments +- Configure appropriate connection timeouts +- Monitor memory usage and set appropriate TTL values + + +--- + +## Use Cases + +### [Semantic Caching](../../../features/semantic-caching) +Build intelligent caching systems that understand query intent rather than just exact matches. + +**Applications:** +- Customer support systems with FAQ matching +- Code completion and documentation search +- Content management with semantic deduplication + +### Knowledge Base & Search +Create intelligent search systems that understand user queries contextually. + +**Applications:** +- Document search and retrieval systems +- Product recommendation engines +- Research paper and knowledge discovery platforms + +### Content Classification +Automatically categorize and tag content based on semantic similarity. + +**Applications:** +- Email classification and routing +- Content moderation and filtering +- News article categorization and clustering + +### Recommendation Systems +Build personalized recommendation engines using vector similarity. + +**Applications:** +- Product recommendations based on user preferences +- Content suggestions for media platforms +- Similar document or article recommendations + +## Related Documentation + +| Topic | Documentation | Description | +|-------|---------------|-------------| +| **Framework Overview** | [What is Framework](../what-is-framework) | Understanding the framework package and VectorStore interface | +| **Semantic Caching** | [Semantic Caching](../../../features/semantic-caching) | Using VectorStore for AI response caching | diff --git a/docs/architecture/framework/what-is-framework.mdx b/docs/architecture/framework/what-is-framework.mdx new file mode 100644 index 000000000..2ba04a207 --- /dev/null +++ b/docs/architecture/framework/what-is-framework.mdx @@ -0,0 +1,49 @@ +--- +title: "What is framework?" +description: "Framework is Bifrost's shared storage and utilities SDK package that provides common database interfaces and logic for the plugin ecosystem." +icon: "play" +--- + +Framework serves as the foundation layer that enables plugins to implement consistent data management patterns without reinventing storage solutions. + +## Installation + +```bash +go get github.com/maximhq/bifrost/framework +``` + +## Purpose + +The framework package was designed to solve a fundamental challenge in plugin development: providing standardized, reliable storage and utility interfaces that plugins can depend on. Instead of each plugin implementing its own database logic, configuration management, or logging systems, framework offers battle-tested, shared implementations. + +## Core Components + +### ConfigStore +A unified configuration persistence layer that provides consistent storage patterns for plugin settings, provider configurations, and system state. Plugins can leverage `ConfigStore` to manage their configuration data with built-in CRUD operations, transaction support, and schema management. + +### LogStore +Standardized logging and audit trail capabilities that enable plugins to implement observability features. `LogStore` provides structured logging, search and filtering capabilities, pagination support, and automated data retention policies. + +### VectorStore +Vector database operations designed for AI-powered plugins that need semantic capabilities. `VectorStore` handles embeddings management, similarity search operations, and namespace isolation, making it easy for plugins to add features like semantic caching, content search, and AI-powered recommendations. + +### Pricing Module +Cost calculation and model pricing management tools that help plugins implement billing and usage tracking features. The pricing system supports multi-tier pricing models, real-time usage tracking, and dynamic pricing updates. + +## Benefits for Plugin Developers + +**Shared Logic**: Common patterns for configuration, logging, and data management are provided out-of-the-box, reducing development time and ensuring consistency across plugins. + +**Standardized Interfaces**: All framework components use consistent APIs, making it easier for developers to work across different plugins and maintain code quality. + +**Pluggable Architecture**: The interface-based design allows different storage backends to be used without changing plugin code, providing flexibility for different deployment scenarios. + +**Transaction Support**: Built-in transaction management and error handling ensure data integrity and provide reliable rollback capabilities. + +**Production Ready**: Framework components are battle-tested in production environments and include features like connection pooling, retry logic, and performance optimizations. + +## Integration with Bifrost + +Framework seamlessly integrates with the Bifrost ecosystem, providing the storage foundation that powers core features like provider management, request logging, semantic caching, and governance. When plugins use framework components, they automatically participate in Bifrost's unified data management strategy. + +The framework package enables plugin developers to focus on their core business logic while relying on robust, shared infrastructure for all storage and utility needs. \ No newline at end of file diff --git a/docs/architecture/plugins/governance.mdx b/docs/architecture/plugins/governance.mdx new file mode 100644 index 000000000..e69de29bb diff --git a/docs/architecture/plugins/jsonparser.mdx b/docs/architecture/plugins/jsonparser.mdx new file mode 100644 index 000000000..e69de29bb diff --git a/docs/architecture/plugins/logging.mdx b/docs/architecture/plugins/logging.mdx new file mode 100644 index 000000000..e69de29bb diff --git a/docs/architecture/plugins/maxim.mdx b/docs/architecture/plugins/maxim.mdx new file mode 100644 index 000000000..e69de29bb diff --git a/docs/architecture/plugins/mocker.mdx b/docs/architecture/plugins/mocker.mdx new file mode 100644 index 000000000..e69de29bb diff --git a/docs/architecture/plugins/semantic-cache.mdx b/docs/architecture/plugins/semantic-cache.mdx new file mode 100644 index 000000000..e69de29bb diff --git a/docs/architecture/plugins/telemetry.mdx b/docs/architecture/plugins/telemetry.mdx new file mode 100644 index 000000000..e69de29bb diff --git a/docs/architecture/transports/in-memory-store.mdx b/docs/architecture/transports/in-memory-store.mdx new file mode 100644 index 000000000..e69de29bb diff --git a/docs/benchmarking/getting-started.mdx b/docs/benchmarking/getting-started.mdx new file mode 100644 index 000000000..f1289b354 --- /dev/null +++ b/docs/benchmarking/getting-started.mdx @@ -0,0 +1,81 @@ +--- +title: "Getting Started" +description: "Introduction to Bifrost's performance capabilities and how to choose the right instance size for your workload." +icon: "rocket" +--- + +## Overview + +Bifrost has been rigorously tested under high load conditions to ensure optimal performance for production deployments. Our benchmark tests demonstrate exceptional performance characteristics at **5,000 requests per second (RPS)** across different AWS EC2 instance types. + +**Key Performance Highlights:** +- **Perfect Success Rate**: 100% request success rate under high load +- **Minimal Overhead**: Less than 15Β΅s added latency per request on average +- **Efficient Queue Management**: Sub-microsecond queue wait times on optimized instances +- **Fast Key Selection**: Near-instantaneous weighted API key selection (~10 ns) + +--- + +## Test Environment Summary + +Bifrost was benchmarked on two primary AWS EC2 instance configurations: + +### **t3.medium (2 vCPUs, 4GB RAM)** +- **Buffer Size**: 15,000 +- **Initial Pool Size**: 10,000 +- **Use Case**: Cost-effective option for moderate workloads + +### **t3.xlarge (4 vCPUs, 16GB RAM)** +- **Buffer Size**: 20,000 +- **Initial Pool Size**: 15,000 +- **Use Case**: High-performance option for demanding workloads + +--- + +## Performance Comparison at a Glance + +| Metric | t3.medium | t3.xlarge | Improvement | +|--------|-----------|-----------|-------------| +| **Success Rate @ 5k RPS** | 100% | 100% | No failed requests | +| **Bifrost Overhead** | 59 Β΅s | 11 Β΅s | **-81%** | +| **Average Latency** | 2.12s | 1.61s | **-24%** | +| **Queue Wait Time** | 47.13 Β΅s | 1.67 Β΅s | **-96%** | +| **JSON Marshaling** | 63.47 Β΅s | 26.80 Β΅s | **-58%** | +| **Response Parsing** | 11.30 ms | 2.11 ms | **-81%** | +| **Peak Memory Usage** | 1,312.79 MB | 3,340.44 MB | +155% | + +> **Note**: t3.xlarge tests used significantly larger response payloads (~10 KB vs ~1 KB), yet still achieved better performance metrics. + + +All benchmarks are on mocked OpenAI calls, whose latency and payload size are mentioned in the respective analysis pages. + + +--- + +## Configuration Flexibility + +One of Bifrost's key strengths is its **configuration flexibility**. You can fine-tune the speed ↔ memory trade-off based on your specific requirements: + +| Configuration Parameter | Effect | +|------------------------|--------| +| `initial_pool_size` | Higher values = faster performance, more memory usage | +| `buffer_size` & `concurrency` | Controls queue depth and max parallel workers (per provider) | +| `retry` & `timeout` | Tune aggressiveness for each provider to meet your SLOs | + +**Configuration Philosophy:** +- **Higher settings** (like t3.xlarge profile) prioritize raw speed +- **Lower settings** (like t3.medium profile) optimize for memory efficiency +- **Custom tuning** lets you find the sweet spot for your specific workload + +--- + +## Next Steps + +### **Detailed Performance Analysis** +- **[t3.medium Performance](./t3.medium)** - Deep dive into cost-effective performance +- **[t3.xlarge Performance](./t3.xl)** - High-performance configuration analysis + +### **Run Your Own Tests** +- **[Run Your Own Benchmarks](./run-your-own-benchmarks)** - Step-by-step guide to benchmark Bifrost in your environment + +Ready to dive deeper? Choose your instance type above or learn how to run your own performance tests. diff --git a/docs/benchmarking/run-your-own-benchmarks.mdx b/docs/benchmarking/run-your-own-benchmarks.mdx new file mode 100644 index 000000000..2e75d53d1 --- /dev/null +++ b/docs/benchmarking/run-your-own-benchmarks.mdx @@ -0,0 +1,355 @@ +--- +title: "Run Your Own Benchmarks" +description: "Step-by-step guide to benchmark Bifrost in your own environment using the official benchmarking tool." +icon: "stopwatch" +--- + +## Overview + +Want to see Bifrost's performance in your specific environment? The [**Bifrost Benchmarking Repository**](https://github.com/maximhq/bifrost-benchmarking) provides everything you need to conduct comprehensive performance tests tailored to your infrastructure and workload requirements. + +**What You Can Test:** +- **Custom Instance Sizes** - Test on your preferred AWS/GCP/Azure instances +- **Your Workload Patterns** - Use your actual request/response sizes +- **Different Configurations** - Compare various Bifrost settings +- **Provider Comparisons** - Benchmark against other AI gateways +- **Load Scenarios** - Test burst loads, sustained traffic, and endurance + +> **πŸ’‘ Open Source**: The benchmarking tool is completely open source! Feel free to submit pull requests if you think anything is missing or could be improved. + +--- + +## Prerequisites + +Before running benchmarks, ensure you have: + +- **Go 1.23+** installed on your testing machine +- **Bifrost instance** running and accessible +- **Target API providers** configured (OpenAI, Anthropic, etc.) +- **Network access** between benchmark tool and Bifrost +- **Sufficient resources** on the testing machine to generate load + +--- + +## Quick Start + +### **1. Clone the Repository** + +```bash +git clone https://github.com/maximhq/bifrost-benchmarking.git +cd bifrost-benchmarking +``` + +### **2. Build the Benchmark Tool** + +```bash +go build benchmark.go +``` + +This creates a `benchmark` executable (or `benchmark.exe` on Windows). + +### **3. Run Your First Benchmark** + +```bash +# Basic benchmark: 500 RPS for 10 seconds +./benchmark -provider bifrost -port 8080 + +# Custom benchmark: 1000 RPS for 30 seconds +./benchmark -provider bifrost -port 8080 -rate 1000 -duration 30 -output my_results.json +``` + +--- + +## Configuration Options + +The benchmark tool offers extensive configuration through command-line flags: + +### **Basic Configuration** + +| Flag | Required | Description | Default | +|------|----------|-------------|---------| +| `-provider ` | βœ… | Provider name (e.g., `bifrost`, `litellm`) | None | +| `-port ` | βœ… | Port number of your Bifrost instance | None | +| `-endpoint ` | ❌ | API endpoint path | `v1/chat/completions` | +| `-rate ` | ❌ | Requests per second | `500` | +| `-duration ` | ❌ | Test duration in seconds | `10` | +| `-output ` | ❌ | Results output file | `results.json` | + +### **Advanced Configuration** + +| Flag | Description | Default | +|------|-------------|---------| +| `-include-provider-in-request` | Include provider name in request payload | `false` | +| `-big-payload` | Use larger, more complex request payloads | `false` | + +--- + +## Benchmark Scenarios + +### **1. Basic Performance Test** + +Test standard performance with typical request sizes: + +```bash +./benchmark -provider bifrost -port 8080 -rate 1000 -duration 60 -output basic_test.json +``` + +**Use Case**: General performance validation + +### **2. High-Load Stress Test** + +Push your instance to its limits: + +```bash +./benchmark -provider bifrost -port 8080 -rate 5000 -duration 120 -output stress_test.json +``` + +**Use Case**: Capacity planning and SLA validation + +### **3. Large Payload Test** + +Test with bigger request/response sizes: + +```bash +./benchmark -provider bifrost -port 8080 -rate 500 -duration 60 -big-payload=true -output large_payload.json +``` + +**Use Case**: Document processing, code generation workloads + +### **4. Endurance Test** + +Long-running stability test: + +```bash +./benchmark -provider bifrost -port 8080 -rate 1000 -duration 1800 -output endurance_test.json +``` + +**Use Case**: Production readiness validation (30-minute test) + +### **5. Comparative Benchmarking** + +Compare Bifrost against other providers: + +```bash +# Test Bifrost +./benchmark -provider bifrost -port 8080 -rate 1000 -duration 60 -output bifrost_results.json + +# Test LiteLLM +./benchmark -provider litellm -port 8000 -rate 1000 -duration 60 -output litellm_results.json + +# Test direct OpenAI (if available) +./benchmark -provider openai -port 443 -endpoint chat/completions -rate 1000 -duration 60 -output openai_results.json +``` + +--- + +## Understanding Results + +The benchmark tool generates detailed JSON results with comprehensive metrics: + +### **Key Metrics Explained** + +```json +{ + "bifrost": { + "request_counts": { + "total_sent": 30000, + "successful": 30000, + "failed": 0 + }, + "success_rate": 100.0, + "latency_metrics": { + "mean_ms": 245.5, + "p50_ms": 230.2, + "p99_ms": 520.8, + "max_ms": 845.3 + }, + "throughput_rps": 5000.0, + "memory_usage": { + "before_mb": 512.5, + "after_mb": 1312.8, + "peak_mb": 1405.2, + "average_mb": 1156.7 + }, + "timestamp": "2025-01-14T10:30:00Z", + "status_codes": { + "200": 30000 + } + } +} +``` + +### **Critical Performance Indicators** + +**Success Rate:** +- **Target**: >99.9% for production readiness +- **Excellent**: 100% (perfect reliability) + +**Latency Metrics:** +- **P50 (Median)**: Typical user experience +- **P99**: Worst-case user experience +- **Mean**: Overall average performance + +**Memory Usage:** +- **Peak**: Maximum memory consumption +- **Average**: Sustained memory usage +- **After - Before**: Memory growth during test + +--- + +## Instance Sizing Recommendations + +Based on your benchmark results, use these guidelines for production sizing: + +### **Resource Planning Matrix** + +| Target RPS | Memory Usage | Recommended Instance | Notes | +|------------|--------------|---------------------|--------| +| **< 1,000** | < 1GB | t3.small | Cost-effective for light loads | +| **1,000 - 3,000** | 1-2GB | t3.medium | Balanced performance/cost | +| **3,000 - 5,000** | 2-4GB | t3.large | High-performance production | +| **5,000+** | 3-6GB | t3.xlarge+ | Enterprise/mission-critical | + +### **Configuration Tuning Based on Results** + +**If seeing high latency:** +- Increase `initial_pool_size` +- Increase `buffer_size` +- Consider larger instance + +**If memory usage is high:** +- Decrease `initial_pool_size` +- Optimize `buffer_size` +- Monitor for memory leaks + +**If success rate < 100%:** +- Reduce request rate +- Increase timeout settings +- Check provider limits + +--- + +## Advanced Testing Scenarios + +### **Burst Load Testing** + +Simulate traffic spikes: + +```bash +# Normal load +./benchmark -provider bifrost -port 8080 -rate 1000 -duration 300 -output normal_load.json + +# Burst load (simulate 5x spike) +./benchmark -provider bifrost -port 8080 -rate 5000 -duration 60 -output burst_load.json +``` + +### **Multi-Instance Testing** + +Test horizontal scaling: + +```bash +# Instance 1 +./benchmark -provider bifrost-1 -port 8080 -rate 2500 -duration 120 -output instance_1.json & + +# Instance 2 +./benchmark -provider bifrost-2 -port 8081 -rate 2500 -duration 120 -output instance_2.json & + +# Wait for both to complete +wait +``` + +### **Different Payload Sizes** + +Compare performance across payload sizes: + +```bash +# Small payloads (default) +./benchmark -provider bifrost -port 8080 -rate 1000 -duration 60 -output small_payload.json + +# Large payloads +./benchmark -provider bifrost -port 8080 -rate 1000 -duration 60 -big-payload=true -output large_payload.json +``` + +--- + +## Continuous Benchmarking + +### **Automated Testing Pipeline** + +Set up regular performance regression testing: + +```bash +#!/bin/bash +# daily_benchmark.sh + +DATE=$(date +%Y%m%d_%H%M%S) +OUTPUT_DIR="benchmarks/$DATE" +mkdir -p $OUTPUT_DIR + +# Run standard benchmarks +./benchmark -provider bifrost -port 8080 -rate 1000 -duration 300 -output "$OUTPUT_DIR/standard.json" +./benchmark -provider bifrost -port 8080 -rate 3000 -duration 180 -output "$OUTPUT_DIR/high_load.json" +./benchmark -provider bifrost -port 8080 -rate 500 -duration 600 -big-payload=true -output "$OUTPUT_DIR/large_payload.json" + +echo "Benchmarks completed: $OUTPUT_DIR" +``` + +### **Performance Monitoring Integration** + +Monitor key metrics over time: +- **Success rate trends** +- **Latency percentile changes** +- **Memory usage patterns** +- **Throughput capacity** + +--- + +## Troubleshooting + +### **Common Issues** + +**Connection Refused:** +```bash +# Check if Bifrost is running +curl http://localhost:8080/health + +# Verify port configuration +netstat -an | grep 8080 +``` +- Check PORT is defined in `.env` file at root. + +**High Error Rates:** +- Check provider API key limits +- Verify Bifrost configuration +- Monitor upstream provider status +- Reduce request rate for baseline test + +**Memory Issues:** +- Monitor system resources during testing +- Check for memory leaks in long tests +- Adjust Bifrost pool sizes + +**Inconsistent Results:** +- Run multiple test iterations +- Account for network variability +- Use longer test durations (60+ seconds) +- Isolate testing environment +- Try hitting gateway requests to a Mock provider + +--- + +## Next Steps + +### **After Running Benchmarks** + +1. **Analyze Results**: Compare against [official benchmarks](./getting-started) +2. **Optimize Configuration**: Tune based on your specific results +3. **Plan Capacity**: Size instances based on measured performance +4. **Set Up Monitoring**: Track key metrics in production + +### **Compare Results** + +- **[t3.medium Performance](./t3.medium)** - Compare against medium instance results +- **[t3.xlarge Performance](./t3.xl)** - Compare against high-performance configuration + +**Ready to benchmark? Clone the [repository](https://github.com/maximhq/bifrost-benchmarking) and start testing!** diff --git a/docs/benchmarking/t3.medium.mdx b/docs/benchmarking/t3.medium.mdx new file mode 100644 index 000000000..a0371c1d4 --- /dev/null +++ b/docs/benchmarking/t3.medium.mdx @@ -0,0 +1,127 @@ +--- +title: "t3.medium" +description: "Detailed performance metrics and analysis for Bifrost running on AWS t3.medium instances (2 vCPUs, 4GB RAM)." +icon: "server" +--- + +## Instance Configuration + +**AWS t3.medium Specifications:** +- **vCPUs**: 2 +- **Memory**: 4GB RAM +- **Network Performance**: Up to 5 Gigabit + +**Bifrost Configuration:** +- **Buffer Size**: 15,000 +- **Initial Pool Size**: 10,000 +- **Test Load**: 5,000 requests per second (RPS) + +--- + +## Performance Results + +### **Overall Performance Metrics** + +| Metric | Value | Notes | +|--------|-------|--------| +| **Success Rate** | 100.00% | Perfect reliability under high load | +| **Average Request Size** | 0.13 KB | Lightweight request payload | +| **Average Response Size** | 1.37 KB | Standard response size for testing | +| **Average Latency** | 2.12s | Total end-to-end response time | +| **Peak Memory Usage** | 1,312.79 MB | ~33% of available 4GB RAM | + +### **Detailed Performance Breakdown** + +| Operation | Latency | Performance Notes | +|-----------|---------|-------------------| +| **Queue Wait Time** | 47.13 Β΅s | Time waiting in Bifrost's internal queue | +| **Key Selection Time** | 16 ns | Weighted API key selection | +| **Message Formatting** | 2.19 Β΅s | Request message preparation | +| **Params Preparation** | 436 ns | Parameter processing | +| **Request Body Preparation** | 2.65 Β΅s | HTTP request body assembly | +| **JSON Marshaling** | 63.47 Β΅s | JSON serialization time | +| **Request Setup** | 6.59 Β΅s | HTTP client configuration | +| **HTTP Request** | 1.56s | Actual provider API call time | +| **Error Handling** | 189 ns | Error processing overhead | +| **Response Parsing** | 11.30 ms | JSON response deserialization | + +**Bifrost's Total Overhead: 59 Β΅s*** + +*\*Excludes JSON marshalling and HTTP calls, which are required in any implementation* + +--- + +## Performance Analysis + +### **Strengths on t3.medium** + +1. **Perfect Reliability**: 100% success rate even at 5,000 RPS +2. **Memory Efficiency**: Uses only 33% of available RAM (1,312.79 MB / 4GB) +3. **Minimal Overhead**: Just 59 Β΅s of added latency per request +4. **Fast Operations**: Sub-microsecond performance for most internal operations + +### **Resource Utilization** + +- **Memory Usage**: Very efficient at 1,312.79 MB peak usage +- **CPU Performance**: Handles 5,000 RPS workload effectively +- **Queue Management**: 47.13 Β΅s average wait time indicates good throughput + +--- + +## Configuration Recommendations + +### **Optimal Settings for t3.medium** + +Based on test results, these configurations work well: + +```json +{ + "client": { + "initial_pool_size": 10000, + "buffer_size": 15000 + } +} +``` + +### **Tuning Opportunities** + +**For Lower Memory Usage:** +- Reduce `initial_pool_size` to 7,500-8,000 +- Decrease `buffer_size` to 12,000-13,000 +- Trade-off: Slightly higher latency + +**For Better Performance:** +- Increase `initial_pool_size` to 12,000-13,000 +- Increase `buffer_size` to 17,000-18,000 +- Trade-off: Higher memory usage (monitor RAM limits) + +--- + +## Comparison Context + +### **vs. t3.xlarge Performance** + +| Metric | t3.medium | t3.xlarge | Difference | +|--------|-----------|-----------|------------| +| **Bifrost Overhead** | 59 Β΅s | 11 Β΅s | +81% slower | +| **Queue Wait Time** | 47.13 Β΅s | 1.67 Β΅s | +96% slower | +| **JSON Marshaling** | 63.47 Β΅s | 26.80 Β΅s | +58% slower | +| **Response Parsing** | 11.30 ms | 2.11 ms | +81% slower | +| **Memory Usage** | 1,312.79 MB | 3,340.44 MB | -61% usage | + +**Key Insights:** +- t3.medium uses **61% less memory** than t3.xlarge +- Performance trade-offs are reasonable for cost savings +- Most operations still complete in microseconds + +--- + +## Next Steps + +**When to upgrade to t3.xlarge:** +- Sustained load approaches 4,000+ RPS +- Queue wait times consistently exceed 75 Β΅s +- Memory usage approaches 75% of available RAM + +- **[Run Your Own Benchmarks](./run-your-own-benchmarks)** to test with your specific workload +- **[Compare with t3.xlarge](./t3.xl)** for performance scaling analysis diff --git a/docs/benchmarking/t3.xl.mdx b/docs/benchmarking/t3.xl.mdx new file mode 100644 index 000000000..0c9c95210 --- /dev/null +++ b/docs/benchmarking/t3.xl.mdx @@ -0,0 +1,151 @@ +--- +title: "t3.xlarge" +description: "Detailed performance metrics and analysis for Bifrost running on AWS t3.xlarge instances (4 vCPUs, 16GB RAM)." +icon: "server" +--- + +## Instance Configuration + +**AWS t3.xlarge Specifications:** +- **vCPUs**: 4 +- **Memory**: 16GB RAM +- **Network Performance**: Up to 5 Gigabit + +**Bifrost Configuration:** +- **Buffer Size**: 20,000 +- **Initial Pool Size**: 15,000 +- **Test Load**: 5,000 requests per second (RPS) + +--- + +## Performance Results + +### **Overall Performance Metrics** + +| Metric | Value | Notes | +|--------|-------|--------| +| **Success Rate** | 100.00% | Perfect reliability under high load | +| **Average Request Size** | 0.13 KB | Lightweight request payload | +| **Average Response Size** | 10.32 KB | **Large response payload testing** | +| **Average Latency** | 1.61s | Total end-to-end response time | +| **Peak Memory Usage** | 3,340.44 MB | ~21% of available 16GB RAM | + +> **Note**: t3.xlarge tests used significantly larger response payloads (~10 KB vs ~1 KB on t3.medium) to stress-test performance with realistic production data sizes. + +### **Detailed Performance Breakdown** + +| Operation | Latency | Performance Notes | +|-----------|---------|-------------------| +| **Queue Wait Time** | 1.67 Β΅s | **96% faster** than t3.medium | +| **Key Selection Time** | 10 ns | **37% faster** weighted API key selection | +| **Message Formatting** | 2.11 Β΅s | Consistent with t3.medium performance | +| **Params Preparation** | 417 ns | Slight improvement over t3.medium | +| **Request Body Preparation** | 2.36 Β΅s | **11% faster** request assembly | +| **JSON Marshaling** | 26.80 Β΅s | **58% faster** serialization | +| **Request Setup** | 7.17 Β΅s | Comparable to t3.medium | +| **HTTP Request** | 1.50s | **4% faster** provider API calls | +| **Error Handling** | 162 ns | **14% faster** error processing | +| **Response Parsing** | 2.11 ms | **81% faster** despite 7.5x larger payloads | + +**Bifrost's Total Overhead: 11 Β΅s*** + +*\*Excludes JSON marshalling and HTTP calls, which are required in any implementation. 81% reduction compared to t3.medium (59 Β΅s β†’ 11 Β΅s)* + +--- + +## Performance Analysis + +### **Exceptional Performance Improvements** + +1. **Dramatic Overhead Reduction**: 81% lower Bifrost overhead (59 Β΅s β†’ 11 Β΅s) +2. **Superior Queue Management**: 96% faster queue wait times (47.13 Β΅s β†’ 1.67 Β΅s) +3. **Faster JSON Processing**: 58% improvement in marshaling despite larger payloads +4. **Efficient Response Parsing**: 81% faster parsing even with 7.5x larger responses +5. **Perfect Reliability**: 100% success rate maintained under high load + +### **Resource Utilization** + +- **Memory Efficiency**: Uses only 21% of available RAM (3,340.44 MB / 16GB) +- **CPU Performance**: Excellent multi-core utilization for 5,000 RPS +- **Headroom**: Substantial capacity for traffic spikes and growth + +--- + +## Scalability and Headroom + +### **Exceptional Scaling Characteristics** + +The t3.xlarge configuration demonstrates **excellent scaling potential**: + +**Current Utilization:** +- **Memory**: 21% used (13GB available headroom) +- **Queue Performance**: 1.67 Β΅s wait time (near-optimal) +- **Processing Speed**: Sub-microsecond for most operations + +**Scaling Potential:** +- **Traffic Spikes**: Can likely handle 15,000+ RPS bursts +- **Response Size Growth**: Efficiently handles 10 KB responses +- **Concurrent Users**: Supports thousands of simultaneous users + +--- + +## Advanced Configuration + +### **Optimal Settings for t3.xlarge** + +Based on test results, these configurations provide excellent performance: + +```json +{ + "client": { + "initial_pool_size": 15000, + "buffer_size": 20000 + } +} +``` + +### **Performance Tuning Opportunities** + +**For Maximum Performance:** +- Increase `initial_pool_size` to 18,000-20,000 +- Increase `buffer_size` to 25,000-30,000 +- Trade-off: Higher memory usage (still well within limits) + +**For Memory Optimization:** +- Current config already very efficient at 21% RAM usage +- Could reduce settings if needed, but performance gains would be lost + +**For Extreme Workloads:** +- Consider `initial_pool_size` up to 25,000 +- Increase `buffer_size` to 35,000+ +- Monitor memory usage approaching 50% of available RAM + +--- + +## Performance Comparison + +### **vs. t3.medium Performance** + +| Metric | t3.medium | t3.xlarge | Improvement | +|--------|-----------|-----------|-------------| +| **Bifrost Overhead** | 59 Β΅s | 11 Β΅s | **-81%** | +| **Average Latency** | 2.12s | 1.61s | **-24%** | +| **Queue Wait Time** | 47.13 Β΅s | 1.67 Β΅s | **-96%** | +| **JSON Marshaling** | 63.47 Β΅s | 26.80 Β΅s | **-58%** | +| **Response Parsing** | 11.30 ms | 2.11 ms | **-81%** | +| **Response Size Handled** | 1.37 KB | 10.32 KB | **+7.5x** | +| **Peak Memory Usage** | 1,312.79 MB | 3,340.44 MB | +155% | +| **Memory Utilization** | 33% | 21% | **-36%** | + +**Key Insights:** +- **81% overhead reduction** while handling 7.5x larger responses +- **Exceptional efficiency** with only 21% memory utilization +- **Dramatic queue performance** improvements +- **Substantial headroom** for growth and traffic spikes + +--- + +## Next Steps + +- **[Run Your Own Benchmarks](./run-your-own-benchmarks)** with your specific payload sizes +- **[Compare with t3.medium](./t3.medium)** for cost-optimization analysis diff --git a/docs/changelogs/v1.2.21.mdx b/docs/changelogs/v1.2.21.mdx new file mode 100644 index 000000000..8ac2cdaa3 --- /dev/null +++ b/docs/changelogs/v1.2.21.mdx @@ -0,0 +1,50 @@ +--- +title: "v1.2.21" +description: "v1.2.21 changelog" +--- + + +- Fixes pricing computation for nested model names i.e. groq/openai/gpt-oss-20b. + + + + +- Pricing module now accommodates nested model names i.e. groq/openai/gpt-oss-20b was getting skipped while computing costs. + + + + +- Upgrades framework to 1.0.23 + + + + +- Upgrades framework to 1.0.23 + + + + +- Upgrades framework to 1.0.23 +- Fixes pricing computation for nested model names. + + + + +- Upgrades framework to 1.0.23 + + + + +- Upgrades framework to 1.0.23 + + + + +- Upgrades framework to 1.0.23 + + + + +- Upgrades framework to 1.0.23 + + diff --git a/docs/contributing/building-a-plugins.mdx b/docs/contributing/building-a-plugins.mdx new file mode 100644 index 000000000..e69de29bb diff --git a/docs/contributing/code-conventions.mdx b/docs/contributing/code-conventions.mdx new file mode 100644 index 000000000..e69de29bb diff --git a/docs/contributing/raising-a-pr.mdx b/docs/contributing/raising-a-pr.mdx new file mode 100644 index 000000000..e69de29bb diff --git a/docs/contributing/running-tests.mdx b/docs/contributing/running-tests.mdx new file mode 100644 index 000000000..e69de29bb diff --git a/docs/contributing/setting-up-repo.mdx b/docs/contributing/setting-up-repo.mdx new file mode 100644 index 000000000..e69de29bb diff --git a/docs/deployment/docker-setup.mdx b/docs/deployment/docker-setup.mdx new file mode 100644 index 000000000..e69de29bb diff --git a/docs/docs.json b/docs/docs.json new file mode 100644 index 000000000..da2648049 --- /dev/null +++ b/docs/docs.json @@ -0,0 +1,184 @@ +{ + "$schema": "https://mintlify.com/schema.json", + "name": "Bifrost", + "logo": { + "dark": "/media/bifrost-logo-dark.png", + "light": "/media/bifrost-logo.png" + }, + "theme": "palm", + "colors": { + "primary": "#0C3B43", + "light": "#07C983" + }, + "topbarLinks": [ + { + "name": "Support", + "url": "mailto:akshay@getmaxim.ai" + } + ], + "topbarCtaButton": { + "name": "Dashboard", + "url": "https://www.getbifrost.ai" + }, + "anchors": [ + { + "name": "Community", + "icon": "discord", + "url": "https://getmax.im/bifrost-discord" + }, + { + "name": "Blog", + "icon": "newspaper", + "url": "https://getmaxim.ai/blog" + } + ], + "navigation": { + "tabs": [ + { + "tab": "Documentation", + "icon": "book-open-cover", + "groups": [ + { + "group": "Quick Start", + "icon": "rocket", + "pages": [ + { + "group": "Gateway", + "icon": "server", + "pages": [ + "quickstart/gateway/setting-up", + "quickstart/gateway/provider-configuration", + "quickstart/gateway/streaming", + "quickstart/gateway/tool-calling", + "quickstart/gateway/multimodal", + "quickstart/gateway/integrations" + ] + }, + { + "group": "Use as Go SDK", + "icon": "code", + "pages": [ + "quickstart/go-sdk/setting-up", + "quickstart/go-sdk/provider-configuration", + "quickstart/go-sdk/streaming", + "quickstart/go-sdk/tool-calling", + "quickstart/go-sdk/multimodal" + ] + } + ] + }, + { + "group": "Provider Integrations", + "icon": "plug", + "pages": [ + "integrations/what-is-an-integration", + "integrations/openai-sdk", + "integrations/anthropic-sdk", + "integrations/genai-sdk", + "integrations/litellm-sdk", + "integrations/langchain-sdk" + ] + }, + { + "group": "Open Source Features", + "icon": "bolt", + "pages": [ + "features/unified-interface", + "features/drop-in-replacement", + "features/fallbacks", + "features/keys-management", + "features/mcp", + "features/tracing", + "features/telemetry", + "features/observability", + "features/governance", + "features/semantic-caching", + "features/custom-providers", + { + "group": "Plugins", + "icon": "puzzle-piece", + "pages": [ + "features/plugins/mocker", + "features/plugins/jsonparser" + ] + } + ] + }, + { + "group": "Enterprise Features", + "icon": "building", + "pages": [ + "enterprise/clustering", + "enterprise/governance", + "enterprise/mcp-with-fa", + "enterprise/vault-support", + "enterprise/invpc-deployments", + "enterprise/intelligent-load-balancing", + "enterprise/custom-plugins", + "enterprise/log-exports" + ] + } + ] + }, + { + "tab": "API Reference", + "icon": "code", + "groups": [ + { + "group": "API Reference", + "openapi": "apis/openapi.json" + } + ] + }, + { + "tab": "Architecture", + "icon": "codepen", + "pages": [ + { + "group": "Core Architecture", + "icon": "sitemap", + "pages": [ + "architecture/core/concurrency", + "architecture/core/request-flow", + "architecture/core/mcp", + "architecture/core/plugins" + ] + }, + { + "group": "Framework", + "icon": "screwdriver-wrench", + "pages": [ + "architecture/framework/what-is-framework", + "architecture/framework/pricing", + "architecture/framework/vector-store" + ] + } + ] + }, + { + "tab": "Benchmarks", + "icon": "chart-line", + "pages": [ + "benchmarking/getting-started", + "benchmarking/t3.medium", + "benchmarking/t3.xl", + "benchmarking/run-your-own-benchmarks" + ] + }, + { + "tab": "Changelogs", + "icon": "bolt", + "pages": [ + "changelogs/v1.2.21" + ] + } + ] + }, + "footer": { + "socials": { + "x": "https://x.com/getmaximai", + "github": "https://github.com/maximhq/bifrost", + "linkedin": "https://linkedin.com/company/maxim-ai" + } + } +} diff --git a/docs/enterprise/clustering.mdx b/docs/enterprise/clustering.mdx new file mode 100644 index 000000000..3d76006cc --- /dev/null +++ b/docs/enterprise/clustering.mdx @@ -0,0 +1,417 @@ +--- +title: "Clustering" +description: "High-availability peer-to-peer clustering with intelligent traffic distribution, automatic failover, and gossip-based state synchronization for enterprise-scale deployments." +icon: "circle-nodes" +--- + +## Overview + +**Bifrost Clustering** provides enterprise-grade high availability through a peer-to-peer network architecture that ensures continuous service availability, intelligent traffic distribution, and automatic failover capabilities. The clustering system uses gossip protocols to maintain consistent state across all nodes while providing seamless scaling and fault tolerance. + +### Why Clustering is Required + +Modern AI gateway deployments face several critical challenges that clustering addresses: + +| Challenge | Impact | Clustering Solution | +|-----------|--------|-------------------| +| **Single Point of Failure** | Complete service outage if gateway fails | Distributed architecture with automatic failover | +| **Traffic Spikes** | Performance degradation under high load | Dynamic load distribution across multiple nodes | +| **Provider Rate Limits** | Request throttling and service interruption | Distributed rate limit tracking and intelligent routing | +| **Regional Latency** | Poor user experience in distant regions | Geographic distribution with local processing | +| **Maintenance Windows** | Service downtime during updates | Rolling updates with zero-downtime deployment | +| **Capacity Planning** | Over/under-provisioning resources | Elastic scaling based on real-time demand | + +### Key Benefits + +| Feature | Description | +|---------|-------------| +| **Peer-to-Peer Architecture** | No single point of failure with equal node participation | +| **Gossip-Based State Sync** | Real-time synchronization of traffic patterns and limits | +| **Automatic Failover** | Seamless traffic redistribution when nodes fail | +| **Request Migration** | Ongoing requests continue on healthy nodes | +| **Zero-Downtime Updates** | Rolling deployments without service interruption | +| **Intelligent Load Distribution** | AI-driven traffic routing based on node capacity | + +--- + +## Architecture + +### Peer-to-Peer Network Design + +Bifrost clustering uses a **peer-to-peer (P2P) network** where all nodes are equal participants. This design eliminates single points of failure and provides superior fault tolerance compared to master-slave architectures. + +![Clustering diagram](../../media/clustering-diagram.png) + +### Minimum Node Requirements + +**Recommended: 3+ nodes minimum** for optimal fault tolerance and consensus. + +| Cluster Size | Fault Tolerance | Use Case | +|--------------|-----------------|----------| +| **3 nodes** | 1 node failure | Small production deployments | +| **5 nodes** | 2 node failures | Medium production deployments | +| **7+ nodes** | 3+ node failures | Large enterprise deployments | + +--- + +## Gossip Protocol Implementation + +### State Synchronization + +The gossip protocol ensures all nodes maintain consistent views of: + +- **Traffic Patterns**: Request volume, latency metrics, error rates per model-key-id +- **Rate Limit States**: Current usage counters for each provider/model combination +- **Node Health**: CPU, memory, network status of all peers +- **Configuration Changes**: Provider updates, routing rules, policies +- **Model Performance**: Real-time metrics for intelligent load balancing +- **Provider Weights**: Dynamic weight adjustments based on performance + + +### Convergence Guarantees + +- **Eventually Consistent**: All nodes converge to the same state within seconds +- **Partition Tolerance**: Nodes continue operating during network splits +- **Conflict Resolution**: Timestamp-based ordering for conflicting updates + +--- + +## Automatic Failover & Request Migration + +### Node Failure Detection + +Bifrost uses multiple failure detection mechanisms: + +1. **Heartbeat Monitoring**: Regular ping/pong between all nodes +2. **Request Timeout Tracking**: Failed API calls indicate node issues +3. **Gossip Silence Detection**: Missing gossip messages trigger health checks +4. **Load Balancer Health Checks**: External monitoring integration + +### Traffic Redistribution + +When a node fails, traffic is automatically redistributed: + +![Traffic distribution](../../media/traffic-redistribution.png) + +### Request Migration Strategies + +Based on configuration, ongoing requests can be handled in multiple ways: + +| Strategy | Description | Use Case | +|----------|-------------|----------| +| **Complete on Origin** | Requests finish on the original node | Stateful operations | +| **Migrate to Healthy Node** | Transfer to available nodes | Stateless operations | +| **Retry with Backoff** | Restart request on healthy node | Idempotent operations | +| **Circuit Breaker** | Fail fast and return error | Time-sensitive operations | + +--- + +## Configuration + +### Basic Cluster Setup + +```json +{ + "cluster": { + "enabled": true, + "node_id": "bifrost-node-1", + "bind_address": "0.0.0.0:8080", + "peers": [ + "bifrost-node-2:8080", + "bifrost-node-3:8080" + ], + "gossip": { + "port": 7946, + "interval": "1s", + "timeout": "5s" + } + } +} +``` + +### Advanced Clustering Options + +```json +{ + "cluster": { + "enabled": true, + "node_id": "bifrost-node-1", + "bind_address": "0.0.0.0:8080", + "peers": [ + "bifrost-node-2:8080", + "bifrost-node-3:8080" + ], + "gossip": { + "port": 7946, + "interval": "1s", + "timeout": "5s", + "max_packet_size": 1400, + "compression": true + }, + "failover": { + "detection_threshold": 3, + "recovery_timeout": "30s", + "request_migration": "migrate_to_healthy" + }, + "load_balancing": { + "algorithm": "weighted_round_robin", + "health_check_interval": "10s", + "weight_adjustment": "auto" + } + } +} +``` + +### Request Migration Configuration + +```json +{ + "cluster": { + "failover": { + "request_migration": "migrate_to_healthy", + "migration_strategies": { + "chat_completions": "migrate_to_healthy", + "embeddings": "complete_on_origin", + "streaming": "circuit_breaker" + }, + "timeout_behavior": { + "short_timeout": "retry_with_backoff", + "long_timeout": "migrate_to_healthy" + } + } + } +} +``` + +--- + +## Deployment Patterns + +### Docker Compose Cluster + +```yaml +version: '3.8' +services: + bifrost-node-1: + image: bifrost:latest + environment: + - CLUSTER_ENABLED=true + - NODE_ID=bifrost-node-1 + - PEERS=bifrost-node-2:8080,bifrost-node-3:8080 + ports: + - "8080:8080" + - "7946:7946" + + bifrost-node-2: + image: bifrost:latest + environment: + - CLUSTER_ENABLED=true + - NODE_ID=bifrost-node-2 + - PEERS=bifrost-node-1:8080,bifrost-node-3:8080 + ports: + - "8081:8080" + - "7947:7946" + + bifrost-node-3: + image: bifrost:latest + environment: + - CLUSTER_ENABLED=true + - NODE_ID=bifrost-node-3 + - PEERS=bifrost-node-1:8080,bifrost-node-2:8080 + ports: + - "8082:8080" + - "7948:7946" +``` + +### Kubernetes Deployment + +```yaml +apiVersion: apps/v1 +kind: StatefulSet +metadata: + name: bifrost-cluster +spec: + serviceName: bifrost-cluster + replicas: 3 + selector: + matchLabels: + app: bifrost + template: + metadata: + labels: + app: bifrost + spec: + containers: + - name: bifrost + image: bifrost:latest + env: + - name: CLUSTER_ENABLED + value: "true" + - name: NODE_ID + valueFrom: + fieldRef: + fieldPath: metadata.name + - name: PEERS + value: "bifrost-cluster-0.bifrost-cluster:8080,bifrost-cluster-1.bifrost-cluster:8080,bifrost-cluster-2.bifrost-cluster:8080" + ports: + - containerPort: 8080 + name: api + - containerPort: 7946 + name: gossip +``` + +--- + +## Monitoring & Observability + +### Cluster Health Metrics + +Monitor these key metrics for cluster health: + +```json +{ + "cluster_metrics": { + "nodes_total": 3, + "nodes_healthy": 3, + "nodes_failed": 0, + "gossip_messages_per_second": 45, + "state_convergence_time_ms": 250, + "request_migration_rate": 0.001, + "load_distribution": { + "node-1": 0.33, + "node-2": 0.34, + "node-3": 0.33 + }, + "provider_performance": { + "openai": { + "total_traffic_percentage": 64.0, + "model_keys": { + "gpt-4-key-1": { + "avg_latency_ms": 1200, + "current_weight": 0.8, + "error_rate": 0.01, + "traffic_percentage": 45.2, + "health_status": "healthy" + }, + "gpt-4-key-2": { + "avg_latency_ms": 1450, + "current_weight": 0.6, + "error_rate": 0.03, + "traffic_percentage": 18.8, + "health_status": "degraded" + } + } + }, + "anthropic": { + "total_traffic_percentage": 36.0, + "model_keys": { + "claude-3-key-1": { + "avg_latency_ms": 980, + "current_weight": 1.0, + "error_rate": 0.005, + "traffic_percentage": 28.5, + "health_status": "healthy" + }, + "claude-3-key-2": { + "avg_latency_ms": 1100, + "current_weight": 0.9, + "error_rate": 0.008, + "traffic_percentage": 7.5, + "health_status": "healthy" + } + } + } + } + } +} +``` + +### Alerting Rules + +Set up alerts for critical cluster events: + +**Cluster-Level Alerts:** +- Node failure detection +- High request migration rates +- Gossip convergence delays +- Uneven load distribution +- Network partition events + +**Model-Key-ID Performance Alerts:** +- High error rates per model-key-id (> 2.5%) +- Latency spikes per model-key-id (> 150% of baseline) +- Weight adjustments frequency (> 10 per minute) +- Traffic imbalance across model keys (> 80% on single key) +- Provider-level performance degradation + +**Example Alert Configuration:** +```yaml +alerts: + - name: "High Error Rate - Model Key" + condition: "error_rate > 0.025" + scope: "model_key_id" + action: "reduce_weight" + + - name: "Latency Spike - Model Key" + condition: "avg_latency_ms > baseline * 1.5" + scope: "model_key_id" + action: "temporary_circuit_break" + + - name: "Traffic Imbalance - Provider" + condition: "single_key_traffic_percentage > 0.8" + scope: "provider" + action: "rebalance_weights" +``` + +--- + +## Best Practices + +### Deployment Guidelines + +1. **Use Odd Number of Nodes**: Prevents split-brain scenarios +2. **Geographic Distribution**: Deploy across availability zones +3. **Resource Sizing**: Ensure nodes can handle redistributed load +4. **Network Security**: Secure gossip communication with encryption +5. **Monitoring Setup**: Implement comprehensive cluster monitoring + +### Performance Optimization + +1. **Gossip Tuning**: Adjust interval based on cluster size and network latency +2. **Load Balancer Configuration**: Use health checks and proper timeouts +3. **Request Routing**: Optimize based on provider latency and capacity +4. **State Compression**: Enable gossip compression for large clusters +5. **Connection Pooling**: Maintain persistent connections between nodes + +### Troubleshooting + +Common issues and solutions: + +| Issue | Symptoms | Solution | +|-------|----------|----------| +| **Split Brain** | Inconsistent responses | Ensure odd number of nodes | +| **Gossip Storms** | High network usage | Tune gossip interval and packet size | +| **Uneven Load** | Some nodes overloaded | Check load balancing configuration | +| **Migration Loops** | Requests bouncing between nodes | Review migration strategies | + +--- + +## Security Considerations + +### Network Security + +- **Gossip Encryption**: Enable TLS for gossip protocol communication +- **API Authentication**: Secure inter-node API calls with mutual TLS +- **Network Segmentation**: Isolate cluster traffic in private networks +- **Firewall Rules**: Restrict gossip ports to cluster nodes only + +### Access Control + +- **Node Authentication**: Verify node identity before joining cluster +- **Configuration Signing**: Cryptographically sign configuration updates +- **Audit Logging**: Track all cluster membership and configuration changes +- **Secret Management**: Secure storage and rotation of cluster secrets + +--- + +This clustering architecture ensures Bifrost can handle enterprise-scale deployments with high availability, automatic failover, and intelligent traffic distribution while maintaining security and performance standards. diff --git a/docs/enterprise/custom-plugins.mdx b/docs/enterprise/custom-plugins.mdx new file mode 100644 index 000000000..44a5a5673 --- /dev/null +++ b/docs/enterprise/custom-plugins.mdx @@ -0,0 +1,16 @@ +--- +title: "Custom Plugins" +description: "Build and deploy enterprise-specific plugins to extend Bifrost's functionality with custom business logic, integrations, and workflow automation." +icon: "plug" +--- + +At Bifrost, we understand that every organization has unique requirements for their LLM infrastructure, workflows, and AI-specific business logic that can't always be addressed by off-the-shelf solutions. That's why we offer comprehensive custom plugin development services to help companies extend Bifrost's LLM gateway functionality with tailored solutions that perfectly fit their specific AI and machine learning needs. + +Our expert team works closely with your organization to design, develop, and deploy custom plugins that integrate seamlessly with your LLM infrastructure and AI workflows. We handle everything from initial consultation to ongoing maintenance. + +- **Custom AI Business Logic Implementation** - Embed your unique AI governance rules and LLM processing logic directly into Bifrost +- **LLM Provider Integrations** - Connect Bifrost with proprietary or specialized LLM providers and AI services +- **AI Workflow Automation** - Automate complex multi-step LLM processes specific to your AI use cases +- **AI Security & Compliance Extensions** - Implement custom AI safety policies, content filtering, and compliance requirements +- **LLM Performance Optimization** - Build plugins optimized for your specific LLM workloads and scaling requirements + diff --git a/docs/enterprise/governance.mdx b/docs/enterprise/governance.mdx new file mode 100644 index 000000000..4eac07fc4 --- /dev/null +++ b/docs/enterprise/governance.mdx @@ -0,0 +1,797 @@ +--- +title: "Governance" +description: "Advanced governance features with enhanced security, compliance reporting, audit trails, and enterprise-grade access controls for large-scale deployments." +icon: "shield-check" +--- + +## Overview + +Enterprise Governance extends Bifrost's [core governance capabilities](../features/governance) with advanced security, compliance, and user management features designed for large-scale enterprise deployments. This module provides comprehensive identity management, regulatory compliance, and detailed audit capabilities. + +**Enterprise Extensions:** +- **Identity & Access Management** - SAML 2.0 and OpenID Connect integration +- **Directory Services** - Active Directory and LDAP user synchronization +- **User-Level Governance** - Individual user authentication and budget allocation +- **Compliance Framework** - SOC 2 Type II, GDPR, ISO 27001, and HIPAA compliance +- **Advanced Auditing** - Comprehensive audit reports and compliance dashboards + +**Builds Upon Core Governance:** +- All standard [Virtual Keys, Teams, and Customers](../features/governance) functionality +- Hierarchical budget management and rate limiting +- Model and provider access controls +- Usage tracking and cost management + +--- + +## SAML & OpenID Connect Integration + +Enterprise Governance provides seamless integration with corporate identity providers through industry-standard authentication protocols. + +### SAML 2.0 Configuration + +**Supported Identity Providers:** +- Microsoft Azure AD / Entra ID +- Okta +- Google Workspace +- Ping Identity (Coming soon) +- Auth0 + + + + +1. **Navigate to Enterprise Settings** + - Open Bifrost UI at `http://localhost:8080` + - Go to **Enterprise** β†’ **Identity Providers** + +2. **Configure SAML Provider** + +**Required Fields:** +- **Provider Name**: Identity provider identifier +- **SSO URL**: SAML SSO endpoint +- **Entity ID**: SAML entity identifier +- **X.509 Certificate**: Identity provider signing certificate + +**Attribute Mapping:** +- **Email Attribute**: `http://schemas.xmlsoap.org/ws/2005/05/identity/claims/emailaddress` +- **Name Attribute**: `http://schemas.xmlsoap.org/ws/2005/05/identity/claims/name` +- **Groups Attribute**: `http://schemas.xmlsoap.org/ws/2005/05/identity/claims/groups` +- **Department Attribute**: `http://schemas.xmlsoap.org/ws/2005/05/identity/claims/department` + +**User Provisioning:** +- **Auto-Create Users**: Automatically create users on first login +- **Default Customer**: Assign new users to default customer +- **Default Team**: Assign new users to default team +- **Default Budget**: Initial budget allocation per user + +3. **Save Configuration** + - Click **Configure SAML Provider** + - Test SSO integration + - Enable for production use + + + + +**Configure SAML Provider:** +```bash +curl -X POST http://localhost:8080/api/enterprise/identity-providers \ + -H "Content-Type: application/json" \ + -d '{ + "type": "saml", + "name": "Azure AD Corporate", + "config": { + "sso_url": "https://login.microsoftonline.com/tenant-id/saml2", + "entity_id": "https://sts.windows.net/tenant-id/", + "x509_certificate": "-----BEGIN CERTIFICATE-----\nMIIC...\n-----END CERTIFICATE-----", + "attribute_mapping": { + "email": "http://schemas.xmlsoap.org/ws/2005/05/identity/claims/emailaddress", + "name": "http://schemas.xmlsoap.org/ws/2005/05/identity/claims/name", + "groups": "http://schemas.xmlsoap.org/ws/2005/05/identity/claims/groups", + "department": "http://schemas.xmlsoap.org/ws/2005/05/identity/claims/department" + }, + "user_provisioning": { + "auto_create": true, + "default_customer_id": "customer-corp", + "default_team_id": "team-general", + "default_budget": { + "max_limit": 50.00, + "reset_duration": "1M" + } + } + }, + "is_active": true + }' +``` + +**Test SAML Configuration:** +```bash +curl -X POST http://localhost:8080/api/enterprise/identity-providers/{provider_id}/test \ + -H "Content-Type: application/json" \ + -d '{ + "test_user_email": "test@company.com" + }' +``` + + + + +```json +{ + "enterprise": { + "identity_providers": [ + { + "id": "saml-azure-ad", + "type": "saml", + "name": "Azure AD Corporate", + "config": { + "sso_url": "https://login.microsoftonline.com/tenant-id/saml2", + "entity_id": "https://sts.windows.net/tenant-id/", + "x509_certificate": "-----BEGIN CERTIFICATE-----\nMIIC...\n-----END CERTIFICATE-----", + "attribute_mapping": { + "email": "http://schemas.xmlsoap.org/ws/2005/05/identity/claims/emailaddress", + "name": "http://schemas.xmlsoap.org/ws/2005/05/identity/claims/name", + "groups": "http://schemas.xmlsoap.org/ws/2005/05/identity/claims/groups", + "department": "http://schemas.xmlsoap.org/ws/2005/05/identity/claims/department" + }, + "user_provisioning": { + "auto_create": true, + "default_customer_id": "customer-corp", + "default_team_id": "team-general", + "default_budget": { + "max_limit": 50.00, + "reset_duration": "1M" + } + } + }, + "is_active": true + } + ] + } +} +``` + + + + +### OpenID Connect Configuration + +**Supported Providers:** +- Google Workspace +- Microsoft Azure AD +- Okta +- Auth0 +- Keycloak +- Generic OIDC providers + + + + +1. **Navigate to Identity Providers** + - Go to **Enterprise** β†’ **Identity Providers** + - Click **Add OpenID Connect Provider** + +2. **Configure OIDC Provider** + +**Required Fields:** +- **Provider Name**: OIDC provider identifier +- **Client ID**: Application client identifier +- **Client Secret**: Application client secret +- **Discovery URL**: OIDC discovery endpoint +- **Scopes**: Required OAuth scopes + +**Advanced Settings:** +- **Token Validation**: JWT signature verification +- **Group Claims**: Map OIDC groups to Bifrost teams +- **Role Claims**: Map OIDC roles to permissions + + + + +**Configure OIDC Provider:** +```bash +curl -X POST http://localhost:8080/api/enterprise/identity-providers \ + -H "Content-Type: application/json" \ + -d '{ + "type": "oidc", + "name": "Google Workspace", + "config": { + "client_id": "client-id.apps.googleusercontent.com", + "client_secret": "client-secret", + "discovery_url": "https://accounts.google.com/.well-known/openid_configuration", + "scopes": ["openid", "email", "profile", "groups"], + "claims_mapping": { + "email": "email", + "name": "name", + "groups": "groups", + "department": "department" + }, + "user_provisioning": { + "auto_create": true, + "group_team_mapping": { + "engineering@company.com": "team-eng-001", + "sales@company.com": "team-sales-001" + } + } + }, + "is_active": true + }' +``` + + + + +--- + +## Active Directory Integration + +Enterprise Governance provides native integration with Microsoft Active Directory and LDAP directories for automated user provisioning and group synchronization. + +### Active Directory Configuration + +**Features:** +- **User Synchronization** - Automatic user import and updates +- **Group Mapping** - AD groups to Bifrost teams/customers +- **Attribute Mapping** - Custom user attribute synchronization +- **Scheduled Sync** - Automated periodic synchronization + + + + +1. **Navigate to Directory Services** + - Go to **Enterprise** β†’ **Directory Services** + - Click **Configure Active Directory** + +2. **Connection Settings** + +**Required Fields:** +- **Domain Controller**: AD server hostname/IP +- **Base DN**: Directory search base +- **Bind DN**: Service account distinguished name +- **Bind Password**: Service account password +- **Port**: LDAP port (389 or 636 for SSL) + +**Sync Settings:** +- **User Filter**: LDAP filter for user objects +- **Group Filter**: LDAP filter for group objects +- **Sync Schedule**: Automated sync frequency +- **Sync Scope**: Full or incremental synchronization + +3. **Attribute Mapping** + +**User Attributes:** +- **Email**: `mail` or `userPrincipalName` +- **Display Name**: `displayName` +- **Department**: `department` +- **Manager**: `manager` +- **Employee ID**: `employeeID` + +**Group Mapping:** +- Map AD groups to Bifrost teams +- Set default customer assignments +- Configure budget inheritance + + + + +**Configure Active Directory:** +```bash +curl -X POST http://localhost:8080/api/enterprise/directory-services \ + -H "Content-Type: application/json" \ + -d '{ + "type": "active_directory", + "name": "Corporate AD", + "config": { + "connection": { + "host": "dc.company.com", + "port": 389, + "use_ssl": false, + "base_dn": "DC=company,DC=com", + "bind_dn": "CN=bifrost-service,OU=Service Accounts,DC=company,DC=com", + "bind_password": "service-password" + }, + "sync_settings": { + "user_filter": "(&(objectClass=user)(!(userAccountControl:1.2.840.113556.1.4.803:=2)))", + "group_filter": "(objectClass=group)", + "sync_schedule": "0 2 * * *", + "sync_scope": "incremental" + }, + "attribute_mapping": { + "email": "userPrincipalName", + "name": "displayName", + "department": "department", + "manager": "manager", + "employee_id": "employeeID" + }, + "group_mapping": { + "CN=Engineering,OU=Groups,DC=company,DC=com": { + "team_id": "team-eng-001", + "customer_id": "customer-corp" + }, + "CN=Sales,OU=Groups,DC=company,DC=com": { + "team_id": "team-sales-001", + "customer_id": "customer-corp" + } + } + }, + "is_active": true + }' +``` + +**Trigger Manual Sync:** +```bash +curl -X POST http://localhost:8080/api/enterprise/directory-services/{service_id}/sync \ + -H "Content-Type: application/json" \ + -d '{ + "sync_type": "full" + }' +``` + + + + +### LDAP Configuration + +**Supported LDAP Servers:** +- Microsoft Active Directory +- OpenLDAP +- Apache Directory Server +- Oracle Directory Server +- IBM Security Directory Server + +**Configuration Example:** +```bash +curl -X POST http://localhost:8080/api/enterprise/directory-services \ + -H "Content-Type: application/json" \ + -d '{ + "type": "ldap", + "name": "OpenLDAP Corporate", + "config": { + "connection": { + "host": "ldap.company.com", + "port": 636, + "use_ssl": true, + "base_dn": "ou=people,dc=company,dc=com", + "bind_dn": "cn=bifrost,ou=service,dc=company,dc=com", + "bind_password": "service-password" + }, + "user_mapping": { + "email": "mail", + "name": "cn", + "department": "ou", + "groups": "memberOf" + } + } + }' +``` + +--- + +## User-Level Authentication & Budgeting + +Enterprise Governance extends the hierarchical governance model to include individual user-level controls, providing granular access management and personalized budget allocation. + +### User Management + +**Enhanced Hierarchy:** +``` +Customer (organization-level budget) + ↓ +Team (department-level budget) + ↓ +User (individual-level budget + authentication) + ↓ +Virtual Key (API-level budget + rate limits) +``` + +**User Features:** +- **Individual Authentication** - Personal login credentials +- **Personal Budgets** - User-specific cost allocation +- **Access Controls** - Per-user model and provider restrictions +- **Usage Tracking** - Individual consumption monitoring +- **Audit Trails** - User-specific activity logging + +### User Configuration + + + + +1. **Navigate to Users** + - Go to **Enterprise** β†’ **Users** + - Click **Create User** or import from directory + +2. **User Details** + +**Basic Information:** +- **Email**: Primary identifier +- **Display Name**: Full name +- **Department**: Organizational unit +- **Manager**: Reporting structure +- **Employee ID**: HR system identifier + +**Authentication:** +- **SSO Integration**: Link to identity provider +- **Multi-Factor Auth**: Require MFA for access +- **Session Management**: Control session duration + +**Budget Allocation:** +- **Personal Budget**: Individual spending limit +- **Budget Period**: Reset frequency +- **Inheritance**: Inherit team/customer budgets + +**Access Controls:** +- **Allowed Models**: Restrict model access +- **Allowed Providers**: Restrict provider access +- **Team Assignment**: Primary team membership +- **Customer Assignment**: Organization membership + + + + +**Create User:** +```bash +curl -X POST http://localhost:8080/api/enterprise/users \ + -H "Content-Type: application/json" \ + -d '{ + "email": "alice@company.com", + "display_name": "Alice Johnson", + "department": "Engineering", + "employee_id": "EMP001", + "team_id": "team-eng-001", + "customer_id": "customer-corp", + "authentication": { + "sso_provider_id": "saml-azure-ad", + "require_mfa": true, + "session_duration": "8h" + }, + "budget": { + "max_limit": 25.00, + "reset_duration": "1M", + "inherit_team_budget": true, + "inherit_customer_budget": true + }, + "access_control": { + "allowed_models": ["gpt-4o-mini", "claude-3-haiku-20240307"], + "allowed_providers": ["openai", "anthropic"], + "max_virtual_keys": 3 + }, + "is_active": true + }' +``` + +**Update User:** +```bash +curl -X PUT http://localhost:8080/api/enterprise/users/{user_id} \ + -H "Content-Type: application/json" \ + -d '{ + "budget": { + "max_limit": 50.00, + "reset_duration": "1M" + }, + "access_control": { + "allowed_models": ["gpt-4o", "claude-3-sonnet-20240229"] + } + }' +``` + + + + +### User Authentication Flow + +**SSO Authentication:** +```bash +# 1. Initiate SSO login +curl -X GET http://localhost:8080/api/enterprise/auth/saml/login?provider=azure-ad + +# 2. After SSO callback, get user token +curl -X POST http://localhost:8080/api/enterprise/auth/token \ + -H "Content-Type: application/json" \ + -d '{ + "saml_response": "base64-encoded-saml-response" + }' + +# 3. Use token for API requests +curl -X POST http://localhost:8080/v1/chat/completions \ + -H "Authorization: Bearer user-jwt-token" \ + -H "Content-Type: application/json" \ + -d '{ + "model": "gpt-4o-mini", + "messages": [{"role": "user", "content": "Hello!"}] + }' +``` + +**Virtual Key with User Context:** +```bash +# Create user-specific virtual key +curl -X POST http://localhost:8080/api/governance/virtual-keys \ + -H "Authorization: Bearer user-jwt-token" \ + -H "Content-Type: application/json" \ + -d '{ + "name": "Alice Personal API Key", + "user_id": "user-alice-001", + "budget": { + "max_limit": 10.00, + "reset_duration": "1w" + } + }' + +# Use virtual key with user tracking +curl -X POST http://localhost:8080/v1/chat/completions \ + -H "x-bf-vk: vk-alice-personal" \ + -H "x-bf-user-id: user-alice-001" \ + -H "Content-Type: application/json" \ + -d '{ + "model": "gpt-4o-mini", + "messages": [{"role": "user", "content": "Hello!"}] + }' +``` + +--- + +## Compliance Framework + +Enterprise Governance includes built-in compliance capabilities for major regulatory frameworks including **SOC 2 Type II**, **GDPR**, **ISO 27001**, and **HIPAA** compliance. These features provide automated compliance monitoring, policy enforcement, and audit trail generation to meet enterprise security and regulatory requirements. + +--- + +## Audit Reports & Compliance Dashboards + +Enterprise Governance provides comprehensive audit reporting and compliance dashboards for regulatory requirements and internal governance. + +### Audit Report Types + +**1. Access Audit Reports** +- User login/logout activities +- Failed authentication attempts +- Privilege escalation events +- Unusual access patterns + +**2. Usage Audit Reports** +- API request tracking +- Model and provider usage +- Budget consumption patterns +- Rate limit violations + +**3. Data Audit Reports** +- Data access and modification +- Data export activities +- Data deletion requests +- Consent management tracking + +**4. Compliance Reports** +- SOC 2 Type II control evidence +- GDPR compliance status +- ISO 27001 risk assessments +- HIPAA safeguard compliance + +### Report Generation + + + + +1. **Navigate to Audit Reports** + - Go to **Enterprise** β†’ **Audit & Compliance** + - Select **Generate Report** + +2. **Report Configuration** + +**Report Type:** +- **Access Report**: Authentication and authorization events +- **Usage Report**: API consumption and cost analysis +- **Compliance Report**: Regulatory compliance status +- **Security Report**: Security events and incidents + +**Date Range:** +- **Last 24 Hours**: Recent activity +- **Last 7 Days**: Weekly summary +- **Last 30 Days**: Monthly analysis +- **Custom Range**: Specific date range + +**Filters:** +- **Users**: Specific users or all users +- **Teams**: Specific teams or all teams +- **Customers**: Specific customers or all customers +- **Event Types**: Filter by event categories + +**Export Options:** +- **PDF**: Formatted compliance report +- **CSV**: Raw data for analysis +- **JSON**: Structured data export + + + + +**Generate Access Audit Report:** +```bash +curl -X POST http://localhost:8080/api/enterprise/audit/reports \ + -H "Content-Type: application/json" \ + -d '{ + "report_type": "access_audit", + "date_range": { + "start_date": "2024-01-01T00:00:00Z", + "end_date": "2024-01-31T23:59:59Z" + }, + "filters": { + "users": ["user-alice-001", "user-bob-002"], + "event_types": ["login", "logout", "failed_login", "privilege_escalation"] + }, + "format": "pdf", + "include_summary": true + }' +``` + +**Generate Usage Audit Report:** +```bash +curl -X POST http://localhost:8080/api/enterprise/audit/reports \ + -H "Content-Type: application/json" \ + -d '{ + "report_type": "usage_audit", + "date_range": { + "start_date": "2024-01-01T00:00:00Z", + "end_date": "2024-01-31T23:59:59Z" + }, + "filters": { + "customers": ["customer-corp"], + "models": ["gpt-4o", "claude-3-sonnet-20240229"], + "providers": ["openai", "anthropic"] + }, + "format": "csv", + "include_cost_analysis": true + }' +``` + +**Generate Compliance Report:** +```bash +curl -X POST http://localhost:8080/api/enterprise/audit/reports \ + -H "Content-Type: application/json" \ + -d '{ + "report_type": "compliance", + "compliance_framework": "soc2_type2", + "date_range": { + "start_date": "2024-01-01T00:00:00Z", + "end_date": "2024-01-31T23:59:59Z" + }, + "control_objectives": ["security", "availability", "confidentiality"], + "format": "pdf", + "include_evidence": true + }' +``` + + + + +### Compliance Dashboards + +**Real-Time Monitoring:** +- **Security Posture**: Current security status and alerts +- **Compliance Status**: Regulatory compliance health check +- **Risk Assessment**: Identified risks and mitigation status +- **Audit Trail**: Recent audit events and activities + +**Dashboard Widgets:** +```bash +curl -X GET http://localhost:8080/api/enterprise/dashboard/compliance \ + -H "Authorization: Bearer admin-token" + +# Response includes: +{ + "security_posture": { + "overall_score": 95, + "active_alerts": 2, + "failed_logins_24h": 5, + "privilege_escalations": 0 + }, + "compliance_status": { + "soc2_type2_compliance": "compliant", + "gdpr_compliance": "compliant", + "iso27001_compliance": "in_progress", + "hipaa_compliance": "not_applicable" + }, + "risk_assessment": { + "high_risk_items": 0, + "medium_risk_items": 3, + "low_risk_items": 12, + "mitigation_progress": "85%" + }, + "recent_activities": [ + { + "timestamp": "2024-01-15T10:30:00Z", + "type": "user_login", + "user": "alice@company.com", + "status": "success" + } + ] +} +``` + +### Automated Compliance Monitoring + +**Continuous Monitoring:** +```bash +curl -X POST http://localhost:8080/api/enterprise/compliance/monitoring \ + -H "Content-Type: application/json" \ + -d '{ + "monitoring_rules": [ + { + "name": "Failed Login Monitoring", + "type": "security_event", + "condition": "failed_logins > 10 in 1h", + "action": "alert_security_team", + "severity": "high" + }, + { + "name": "Data Export Monitoring", + "type": "data_access", + "condition": "data_export_size > 1GB", + "action": "require_approval", + "severity": "medium" + }, + { + "name": "Budget Threshold Alert", + "type": "budget_usage", + "condition": "usage > 80% of budget", + "action": "notify_manager", + "severity": "low" + } + ], + "notification_channels": { + "email": ["security@company.com", "compliance@company.com"], + "slack": "#security-alerts", + "webhook": "https://company.com/security-webhook" + } + }' +``` + +--- + +## Error Responses + +Enterprise Governance extends standard governance errors with additional authentication and compliance-related responses: + +**Authentication Errors:** +```json +{ + "error": { + "type": "authentication_required", + "message": "SSO authentication required" + } +} +``` + +```json +{ + "error": { + "type": "mfa_required", + "message": "Multi-factor authentication required" + } +} +``` + +**Authorization Errors:** +```json +{ + "error": { + "type": "user_not_authorized", + "message": "User does not have permission to access this model" + } +} +``` + +**Compliance Errors:** +```json +{ + "error": { + "type": "compliance_violation", + "message": "Request violates GDPR data minimization requirements" + } +} +``` + +--- + +## Next Steps + +- **[Core Governance](../features/governance)** - Understand base governance concepts +- **[Clustering](./clustering)** - Deploy enterprise governance across multiple nodes +{/* - **[SSO Integration](./sso-saml-openid-connect)** - Detailed SSO configuration guide */} +- **[Vault Support](./vault-support)** - Secure credential management +- **[Custom Plugins](./custom-plugins)** - Extend enterprise governance capabilities diff --git a/docs/enterprise/intelligent-load-balancing.mdx b/docs/enterprise/intelligent-load-balancing.mdx new file mode 100644 index 000000000..13eb0f264 --- /dev/null +++ b/docs/enterprise/intelligent-load-balancing.mdx @@ -0,0 +1,371 @@ +--- +title: "Adaptive Load Balancing" +description: "Advanced load balancing algorithms with predictive scaling, health monitoring, and performance optimization for enterprise-grade traffic distribution." +icon: "brain" +--- + +## Overview + +**Adaptive Load Balancing** in Bifrost automatically optimizes traffic distribution across provider keys and models based on real-time performance metrics. The system continuously monitors error rates, latency, and throughput to dynamically adjust weights, ensuring optimal performance and reliability. + +### Key Features + +| Feature | Description | +|---------|-------------| +| **Dynamic Weight Adjustment** | Automatically adjusts key weights based on performance metrics | +| **Real-time Performance Monitoring** | Tracks error rates, latency, and success rates per model-key combination | +| **Cross-Node Synchronization** | Gossip protocol ensures consistent weight information across all cluster nodes | +| **Predictive Scaling** | Anticipates traffic patterns and adjusts weights proactively | +| **Circuit Breaker Integration** | Temporarily removes poorly performing keys from rotation | +| **Model-Level Optimization** | Optimizes performance at both provider and individual model levels | + +--- + +## How Adaptive Load Balancing Works + +### Performance Metrics Collection + +The system continuously collects performance data for each model-key combination: + +```json +{ + "provider": "openai", + "model_key_id": "gpt-4-key-1", + "metrics": { + "avg_latency_ms": 1200, + "error_rate": 0.01, + "success_rate": 0.99, + "requests_per_minute": 362, + "tokens_processed": 87500, + "current_weight": 0.8, + "baseline_latency_ms": 980, + "performance_score": 0.85 + } +} +``` + +### Weight Adjustment Algorithm + +The adaptive load balancer automatically adjusts weights based on real-time performance metrics: + +- **High Error Rates**: Reduces weight for keys with elevated error rates +- **Latency Spikes**: Decreases weight when response times exceed baseline thresholds +- **Superior Performance**: Increases weight for consistently high-performing keys +- **Gradual Adjustments**: Makes incremental changes to prevent traffic oscillation + +### Real-Time Weight Synchronization + +In clustered deployments, weight adjustments are synchronized across all nodes using the gossip protocol: + +#### Weight Update Message Format + +```json +{ + "version": 1, + "type": "weight_update", + "node_id": "bifrost-node-b", + "timestamp": "2024-01-15T10:30:15Z", + "data": { + "provider": "openai", + "model_key_id": "gpt-4-key-2", + "weight_change": { + "from": 0.8, + "to": 0.6, + "reason": "high_error_rate", + "threshold_exceeded": 0.025, + "adjustment_factor": 0.75 + }, + "performance_metrics": { + "avg_latency_ms": 1450, + "baseline_latency_ms": 1100, + "error_rate": 0.03, + "success_rate": 0.97, + "requests_count": 150, + "performance_score": 0.72 + }, + "next_evaluation": "2024-01-15T10:31:15Z" + } +} +``` + +--- + +## Performance Monitoring & Alerting + +### Key Performance Indicators + +The system tracks these critical metrics for each model-key combination: + +| Metric | Threshold | Action | +|--------|-----------|--------| +| **Error Rate** | > 2.5% | Reduce weight by 30% | +| **Latency Spike** | > 150% baseline | Reduce weight by 20% | +| **Success Rate** | < 95% | Circuit breaker activation | +| **Response Time** | > 5000ms | Temporary removal from pool | +| **Throughput Drop** | < 50% expected | Weight adjustment | + +### Automatic Performance Alerts + +```json +{ + "version": 1, + "type": "performance_alert", + "node_id": "bifrost-node-c", + "timestamp": "2024-01-15T10:31:00Z", + "data": { + "alert_type": "latency_spike", + "severity": "warning", + "provider": "anthropic", + "model_key_id": "claude-3-key-1", + "current_metrics": { + "avg_latency_ms": 2800, + "baseline_latency_ms": 980, + "spike_percentage": 185.7, + "error_rate": 0.008, + "current_weight": 1.0 + }, + "recommended_action": "reduce_weight", + "suggested_new_weight": 0.7, + "auto_applied": true + } +} +``` + +--- + +## Configuration + +### Basic Adaptive Load Balancing Setup + +```json +{ + "adaptive_load_balancing": { + "enabled": true, + "algorithm": "adaptive_weighted", + "evaluation_interval": "30s", + "weight_adjustment": { + "enabled": true, + "max_change_per_cycle": 0.3, + "min_weight": 0.1, + "max_weight": 2.0 + }, + "performance_thresholds": { + "error_rate_warning": 0.02, + "error_rate_critical": 0.05, + "latency_spike_threshold": 1.5, + "circuit_breaker_threshold": 0.95 + } + } +} +``` + +### Advanced Configuration + +```json +{ + "adaptive_load_balancing": { + "enabled": true, + "algorithm": "adaptive_weighted", + "evaluation_interval": "30s", + "weight_adjustment": { + "enabled": true, + "strategy": "performance_based", + "max_change_per_cycle": 0.3, + "min_weight": 0.1, + "max_weight": 2.0, + "adjustment_factors": { + "error_rate_penalty": 0.7, + "latency_penalty": 0.8, + "performance_bonus": 1.1 + } + }, + "performance_thresholds": { + "error_rate_warning": 0.02, + "error_rate_critical": 0.05, + "latency_spike_threshold": 1.5, + "latency_critical_threshold": 2.0, + "circuit_breaker_threshold": 0.95, + "recovery_threshold": 0.98 + }, + "metrics_collection": { + "window_size": "5m", + "sample_rate": "1s", + "baseline_calculation": "rolling_average_7d" + }, + "predictive_scaling": { + "enabled": true, + "prediction_window": "15m", + "confidence_threshold": 0.8, + "proactive_adjustments": true + } + } +} +``` + +### Provider-Specific Configuration + +```json +{ + "providers": [ + { + "id": "openai", + "keys": [ + { + "key": "sk-...", + "model_key_id": "gpt-4-key-1", + "weight": 1.0, + "adaptive_balancing": { + "enabled": true, + "baseline_latency_ms": 1100, + "expected_error_rate": 0.01, + "max_requests_per_minute": 500, + "priority": "high" + } + }, + { + "key": "sk-...", + "model_key_id": "gpt-4-key-2", + "weight": 0.8, + "adaptive_balancing": { + "enabled": true, + "baseline_latency_ms": 1200, + "expected_error_rate": 0.015, + "max_requests_per_minute": 400, + "priority": "medium" + } + } + ] + } + ] +} +``` + +--- + +## Traffic Distribution Examples + +### Before Adaptive Load Balancing + +```json +{ + "provider": "openai", + "traffic_distribution": { + "gpt-4-key-1": { + "weight": 1.0, + "traffic_percentage": 50.0, + "avg_latency_ms": 1450, + "error_rate": 0.03, + "status": "degraded_performance" + }, + "gpt-4-key-2": { + "weight": 1.0, + "traffic_percentage": 50.0, + "avg_latency_ms": 1100, + "error_rate": 0.01, + "status": "healthy" + } + } +} +``` + +### After Adaptive Load Balancing + +```json +{ + "provider": "openai", + "traffic_distribution": { + "gpt-4-key-1": { + "weight": 0.6, + "traffic_percentage": 35.3, + "avg_latency_ms": 1450, + "error_rate": 0.03, + "status": "weight_reduced", + "adjustment_reason": "high_error_rate_and_latency" + }, + "gpt-4-key-2": { + "weight": 1.1, + "traffic_percentage": 64.7, + "avg_latency_ms": 1100, + "error_rate": 0.01, + "status": "weight_increased", + "adjustment_reason": "superior_performance" + } + }, + "overall_improvement": { + "avg_latency_reduction": "12.3%", + "error_rate_reduction": "23.1%", + "throughput_increase": "8.7%" + } +} +``` + +--- + +## Monitoring Dashboard + +### Real-Time Performance View + +Monitor adaptive load balancing effectiveness through these key metrics: + +```json +{ + "adaptive_load_balancing_metrics": { + "last_evaluation": "2024-01-15T10:30:00Z", + "next_evaluation": "2024-01-15T10:30:30Z", + "total_adjustments_last_hour": 12, + "performance_improvements": { + "latency_improvement": "15.2%", + "error_rate_reduction": "28.4%", + "throughput_increase": "11.8%" + }, + "provider_performance": { + "openai": { + "total_keys": 3, + "healthy_keys": 2, + "degraded_keys": 1, + "avg_weight": 0.83, + "traffic_distribution": { + "gpt-4-key-1": { + "weight": 0.6, + "traffic_percentage": 28.5, + "performance_score": 0.72, + "trend": "declining" + }, + "gpt-4-key-2": { + "weight": 1.1, + "traffic_percentage": 52.3, + "performance_score": 0.94, + "trend": "stable" + }, + "gpt-4-key-3": { + "weight": 0.9, + "traffic_percentage": 19.2, + "performance_score": 0.87, + "trend": "improving" + } + } + }, + "anthropic": { + "total_keys": 2, + "healthy_keys": 2, + "degraded_keys": 0, + "avg_weight": 1.05, + "traffic_distribution": { + "claude-3-key-1": { + "weight": 1.0, + "traffic_percentage": 48.2, + "performance_score": 0.91, + "trend": "stable" + }, + "claude-3-key-2": { + "weight": 1.1, + "traffic_percentage": 51.8, + "performance_score": 0.95, + "trend": "improving" + } + } + } + } + } +} +``` \ No newline at end of file diff --git a/docs/enterprise/invpc-deployments.mdx b/docs/enterprise/invpc-deployments.mdx new file mode 100644 index 000000000..42e2f30c1 --- /dev/null +++ b/docs/enterprise/invpc-deployments.mdx @@ -0,0 +1,108 @@ +--- +title: "In-VPC Deployments" +description: "Deploy Bifrost within your private cloud infrastructure with VPC isolation, custom networking, and enhanced security controls for enterprise environments." +icon: "cloud" +--- + +In-VPC (Virtual Private Cloud) deployments allow you to run Bifrost entirely within your private cloud infrastructure, providing maximum security, compliance, and control over your AI gateway deployment. + +## Supported Cloud Providers + +Bifrost supports INVPC deployments across all major cloud providers: + +
+
+ Google Cloud Platform +
+
+ Amazon Web Services +
+
+ Microsoft Azure +
+
+ Cloudflare +
+
+ Vercel +
+
+ +## Architecture Benefits + +### Security & Compliance +- **Network Isolation**: Complete isolation within your VPC with no external network dependencies +- **Data Sovereignty**: All data processing occurs within your controlled environment +- **Compliance Ready**: Meets requirements for HIPAA, SOC2, GDPR, and other regulatory frameworks +- **Zero Trust Architecture**: Implements principle of least privilege with granular access controls + +### Performance & Reliability +- **Low Latency**: Direct communication between services within your network +- **High Availability**: Multi-zone deployment with automatic failover capabilities +- **Guaranteed Uptime**: 99.95% SLA with comprehensive monitoring and alerting + +### Control & Customization +- **Custom Networking**: Configure subnets, routing, and security groups to your specifications +- **Resource Management**: Full control over compute, storage, and network resources +- **Scaling Policies**: Define auto-scaling rules based on your usage patterns + +## Service Level Agreement + +### Availability Commitment +- **Uptime Guarantee**: 99.95% monthly uptime for all core components +- **Downtime Calculation**: `(Total Minutes - Downtime Minutes) / Total Minutes Γ— 100` +- **Partial Downtime**: Reduced functionality counted as 50% downtime + +### Core Components Covered +The following components are monitored for SLA compliance: +- Gateway instance +- Log ingestion pipeline + +### Exclusions +SLA excludes downtime due to: +- Scheduled maintenance (14-day advance notice) +- Downstream provider incidents +- Client hardware/software/network issues +- Third-party AI provider outages +- Client misuse or unauthorized modifications + +## Support & Maintenance + +### Technical Support +- **24/7 Critical Support**: Available for core component issues +- **Multiple Channels**: Platform, email (contact@getmaxim.ai), or Slack Connect +- **Audit Trail**: Detailed logs for any data access during troubleshooting + +### Maintenance Windows +- **Scheduled Maintenance**: 14-day advance notice for major updates +- **Security Patches**: Immediate or 14-day delayed application (your choice) +- **Continuous Updates**: Regular feature improvements with 7-day advance notice + +## Getting Started + +### Prerequisites +- VPC with appropriate CIDR ranges +- Kubernetes cluster (GKE, EKS, or AKS) +- Container registry access +- DNS configuration for internal routing + +### Deployment Process +1. **Infrastructure Setup**: Configure VPC, subnets, and security groups +2. **Cluster Preparation**: Set up Kubernetes cluster with required permissions +3. **Bifrost Installation**: Deploy using provided Helm charts or manifests +4. **Configuration**: Apply your specific settings and integrations +5. **Validation**: Run connectivity and performance tests +6. **Go Live**: Begin routing production traffic + + +## Cost Optimization + +### Resource Sizing +- **Development**: 2 vCPU, 4GB RAM minimum +- **Production**: 4+ vCPU, 8GB+ RAM recommended +- **High Availability**: Multi-zone deployment with load balancing + +### Scaling Strategies +- **Horizontal Pod Autoscaling**: Based on CPU/memory utilization +- **Vertical Pod Autoscaling**: Automatic resource adjustment +- **Cluster Autoscaling**: Node pool expansion/contraction diff --git a/docs/enterprise/log-exports.mdx b/docs/enterprise/log-exports.mdx new file mode 100644 index 000000000..a179e79a5 --- /dev/null +++ b/docs/enterprise/log-exports.mdx @@ -0,0 +1,348 @@ +--- +title: "Log Exports" +description: "Export and analyze request logs, traces, and telemetry data from Bifrost with enterprise-grade data export capabilities for compliance, monitoring, and analytics." +icon: "download" +--- + +# Log Exports + +Bifrost Enterprise provides comprehensive log export capabilities, allowing you to automatically export request logs, traces, and telemetry data to various storage systems and data lakes on configurable schedules. + +## Overview + +The log export system enables: +- **Scheduled Exports**: Daily, weekly, or monthly automated exports +- **Multiple Destinations**: Object stores, data warehouses, and data lakes +- **Format Flexibility**: JSON, CSV, Parquet, and custom formats +- **Filtering & Transformation**: Export specific data subsets with custom transformations +- **Compliance**: Meet data retention and audit requirements + +## Supported Export Destinations + +### Object Storage + +#### Amazon S3 +```json +{ + "export": { + "destination": "s3", + "config": { + "bucket": "bifrost-logs", + "region": "us-west-2", + "prefix": "logs/{year}/{month}/{day}/", + "credentials": { + "access_key_id": "${AWS_ACCESS_KEY_ID}", + "secret_access_key": "${AWS_SECRET_ACCESS_KEY}" + } + } + } +} +``` + +#### Google Cloud Storage +```json +{ + "export": { + "destination": "gcs", + "config": { + "bucket": "bifrost-logs", + "prefix": "logs/{year}/{month}/{day}/", + "credentials": { + "service_account_key": "${GCP_SERVICE_ACCOUNT_KEY}" + } + } + } +} +``` + +#### Azure Blob Storage +```json +{ + "export": { + "destination": "azure_blob", + "config": { + "container": "bifrost-logs", + "account_name": "${AZURE_ACCOUNT_NAME}", + "account_key": "${AZURE_ACCOUNT_KEY}", + "prefix": "logs/{year}/{month}/{day}/" + } + } +} +``` + +### Data Warehouses & Lakes + +#### Snowflake +```json +{ + "export": { + "destination": "snowflake", + "config": { + "account": "your-account.snowflakecomputing.com", + "database": "BIFROST_LOGS", + "schema": "PUBLIC", + "table": "request_logs", + "warehouse": "COMPUTE_WH", + "credentials": { + "username": "${SNOWFLAKE_USERNAME}", + "password": "${SNOWFLAKE_PASSWORD}" + } + } + } +} +``` + +#### Amazon Redshift +```json +{ + "export": { + "destination": "redshift", + "config": { + "cluster": "bifrost-cluster", + "database": "bifrost_logs", + "schema": "public", + "table": "request_logs", + "region": "us-west-2", + "credentials": { + "username": "${REDSHIFT_USERNAME}", + "password": "${REDSHIFT_PASSWORD}" + } + } + } +} +``` + +#### Google BigQuery +```json +{ + "export": { + "destination": "bigquery", + "config": { + "project_id": "your-project-id", + "dataset": "bifrost_logs", + "table": "request_logs", + "credentials": { + "service_account_key": "${GCP_SERVICE_ACCOUNT_KEY}" + } + } + } +} +``` + +## Export Schedules + +### Daily Exports +```json +{ + "export": { + "schedule": "daily", + "time": "02:00", + "timezone": "UTC" + } +} +``` + +### Weekly Exports +```json +{ + "export": { + "schedule": "weekly", + "day": "sunday", + "time": "03:00", + "timezone": "UTC" + } +} +``` + +### Monthly Exports +```json +{ + "export": { + "schedule": "monthly", + "day": 1, + "time": "04:00", + "timezone": "UTC" + } +} +``` + +## Export Configuration + +### Complete Export Configuration Example + +```json +{ + "log_exports": { + "enabled": true, + "exports": [ + { + "name": "daily_s3_export", + "enabled": true, + "schedule": { + "frequency": "daily", + "time": "02:00", + "timezone": "UTC" + }, + "destination": { + "type": "s3", + "config": { + "bucket": "bifrost-logs-prod", + "region": "us-west-2", + "prefix": "daily-exports/{year}/{month}/{day}/", + "credentials": { + "access_key_id": "${AWS_ACCESS_KEY_ID}", + "secret_access_key": "${AWS_SECRET_ACCESS_KEY}" + } + } + }, + "data": { + "format": "parquet", + "compression": "gzip", + "include": [ + "request_logs", + "response_logs", + "error_logs" + ], + "filters": { + "date_range": "last_24_hours", + "status_codes": [200, 400, 401, 403, 404, 500] + } + } + }, + { + "name": "weekly_bigquery_export", + "enabled": true, + "schedule": { + "frequency": "weekly", + "day": "sunday", + "time": "03:00", + "timezone": "UTC" + }, + "destination": { + "type": "bigquery", + "config": { + "project_id": "your-analytics-project", + "dataset": "bifrost_analytics", + "table": "weekly_logs", + "credentials": { + "service_account_key": "${GCP_SERVICE_ACCOUNT_KEY}" + } + } + }, + "data": { + "format": "json", + "include": [ + "request_logs", + "metrics", + "traces" + ], + "transformations": [ + { + "type": "aggregate", + "group_by": ["provider", "model", "customer_id"], + "metrics": ["total_requests", "avg_latency", "error_rate"] + } + ] + } + } + ] + } +} +``` + +## Data Formats + +### JSON Format +```json +{ + "timestamp": "2024-01-15T10:30:00Z", + "request_id": "req_123456789", + "customer_id": "cust_abc123", + "provider": "openai", + "model": "gpt-4", + "endpoint": "/v1/chat/completions", + "method": "POST", + "status_code": 200, + "latency_ms": 1250, + "input_tokens": 100, + "output_tokens": 150, + "cost_usd": 0.0045 +} +``` + +### CSV Format +```csv +timestamp,request_id,customer_id,provider,model,endpoint,method,status_code,latency_ms,input_tokens,output_tokens,cost_usd +2024-01-15T10:30:00Z,req_123456789,cust_abc123,openai,gpt-4,/v1/chat/completions,POST,200,1250,100,150,0.0045 +``` + +### Parquet Schema +``` +message log_record { + required int64 timestamp; + required binary request_id (UTF8); + required binary customer_id (UTF8); + required binary provider (UTF8); + required binary model (UTF8); + required binary endpoint (UTF8); + required binary method (UTF8); + required int32 status_code; + required int32 latency_ms; + optional int32 input_tokens; + optional int32 output_tokens; + optional double cost_usd; +} +``` + +## Data Filtering & Transformation + +### Filtering Options +```json +{ + "filters": { + "date_range": { + "start": "2024-01-01T00:00:00Z", + "end": "2024-01-31T23:59:59Z" + }, + "providers": ["openai", "anthropic", "azure"], + "models": ["gpt-4", "claude-3-sonnet"], + "status_codes": [200, 201, 400, 401, 403, 404, 500], + "customers": ["cust_123", "cust_456"], + "min_latency_ms": 100, + "max_latency_ms": 10000, + "has_errors": true + } +} +``` + +### Transformation Options +```json +{ + "transformations": [ + { + "type": "aggregate", + "group_by": ["provider", "model", "date"], + "metrics": [ + "count", + "avg_latency", + "p95_latency", + "total_tokens", + "total_cost", + "error_rate" + ] + }, + { + "type": "anonymize", + "fields": ["customer_id", "request_id"], + "method": "hash" + }, + { + "type": "enrich", + "add_fields": { + "export_timestamp": "${EXPORT_TIMESTAMP}", + "export_version": "${EXPORT_VERSION}" + } + } + ] +} +``` \ No newline at end of file diff --git a/docs/enterprise/mcp-with-fa.mdx b/docs/enterprise/mcp-with-fa.mdx new file mode 100644 index 000000000..8959fa106 --- /dev/null +++ b/docs/enterprise/mcp-with-fa.mdx @@ -0,0 +1,189 @@ +--- +title: "MCP with Federated Auth" +description: "Transform your existing private enterprise APIs into LLM-ready MCP tools using federated authentication without writing a single line of code" +icon: "screwdriver-wrench" +--- + +Transform your existing private enterprise APIs into LLM-ready MCP tools instantly. Add your APIs along with authentication information, and Bifrost dynamically syncs user authentication to allow these existing APIs to be used as MCP tools. + +## Supported Import Methods + +Add your existing APIs to Bifrost using any of these methods: + + + +Import your existing Postman collections directly into Bifrost. All request configurations, headers, and parameters are preserved. + +```json +{ + "info": { + "name": "Enterprise API Collection", + "schema": "https://schema.getpostman.com/json/collection/v2.1.0/collection.json" + }, + "item": [ + { + "name": "Get User Profile", + "request": { + "method": "GET", + "header": [ + { + "key": "Authorization", + "value": "{{req.header.authorization}}", + "type": "text" + } + ], + "url": { + "raw": "https://api.company.com/users/profile", + "host": ["api", "company", "com"], + "path": ["users", "profile"] + } + } + } + ] +} +``` + + + +Use your existing OpenAPI 3.0+ specifications. Bifrost automatically converts them into MCP-compatible tools. + +```yaml +openapi: 3.0.0 +info: + title: Enterprise API + version: 1.0.0 +paths: + /users/profile: + get: + summary: Get user profile + security: + - BearerAuth: [] + parameters: + - name: Authorization + in: header + schema: + type: string + example: "{{req.header.authorization}}" +components: + securitySchemes: + BearerAuth: + type: http + scheme: bearer +``` + + + +Convert your existing cURL commands directly into MCP tools. + +```bash +curl -X GET "https://api.company.com/users/profile" \ + -H "Authorization: {{req.header.authorization}}" \ + -H "Content-Type: application/json" +``` + + + +Use Bifrost's intuitive UI to manually configure your API endpoints with the same ease as Postman. + +1. Set HTTP method and URL +2. Configure headers with variable substitution +3. Define request body (if needed) +4. Test the endpoint +5. Deploy as MCP tool + + + +## What Happens Next + +Once you upload your API specifications, Bifrost automatically: + +- **Syncs authentication systems** from your existing APIs +- **Converts endpoints** into MCP-compatible tools +- **Maintains security** using your current auth infrastructure +- **Makes APIs available** to LLMs instantly + +## Supported Authentication Types + +Bifrost automatically handles all common authentication patterns: + +- **Bearer Tokens** (JWT, OAuth) +- **API Keys** (headers, query parameters) +- **Custom Headers** (tenant IDs, user tokens) +- **Basic Auth** and other standard methods + +## Real-World Use Cases + +### Enterprise CRM Integration + +Transform your Salesforce, HubSpot, or custom CRM APIs: + +```json +{ + "name": "Get Customer Data", + "method": "GET", + "url": "https://api.company.com/crm/customers/{{req.body.customer_id}}", + "headers": { + "Authorization": "{{req.header.authorization}}", + "X-Tenant-ID": "{{req.header.x-tenant-id}}" + } +} +``` + +### Internal Microservices + +Make your internal microservices LLM-accessible: + +```yaml +paths: + /internal/user-service/profile: + get: + parameters: + - name: Authorization + in: header + schema: + type: string + default: "{{req.header.authorization}}" + - name: X-Service-Token + in: header + schema: + type: string + default: "{{env.INTERNAL_SERVICE_TOKEN}}" +``` + +### Database APIs + +Connect to your database APIs securely: + +```http +POST https://db-api.company.com/query +Content-Type: application/json +Authorization: {{req.header.authorization}} +X-Database-Name: {{req.header.x-database}} + +{ + "query": "SELECT * FROM users WHERE tenant_id = '{{req.body.tenant_id}}'", + "limit": 100 +} +``` + +## Security Benefits + +### 1. **Zero Trust Architecture** +- Authentication happens at the edge (your existing systems) +- Bifrost never stores or caches authentication credentials +- Each request is authenticated independently + +### 2. **Existing Security Policies** +- Leverage your current RBAC (Role-Based Access Control) +- Maintain existing audit trails +- No changes to security infrastructure required + +### 3. **Granular Access Control** +- Different users get different API access based on their credentials +- Tenant isolation maintained through existing headers +- API rate limiting and quotas preserved + +### 4. **Compliance Friendly** +- No sensitive data passes through Bifrost permanently +- Existing compliance frameworks remain intact +- Audit trails maintained in your systems \ No newline at end of file diff --git a/docs/enterprise/vault-support.mdx b/docs/enterprise/vault-support.mdx new file mode 100644 index 000000000..ef2ffc4c4 --- /dev/null +++ b/docs/enterprise/vault-support.mdx @@ -0,0 +1,182 @@ +--- +title: "Vault Support" +description: "Secure API key management with HashiCorp Vault, AWS Secrets Manager, Google Secret Manager, and Azure Key Vault integration. Store and retrieve sensitive credentials using enterprise-grade secret management." +icon: "vault" +--- + +Bifrost's vault support enables seamless integration with enterprise-grade secret management systems, allowing you to connect to existing vaults and automatically sync virtual keys and provider API keys directly onto the Bifrost platform. + +## Overview + +The vault integration provides: + +- **Automated Key Synchronization**: Connect to your existing vault infrastructure and sync all API keys automatically +- **Periodic Key Management**: Regular synchronization ensures deprecated and archived keys are properly managed +- **Multi-Vault Support**: Compatible with HashiCorp Vault, AWS Secrets Manager, Google Secret Manager, and Azure Key Vault +- **Zero-Downtime Operations**: Keys are synced without interrupting your running services + +## Supported Vault Systems + +### HashiCorp Vault + +Connect to your HashiCorp Vault instance for centralized secret management. + +```json +{ + "vault": { + "type": "hashicorp", + "address": "https://vault.company.com:8200", + "token": "${VAULT_TOKEN}", + "mount": "secret", + "sync_interval": "300s" + } +} +``` + +### AWS Secrets Manager + +Integrate with AWS Secrets Manager for cloud-native secret storage. + +```json +{ + "vault": { + "type": "aws_secrets_manager", + "region": "us-east-1", + "access_key_id": "${AWS_ACCESS_KEY_ID}", + "secret_access_key": "${AWS_SECRET_ACCESS_KEY}", + "sync_interval": "300s" + } +} +``` + +### Google Secret Manager + +Use Google Cloud's Secret Manager for secure key storage. + +```json +{ + "vault": { + "type": "google_secret_manager", + "project_id": "your-project-id", + "credentials_file": "/path/to/service-account.json", + "sync_interval": "300s" + } +} +``` + +### Azure Key Vault + +Connect to Azure Key Vault for Microsoft cloud environments. + +```json +{ + "vault": { + "type": "azure_key_vault", + "vault_url": "https://your-keyvault.vault.azure.net/", + "client_id": "${AZURE_CLIENT_ID}", + "client_secret": "${AZURE_CLIENT_SECRET}", + "tenant_id": "${AZURE_TENANT_ID}", + "sync_interval": "300s" + } +} +``` + +## Key Synchronization + +### Automatic Sync Process + +Bifrost automatically synchronizes keys from your vault at regular intervals: + +1. **Discovery**: Scans the configured vault paths for API keys and virtual keys +2. **Validation**: Verifies key format and accessibility +3. **Sync**: Updates Bifrost's internal key store with new and modified keys +4. **Deprecation**: Identifies and archives keys that have been removed from the vault +5. **Notification**: Logs sync status and any issues encountered + +### Sync Configuration + +Configure synchronization behavior to match your operational requirements: + +```json +{ + "vault": { + "sync_interval": "300s", + "sync_paths": [ + "bifrost/provider-keys/*", + "bifrost/virtual-keys/*" + ], + "auto_deprecate": true, + "backup_deprecated_keys": true + } +} +``` + +#### Configuration Options + +| Option | Description | Default | +|--------|-------------|---------| +| `sync_interval` | Time between sync operations | `300s` | +| `sync_paths` | Vault paths to monitor for keys | `["bifrost/*"]` | +| `auto_deprecate` | Automatically deprecate removed keys | `true` | +| `backup_deprecated_keys` | Backup keys before deprecation | `true` | + +## Key Management Lifecycle + +### Key States + +Keys in Bifrost can have the following states: + +- **Active**: Currently in use and available for requests +- **Deprecated**: Marked for removal but still functional +- **Archived**: Removed from active use but retained for audit purposes +- **Expired**: Keys that have exceeded their validity period + +### Deprecation Process + +When keys are removed from the vault: + +1. **Detection**: Next sync cycle identifies missing keys +2. **Grace Period**: Keys enter deprecated state with configurable grace period +3. **Notification**: Administrators are notified of pending deprecation +4. **Archive**: Keys are moved to archived state after grace period expires + +```json +{ + "vault": { + "deprecation": { + "grace_period": "24h", + "notify_admins": true, + "retain_archived": "90d" + } + } +} +``` + +## Security Considerations + +### Authentication + +- **Vault Tokens**: Use time-limited tokens with minimal required permissions +- **IAM Roles**: Leverage cloud provider IAM roles for secure authentication +- **Certificate-based Auth**: Support for mutual TLS authentication where available + +### Encryption + +- **Transit Encryption**: All communication with vault systems uses TLS +- **At-Rest Encryption**: Keys are encrypted in Bifrost's internal storage +- **Key Rotation**: Automatic detection and handling of rotated vault credentials + +### Audit Trail + +Complete audit logging for all vault operations: + +```json +{ + "timestamp": "2024-01-15T10:30:00Z", + "operation": "key_sync", + "vault_type": "hashicorp", + "keys_synced": 15, + "keys_deprecated": 2, + "status": "success" +} +``` diff --git a/docs/favicon.png b/docs/favicon.png new file mode 100644 index 000000000..19ed93b1f Binary files /dev/null and b/docs/favicon.png differ diff --git a/docs/features/custom-providers.mdx b/docs/features/custom-providers.mdx new file mode 100644 index 000000000..484cd2e20 --- /dev/null +++ b/docs/features/custom-providers.mdx @@ -0,0 +1,294 @@ +--- +title: "Custom Providers" +description: "Create custom provider configurations with specific request type restrictions, custom naming, and controlled access patterns." +icon: "gears" +--- + +## What Are Custom Providers? + +Custom providers allow you to create multiple instances of the same base provider, each with different configurations and access patterns. The key feature is request type control, which enables you to restrict what operations each custom provider instance can perform. + +Think of custom providers as "multiple views" of the same underlying provider β€” you can create several custom configurations for OpenAI, Anthropic, or any other provider, each optimized for different use cases while sharing the same API keys and base infrastructure. + +## Key Benefits + +- **Multiple Provider Instances**: Create several configurations of the same base provider (e.g., multiple OpenAI configurations) +- **Request Type Control**: Restrict which operations (chat, embeddings, speech, etc.) each custom provider can perform +- **Custom Naming**: Use descriptive names like "openai-production" or "openai-staging" +- **Provider Reuse**: Maximize the value of your existing provider accounts + +## How to Configure + +Custom providers are configured using the `custom_provider_config` field, which extends the standard provider configuration. The main purpose is to create multiple instances of the same base provider, each with different request type restrictions. + +**Important**: The `allowed_requests` field follows a specific behavior: +- **Omitted entirely**: All operations are allowed (default behavior) +- **Partially specified**: Only explicitly set fields are allowed, others default to `false` +- **Fully specified**: Only the operations you explicitly enable are allowed +- **Present but empty object (`{}`)**: All fields are set to false + + + + + +![Provider Configuration Interface](../media/ui-custom-provider.png) + +1. Go to **http://localhost:8080** +2. Navigate to **"Providers"** in the sidebar +3. Click **"Add New Provider"** +4. Choose a unique provider name (e.g., "openai-custom") +5. Select the base provider type (e.g., "openai") +6. Configure which request types are allowed +7. Save configuration + + + + + +```bash +# Create a chat-only custom provider +curl --location 'http://localhost:8080/api/providers' \ +--header 'Content-Type: application/json' \ +--data '{ + "provider": "openai-custom", + "keys": [ + { + "value": "env.OPENAI_API_KEY", + "models": [], + "weight": 1.0 + } + ], + "custom_provider_config": { + "base_provider_type": "openai", + "allowed_requests": { + "text_completion": false, + "chat_completion": true, + "chat_completion_stream": true, + "embedding": false, + "speech": false, + "speech_stream": false, + "transcription": false, + "transcription_stream": false + } + } +}' +``` + + + + + +```json +{ + "providers": { + "openai-custom": { + "keys": [ + { + "value": "env.OPENAI_API_KEY", + "models": [], + "weight": 1.0 + } + ], + "custom_provider_config": { + "base_provider_type": "openai", + "allowed_requests": { + "text_completion": false, + "chat_completion": true, + "chat_completion_stream": true, + "embedding": false, + "speech": false, + "speech_stream": false, + "transcription": false, + "transcription_stream": false + } + } + } + } +} +``` + + + + + +## Configuration Options + +### Allowed Request Types + +Control which operations your custom provider can perform. The behavior is: + +- **If `allowed_requests` is not specified**: All operations are allowed by default +- **If `allowed_requests` is specified**: Only the fields set to `true` are allowed, all others default to `false` + +Available operations: + +- **`text_completion`**: Legacy text completion requests +- **`chat_completion`**: Standard chat completion requests +- **`chat_completion_stream`**: Streaming chat responses +- **`embedding`**: Text embedding generation +- **`speech`**: Text-to-speech conversion +- **`speech_stream`**: Streaming text-to-speech +- **`transcription`**: Speech-to-text conversion +- **`transcription_stream`**: Streaming speech-to-text + +### Base Provider Types + +Custom providers can be built on these supported providers: + +- `openai` - OpenAI API +- `anthropic` - Anthropic Claude +- `bedrock` - AWS Bedrock +- `cohere` - Cohere +- `gemini` - Gemini + +## Use Cases + +### 1. Environment-Specific Configurations + +Create different configurations for production, staging, and development environments: + +```json +{ + "openai-production": { + "keys": [{ "value": "env.PROVIDER_API_KEY", "models": [], "weight": 1.0 }], + "custom_provider_config": { + "base_provider_type": "openai", + "allowed_requests": { + "chat_completion": true, + "chat_completion_stream": true, + "embedding": true, + "speech": true, + "speech_stream": true + } + } + }, + "openai-staging": { + "keys": [{ "value": "env.PROVIDER_API_KEY", "models": [], "weight": 1.0 }], + "custom_provider_config": { + "base_provider_type": "openai", + "allowed_requests": { + "chat_completion": true, + "chat_completion_stream": true, + "embedding": true, + "speech": false, + "speech_stream": false + } + } + }, + "openai-dev": { + "keys": [{ "value": "env.PROVIDER_API_KEY", "models": [], "weight": 1.0 }], + "custom_provider_config": { + "base_provider_type": "openai", + "allowed_requests": { + "chat_completion": true, + "chat_completion_stream": false, + "embedding": false, + "speech": false, + "speech_stream": false + } + } + } +} +``` + +### 2. Role-Based Access Control + +Restrict capabilities based on user roles or team permissions. You can then create virtual keys for better management of who can access which providers, providing granular control over team permissions and resource usage. This integrates seamlessly with Bifrost's **[governance](./governance)** features for comprehensive access control and monitoring: + +```json +{ + "openai-developers": { + "keys": [{ "value": "env.PROVIDER_API_KEY", "models": [], "weight": 1.0 }], + "custom_provider_config": { + "base_provider_type": "openai", + "allowed_requests": { + "chat_completion": true, + "chat_completion_stream": true, + "embedding": true, + "text_completion": true + } + } + }, + "openai-analysts": { + "keys": [{ "value": "env.PROVIDER_API_KEY", "models": [], "weight": 1.0 }], + "custom_provider_config": { + "base_provider_type": "openai", + "allowed_requests": { + "chat_completion": true, + "embedding": true + } + } + }, + "openai-support": { + "keys": [{ "value": "env.PROVIDER_API_KEY", "models": [], "weight": 1.0 }], + "custom_provider_config": { + "base_provider_type": "openai", + "allowed_requests": { + "chat_completion": true, + "chat_completion_stream": false + } + } + } +} +``` + +### 3. Feature Testing and Rollouts + +Test new features with limited user groups: + +```json +{ + "openai-beta-streaming": { + "keys": [{ "value": "env.PROVIDER_API_KEY", "models": [], "weight": 1.0 }], + "custom_provider_config": { + "base_provider_type": "openai", + "allowed_requests": { + "chat_completion": true, + "chat_completion_stream": true, + "embedding": false + } + } + }, + "openai-stable": { + "keys": [{ "value": "env.PROVIDER_API_KEY", "models": [], "weight": 1.0 }], + "custom_provider_config": { + "base_provider_type": "openai", + "allowed_requests": { + "chat_completion": true, + "chat_completion_stream": false, + "embedding": true + } + } + } +} +``` + +## Making Requests + +Use your custom provider name in requests: + +```bash +# Request to custom provider +curl --location 'http://localhost:8080/v1/chat/completions' \ +--header 'Content-Type: application/json' \ +--data '{ + "model": "openai-custom/gpt-4o-mini", + "messages": [ + {"role": "user", "content": "Hello!"} + ] +}' +``` + +## Relationship to Provider Configuration + +Custom providers extend the standard provider configuration system. They inherit all the capabilities of their base provider while adding request type restrictions. + +**Learn more about provider configuration:** +- **[Gateway Provider Configuration](../quickstart/gateway/provider-configuration)** +- **[Go SDK Provider Configuration](../quickstart/go-sdk/provider-configuration)** + +## Next Steps + +- **[Fallbacks](./fallbacks)** - Automatic failover between providers +- **[Load Balancing](./keys-management)** - Intelligent API key management with weighted load balancing +- **[Governance](./governance)** - Advanced access control and monitoring diff --git a/docs/features/drop-in-replacement.mdx b/docs/features/drop-in-replacement.mdx new file mode 100644 index 000000000..b6be05818 --- /dev/null +++ b/docs/features/drop-in-replacement.mdx @@ -0,0 +1,78 @@ +--- +title: "Drop-in Replacement" +description: "Replace your existing AI SDK connections with Bifrost by changing just the base URL. Keep your code, gain advanced features like fallbacks, load balancing, and governance." +icon: "shuffle" +--- + +## Zero Code Changes + +The Bifrost Gateway acts as a drop-in replacement for popular AI SDKs. This means you can point your existing OpenAI, Anthropic, or Google GenAI client to Bifrost's HTTP gateway and instantly gain access to advanced features without rewriting your application. + +The magic happens with a single line change: update your `base_url` to point to Bifrost's gateway, and everything else stays exactly the same. + +## How It Works + +Bifrost provides **100% compatible endpoints** for popular AI SDKs by acting as a protocol adapter. Your existing SDK code continues to work unchanged, but now benefits from Bifrost's multi-provider support, automatic failovers, semantic caching, and governance features. + + + + + +```python +# Before: Direct to OpenAI +client = openai.OpenAI( + api_key="your-openai-key" +) + +# After: Through Bifrost +client = openai.OpenAI( + base_url="http://localhost:8080/openai", # Only change needed + api_key="dummy-key" # Keys handled by Bifrost +) +``` + + + + + +```python +# Before: Direct to Anthropic +client = anthropic.Anthropic( + api_key="your-anthropic-key" +) + +# After: Through Bifrost +client = anthropic.Anthropic( + base_url="http://localhost:8080/anthropic", # Only change needed + api_key="dummy-key" # Keys handled by Bifrost +) +``` + + + + + +## Instant Advanced Features + +Once your SDK points to Bifrost, you automatically get: + +- **Multi-provider support** with automatic failovers +- **Load balancing** across multiple API keys +- **Semantic caching** for faster responses +- **Governance controls** for usage monitoring and budgets +- **Request/response logging** and analytics +- **Rate limiting** and circuit breakers + +and so much more! All without changing a **single line** of your application logic. + +## Complete Integration Support + +Bifrost provides drop-in compatibility for multiple popular AI SDKs and frameworks: + +- **[OpenAI SDK](../integrations/openai-sdk)** +- **[Anthropic SDK](../integrations/anthropic-sdk)** +- **[Google GenAI SDK](../integrations/genai-sdk)** +- **[LiteLLM](../integrations/litellm-sdk)** +- **[LangChain](../integrations/langchain-sdk)** + +**For detailed setup instructions and compatibility information:** [Complete Integration Guide](../integrations/what-is-an-integration) \ No newline at end of file diff --git a/docs/features/fallbacks.mdx b/docs/features/fallbacks.mdx new file mode 100644 index 000000000..315319c07 --- /dev/null +++ b/docs/features/fallbacks.mdx @@ -0,0 +1,187 @@ +--- +title: "Fallbacks" +description: "Automatic failover between AI providers and models. When your primary provider fails, Bifrost seamlessly switches to backup providers without interrupting your application." +icon: "list-check" +--- + +## Automatic Provider Failover + +Fallbacks provide automatic failover when your primary AI provider experiences issues. Whether it's rate limiting, outages, or model unavailability, Bifrost automatically tries backup providers in the order you specify until one succeeds. + +When a fallback is triggered, Bifrost treats it as a completely new request - all configured plugins (caching, governance, logging, etc.) run again for the fallback provider, ensuring consistent behavior across all providers. + +## How Fallbacks Work + +When you configure fallbacks, Bifrost follows this process: + +1. **Primary Attempt**: Tries your main provider/model first +2. **Automatic Detection**: If the primary fails (network error, rate limit, model unavailable), Bifrost detects the failure +3. **Sequential Fallbacks**: Tries each fallback provider in order until one succeeds +4. **Success Response**: Returns the response from the first successful provider +5. **Complete Failure**: If all providers fail, returns the original error from the primary provider + +Each fallback attempt is treated as a fresh request, so all your configured plugins (semantic caching, governance rules, monitoring) apply to whichever provider ultimately handles the request. + +## Implementation Examples + + + + +```bash +# Chat completion with multiple fallbacks +curl -X POST http://localhost:8080/v1/chat/completions \ + -H "Content-Type: application/json" \ + -d '{ + "model": "openai/gpt-4o-mini", + "messages": [ + { + "role": "user", + "content": "Explain quantum computing in simple terms" + } + ], + "fallbacks": [ + "anthropic/claude-3-5-sonnet-20241022", + "bedrock/anthropic.claude-3-sonnet-20240229-v1:0" + ], + "max_tokens": 1000, + "temperature": 0.7 + }' +``` + +**Response (from whichever provider succeeded):** +```json +{ + "id": "chatcmpl-123", + "object": "chat.completion", + "choices": [ + { + "index": 0, + "message": { + "role": "assistant", + "content": "Quantum computing is like having a super-powered calculator..." + }, + "finish_reason": "stop" + } + ], + "usage": { + "prompt_tokens": 12, + "completion_tokens": 150, + "total_tokens": 162 + }, + "extra_fields": { + "provider": "anthropic", + "latency": 1.2 + } +} +``` + + + + + +```go +package main + +import ( + "context" + "fmt" + "github.com/maximhq/bifrost" + "github.com/maximhq/bifrost/core/schemas" +) + +func chatWithFallbacks(client *bifrost.Bifrost) { + ctx := context.Background() + + // Chat request with multiple fallbacks + response, err := client.ChatCompletion(ctx, &schemas.BifrostRequest{ + Provider: schemas.OpenAI, + Model: "gpt-4o-mini", + Messages: []schemas.BifrostMessage{ + { + Role: schemas.ModelChatMessageRoleUser, + Content: schemas.MessageContent{ + ContentStr: bifrost.Ptr("Explain quantum computing in simple terms"), + }, + }, + }, + // Fallback chain: OpenAI β†’ Anthropic β†’ Bedrock + Fallbacks: []schemas.Fallback{ + { + Provider: schemas.Anthropic, + Model: "claude-3-5-sonnet-20241022", + }, + { + Provider: schemas.Bedrock, + Model: "anthropic.claude-3-sonnet-20240229-v1:0", + }, + }, + ModelParameters: &schemas.ModelParameters{ + MaxTokens: bifrost.Ptr(1000), + Temperature: bifrost.Ptr(0.7), + }, + }) + + if err != nil { + fmt.Printf("All providers failed: %v\n", err) + return + } + + // Success! Response came from whichever provider worked + fmt.Printf("Response from %s: %s\n", + response.ExtraFields.Provider, + *response.Choices[0].BifrostNonStreamResponseChoice.Message.Content.ContentStr) +} +``` + + + + + +## Real-World Scenarios + +**Scenario 1: Rate Limiting** +- Primary: OpenAI hits rate limit β†’ Fallback: Anthropic succeeds +- Your application continues without interruption + +**Scenario 2: Model Unavailability** +- Primary: Specific model unavailable β†’ Fallback: Different provider with similar model +- Seamless transition to equivalent capability + +**Scenario 3: Provider Outage** +- Primary: Provider experiencing downtime β†’ Fallback: Alternative provider +- Business continuity maintained + +**Scenario 4: Cost Optimization** +- Primary: Premium model for quality β†’ Fallback: Cost-effective alternative if budget exceeded +- Governance rules can trigger fallbacks based on usage + +## Fallback Behavior Details + +**What Triggers Fallbacks:** +- Network connectivity issues +- Provider API errors (500, 502, 503, 504) +- Rate limiting (429 errors) +- Model unavailability +- Request timeouts +- Authentication failures + +**What Preserves Original Error:** +- Request validation errors (malformed requests) +- Plugin-enforced blocks (governance violations) +- Certain provider-specific errors marked as non-retryable + +**Plugin Execution:** +When a fallback is triggered, the fallback request is treated as completely new: +- Semantic cache checks run again (different provider might have cached responses) +- Governance rules apply to the new provider +- Logging captures the fallback attempt +- All configured plugins execute fresh for the fallback provider + +**Plugin Fallback Control:** +Plugins can control whether fallbacks should be triggered based on their specific logic. For example: +- A custom plugin might prevent fallbacks for certain types of errors +- Security plugins might disable fallbacks for compliance reasons + +When a plugin determines that fallbacks should not be attempted, it can prevent the fallback mechanism entirely, ensuring the original error is returned immediately. + +This ensures consistent behavior regardless of which provider ultimately handles your request, while giving plugins full control over the fallback decision process. And you can always know which provider handled your request via `extra_fields`. diff --git a/docs/features/governance.mdx b/docs/features/governance.mdx new file mode 100644 index 000000000..d0c07c4b8 --- /dev/null +++ b/docs/features/governance.mdx @@ -0,0 +1,763 @@ +--- +title: "Budget Management" +description: "Enterprise-grade budget management and cost control with hierarchical budget allocation through virtual keys, teams, and customers." +icon: "money-bills" +--- + +## Overview + +Bifrost's budget management system provides comprehensive cost control and financial governance for enterprise AI deployments. It operates through a **hierarchical budget structure** that enables granular cost management, usage tracking, and financial oversight across your entire organization. + +**Core Hierarchy:** +``` +Customer (has independent budget) + ↓ (one-to-many) +Team (has independent budget) + ↓ (one-to-many) +Virtual Key (has independent budget + rate limits) + +OR + +Customer (has independent budget) + ↓ (direct attachment) +Virtual Key (has independent budget + rate limits) + +OR + +Virtual Key (standalone - has independent budget + rate limits) +``` + +**Key Capabilities:** +- **Virtual Keys** - Primary access control via `x-bf-vk` header (exclusive team OR customer attachment) +- **Budget Management** - Independent budget limits at each hierarchy level with cumulative checking +- **Rate Limiting** - Request and token-based throttling (VK-level only) +- **Model/Provider Filtering** - Granular access control per virtual key +- **Usage Tracking** - Real-time monitoring and audit trails +- **Audit Headers** - Optional team and customer identification + +**Budgeting Modes:** +- **Mandatory** (`enforce_governance_header: true`) - All requests require `x-bf-vk` header +- **Optional** (`enforce_governance_header: false`) - Governance applied only when `x-bf-vk` header present + +For detailed implementation architecture, see [Architecture > Plugins > Governance](../architecture/plugins/governance). + +--- + +## Virtual Keys + +Virtual Keys are the primary governance entity in Bifrost. Users and applications authenticate using the `x-bf-vk` header, which maps to specific access permissions, budgets, and rate limits. + +**Key Features:** +- **Access Control** - Model and provider filtering +- **Cost Management** - Independent budgets (checked along with team/customer budgets if attached) +- **Rate Limiting** - Token and request-based throttling (VK-level only) +- **Key Restrictions** - Limit VK to specific provider API keys (if configured, VK can only use those keys) +- **Exclusive Attachment** - Belongs to either one team OR one customer OR neither (mutually exclusive) +- **Active/Inactive Status** - Enable/disable access instantly + +### Configuration + + + + +1. **Navigate to Virtual Keys** + - Open Bifrost UI at `http://localhost:8080` + - Go to **Governance** β†’ **Virtual Keys** + +2. **Create Virtual Key** + +![Virtual Key Creation](../../media/ui-virtual-key.png) + +**Required Fields:** +- **Name**: Descriptive identifier +- **Description**: Optional usage details + +**Access Control:** +- **Allowed Models**: Specific models (empty = all allowed) +- **Allowed Providers**: Specific providers (empty = all allowed) + +**Budget Settings:** +- **Max Limit**: Dollar amount (e.g., `10.50`) +- **Reset Duration**: `1m`, `1h`, `1d`, `1w`, `1M` + +**Rate Limits:** +- **Token Limit**: Max tokens per period +- **Request Limit**: Max requests per period +- **Reset Duration**: Reset frequency for each limit + +**Associations:** +- **Team**: Assign to existing team (mutually exclusive with customer) +- **Customer**: Assign to existing customer (mutually exclusive with team) +- **Provider Keys**: Restrict VK to specific API keys (optional - leave empty for no restrictions) + +3. **Save Configuration** + - Click **Create Virtual Key** + - Note the generated VK value for client use + + + + +**Create Virtual Key (attached to team):** +```bash +curl -X POST http://localhost:8080/api/governance/virtual-keys \ + -H "Content-Type: application/json" \ + -d '{ + "name": "Engineering Team API", + "description": "Main API key for engineering team", + "allowed_models": ["gpt-4o-mini", "claude-3-sonnet-20240229"], + "allowed_providers": ["openai", "anthropic"], + "team_id": "team-eng-001", + "budget": { + "max_limit": 100.00, + "reset_duration": "1M" + }, + "rate_limit": { + "token_max_limit": 10000, + "token_reset_duration": "1h", + "request_max_limit": 100, + "request_reset_duration": "1m" + }, + "key_ids": ["8c52039e-38c6-48b2-8016-0bd884b7befb"], + "is_active": true + }' +``` + +**Create Virtual Key (directly attached to customer):** +```bash +curl -X POST http://localhost:8080/api/governance/virtual-keys \ + -H "Content-Type: application/json" \ + -d '{ + "name": "Executive API Key", + "description": "Direct customer-level API access", + "allowed_models": ["gpt-4o", "claude-3-opus-20240229"], + "allowed_providers": ["openai", "anthropic"], + "customer_id": "customer-acme-corp", + "budget": { + "max_limit": 500.00, + "reset_duration": "1M" + }, + "is_active": true + }' +``` + +> **Note**: +> - `team_id` and `customer_id` are mutually exclusive - a VK can only belong to one team OR one customer, not both. +> - `key_ids` restricts the VK to only use those specific provider API keys. Omit this field to allow access to all available keys. + +**Update Virtual Key:** +```bash +curl -X PUT http://localhost:8080/api/governance/virtual-keys/{vk_id} \ + -H "Content-Type: application/json" \ + -d '{ + "description": "Updated description", + "allowed_models": ["gpt-4o", "claude-3-opus-20240229"], + "budget": { + "max_limit": 150.00, + "reset_duration": "1M" + } + }' +``` + +**Get Virtual Keys:** +```bash +# List all virtual keys +curl http://localhost:8080/api/governance/virtual-keys + +# Get specific virtual key +curl http://localhost:8080/api/governance/virtual-keys/{vk_id} +``` + +**Delete Virtual Key:** +```bash +curl -X DELETE http://localhost:8080/api/governance/virtual-keys/{vk_id} +``` + + + + +```json +{ + "client": { + "enable_governance": true, + "enforce_governance_header": true + }, + "governance": { + "virtual_keys": [ + { + "id": "vk-001", + "name": "Engineering Team API", + "value": "vk-engineering-main", + "description": "Main API key for engineering team", + "is_active": true, + "allowed_models": ["gpt-4o-mini", "claude-3-sonnet-20240229"], + "allowed_providers": ["openai", "anthropic"], + "team_id": "team-eng-001", + "budget_id": "budget-eng-vk", + "rate_limit_id": "rate-limit-eng-vk", + "keys": [ + {"key_id": "8c52039e-38c6-48b2-8016-0bd884b7befb"} + ] + }, + { + "id": "vk-002", + "name": "Executive API Key", + "value": "vk-executive-direct", + "description": "Direct customer-level API access", + "is_active": true, + "allowed_models": ["gpt-4o", "claude-3-opus-20240229"], + "allowed_providers": ["openai", "anthropic"], + "customer_id": "customer-acme-corp", + "budget_id": "budget-exec-vk", + "keys": [ + {"key_id": "8c52039e-38c6-48b2-8016-0bd884b7befb"} + ] + } + ], + "budgets": [ + { + "id": "budget-eng-vk", + "max_limit": 100.00, + "reset_duration": "1M", + "current_usage": 0.0, + "last_reset": "2025-01-01T00:00:00Z" + }, + { + "id": "budget-exec-vk", + "max_limit": 500.00, + "reset_duration": "1M", + "current_usage": 0.0, + "last_reset": "2025-01-01T00:00:00Z" + } + ], + "rate_limits": [ + { + "id": "rate-limit-eng-vk", + "token_max_limit": 10000, + "token_reset_duration": "1h", + "token_current_usage": 0, + "token_last_reset": "2025-01-01T00:00:00Z", + "request_max_limit": 100, + "request_reset_duration": "1m", + "request_current_usage": 0, + "request_last_reset": "2025-01-01T00:00:00Z" + } + ] + } +} +``` + + + + +### Key Restrictions + +Virtual Keys can be restricted to use only specific provider API keys. When key restrictions are configured, the VK can only access those designated keys, providing fine-grained control over which API keys different users or applications can utilize. + +**How It Works:** +- **No Restrictions** (default): VK can use any available provider keys based on load balancing +- **With Restrictions**: VK limited to only the specified key IDs, regardless of other available keys + +**Example Scenario:** +``` +Available Provider Keys: +β”œβ”€β”€ key-prod-001 β†’ sk-prod-key... (Production OpenAI key) +β”œβ”€β”€ key-dev-002 β†’ sk-dev-key... (Development OpenAI key) +└── key-test-003 β†’ sk-test-key... (Testing OpenAI key) + +Virtual Key Restrictions: +β”œβ”€β”€ vk-prod-main +β”‚ β”œβ”€β”€ Allowed Models: [gpt-4o] +β”‚ └── Restricted Keys: [key-prod-001] ← ONLY production key +β”œβ”€β”€ vk-dev-main +β”‚ β”œβ”€β”€ Allowed Models: [gpt-4o-mini] +β”‚ └── Restricted Keys: [key-dev-002, key-test-003] ← Dev + test keys +└── vk-unrestricted + β”œβ”€β”€ Allowed Models: [all models] + └── Restricted Keys: [] ← Can use ANY available key +``` + +**Request Behavior:** +```bash +# Production VK - will ONLY use key-prod-001 +curl -X POST http://localhost:8080/v1/chat/completions \ + -H "x-bf-vk: vk-prod-main" \ + -d '{"model": "gpt-4o", "messages": [{"role": "user", "content": "Hello!"}]}' + +# Development VK - will load balance between key-dev-002 and key-test-003 +curl -X POST http://localhost:8080/v1/chat/completions \ + -H "x-bf-vk: vk-dev-main" \ + -d '{"model": "gpt-4o-mini", "messages": [{"role": "user", "content": "Hello!"}]}' + +# VK with no key restrictions - can use any available OpenAI key +curl -X POST http://localhost:8080/v1/chat/completions \ + -H "x-bf-vk: vk-unrestricted" \ + -d '{"model": "gpt-4o-mini", "messages": [{"role": "user", "content": "Hello!"}]}' +``` + +**Use Cases:** +- **Environment Separation** - Production VKs use production keys, dev VKs use dev keys +- **Cost Control** - Different teams use keys with different billing accounts +- **Access Control** - Restrict sensitive keys to specific VKs only +- **Compliance** - Ensure certain workloads only use compliant/audited keys + +--- + +## Teams + +Teams provide organizational grouping for virtual keys with department-level budget management. Teams can belong to one customer and have their own independent budget allocation. + +**Key Features:** +- **Organizational Structure** - Group multiple virtual keys +- **Independent Budgets** - Department-level cost control (separate from customer budgets) +- **Customer Association** - Can belong to one customer (optional) +- **No Rate Limits** - Teams cannot have rate limits (VK-level only) + +**Configuration** + + + + +1. **Navigate to Teams** + - Open Bifrost UI at `http://localhost:8080` + - Go to **Governance** β†’ **Teams** + +2. **Create Team** + +![Team Creation](../../media/ui-create-teams.png) + +**Required Fields:** +- **Name**: Team identifier +- **Customer**: Optional parent customer + +**Budget Settings:** +- **Max Limit**: Department budget in dollars +- **Reset Duration**: Budget reset frequency + +**Virtual Key Assignment:** +- Assign existing virtual keys to team +- Create new virtual keys under team + +3. **Save Configuration** + - Click **Create Team** + - Assign virtual keys to the team + + + + +**Create Team:** +```bash +curl -X POST http://localhost:8080/api/governance/teams \ + -H "Content-Type: application/json" \ + -d '{ + "name": "Engineering Team", + "customer_id": "customer-acme-corp", + "budget": { + "max_limit": 500.00, + "reset_duration": "1M" + } + }' +``` + +**Update Team:** +```bash +curl -X PUT http://localhost:8080/api/governance/teams/{team_id} \ + -H "Content-Type: application/json" \ + -d '{ + "name": "Updated Engineering Team", + "budget": { + "max_limit": 750.00, + "reset_duration": "1M" + } + }' +``` + +**Get Teams:** +```bash +# List all teams +curl http://localhost:8080/api/governance/teams + +# Get specific team +curl http://localhost:8080/api/governance/teams/{team_id} +``` + +**Delete Team:** +```bash +curl -X DELETE http://localhost:8080/api/governance/teams/{team_id} +``` + + + + +```json +{ + "governance": { + "teams": [ + { + "id": "team-eng-001", + "name": "Engineering Team", + "customer_id": "customer-acme-corp", + "budget_id": "budget-team-eng" + }, + { + "id": "team-sales-001", + "name": "Sales Team", + "customer_id": "customer-acme-corp", + "budget_id": "budget-team-sales" + } + ], + "budgets": [ + { + "id": "budget-team-eng", + "max_limit": 500.00, + "reset_duration": "1M", + "current_usage": 0.0, + "last_reset": "2025-01-01T00:00:00Z" + }, + { + "id": "budget-team-sales", + "max_limit": 250.00, + "reset_duration": "1M", + "current_usage": 0.0, + "last_reset": "2025-01-01T00:00:00Z" + } + ] + } +} +``` + + + + +--- + +## Customers + +Customers represent the highest level in the governance hierarchy, typically corresponding to organizations or major business units. They provide top-level budget control and organizational structure. + +**Key Features:** +- **Top-Level Organization** - Highest hierarchy level +- **Independent Budgets** - Organization-wide cost control (separate from team/VK budgets) +- **Team Management** - Contains multiple teams and direct VKs +- **No Rate Limits** - Customers cannot have rate limits (VK-level only) + +**Configuration** + + + + +1. **Navigate to Customers** + - Open Bifrost UI at `http://localhost:8080` + - Go to **Governance** β†’ **Customers** + +2. **Create Customer** + +![Customer Creation](../../media/ui-create-customer.png) + +**Required Fields:** +- **Name**: Organization identifier + +**Budget Settings:** +- **Max Limit**: Organization budget in dollars +- **Reset Duration**: Budget reset frequency + +**Team Management:** +- View all teams under customer +- Create new teams under customer +- Monitor aggregate usage + +3. **Save Configuration** + - Click **Create Customer** + - Create teams under the customer + + + + +**Create Customer:** +```bash +curl -X POST http://localhost:8080/api/governance/customers \ + -H "Content-Type: application/json" \ + -d '{ + "name": "Acme Corporation", + "budget": { + "max_limit": 2000.00, + "reset_duration": "1M" + } + }' +``` + +**Update Customer:** +```bash +curl -X PUT http://localhost:8080/api/governance/customers/{customer_id} \ + -H "Content-Type: application/json" \ + -d '{ + "name": "Acme Corp (Updated)", + "budget": { + "max_limit": 2500.00, + "reset_duration": "1M" + } + }' +``` + +**Get Customers:** +```bash +# List all customers +curl http://localhost:8080/api/governance/customers + +# Get specific customer +curl http://localhost:8080/api/governance/customers/{customer_id} +``` + +**Delete Customer:** +```bash +curl -X DELETE http://localhost:8080/api/governance/customers/{customer_id} +``` + + + + +```json +{ + "governance": { + "customers": [ + { + "id": "customer-acme-corp", + "name": "Acme Corporation", + "budget_id": "budget-customer-acme" + }, + { + "id": "customer-beta-inc", + "name": "Beta Inc", + "budget_id": "budget-customer-beta" + } + ], + "budgets": [ + { + "id": "budget-customer-acme", + "max_limit": 2000.00, + "reset_duration": "1M", + "current_usage": 0.0, + "last_reset": "2025-01-01T00:00:00Z" + }, + { + "id": "budget-customer-beta", + "max_limit": 1500.00, + "reset_duration": "1M", + "current_usage": 0.0, + "last_reset": "2025-01-01T00:00:00Z" + } + ] + } +} +``` + + + + +--- + +## Usage & Headers + +### Required Header + +All governance-enabled requests must include the virtual key header: + +```bash +curl -X POST http://localhost:8080/v1/chat/completions \ + -H "Content-Type: application/json" \ + -H "x-bf-vk: vk-engineering-main" \ + -d '{ + "model": "gpt-4o-mini", + "messages": [{"role": "user", "content": "Hello!"}] + }' +``` + +### Optional Audit Headers + +Include additional headers for enhanced tracking and audit trails: + +```bash +curl -X POST http://localhost:8080/v1/chat/completions \ + -H "Content-Type: application/json" \ + -H "x-bf-vk: vk-engineering-main" \ + -H "x-bf-team: team-eng-001" \ + -H "x-bf-customer: customer-acme-corp" \ + -H "x-bf-user-id: user-alice" \ + -d '{ + "model": "gpt-4o-mini", + "messages": [{"role": "user", "content": "Hello!"}] + }' +``` + +**Header Definitions:** +- `x-bf-vk` - **Required** virtual key for access control +- `x-bf-team` - Optional team identifier for audit trails +- `x-bf-customer` - Optional customer identifier for audit trails +- `x-bf-user-id` - Optional user identifier for detailed tracking + +### Cost Calculation + +Bifrost automatically calculates costs based on: +- **Provider Pricing** - Real-time model pricing data +- **Token Usage** - Input + output tokens from API responses +- **Request Type** - Different pricing for chat, text, embedding, speech, transcription +- **Cache Status** - Reduced costs for cached responses +- **Batch Operations** - Volume discounts for batch requests + +Cost calculation details are covered in [Architecture > Plugins > Governance](../architecture/plugins/governance). + +### Budget Checking Flow + +When a request is made with a virtual key, Bifrost checks **all applicable budgets independently** in the hierarchy. Each budget must have sufficient remaining balance for the request to proceed. + +**Checking Sequence:** + +**For VK β†’ Team β†’ Customer:** +``` +1. βœ“ VK Budget (if VK has budget) +2. βœ“ Team Budget (if VK's team has budget) +3. βœ“ Customer Budget (if team's customer has budget) +``` + +**For VK β†’ Customer (direct):** +``` +1. βœ“ VK Budget (if VK has budget) +2. βœ“ Customer Budget (if VK's customer has budget) +``` + +**For Standalone VK:** +``` +1. βœ“ VK Budget (if VK has budget) +``` + +**Important Notes:** +- **All applicable budgets must pass** - any single budget failure blocks the request +- **Budgets are independent** - each tracks its own usage and limits +- **Costs are deducted from all applicable budgets** - same cost applied to each level +- **Rate limits checked only at VK level** - teams and customers have no rate limits + +**Example:** +- VK budget: $9/$10 remaining βœ“ +- Team budget: $15/$20 remaining βœ“ +- Customer budget: $45/$50 remaining βœ“ +- **Result: Allowed** (no budget is exceeded) +- After request: + - Request cost: $2 + - Updated VK=$11/$10, Team=$17/$20, Customer=$47/$50 + - Then the next request will be blocked. + +--- + +## Error Responses + +- Virtual Key Not Found (400) +```json +{ + "error": { + "type": "virtual_key_required", + "message": "x-bf-vk header is missing" + } +} +``` + +- Virtual Key Blocked (403) +```json +{ + "error": { + "type": "virtual_key_blocked", + "message": "Virtual key is inactive" + } +} +``` + +- Model Not Allowed (403) +```json +{ + "error": { + "type": "model_blocked", + "message": "Model 'gpt-4o' is not allowed for this virtual key" + } +} +``` + +- Provider Not Allowed (403) +```json +{ + "error": { + "type": "provider_blocked", + "message": "Provider 'anthropic' is not allowed for this virtual key" + } +} +``` + +- Rate Limit Exceeded (429) +```json +{ + "error": { + "type": "rate_limited", + "message": "Rate limits exceeded: [token limit exceeded (1500/1000, resets every 1h)]" + } +} +``` + +- Token Limit Exceeded (429) +```json +{ + "error": { + "type": "token_limited", + "message": "Rate limits exceeded: [token limit exceeded (1500/1000, resets every 1h)]" + } +} +``` + +- Request Limit Exceeded (429) +```json +{ + "error": { + "type": "request_limited", + "message": "Rate limits exceeded: [request limit exceeded (101/100, resets every 1m)]" + } +} +``` + +- Budget Exceeded (402) +```json +{ + "error": { + "type": "budget_exceeded", + "message": "Budget check failed: VK budget exceeded: 105.50 > 100.00 dollars" + } +} +``` + +**Budget Error Variations:** +- `"VK budget exceeded: 105.50 > 100.00 dollars"` - Virtual Key budget exceeded +- `"Team budget exceeded: 250.75 > 250.00 dollars"` - Team budget exceeded +- `"Customer budget exceeded: 1500.25 > 1500.00 dollars"` - Customer budget exceeded + +--- + +## Reset Durations + +Budgets and rate limits support flexible reset durations: + +**Format Examples:** +- `1m` - 1 minute +- `5m` - 5 minutes +- `1h` - 1 hour +- `1d` - 1 day +- `1w` - 1 week +- `1M` - 1 month + +**Common Patterns:** +- **Rate Limits**: `1m`, `1h`, `1d` for request throttling +- **Budgets**: `1d`, `1w`, `1M` for cost control +- **Development**: `5m`, `15m` for testing scenarios + +--- + +## Next Steps + +- **[Architecture Overview](../architecture/plugins/governance)** - Technical implementation details +- **[Telemetry](./telemetry)** - Monitoring governance usage and performance +- **[Tracing](./tracing)** - Audit trails and request tracking +- **[Provider Configuration](../quickstart/gateway/provider-configuration)** - Setting up provider API keys diff --git a/docs/features/keys-management.mdx b/docs/features/keys-management.mdx new file mode 100644 index 000000000..d05c8e800 --- /dev/null +++ b/docs/features/keys-management.mdx @@ -0,0 +1,251 @@ +--- +title: "Load Balance" +description: "Intelligent API key management with weighted load balancing, model-specific filtering, and automatic failover. Distribute traffic across multiple keys for optimal performance and reliability." +icon: "scale-balanced" +--- + +## Smart Key Distribution + +Bifrost's key management system goes beyond simple API key storage. It provides intelligent load balancing, model-specific key filtering, and weighted distribution to optimize performance and manage costs across multiple API keys. + +When you configure multiple keys for a provider, Bifrost automatically distributes requests using sophisticated selection algorithms that consider key weights, model compatibility, and deployment mappings. + +## How Key Selection Works + +Bifrost follows a precise selection process for every request: + +1. **Context Override Check**: First checks if a key is explicitly provided in context (bypassing management) +2. **Provider Key Lookup**: Retrieves all configured keys for the requested provider +3. **Model Filtering**: Filters keys that support the requested model +4. **Deployment Validation**: For Azure/Bedrock, validates deployment mappings +5. **Weighted Selection**: Uses weighted random selection among eligible keys + +This ensures optimal key usage while respecting your configuration constraints. + +## Implementation Examples + + + + + +```bash +# Configure multiple keys with weights via API +curl -X POST http://localhost:8080/api/providers \ + -H "Content-Type: application/json" \ + -d '{ + "provider": "openai", + "keys": [ + { + "value": "env.OPENAI_API_KEY_1", + "models": ["gpt-4o", "gpt-4o-mini"], + "weight": 0.7 + }, + { + "value": "env.OPENAI_API_KEY_2", + "models": [], + "weight": 0.3 + } + ] + }' + +# Regular request (uses weighted key selection) +curl -X POST http://localhost:8080/v1/chat/completions \ + -H "Content-Type: application/json" \ + -d '{ + "model": "openai/gpt-4o-mini", + "messages": [{"role": "user", "content": "Hello!"}] + }' + +# Request with direct API key (bypasses key management) +curl -X POST http://localhost:8080/v1/chat/completions \ + -H "Content-Type: application/json" \ + -H "Authorization: Bearer sk-your-direct-api-key" \ + -d '{ + "model": "openai/gpt-4o-mini", + "messages": [{"role": "user", "content": "Hello!"}] + }' +``` + + + + + +```go +package main + +import ( + "context" + "github.com/maximhq/bifrost/core/schemas" +) + +func (a *MyAccount) GetKeysForProvider(ctx *context.Context, provider schemas.ModelProvider) ([]schemas.Key, error) { + switch provider { + case schemas.OpenAI: + return []schemas.Key{ + { + ID: "primary-key", + Value: "env.OPENAI_API_KEY_1", + Models: ["gpt-4o", "gpt-4o-mini"], // Model whitelist + Weight: 0.7, // 70% of traffic + }, + { + ID: "secondary-key", + Value: "env.OPENAI_API_KEY_2", + Models: [], // Empty = supports all models + Weight: 0.3, // 30% of traffic + }, + }, nil + case schemas.Anthropic: + return []schemas.Key{ + { + Value: "env.ANTHROPIC_API_KEY", + Models: ["claude-3-5-sonnet-20241022"], + Weight: 1.0, + }, + }, nil + } + return nil, fmt.Errorf("provider %s not supported", provider) +} + +// Using with explicit context key (bypasses key management) +func makeRequestWithDirectKey() { + ctx := context.Background() + + // Direct key bypasses all key management + directKey := schemas.Key{ + Value: "sk-direct-api-key", + Weight: 1.0, + } + ctx = context.WithValue(ctx, schemas.BifrostContextKeyDirectKey, directKey) + + response, err := client.ChatCompletion(ctx, &schemas.BifrostRequest{ + Provider: schemas.OpenAI, + Model: "gpt-4o-mini", + Messages: messages, + }) +} +``` + + + + + +## Weighted Load Balancing + +Bifrost uses weighted random selection to distribute requests across multiple keys. This allows you to: + +**Control Traffic Distribution:** +- Assign higher weights to premium keys with better rate limits +- Balance between production and backup keys +- Gradually migrate traffic during key rotation + +**Weight Calculation Example:** +``` +Key 1: Weight 0.7 (70% probability) +Key 2: Weight 0.3 (30% probability) +Total Weight: 1.0 + +Random selection ensures statistical distribution over time +``` + +**Algorithm Details:** +1. Calculate total weight of all eligible keys +2. Generate random number between 0 and total weight +3. Select key based on cumulative weight ranges +4. If selected key fails, automatic fallback to next available key + +## Model Whitelisting and Filtering + +Keys can be restricted to specific models for access control and cost management: + +**Model Filtering Logic:** +- **Empty `models` array**: Key supports ALL models for that provider +- **Populated `models` array**: Key only supports listed models +- **Model mismatch**: Key is excluded from selection for that request + +**Use Cases:** +- **Premium Models**: Dedicated keys for expensive models (GPT-4, Claude-3) +- **Team Separation**: Different keys for different teams or projects +- **Cost Control**: Restrict access to specific model tiers +- **Compliance**: Separate keys for different security requirements + +**Example Model Restrictions:** +```json +{ + "keys": [ + { + "value": "premium-key", + "models": ["gpt-4o", "o1-preview"], // Only premium models + "weight": 1.0 + }, + { + "value": "standard-key", + "models": ["gpt-4o-mini", "gpt-3.5-turbo"], // Only standard models + "weight": 1.0 + } + ] +} +``` + +## Deployment Mapping (Azure & Bedrock) + +For cloud providers with deployment-based routing, Bifrost validates deployment availability: + +**Azure OpenAI:** +- Keys must have deployment mappings for specific models +- Deployment name maps to actual Azure deployment identifier +- Missing deployment excludes key from selection + +**AWS Bedrock:** +- Supports model profiles and direct model access +- Deployment mappings enable inference profile routing +- ARN configuration determines URL formation + +**Deployment Validation Process:** +1. Check if provider uses deployments (Azure/Bedrock) +2. Verify deployment exists for requested model +3. Exclude keys without proper deployment mapping +4. Continue with standard weighted selection + +## Direct Key Bypass + +For scenarios requiring explicit key control, Bifrost supports bypassing the entire key management system: + +**Go SDK Context Override:** +Pass a key directly in the request context using `schemas.BifrostContextKeyDirectKey`. This completely bypasses provider key lookup and selection. + +**Gateway Header-based Keys:** +Send API keys in `Authorization` (Bearer) or `x-api-key` headers. Requires `allow_direct_keys` setting to be enabled. + +**Enable Direct Keys:** + + + + + +![Web UI](../../media/ui-config-direct-keys.png) + +1. Navigate to **Configuration** page +2. Toggle **"Allow Direct Keys"** to enabled +3. Save configuration + + + + +```json +{ + "client": { + "allow_direct_keys": true + } +} +``` + + + + + +**When to Use Direct Keys:** +- Per-user API key scenarios +- External key management systems +- Testing with specific keys +- Debugging key-related issues diff --git a/docs/features/mcp.mdx b/docs/features/mcp.mdx new file mode 100644 index 000000000..fc2c744e3 --- /dev/null +++ b/docs/features/mcp.mdx @@ -0,0 +1,806 @@ +--- +title: "Model Context Protocol (MCP)" +description: "Enable AI models to discover and execute external tools dynamically. Transform static chat models into action-capable agents with filesystem access, web search, databases, and custom business logic." +icon: "toolbox" +--- + +## Overview + +**Model Context Protocol (MCP)** enables AI models to seamlessly discover and execute external tools at runtime, transforming static chat models into dynamic, action-capable agents. Instead of being limited to text generation, AI models can interact with filesystems, search the web, query databases, and execute custom business logic through external MCP servers. + +Bifrost's MCP integration provides a secure, high-performance bridge between AI models and external tools, with client-side control over all tool execution and granular filtering capabilities. + +**πŸ”’ Security-First Design**: Bifrost never automatically executes tool calls. Instead, it provides APIs for explicit tool execution, ensuring human oversight and approval for all potentially dangerous operations. + +### Key Benefits + +| Feature | Description | +|---------|-------------| +| **Dynamic Discovery** | Tools are discovered at runtime from external MCP servers | +| **Stateless Design** | Independent API calls with no session state management | +| **Client-Side Control** | Bifrost manages all tool execution for security and observability | +| **Multiple Protocols** | STDIO, HTTP, and SSE connection types | +| **Granular Filtering** | Control tool availability per request and client | +| **High Performance** | Async execution with minimal latency overhead | +| **Copy-Pastable Responses** | Tool results designed for seamless conversation assembly | + +--- + +## How MCP Works in Bifrost + +Bifrost acts as an MCP client that connects to external MCP servers hosting tools. The integration is **completely stateless** with independent API calls: + +1. **Discovery**: Bifrost connects to configured MCP servers and discovers available tools +2. **Integration**: Tools are automatically added to the AI model's function calling schema +3. **Suggestion**: Chat completion requests return tool call suggestions (not executed) +4. **Execution**: Separate tool execution API calls execute specific tool calls +5. **Assembly**: Your application manages conversation state and assembles chat history +6. **Continuation**: Follow-up chat requests use the complete conversation history + +**Stateless Tool Flow:** +``` +Chat Request β†’ Tool Call Suggestions (Independent) + ↓ +Tool Execution Request β†’ Tool Results (Independent) + ↓ +Your App Assembles History β†’ Continue Chat (Independent) +``` + +**Bifrost never automatically executes tool calls.** All API calls are independent and stateless: + +- **Chat completions** return tool call suggestions without executing them +- **Tool execution** requires separate API calls with explicit tool call data +- **No state management** - your application controls conversation flow +- **Copy-pastable responses** designed for easy conversation assembly + +This design prevents: +- Unintended API calls to external services +- Accidental data modification or deletion +- Execution of potentially harmful commands + +**Implementation Pattern:** +``` +1. POST /v1/chat/completions β†’ Get tool call suggestions (stateless) +2. Your App Reviews Tool Calls β†’ Decides which to execute +3. POST /v1/mcp/tool/execute β†’ Execute specific tool calls (stateless) +4. Your App Assembles History β†’ Continue with complete conversation +``` + +This stateless pattern ensures **explicit control** over all tool operations while providing responses optimized for conversation continuity. + +--- + +## Setup Guides + +### Go SDK Setup + +Configure MCP in your Bifrost initialization: + +```go +package main + +import ( + "github.com/maximhq/bifrost/core" + "github.com/maximhq/bifrost/core/schemas" +) + +func main() { + mcpConfig := &schemas.MCPConfig{ + ClientConfigs: []schemas.MCPClientConfig{ + { + Name: "filesystem-tools", + ConnectionType: schemas.MCPConnectionTypeSTDIO, + StdioConfig: &schemas.MCPStdioConfig{ + Command: "node", + Args: []string{"filesystem-mcp-server.js"}, + }, + }, + { + Name: "web-search", + ConnectionType: schemas.MCPConnectionTypeHTTP, + ConnectionString: bifrost.Ptr("http://localhost:3001/mcp"), + }, + }, + } + + client, err := bifrost.Init(context.Background(), schemas.BifrostConfig{ + Account: account, + MCPConfig: mcpConfig, + Logger: bifrost.NewDefaultLogger(schemas.LogLevelInfo), + }) +} +``` + +### Gateway Setup + + + + +![MCP Configuration in Web UI](../media/ui-mcp-config.png) + +1. Navigate to **MCP Clients** in the Bifrost Gateway UI +2. Click **Add MCP Client** +3. Configure connection details: + - **Name**: Unique identifier for the MCP client + - **Connection Type**: STDIO, HTTP, or SSE + - **Connection Details**: Command/URL based on connection type + - **Tool Filtering**: Optional whitelist/blacklist of tools + +The UI automatically validates configurations and shows connection status in real-time. + + + + +Add MCP clients via the Gateway API: + +```bash +# Add STDIO MCP client +curl -X POST http://localhost:8080/api/mcp/client \ + -H "Content-Type: application/json" \ + -d '{ + "name": "filesystem-tools", + "connection_type": "stdio", + "stdio_config": { + "command": "node", + "args": ["filesystem-mcp-server.js"], + "envs": ["NODE_ENV"] + }, + "tools_to_execute": ["read_file", "write_file"] + }' + +# Add HTTP MCP client +curl -X POST http://localhost:8080/api/mcp/client \ + -H "Content-Type: application/json" \ + -d '{ + "name": "web-search", + "connection_type": "http", + "connection_string": "http://localhost:3001/mcp" + }' +``` + + + + +Configure MCP clients in your `config.json`: + +```json +{ + "mcp": { + "client_configs": [ + { + "name": "filesystem-tools", + "connection_type": "stdio", + "stdio_config": { + "command": "node", + "args": ["filesystem-mcp-server.js"], + "envs": ["NODE_ENV"] + }, + "tools_to_execute": ["read_file", "write_file", "list_directory"] + }, + { + "name": "web-search", + "connection_type": "http", + "connection_string": "env.WEB_SEARCH_MCP_URL", + "tools_to_skip": ["internal_debug_tool"] + }, + { + "name": "real-time-data", + "connection_type": "sse", + "connection_string": "https://api.example.com/mcp/sse" + } + ] + } +} +``` + + + + +--- + +## Connection Types + +### STDIO Connection + +STDIO connections launch external processes and communicate via standard input/output. Best for local tools and scripts. + +**Configuration:** +```json +{ + "name": "local-tools", + "connection_type": "stdio", + "stdio_config": { + "command": "python", + "args": ["-m", "my_mcp_server"], + "envs": ["PYTHON_PATH", "API_KEY"] + } +} +``` + +**Use Cases:** +- Local filesystem operations +- Database queries with local credentials +- Python/Node.js MCP servers +- Custom business logic scripts + +### HTTP Connection + +HTTP connections communicate with MCP servers via HTTP requests. Ideal for remote services and microservices. + +**Configuration:** +```json +{ + "name": "remote-api", + "connection_type": "http", + "connection_string": "https://mcp-server.example.com/api" +} +``` + +**Use Cases:** +- Remote API integrations +- Cloud-hosted MCP services +- Microservice architectures +- Third-party tool providers + +### SSE Connection + +Server-Sent Events (SSE) connections provide real-time, persistent connections to MCP servers. Best for streaming data and live updates. + +**Configuration:** +```json +{ + "name": "live-data", + "connection_type": "sse", + "connection_string": "https://stream.example.com/mcp/events" +} +``` + +**Use Cases:** +- Real-time market data +- Live system monitoring +- Streaming analytics +- Event-driven workflows + +--- + +## End-to-End Tool Calling + + + + +Complete tool calling workflow with the Go SDK: + +```go +package main + +import ( + "context" + "fmt" + "github.com/maximhq/bifrost/core" + "github.com/maximhq/bifrost/core/schemas" +) + +func main() { + // Initialize Bifrost with MCP + client, err := bifrost.Init(context.Background(), schemas.BifrostConfig{ + Account: account, + MCPConfig: &schemas.MCPConfig{ + ClientConfigs: []schemas.MCPClientConfig{ + { + Name: "filesystem", + ConnectionType: schemas.MCPConnectionTypeSTDIO, + StdioConfig: &schemas.MCPStdioConfig{ + Command: "node", + Args: []string{"fs-mcp-server.js"}, + }, + }, + }, + }, + }) + + firstMessage := schemas.BifrostMessage{ + Role: schemas.ModelChatMessageRoleUser, + Content: schemas.MessageContent{ + ContentStr: bifrost.Ptr("Read the contents of config.json file"), + }, + } + + // Create request with tools automatically included + request := &schemas.BifrostRequest{ + Provider: schemas.OpenAI, + Model: "gpt-4o-mini", + Input: schemas.RequestInput{ + ChatCompletionInput: &[]schemas.BifrostMessage{ + firstMessage, + }, + }, + Params: &schemas.ModelParameters{ + Temperature: bifrost.Ptr(0.7), + }, + } + + // Send chat completion request - MCP tools are automatically available + response, err := client.ChatCompletionRequest(context.Background(), request) + if err != nil { + panic(err) + } + + // Build conversation history for final response + conversationHistory := []schemas.BifrostMessage{ + firstMessage, + } + + // Handle tool calls in response (suggestions only - not executed) + if response.Choices[0].Message.ToolCalls != nil { + secondMessage := response.Choices[0].Message + + // Add assistant message with tool calls to history + conversationHistory = append(conversationHistory, secondMessage) + + for _, toolCall := range *secondMessage.ToolCalls { + fmt.Printf("Tool suggested: %s\n", *toolCall.Function.Name) + + // YOUR APPLICATION DECISION: Review the tool call + // - Validate tool name and arguments + // - Apply security and business rules + // - Check permissions and rate limits + // - Decide whether to execute + + shouldExecute := validateToolCall(toolCall) // Your validation logic + if !shouldExecute { + fmt.Printf("Tool call rejected by application\n") + continue + } + + // EXPLICIT EXECUTION: Separate API call + thirdMessage, err := client.ExecuteMCPTool(context.Background(), toolCall) + if err != nil { + fmt.Printf("Tool execution failed: %v\n", err) + continue + } + + fmt.Printf("Tool result: %s\n", *thirdMessage.Content.ContentStr) + + // Add tool result to conversation history + conversationHistory = append(conversationHistory, thirdMessage) + } + + // Send complete conversation history for final response + finalRequest := &schemas.BifrostRequest{ + Provider: schemas.OpenAI, + Model: "gpt-4o-mini", + Input: schemas.RequestInput{ + ChatCompletionInput: &conversationHistory, + }, + Params: &schemas.ModelParameters{ + Temperature: bifrost.Ptr(0.7), + }, + } + + finalResponse, err := client.ChatCompletionRequest(context.Background(), finalRequest) + if err != nil { + panic(err) + } + + fmt.Printf("Final response: %s\n", *finalResponse.Choices[0].Message.Content.ContentStr) + } +} +``` + + + + +Complete tool calling workflow via Gateway API: + +```bash +# 1. Send chat completion request - tools are automatically included +curl -X POST http://localhost:8080/v1/chat/completions \ + -H "Content-Type: application/json" \ + -d '{ + "model": "openai/gpt-4o", + "messages": [ + { + "role": "user", + "content": "Show me latest videos regarding Bifrost" + } + ] + }' + +# Response includes tool calls (suggestions only - NOT executed yet): +{ + "choices": [{ + "message": { + "role": "assistant", + "content": "", + "tool_calls": [{ + "id": "call_f5aAgjJAC9FO4Or0F2oCVAho", + "type": "function", + "function": { + "name": "YOUTUBE_SEARCH_YOU_TUBE", + "arguments": "{\"q\":\"Bifrost\",\"part\":\"snippet\",\"maxResults\":5}" + } + }] + } + }] +} + +# 2. YOUR APPLICATION DECISION: Review the tool call +# - Validate the search query is appropriate +# - Check rate limits and quotas +# - Apply content filtering rules +# - Approve or reject based on business logic + +# 3. EXPLICIT EXECUTION: Execute the approved tool call (request body is the same as the tool call suggestion) +curl -X POST http://localhost:8080/v1/mcp/tool/execute \ + -H "Content-Type: application/json" \ + -d '{ + "type": "function", + "id": "call_f5aAgjJAC9FO4Or0F2oCVAho", + "function": { + "name": "YOUTUBE_SEARCH_YOU_TUBE", + "arguments": "{\"q\":\"Bifrost\",\"part\":\"snippet\",\"maxResults\":5}" + } + }' + +# Tool execution response (copy-pastable for conversation): +{ + "role": "tool", + "content": "{\n\"data\": {\n\"response_data\": {\n\"items\": [\n{\n\"snippet\": {\n\"title\": \"Fastest LLM Gateway - Bifrost\",\n \"description\": \"Bifrost is the fastest LLM Gateway that allows you to use any LLM...\"\n}\n}\n]\n}\n}\n}", + "tool_call_id": "call_f5aAgjJAC9FO4Or0F2oCVAho" +} + +# 4. Assemble complete conversation history and continue +curl -X POST http://localhost:8080/v1/chat/completions \ + -H "Content-Type: application/json" \ + -d '{ + "model": "openai/gpt-4o", + "messages": [ + { + "role": "user", + "content": "Show me latest videos regarding Bifrost" + }, + { + "role": "assistant", + "content": "", + "tool_calls": [{ + "id": "call_f5aAgjJAC9FO4Or0F2oCVAho", + "type": "function", + "function": { + "name": "YOUTUBE_SEARCH_YOU_TUBE", + "arguments": "{\"q\":\"Bifrost\",\"part\":\"snippet\",\"maxResults\":5}" + } + }] + }, + { + "role": "tool", + "content": "{\n\"data\": {\n\"response_data\": {...}\n }\n}", + "tool_call_id": "call_f5aAgjJAC9FO4Or0F2oCVAho" + } + ] + }' + +# Final response with formatted results: +{ + "choices": [{ + "message": { + "role": "assistant", + "content": "Here are the latest videos related to \"Bifrost\":\n\n1. **Fastest LLM Gateway - Bifrost**\n - Published: August 21, 2025\n - Description: Bifrost is the fastest LLM Gateway that allows you to use any LLM..." + } + }] +} +``` + + + + +--- + +## Tool Registry (Go SDK Only) + +The Go SDK provides a powerful tool registry for hosting custom tools directly within your application using typed handlers. + +### Registering Typed Tools + +```go +package main + +import ( + "fmt" + "strings" + "github.com/maximhq/bifrost/core" + "github.com/maximhq/bifrost/core/schemas" +) + +// Define typed arguments for your tool +type CalculatorArgs struct { + Operation string `json:"operation"` // add, subtract, multiply, divide + A float64 `json:"a"` + B float64 `json:"b"` +} + +// Define typed tool handler +func calculatorHandler(args CalculatorArgs) (string, error) { + switch strings.ToLower(args.Operation) { + case "add": + return fmt.Sprintf("%.2f", args.A + args.B), nil + case "subtract": + return fmt.Sprintf("%.2f", args.A - args.B), nil + case "multiply": + return fmt.Sprintf("%.2f", args.A * args.B), nil + case "divide": + if args.B == 0 { + return "", fmt.Errorf("cannot divide by zero") + } + return fmt.Sprintf("%.2f", args.A / args.B), nil + default: + return "", fmt.Errorf("unsupported operation: %s", args.Operation) + } +} + +func main() { + // Initialize Bifrost (tool registry creates in-process MCP automatically) + client, err := bifrost.Init(context.Background(), schemas.BifrostConfig{ + Account: account, + Logger: bifrost.NewDefaultLogger(schemas.LogLevelInfo), + }) + + // Define tool schema + calculatorSchema := schemas.Tool{ + Type: "function", + Function: schemas.Function{ + Name: "calculator", + Description: "Perform basic arithmetic operations", + Parameters: schemas.FunctionParameters{ + Type: "object", + Properties: map[string]interface{}{ + "operation": map[string]interface{}{ + "type": "string", + "description": "The operation to perform", + "enum": []string{"add", "subtract", "multiply", "divide"}, + }, + "a": map[string]interface{}{ + "type": "number", + "description": "First number", + }, + "b": map[string]interface{}{ + "type": "number", + "description": "Second number", + }, + }, + Required: []string{"operation", "a", "b"}, + }, + }, + } + + // Register the typed tool + err = client.RegisterMCPTool("calculator", "Perform arithmetic calculations", + func(args any) (string, error) { + // Convert args to typed struct + calculatorArgs := CalculatorArgs{} + if jsonBytes, err := json.Marshal(args); err == nil { + json.Unmarshal(jsonBytes, &calculatorArgs) + } + return calculatorHandler(calculatorArgs) + }, calculatorSchema) + + if err != nil { + panic(fmt.Sprintf("Failed to register tool: %v", err)) + } + + // Now use the tool in requests + request := &schemas.BifrostRequest{ + Provider: schemas.OpenAI, + Model: "gpt-4o-mini", + Input: schemas.RequestInput{ + ChatCompletionInput: &[]schemas.BifrostMessage{ + { + Role: schemas.ModelChatMessageRoleUser, + Content: schemas.MessageContent{ + ContentStr: bifrost.Ptr("Calculate 15.5 + 24.3"), + }, + }, + }, + }, + Params: &schemas.ModelParameters{ + Temperature: bifrost.Ptr(0.7), + }, + } + + response, err := client.ChatCompletionRequest(context.Background(), request) + // The model can now use the calculator tool automatically +} +``` + +### Tool Registry Benefits + +- **Type Safety**: Compile-time checking of tool arguments and return types +- **Performance**: In-process execution with zero network overhead +- **Simplicity**: No external MCP server setup required +- **Integration**: Tools are automatically available to all AI requests +- **Error Handling**: Structured error responses with detailed context + +--- + +## Advanced Configuration + +### Tool and Client Filtering + +Control which tools and clients are available per request or globally: + +**Request-Level Filtering:** + + + + +Use context values to filter clients and tools per request: + +```go +// Include only specific clients +ctx := context.WithValue(context.Background(), "mcp-include-clients", []string{"filesystem", "web-search"}) + +// Exclude specific tools +ctx = context.WithValue(ctx, "mcp-exclude-tools", "debug_tool,internal_tool") + +// Include only specific tools +ctx = context.WithValue(ctx, "mcp-include-tools", "search,read_file") + +// Exclude specific clients +ctx = context.WithValue(ctx, "mcp-exclude-clients", "admin-tools") + +response, err := client.ChatCompletionRequest(ctx, request) +``` + + + + +Use headers to filter clients and tools per request: + +```bash +curl -X POST http://localhost:8080/v1/chat/completions \ + -H "Content-Type: application/json" \ + -H "x-bf-mcp-include-clients: filesystem,web-search" \ + -H "x-bf-mcp-exclude-tools: debug_tool,internal_tool" \ + -d '{ + "model": "openai/gpt-4o", + "messages": [ + { + "role": "user", + "content": "Search for recent AI developments" + } + ] + }' + +# Alternative filtering options: +# -H "x-bf-mcp-include-tools: search,read_file" # Whitelist specific tools +# -H "x-bf-mcp-exclude-clients: admin-tools" # Blacklist specific clients +``` + +**Available MCP Headers:** +- `x-bf-mcp-include-clients`: Comma-separated list of clients to include +- `x-bf-mcp-exclude-clients`: Comma-separated list of clients to exclude +- `x-bf-mcp-include-tools`: Comma-separated list of tools to include +- `x-bf-mcp-exclude-tools`: Comma-separated list of tools to exclude + + + + +**Filtering Priority Rules:** + +1. **Request-Level vs Config-Level**: Request-level filtering (context/headers) takes priority over config-level filtering and can override it + +2. **Include vs Exclude Priority**: + - **Include lists are strict whitelists**: If `include-clients`/`include-tools` is specified, ONLY those clients/tools are allowed + - **Whitelist priority**: When both include and exclude are specified and a tool/client is in both, include takes priority + +**Client Configuration Filtering:** +```json +{ + "name": "external-api", + "connection_type": "http", + "connection_string": "https://api.example.com/mcp", + "tools_to_execute": ["search", "summarize"], // Whitelist specific tools + "tools_to_skip": ["delete", "admin_action"] // Blacklist specific tools +} +``` + +### Environment Variables + +Use environment variables for sensitive configuration: + +**Gateway:** +```json +{ + "name": "secure-api", + "connection_type": "http", + "connection_string": "env.SECURE_MCP_URL", // References $SECURE_MCP_URL + "stdio_config": { + "command": "python", + "args": ["-m", "secure_server"], + "envs": ["API_SECRET", "DATABASE_URL"] // Required environment variables + } +} +``` + +**Environment variables are:** +- Automatically resolved during client connection +- Redacted in API responses and UI for security +- Validated at startup to ensure all required variables are set + +### Client State Management + +Monitor and manage MCP client connections: + + + + +```go +// Get all connected clients and their status +clients, err := client.GetMCPClients() +for _, mcpClient := range clients { + fmt.Printf("Client: %s, State: %s, Tools: %v\n", + mcpClient.Name, mcpClient.State, mcpClient.Tools) +} + +// Reconnect a disconnected client +err = client.ReconnectMCPClient("filesystem-tools") + +// Add new client at runtime +err = client.AddMCPClient(newClientConfig) + +// Remove client +err = client.RemoveMCPClient("old-client") + +// Edit client tools +err = client.EditMCPClientTools("filesystem-tools", + []string{"read_file", "write_file"}, // tools to add + []string{"delete_file"}) // tools to remove +``` + + + + +```bash +# Get client status +curl http://localhost:8080/api/mcp/clients + +# Reconnect client +curl -X POST http://localhost:8080/api/mcp/client/filesystem-tools/reconnect + +# Add new client +curl -X POST http://localhost:8080/api/mcp/client \ + -H "Content-Type: application/json" \ + -d '{ + "name": "new-filesystem", + "connection_type": "stdio", + "stdio_config": { + "command": "node", + "args": ["fs-server.js"] + } + }' + +# Edit client tools +curl -X PUT http://localhost:8080/api/mcp/client/filesystem-tools \ + -H "Content-Type: application/json" \ + -d '{ + "tools_to_add": ["read_file", "write_file"], + "tools_to_remove": ["delete_file"] + }' + +# Remove client +curl -X DELETE http://localhost:8080/api/mcp/client/old-client +``` + + + + +**Connection States:** +- **Connected**: Client is active and tools are available +- **Connecting**: Client is establishing connection +- **Disconnected**: Client lost connection but can be reconnected +- **Error**: Client configuration or connection failed + +--- + +## Architecture Details + +For detailed information about MCP's internal architecture, concurrency model, tool discovery process, and performance characteristics, see the [MCP Architecture Guide](../architecture/core/mcp). diff --git a/docs/features/observability.mdx b/docs/features/observability.mdx new file mode 100644 index 000000000..870f62e58 --- /dev/null +++ b/docs/features/observability.mdx @@ -0,0 +1,225 @@ +--- +title: "Observability" +description: "Integrate Maxim SDK for comprehensive LLM observability, tracing, and evaluation." +icon: "binoculars" +--- + +## Overview + +Bifrost provides comprehensive LLM observability through the **Maxim plugin**, enabling seamless tracking, evaluation, and analysis of AI interactions. The plugin automatically forwards all LLM requests and responses to Maxim's platform for detailed monitoring and performance insights. + +![Maxim Logs](../media/maxim-logs.png) + +--- + +## Setup + +The Maxim plugin enables seamless observability and evaluation of LLM interactions by forwarding inputs/outputs to Maxim's platform: + + + + +```go +package main + +import ( + "context" + bifrost "github.com/maximhq/bifrost/core" + "github.com/maximhq/bifrost/core/schemas" + maxim "github.com/maximhq/bifrost/plugins/maxim" +) + +func main() { + // Initialize Maxim plugin + maximPlugin, err := maxim.Init(maxim.Config{ + ApiKey: "your_maxim_api_key", + LogRepoId: "your_default_repo_id", // Optional: fallback repository + }) + if err != nil { + panic(err) + } + + // Initialize Bifrost with the plugin + client, err := bifrost.Init(context.Background(), schemas.BifrostConfig{ + Account: &yourAccount, + Plugins: []schemas.Plugin{maximPlugin}, + }) + if err != nil { + panic(err) + } + defer client.Shutdown() + + // All requests will now be traced to Maxim +} +``` + + + + +For HTTP transport, configure via environment variables: + +```json +{ + "plugins": [ + { + "enabled": true, + "name": "maxim", + "config": { + "api_key": "your_maxim_api_key", + "log_repo_id": "your_default_repo_id" + } + } + ] +} +``` + + + + +## Configuration + +| Field | Type | Required | Description | +|-------|------|----------|-------------| +| `ApiKey` | `string` | βœ… Yes | Your Maxim API key for authentication | +| `LogRepoId` | `string` | ❌ No | Default log repository ID (can be overridden per request) | + +## Repository Selection + +The plugin uses repository selection with the following priority: + +1. **Header/Context Repository** - Highest priority +2. **Default Repository** (from plugin config) - Fallback +3. **Skip Logging** - If neither is available + + + + +```go +ctx := context.Background() + +// Use specific repository for this request +ctx = context.WithValue(ctx, maxim.LogRepoIDKey, "project-specific-repo") +``` + + + + +```bash +# Use default repository (from config) +curl -X POST http://localhost:8080/v1/chat/completions \ + -d '{"model": "gpt-4", "messages": [...]}' + +# Override with specific repository +curl -X POST http://localhost:8080/v1/chat/completions \ + -H "x-bf-maxim-log-repo-id: project-specific-repo" \ + -d '{"model": "gpt-4", "messages": [...]}' +``` + + + + + +## Custom Trace Management + +### Trace Propagation + +The plugin supports custom session, trace, and generation IDs for advanced tracing scenarios: + + + +```go +ctx := context.Background() + +// Prefer typed keys from the Maxim plugin +ctx = context.WithValue(ctx, maxim.TraceIDKey, "custom-trace-123") +ctx = context.WithValue(ctx, maxim.GenerationIDKey, "custom-gen-456") +ctx = context.WithValue(ctx, maxim.SessionIDKey, "user-session-789") + +// Optionally set human-friendly names +ctx = context.WithValue(ctx, maxim.TraceNameKey, "checkout-flow") +ctx = context.WithValue(ctx, maxim.GenerationNameKey, "rerank-step") +``` + + +```bash +curl -X POST http://localhost:8080/v1/chat/completions \ + -H "x-bf-maxim-trace-id: custom-trace-123" \ + -H "x-bf-maxim-generation-id: custom-gen-456" \ + -H "x-bf-maxim-session-id: user-session-789" \ + -H "x-bf-maxim-trace-name: checkout-flow" \ + -H "x-bf-maxim-generation-name: rerank-step" \ + -d '{"model": "gpt-4", "messages": [...]}' +``` + + + +### Custom Tags + +You can add custom tags to traces for enhanced filtering and analytics: + + + + +```go +ctx := context.Background() + +// Pass arbitrary tag key-values via context map +tags := map[string]string{ + "environment": "production", + "user-id": "user-123", + "feature-flag": "new-ui", +} +ctx = context.WithValue(ctx, maxim.TagsKey, tags) +``` + + + + +```bash +curl -X POST http://localhost:8080/v1/chat/completions \ + -H "x-bf-maxim-environment: production" \ + -H "x-bf-maxim-user-id: user-123" \ + -H "x-bf-maxim-feature-flag: new-ui" \ + -d '{"model": "gpt-4", "messages": [...]}' +``` + +Reserved keys are `session-id`, `trace-id`, `trace-name`, `generation-id`, `generation-name`, `log-repo-id`. All other `x-bf-maxim-*` headers are treated as tags. + + + + +## Supported Request Types + +The plugin supports the following Bifrost request types: + +- Text Completion +- Chat Completion + +## Monitoring & Analytics + +Once configured, monitor your AI apps in the [Maxim Dashboard](https://getmaxim.ai/). Maxim is an end-to-end evaluation & observability platform built to help teams ship AI agents faster while maintaining high quality. + +* **Experiment / Prompt Engineering** + Playground++ for prompt design: versioning, comparison (A/B), visual chaining, low-code tooling. + +* **Simulation & Evaluation** + Test agents over thousands of scenarios, both automated (statistical, programmatic) and human-in-the-loop for edge cases. Custom and off-the-shelf evaluators. + +* **Observability / Monitoring** + Real-time traces, logging, debugging of multi-agent workflows, live issue tracking, alerts when quality or performance degrade. + +* **Data Engine & Dataset Management** + Support for multi-modal datasets, import & continuous curation, feedback/annotation pipelines, data splitting for experiments. + +* **Governance, Security & Compliance** + Features like SOC 2 Type II compliance, enterprise security controls, permissions, auditability. + +* **Alerts & SLAs**: Threshold-based notifications to keep quality and latency in guardrails + +## Next Steps + +Now that you have observability set up with the Maxim plugin, explore these related topics: + +- **[Tracing](./tracing)** - Deep-dive into request/response logging and correlation +- **[Telemetry](./telemetry)** - Prometheus metrics, dashboards, and alerting +- **[Governance](./governance)** - Virtual keys, per-team controls, and usage limits diff --git a/docs/features/plugins/circuit-breaker.mdx b/docs/features/plugins/circuit-breaker.mdx new file mode 100644 index 000000000..e69de29bb diff --git a/docs/features/plugins/jsonparser.mdx b/docs/features/plugins/jsonparser.mdx new file mode 100644 index 000000000..379f09f9a --- /dev/null +++ b/docs/features/plugins/jsonparser.mdx @@ -0,0 +1,306 @@ +--- +title: JSON Parser +description: A simple Bifrost plugin that handles partial JSON chunks in streaming responses by making them valid JSON objects. +icon: "code-branch" +--- + +## Overview + +When using AI providers that stream JSON responses, the individual chunks often contain incomplete JSON that cannot be parsed directly. This plugin automatically detects and fixes partial JSON chunks by adding the necessary closing braces, brackets, and quotes to make them valid JSON. + +## Features + +- **Automatic JSON Completion**: Detects partial JSON and adds missing closing characters +- **Streaming Only**: Processes only streaming responses (non-streaming responses are ignored) +- **Flexible Usage Modes**: Supports two usage types for different deployment scenarios +- **Safe Fallback**: Returns original content if JSON cannot be fixed +- **Memory Leak Prevention**: Automatic cleanup of stale accumulated content with configurable intervals +- **Zero Dependencies**: Only depends on Go's standard library + +## Usage + +### Usage Types + +The plugin supports two usage types: + +1. **AllRequests**: Processes all streaming responses automatically +2. **PerRequest**: Processes only when explicitly enabled via request context + + +```go +package main + +import ( + "time" + "github.com/maximhq/bifrost/core" + "github.com/maximhq/bifrost/core/schemas" + "github.com/maximhq/bifrost/plugins/jsonparser" +) + +func main() { + // Create the JSON parser plugin for all requests + jsonPlugin := jsonparser.NewJsonParserPlugin(jsonparser.PluginConfig{ + Usage: jsonparser.AllRequests, + CleanupInterval: 2 * time.Minute, // Cleanup every 2 minutes + MaxAge: 10 * time.Minute, // Remove entries older than 10 minutes + }) + + // Initialize Bifrost with the plugin + client, err := bifrost.Init(context.Background(), schemas.BifrostConfig{ + Account: &MyAccount{}, + Plugins: []schemas.Plugin{ + jsonPlugin, + }, + }) + + if err != nil { + panic(err) + } + + // Use the client normally - JSON parsing happens automatically + // in the PostHook for all streaming responses +} +``` + +### PerRequest Mode + +```go +package main + +import ( + "context" + "time" + "github.com/maximhq/bifrost/core" + "github.com/maximhq/bifrost/core/schemas" + "github.com/maximhq/bifrost/plugins/jsonparser" +) + +func main() { + // Create the JSON parser plugin for per-request control + jsonPlugin := jsonparser.NewJsonParserPlugin(jsonparser.PluginConfig{ + Usage: jsonparser.PerRequest, + CleanupInterval: 2 * time.Minute, // Cleanup every 2 minutes + MaxAge: 10 * time.Minute, // Remove entries older than 10 minutes + }) + + // Initialize Bifrost with the plugin + client, err := bifrost.Init(context.Background(), schemas.BifrostConfig{ + Account: &MyAccount{}, + Plugins: []schemas.Plugin{ + jsonPlugin, + }, + }) + + if err != nil { + panic(err) + } + + ctx := context.WithValue(context.Background(), jsonparser.EnableStreamingJSONParser, true) + + // Enable JSON parsing for specific requests + stream, bifrostErr := client.ChatCompletionStreamRequest(ctx, request) + if bifrostErr != nil { + // handle error + } + for chunk := range stream { + _ = chunk // handle each streaming chunk + } +} +``` + +### Configuration + +```go +// Custom cleanup configuration +plugin := jsonparser.NewJsonParserPlugin(jsonparser.PluginConfig{ + Usage: jsonparser.AllRequests, + CleanupInterval: 2 * time.Minute, // Cleanup every 2 minutes + MaxAge: 10 * time.Minute, // Remove entries older than 10 minutes +}) +``` + +#### Default Values + +- **CleanupInterval**: 5 minutes (how often to run cleanup) +- **MaxAge**: 30 minutes (how old entries can be before cleanup) +- **Usage**: Must be specified (AllRequests or PerRequest) + +### Context Key for PerRequest Mode + +When using `PerRequest` mode, the plugin checks for the context key `jsonparser.EnableStreamingJSONParser` with a boolean value: + +- `true`: Enable JSON parsing for this request +- `false`: Disable JSON parsing for this request +- Key not present: Disable JSON parsing for this request + +**Example:** + +```go +import ( + "context" + + "github.com/maximhq/bifrost/plugins/jsonparser" +) + +// Enable JSON parsing for this request +ctx := context.WithValue(context.Background(), jsonparser.EnableStreamingJSONParser, true) + +// Disable JSON parsing for this request +ctx := context.WithValue(context.Background(), jsonparser.EnableStreamingJSONParser, false) + +// No context key - JSON parsing disabled (default behavior) +ctx := context.Background() +``` + +## How It Works + +The plugin implements an optimized `parsePartialJSON` function with the following steps: + +1. **Usage Check**: Determines if processing should occur based on usage type and context +2. **Validates Input**: First tries to parse the string as valid JSON +3. **Character Analysis**: If invalid, processes the string character-by-character to track: + - String boundaries (inside/outside quotes) + - Escape sequences + - Opening/closing braces and brackets +4. **Auto-Completion**: Adds missing closing characters in the correct order +5. **Validation**: Verifies the completed JSON is valid +6. **Fallback**: Returns original content if completion fails + +### Memory Management + +The plugin automatically manages memory by: + +1. **Accumulating Content**: Stores partial JSON chunks with timestamps for each request +2. **Periodic Cleanup**: Runs a background goroutine that removes stale entries based on `MaxAge` +3. **Request Completion**: Automatically clears accumulated content when requests complete successfully +4. **Configurable Intervals**: Allows customization of cleanup frequency and retention periods + +### Real-Life Streaming Example + +Here's a practical example showing how the JSON parser plugin fixes broken JSON chunks in streaming responses: + +```go +package main + +import ( + "context" + "encoding/json" + "fmt" + "time" + "github.com/maximhq/bifrost/core" + "github.com/maximhq/bifrost/core/schemas" + "github.com/maximhq/bifrost/plugins/jsonparser" +) + +func main() { + // Create JSON parser plugin + jsonPlugin := jsonparser.NewJsonParserPlugin(jsonparser.PluginConfig{ + Usage: jsonparser.AllRequests, + CleanupInterval: 2 * time.Minute, + MaxAge: 10 * time.Minute, + }) + + // Initialize Bifrost with the plugin + client, err := bifrost.Init(context.Background(), schemas.BifrostConfig{ + Account: &MyAccount{}, + Plugins: []schemas.Plugin{jsonPlugin}, + }) + if err != nil { + panic(err) + } + defer client.Cleanup() + + // Request structured JSON response + request := &schemas.BifrostRequest{ + Provider: schemas.OpenAI, + Model: "gpt-4o-mini", + Messages: []schemas.BifrostMessage{ + { + Role: schemas.ModelChatMessageRoleUser, + Content: schemas.MessageContent{ + ContentStr: bifrost.Ptr("Return user profile as JSON: {\"name\": \"John Doe\", \"email\": \"john@example.com\"}"), + }, + }, + }, + } + + // Stream the response + stream, bifrostErr := client.ChatCompletionStreamRequest(context.Background(), request) + if bifrostErr != nil { + panic(bifrostErr) + } + + fmt.Println("Streaming JSON response:") + for chunk := range stream { + if chunk.BifrostResponse != nil && len(chunk.BifrostResponse.Choices) > 0 { + choice := chunk.BifrostResponse.Choices[0] + if choice.BifrostStreamResponseChoice != nil && choice.BifrostStreamResponseChoice.Delta.Content != nil { + content := *choice.BifrostStreamResponseChoice.Delta.Content + fmt.Printf("Chunk: %s\n", content) + + // With JSON parser, you can parse each chunk immediately + var jsonData map[string]interface{} + if err := json.Unmarshal([]byte(content), &jsonData); err == nil { + fmt.Printf("βœ… Valid JSON parsed successfully\n") + } else { + fmt.Printf("❌ Invalid JSON: %v\n", err) + } + } + } + } +} +``` + +**Without JSON Parser** (raw streaming chunks): +``` +Chunk 1: `{` ❌ Invalid JSON +Chunk 2: `{"name"` ❌ Invalid JSON +Chunk 3: `{"name": "John"` ❌ Invalid JSON +Chunk 4: `{"name": "John Doe"` ❌ Invalid JSON +``` + +**With JSON Parser** (processed chunks): +``` +Chunk 1: `{}` βœ… Valid JSON +Chunk 2: `{"name": ""}` βœ… Valid JSON +Chunk 3: `{"name": "John"}` βœ… Valid JSON +Chunk 4: `{"name": "John Doe"}` βœ… Valid JSON +``` + +### Use Cases + +- **Function Calling**: Stream tool call arguments as valid JSON throughout the response +- **Structured Data**: Stream complex JSON objects (user profiles, product catalogs) progressively +- **Real-time Parsing**: Enable client-side JSON parsing at each streaming step without waiting for completion +- **API Integration**: Forward streaming JSON to downstream services that expect valid JSON +- **Live Updates**: Update UI components with valid JSON data as it streams in + +### Example Transformations + +| Input | Output | +|-------|--------| +| `{"name": "John"` | `{"name": "John"}` | +| `["apple", "banana"` | `["apple", "banana"]` | +| `{"user": {"name": "John"` | `{"user": {"name": "John"}}` | +| `{"message": "Hello\nWorld"` | `{"message": "Hello\nWorld"}` | +| `""` (empty string) | `{}` | +| `" "` (whitespace only) | `{}` | + +## Testing + +Run the test suite: + +```bash +cd plugins/jsonparser +go test -v +``` + +The tests cover: +- Plugin interface compliance +- Both usage types (AllRequests and PerRequest) +- Context-based enabling/disabling +- Streaming responses only (non-streaming responses are ignored) +- Various JSON completion scenarios +- Edge cases and error conditions +- Memory cleanup functionality with real and simulated requests +- Configuration options and default values \ No newline at end of file diff --git a/docs/features/plugins/mocker.mdx b/docs/features/plugins/mocker.mdx new file mode 100644 index 000000000..7749b15d8 --- /dev/null +++ b/docs/features/plugins/mocker.mdx @@ -0,0 +1,484 @@ +--- +title: "Mocker" +description: "Mock AI provider responses for testing, development, and simulation purposes." +icon: "mask" +--- + +## Quick Start + +### Minimal Configuration + +The simplest way to use the Mocker plugin is with no configuration - it will create a default catch-all rule: + +```go +package main + +import ( + "context" + bifrost "github.com/maximhq/bifrost/core" + "github.com/maximhq/bifrost/core/schemas" + mocker "github.com/maximhq/bifrost/plugins/mocker" +) + +func main() { + // Create plugin with minimal config + plugin, err := mocker.NewMockerPlugin(mocker.MockerConfig{ + Enabled: true, // Default rule will be created automatically + }) + if err != nil { + panic(err) + } + + // Initialize Bifrost with the plugin + client, initErr := bifrost.Init(context.Background(), schemas.BifrostConfig{ + Account: &yourAccount, + Plugins: []schemas.Plugin{plugin}, + }) + if err != nil { + panic(err) + } + defer client.Cleanup() + + // All requests will now return: "This is a mock response from the Mocker plugin" + response, _ := client.ChatCompletionRequest(context.Background(), &schemas.BifrostRequest{ + Provider: schemas.OpenAI, + Model: "gpt-4", + Input: schemas.RequestInput{ + ChatCompletionInput: &[]schemas.BifrostMessage{ + { + Role: schemas.ModelChatMessageRoleUser, + Content: schemas.MessageContent{ + ContentStr: bifrost.Ptr("Hello!"), + }, + }, + }, + }, + }) +} +``` + +### Custom Response + +```go +plugin, err := mocker.NewMockerPlugin(mocker.MockerConfig{ + Enabled: true, + Rules: []mocker.MockRule{ + { + Name: "openai-mock", + Enabled: true, + Probability: 1.0, // Always trigger + Conditions: mocker.Conditions{ + Providers: []string{"openai"}, + }, + Responses: []mocker.Response{ + { + Type: mocker.ResponseTypeSuccess, + Content: &mocker.SuccessResponse{ + Message: "Hello! This is a custom mock response for OpenAI.", + Usage: &mocker.Usage{ + PromptTokens: 15, + CompletionTokens: 25, + TotalTokens: 40, + }, + }, + }, + }, + }, + }, +}) +``` + +## Installation + +Add the plugin to your project: + + ```bash + go get github.com/maximhq/bifrost/plugins/mocker + ``` + +Import in your code: + + ```go + import mocker "github.com/maximhq/bifrost/plugins/mocker" + ``` + +## Basic Usage + +### Creating the Plugin + +```go +config := mocker.MockerConfig{ + Enabled: true, + DefaultBehavior: mocker.DefaultBehaviorPassthrough, // "passthrough", "success", "error" + Rules: []mocker.MockRule{ + // Your rules here + }, +} + +plugin, err := mocker.NewMockerPlugin(config) +if err != nil { + log.Fatal(err) +} +``` + +### Adding to Bifrost + +```go +client, initErr := bifrost.Init(context.Background(), schemas.BifrostConfig{ + Account: &yourAccount, + Plugins: []schemas.Plugin{plugin}, + Logger: bifrost.NewDefaultLogger(schemas.LogLevelInfo), +}) +``` + +### Disabling the Plugin + +```go +config := mocker.MockerConfig{ + Enabled: false, // All requests pass through to real providers +} +``` + +## Key Features + +### Template Variables + +Create dynamic responses using templates: + +```go +Response{ + Type: mocker.ResponseTypeSuccess, + Content: &mocker.SuccessResponse{ + MessageTemplate: stringPtr("Hello from {{provider}} using model {{model}}!"), + }, +} +``` + +**Available Variables:** +- `{{provider}}` - Provider name (e.g., "openai", "anthropic") +- `{{model}}` - Model name (e.g., "gpt-4", "claude-3") +- `{{faker.*}}` - Fake data generation (see Configuration Reference) + +### Weighted Response Selection + +Configure multiple responses with different probabilities: + +```go +Responses: []mocker.Response{ + { + Type: mocker.ResponseTypeSuccess, + Weight: 0.8, // 80% chance + Content: &mocker.SuccessResponse{ + Message: "Success response", + }, + }, + { + Type: mocker.ResponseTypeError, + Weight: 0.2, // 20% chance + Error: &mocker.ErrorResponse{ + Message: "Rate limit exceeded", + Type: stringPtr("rate_limit"), + Code: stringPtr("429"), + }, + }, +} +``` + +### Latency Simulation + +Add realistic delays to responses: + +```go +// Fixed latency +Latency: &mocker.Latency{ + Type: mocker.LatencyTypeFixed, + Min: 250 * time.Millisecond, +} + +// Variable latency +Latency: &mocker.Latency{ + Type: mocker.LatencyTypeUniform, + Min: 100 * time.Millisecond, + Max: 500 * time.Millisecond, +} +``` + +### Advanced Matching + +#### Regex Message Matching +```go +Conditions: mocker.Conditions{ + MessageRegex: stringPtr(`(?i).*support.*|.*help.*`), +} +``` + +#### Request Size Filtering +```go +Conditions: mocker.Conditions{ + RequestSize: &mocker.SizeRange{ + Min: 100, // bytes + Max: 1000, // bytes + }, +} +``` + +### Faker Data Generation + +Create realistic test data using faker variables: + +```go +{ + Name: "user-profile-example", + Responses: []mocker.Response{ + { + Type: mocker.ResponseTypeSuccess, + Content: &mocker.SuccessResponse{ + MessageTemplate: stringPtr(`User Profile: +- Name: {{faker.name}} +- Email: {{faker.email}} +- Company: {{faker.company}} +- Address: {{faker.address}}, {{faker.city}} +- Phone: {{faker.phone}} +- User ID: {{faker.uuid}} +- Join Date: {{faker.date}} +- Premium Account: {{faker.boolean}}`), + }, + }, + }, +} +``` + +### Statistics and Monitoring + +Get runtime statistics for monitoring: + +```go +stats := plugin.GetStatistics() +fmt.Printf("Plugin enabled: %v\n", stats.Enabled) +fmt.Printf("Total requests: %d\n", stats.TotalRequests) +fmt.Printf("Mocked requests: %d\n", stats.MockedRequests) + +// Rule-specific stats +for ruleName, ruleStats := range stats.Rules { + fmt.Printf("Rule %s: %d triggers\n", ruleName, ruleStats.Triggers) +} +``` + +## Configuration Reference + +### MockerConfig + +| Field | Type | Default | Description | +|-------|------|---------|-------------| +| `Enabled` | `bool` | `false` | Enable/disable the entire plugin | +| `DefaultBehavior` | `string` | `"passthrough"` | Action when no rules match: `"passthrough"`, `"success"`, `"error"` | +| `GlobalLatency` | `*Latency` | `nil` | Global latency applied to all rules | +| `Rules` | `[]MockRule` | `[]` | List of mock rules evaluated in priority order | + +### MockRule + +| Field | Type | Default | Description | +|-------|------|---------|-------------| +| `Name` | `string` | - | Unique rule name for identification | +| `Enabled` | `bool` | `true` | Enable/disable this specific rule | +| `Priority` | `int` | `0` | Higher numbers = higher priority | +| `Probability` | `float64` | `1.0` | Activation probability (0.0=never, 1.0=always) | +| `Conditions` | `Conditions` | `{}` | Matching conditions (empty = match all) | +| `Responses` | `[]Response` | - | Possible responses (weighted random selection) | +| `Latency` | `*Latency` | `nil` | Rule-specific latency override | + +### Conditions + +| Field | Type | Description | +|-------|------|-------------| +| `Providers` | `[]string` | Match specific providers: `["openai", "anthropic"]` | +| `Models` | `[]string` | Match specific models: `["gpt-4", "claude-3"]` | +| `MessageRegex` | `*string` | Regex pattern to match message content | +| `RequestSize` | `*SizeRange` | Request size constraints in bytes | + +### Response + +| Field | Type | Description | +|-------|------|-------------| +| `Type` | `string` | Response type: `"success"` or `"error"` | +| `Weight` | `float64` | Weight for random selection (default: 1.0) | +| `Content` | `*SuccessResponse` | Required if `Type="success"` | +| `Error` | `*ErrorResponse` | Required if `Type="error"` | +| `AllowFallbacks` | `*bool` | Control fallback behavior (`nil`=allow, `false`=block) | + +### SuccessResponse + +| Field | Type | Description | +|-------|------|-------------| +| `Message` | `string` | Static response message | +| `MessageTemplate` | `*string` | Template with variables: `{{provider}}`, `{{model}}`, `{{faker.*}}` | +| `Model` | `*string` | Override model name in response | +| `Usage` | `*Usage` | Token usage information | +| `FinishReason` | `*string` | Completion reason (default: `"stop"`) | +| `CustomFields` | `map[string]interface{}` | Additional metadata fields | + +### ErrorResponse + +| Field | Type | Description | +|-------|------|-------------| +| `Message` | `string` | Error message to return | +| `Type` | `*string` | Error type (e.g., `"rate_limit"`, `"auth_error"`) | +| `Code` | `*string` | Error code (e.g., `"429"`, `"401"`) | +| `StatusCode` | `*int` | HTTP status code | + +### Latency + +| Field | Type | Description | +|-------|------|-------------| +| `Type` | `string` | Latency type: `"fixed"` or `"uniform"` | +| `Min` | `time.Duration` | Minimum/exact latency (use `time.Millisecond`) | +| `Max` | `time.Duration` | Maximum latency (required for `"uniform"`) | + +**Important**: Use Go's `time.Duration` constants: +- βœ… Correct: `100 * time.Millisecond` +- ❌ Wrong: `100` (nanoseconds, barely noticeable) + +### Faker Variables + +#### Personal Information +- `{{faker.name}}` - Full name +- `{{faker.first_name}}` - First name only +- `{{faker.last_name}}` - Last name only +- `{{faker.email}}` - Email address +- `{{faker.phone}}` - Phone number + +#### Location +- `{{faker.address}}` - Street address +- `{{faker.city}}` - City name +- `{{faker.state}}` - State/province +- `{{faker.zip_code}}` - Postal code + +#### Business +- `{{faker.company}}` - Company name +- `{{faker.job_title}}` - Job title + +#### Text and Data +- `{{faker.lorem_ipsum}}` - Lorem ipsum text +- `{{faker.lorem_ipsum:10}}` - Lorem ipsum with 10 words +- `{{faker.uuid}}` - UUID v4 +- `{{faker.hex_color}}` - Hex color code + +#### Numbers and Dates +- `{{faker.integer}}` - Random integer (1-100) +- `{{faker.integer:10,50}}` - Random integer between 10-50 +- `{{faker.float}}` - Random float (0-100, 2 decimals) +- `{{faker.float:1,10}}` - Random float between 1-10 +- `{{faker.boolean}}` - Random boolean +- `{{faker.date}}` - Date (YYYY-MM-DD format) +- `{{faker.datetime}}` - Datetime (YYYY-MM-DD HH:MM:SS format) + +## Best Practices + +### Rule Organization + +```go +// Use priority to control rule evaluation order +rules := []mocker.MockRule{ + {Name: "specific-error", Priority: 100, Conditions: /* specific */}, + {Name: "general-success", Priority: 50, Conditions: /* general */}, + {Name: "catch-all", Priority: 0, Conditions: /* empty */}, +} +``` + +### Development vs Production + +```go +// Development: High mock rate +config := mocker.MockerConfig{ + Enabled: true, + Rules: []mocker.MockRule{ + {Probability: 1.0}, // Always mock + }, +} + +// Production: Occasional testing +config := mocker.MockerConfig{ + Enabled: true, + Rules: []mocker.MockRule{ + {Probability: 0.1}, // 10% mock rate + }, +} +``` + +### Performance Considerations + +- Place specific conditions before general ones (higher priority) +- Use simple string matching over complex regex when possible +- Keep response templates reasonably sized +- Consider disabling debug logging in production + +### Testing Your Configuration + +```go +func validateMockerConfig(config mocker.MockerConfig) error { + _, err := mocker.NewMockerPlugin(config) + return err +} + +// Test before deployment +if err := validateMockerConfig(yourConfig); err != nil { + log.Fatalf("Invalid mocker configuration: %v", err) +} +``` + +## Common Issues + +### Plugin Not Triggering + +1. Check if plugin is enabled: `Enabled: true` +2. Verify rule is enabled: `rule.Enabled: true` +3. Check probability: `Probability: 1.0` for testing +4. Verify conditions match your request + +### Latency Not Working + +Use `time.Duration` constants, not raw integers: + +```go +// ❌ Wrong: 100 nanoseconds (barely noticeable) +Min: 100 + +// βœ… Correct: 100 milliseconds +Min: 100 * time.Millisecond +``` + +### Regex Not Matching + +Test your regex pattern and ensure proper escaping: + +```go +// Case-insensitive matching +MessageRegex: stringPtr(`(?i).*help.*`) + +// Escape special characters +MessageRegex: stringPtr(`\$\d+\.\d+`) // Match $12.34 +``` + +### Controlling Fallbacks + +```go +Response{ + Type: mocker.ResponseTypeError, + AllowFallbacks: boolPtr(false), // Block fallbacks + Error: &mocker.ErrorResponse{ + Message: "Authentication failed", + }, +} +``` + +### Debug Mode + +Enable debug logging to troubleshoot: + +```go +client, initErr := bifrost.Init(context.Background(), schemas.BifrostConfig{ + Account: &account, + Plugins: []schemas.Plugin{plugin}, + Logger: bifrost.NewDefaultLogger(schemas.LogLevelDebug), +}) +``` diff --git a/docs/features/semantic-caching.mdx b/docs/features/semantic-caching.mdx new file mode 100644 index 000000000..d6f16cee1 --- /dev/null +++ b/docs/features/semantic-caching.mdx @@ -0,0 +1,519 @@ +--- +title: "Semantic Caching" +description: "Intelligent response caching based on semantic similarity. Reduce costs and latency by serving cached responses for semantically similar requests." +icon: "database" +--- + +## Overview + +Semantic caching uses vector similarity search to intelligently cache AI responses, serving cached results for semantically similar requests even when the exact wording differs. This dramatically reduces API costs and latency for repeated or similar queries. + +**Key Benefits:** +- **Cost Reduction**: Avoid expensive LLM API calls for similar requests +- **Improved Performance**: Sub-millisecond cache retrieval vs multi-second API calls +- **Intelligent Matching**: Semantic similarity beyond exact text matching +- **Streaming Support**: Full streaming response caching with proper chunk ordering + +--- + +## Core Features + +- **Dual-Layer Caching**: Exact hash matching + semantic similarity search (customizable threshold) +- **Vector-Powered Intelligence**: Uses embeddings to find semantically similar requests +- **Dynamic Configuration**: Per-request TTL and threshold overrides via headers/context +- **Model/Provider Isolation**: Separate caching per model and provider combination + +--- + +## Vector Store Setup + + + + + +```go +import ( + "context" + "github.com/maximhq/bifrost/framework/vectorstore" + "github.com/maximhq/bifrost/core/schemas" +) + +// Configure vector store +vectorConfig := &vectorstore.Config{ + Enabled: true, + Type: vectorstore.VectorStoreTypeWeaviate, + Config: vectorstore.WeaviateConfig{ + Scheme: "http", + Host: "localhost:8080", + }, +} + +// Create vector store +store, err := vectorstore.NewVectorStore(context.Background(), vectorConfig, logger) +if err != nil { + log.Fatal("Failed to create vector store:", err) +} +``` + + + + + +```json +{ + "vector_store": { + "enabled": true, + "type": "weaviate", + "config": { + "host": "localhost:8080", + "scheme": "http", + } + } +} +``` + +**For Weaviate Cloud:** +```json +{ + "vector_store": { + "enabled": true, + "type": "weaviate", + "config": { + "host": "your-cluster.weaviate.network", + "scheme": "https", + "api_key": "your-weaviate-api-key" + } + } +} +``` + + + + + +--- + +## Semantic Cache Configuration + + + + + +```go +import ( + "github.com/maximhq/bifrost/plugins/semanticcache" + "github.com/maximhq/bifrost/core/schemas" +) + +// Configure semantic cache plugin +cacheConfig := semanticcache.Config{ + // Embedding model configuration (Required) + Provider: schemas.OpenAI, + Keys: []schemas.Key{{Value: "sk-..."}}, + EmbeddingModel: "text-embedding-3-small", + Dimension: 1536, + + // Cache behavior + TTL: 5 * time.Minute, // Time to live for cached responses (default: 5 minutes) + Threshold: 0.8, // Similarity threshold for cache lookup (default: 0.8) + CleanUpOnShutdown: true, // Clean up cache on shutdown (default: false) + + // Conversation behavior + ConversationHistoryThreshold: 5, // Skip caching if conversation has > N messages (default: 3) + ExcludeSystemPrompt: bifrost.Ptr(false), // Exclude system messages from cache key (default: false) + + // Advanced options + CacheByModel: bifrost.Ptr(true), // Include model in cache key (default: true) + CacheByProvider: bifrost.Ptr(true), // Include provider in cache key (default: true) +} + +// Create plugin +plugin, err := semanticcache.Init(context.Background(), cacheConfig, logger, store) +if err != nil { + log.Fatal("Failed to create semantic cache plugin:", err) +} + +// Add to Bifrost config +bifrostConfig := schemas.BifrostConfig{ + Plugins: []schemas.Plugin{plugin}, + // ... other config +} +``` + + + + + +![Semantic Cache Plugin Configuration](../media/ui-semantic-cache-config.png) + +**Note**: Make sure you have a vector store setup (using `config.json`) before configuring the semantic cache plugin. + +1. **Navigate to Settings** + - Open Bifrost UI at `http://localhost:8080` + - Go to Settings. + +2. **Configure Semantic Cache Plugin** + +- Toggle the plugin switch to enable it, and fill in the required fields. + +**Required Fields:** +- **Provider**: The provider to use for caching. +- **Embedding Model**: The embedding model to use for caching. + +**Note**: Changes will need a restart of the Bifrost server to take effect, because the plugin is loaded on startup only. + + + + + +```json +{ + "plugins": [ + { + "enabled": true, + "name": "semantic_cache", + "config": { + "provider": "openai", + "embedding_model": "text-embedding-3-small", + + "cleanup_on_shutdown": true, + "ttl": "5m", + "threshold": 0.8, + + "conversation_history_threshold": 3, + "exclude_system_prompt": false, + + "cache_by_model": true, + "cache_by_provider": true + } + } + ] +} +``` + +> **Note**: All the available keys will be taken from the provider config on initialization, so make sure to add the keys to the provider you have specified in the config. Any updates to the keys will not be reflected until next restart. + +**TTL Format Options:** +- Duration strings: `"30s"`, `"5m"`, `"1h"`, `"24h"` +- Numeric seconds: `300` (5 minutes), `3600` (1 hour) + + + + + +--- + +## Cache Triggering + + +**Cache Key is mandatory**: Semantic caching only activates when a cache key is provided. Without a cache key, requests bypass caching entirely. + + + + + +Must set cache key in request context: + +```go +// This request WILL be cached +ctx = context.WithValue(ctx, semanticcache.CacheKey, "session-123") +response, err := client.ChatCompletionRequest(ctx, request) + +// This request will NOT be cached (no context value) +response, err := client.ChatCompletionRequest(context.Background(), request) +``` + + + + +Must set cache key in request header `x-bf-cache-key`: + +```bash +# This request WILL be cached +curl -H "x-bf-cache-key: session-123" ... + +# This request will NOT be cached (no header) +curl ... +``` + + + + + +## Per-Request Overrides + +Override default TTL and similarity threshold per request: + + + + + +You can set TTL and threshold in the request context, in the keys you configured in the plugin config: + +```go +// Go SDK: Custom TTL and threshold +ctx = context.WithValue(ctx, semanticcache.CacheKey, "session-123") +ctx = context.WithValue(ctx, semanticcache.CacheTTLKey, 30*time.Second) +ctx = context.WithValue(ctx, semanticcache.CacheThresholdKey, 0.9) +``` + + + + + +You can set TTL and threshold in the request headers `x-bf-cache-ttl` and `x-bf-cache-threshold`: + +```bash +# HTTP: Custom TTL and threshold +curl -H "x-bf-cache-key: session-123" \ + -H "x-bf-cache-ttl: 30s" \ + -H "x-bf-cache-threshold: 0.9" ... +``` + + + + + +--- + +## Advanced Cache Control + +### Cache Type Control + +Control which caching mechanism to use per request: + + + + + +```go +// Use only direct hash matching (fastest) +ctx = context.WithValue(ctx, semanticcache.CacheKey, "session-123") +ctx = context.WithValue(ctx, semanticcache.CacheTypeKey, semanticcache.CacheTypeDirect) + +// Use only semantic similarity search +ctx = context.WithValue(ctx, semanticcache.CacheKey, "session-123") +ctx = context.WithValue(ctx, semanticcache.CacheTypeKey, semanticcache.CacheTypeSemantic) + +// Default behavior: Direct + semantic fallback (if not specified) +ctx = context.WithValue(ctx, semanticcache.CacheKey, "session-123") +``` + + + + + +```bash +# Direct hash matching only +curl -H "x-bf-cache-key: session-123" \ + -H "x-bf-cache-type: direct" ... + +# Semantic similarity search only +curl -H "x-bf-cache-key: session-123" \ + -H "x-bf-cache-type: semantic" ... + +# Default: Both (if header not specified) +curl -H "x-bf-cache-key: session-123" ... +``` + + + + + +### No-Store Control + +Disable response caching while still allowing cache reads: + + + + + +```go +// Read from cache but don't store the response +ctx = context.WithValue(ctx, semanticcache.CacheKey, "session-123") +ctx = context.WithValue(ctx, semanticcache.CacheNoStoreKey, true) +``` + + + + + +```bash +# Read from cache but don't store response +curl -H "x-bf-cache-key: session-123" \ + -H "x-bf-cache-no-store: true" ... +``` + + + + + +--- + +## Conversation Configuration + +### History Threshold Logic + +The `ConversationHistoryThreshold` setting skips caching for conversations with many messages to prevent false positives: + +**Why this matters:** +- **Semantic False Positives**: Long conversation histories have high probability of semantic matches with unrelated conversations due to topic overlap +- **Direct Cache Inefficiency**: Long conversations rarely have exact hash matches, making direct caching less effective +- **Performance**: Reduces vector store load by filtering out low-value caching scenarios + +```json +{ + "conversation_history_threshold": 3 // Skip caching if > 3 messages in conversation +} +``` + +**Recommended Values:** +- **1-2**: Very conservative (may miss valuable caching opportunities) +- **3-5**: Balanced approach (default: 3) +- **10+**: Cache longer conversations (higher false positive risk) + +### System Prompt Handling + +Control whether system messages are included in cache key generation: + +```json +{ + "exclude_system_prompt": false // Include system messages in cache key (default) +} +``` + +**When to exclude (`true`):** +- System prompts change frequently but content is similar +- Multiple system prompt variations for same use case +- Focus caching on user content similarity + +**When to include (`false`):** +- System prompts significantly change response behavior +- Each system prompt requires distinct cached responses +- Strict response consistency requirements + +--- + +## Cache Management + +### Cache Metadata Location + +When responses are served from semantic cache, 3 key variables are automatically added to the response: + +**Location**: `response.ExtraFields.CacheDebug` (as a JSON object) + +**Fields**: +- `CacheHit` (boolean): `true` if the response was served from the cache, `false` when lookup fails. +- `HitType` (string): `"semantic"` for similarity match, `"direct"` for hash match +- `CacheID` (string): Unique cache entry ID for management operations (present only for cache hits) + + +**Semantic Cache Only**: +- `ProviderUsed` (string): Provider used for the calculating semantic match embedding. (present for both cache hits and misses) +- `ModelUsed` (string): Model used for the calculating semantic match embedding. (present for both cache hits and misses) +- `InputTokens` (number): Number of tokens extracted from the request for the semantic match embedding calculation. (present for both cache hits and misses) +- `Threshold` (number): Similarity threshold used for the match. (present only for cache hits) +- `Similarity` (number): Similarity score for the match. (present only for cache hits) + +Example HTTP Response: + +```json +{ + "extra_fields": { + "cache_debug": { + "cache_hit": true, + "hit_type": "direct", + "cache_id": "550e8500-e29b-41d4-a725-446655440001", + } + } +} + +{ + "extra_fields": { + "cache_debug": { + "cache_hit": true, + "hit_type": "semantic", + "cache_id": "550e8500-e29b-41d4-a725-446655440001", + "threshold": 0.8, + "similarity": 0.95, + "provider_used": "openai", + "model_used": "gpt-4o-mini", + "input_tokens": 100 + } + } +} + +{ + "extra_fields": { + "cache_debug": { + "cache_hit": false, + "provider_used": "openai", + "model_used": "gpt-4o-mini", + "input_tokens": 20 + } + } +} +``` + + +These variables allow you to detect cached responses and get the cache entry ID needed for clearing specific entries. + +### Clear Specific Cache Entry + +Use the request ID from cached responses to clear specific entries: + + + + + +```go +// Clear specific entry by request ID +err := plugin.ClearCacheForRequestID("550e8400-e29b-41d4-a716-446655440000") + +// Clear all entries for a cache key +err := plugin.ClearCacheForKey("support-session-456") +``` + + + + + +```bash +# Clear specific cached entry by request ID +curl -X DELETE http://localhost:8080/api/cache/clear/550e8400-e29b-41d4-a716-446655440000 + +# Clear all entries for a cache key +curl -X DELETE http://localhost:8080/api/cache/clear-by-key/support-session-456 +``` + + + + + +### Cache Lifecycle & Cleanup + +The semantic cache automatically handles cleanup to prevent storage bloat: + +**Automatic Cleanup:** +- **TTL Expiration**: Entries are automatically removed when TTL expires +- **Shutdown Cleanup**: All cache entries are cleared from the vector store namespace and the namespace itself when Bifrost client shuts down +- **Namespace Isolation**: Each Bifrost instance uses isolated vector store namespaces to prevent conflicts + +**Manual Cleanup Options:** +- Clear specific entries by request ID (see examples above) +- Clear all entries for a cache key +- Restart Bifrost to clear all cache data + + +The semantic cache namespace and all its cache entries are deleted when Bifrost client shuts down **only if `cleanup_on_shutdown` is set to `true`**. By default (`cleanup_on_shutdown: false`), cache data persists between restarts. DO NOT use the plugin's namespace for external purposes. + + + +**Dimension Changes**: If you update the `dimension` config, the existing namespace will contain data with mixed dimensions, causing retrieval issues. To avoid this, either use a different `vector_store_namespace` or set `cleanup_on_shutdown: true` before restarting. + + +--- + + +**Vector Store Requirement**: Semantic caching requires a configured vector store (currently Weaviate only). Without vector store setup, the plugin will not function. + \ No newline at end of file diff --git a/docs/features/sso-with-google-github.mdx b/docs/features/sso-with-google-github.mdx new file mode 100644 index 000000000..aee362bc4 --- /dev/null +++ b/docs/features/sso-with-google-github.mdx @@ -0,0 +1,6 @@ +--- +title: "SSO with Google & GitHub" +description: "Secure single sign-on authentication with Google and GitHub OAuth providers." +tag: "Coming soon" +icon: "sign-in-alt" +--- \ No newline at end of file diff --git a/docs/features/telemetry.mdx b/docs/features/telemetry.mdx new file mode 100644 index 000000000..dcf790e66 --- /dev/null +++ b/docs/features/telemetry.mdx @@ -0,0 +1,293 @@ +--- +title: "Telemetry" +description: "Comprehensive Prometheus-based monitoring for Bifrost Gateway with custom metrics and labels." +icon: "gauge" +--- + +## Overview + +Bifrost provides built-in telemetry and monitoring capabilities through Prometheus metrics collection. The telemetry system tracks both HTTP-level performance metrics and upstream provider interactions, giving you complete visibility into your AI gateway's performance and usage patterns. + +**Key Features:** +- **Prometheus Integration** - Native metrics collection at `/metrics` endpoint +- **Comprehensive Tracking** - Success/error rates, token usage, costs, and cache performance +- **Custom Labels** - Configurable dimensions for detailed analysis +- **Dynamic Headers** - Runtime label injection via `x-bf-prom-*` headers +- **Cost Monitoring** - Real-time tracking of AI provider costs in USD +- **Cache Analytics** - Direct and semantic cache hit tracking +- **Async Collection** - Zero-latency impact on request processing +- **Multi-Level Tracking** - HTTP transport + upstream provider metrics + +The telemetry plugin operates asynchronously to ensure metrics collection doesn't impact request latency or connection performance. + +--- + +## Default Metrics + +### HTTP Transport Metrics + +These metrics track all incoming HTTP requests to Bifrost: + +| Metric | Type | Description | Labels | +|--------|------|-------------|---------| +| `http_requests_total` | Counter | Total number of HTTP requests | `path`, `method`, `status`, custom labels | +| `http_request_duration_seconds` | Histogram | Duration of HTTP requests | `path`, `method`, `status`, custom labels | +| `http_request_size_bytes` | Histogram | Size of incoming HTTP requests | `path`, `method`, `status`, custom labels | +| `http_response_size_bytes` | Histogram | Size of outgoing HTTP responses | `path`, `method`, `status`, custom labels | + +### Upstream Provider Metrics + +These metrics track requests forwarded to AI providers: + +| Metric | Type | Description | Labels | +|--------|------|-------------|---------| +| `bifrost_upstream_requests_total` | Counter | Total requests forwarded to upstream providers | `provider`, `model`, `method`, custom labels | +| `bifrost_upstream_latency_seconds` | Histogram | Latency of upstream provider requests | `provider`, `model`, `method`, custom labels | +| `bifrost_success_requests_total` | Counter | Total successful requests to upstream providers | `provider`, `model`, `method`, custom labels | +| `bifrost_error_requests_total` | Counter | Total failed requests to upstream providers | `provider`, `model`, `method`, custom labels | +| `bifrost_input_tokens_total` | Counter | Total input tokens sent to upstream providers | `provider`, `model`, `method`, custom labels | +| `bifrost_output_tokens_total` | Counter | Total output tokens received from upstream providers | `provider`, `model`, `method`, custom labels | +| `bifrost_cache_hits_total` | Counter | Total cache hits by type (direct/semantic) | `provider`, `model`, `method`, `cache_type`, custom labels | +| `bifrost_cost_total` | Counter | Total cost in USD for upstream provider requests | `provider`, `model`, `method`, custom labels | + +**Label Definitions:** +- `provider`: AI provider name (e.g., `openai`, `anthropic`, `azure`) +- `model`: Model name (e.g., `gpt-4o-mini`, `claude-3-sonnet`) +- `method`: Request type (`chat`, `text`, `embedding`, `speech`, `transcription`) +- `cache_type`: Cache hit type (`direct`, `semantic`) - only for cache hits metric +- `path`: HTTP endpoint path +- `status`: HTTP status code + +--- + +## Monitoring Examples + +### Success Rate Monitoring +Track the success rate of requests to different providers: + +```promql +# Success rate by provider +rate(bifrost_success_requests_total[5m]) / +rate(bifrost_upstream_requests_total[5m]) * 100 +``` + +### Token Usage Analysis +Monitor token consumption across different models: + +```promql +# Input tokens per minute by model +increase(bifrost_input_tokens_total[1m]) + +# Output tokens per minute by model +increase(bifrost_output_tokens_total[1m]) + +# Token efficiency (output/input ratio) +rate(bifrost_output_tokens_total[5m]) / +rate(bifrost_input_tokens_total[5m]) +``` + +### Cost Tracking +Monitor spending across providers and models: + +```promql +# Cost per second by provider +sum by (provider) (rate(bifrost_cost_total[1m])) + +# Daily cost estimate +sum by (provider) (increase(bifrost_cost_total[1d])) + +# Cost per request by provider and model +sum by (provider, model) (rate(bifrost_cost_total[5m])) / +sum by (provider, model) (rate(bifrost_upstream_requests_total[5m])) +``` + +### Cache Performance +Track cache effectiveness: + +```promql +# Cache hit rate by type +rate(bifrost_cache_hits_total[5m]) / +rate(bifrost_upstream_requests_total[5m]) * 100 + +# Direct vs semantic cache hits +sum by (cache_type) (rate(bifrost_cache_hits_total[5m])) +``` + +### Error Rate Analysis +Monitor error patterns: + +```promql +# Error rate by provider +rate(bifrost_error_requests_total[5m]) / +rate(bifrost_upstream_requests_total[5m]) * 100 + +# Errors by model +sum by (model) (rate(bifrost_error_requests_total[5m])) +``` + +--- + +## Configuration + +Configure custom Prometheus labels to add dimensions for filtering and analysis: + + + + +![Prometheus Labels](../media/ui-prometheus-labels.png) + +1. **Navigate to Configuration** + - Open Bifrost UI at `http://localhost:8080` + - Go to **Config** tab + +2. **Prometheus Labels** + ``` + Custom Labels: team, environment, organization, project + ``` + + + + +```bash +# Update prometheus labels via API +curl -X PATCH http://localhost:8080/config \ + -H "Content-Type: application/json" \ + -d '{ + "client": { + "prometheus_labels": ["team", "environment", "organization", "project"] + } + }' +``` + + + + +```json +{ + "client": { + "prometheus_labels": ["team", "environment", "organization", "project"], + "drop_excess_requests": false, + "initial_pool_size": 300 + } +} +``` + + + + +### Dynamic Label Injection + +Add custom label values at runtime using `x-bf-prom-*` headers: + +```bash +# Add custom labels to specific requests +curl -X POST http://localhost:8080/v1/chat/completions \ + -H "Content-Type: application/json" \ + -H "x-bf-prom-team: engineering" \ + -H "x-bf-prom-environment: production" \ + -H "x-bf-prom-organization: my-org" \ + -H "x-bf-prom-project: my-project" \ + -d '{ + "model": "gpt-4o-mini", + "messages": [{"role": "user", "content": "Hello!"}] + }' +``` + +**Header Format:** +- Prefix: `x-bf-prom-` +- Label name: Any string after the prefix +- Value: String value for the label + +--- + +## Infrastructure Setup + +### Development & Testing + +For local development and testing, use the provided Docker Compose setup: + +```bash +# Navigate to telemetry plugin directory +cd plugins/telemetry + +# Start Prometheus and Grafana +docker-compose up -d + +# Access endpoints +# Prometheus: http://localhost:9090 +# Grafana: http://localhost:3000 (admin/admin) +# Bifrost metrics: http://localhost:8080/metrics +``` + + +**Development Only**: The provided Docker Compose setup is for testing purposes only. Do not use in production without proper security, scaling, and persistence configuration. + + +You can use the Prometheus scraping endpoint to create your own Grafana dashboards. Given below are few examples created using the Docker Compose setup. + +![Grafana Dashboard](../media/ui-grafana-dashboard.png) + +### Production Deployment + +For production environments: + +1. **Deploy Prometheus** with proper persistence, retention, and security +2. **Configure scraping** to target your Bifrost instances at `/metrics` +3. **Set up Grafana** with authentication and dashboards +4. **Configure alerts** based on your SLA requirements + +**Prometheus Scrape Configuration:** +```yaml +scrape_configs: + - job_name: "bifrost-gateway" + static_configs: + - targets: ["bifrost-instance-1:8080", "bifrost-instance-2:8080"] + scrape_interval: 30s + metrics_path: /metrics +``` + +### Production Alerting Examples + +Configure alerts for critical scenarios using the new metrics: + +**High Error Rate Alert:** +```yaml +- alert: BifrostHighErrorRate + expr: sum by (provider) (rate(bifrost_error_requests_total[5m])) / sum by (provider) (rate(bifrost_upstream_requests_total[5m])) > 0.05 + for: 2m + labels: + severity: warning + annotations: + summary: "High error rate detected for provider {{ $labels.provider }} ({{ $value | humanizePercentage }})" +``` + +**High Cost Alert:** +```yaml +- alert: BifrostHighCosts + expr: sum by (provider) (increase(bifrost_cost_total[1d])) > 100 # $100/day threshold + for: 10m + labels: + severity: warning + annotations: + summary: "Daily cost for provider {{ $labels.provider }} exceeds $100 ({{ $value | printf \"%.2f\" }})" +``` + +**Cache Performance Alert:** +```yaml +- alert: BifrostLowCacheHitRate + expr: sum by (provider) (rate(bifrost_cache_hits_total[15m])) / sum by (provider) (rate(bifrost_upstream_requests_total[15m])) < 0.1 + for: 5m + labels: + severity: info + annotations: + summary: "Cache hit rate for provider {{ $labels.provider }} below 10% ({{ $value | humanizePercentage }})" +``` + +--- + +## Next Steps + +- **[Architecture Overview](../architecture/plugins/telemetry)** - Deep dive into telemetry architecture +- **[Prometheus Documentation](https://prometheus.io/docs/)** - Official Prometheus guides +- **[Grafana Setup](https://grafana.com/docs/)** - Dashboard creation and management +- **[Tracing](./tracing)** - Request/response logging for detailed analysis diff --git a/docs/features/tracing.mdx b/docs/features/tracing.mdx new file mode 100644 index 000000000..4b715f9c2 --- /dev/null +++ b/docs/features/tracing.mdx @@ -0,0 +1,239 @@ +--- +title: "Tracing" +description: "Monitor and analyze every AI request and response in real-time. Track performance, debug issues, and gain insights into your AI application's behavior with comprehensive request tracing." +icon: "paw" +--- + +## Overview + +**Tracing** is a powerful Bifrost Gateway feature that automatically captures and stores detailed information about every AI request and response that flows through your system. This includes complete request/response cycles, performance metrics, error details, and provider-specific metadata. + +Unlike basic logging, Bifrost's tracing system provides structured, searchable data with real-time monitoring capabilities, making it easy to debug issues, analyze performance patterns, and understand your AI application's behavior at scale. + +![Live Log Stream Interface](../media/ui-live-log-stream.gif) + +## How It Works + +Bifrost traces comprehensive information for every request: + +![Complete Request Tracing Overview](../media/ui-request-tracing-overview.png) + +### **Request Data** +- **Input Messages**: Complete conversation history and user prompts +- **Model Parameters**: Temperature, max tokens, tools, and all other parameters +- **Provider Context**: Which provider and model handled the request + +### **Response Data** +- **Output Messages**: AI responses, tool calls, and function results +- **Performance Metrics**: Latency and token usage +- **Status Information**: Success or error details + +### **Multimodal Support** +- **Audio Processing**: Speech synthesis and transcription inputs/outputs +- **Vision Analysis**: Image URLs and vision model responses +- **Tool Execution**: Function calling arguments and results + +![Multimodal Request Tracing](../media/ui-multimodal-tracing.png) + +All data is automatically captured without any changes to your application code. + +--- + +## Configuration + +Configure request tracing to control what gets logged and where it's stored: + + + + + +![Tracing Configuration Interface](../media/ui-tracing-config.png) + +1. Navigate to **http://localhost:8080** +2. Go to **"Settings"** +3. Toggle **"Enable Logs"** + + + + + +**Enable/Disable Tracing:** +```bash +curl --location 'http://localhost:8080/api/config' \ +--header 'Content-Type: application/json' \ +--method PUT \ +--data '{ + "enable_logging": true, + "drop_excess_requests": false, + "initial_pool_size": 300, + "enable_governance": true, + "enforce_governance_header": false, + "allow_direct_keys": false, + "prometheus_labels": [], + "allowed_origins": [] +}' +``` + +**Check Current Configuration:** +```bash +curl --location 'http://localhost:8080/api/config' +``` + +**Response includes tracing status:** +```json +{ + "client_config": { + "enable_logging": true, + "drop_excess_requests": false + }, + "is_db_connected": true, + "is_cache_connected": true, + "is_logs_connected": true +} +``` + + + + + +```json +{ + "client": { + "enable_logging": true, + "drop_excess_requests": false, + "initial_pool_size": 300, + "enable_governance": true, + "allow_direct_keys": false + }, + "logs_store": { + "enabled": true, + "type": "sqlite", + "config": { + "path": "./logs.db" + } + } +} +``` + +**Configuration Options:** +- **`enable_logging`**: Master toggle for request tracing +- **`logs_store.enabled`**: Enable persistent log storage +- **`logs_store.type`**: Database type (currently `sqlite`) +- **`logs_store.config.path`**: Database file path + + + + + +--- + +## Advanced Filtering + +Retrieve and analyze logs with powerful filtering capabilities: + +![Advanced Log Filtering Interface](../media/ui-log-filtering.gif) + +### **API Filtering Options** + +```bash +curl 'http://localhost:8080/api/logs?' \ +'providers=openai,anthropic&' \ +'models=gpt-4o-mini&' \ +'status=success,error&' \ +'start_time=2024-01-15T00:00:00Z&' \ +'end_time=2024-01-15T23:59:59Z&' \ +'min_latency=1000&' \ +'max_latency=5000&' \ +'min_tokens=10&' \ +'max_tokens=1000&' \ +'min_cost=0.001&' \ +'max_cost=10&' \ +'content_search=python&' \ +'limit=100&' \ +'offset=0' +``` + +### **Available Filters** + +| Filter | Description | Example | +|--------|-------------|---------| +| `providers` | Filter by AI providers | `openai,anthropic` | +| `models` | Filter by specific models | `gpt-4o-mini,claude-3-sonnet` | +| `status` | Request status | `success,error,processing` | +| `objects` | Request types | `chat.completion,embedding` | +| `start_time` / `end_time` | Time range (RFC3339) | `2024-01-15T10:00:00Z` | +| `min_latency` / `max_latency` | Response time (ms) | `1000` to `5000` | +| `min_tokens` / `max_tokens` | Token usage range | `10` to `1000` | +| `min_cost` / `max_cost` | Cost range (USD) | `0.001` to `10` | +| `content_search` | Search in messages | `"error handling"` | +| `limit` / `offset` | Pagination | `100`, `200` | + +### **Response Format** + +```json +{ + "logs": [...], + "total_count": 1234, + "has_more": true, + "filters_applied": { + "providers": ["openai"], + "status": ["success"] + } +} +``` + +Perfect for analytics, debugging specific issues, or building custom monitoring dashboards. + +--- + +## Log Store Options + +Choose the right storage backend for your scale and requirements: + +### **Current Support** + +**SQLite** (Default) +- **Best for**: Development, small-medium deployments +- **Performance**: Excellent for read-heavy workloads +- **Setup**: Zero configuration, single file storage +- **Limits**: Single-writer, local filesystem only + +```json +{ + "logs_store": { + "enabled": true, + "type": "sqlite", + "config": { + "path": "./logs.db" + } + } +} +``` + +### **Planned Support** + +**PostgreSQL** (Coming Soon) +- **Best for**: High-volume production deployments +- **Performance**: Excellent concurrent writes and complex queries +- **Features**: Advanced indexing, partitioning, replication + +**MySQL** (Coming Soon) +- **Best for**: Traditional MySQL environments +- **Performance**: Good balance of features and performance +- **Features**: Familiar ecosystem, wide tooling support + +**ClickHouse** (Coming Soon) +- **Best for**: Analytics and time-series workloads +- **Performance**: Exceptional for large-scale log analysis +- **Features**: Columnar storage, compression, real-time analytics + +{/* **To understand how the tracing plugin handles everything concurrently without increasing latency while maintaining best performance, check the [Tracing Architecture Guide](../architecture/plugins/tracing).** */} + +--- + +## Next Steps + +{/* - **[Architecture Deep Dive](../architecture/plugins/tracing)** - Internal implementation and performance tuning */} +- **[Observability](./observability)** - Metrics, monitoring, and alerting setup +- **[Gateway Setup](../quickstart/gateway/setting-up)** - Get Bifrost running with tracing enabled +- **[Provider Configuration](../quickstart/gateway/provider-configuration)** - Configure multiple providers for better insights \ No newline at end of file diff --git a/docs/features/unified-interface.mdx b/docs/features/unified-interface.mdx new file mode 100644 index 000000000..be80fbf60 --- /dev/null +++ b/docs/features/unified-interface.mdx @@ -0,0 +1,100 @@ +--- +title: "Unified Interface" +description: "Every AI provider returns the same OpenAI-compatible response format, making it seamless to switch between providers without changing your application code." +icon: "layer-group" +--- + +## One Format, All Providers + +The beauty of Bifrost lies in its unified interface: regardless of whether you're using OpenAI, Anthropic, AWS Bedrock, Google Vertex, or any other supported provider, you always get the same response format. This means your application logic never needs to change when switching providers. + +Bifrost standardizes all provider responses to follow the **OpenAI-compatible structure**, so you can write your code once and use it with any provider. + +## How It Works + +When you make a request to any provider through Bifrost, the response always follows the same structure - the familiar OpenAI format that most developers already know. Behind the scenes, Bifrost handles all the complexity of translating between different provider formats. + + + + + +```bash +# Same response format regardless of provider +curl -X POST http://localhost:8080/v1/chat/completions \ + -H "Content-Type: application/json" \ + -d '{ + "model": "openai/gpt-4o-mini", + "messages": [{"role": "user", "content": "Hello!"}] + }' + +# Returns OpenAI-compatible format: +{ + "id": "chatcmpl-123", + "object": "chat.completion", + "choices": [ + { + "index": 0, + "message": { + "role": "assistant", + "content": "Hello! How can I help you?" + }, + "finish_reason": "stop" + } + ], + "usage": { + "prompt_tokens": 10, + "completion_tokens": 9, + "total_tokens": 19 + } +} +``` + + + + + +```go +// Same response structure regardless of provider +type BifrostResponse struct { + ID string `json:"id,omitempty"` + Object string `json:"object,omitempty"` + Choices []BifrostResponseChoice `json:"choices"` + Usage *LLMUsage `json:"usage,omitempty"` + Model string `json:"model,omitempty"` + Created int64 `json:"created,omitempty"` +} + +// Works with any provider +response, err := client.ChatCompletion(ctx, &schemas.BifrostRequest{ + Provider: schemas.OpenAI, // or Anthropic, Bedrock, etc. + Model: "gpt-4o-mini", // or "claude-3-sonnet", etc. + Messages: messages, +}) +// Response structure is always the same! +``` + + + + + + +## The Power of Consistency + +This unified approach means you can: + +- **Switch providers instantly** without changing application logic +- **Mix and match providers** using fallbacks and load balancing +- **Future-proof your code** as new providers get added +- **Use familiar OpenAI patterns** regardless of the underlying provider + +Whether you're calling OpenAI's GPT-4, Anthropic's Claude, or AWS Bedrock's models, your application sees the exact same response structure. This consistency is what makes Bifrost's advanced features like automatic fallbacks and multi-provider load balancing possible. + +## Provider Transparency + +While the response format stays consistent, Bifrost doesn't hide which provider actually handled your request. Provider information is always available in the `extra_fields` section, along with any provider-specific metadata you might need for debugging or analytics. + +This gives you the best of both worlds: consistent application logic with full transparency into the underlying provider behavior. + +**Learn more about configuring provider transparency:** +- **[Go SDK Provider Configuration](../quickstart/go-sdk/provider-configuration)** - Configure `SendBackRawResponse` and other provider settings +- **[Gateway Provider Configuration](../quickstart/gateway/provider-configuration)** - Configure `send_back_raw_response` via API, UI, or config file diff --git a/docs/googleTag.js b/docs/googleTag.js new file mode 100644 index 000000000..92d932b72 --- /dev/null +++ b/docs/googleTag.js @@ -0,0 +1,41 @@ +(function (w, d, s, l, i) { + w[l] = w[l] || []; w[l].push({ + 'gtm.start': + new Date().getTime(), event: 'gtm.js' + }); var f = d.getElementsByTagName(s)[0], + j = d.createElement(s), dl = l != 'dataLayer' ? '&l=' + l : ''; j.async = true; j.src = + 'https://www.googletagmanager.com/gtm.js?id=' + i + dl; f.parentNode.insertBefore(j, f); +})(window, document, 'script', 'dataLayer', 'GTM-PZVSZ6P5'); + + +(function() { + var script = document.createElement('script'); + script.src = "https://g.getmaxim.ai?id=G-Q9GWB3JQM9"; + script.async = true; + document.head.appendChild(script); +})(); + +window.dataLayer = window.dataLayer || []; +function gtag() { dataLayer.push(arguments); } +gtag('js', new Date()); +gtag('config', 'G-Q9GWB3JQM9'); + +// Attach GTM noscript to the top of the body +(function() { + var noscript = document.createElement('noscript'); + var iframe = document.createElement('iframe'); + iframe.src = "https://www.googletagmanager.com/ns.html?id=GTM-PZVSZ6P5"; + iframe.height = "0"; + iframe.width = "0"; + iframe.style.display = "none"; + iframe.style.visibility = "hidden"; + noscript.appendChild(iframe); + + if (document.body) { + document.body.insertBefore(noscript, document.body.firstChild); + } else { + document.addEventListener('DOMContentLoaded', function() { + document.body.insertBefore(noscript, document.body.firstChild); + }); + } +})(); \ No newline at end of file diff --git a/docs/integrations/anthropic-sdk.mdx b/docs/integrations/anthropic-sdk.mdx new file mode 100644 index 000000000..65a6cb7dd --- /dev/null +++ b/docs/integrations/anthropic-sdk.mdx @@ -0,0 +1,342 @@ +--- +title: "Anthropic SDK" +description: "Use Bifrost as a drop-in replacement for Anthropic API with full compatibility and enhanced features." +icon: "a" +--- + +## Overview + +Bifrost provides complete Anthropic API compatibility through protocol adaptation. The integration handles request transformation, response normalization, and error mapping between Anthropic's Messages API specification and Bifrost's internal processing pipeline. + +This integration enables you to utilize Bifrost's features like governance, load balancing, semantic caching, multi-provider support, and more, all while preserving your existing Anthropic SDK-based architecture. + +**Endpoint:** `/anthropic` + +--- + +## Setup + + + + +```python {5} +import anthropic + +# Configure client to use Bifrost +client = anthropic.Anthropic( + base_url="http://localhost:8080/anthropic", + api_key="dummy-key" # Keys handled by Bifrost +) + +# Make requests as usual +response = client.messages.create( + model="claude-3-sonnet-20240229", + max_tokens=1000, + messages=[{"role": "user", "content": "Hello!"}] +) + +print(response.content[0].text) +``` + + + + +```javascript {5} +import Anthropic from "@anthropic-ai/sdk"; + +// Configure client to use Bifrost +const anthropic = new Anthropic({ + baseURL: "http://localhost:8080/anthropic", + apiKey: "dummy-key", // Keys handled by Bifrost +}); + +// Make requests as usual +const response = await anthropic.messages.create({ + model: "claude-3-sonnet-20240229", + max_tokens: 1000, + messages: [{ role: "user", content: "Hello!" }], +}); + +console.log(response.content[0].text); +``` + + + + +--- + +## Provider/Model Usage Examples + +Use multiple providers through the same Anthropic SDK format by prefixing model names with the provider: + + + + +```python +import anthropic + +client = anthropic.Anthropic( + base_url="http://localhost:8080/anthropic", + api_key="dummy-key" +) + +# Anthropic models (default) +anthropic_response = client.messages.create( + model="claude-3-sonnet-20240229", + max_tokens=1000, + messages=[{"role": "user", "content": "Hello from Claude!"}] +) + +# OpenAI models via Anthropic SDK format +openai_response = client.messages.create( + model="openai/gpt-4o-mini", + max_tokens=1000, + messages=[{"role": "user", "content": "Hello from OpenAI!"}] +) + +# Google Vertex models via Anthropic SDK format +vertex_response = client.messages.create( + model="vertex/gemini-pro", + max_tokens=1000, + messages=[{"role": "user", "content": "Hello from Gemini!"}] +) + +# Azure OpenAI models +azure_response = client.messages.create( + model="azure/gpt-4o", + max_tokens=1000, + messages=[{"role": "user", "content": "Hello from Azure!"}] +) + +# Local Ollama models +ollama_response = client.messages.create( + model="ollama/llama3.1:8b", + max_tokens=1000, + messages=[{"role": "user", "content": "Hello from Ollama!"}] +) +``` + + + + +```javascript +import Anthropic from "@anthropic-ai/sdk"; + +const anthropic = new Anthropic({ + baseURL: "http://localhost:8080/anthropic", + apiKey: "dummy-key", +}); + +// Anthropic models (default) +const anthropicResponse = await anthropic.messages.create({ + model: "claude-3-sonnet-20240229", + max_tokens: 1000, + messages: [{ role: "user", content: "Hello from Claude!" }], +}); + +// OpenAI models via Anthropic SDK format +const openaiResponse = await anthropic.messages.create({ + model: "openai/gpt-4o-mini", + max_tokens: 1000, + messages: [{ role: "user", content: "Hello from OpenAI!" }], +}); + +// Google Vertex models via Anthropic SDK format +const vertexResponse = await anthropic.messages.create({ + model: "vertex/gemini-pro", + max_tokens: 1000, + messages: [{ role: "user", content: "Hello from Gemini!" }], +}); + +// Azure OpenAI models +const azureResponse = await anthropic.messages.create({ + model: "azure/gpt-4o", + max_tokens: 1000, + messages: [{ role: "user", content: "Hello from Azure!" }], +}); + +// Local Ollama models +const ollamaResponse = await anthropic.messages.create({ + model: "ollama/llama3.1:8b", + max_tokens: 1000, + messages: [{ role: "user", content: "Hello from Ollama!" }], +}); +``` + + + + +--- + +## Adding Custom Headers + +Pass custom headers required by Bifrost plugins (like governance, telemetry, etc.): + + + + +```python +import anthropic + +client = anthropic.Anthropic( + base_url="http://localhost:8080/anthropic", + api_key="dummy-key", + default_headers={ + "x-bf-vk": "vk_12345", # Virtual key for governance + "x-bf-user-id": "user_789", # User identification + "x-bf-team-id": "team_456", # Team identification + "x-bf-trace-id": "trace_abc123", # Request tracing + } +) + +response = client.messages.create( + model="claude-3-sonnet-20240229", + max_tokens=1000, + messages=[{"role": "user", "content": "Hello with custom headers!"}] +) +``` + + + + +```javascript +import Anthropic from "@anthropic-ai/sdk"; + +const anthropic = new Anthropic({ + baseURL: "http://localhost:8080/anthropic", + apiKey: "dummy-key", + defaultHeaders: { + "x-bf-vk": "vk_12345", // Virtual key for governance + "x-bf-user-id": "user_789", // User identification + "x-bf-team-id": "team_456", // Team identification + "x-bf-trace-id": "trace_abc123", // Request tracing + }, +}); + +const response = await anthropic.messages.create({ + model: "claude-3-sonnet-20240229", + max_tokens: 1000, + messages: [{ role: "user", content: "Hello with custom headers!" }], +}); +``` + + + + +--- + +## Using Direct Keys + +Pass API keys directly in requests to bypass Bifrost's load balancing. You can pass any provider's API key (OpenAI, Anthropic, Mistral, etc.) since Bifrost only looks for `Authorization` or `x-api-key` headers. This requires the **Allow Direct API keys** option to be enabled in Bifrost configuration. + +> **Learn more:** See [Quickstart Configuration](../quickstart/README) for enabling direct API key usage. + + + + +```python +import anthropic + +# Using Anthropic's API key directly +client_with_direct_key = anthropic.Anthropic( + base_url="http://localhost:8080/anthropic", + api_key="sk-your-anthropic-key" # Anthropic's API key works +) + +anthropic_response = client_with_direct_key.messages.create( + model="claude-3-sonnet-20240229", + max_tokens=1000, + messages=[{"role": "user", "content": "Hello from Claude!"}] +) + +# or pass different provider keys per request using headers +client = anthropic.Anthropic( + base_url="http://localhost:8080/anthropic", + api_key="dummy-key" +) + +# Use Anthropic key for Claude +anthropic_response = client.messages.create( + model="claude-3-sonnet-20240229", + max_tokens=1000, + messages=[{"role": "user", "content": "Hello Claude!"}], + extra_headers={ + "x-api-key": "sk-ant-your-anthropic-key" + } +) + +# Use OpenAI key for GPT models +openai_response = client.messages.create( + model="openai/gpt-4o-mini", + max_tokens=1000, + messages=[{"role": "user", "content": "Hello GPT!"}], + extra_headers={ + "Authorization": "Bearer sk-your-openai-key" + } +) +``` + + + + +```javascript +import Anthropic from "@anthropic-ai/sdk"; + +// Using Anthropic's API key directly +const anthropicWithDirectKey = new Anthropic({ + baseURL: "http://localhost:8080/anthropic", + apiKey: "sk-your-anthropic-key", // Anthropic's API key works +}); + + +const anthropicResponse = await anthropicWithDirectKey.messages.create({ + model: "claude-3-sonnet-20240229", + max_tokens: 1000, + messages: [{ role: "user", content: "Hello from Claude!" }], +}); + + +// or pass different provider keys per request using headers +const anthropic = new Anthropic({ + baseURL: "http://localhost:8080/anthropic", + apiKey: "dummy-key", +}); + +// Use Anthropic key for Claude +const anthropicResponse = await anthropic.messages.create({ + model: "claude-3-sonnet-20240229", + max_tokens: 1000, + messages: [{ role: "user", content: "Hello Claude!" }], + headers: { + "x-api-key": "sk-ant-your-anthropic-key", + }, +}); + +// Use OpenAI key for GPT models +const openaiResponseWithHeader = await anthropic.messages.create({ + model: "openai/gpt-4o-mini", + max_tokens: 1000, + messages: [{ role: "user", content: "Hello GPT!" }], + headers: { + "Authorization": "Bearer sk-your-openai-key", + }, +}); +``` + + + + +--- + +## Supported Features + +The Anthropic integration supports all features that are available in both the Anthropic SDK and Bifrost core functionality. If the Anthropic SDK supports a feature and Bifrost supports it, the integration will work seamlessly. πŸ˜„ + +--- + +## Next Steps + +- **[OpenAI SDK](./openai-sdk)** - GPT integration patterns +- **[Google GenAI SDK](./genai-sdk)** - Gemini integration patterns +- **[Configuration](../quickstart/README)** - Bifrost setup and configuration +- **[Core Features](../features/)** - Advanced Bifrost capabilities diff --git a/docs/integrations/genai-sdk.mdx b/docs/integrations/genai-sdk.mdx new file mode 100644 index 000000000..42b893951 --- /dev/null +++ b/docs/integrations/genai-sdk.mdx @@ -0,0 +1,288 @@ +--- +title: "Google GenAI SDK" +description: "Use Bifrost as a drop-in replacement for Google GenAI API with full compatibility and enhanced features." +icon: "g" +--- + +## Overview + +Bifrost provides complete Google GenAI API compatibility through protocol adaptation. The integration handles request transformation, response normalization, and error mapping between Google's GenAI API specification and Bifrost's internal processing pipeline. + +This integration enables you to utilize Bifrost's features like governance, load balancing, semantic caching, multi-provider support, and more, all while preserving your existing Google GenAI SDK-based architecture. + +**Endpoint:** `/genai` + +--- + +## Setup + + + + +```python {7} +from google import genai +from google.genai.types import HttpOptions + +# Configure client to use Bifrost +client = genai.Client( + api_key="dummy-key", # Keys handled by Bifrost + http_options=HttpOptions(base_url="http://localhost:8080/genai") +) + +# Make requests as usual +response = client.models.generate_content( + model="gemini-1.5-flash", + contents="Hello!" +) + +print(response.text) +``` + + + + +```javascript {5} +import { GoogleGenerativeAI } from "@google/generative-ai"; + +// Configure client to use Bifrost +const genAI = new GoogleGenerativeAI("dummy-key", { + baseUrl: "http://localhost:8080/genai", // Keys handled by Bifrost +}); + +// Make requests as usual +const model = genAI.getGenerativeModel({ model: "gemini-1.5-flash" }); +const response = await model.generateContent("Hello!"); + +console.log(response.response.text()); +``` + + + + +--- + +## Provider/Model Usage Examples + +Use multiple providers through the same GenAI SDK format by prefixing model names with the provider: + + + + +```python +from google import genai +from google.genai.types import HttpOptions + +client = genai.Client( + api_key="dummy-key", + http_options=HttpOptions(base_url="http://localhost:8080/genai") +) + +# Google Vertex models (default) +vertex_response = client.models.generate_content( + model="gemini-1.5-flash", + contents="Hello from Gemini!" +) + +# OpenAI models via GenAI SDK format +openai_response = client.models.generate_content( + model="openai/gpt-4o-mini", + contents="Hello from OpenAI!" +) + +# Anthropic models via GenAI SDK format +anthropic_response = client.models.generate_content( + model="anthropic/claude-3-sonnet-20240229", + contents="Hello from Claude!" +) + +# Azure OpenAI models +azure_response = client.models.generate_content( + model="azure/gpt-4o", + contents="Hello from Azure!" +) + +# Local Ollama models +ollama_response = client.models.generate_content( + model="ollama/llama3.1:8b", + contents="Hello from Ollama!" +) +``` + + + + +```javascript +import { GoogleGenerativeAI } from "@google/generative-ai"; + +const genAI = new GoogleGenerativeAI("dummy-key", { + baseUrl: "http://localhost:8080/genai", +}); + +// Google Vertex models (default) +const geminiModel = genAI.getGenerativeModel({ model: "gemini-1.5-flash" }); +const vertexResponse = await geminiModel.generateContent("Hello from Gemini!"); + +// OpenAI models via GenAI SDK format +const openaiModel = genAI.getGenerativeModel({ model: "openai/gpt-4o-mini" }); +const openaiResponse = await openaiModel.generateContent("Hello from OpenAI!"); + +// Anthropic models via GenAI SDK format +const anthropicModel = genAI.getGenerativeModel({ model: "anthropic/claude-3-sonnet-20240229" }); +const anthropicResponse = await anthropicModel.generateContent("Hello from Claude!"); + +// Azure OpenAI models +const azureModel = genAI.getGenerativeModel({ model: "azure/gpt-4o" }); +const azureResponse = await azureModel.generateContent("Hello from Azure!"); + +// Local Ollama models +const ollamaModel = genAI.getGenerativeModel({ model: "ollama/llama3.1:8b" }); +const ollamaResponse = await ollamaModel.generateContent("Hello from Ollama!"); +``` + + + + +--- + +## Adding Custom Headers + +Pass custom headers required by Bifrost plugins (like governance, telemetry, etc.): + + + + +```python +from google import genai +from google.genai.types import HttpOptions + +# Configure client with custom headers +client = genai.Client( + api_key="dummy-key", + http_options=HttpOptions( + base_url="http://localhost:8080/genai", + headers={ + "x-bf-vk": "vk_12345", # Virtual key for governance + "x-bf-user-id": "user_789", # User identification + "x-bf-team-id": "team_456", # Team identification + "x-bf-trace-id": "trace_abc123", # Request tracing + } + ) +) + +response = client.models.generate_content( + model="gemini-1.5-flash", + contents="Hello with custom headers!" +) +``` + + + + +```javascript +import { GoogleGenerativeAI } from "@google/generative-ai"; + +// Configure client with custom headers +const genAI = new GoogleGenerativeAI("dummy-key", { + baseUrl: "http://localhost:8080/genai", + customHeaders: { + "x-bf-vk": "vk_12345", // Virtual key for governance + "x-bf-user-id": "user_789", // User identification + "x-bf-team-id": "team_456", // Team identification + "x-bf-trace-id": "trace_abc123", // Request tracing + }, +}); + +const model = genAI.getGenerativeModel({ model: "gemini-1.5-flash" }); +const response = await model.generateContent("Hello with custom headers!"); +``` + + + + +--- + +## Using Direct Keys + +Pass API keys directly in requests to bypass Bifrost's load balancing. You can pass any provider's API key (OpenAI, Anthropic, Mistral, etc.) since Bifrost only looks for `Authorization` or `x-api-key` headers. This requires the **Allow Direct API keys** option to be enabled in Bifrost configuration. + +> **Learn more:** See [Quickstart Configuration](../quickstart/README) for enabling direct API key usage. + + + + +```python +from google import genai +from google.genai.types import HttpOptions + +# Pass different provider keys per request using headers +client = genai.Client( + api_key="dummy-key", + http_options=HttpOptions(base_url="http://localhost:8080/genai") +) + +# Use Anthropic key for Claude models +anthropic_response = client.models.generate_content( + model="anthropic/claude-3-sonnet-20240229", + contents="Hello Claude!", + request_options={ + "headers": {"x-api-key": "your-anthropic-api-key"} + } +) + +# Use OpenAI key for GPT models +openai_response = client.models.generate_content( + model="openai/gpt-4o-mini", + contents="Hello GPT!", + request_options={ + "headers": {"Authorization": "Bearer sk-your-openai-key"} + } +) +``` + + + + +```javascript +import { GoogleGenerativeAI } from "@google/generative-ai"; + +// Pass different provider keys per request using headers +const genAI = new GoogleGenerativeAI("dummy-key", { + baseUrl: "http://localhost:8080/genai", +}); + +// Use Anthropic key for Claude models +const anthropicModel = genAI.getGenerativeModel({ + model: "anthropic/claude-3-sonnet-20240229", + requestOptions: { + customHeaders: { "x-api-key": "your-anthropic-api-key" } + } +}); +const anthropicResponse = await anthropicModel.generateContent("Hello Claude!"); + +// Use OpenAI key for GPT models +const gptModel = genAI.getGenerativeModel({ + model: "openai/gpt-4o-mini", + requestOptions: { + customHeaders: { "Authorization": "Bearer sk-your-openai-key" } + } +}); +const gptResponse = await gptModel.generateContent("Hello GPT!"); +``` + + + + +--- + +## Supported Features + +The Google GenAI integration supports all features that are available in both the Google GenAI SDK and Bifrost core functionality. If the Google GenAI SDK supports a feature and Bifrost supports it, the integration will work seamlessly. πŸ˜„ + +--- + +## Next Steps + +- **[OpenAI SDK](./openai-sdk)** - GPT integration patterns +- **[Anthropic SDK](./anthropic-sdk)** - Claude integration patterns +- **[Configuration](../quickstart/README)** - Bifrost setup and configuration +- **[Core Features](../features/)** - Advanced Bifrost capabilities diff --git a/docs/integrations/langchain-sdk.mdx b/docs/integrations/langchain-sdk.mdx new file mode 100644 index 000000000..ea66b0000 --- /dev/null +++ b/docs/integrations/langchain-sdk.mdx @@ -0,0 +1,311 @@ +--- +title: "Langchain SDK" +description: "Use Bifrost as a drop-in proxy for Langchain applications with zero code changes." +icon: "crow" +--- + +Since Langchain already provides multi-provider abstraction and chaining capabilities, Bifrost adds enterprise features like governance, semantic caching, MCP tools, observability, etc, on top of your existing setup. + +**Endpoint:** `/langchain` + + +**Provider Compatibility:** This integration only works for AI providers that both Langchain and Bifrost support. If you're using a provider specific to Langchain that Bifrost doesn't support (or vice versa), those requests will fail. + +--- + +## Setup + + + + +```python {7} +from langchain_openai import ChatOpenAI +from langchain_core.messages import HumanMessage + +# Configure client to use Bifrost +llm = ChatOpenAI( + model="gpt-4o-mini", + openai_api_base="http://localhost:8080/langchain", # Point to Bifrost + openai_api_key="dummy-key" # Keys managed by Bifrost +) + +response = llm.invoke([HumanMessage(content="Hello!")]) +print(response.content) +``` + + + + +```javascript {7} +import { ChatOpenAI } from "@langchain/openai"; + +// Configure client to use Bifrost +const llm = new ChatOpenAI({ + model: "gpt-4o-mini", + configuration: { + baseURL: "http://localhost:8080/langchain", // Point to Bifrost + }, + openAIApiKey: "dummy-key" // Keys managed by Bifrost +}); + +const response = await llm.invoke("Hello!"); +console.log(response.content); +``` + + + + + +--- + +## Provider/Model Usage Examples + +Your existing Langchain provider switching works unchanged through Bifrost: + + + + +```python +from langchain_openai import ChatOpenAI +from langchain_anthropic import ChatAnthropic +from langchain_google_genai import ChatGoogleGenerativeAI +from langchain_core.messages import HumanMessage + +base_url = "http://localhost:8080/langchain" + +# OpenAI models via Langchain +openai_llm = ChatOpenAI( + model="gpt-4o-mini", + openai_api_base=base_url +) + +# Anthropic models via Langchain +anthropic_llm = ChatAnthropic( + model="claude-3-sonnet-20240229", + anthropic_api_url=base_url +) + +# Google models via Langchain +google_llm = ChatGoogleGenerativeAI( + model="gemini-1.5-flash", + google_api_base=base_url +) + +# All work the same way +openai_response = openai_llm.invoke([HumanMessage(content="Hello GPT!")]) +anthropic_response = anthropic_llm.invoke([HumanMessage(content="Hello Claude!")]) +google_response = google_llm.invoke([HumanMessage(content="Hello Gemini!")]) +``` + + + + +```javascript +import { ChatOpenAI } from "@langchain/openai"; +import { ChatAnthropic } from "@langchain/anthropic"; +import { ChatGoogleGenerativeAI } from "@langchain/google-genai"; + +const baseURL = "http://localhost:8080/langchain"; + +// OpenAI models via Langchain +const openaiLlm = new ChatOpenAI({ + model: "gpt-4o-mini", + configuration: { baseURL } +}); + +// Anthropic models via Langchain +const anthropicLlm = new ChatAnthropic({ + model: "claude-3-sonnet-20240229", + clientOptions: { baseURL } +}); + +// Google models via Langchain +const googleLlm = new ChatGoogleGenerativeAI({ + model: "gemini-1.5-flash", + baseURL +}); + +// All work the same way +const openaiResponse = await openaiLlm.invoke("Hello GPT!"); +const anthropicResponse = await anthropicLlm.invoke("Hello Claude!"); +const googleResponse = await googleLlm.invoke("Hello Gemini!"); +``` + + + + +--- + +## Adding Custom Headers + +Add Bifrost-specific headers for governance and tracking: + + + + +```python +from langchain_openai import ChatOpenAI +from langchain_core.messages import HumanMessage + +# Add custom headers for Bifrost features +llm = ChatOpenAI( + model="gpt-4o-mini", + openai_api_base="http://localhost:8080/langchain", + default_headers={ + "x-bf-vk": "your-virtual-key", # Virtual key for governance + "x-bf-user-id": "user123", # User tracking + "x-bf-team-id": "team-ai", # Team tracking + "x-bf-trace-id": "trace-456" # Custom trace ID + } +) + +response = llm.invoke([HumanMessage(content="Hello!")]) +print(response.content) +``` + + + + +```javascript +import { ChatOpenAI } from "@langchain/openai"; + +// Add custom headers for Bifrost features +const llm = new ChatOpenAI({ + model: "gpt-4o-mini", + configuration: { + baseURL: "http://localhost:8080/langchain", + defaultHeaders: { + "x-bf-vk": "your-virtual-key", // Virtual key for governance + "x-bf-user-id": "user123", // User tracking + "x-bf-team-id": "team-ai", // Team tracking + "x-bf-trace-id": "trace-456" // Custom trace ID + } + } +}); + +const response = await llm.invoke("Hello!"); +console.log(response.content); +``` + + + + +--- + +## Using Direct Keys + +Pass API keys directly to bypass Bifrost's key management. You can pass any provider's API key since Bifrost only looks for `Authorization` or `x-api-key` headers. This requires the **Allow Direct API keys** option to be enabled in Bifrost configuration. + +> **Learn more:** See [Quickstart Configuration](../quickstart/README) for enabling direct API key usage. + + + + +```python +from langchain_openai import ChatOpenAI +from langchain_anthropic import ChatAnthropic +from langchain_core.messages import HumanMessage + +# Using OpenAI key directly +openai_llm = ChatOpenAI( + model="gpt-4o-mini", + openai_api_base="http://localhost:8080/langchain", + default_headers={ + "Authorization": "Bearer sk-your-openai-key" + } +) + +# Using Anthropic key for Claude models +anthropic_llm = ChatAnthropic( + model="claude-3-sonnet-20240229", + anthropic_api_url="http://localhost:8080/langchain", + default_headers={ + "x-api-key": "sk-ant-your-anthropic-key" + } +) + +# Using Azure OpenAI with direct Azure key +from langchain_openai import AzureChatOpenAI + +azure_llm = AzureChatOpenAI( + deployment_name="gpt-4o-aug", + api_key="your-azure-api-key", + azure_endpoint="http://localhost:8080/langchain", + api_version="2024-05-01-preview", + max_tokens=100, + default_headers={ + "x-bf-azure-endpoint": "https://your-resource.openai.azure.com", + } +) + +openai_response = openai_llm.invoke([HumanMessage(content="Hello GPT!")]) +anthropic_response = anthropic_llm.invoke([HumanMessage(content="Hello Claude!")]) +azure_response = azure_llm.invoke([HumanMessage(content="Hello from Azure!")]) +``` + + + + +```javascript +import { ChatOpenAI } from "@langchain/openai"; +import { ChatAnthropic } from "@langchain/anthropic"; + +// Using OpenAI key directly +const openaiLlm = new ChatOpenAI({ + model: "gpt-4o-mini", + configuration: { + baseURL: "http://localhost:8080/langchain", + defaultHeaders: { + "Authorization": "Bearer sk-your-openai-key" + } + } +}); + +// Using Anthropic key for Claude models +const anthropicLlm = new ChatAnthropic({ + model: "claude-3-sonnet-20240229", + clientOptions: { + baseURL: "http://localhost:8080/langchain", + defaultHeaders: { + "x-api-key": "sk-ant-your-anthropic-key" + } + } +}); + +// Using Azure OpenAI with direct Azure key +import { AzureChatOpenAI } from "@langchain/openai"; + +const azureLlm = new AzureChatOpenAI({ + deploymentName: "gpt-4o-aug", + apiKey: "your-azure-api-key", + azureOpenAIEndpoint: "http://localhost:8080/langchain", + apiVersion: "2024-05-01-preview", + maxTokens: 100, + configuration: { + defaultHeaders: { + "x-bf-azure-endpoint": "https://your-resource.openai.azure.com", + } + } +}); + +const openaiResponse = await openaiLlm.invoke("Hello GPT!"); +const anthropicResponse = await anthropicLlm.invoke("Hello Claude!"); +const azureResponse = await azureLlm.invoke("Hello from Azure!"); +``` + + + + +--- + +## Supported Features + +The Langchain integration supports all features that are available in both the Langchain SDK and Bifrost core functionality. Your existing Langchain chains and workflows work seamlessly with Bifrost's enterprise features. πŸ˜„ + +--- + +## Next Steps + +- **[Governance Features](../features/governance)** - Virtual keys and team management +- **[Semantic Caching](../features/semantic-caching)** - Intelligent response caching +- **[Configuration](../quickstart/README)** - Provider setup and API key management diff --git a/docs/integrations/litellm-sdk.mdx b/docs/integrations/litellm-sdk.mdx new file mode 100644 index 000000000..3d6d1cb2d --- /dev/null +++ b/docs/integrations/litellm-sdk.mdx @@ -0,0 +1,183 @@ +--- +title: "LiteLLM SDK" +description: "Use Bifrost as a drop-in proxy for LiteLLM applications with zero code changes." +icon: "train" +--- + +Since LiteLLM already provides multi-provider abstraction, Bifrost adds enterprise features like governance, semantic caching, MCP tools, observability, etc, on top of your existing setup. + +**Endpoint:** `/litellm` + + + **Provider Compatibility:** This integration only works for AI providers that both LiteLLM and Bifrost support. If you're using a provider specific to LiteLLM that Bifrost doesn't support (or vice versa), those requests will fail. + +--- + +## Setup + + + + +```python {7} +from litellm import completion + +# Configure client to use Bifrost +response = completion( + model="gpt-4o-mini", + messages=[{"role": "user", "content": "Hello!"}], + base_url="http://localhost:8080/litellm" # Point to Bifrost +) + +print(response.choices[0].message.content) +``` + + + + +--- + +## Provider/Model Usage Examples + +Your existing LiteLLM provider switching works unchanged through Bifrost: + + + + +```python {4} +from litellm import completion + +# All your existing LiteLLM patterns work the same +base_url = "http://localhost:8080/litellm" + +# OpenAI models +openai_response = completion( + model="gpt-4o-mini", + messages=[{"role": "user", "content": "Hello GPT!"}], + base_url=base_url +) + +# Anthropic models +anthropic_response = completion( + model="claude-3-sonnet-20240229", + messages=[{"role": "user", "content": "Hello Claude!"}], + base_url=base_url +) + +# Google models +google_response = completion( + model="gemini/gemini-1.5-flash", + messages=[{"role": "user", "content": "Hello Gemini!"}], + base_url=base_url +) + +# Azure OpenAI models +azure_response = completion( + model="azure/gpt-4o", + messages=[{"role": "user", "content": "Hello Azure!"}], + base_url=base_url +) +``` + + + + +--- + +## Adding Custom Headers + +Add Bifrost-specific headers for governance and tracking: + + + + +```python +from litellm import completion + +# Add custom headers for Bifrost features +response = completion( + model="gpt-4o-mini", + messages=[{"role": "user", "content": "Hello!"}], + base_url="http://localhost:8080/litellm", + extra_headers={ + "x-bf-vk": "your-virtual-key", # Virtual key for governance + "x-bf-user-id": "user123", # User tracking + "x-bf-team-id": "team-ai", # Team tracking + "x-bf-trace-id": "trace-456" # Custom trace ID + } +) + +print(response.choices[0].message.content) +``` + + + + +--- + +## Using Direct Keys + +Pass API keys directly to bypass Bifrost's key management. You can pass any provider's API key since Bifrost only looks for `Authorization` or `x-api-key` headers. This requires the **Allow Direct API keys** option to be enabled in Bifrost configuration. + +> **Learn more:** See [Quickstart Configuration](../quickstart/README) for enabling direct API key usage. + + + + +```python +from litellm import completion + +# Using OpenAI key directly +openai_response = completion( + model="gpt-4o-mini", + messages=[{"role": "user", "content": "Hello GPT!"}], + base_url="http://localhost:8080/litellm", + extra_headers={ + "Authorization": "Bearer sk-your-openai-key" + } +) + +# Using Anthropic key for Claude models +anthropic_response = completion( + model="claude-3-sonnet-20240229", + messages=[{"role": "user", "content": "Hello Claude!"}], + base_url="http://localhost:8080/litellm", + extra_headers={ + "x-api-key": "sk-ant-your-anthropic-key" + } +) + +# Using Azure OpenAI with direct Azure key +import os + +deployment = os.getenv("AZURE_OPENAI_DEPLOYMENT", "my-azure-deployment") +model = f"azure/{deployment}" + +azure_response = completion( + model=model, + messages=[{"role": "user", "content": "Hello from LiteLLM (Azure demo)!"}], + base_url="http://localhost:8080/litellm", + api_key=os.getenv("AZURE_API_KEY", "your-azure-api-key"), + deployment_id=os.getenv("AZURE_OPENAI_DEPLOYMENT", "gpt-4o-aug"), + max_tokens=100, + extra_headers={ + "x-bf-azure-endpoint": "https://your-resource.openai.azure.com", + } +) +``` + + + + +--- + +## Supported Features + +The LiteLLM integration supports all features that are available in both the LiteLLM SDK and Bifrost core functionality. Your existing LiteLLM code works seamlessly with Bifrost's enterprise features. πŸ˜„ + +--- + +## Next Steps + +- **[Governance Features](../features/governance)** - Virtual keys and team management +- **[Semantic Caching](../features/semantic-caching)** - Intelligent response caching +- **[Configuration](../quickstart/README)** - Provider setup and API key management diff --git a/docs/integrations/openai-sdk.mdx b/docs/integrations/openai-sdk.mdx new file mode 100644 index 000000000..889bbdd63 --- /dev/null +++ b/docs/integrations/openai-sdk.mdx @@ -0,0 +1,371 @@ +--- +title: "OpenAI SDK" +description: "Use Bifrost as a drop-in replacement for OpenAI API with full compatibility and enhanced features." +icon: "o" +--- + +## Overview + +Bifrost provides complete OpenAI API compatibility through protocol adaptation. The integration handles request transformation, response normalization, and error mapping between OpenAI's API specification and Bifrost's internal processing pipeline. + +This integration enables you to utilize Bifrost's features like governance, load balancing, semantic caching, multi-provider support, and more, all while preserving your existing OpenAI SDK-based architecture. + +**Endpoint:** `/openai` + +--- + +## Setup + + + + +```python {5} +import openai + +# Configure client to use Bifrost +client = openai.OpenAI( + base_url="http://localhost:8080/openai", + api_key="dummy-key" # Keys handled by Bifrost +) + +# Make requests as usual +response = client.chat.completions.create( + model="gpt-4o-mini", + messages=[{"role": "user", "content": "Hello!"}] +) + +print(response.choices[0].message.content) +``` + + + + +```javascript {5} +import OpenAI from "openai"; + +// Configure client to use Bifrost +const openai = new OpenAI({ + baseURL: "http://localhost:8080/openai", + apiKey: "dummy-key", // Keys handled by Bifrost +}); + +// Make requests as usual +const response = await openai.chat.completions.create({ + model: "gpt-4o-mini", + messages: [{ role: "user", content: "Hello!" }], +}); + +console.log(response.choices[0].message.content); +``` + + + + +--- + +## Provider/Model Usage Examples + +Use multiple providers through the same OpenAI SDK format by prefixing model names with the provider: + + + + +```python +import openai + +client = openai.OpenAI( + base_url="http://localhost:8080/openai", + api_key="dummy-key" +) + +# OpenAI models (default) +openai_response = client.chat.completions.create( + model="gpt-4o-mini", + messages=[{"role": "user", "content": "Hello from OpenAI!"}] +) + +# Anthropic models via OpenAI SDK format +anthropic_response = client.chat.completions.create( + model="anthropic/claude-3-sonnet-20240229", + messages=[{"role": "user", "content": "Hello from Claude!"}] +) + +# Google Vertex models via OpenAI SDK format +vertex_response = client.chat.completions.create( + model="vertex/gemini-pro", + messages=[{"role": "user", "content": "Hello from Gemini!"}] +) + +# Azure OpenAI models +azure_response = client.chat.completions.create( + model="azure/gpt-4o", + messages=[{"role": "user", "content": "Hello from Azure!"}] +) + +# Local Ollama models +ollama_response = client.chat.completions.create( + model="ollama/llama3.1:8b", + messages=[{"role": "user", "content": "Hello from Ollama!"}] +) +``` + + + + +```javascript +import OpenAI from "openai"; + +const openai = new OpenAI({ + baseURL: "http://localhost:8080/openai", + apiKey: "dummy-key", +}); + +// OpenAI models (default) +const openaiResponse = await openai.chat.completions.create({ + model: "gpt-4o-mini", + messages: [{ role: "user", content: "Hello from OpenAI!" }], +}); + +// Anthropic models via OpenAI SDK format +const anthropicResponse = await openai.chat.completions.create({ + model: "anthropic/claude-3-sonnet-20240229", + messages: [{ role: "user", content: "Hello from Claude!" }], +}); + +// Google Vertex models via OpenAI SDK format +const vertexResponse = await openai.chat.completions.create({ + model: "vertex/gemini-pro", + messages: [{ role: "user", content: "Hello from Gemini!" }], +}); + +// Azure OpenAI models +const azureResponse = await openai.chat.completions.create({ + model: "azure/gpt-4o", + messages: [{ role: "user", content: "Hello from Azure!" }], +}); + +// Local Ollama models +const ollamaResponse = await openai.chat.completions.create({ + model: "ollama/llama3.1:8b", + messages: [{ role: "user", content: "Hello from Ollama!" }], +}); +``` + + + + +--- + +## Adding Custom Headers + +Pass custom headers required by Bifrost plugins (like governance, telemetry, etc.): + + + + +```python +import openai + +client = openai.OpenAI( + base_url="http://localhost:8080/openai", + api_key="dummy-key", + default_headers={ + "x-bf-vk": "vk_12345", # Virtual key for governance + "x-bf-user-id": "user_789", # User identification + "x-bf-team-id": "team_456", # Team identification + "x-bf-trace-id": "trace_abc123", # Request tracing + } +) + +response = client.chat.completions.create( + model="gpt-4o-mini", + messages=[{"role": "user", "content": "Hello with custom headers!"}] +) +``` + + + + +```javascript +import OpenAI from "openai"; + +const openai = new OpenAI({ + baseURL: "http://localhost:8080/openai", + apiKey: "dummy-key", + defaultHeaders: { + "x-bf-vk": "vk_12345", // Virtual key for governance + "x-bf-user-id": "user_789", // User identification + "x-bf-team-id": "team_456", // Team identification + "x-bf-trace-id": "trace_abc123", // Request tracing + }, +}); + +const response = await openai.chat.completions.create({ + model: "gpt-4o-mini", + messages: [{ role: "user", content: "Hello with custom headers!" }], +}); +``` + + + + +--- + +## Using Direct Keys + +Pass API keys directly in requests to bypass Bifrost's load balancing. You can pass any provider's API key (OpenAI, Anthropic, Mistral, etc.) since Bifrost only looks for `Authorization` or `x-api-key` headers. This requires the **Allow Direct API keys** option to be enabled in Bifrost configuration. + +> **Learn more:** See [Quickstart Configuration](../quickstart/README) for enabling direct API key usage. + + + + +```python +import openai + +# Using OpenAI's API key directly +client_with_direct_key = openai.OpenAI( + base_url="http://localhost:8080/openai", + api_key="sk-your-openai-key" # OpenAI's API key works +) + +openai_response = client_with_direct_key.chat.completions.create( + model="openai/gpt-4o-mini", + messages=[{"role": "user", "content": "Hello from GPT!"}] +) + +# Or pass different provider keys per request +client = openai.OpenAI( + base_url="http://localhost:8080/openai", + api_key="dummy-key" +) + +# Use OpenAI key for GPT models +openai_response = client.chat.completions.create( + model="gpt-4o-mini", + messages=[{"role": "user", "content": "Hello GPT!"}], + extra_headers={ + "Authorization": "Bearer sk-your-openai-key" + } +) + +# Use Anthropic key for Claude models +anthropic_response = client.chat.completions.create( + model="anthropic/claude-3-sonnet-20240229", + messages=[{"role": "user", "content": "Hello Claude!"}], + extra_headers={ + "x-api-key": "sk-ant-your-anthropic-key" + } +) +``` + + + + +```javascript +import OpenAI from "openai"; + +// Using OpenAI's API key directly +const openaiWithDirectKey = new OpenAI({ + baseURL: "http://localhost:8080/openai", + apiKey: "sk-your-openai-key", // OpenAI's API key works +}); + +const openaiResponse = await openaiWithDirectKey.chat.completions.create({ + model: "openai/gpt-4o-mini", + messages: [{ role: "user", content: "Hello from GPT!" }], +}); + +// Or pass different provider keys per request +const openai = new OpenAI({ + baseURL: "http://localhost:8080/openai", + apiKey: "dummy-key", +}); + +// Use OpenAI key for GPT models +const openaiResponse = await openai.chat.completions.create({ + model: "gpt-4o-mini", + messages: [{ role: "user", content: "Hello GPT!" }], + headers: { + "Authorization": "Bearer sk-your-openai-key", + }, +}); + +// Use Anthropic key for Claude models +const anthropicResponseWithHeader = await openai.chat.completions.create({ + model: "anthropic/claude-3-sonnet-20240229", + messages: [{ role: "user", content: "Hello Claude!" }], + headers: { + "x-api-key": "sk-ant-your-anthropic-key", + }, +}); +``` + + + + +For Azure OpenAI, you can use the AzureOpenAI client and point it to Bifrost integration endpoint. The `x-bf-azure-endpoint` header is required to specify your Azure OpenAI resource endpoint. + + + + +```python +from openai import AzureOpenAI + +azure_client = AzureOpenAI( + api_key="your-azure-api-key", + api_version="2024-02-01", + azure_endpoint="http://localhost:8080/openai", # Point to Bifrost + default_headers={ + "x-bf-azure-endpoint": "https://your-resource.openai.azure.com" + } +) + +azure_response = azure_client.chat.completions.create( + model="gpt-4-deployment", # Your deployment name + messages=[{"role": "user", "content": "Hello from Azure!"}] +) + +print(azure_response.choices[0].message.content) +``` + + + + +```javascript +import { AzureOpenAI } from "openai"; + +const azureClient = new AzureOpenAI({ + apiKey: "your-azure-api-key", + apiVersion: "2024-02-01", + baseURL: "http://localhost:8080/openai", // Point to Bifrost + defaultHeaders: { + "x-bf-azure-endpoint": "https://your-resource.openai.azure.com" + } +}); + +const azureResponse = await azureClient.chat.completions.create({ + model: "gpt-4-deployment", // Your deployment name + messages: [{ role: "user", content: "Hello from Azure!" }], +}); + +console.log(azureResponse.choices[0].message.content); +``` + + + + +--- + +## Supported Features + +The OpenAI integration supports all features that are available in both the OpenAI SDK and Bifrost core functionality. If the OpenAI SDK supports a feature and Bifrost supports it, the integration will work seamlessly. πŸ˜„ + +--- + +## Next Steps + +- **[Anthropic SDK](./anthropic-sdk)** - Claude integration patterns +- **[Google GenAI SDK](./genai-sdk)** - Gemini integration patterns +- **[Configuration](../quickstart/README)** - Bifrost setup and configuration +- **[Core Features](../features/)** - Advanced Bifrost capabilities \ No newline at end of file diff --git a/docs/integrations/what-is-an-integration.mdx b/docs/integrations/what-is-an-integration.mdx new file mode 100644 index 000000000..bb4c87306 --- /dev/null +++ b/docs/integrations/what-is-an-integration.mdx @@ -0,0 +1,231 @@ +--- +title: "What is an integration?" +description: "Protocol adapters that translate between Bifrost's unified API and provider-specific API formats like OpenAI, Anthropic, and Google GenAI." +icon: "box" +--- + +## Overview + +An integration is a protocol adapter that translates between Bifrost's unified API and provider-specific API formats. Each integration handles request transformation, response normalization, and error mapping between the external API contract and Bifrost's internal processing pipeline. + +Integrations enable you to utilize Bifrost's features like governance, MCP tools, load balancing, semantic caching, multi-provider support, and more, all while preserving your existing SDK-based architecture. Bifrost handles all the overhead of structure conversion, requiring only a single URL change to switch from direct provider APIs to Bifrost's gateway. + +Bifrost converts the request/response format of the provider API to the Bifrost API format based on the integration used, so you don't have to. + +--- + +## Quick Migration + +### **Before (Direct Provider)** + +```python +import openai + +client = openai.OpenAI( + api_key="your-openai-key" +) +``` + +### **After (Bifrost)** + +```python {4} +import openai + +client = openai.OpenAI( + base_url="http://localhost:8080/openai", # Point to Bifrost + api_key="dummy-key" # Keys are handled in Bifrost now +) +``` + +**That's it!** Your application now benefits from Bifrost's features with no other changes. + +--- + +## Supported Integrations + +1. [OpenAI](./openai-sdk) +2. [Anthropic](./anthropic-sdk) +3. [Google GenAI](./genai-sdk) +4. [LiteLLM](./litellm-sdk) +5. [Langchain](./langchain-sdk) + +--- + +## Provider-Prefixed Models + +Use multiple providers seamlessly by prefixing model names with the provider: + + + +```python +import openai + +# Single client, multiple providers +client = openai.OpenAI( + base_url="http://localhost:8080/openai", + api_key="dummy" # API keys configured in Bifrost +) + +# OpenAI models +response1 = client.chat.completions.create( + model="gpt-4o-mini", # (default OpenAI since it's OpenAI's SDK) + messages=[{"role": "user", "content": "Hello!"}] +) +``` + + +```python +import openai + +# Anthropic models using OpenAI SDK format +response2 = client.chat.completions.create( + model="anthropic/claude-3-sonnet-20240229", + messages=[{"role": "user", "content": "Hello!"}] +) +``` + + +```python +import openai + +# Azure OpenAI models +response4 = client.chat.completions.create( + model="azure/gpt-4o", + messages=[{"role": "user", "content": "Hello!"}] +) +``` + + +```python +import openai + +# Google Vertex models +response3 = client.chat.completions.create( + model="vertex/gemini-pro", + messages=[{"role": "user", "content": "Hello!"}] +) +``` + + +```python +import openai + +# Local Ollama models +response5 = client.chat.completions.create( + model="ollama/llama3.1:8b", + messages=[{"role": "user", "content": "Hello!"}] +) +``` + + + +--- + +## Direct API Usage + +For custom HTTP clients or when you have existing provider-specific setup and want to use Bifrost gateway without restructuring your codebase: + +```python {5,18,31,} +import requests + +# Fully OpenAI compatible endpoint +response = requests.post( + "http://localhost:8080/openai/v1/chat/completions", + headers={ + "Authorization": f"Bearer {openai_key}", + "Content-Type": "application/json" + }, + json={ + "model": "gpt-4o-mini", + "messages": [{"role": "user", "content": "Hello!"}] + } +) + +# Fully Anthropic compatible endpoint +response = requests.post( + "http://localhost:8080/anthropic/v1/messages", + headers={ + "Content-Type": "application/json", + }, + json={ + "model": "claude-3-sonnet-20240229", + "max_tokens": 1000, + "messages": [{"role": "user", "content": "Hello!"}] + } +) + +# Fully Google GenAI compatible endpoint +response = requests.post( + "http://localhost:8080/genai/v1beta/models/gemini-1.5-flash/generateContent", + headers={ + "Content-Type": "application/json", + }, + json={ + "contents": [ + {"parts": [{"text": "Hello!"}]} + ], + "generation_config": { + "max_output_tokens": 1000, + "temperature": 1 + } + } +) +``` + +--- + + +## Migration Strategies + +### **Gradual Migration** + +1. **Start with development** - Test Bifrost in dev environment +2. **Canary deployment** - Route 5% of traffic through Bifrost +3. **Feature-by-feature** - Migrate specific endpoints gradually +4. **Full migration** - Switch all traffic to Bifrost + +### **Blue-Green Migration** + +```python +import os +import random + +# Route traffic based on feature flag +def get_base_url(provider: str) -> str: + if os.getenv("USE_BIFROST", "false") == "true": + return f"http://bifrost:8080/{provider}" + else: + return f"https://api.{provider}.com" + +# Gradual rollout +def should_use_bifrost() -> bool: + rollout_percentage = int(os.getenv("BIFROST_ROLLOUT", "0")) + return random.randint(1, 100) <= rollout_percentage +``` + +### **Feature Flag Integration** + +```python +# Using feature flags for safe migration +import openai +from feature_flags import get_flag + +def create_client(): + if get_flag("use_bifrost_openai"): + base_url = "http://bifrost:8080/openai" + else: + base_url = "https://api.openai.com" + + return openai.OpenAI( + base_url=base_url, + api_key=os.getenv("OPENAI_API_KEY") + ) +``` + +--- + +## Next Steps + +- **[HTTP Transport Overview](../quickstart/gateway/setting-up)** - Main HTTP transport guide +- **[Endpoints](../apis/openapi.json)** - Complete API reference +- **[Configuration](../quickstart/gateway/provider-configuration)** - Provider setup and config diff --git a/docs/intercom.js b/docs/intercom.js new file mode 100644 index 000000000..2ec5006ee --- /dev/null +++ b/docs/intercom.js @@ -0,0 +1,8 @@ +window.intercomSettings = { + api_base: "https://api-iam.intercom.io", + app_id: "glx5mihe", +}; + + +// We pre-filled your app ID in the widget URL: 'https://widget.intercom.io/widget/glx5mihe' +(function () { var w = window; var ic = w.Intercom; if (typeof ic === "function") { ic('reattach_activator'); ic('update', w.intercomSettings); } else { var d = document; var i = function () { i.c(arguments); }; i.q = []; i.c = function (args) { i.q.push(args); }; w.Intercom = i; var l = function () { var s = d.createElement('script'); s.type = 'text/javascript'; s.async = true; s.src = 'https://widget.intercom.io/widget/glx5mihe'; var x = d.getElementsByTagName('script')[0]; x.parentNode.insertBefore(s, x); }; if (document.readyState === 'complete') { l(); } else if (w.attachEvent) { w.attachEvent('onload', l); } else { w.addEventListener('load', l, false); } } })(); diff --git a/docs/jsonLd.js b/docs/jsonLd.js new file mode 100644 index 000000000..7b8be576e --- /dev/null +++ b/docs/jsonLd.js @@ -0,0 +1,55 @@ +const jsonLd = { + "@context": "https://schema.org", + "@type": "WebPage", + url: "https://www.getmaxim.ai/bifrost/docs", + name: "Bifrost Documentation", + description: + "Comprehensive documentation for Maxim's end-to-end platform for AI simulation, evaluation, and observability. Learn how to build, evaluate, and monitor GenAI workflows at scale.", + publisher: { + "@type": "Organization", + name: "Bifrost", + url: "https://www.getmaxim.ai/bifrost", + logo: { + "@type": "ImageObject", + url: "https://bifrost.getmaxim.ai/logo-full.svg", + width: 300, + height: 60, + }, + sameAs: ["https://twitter.com/getmaximai", "https://www.linkedin.com/company/maxim-ai", "https://www.youtube.com/@getmaximai"], + }, + mainEntity: { + "@type": "TechArticle", + name: "Bifrost Documentation", + url: "https://www.getmaxim.ai/bifrost", + headline: "Bifrost Docs", + description: + "Bifrost is the fastest LLM gateway in the market, 90x faster than LiteLLM (P99 latency).", + inLanguage: "en", + }, +}; + +function injectJsonLd() { + const script = document.createElement("script"); + script.type = "application/ld+json"; + script.text = JSON.stringify(jsonLd); + + if (document.readyState === "loading") { + document.addEventListener("DOMContentLoaded", () => { + document.head.appendChild(script); + }); + } else { + document.head.appendChild(script); + } + + return () => { + if (script.parentNode) { + script.parentNode.removeChild(script); + } + }; +} + +// Call the function to inject JSON-LD +const cleanup = injectJsonLd(); + +// Cleanup when needed +// cleanup() \ No newline at end of file diff --git a/docs/media/aws-icon.png b/docs/media/aws-icon.png new file mode 100644 index 000000000..627547c13 Binary files /dev/null and b/docs/media/aws-icon.png differ diff --git a/docs/media/azure-icon.png b/docs/media/azure-icon.png new file mode 100644 index 000000000..7c750318d Binary files /dev/null and b/docs/media/azure-icon.png differ diff --git a/docs/media/bifrost-logo-dark.png b/docs/media/bifrost-logo-dark.png new file mode 100644 index 000000000..5049cb85f Binary files /dev/null and b/docs/media/bifrost-logo-dark.png differ diff --git a/docs/media/bifrost-logo.png b/docs/media/bifrost-logo.png new file mode 100644 index 000000000..b47319dc4 Binary files /dev/null and b/docs/media/bifrost-logo.png differ diff --git a/docs/media/cloudflare-icon.png b/docs/media/cloudflare-icon.png new file mode 100644 index 000000000..21f809aed Binary files /dev/null and b/docs/media/cloudflare-icon.png differ diff --git a/docs/media/clustering-diagram.png b/docs/media/clustering-diagram.png new file mode 100644 index 000000000..5b3a5d764 Binary files /dev/null and b/docs/media/clustering-diagram.png differ diff --git a/docs/media/cover.png b/docs/media/cover.png new file mode 100644 index 000000000..b19c328ca Binary files /dev/null and b/docs/media/cover.png differ diff --git a/docs/media/gcp-icon.png b/docs/media/gcp-icon.png new file mode 100644 index 000000000..2adedff32 Binary files /dev/null and b/docs/media/gcp-icon.png differ diff --git a/docs/media/gcp-icon.svg b/docs/media/gcp-icon.svg new file mode 100644 index 000000000..cb7a2aa70 --- /dev/null +++ b/docs/media/gcp-icon.svg @@ -0,0 +1,11 @@ + + + + + Error 404 (Not Found)!!1 + + +

404. That’s an error. +

The requested URL /devrel-devsite/prod/v2210deb8920cd4a55bd580441aa58e7853afc04b39a9d9ac4798e1aa28e803c49/cloud/images/cloud-logo.svg was not found on this server. That’s all we know. diff --git a/docs/media/getting-started.png b/docs/media/getting-started.png new file mode 100644 index 000000000..c7b2d1d8b Binary files /dev/null and b/docs/media/getting-started.png differ diff --git a/docs/media/maxim-logs.png b/docs/media/maxim-logs.png new file mode 100644 index 000000000..c738f8067 Binary files /dev/null and b/docs/media/maxim-logs.png differ diff --git a/docs/media/package-demo.mp4 b/docs/media/package-demo.mp4 new file mode 100644 index 000000000..a7651c07c Binary files /dev/null and b/docs/media/package-demo.mp4 differ diff --git a/docs/media/provider-configs.png b/docs/media/provider-configs.png new file mode 100644 index 000000000..8112b35ac Binary files /dev/null and b/docs/media/provider-configs.png differ diff --git a/docs/media/run-npx.mp4 b/docs/media/run-npx.mp4 new file mode 100644 index 000000000..3521738e6 Binary files /dev/null and b/docs/media/run-npx.mp4 differ diff --git a/docs/media/traffic-redistribution.png b/docs/media/traffic-redistribution.png new file mode 100644 index 000000000..fa8278690 Binary files /dev/null and b/docs/media/traffic-redistribution.png differ diff --git a/docs/media/ui-azure-config.png b/docs/media/ui-azure-config.png new file mode 100644 index 000000000..1d532845a Binary files /dev/null and b/docs/media/ui-azure-config.png differ diff --git a/docs/media/ui-bedrock-config.png b/docs/media/ui-bedrock-config.png new file mode 100644 index 000000000..bf7844b20 Binary files /dev/null and b/docs/media/ui-bedrock-config.png differ diff --git a/docs/media/ui-concurrency-buffer-size.png b/docs/media/ui-concurrency-buffer-size.png new file mode 100644 index 000000000..329206659 Binary files /dev/null and b/docs/media/ui-concurrency-buffer-size.png differ diff --git a/docs/media/ui-concurrency-timeout.png b/docs/media/ui-concurrency-timeout.png new file mode 100644 index 000000000..a415526f4 Binary files /dev/null and b/docs/media/ui-concurrency-timeout.png differ diff --git a/docs/media/ui-config-direct-keys.png b/docs/media/ui-config-direct-keys.png new file mode 100644 index 000000000..408ea11ee Binary files /dev/null and b/docs/media/ui-config-direct-keys.png differ diff --git a/docs/media/ui-config.png b/docs/media/ui-config.png new file mode 100644 index 000000000..6dc809f70 Binary files /dev/null and b/docs/media/ui-config.png differ diff --git a/docs/media/ui-create-customer.png b/docs/media/ui-create-customer.png new file mode 100644 index 000000000..0b01d038d Binary files /dev/null and b/docs/media/ui-create-customer.png differ diff --git a/docs/media/ui-create-teams.png b/docs/media/ui-create-teams.png new file mode 100644 index 000000000..108f13848 Binary files /dev/null and b/docs/media/ui-create-teams.png differ diff --git a/docs/media/ui-custom-provider.png b/docs/media/ui-custom-provider.png new file mode 100644 index 000000000..950a0702e Binary files /dev/null and b/docs/media/ui-custom-provider.png differ diff --git a/docs/media/ui-grafana-dashboard.png b/docs/media/ui-grafana-dashboard.png new file mode 100644 index 000000000..88b8a98de Binary files /dev/null and b/docs/media/ui-grafana-dashboard.png differ diff --git a/docs/media/ui-live-log-stream.gif b/docs/media/ui-live-log-stream.gif new file mode 100644 index 000000000..883da06d7 Binary files /dev/null and b/docs/media/ui-live-log-stream.gif differ diff --git a/docs/media/ui-log-filtering.gif b/docs/media/ui-log-filtering.gif new file mode 100644 index 000000000..1cb93a5d4 Binary files /dev/null and b/docs/media/ui-log-filtering.gif differ diff --git a/docs/media/ui-mcp-config.png b/docs/media/ui-mcp-config.png new file mode 100644 index 000000000..5176ddfd8 Binary files /dev/null and b/docs/media/ui-mcp-config.png differ diff --git a/docs/media/ui-multi-key-for-models.png b/docs/media/ui-multi-key-for-models.png new file mode 100644 index 000000000..b68ade0ff Binary files /dev/null and b/docs/media/ui-multi-key-for-models.png differ diff --git a/docs/media/ui-multimodal-tracing.png b/docs/media/ui-multimodal-tracing.png new file mode 100644 index 000000000..281a7c0df Binary files /dev/null and b/docs/media/ui-multimodal-tracing.png differ diff --git a/docs/media/ui-observability-config.png b/docs/media/ui-observability-config.png new file mode 100644 index 000000000..a929a1eff Binary files /dev/null and b/docs/media/ui-observability-config.png differ diff --git a/docs/media/ui-prometheus-labels.png b/docs/media/ui-prometheus-labels.png new file mode 100644 index 000000000..ad59bb54a Binary files /dev/null and b/docs/media/ui-prometheus-labels.png differ diff --git a/docs/media/ui-proxy-setup.png b/docs/media/ui-proxy-setup.png new file mode 100644 index 000000000..2fe913182 Binary files /dev/null and b/docs/media/ui-proxy-setup.png differ diff --git a/docs/media/ui-raw-response.png b/docs/media/ui-raw-response.png new file mode 100644 index 000000000..15f0a5537 Binary files /dev/null and b/docs/media/ui-raw-response.png differ diff --git a/docs/media/ui-request-tracing-overview.png b/docs/media/ui-request-tracing-overview.png new file mode 100644 index 000000000..8f88f6b1f Binary files /dev/null and b/docs/media/ui-request-tracing-overview.png differ diff --git a/docs/media/ui-semantic-cache-config.png b/docs/media/ui-semantic-cache-config.png new file mode 100644 index 000000000..ffe563881 Binary files /dev/null and b/docs/media/ui-semantic-cache-config.png differ diff --git a/docs/media/ui-tracing-config.png b/docs/media/ui-tracing-config.png new file mode 100644 index 000000000..0741f0386 Binary files /dev/null and b/docs/media/ui-tracing-config.png differ diff --git a/docs/media/ui-vertex-config.png b/docs/media/ui-vertex-config.png new file mode 100644 index 000000000..c091f9460 Binary files /dev/null and b/docs/media/ui-vertex-config.png differ diff --git a/docs/media/ui-virtual-key.png b/docs/media/ui-virtual-key.png new file mode 100644 index 000000000..4665244a9 Binary files /dev/null and b/docs/media/ui-virtual-key.png differ diff --git a/docs/media/ui-weighted-provider-keys.png b/docs/media/ui-weighted-provider-keys.png new file mode 100644 index 000000000..1e02e290f Binary files /dev/null and b/docs/media/ui-weighted-provider-keys.png differ diff --git a/docs/media/vercel-icon.png b/docs/media/vercel-icon.png new file mode 100644 index 000000000..7bdcd2a19 Binary files /dev/null and b/docs/media/vercel-icon.png differ diff --git a/docs/quickstart/README.mdx b/docs/quickstart/README.mdx new file mode 100644 index 000000000..e69de29bb diff --git a/docs/quickstart/gateway/debugging.mdx b/docs/quickstart/gateway/debugging.mdx new file mode 100644 index 000000000..e69de29bb diff --git a/docs/quickstart/gateway/integrations.mdx b/docs/quickstart/gateway/integrations.mdx new file mode 100644 index 000000000..a72273ce5 --- /dev/null +++ b/docs/quickstart/gateway/integrations.mdx @@ -0,0 +1,69 @@ +--- +title: "Integrations" +description: "Use Bifrost as a drop-in replacement for existing AI provider SDKs with zero code changes. Just change the base URL and unlock advanced features." +icon: "link" +--- + +## What are Integrations? + +Integrations are protocol adapters that make Bifrost **100% compatible** with existing AI provider SDKs. They translate between provider-specific API formats (OpenAI, Anthropic, Google GenAI) and Bifrost's unified API, enabling you to: + +- **Drop-in replacement** - Change only the base URL in your existing code +- **Zero migration effort** - Keep your current SDK and request/response handling +- **Instant feature access** - Get governance, caching, fallbacks, and monitoring without code changes + +## Quick Example + +### Before (Direct Provider) +```python +import openai + +client = openai.OpenAI( + api_key="your-openai-key" +) +``` + +### After (Bifrost Integration) +```python +import openai + +client = openai.OpenAI( + base_url="http://localhost:8080/openai", # Point to Bifrost + api_key="dummy-key" # Keys handled by Bifrost +) +``` + +**That's it!** Your application now has automatic fallbacks, governance, monitoring, and all Bifrost features. + +## Available Integrations + +Bifrost provides complete compatibility with these popular AI SDKs: + +- **[OpenAI SDK](../../integrations/openai-sdk)** +- **[Anthropic SDK](../../integrations/anthropic-sdk)** +- **[Google GenAI SDK](../../integrations/genai-sdk)** +- **[LiteLLM](../../integrations/litellm-sdk)** +- **[LangChain](../../integrations/langchain-sdk)** + +## Learn More + +For detailed setup guides, compatibility information, and advanced usage: + +**➜ [Complete Integration Documentation](../../integrations/what-is-an-integration)** + +## Next Steps + +Now that you understand integrations, explore these related topics: + +### Essential Topics + +- **[Provider Configuration](./provider-configuration)** - Set up multiple AI providers for redundancy +- **[Tool Calling](./tool-calling)** - Enable AI models to use external functions +- **[Streaming Responses](./streaming)** - Real-time response generation +- **[Multimodal AI](./multimodal)** - Process images, audio, and multimedia content + +### Advanced Topics + +- **[Core Features](../../features/)** - Governance, caching, and observability +- **[Architecture](../../architecture/)** - How Bifrost works internally +- **[Deployment](../../deployment/)** - Production setup and scaling diff --git a/docs/quickstart/gateway/multimodal.mdx b/docs/quickstart/gateway/multimodal.mdx new file mode 100644 index 000000000..6a260b485 --- /dev/null +++ b/docs/quickstart/gateway/multimodal.mdx @@ -0,0 +1,314 @@ +--- +title: "Multimodal Support" +description: "Process multiple types of content including images, audio, and text with AI models. Bifrost supports vision analysis, speech synthesis, and audio transcription across various providers." +icon: "images" +--- + +## Vision: Analyzing Images with AI + +Send images to vision-capable models for analysis, description, and understanding. This example shows how to analyze an image from a URL using GPT-4o with high detail processing for better accuracy. + +```bash +curl --location 'http://localhost:8080/v1/chat/completions' \ +--header 'Content-Type: application/json' \ +--data '{ + "model": "openai/gpt-4o", + "messages": [ + { + "role": "user", + "content": [ + { + "type": "text", + "text": "What do you see in this image? Please describe it in detail." + }, + { + "type": "image_url", + "image_url": { + "url": "https://upload.wikimedia.org/wikipedia/commons/thumb/d/dd/Gfp-wisconsin-madison-the-nature-boardwalk.jpg/2560px-Gfp-wisconsin-madison-the-nature-boardwalk.jpg", + "detail": "high" + } + } + ] + } + ] +}' +``` + +**Response includes detailed image analysis:** +```json +{ + "choices": [{ + "message": { + "role": "assistant", + "content": "I can see a beautiful wooden boardwalk extending through a natural landscape..." + } + }] +} +``` + +## Audio Understanding: Analyzing Audio with AI + +If your chat application supports text input, you can add audio input and outputβ€”just include audio in the modalities array and use an audio model, like gpt-4o-audio-preview. + +### Audio Input to Model + +```bash +curl --location 'http://localhost:8080/v1/chat/completions' \ +--header 'Content-Type: application/json' \ +--data '{ + "model": "openai/gpt-4o-audio-preview", + "modalities": ["text"], + "messages": [ + { + "role": "user", + "content": [ + { + "type": "text", + "text": "Please analyze this audio recording and summarize what was discussed." + }, + { + "type": "input_audio", + "input_audio": { + "data": "", + "format": "wav" + } + } + ] + } + ] +}' +``` + +### Audio Output from Model + +```bash +{ + "choices": [ + { + "index": 0, + "finish_reason": "stop", + "message": { + "role": "assistant", + "content": "The audio recording captured a brief segment where a speaker simply said \"Affirmative\" in response. There wasn't any detailed discussion or context provided beyond that one-word affirmation. If you have more audio or specific questions, feel free to share!" + } + } + ] +} +``` + +## Text-to-Speech: Converting Text to Audio + +Convert text into natural-sounding speech using AI voice models. This example demonstrates generating an MP3 audio file from text using the "alloy" voice. The result is returned as binary audio data. + +```bash +curl --location 'http://localhost:8080/v1/audio/speech' \ +--header 'Content-Type: application/json' \ +--data '{ + "model": "openai/tts-1", + "input": "Hello! This is a sample text that will be converted to speech using Bifrost speech synthesis capabilities. The weather today is wonderful, and I hope you are having a great day!", + "voice": "alloy", + "response_format": "mp3" +}' \ +--output "output.mp3" +``` + +**Save audio to file:** +```bash +# The --output flag saves the binary audio data directly to a file +# File size will vary based on input text length +``` + +## Speech-to-Text: Transcribing Audio Files + +Convert audio files into text using AI transcription models. This example shows how to transcribe an MP3 file using OpenAI's Whisper model, with an optional context prompt to improve accuracy. + +```bash +curl --location 'http://localhost:8080/v1/audio/transcriptions' \ +--form 'file=@"output.mp3"' \ +--form 'model="openai/whisper-1"' \ +--form 'prompt="This is a sample audio transcription from Bifrost speech synthesis."' +``` + +**Response format:** +```json +{ + "text": "Hello! This is a sample text that will be converted to speech using Bifrost speech synthesis capabilities. The weather today is wonderful, and I hope you are having a great day!" +} +``` + +## Advanced Vision Examples + +### Multiple Images + +Send multiple images in a single request for comparison or analysis. This is useful for comparing products, analyzing changes over time, or understanding relationships between different visual elements. + +```bash +curl --location 'http://localhost:8080/v1/chat/completions' \ +--header 'Content-Type: application/json' \ +--data '{ + "model": "openai/gpt-4o", + "messages": [ + { + "role": "user", + "content": [ + { + "type": "text", + "text": "Compare these two images. What are the differences?" + }, + { + "type": "image_url", + "image_url": { + "url": "https://example.com/image1.jpg" + } + }, + { + "type": "image_url", + "image_url": { + "url": "https://example.com/image2.jpg" + } + } + ] + } + ] +}' +``` + +### Base64 Images + +Process local images by encoding them as base64 data URLs. This approach is ideal when you need to analyze images stored locally on your system without uploading them to external URLs first. + +```bash +# First, encode your local image to base64 +base64_image=$(base64 -i local_image.jpg) +data_url="data:image/jpeg;base64,$base64_image" + +curl --location 'http://localhost:8080/v1/chat/completions' \ +--header 'Content-Type: application/json' \ +--data '{ + "model": "openai/gpt-4o", + "messages": [ + { + "role": "user", + "content": [ + { + "type": "text", + "text": "Analyze this image and describe what you see." + }, + { + "type": "image_url", + "image_url": { + "url": "'$data_url'", + "detail": "high" + } + } + ] + } + ] +}' +``` + +## Audio Configuration Options + +### Voice Selection for Speech Synthesis + +OpenAI provides six distinct voice options, each with different characteristics: + +- `alloy` - Balanced, natural voice +- `echo` - Deep, resonant voice +- `fable` - Expressive, storytelling voice +- `onyx` - Strong, confident voice +- `nova` - Bright, energetic voice +- `shimmer` - Gentle, soothing voice + +```bash +# Example with different voice +curl --location 'http://localhost:8080/v1/audio/speech' \ +--header 'Content-Type: application/json' \ +--data '{ + "model": "openai/tts-1", + "input": "This is the nova voice speaking.", + "voice": "nova", + "response_format": "mp3" +}' \ +--output "sample_nova.mp3" +``` + +### Audio Formats + +Generate audio in different formats depending on your use case. MP3 for general use, Opus for web streaming, AAC for mobile apps, and FLAC for high-quality audio applications. + +```bash +# MP3 format (default) +"response_format": "mp3" + +# Opus format for web streaming +"response_format": "opus" + +# AAC format for mobile apps +"response_format": "aac" + +# FLAC format for high-quality audio +"response_format": "flac" +``` + +## Transcription Options + +### Language Specification + +Improve transcription accuracy by specifying the source language. This is particularly helpful for non-English audio or when the audio contains technical terms or specific domain vocabulary. + +```bash +curl --location 'http://localhost:8080/v1/audio/transcriptions' \ +--form 'file=@"spanish_audio.mp3"' \ +--form 'model="openai/whisper-1"' \ +--form 'language="es"' \ +--form 'prompt="This is a Spanish audio recording about technology."' +``` + +### Response Formats + +Choose between simple text output or detailed JSON responses with timestamps. The verbose JSON format provides word-level and segment-level timing information, useful for creating subtitles or analyzing speech patterns. + +```bash +# Text only response +curl --location 'http://localhost:8080/v1/audio/transcriptions' \ +--form 'file=@"audio.mp3"' \ +--form 'model="openai/whisper-1"' \ +--form 'response_format="text"' + +# JSON with timestamps +curl --location 'http://localhost:8080/v1/audio/transcriptions' \ +--form 'file=@"audio.mp3"' \ +--form 'model="openai/whisper-1"' \ +--form 'response_format="verbose_json"' \ +--form 'timestamp_granularities[]=word' \ +--form 'timestamp_granularities[]=segment' +``` + +## Provider Support + +Different providers support different multimodal capabilities: + +| Provider | Vision | Text-to-Speech | Speech-to-Text | +|----------|--------|----------------|----------------| +| OpenAI | βœ… GPT-4V, GPT-4o | βœ… TTS-1, TTS-1-HD | βœ… Whisper | +| Anthropic | βœ… Claude 3 Sonnet/Opus | ❌ | ❌ | +| Google Vertex | βœ… Gemini Pro Vision | βœ… | βœ… | +| Azure OpenAI | βœ… GPT-4V | βœ… | βœ… Whisper | + +## Next Steps + +Now that you understand multimodal capabilities, explore these related topics: + +### Essential Topics + +- **[Streaming Responses](./streaming)** - Real-time multimodal processing +- **[Tool Calling](./tool-calling)** - Combine with external tools +- **[Provider Configuration](./provider-configuration)** - Multiple providers for different capabilities +- **[Integrations](./integrations)** - Drop-in compatibility with existing SDKs + +### Advanced Topics + +- **[Core Features](../../features/)** - Advanced Bifrost capabilities +- **[Architecture](../../architecture/)** - How Bifrost works internally +- **[Deployment](../../deployment/)** - Production setup and scaling diff --git a/docs/quickstart/gateway/provider-configuration.mdx b/docs/quickstart/gateway/provider-configuration.mdx new file mode 100644 index 000000000..a78c5a16b --- /dev/null +++ b/docs/quickstart/gateway/provider-configuration.mdx @@ -0,0 +1,927 @@ +--- +title: "Provider Configuration" +description: "Configure multiple AI providers for custom concurrency, queue sizes, proxy settings, and more." +icon: "sliders" +--- + +## Multi-Provider Setup + +Configure multiple providers to seamlessly switch between them. This example shows how to configure OpenAI, Anthropic, and Mistral providers. + + + + + +![Provider Configuration Interface](../../media/provider-configs.png) + +1. Go to **http://localhost:8080** +2. Navigate to **"Providers"** in the sidebar +3. Click **"Add Provider"** +4. Select provider and configure keys +5. Save configuration + + + + + +```bash +# Add OpenAI provider +curl --location 'http://localhost:8080/api/providers' \ +--header 'Content-Type: application/json' \ +--data '{ + "provider": "openai", + "keys": [ + { + "value": "env.OPENAI_API_KEY", + "models": [], + "weight": 1.0 + } + ] +}' + +# Add Anthropic provider +curl --location 'http://localhost:8080/api/providers' \ +--header 'Content-Type: application/json' \ +--data '{ + "provider": "anthropic", + "keys": [ + { + "value": "env.ANTHROPIC_API_KEY", + "models": [], + "weight": 1.0 + } + ] +}' +``` + + + + + +```json +{ + "providers": { + "openai": { + "keys": [ + { + "value": "env.OPENAI_API_KEY", + "models": [], + "weight": 1.0 + } + ] + }, + "anthropic": { + "keys": [ + { + "value": "env.ANTHROPIC_API_KEY", + "models": [], + "weight": 1.0 + } + ] + } + } +} +``` + + + + + +## Making Requests + +Once providers are configured, you can make requests to any specific provider. This example shows how to send a request directly to OpenAI's GPT-4o Mini model. Bifrost handles the provider-specific API formatting automatically. + +```bash +curl --location 'http://localhost:8080/v1/chat/completions' \ +--header 'Content-Type: application/json' \ +--data '{ + "model": "openai/gpt-4o-mini", + "messages": [ + {"role": "user", "content": "Hello!"} + ] +}' +``` + +## Environment Variables + +Set up your API keys for the providers you want to use. Bifrost supports both direct key values and environment variable references with the `env.` prefix: + +```bash +export OPENAI_API_KEY="your-openai-api-key" +export ANTHROPIC_API_KEY="your-anthropic-api-key" +export MISTRAL_API_KEY="your-mistral-api-key" +export GROQ_API_KEY="your-groq-api-key" +export COHERE_API_KEY="your-cohere-api-key" +``` + +**Environment Variable Handling:** +- Use `"value": "env.VARIABLE_NAME"` to reference environment variables +- Use `"value": "sk-proj-xxxxxxxxx"` to pass keys directly +- All sensitive data is automatically redacted in GET requests and UI responses for security + +## Advanced Configuration + +### Weighted Load Balancing + +Distribute requests across multiple API keys or providers based on custom weights. This example shows how to split traffic 70/30 between two OpenAI keys, useful for managing rate limits or costs across different accounts. + + + + + +![Weighted Load Balancing Interface](../../media/ui-weighted-provider-keys.png) + +1. Navigate to **"Providers"** β†’ **"OpenAI"** +2. Click **"Add Key"** to add multiple keys +3. Set weight values (0.7 and 0.3) +4. Save configuration + + + + + +```bash +curl --location 'http://localhost:8080/api/providers' \ +--header 'Content-Type: application/json' \ +--data '{ + "provider": "openai", + "keys": [ + { + "value": "env.OPENAI_API_KEY_1", + "models": [], + "weight": 0.7 + }, + { + "value": "env.OPENAI_API_KEY_2", + "models": [], + "weight": 0.3 + } + ] +}' +``` + + + + + +```json +{ + "providers": { + "openai": { + "keys": [ + { + "value": "env.OPENAI_API_KEY_1", + "models": [], + "weight": 0.7 + }, + { + "value": "env.OPENAI_API_KEY_2", + "models": [], + "weight": 0.3 + } + ] + } + } +} +``` + + + + + +### Model-Specific Keys + +Use different API keys for specific models, allowing you to manage access controls and billing separately. This example uses a premium key for advanced reasoning models (o1-preview, o1-mini) and a standard key for regular GPT models. + + + + + +![Model-Specific Keys Interface](../../media/ui-multi-key-for-models.png) + +1. Navigate to **"Providers"** β†’ **"OpenAI"** +2. Add first key with models: `["gpt-4o", "gpt-4o-mini"]` +3. Add premium key with models: `["o1-preview", "o1-mini"]` +4. Save configuration + + + + + +```bash +curl --location 'http://localhost:8080/api/providers' \ +--header 'Content-Type: application/json' \ +--data '{ + "provider": "openai", + "keys": [ + { + "value": "env.OPENAI_API_KEY", + "models": ["gpt-4o", "gpt-4o-mini"], + "weight": 1.0 + }, + { + "value": "env.OPENAI_API_KEY_PREMIUM", + "models": ["o1-preview", "o1-mini"], + "weight": 1.0 + } + ] +}' +``` + + + + + +```json +{ + "providers": { + "openai": { + "keys": [ + { + "value": "env.OPENAI_API_KEY", + "models": ["gpt-4o", "gpt-4o-mini"], + "weight": 1.0 + }, + { + "value": "env.OPENAI_API_KEY_PREMIUM", + "models": ["o1-preview", "o1-mini"], + "weight": 1.0 + } + ] + } + } +} +``` + + + + + +### Custom Network Settings + +Customize the network configuration for each provider, including custom base URLs, extra headers, and timeout settings. This example shows how to use a local OpenAI-compatible server with custom headers for user identification. + + + + + +![Network Configuration Interface](../../media/ui-proxy-setup.png) + +1. Navigate to **"Providers"** β†’ **"OpenAI"** β†’ **"Advanced"** +2. Set **Base URL**: `http://localhost:8000/v1` +3. Set **Timeout**: `30` seconds +4. Save configuration + + + + + +```bash +curl --location 'http://localhost:8080/api/providers' \ +--header 'Content-Type: application/json' \ +--data '{ + "provider": "openai", + "keys": [ + { + "value": "env.OPENAI_API_KEY", + "models": [], + "weight": 1.0 + } + ], + "network_config": { + "base_url": "http://localhost:8000/v1", + "extra_headers": { + "x-user-id": "123" + }, + "default_request_timeout_in_seconds": 30 + } +}' +``` + + + + + +```json +{ + "providers": { + "openai": { + "keys": [ + { + "value": "env.OPENAI_API_KEY", + "models": [], + "weight": 1.0 + } + ], + "network_config": { + "base_url": "http://localhost:8000/v1", + "extra_headers": { + "x-user-id": "123" + }, + "default_request_timeout_in_seconds": 30 + } + } + } +} +``` + + + + + +### Managing Retries + +Configure retry behavior for handling temporary failures and rate limits. This example sets up exponential backoff with up to 5 retries, starting with 1ms delay and capping at 10 seconds - ideal for handling transient network issues. + + + + + +![Retry Configuration Interface](../../media/ui-concurrency-timeout.png) + +1. Navigate to **"Providers"** β†’ **"OpenAI"** β†’ **"Advanced"** +2. Set **Max Retries**: `5` +3. Set **Initial Backoff**: `1` ms +4. Set **Max Backoff**: `10000` ms +5. Save configuration + + + + + +```bash +curl --location 'http://localhost:8080/api/providers' \ +--header 'Content-Type: application/json' \ +--data '{ + "provider": "openai", + "keys": [ + { + "value": "env.OPENAI_API_KEY", + "models": [], + "weight": 1.0 + } + ], + "network_config": { + "max_retries": 5, + "retry_backoff_initial_ms": 1, + "retry_backoff_max_ms": 10000 + } +}' +``` + + + + + +```json +{ + "providers": { + "openai": { + "keys": [ + { + "value": "env.OPENAI_API_KEY", + "models": [], + "weight": 1.0 + } + ], + "network_config": { + "max_retries": 5, + "retry_backoff_initial_ms": 1, + "retry_backoff_max_ms": 10000 + } + } + } +} +``` + + + + + +### Custom Concurrency and Buffer Size + +Fine-tune performance by adjusting worker concurrency and queue sizes per provider (defaults are 1000 workers and 5000 queue size). This example gives OpenAI higher limits (100 workers, 500 queue) for high throughput, while Anthropic gets conservative limits to respect their rate limits. + + + + + +![Concurrency Configuration Interface](../../media/ui-concurrency-buffer-size.png) + +1. Navigate to **"Providers"** β†’ **Provider** β†’ **"Performance"** +2. Set **Concurrency**: Worker count (100 for OpenAI, 25 for Anthropic) +3. Set **Buffer Size**: Queue size (500 for OpenAI, 100 for Anthropic) +4. Save configuration + + + + + +```bash +# OpenAI with high throughput settings +curl --location 'http://localhost:8080/api/providers' \ +--header 'Content-Type: application/json' \ +--data '{ + "provider": "openai", + "keys": [ + { + "value": "env.OPENAI_API_KEY", + "models": [], + "weight": 1.0 + } + ], + "concurrency_and_buffer_size": { + "concurrency": 100, + "buffer_size": 500 + } +}' + +# Anthropic with conservative settings +curl --location 'http://localhost:8080/api/providers' \ +--header 'Content-Type: application/json' \ +--data '{ + "provider": "anthropic", + "keys": [ + { + "value": "env.ANTHROPIC_API_KEY", + "models": [], + "weight": 1.0 + } + ], + "concurrency_and_buffer_size": { + "concurrency": 25, + "buffer_size": 100 + } +}' +``` + + + + + +```json +{ + "providers": { + "openai": { + "keys": [ + { + "value": "env.OPENAI_API_KEY", + "models": [], + "weight": 1.0 + } + ], + "concurrency_and_buffer_size": { + "concurrency": 100, + "buffer_size": 500 + } + }, + "anthropic": { + "keys": [ + { + "value": "env.ANTHROPIC_API_KEY", + "models": [], + "weight": 1.0 + } + ], + "concurrency_and_buffer_size": { + "concurrency": 25, + "buffer_size": 100 + } + } + } +} +``` + + + + + +### Setting Up a Proxy + +Route requests through proxies for compliance, security, or geographic requirements. This example shows both HTTP proxy for OpenAI and authenticated SOCKS5 proxy for Anthropic, useful for corporate environments or regional access. + + + + + +![Proxy Configuration Interface](../../media/ui-proxy-setup.png) + +1. Navigate to **"Providers"** β†’ **Provider** β†’ **"Proxy"** +2. Select **Proxy Type**: HTTP or SOCKS5 +3. Set **Proxy URL**: `http://localhost:8000` +4. Add credentials if needed (username/password) +5. Save configuration + + + + + +```bash +# HTTP proxy for OpenAI +curl --location 'http://localhost:8080/api/providers' \ +--header 'Content-Type: application/json' \ +--data '{ + "provider": "openai", + "keys": [ + { + "value": "env.OPENAI_API_KEY", + "models": [], + "weight": 1.0 + } + ], + "proxy_config": { + "type": "http", + "url": "http://localhost:8000" + } +}' + +# SOCKS5 proxy with authentication for Anthropic +curl --location 'http://localhost:8080/api/providers' \ +--header 'Content-Type: application/json' \ +--data '{ + "provider": "anthropic", + "keys": [ + { + "value": "env.ANTHROPIC_API_KEY", + "models": [], + "weight": 1.0 + } + ], + "proxy_config": { + "type": "socks5", + "url": "http://localhost:8000", + "username": "user", + "password": "password" + } +}' +``` + + + + + +```json +{ + "providers": { + "openai": { + "keys": [ + { + "value": "env.OPENAI_API_KEY", + "models": [], + "weight": 1.0 + } + ], + "proxy_config": { + "type": "http", + "url": "http://localhost:8000" + } + }, + "anthropic": { + "keys": [ + { + "value": "env.ANTHROPIC_API_KEY", + "models": [], + "weight": 1.0 + } + ], + "proxy_config": { + "type": "socks5", + "url": "http://localhost:8000", + "username": "user", + "password": "password" + } + } + } +} +``` + + + + + +### Send Back Raw Response + +Include the original provider response alongside Bifrost's standardized response format. Useful for debugging and accessing provider-specific metadata. + + + + + +![Raw Response Configuration Interface](../../media/ui-raw-response.png) + +1. Navigate to **"Providers"** β†’ **Provider** β†’ **"Advanced"** +2. Toggle **"Include Raw Response"** to enabled +3. Save configuration + + + + + +```bash +curl --location 'http://localhost:8080/api/providers' \ +--header 'Content-Type: application/json' \ +--data '{ + "provider": "openai", + "keys": [ + { + "value": "env.OPENAI_API_KEY", + "models": [], + "weight": 1.0 + } + ], + "send_back_raw_response": true +}' +``` + + + + + +```json +{ + "providers": { + "openai": { + "keys": [ + { + "value": "env.OPENAI_API_KEY", + "models": [], + "weight": 1.0 + } + ], + "send_back_raw_response": true + } + } +} +``` + + + + + +When enabled, the raw provider response appears in `extra_fields.raw_response`: + +```json +{ + "choices": [...], + "usage": {...}, + "extra_fields": { + "provider": "openai", + "raw_response": { + // Original OpenAI response here + } + } +} +``` + +## Provider-Specific Authentication + +Enterprise cloud providers require additional configuration beyond API keys. Configure Azure OpenAI, AWS Bedrock, and Google Vertex with platform-specific authentication details. + +### Azure OpenAI + +Azure OpenAI requires endpoint URLs, deployment mappings, and API version configuration: + + + + + +![Azure OpenAI Configuration Interface](../../media/ui-azure-config.png) + +1. Navigate to **"Providers"** β†’ **"Azure OpenAI"** +2. Set **API Key**: Your Azure API key +3. Set **Endpoint**: Your Azure endpoint URL +4. Configure **Deployments**: Map model names to deployment names +5. Set **API Version**: e.g., `2024-08-01-preview` +6. Save configuration + + + + + +```bash +curl --location 'http://localhost:8080/api/providers' \ +--header 'Content-Type: application/json' \ +--data '{ + "provider": "azure", + "keys": [ + { + "value": "env.AZURE_API_KEY", + "models": ["gpt-4o", "gpt-4o-mini"], + "weight": 1.0, + "azure_key_config": { + "endpoint": "env.AZURE_ENDPOINT", + "deployments": { + "gpt-4o": "gpt-4o-deployment", + "gpt-4o-mini": "gpt-4o-mini-deployment" + }, + "api_version": "2024-08-01-preview" + } + } + ] +}' +``` + + + + + +```json +{ + "providers": { + "azure": { + "keys": [ + { + "value": "env.AZURE_API_KEY", + "models": ["gpt-4o", "gpt-4o-mini"], + "weight": 1.0, + "azure_key_config": { + "endpoint": "env.AZURE_ENDPOINT", + "deployments": { + "gpt-4o": "gpt-4o-deployment", + "gpt-4o-mini": "gpt-4o-mini-deployment" + }, + "api_version": "2024-08-01-preview" + } + } + ] + } + } +} +``` + + + + + +### AWS Bedrock + +AWS Bedrock supports both explicit credentials and IAM role authentication: + + + + + +![AWS Bedrock Configuration Interface](../../media/ui-bedrock-config.png) + +1. Navigate to **"Providers"** β†’ **"AWS Bedrock"** +2. Set **API Key**: AWS API Key (or leave empty if using IAM role authentication) +3. Set **Access Key**: AWS Access Key ID (or leave empty to use IAM in environment) +4. Set **Secret Key**: AWS Secret Access Key (or leave empty to use IAM in environment) +5. Set **Region**: e.g., `us-east-1` +6. Configure **Deployments**: Map model names to inference profiles +7. Set **ARN**: Required for deployments mapping +8. Save configuration + + + + + +```bash +curl --location 'http://localhost:8080/api/providers' \ +--header 'Content-Type: application/json' \ +--data '{ + "provider": "bedrock", + "keys": [ + { + "models": ["anthropic.claude-3-sonnet-20240229-v1:0", "anthropic.claude-v2:1"], + "weight": 1.0, + "bedrock_key_config": { + "access_key": "env.AWS_ACCESS_KEY_ID", + "secret_key": "env.AWS_SECRET_ACCESS_KEY", + "session_token": "env.AWS_SESSION_TOKEN", + "region": "us-east-1", + "deployments": { + "claude-3-sonnet": "us.anthropic.claude-3-sonnet-20240229-v1:0" + }, + "arn": "arn:aws:bedrock:us-east-1:123456789012:inference-profile" + } + } + ] +}' +``` + + + + + +```json +{ + "providers": { + "bedrock": { + "keys": [ + { + "models": ["anthropic.claude-3-sonnet-20240229-v1:0", "anthropic.claude-v2:1"], + "weight": 1.0, + "bedrock_key_config": { + "access_key": "env.AWS_ACCESS_KEY_ID", + "secret_key": "env.AWS_SECRET_ACCESS_KEY", + "session_token": "env.AWS_SESSION_TOKEN", + "region": "us-east-1", + "deployments": { + "claude-3-sonnet": "us.anthropic.claude-3-sonnet-20240229-v1:0" + }, + "arn": "arn:aws:bedrock:us-east-1:123456789012:inference-profile" + } + } + ] + } + } +} +``` + + + + + +**Notes:** +- If using API Key authentication, set `value` field to the API key, else leave it empty for IAM role authentication. +- In IAM role authentication, if both `access_key` and `secret_key` are empty, Bifrost uses IAM role authentication from the environment. +- `arn` is required for URL formation - `deployments` mapping is ignored without it. +- When using `arn` + `deployments`, Bifrost uses model profiles; otherwise forms path with incoming model name directly. + +### Google Vertex + +Google Vertex requires project configuration and authentication credentials: + + + + + +![Google Vertex Configuration Interface](../../media/ui-vertex-config.png) + +1. Navigate to **"Providers"** β†’ **"Google Vertex"** +2. Set **API Key**: Your Vertex API key +3. Set **Project ID**: Your Google Cloud project ID +4. Set **Region**: e.g., `us-central1` +5. Set **Auth Credentials**: Service account credentials JSON +6. Save configuration + + + + + +```bash +curl --location 'http://localhost:8080/api/providers' \ +--header 'Content-Type: application/json' \ +--data '{ + "provider": "vertex", + "keys": [ + { + "value": "env.VERTEX_API_KEY", + "models": ["gemini-pro", "gemini-pro-vision"], + "weight": 1.0, + "vertex_key_config": { + "project_id": "env.VERTEX_PROJECT_ID", + "region": "us-central1", + "auth_credentials": "env.VERTEX_CREDENTIALS" + } + } + ] +}' +``` + + + + + +```json +{ + "providers": { + "vertex": { + "keys": [ + { + "value": "env.VERTEX_API_KEY", + "models": ["gemini-pro", "gemini-pro-vision"], + "weight": 1.0, + "vertex_key_config": { + "project_id": "env.VERTEX_PROJECT_ID", + "region": "us-central1", + "auth_credentials": "env.VERTEX_CREDENTIALS" + } + } + ] + } + } +} +``` + + + + + +## Next Steps + +Now that you understand provider configuration, explore these related topics: + +### Essential Topics + +- **[Streaming Responses](./streaming)** - Real-time response generation +- **[Tool Calling](./tool-calling)** - Enable AI to use external functions +- **[Multimodal AI](./multimodal)** - Process images, audio, and text +- **[Integrations](./integrations)** - Drop-in compatibility with existing SDKs + +### Advanced Topics + +- **[Core Features](../../features/)** - Advanced Bifrost capabilities +- **[Architecture](../../architecture/)** - How Bifrost works internally +- **[Deployment](../../deployment/)** - Production setup and scaling diff --git a/docs/quickstart/gateway/setting-up.mdx b/docs/quickstart/gateway/setting-up.mdx new file mode 100644 index 000000000..d7afaf5f4 --- /dev/null +++ b/docs/quickstart/gateway/setting-up.mdx @@ -0,0 +1,203 @@ +--- +title: "Setting Up" +description: "Get Bifrost running as an HTTP API gateway in 30 seconds with zero configuration. Perfect for any programming language." +icon: "play" +--- + +![Bifrost Gateway Installation](../../media/getting-started.png) + +## 30-Second Setup + +Get Bifrost running as a blazing-fast HTTP API gateway with **zero configuration**. Connect to any AI provider (OpenAI, Anthropic, Bedrock, and more) through a unified API that follows **OpenAI request/response format**. + +### 1. Choose Your Setup Method + +Both options work perfectly - choose what fits your workflow: + +#### NPX Binary + + + +```bash +# Install and run locally +npx -y @maximhq/bifrost +``` + +#### Docker + +```bash +# Pull and run Bifrost HTTP API +docker pull maximhq/bifrost +docker run -p 8080:8080 maximhq/bifrost +``` + +**For Data Persistence** + +```bash +# For configuration persistence across restarts +docker run -p 8080:8080 -v $(pwd)/data:/app/data maximhq/bifrost +``` +### 2. Configuration Flags + +| Flag | Default | NPX | Docker | Description | +|------|---------|-----|--------|-------------| +| port | 8080 | `-port 8080` | `-e APP_PORT=8080 -p 8080:8080` | HTTP server port | +| host | localhost | `-host 0.0.0.0` | `-e APP_HOST=0.0.0.0` | Host to bind server to | +| log-level | info | `-log-level info` | `-e LOG_LEVEL=info` | Log level (debug, info, warn, error) | +| log-style | json | `-log-style json` | `-e LOG_STYLE=json` | Log style (pretty, json) | + + +**Understanding App Directory** + +The `-app-dir` flag determines where Bifrost stores all its data: + +```bash +# Specify custom directory +npx -y @maximhq/bifrost -app-dir ./my-bifrost-data + +# If not specified, creates in your OS config directory: +# β€’ Linux/macOS: ~/.config/bifrost +# β€’ Windows: %APPDATA%\bifrost +``` + +**What's stored in app-dir:** +- `config.json` - Configuration file (optional) +- `config.db` - SQLite database for UI configuration +- `logs.db` - Request logs database + +**Note:** When using Bifrost via Docker, the volume you mount will be used as the app-dir. + +### 3. Open the Web Interface + +Navigate to **http://localhost:8080** in your browser: + +```bash +# macOS +open http://localhost:8080 + +# Linux +xdg-open http://localhost:8080 + +# Windows +start http://localhost:8080 +``` + +πŸ–₯️ **The Web UI provides:** +- **Visual provider setup** - Add API keys with clicks, not code +- **Real-time configuration** - Changes apply immediately +- **Live monitoring** - Request logs, metrics, and analytics +- **Governance management** - Virtual keys, usage budgets, and more + +### 4. Test Your First API Call + +```bash +curl -X POST http://localhost:8080/v1/chat/completions \ + -H "Content-Type: application/json" \ + -d '{ + "model": "openai/gpt-4o-mini", + "messages": [{"role": "user", "content": "Hello, Bifrost!"}] + }' +``` + +**πŸŽ‰ That's it!** Bifrost is running and ready to route AI requests. + +### What Just Happened? + +1. **Zero Configuration Start**: Bifrost launched without any config files - everything can be configured through the Web UI or API +2. **OpenAI-Compatible API**: All Bifrost APIs follow OpenAI request/response format for seamless integration +3. **Unified API Endpoint**: `/v1/chat/completions` works with any provider (OpenAI, Anthropic, Bedrock, etc.) +4. **Provider Resolution**: `openai/gpt-4o-mini` tells Bifrost to use OpenAI's GPT-4o Mini model +5. **Automatic Routing**: Bifrost handles authentication, rate limiting, and request routing automatically + +--- + +## Two Configuration Modes + +Bifrost supports **two configuration approaches** - you cannot use both simultaneously: + +### Mode 1: Web UI Configuration + +![Configuration via UI](../../media/ui-config.png) + +**When the UI is available:** +- No `config.json` file exists (Bifrost auto-creates SQLite database) +- `config.json` exists with `config_store` configured + +### Mode 2: File-based Configuration + +**When to use:** Advanced setups, GitOps workflows, or when UI is not needed + +Create `config.json` in your app directory: + +```json +{ + "client": { + "drop_excess_requests": false + }, + "providers": { + "openai": { + "keys": [ + { + "value": "env.OPENAI_API_KEY", + "models": ["gpt-4o-mini", "gpt-4o"], + "weight": 1.0 + } + ] + } + }, + "config_store": { + "enabled": true, + "type": "sqlite", + "config": { + "path": "./config.db" + } + } +} +``` + +**Without `config_store` in `config.json`:** +- **UI is disabled** - no real-time configuration possible +- **Read-only mode** - `config.json` is never modified +- **Memory-only** - all configurations loaded into memory at startup +- **Restart required** - changes to `config.json` only apply after restart + +**With `config_store` in `config.json`:** +- **UI is enabled** - full real-time configuration via web interface +- **Database check** - Bifrost checks if config store database exists and has data + - **Empty DB**: Bootstraps database with `config.json` settings, then uses DB exclusively + - **Existing DB**: Uses database directly, **ignores** `config.json` configurations +- **Persistent storage** - all changes saved to database immediately + +**Important for Advanced Users:** +If you want database persistence but prefer not to use the UI, note that modifying `config.json` after initial bootstrap has no effect when `config_store` is enabled. Use the public HTTP APIs to make configuration changes instead. + +**The Three Stores Explained:** +- **Config Store**: Stores provider configs, API keys, MCP settings - Required for UI functionality +- **Logs Store**: Stores request logs shown in UI - Optional, can be disabled +- **Vector Store**: Used for semantic caching - Optional, can be disabled + +--- + +## Next Steps + +Now that you have Bifrost running, explore these focused guides: + +### Essential Topics + +- **[Provider Configuration](./provider-configuration)** - Multiple providers, automatic failovers & load balancing +- **[Integrations](../../integrations/what-is-an-integration)** - Drop-in replacements for OpenAI, Anthropic, and GenAI SDKs +- **[Multimodal Support](./multimodal)** - Support for text, images, audio, and streaming, all behind a common interface. + +### Advanced Topics + +- **[Tracing](../../features/tracing)** - Logging requests for monitoring and debugging +- **[MCP Tools](../../features/mcp)** - Enable AI models to use external tools (filesystem, web search, databases) +- **[Governance](../../features/governance)** - Usage tracking, rate limiting, and cost control +- **[Deployment](../../deployment/docker-setup)** - Production setup and scaling + +--- + +**Happy building with Bifrost!** πŸš€ diff --git a/docs/quickstart/gateway/streaming.mdx b/docs/quickstart/gateway/streaming.mdx new file mode 100644 index 000000000..2dd2dd362 --- /dev/null +++ b/docs/quickstart/gateway/streaming.mdx @@ -0,0 +1,119 @@ +--- +title: "Streaming Responses" +description: "Receive AI responses in real-time via Server-Sent Events. Perfect for chat applications, audio processing, and real-time transcription where you want immediate results." +icon: "water" +--- + +## Streaming Chat Responses + +Receive AI responses in real-time as they're generated. Perfect for chat applications where you want to show responses as they're being typed, improving user experience. + +```bash +curl --location 'http://localhost:8080/v1/chat/completions' \ +--header 'Content-Type: application/json' \ +--data '{ + "model": "openai/gpt-4o-mini", + "messages": [ + {"role": "user", "content": "Tell me a story about a robot learning to paint"} + ], + "stream": true +}' +``` + +**Response Format (Server-Sent Events):** +``` +data: {"choices":[{"delta":{"content":"Once"}}],"model":"gpt-4o-mini"} + +data: {"choices":[{"delta":{"content":" upon"}}],"model":"gpt-4o-mini"} + +data: {"choices":[{"delta":{"content":" a"}}],"model":"gpt-4o-mini"} + +data: [DONE] +``` + +Each chunk contains partial content that you can append to build the complete response in real-time. + +> **Note:** Streaming requests also follow the default timeout setting defined in provider configuration, which defaults to **30 seconds**. + + +Bifrost standardizes all stream responses to send usage and finish reason only in the last chunk, and content in the previous chunks. + + +## Text-to-Speech Streaming: Real-time Audio Generation + +Stream audio generation in real-time as text is converted to speech. Ideal for long texts or when you need immediate audio playback. + +```bash +curl --location 'http://localhost:8080/v1/audio/speech' \ +--header 'Content-Type: application/json' \ +--data '{ + "model": "openai/gpt-4o-mini-tts", + "input": "Hello this is a sample test, respond with hello for my Bifrost", + "voice": "alloy", + "stream_format": "sse" +}' +``` + +**Response:** Audio chunks are delivered via Server-Sent Events. Each chunk contains base64-encoded audio data that you can decode and play or save progressively. + +``` +data: {"audio":"UklGRigAAABXQVZFZm10IBAAAAABAAEA..."} + +data: {"audio":"AKlFQVZFZm10IBAAAAABAAEAq..."} + +data: [DONE] +``` + +**To save the stream:** Add `> audio_stream.txt` to redirect output to a file. + +## Speech-to-Text Streaming: Real-time Audio Transcription + +Stream audio transcription results as they're processed. Get immediate text output for real-time applications or long audio files. + +```bash +curl --location 'http://localhost:8080/v1/audio/transcriptions' \ +--form 'file=@"/path/to/your/audio.mp3"' \ +--form 'model="openai/gpt-4o-transcribe"' \ +--form 'stream="true"' \ +--form 'response_format="json"' +``` + +**Response Format:** +``` +data: {"text":"Hello"} + +data: {"text":" this"} + +data: {"text":" is"} + +data: {"text":" a sample"} + +data: [DONE] +``` + +**Additional options:** Add `--form 'language="en"'` or `--form 'prompt="context hint"'` for better accuracy. + +## Audio Format Support + +**Speech Synthesis:** Supports `"response_format": "mp3"` (default) and `"response_format": "wav"` + +**Transcription Input:** Accepts MP3, WAV, M4A, and other common audio formats + +> **Note:** Streaming capabilities vary by provider and model. Check each provider's documentation for specific streaming support and limitations. + +## Next Steps + +Now that you understand streaming responses, explore these related topics: + +### Essential Topics + +- **[Tool Calling](./tool-calling)** - Enable AI models to use external tools and functions +- **[Multimodal AI](./multimodal)** - Process images, audio, and multimedia content +- **[Provider Configuration](./provider-configuration)** - Multiple providers for redundancy +- **[Integrations](./integrations)** - Drop-in compatibility with existing SDKs + +### Advanced Topics + +- **[Core Features](../../features/)** - Advanced Bifrost capabilities +- **[Architecture](../../architecture/)** - How Bifrost works internally +- **[Deployment](../../deployment/)** - Production setup and scaling diff --git a/docs/quickstart/gateway/tool-calling.mdx b/docs/quickstart/gateway/tool-calling.mdx new file mode 100644 index 000000000..9117559a2 --- /dev/null +++ b/docs/quickstart/gateway/tool-calling.mdx @@ -0,0 +1,165 @@ +--- +title: "Tool Calling" +description: "Enable AI models to use external functions and services by defining tool schemas or connecting to Model Context Protocol (MCP) servers. This allows AI to interact with databases, APIs, file systems, and more." +icon: "wrench" +--- + +## Function Calling with Custom Tools + +Enable AI models to use external functions by defining tool schemas using OpenAI format. Models can then call these functions automatically based on user requests. + +```bash +curl --location 'http://localhost:8080/v1/chat/completions' \ +--header 'Content-Type: application/json' \ +--data '{ + "model": "openai/gpt-4o-mini", + "messages": [ + {"role": "user", "content": "What is 15 + 27? Use the calculator tool."} + ], + "tools": [ + { + "type": "function", + "function": { + "name": "calculator", + "description": "A calculator tool for basic arithmetic operations", + "parameters": { + "type": "object", + "properties": { + "operation": { + "type": "string", + "description": "The operation to perform", + "enum": ["add", "subtract", "multiply", "divide"] + }, + "a": { + "type": "number", + "description": "The first number" + }, + "b": { + "type": "number", + "description": "The second number" + } + }, + "required": ["operation", "a", "b"] + } + } + } + ], + "tool_choice": "auto" +}' +``` + +**Response includes tool calls:** +```json +{ + "choices": [{ + "message": { + "role": "assistant", + "tool_calls": [{ + "id": "call_abc123", + "type": "function", + "function": { + "name": "calculator", + "arguments": "{\"operation\":\"add\",\"a\":15,\"b\":27}" + } + }] + } + }] +} +``` + +## Connecting to MCP Servers + +Connect to Model Context Protocol (MCP) servers to give AI models access to external tools and services without manually defining each function. + + + +![MCP Configuration Interface](../../media/ui-mcp-config.png) + +1. Go to **http://localhost:8080** +2. Navigate to **"MCP Clients"** in the sidebar +3. Click **"Add MCP Client"** +4. Enter server details and save + + + +```bash +curl --location 'http://localhost:8080/api/mcp/client' \ +--header 'Content-Type: application/json' \ +--data '{ + "name": "filesystem", + "connection_type": "stdio", + "stdio_config": { + "command": ["npx", "@modelcontextprotocol/server-filesystem", "/tmp"], + "args": [] + } +}' +``` + +**List configured MCP clients:** +```bash +curl --location 'http://localhost:8080/api/mcp/clients' +``` + + + +```json +{ + "mcp": { + "client_configs": [ + { + "name": "filesystem", + "connection_type": "stdio", + "stdio_config": { + "command": ["npx", "@modelcontextprotocol/server-filesystem", "/tmp"], + "args": [] + } + }, + { + "name": "youtube-search", + "connection_type": "http", + "connection_string": "http://your-youtube-mcp-url" + } + ] + } +} +``` + + + + +Read more about MCP connections and advanced end to end tool execution in the [MCP Features](../../features/mcp) section. + +## Tool Choice Options + +Control how the AI uses tools: + +```bash +# Force use of specific tool +"tool_choice": { + "type": "function", + "function": {"name": "calculator"} +} + +# Let AI decide automatically (default) +"tool_choice": "auto" + +# Disable tool usage +"tool_choice": "none" +``` + +## Next Steps + +Now that you understand tool calling, explore these related topics: + +### Essential Topics + +- **[Multimodal AI](./multimodal)** - Process images, audio, and multimedia content +- **[Streaming Responses](./streaming)** - Real-time response generation with tool calls +- **[Provider Configuration](./provider-configuration)** - Multiple providers for redundancy +- **[Integrations](./integrations)** - Drop-in compatibility with existing SDKs + +### Advanced Topics + +- **[MCP Features](../../features/mcp)** - Advanced MCP server management and configuration +- **[Core Features](../../features/)** - Advanced Bifrost capabilities +- **[Architecture](../../architecture/)** - How Bifrost works internally diff --git a/docs/quickstart/go-sdk/logger.mdx b/docs/quickstart/go-sdk/logger.mdx new file mode 100644 index 000000000..e69de29bb diff --git a/docs/quickstart/go-sdk/multimodal.mdx b/docs/quickstart/go-sdk/multimodal.mdx new file mode 100644 index 000000000..3c3d8cf6e --- /dev/null +++ b/docs/quickstart/go-sdk/multimodal.mdx @@ -0,0 +1,357 @@ +--- +title: "Multimodal Support" +description: "Process multiple types of content including images, audio, and text with AI models. Bifrost supports vision analysis, speech synthesis, and audio transcription across various providers." +icon: "images" +--- + +## Vision: Analyzing Images with AI + +Send images to vision-capable models for analysis, description, and understanding. This example shows how to analyze an image from a URL using GPT-4o with high detail processing for better accuracy. + +```go +response, err := client.ChatCompletionRequest(context.Background(), &schemas.BifrostRequest{ + Provider: schemas.OpenAI, + Model: "gpt-4o", // Using vision-capable model + Input: schemas.RequestInput{ + ChatCompletionInput: &[]schemas.BifrostMessage{ + { + Role: schemas.ModelChatMessageRoleUser, + Content: schemas.MessageContent{ + ContentBlocks: &[]schemas.ContentBlock{ + { + Type: schemas.ContentBlockTypeText, + Text: bifrost.Ptr("What do you see in this image? Please describe it in detail."), + }, + { + Type: schemas.ContentBlockTypeImage, + ImageURL: &schemas.ImageURLStruct{ + URL: "https://upload.wikimedia.org/wikipedia/commons/thumb/d/dd/Gfp-wisconsin-madison-the-nature-boardwalk.jpg/2560px-Gfp-wisconsin-madison-the-nature-boardwalk.jpg", + Detail: bifrost.Ptr("high"), // Optional: can be "low", "high", or "auto" + }, + }, + }, + }, + }, + }, + }, +}) + +if err != nil { + panic(err) +} + +fmt.Println("Response:", *response.Choices[0].Message.Content.ContentStr) +``` + +## Audio Understanding: Analyzing Audio with AI + +If your chat application supports text input, you can add audio input and outputβ€”just include audio in the modalities array and use an audio model, like gpt-4o-audio-preview. + +### Audio Input to Model + +```go +response, err := client.ChatCompletionRequest(context.Background(), &schemas.BifrostRequest{ + Provider: schemas.OpenAI, + Model: "gpt-4o-audio-preview", + Input: schemas.RequestInput{ + ChatCompletionInput: &[]schemas.BifrostMessage{ + { + Role: schemas.ModelChatMessageRoleUser, + Content: schemas.MessageContent{ + ContentBlocks: &[]schemas.ContentBlock{ + { + Type: schemas.ContentBlockTypeText, + Text: bifrost.Ptr("Please analyze this audio recording and summarize what was discussed."), + }, + { + Type: schemas.ContentBlockTypeInputAudio, + InputAudio: &schemas.InputAudioStruct{ + Data: <"base64-encoded audio data containing the word 'Affirmative'>", + Format: "wav", + }, + }, + }, + }, + }, + }, + }, +}) +``` + +## Text-to-Speech: Converting Text to Audio + +Convert text into natural-sounding speech using AI voice models. This example demonstrates generating an MP3 audio file from text using the "alloy" voice. The result is saved to a local file for playback. + +```go +response, err := client.SpeechRequest(context.Background(), &schemas.BifrostRequest{ + Provider: schemas.OpenAI, + Model: "tts-1", // Using text-to-speech model + Input: schemas.RequestInput{ + SpeechInput: &schemas.SpeechInput{ + Input: "Hello! This is a sample text that will be converted to speech using Bifrost's speech synthesis capabilities. The weather today is wonderful, and I hope you're having a great day!", + VoiceConfig: schemas.SpeechVoiceInput{ + Voice: bifrost.Ptr("alloy"), + }, + ResponseFormat: "mp3", + }, + }, +}) + +if err != nil { + panic(err) +} + +// Handle speech synthesis response +if response.Speech != nil && len(response.Speech.Audio) > 0 { + // Save the audio to a file + filename := "output.mp3" + err := os.WriteFile("output.mp3", response.Speech.Audio, 0644) + if err != nil { + panic(fmt.Sprintf("Failed to save audio file: %v", err)) + } + + fmt.Printf("Speech synthesis successful! Audio saved to %s, file size: %d bytes\n", filename, len(response.Speech.Audio)) +} +``` + +## Speech-to-Text: Transcribing Audio Files + +Convert audio files into text using AI transcription models. This example shows how to transcribe an MP3 file using OpenAI's Whisper model, with an optional context prompt to improve accuracy. + +```go +// Read the audio file for transcription +audioFilename := "output.mp3" +audioData, err := os.ReadFile(audioFilename) +if err != nil { + panic(fmt.Sprintf("Failed to read audio file %s: %v. Please make sure the file exists.", audioFilename, err)) +} + +fmt.Printf("Loaded audio file %s (%d bytes) for transcription...\n", audioFilename, len(audioData)) + +response, err := client.TranscriptionRequest(context.Background(), &schemas.BifrostRequest{ + Provider: schemas.OpenAI, + Model: "whisper-1", // Using Whisper model for transcription + Input: schemas.RequestInput{ + TranscriptionInput: &schemas.TranscriptionInput{ + File: audioData, + Prompt: bifrost.Ptr("This is a sample audio transcription from Bifrost speech synthesis."), // Optional: provide context + }, + }, +}) + +if err != nil { + panic(err) +} + +fmt.Printf("Transcription Result: %s\n", response.Transcribe.Text) +``` + +## Advanced Vision Examples + +### Multiple Images + +Send multiple images in a single request for comparison or analysis. This is useful for comparing products, analyzing changes over time, or understanding relationships between different visual elements. + +```go +response, err := client.ChatCompletionRequest(context.Background(), &schemas.BifrostRequest{ + Provider: schemas.OpenAI, + Model: "gpt-4o", + Input: schemas.RequestInput{ + ChatCompletionInput: &[]schemas.BifrostMessage{ + { + Role: schemas.ModelChatMessageRoleUser, + Content: schemas.MessageContent{ + ContentBlocks: &[]schemas.ContentBlock{ + { + Type: schemas.ContentBlockTypeText, + Text: bifrost.Ptr("Compare these two images. What are the differences?"), + }, + { + Type: schemas.ContentBlockTypeImage, + ImageURL: &schemas.ImageURLStruct{ + URL: "https://example.com/image1.jpg", + }, + }, + { + Type: schemas.ContentBlockTypeImage, + ImageURL: &schemas.ImageURLStruct{ + URL: "https://example.com/image2.jpg", + }, + }, + }, + }, + }, + }, + }, +}) +``` + +### Base64 Images + +Process local images by encoding them as base64 data URLs. This approach is ideal when you need to analyze images stored locally on your system without uploading them to external URLs first. + +```go +// Read and encode image +imageData, err := os.ReadFile("local_image.jpg") +if err != nil { + panic(err) +} +base64Image := base64.StdEncoding.EncodeToString(imageData) +dataURL := fmt.Sprintf("data:image/jpeg;base64,%s", base64Image) + +response, err := client.ChatCompletionRequest(context.Background(), &schemas.BifrostRequest{ + Provider: schemas.OpenAI, + Model: "gpt-4o", + Input: schemas.RequestInput{ + ChatCompletionInput: &[]schemas.BifrostMessage{ + { + Role: schemas.ModelChatMessageRoleUser, + Content: schemas.MessageContent{ + ContentBlocks: &[]schemas.ContentBlock{ + { + Type: schemas.ContentBlockTypeText, + Text: bifrost.Ptr("Analyze this image and describe what you see."), + }, + { + Type: schemas.ContentBlockTypeImage, + ImageURL: &schemas.ImageURLStruct{ + URL: dataURL, + Detail: bifrost.Ptr("high"), + }, + }, + }, + }, + }, + }, + }, +}) +``` + +## Audio Configuration Options + +### Voice Selection for Speech Synthesis + +OpenAI provides six distinct voice options, each with different characteristics. This example generates sample audio files for each voice so you can compare and choose the one that best fits your application. + +```go +// Available voices: alloy, echo, fable, onyx, nova, shimmer +voices := []string{"alloy", "echo", "fable", "onyx", "nova", "shimmer"} + +for _, voice := range voices { + response, err := client.SpeechRequest(context.Background(), &schemas.BifrostRequest{ + Provider: schemas.OpenAI, + Model: "tts-1", + Input: schemas.RequestInput{ + SpeechInput: &schemas.SpeechInput{ + Input: fmt.Sprintf("This is the %s voice speaking.", voice), + VoiceConfig: schemas.SpeechVoiceInput{ + Voice: bifrost.Ptr(voice), + }, + ResponseFormat: "mp3", + }, + }, + }) + + if err == nil && response.Speech != nil { + filename := fmt.Sprintf("sample_%s.mp3", voice) + os.WriteFile(filename, response.Speech.Audio, 0644) + fmt.Printf("Generated %s\n", filename) + } +} +``` + +### Audio Formats + +Generate audio in different formats depending on your use case. MP3 for general use, Opus for web streaming, AAC for mobile apps, and FLAC for high-quality audio applications. + +```go +formats := []string{"mp3", "opus", "aac", "flac"} + +for _, format := range formats { + response, err := client.SpeechRequest(context.Background(), &schemas.BifrostRequest{ + Provider: schemas.OpenAI, + Model: "tts-1", + Input: schemas.RequestInput{ + SpeechInput: &schemas.SpeechInput{ + Input: "Testing different audio formats.", + VoiceConfig: schemas.SpeechVoiceInput{Voice: bifrost.Ptr("alloy")}, + ResponseFormat: format, + }, + }, + }) + + if err == nil && response.Speech != nil { + filename := fmt.Sprintf("output.%s", format) + os.WriteFile(filename, response.Speech.Audio, 0644) + } +} +``` + +## Transcription Options + +### Language Specification + +Improve transcription accuracy by specifying the source language. This is particularly helpful for non-English audio or when the audio contains technical terms or specific domain vocabulary. + +```go +response, err := client.TranscriptionRequest(context.Background(), &schemas.BifrostRequest{ + Provider: schemas.OpenAI, + Model: "whisper-1", + Input: schemas.RequestInput{ + TranscriptionInput: &schemas.TranscriptionInput{ + File: audioData, + Language: bifrost.Ptr("es"), // Spanish + Prompt: bifrost.Ptr("This is a Spanish audio recording about technology."), + }, + }, +}) +``` + +### Response Formats + +Choose between simple text output or detailed JSON responses with timestamps. The verbose JSON format provides word-level and segment-level timing information, useful for creating subtitles or analyzing speech patterns. + +```go +// Text only +response, err := client.TranscriptionRequest(context.Background(), &schemas.BifrostRequest{ + Provider: schemas.OpenAI, + Model: "whisper-1", + Input: schemas.RequestInput{ + TranscriptionInput: &schemas.TranscriptionInput{ + File: audioData, + ResponseFormat: bifrost.Ptr("text"), + }, + }, +}) + +// JSON with timestamps +response, err := client.TranscriptionRequest(context.Background(), &schemas.BifrostRequest{ + Provider: schemas.OpenAI, + Model: "whisper-1", + Input: schemas.RequestInput{ + TranscriptionInput: &schemas.TranscriptionInput{ + File: audioData, + ResponseFormat: bifrost.Ptr("verbose_json"), + TimestampGranularities: &[]string{"word", "segment"}, + }, + }, +}) +``` + +## Provider Support + +Different providers support different multimodal capabilities: + +| Provider | Vision | Text-to-Speech | Speech-to-Text | +|----------|--------|----------------|----------------| +| OpenAI | βœ… GPT-4V, GPT-4o | βœ… TTS-1, TTS-1-HD | βœ… Whisper | +| Anthropic | βœ… Claude 3 Sonnet/Opus | ❌ | ❌ | +| Google Vertex | βœ… Gemini Pro Vision | βœ… | βœ… | +| Azure OpenAI | βœ… GPT-4V | βœ… | βœ… Whisper | + +## Next Steps + +- **[Streaming Responses](./streaming)** - Real-time multimodal processing +- **[Tool Calling](./tool-calling)** - Combine with external tools +- **[Provider Configuration](./provider-configuration)** - Multiple providers for different capabilities +- **[Core Features](../../features/)** - Advanced Bifrost capabilities diff --git a/docs/quickstart/go-sdk/provider-configuration.mdx b/docs/quickstart/go-sdk/provider-configuration.mdx new file mode 100644 index 000000000..142ad8ec5 --- /dev/null +++ b/docs/quickstart/go-sdk/provider-configuration.mdx @@ -0,0 +1,401 @@ +--- +title: "Provider Configuration" +description: "Configure multiple AI providers for custom concurrency, queue sizes, proxy settings, and more." +icon: "sliders" +--- + +## Multi-Provider Setup + +Configure multiple providers to seamlessly switch between them. This example shows how to configure OpenAI, Anthropic, and Mistral providers. + +```go +type MyAccount struct{} + +func (a *MyAccount) GetConfiguredProviders() ([]schemas.ModelProvider, error) { + return []schemas.ModelProvider{schemas.OpenAI, schemas.Anthropic, schemas.Mistral}, nil +} + +func (a *MyAccount) GetKeysForProvider(ctx *context.Context, provider schemas.ModelProvider) ([]schemas.Key, error) { + switch provider { + case schemas.OpenAI: + return []schemas.Key{{ + Value: os.Getenv("OPENAI_API_KEY"), + Models: []string{}, + Weight: 1.0, + }}, nil + case schemas.Anthropic: + return []schemas.Key{{ + Value: os.Getenv("ANTHROPIC_API_KEY"), + Models: []string{}, + Weight: 1.0, + }}, nil + case schemas.Mistral: + return []schemas.Key{{ + Value: os.Getenv("MISTRAL_API_KEY"), + Models: []string{}, + Weight: 1.0, + }}, nil + } + return nil, fmt.Errorf("provider %s not supported", provider) +} + +func (a *MyAccount) GetConfigForProvider(provider schemas.ModelProvider) (*schemas.ProviderConfig, error) { + // Return same config for all providers + return &schemas.ProviderConfig{ + NetworkConfig: schemas.DefaultNetworkConfig, + ConcurrencyAndBufferSize: schemas.DefaultConcurrencyAndBufferSize, + }, nil +} +``` + +> If Bifrost receives a new provider at runtime (i.e., one that is not returned by `GetConfiguredProviders()` initially on `bifrost.Init()`), it will set up the provider at runtime using `GetConfigForProvider()`, which may cause a delay in the first request to that provider. + +## Making Requests + +Once providers are configured, you can make requests to any specific provider. This example shows how to send a request directly to Mistral's latest vision model. Bifrost handles the provider-specific API formatting automatically. + +```go +response, err := client.ChatCompletionRequest(context.Background(), &schemas.BifrostRequest{ + Provider: schemas.Mistral, + Model: "pixtral-12b-latest", + Input: schemas.RequestInput{ + ChatCompletionInput: &messages, + }, +}) +``` + +## Environment Variables + +Set up your API keys for the providers you want to use: + +```bash +export OPENAI_API_KEY="your-openai-api-key" +export ANTHROPIC_API_KEY="your-anthropic-api-key" +export MISTRAL_API_KEY="your-mistral-api-key" +export GROQ_API_KEY="your-groq-api-key" +export COHERE_API_KEY="your-cohere-api-key" +``` + +## Advanced Configuration + +### Weighted Load Balancing + +Distribute requests across multiple API keys or providers based on custom weights. This example shows how to split traffic 70/30 between two OpenAI keys, useful for managing rate limits or costs across different accounts. + +```go +func (a *MyAccount) GetKeysForProvider(ctx *context.Context, provider schemas.ModelProvider) ([]schemas.Key, error) { + switch provider { + case schemas.OpenAI: + return []schemas.Key{{ + Value: os.Getenv("OPENAI_API_KEY_1"), + Models: []string{}, + Weight: 0.7, // 70% of requests + }, + { + Value: os.Getenv("OPENAI_API_KEY_2"), + Models: []string{}, + Weight: 0.3, // 30% of requests + }, + }, nil + } + return nil, fmt.Errorf("provider %s not supported", provider) +} +``` + +### Model-Specific Keys + +Use different API keys for specific models, allowing you to manage access controls and billing separately. This example uses a premium key for advanced reasoning models (o1-preview, o1-mini) and a standard key for regular GPT models. + +```go +func (a *MyAccount) GetKeysForProvider(ctx *context.Context, provider schemas.ModelProvider) ([]schemas.Key, error) { + switch provider { + case schemas.OpenAI: + return []schemas.Key{ + { + Value: os.Getenv("OPENAI_API_KEY"), + Models: []string{"gpt-4o", "gpt-4o-mini"}, + Weight: 1.0, + }, + { + Value: os.Getenv("OPENAI_API_KEY_PREMIUM"), + Models: []string{"o1-preview", "o1-mini"}, + Weight: 1.0, + }, + }, nil + } + return nil, fmt.Errorf("provider %s not supported", provider) +} +``` + +### Custom Network Settings + +Customize the network configuration for each provider, including custom base URLs, extra headers, and timeout settings. This example shows how to use a local OpenAI-compatible server with custom headers for user identification. + +```go +func (a *MyAccount) GetConfigForProvider(provider schemas.ModelProvider) (*schemas.ProviderConfig, error) { + switch provider { + case schemas.OpenAI: + return &schemas.ProviderConfig{ + NetworkConfig: schemas.NetworkConfig{ + BaseURL: "http://localhost:8000/v1", // Custom openai setup + ExtraHeaders: map[string]string{ // Will be included in the request headers + "x-user-id": "123", + }, + DefaultRequestTimeoutInSeconds: 30, + }, + ConcurrencyAndBufferSize: schemas.DefaultConcurrencyAndBufferSize, + }, nil + } + return nil, fmt.Errorf("provider %s not supported", provider) +} +``` +### Managing Retries + +Configure retry behavior for handling temporary failures and rate limits. This example sets up exponential backoff with up to 5 retries, starting with 1ms delay and capping at 10 seconds - ideal for handling transient network issues. + +```go +func (a *MyAccount) GetConfigForProvider(provider schemas.ModelProvider) (*schemas.ProviderConfig, error) { + switch provider { + case schemas.OpenAI: + return &schemas.ProviderConfig{ + NetworkConfig: schemas.NetworkConfig{ + MaxRetries: 5, + RetryBackoffInitial: 1 * time.Millisecond, + RetryBackoffMax: 10 * time.Second, + }, + ConcurrencyAndBufferSize: schemas.DefaultConcurrencyAndBufferSize, + }, nil + } + return nil, fmt.Errorf("provider %s not supported", provider) +} +``` + +### Custom Concurrency and Buffer Size + +Fine-tune performance by adjusting worker concurrency and queue sizes per provider (defaults are 1000 workers and 5000 queue size). This example gives OpenAI higher limits (100 workers, 500 queue) for high throughput, while Anthropic gets conservative limits to respect their rate limits. + +```go +func (a *MyAccount) GetConfigForProvider(provider schemas.ModelProvider) (*schemas.ProviderConfig, error) { + switch provider { + case schemas.OpenAI: + return &schemas.ProviderConfig{ + NetworkConfig: schemas.DefaultNetworkConfig, + ConcurrencyAndBufferSize: schemas.ConcurrencyAndBufferSize{ + MaxConcurrency: 100, // Max number of concurrent requests (no of workers) + BufferSize: 500, // Max number of requests in the buffer (queue size) + }, + }, nil + case schemas.Anthropic: + return &schemas.ProviderConfig{ + NetworkConfig: schemas.DefaultNetworkConfig, + ConcurrencyAndBufferSize: schemas.ConcurrencyAndBufferSize{ + MaxConcurrency: 25, + BufferSize: 100, + }, + }, nil + } + return nil, fmt.Errorf("provider %s not supported", provider) +} +``` + +### Setting Up a Proxy + +Route requests through proxies for compliance, security, or geographic requirements. This example shows both HTTP proxy for OpenAI and authenticated SOCKS5 proxy for Anthropic, useful for corporate environments or regional access. + +```go +func (a *MyAccount) GetConfigForProvider(provider schemas.ModelProvider) (*schemas.ProviderConfig, error) { + switch provider { + case schemas.OpenAI: + return &schemas.ProviderConfig{ + NetworkConfig: schemas.DefaultNetworkConfig, + ConcurrencyAndBufferSize: schemas.DefaultConcurrencyAndBufferSize, + ProxyConfig: &schemas.ProxyConfig{ + Type: schemas.HttpProxy, + URL: "http://localhost:8000", // Proxy URL + }, + }, nil + case schemas.Anthropic: + return &schemas.ProviderConfig{ + NetworkConfig: schemas.DefaultNetworkConfig, + ConcurrencyAndBufferSize: schemas.DefaultConcurrencyAndBufferSize, + ProxyConfig: &schemas.ProxyConfig{ + Type: schemas.Socks5Proxy, + URL: "http://localhost:8000", // Proxy URL + Username: "user", + Password: "password", + }, + }, nil + } + return nil, fmt.Errorf("provider %s not supported", provider) +} +``` + +### Send Back Raw Response + +Include the original provider response alongside Bifrost's standardized response format. Useful for debugging and accessing provider-specific metadata. + +```go +func (a *MyAccount) GetConfigForProvider(ctx *context.Context, provider schemas.ModelProvider) (*schemas.ProviderConfig, error) { + return &schemas.ProviderConfig{ + NetworkConfig: schemas.DefaultNetworkConfig, + ConcurrencyAndBufferSize: schemas.DefaultConcurrencyAndBufferSize, + SendBackRawResponse: true, // Include raw provider response + }, nil +} +``` + +When enabled, the raw provider response appears in `ExtraFields.RawResponse`: + +```go +type BifrostResponse struct { + Choices []BifrostResponseChoice `json:"choices"` + Usage *LLMUsage `json:"usage,omitempty"` + ExtraFields BifrostResponseExtraFields `json:"extra_fields"` +} + +type BifrostResponseExtraFields struct { + Provider ModelProvider `json:"provider"` + RawResponse interface{} `json:"raw_response,omitempty"` // Original provider response +} +``` + +## Provider-Specific Authentication + +Enterprise cloud providers require additional configuration beyond API keys. Configure Azure OpenAI, AWS Bedrock, and Google Vertex with platform-specific authentication details. + + + + + +Azure OpenAI requires endpoint URLs, deployment mappings, and API version configuration: + +```go +func (a *MyAccount) GetKeysForProvider(ctx *context.Context, provider schemas.ModelProvider) ([]schemas.Key, error) { + switch provider { + case schemas.Azure: + return []schemas.Key{ + { + Value: os.Getenv("AZURE_API_KEY"), + Models: []string{"gpt-4o", "gpt-4o-mini"}, + Weight: 1.0, + AzureKeyConfig: &schemas.AzureKeyConfig{ + Endpoint: os.Getenv("AZURE_ENDPOINT"), // e.g., "https://your-resource.openai.azure.com" + Deployments: map[string]string{ + "gpt-4o": "gpt-4o-deployment", + "gpt-4o-mini": "gpt-4o-mini-deployment", + }, + APIVersion: bifrost.Ptr("2024-08-01-preview"), // Azure API version + }, + }, + }, nil + } + return nil, fmt.Errorf("provider %s not supported", provider) +} +``` + + + + + +AWS Bedrock supports both explicit credentials and IAM role authentication: + +```go +func (a *MyAccount) GetKeysForProvider(ctx *context.Context, provider schemas.ModelProvider) ([]schemas.Key, error) { + switch provider { + case schemas.Bedrock: + return []schemas.Key{ + { + Models: []string{"anthropic.claude-3-sonnet-20240229-v1:0", "anthropic.claude-v2:1"}, + Weight: 1.0, + Value: os.Getenv("AWS_API_KEY"), // Leave empty for IAM role authentication + BedrockKeyConfig: &schemas.BedrockKeyConfig{ + AccessKey: os.Getenv("AWS_ACCESS_KEY_ID"), // Leave empty for API Key authentication or system's IAM pickup + SecretKey: os.Getenv("AWS_SECRET_ACCESS_KEY"), // Leave empty for API Key authentication or system's IAM pickup + SessionToken: bifrost.Ptr(os.Getenv("AWS_SESSION_TOKEN")), // Optional + Region: bifrost.Ptr("us-east-1"), + // For model profiles (inference profiles) + Deployments: map[string]string{ + "claude-3-sonnet": "us.anthropic.claude-3-sonnet-20240229-v1:0", + }, + // For direct model access without profiles + ARN: bifrost.Ptr("arn:aws:bedrock:us-east-1:123456789012:inference-profile"), + }, + }, + }, nil + } + return nil, fmt.Errorf("provider %s not supported", provider) +} +``` + +**Notes:** +- If using API Key authentication, set `Value` field to the API key, else leave it empty for IAM role authentication. +- In IAM role authentication, if both `AccessKey` and `SecretKey` are empty, Bifrost uses IAM from the environment. +- `ARN` is required for URL formation - `Deployments` mapping is ignored without it. +- When using `ARN` + `Deployments`, Bifrost uses model profiles; otherwise forms path with incoming model name directly. + + + + + +Google Vertex requires project configuration and authentication credentials: + +```go +func (a *MyAccount) GetKeysForProvider(ctx *context.Context, provider schemas.ModelProvider) ([]schemas.Key, error) { + switch provider { + case schemas.Vertex: + return []schemas.Key{ + { + Value: os.Getenv("VERTEX_API_KEY"), // Optional if using service account + Models: []string{"gemini-pro", "gemini-pro-vision"}, + Weight: 1.0, + VertexKeyConfig: &schemas.VertexKeyConfig{ + ProjectID: os.Getenv("VERTEX_PROJECT_ID"), // GCP project ID + Region: "us-central1", // GCP region + AuthCredentials: os.Getenv("VERTEX_CREDENTIALS"), // Service account JSON or path + }, + }, + }, nil + } + return nil, fmt.Errorf("provider %s not supported", provider) +} +``` + + + + + +## Best Practices + +### Performance Considerations + +Keys are fetched from your `GetKeysForProvider` implementation on every request. Ensure your implementation is optimized for speed to avoid adding latency: + +```go +func (a *MyAccount) GetKeysForProvider(ctx *context.Context, provider schemas.ModelProvider) ([]schemas.Key, error) { + // βœ… Good: Fast in-memory lookup + switch provider { + case schemas.OpenAI: + return a.cachedOpenAIKeys, nil // Pre-cached keys + } + + // ❌ Avoid: Database queries, API calls, complex algorithms + // This will add latency to every AI request + // keys := fetchKeysFromDatabase(provider) // Too slow! + // return processWithComplexLogic(keys) // Too slow! + + return nil, fmt.Errorf("provider %s not supported", provider) +} +``` + +**Recommendations:** +- Cache keys in memory during application startup +- Use simple switch statements or map lookups +- Avoid database queries, file I/O, or network calls +- Keep complex key processing logic outside the request path + +## Next Steps + +- **[Streaming Responses](./streaming)** - Real-time response generation +- **[Tool Calling](./tool-calling)** - Enable AI to use external functions +- **[Multimodal AI](./multimodal)** - Process images, audio, and text +- **[Core Features](../../features/)** - Advanced Bifrost capabilities diff --git a/docs/quickstart/go-sdk/setting-up.mdx b/docs/quickstart/go-sdk/setting-up.mdx new file mode 100644 index 000000000..be6d9165c --- /dev/null +++ b/docs/quickstart/go-sdk/setting-up.mdx @@ -0,0 +1,146 @@ +--- +title: "Setting Up" +description: "Get Bifrost running in your Go application in 30 seconds with minimal setup and direct code integration." +icon: "play" +--- + + + + +## 30-Second Setup + +Get Bifrost running in your Go application with minimal setup. This guide shows you how to integrate multiple AI providers through a single, unified interface. + +### 1. Install Package + +```bash +go mod init my-bifrost-app +go get github.com/maximhq/bifrost/core +``` + +### 2. Set Environment Variable + +```bash +export OPENAI_API_KEY="your-openai-api-key" +``` + +### 3. Create `main.go` + +```go +package main + +import ( + "context" + "fmt" + "os" + + "github.com/maximhq/bifrost/core" + "github.com/maximhq/bifrost/core/schemas" +) + +type MyAccount struct{} + +// Account interface needs to implement these 3 methods +func (a *MyAccount) GetConfiguredProviders() ([]schemas.ModelProvider, error) { + return []schemas.ModelProvider{schemas.OpenAI}, nil +} + +func (a *MyAccount) GetKeysForProvider(ctx *context.Context, provider schemas.ModelProvider) ([]schemas.Key, error) { + if provider == schemas.OpenAI { + return []schemas.Key{{ + Value: os.Getenv("OPENAI_API_KEY"), + Models: []string{}, // Keep Models empty to use any model + Weight: 1.0, + }}, nil + } + return nil, fmt.Errorf("provider %s not supported", provider) +} + +func (a *MyAccount) GetConfigForProvider(provider schemas.ModelProvider) (*schemas.ProviderConfig, error) { + if provider == schemas.OpenAI { + // Return default config (can be customized for advanced use cases) + return &schemas.ProviderConfig{ + NetworkConfig: schemas.DefaultNetworkConfig, + ConcurrencyAndBufferSize: schemas.DefaultConcurrencyAndBufferSize, + }, nil + } + return nil, fmt.Errorf("provider %s not supported", provider) +} + +// Main function implement to initialize bifrost and make a request +func main() { + client, initErr := bifrost.Init(context.Background(), schemas.BifrostConfig{ + Account: &MyAccount{}, + }) + if initErr != nil { + panic(initErr) + } + defer client.Cleanup() + + messages := []schemas.BifrostMessage{ + { + Role: schemas.ModelChatMessageRoleUser, + Content: schemas.MessageContent{ + ContentStr: bifrost.Ptr("Hello, Bifrost!"), + }, + }, + } + + response, err := client.ChatCompletionRequest(context.Background(), &schemas.BifrostRequest{ + Provider: schemas.OpenAI, + Model: "gpt-4o-mini", + Input: schemas.RequestInput{ + ChatCompletionInput: &messages, + }, + }) + + if err != nil { + panic(err) + } + + fmt.Println("Response:", *response.Choices[0].Message.Content.ContentStr) +} +``` + +### 4. Run Your App + +```bash +go run main.go +# Output: Response: Hello! I'm Bifrost, your AI model gateway... +``` + +**πŸŽ‰ That's it!** You're now running Bifrost in your Go application. + +### What Just Happened? + +1. **Account Interface**: `MyAccount` provides API keys and list of providers to Bifrost for initialisation and key lookups. +2. **Provider Resolution**: `schemas.OpenAI` tells Bifrost to use OpenAI as the provider. +3. **Model Selection**: `"gpt-4o-mini"` specifies which model to use. +4. **Unified API**: Same interface works for any provider/model combination (OpenAI, Anthropic, Vertex etc.) + +--- + +## Next Steps + +Now that you have Bifrost running, explore these focused guides: + +### Essential Topics + +- **[Provider Configuration](./provider-configuration)** - Multiple providers & automatic failovers +- **[Streaming Responses](./streaming)** - Real-time chat, audio, and transcription +- **[Tool Calling](./tool-calling)** - Functions & MCP server integration +- **[Multimodal AI](./multimodal)** - Images, speech synthesis, and vision + +### Advanced Topics + +- **[Core Features](../../features/)** - Caching, observability, and governance +- **[Integrations](../../integrations/)** - Drop-in replacements for existing SDKs +- **[Architecture](../../architecture/)** - How Bifrost works internally +- **[Deployment](../../deployment/)** - Production setup and scaling + +--- + +**Happy coding with Bifrost!** πŸš€ diff --git a/docs/quickstart/go-sdk/streaming.mdx b/docs/quickstart/go-sdk/streaming.mdx new file mode 100644 index 000000000..a50071b21 --- /dev/null +++ b/docs/quickstart/go-sdk/streaming.mdx @@ -0,0 +1,218 @@ +--- +title: "Streaming Responses" +description: "Receive AI responses in real-time as they're generated. Perfect for chat applications, audio processing, and real-time transcription where you want immediate results." +icon: "water" +--- + +## Streaming Chat Responses + +Receive AI responses in real-time as they're generated. Perfect for chat applications where you want to show responses as they're being typed, improving user experience. + +```go +stream, err := client.ChatCompletionStreamRequest(context.Background(), &schemas.BifrostRequest{ + Provider: schemas.OpenAI, + Model: "gpt-4o-mini", + Input: schemas.RequestInput{ + ChatCompletionInput: &messages, + }, +}) + +if err != nil { + log.Printf("Streaming request failed: %v", err) + return +} + +for chunk := range stream { + // Handle errors in stream + if chunk.BifrostError != nil { + log.Printf("Stream error: %v", chunk.BifrostError) + break + } + + // Process response chunks + if chunk.BifrostResponse != nil && len(chunk.BifrostResponse.Choices) > 0 { + choice := chunk.BifrostResponse.Choices[0] + + // Check for streaming content + if choice.BifrostStreamResponseChoice != nil && + choice.BifrostStreamResponseChoice.Delta.Content != nil { + + content := *choice.BifrostStreamResponseChoice.Delta.Content + fmt.Print(content) // Print content as it arrives + } + } +} +``` + +> **Note:** Streaming requests also follow the default timeout setting defined in provider configuration, which defaults to **30 seconds**. + + +Bifrost standardizes all stream responses to send usage and finish reason only in the last chunk, and content in the previous chunks. + + +## Text-to-Speech Streaming: Real-time Audio Generation + +Stream audio generation in real-time as text is converted to speech. Ideal for long texts or when you need immediate audio playback. + +```go +stream, err := client.SpeechStreamRequest(context.Background(), &schemas.BifrostRequest{ + Provider: schemas.OpenAI, + Model: "tts-1", // Using text-to-speech model + Input: schemas.RequestInput{ + SpeechInput: &schemas.SpeechInput{ + Input: "Hello! This is a sample text that will be converted to speech using Bifrost's speech synthesis capabilities. The weather today is wonderful, and I hope you're having a great day!", + VoiceConfig: schemas.SpeechVoiceInput{ + Voice: bifrost.Ptr("alloy"), + }, + ResponseFormat: "mp3", + }, + }, +}) + +if err != nil { + panic(err) +} + +// Handle speech synthesis stream +var audioData []byte +var totalChunks int +filename := "output.mp3" + +for chunk := range stream { + if chunk.BifrostError != nil { + panic(fmt.Sprintf("Stream error: %s", chunk.BifrostError.Error.Message)) + } + + if chunk.BifrostResponse != nil && chunk.BifrostResponse.Speech != nil { + // Accumulate audio data from each chunk + audioData = append(audioData, chunk.BifrostResponse.Speech.Audio...) + totalChunks++ + fmt.Printf("Received chunk %d, size: %d bytes\n", totalChunks, len(chunk.BifrostResponse.Speech.Audio)) + } +} + +if len(audioData) > 0 { + // Save the accumulated audio to a file + err := os.WriteFile(filename, audioData, 0644) + if err != nil { + panic(fmt.Sprintf("Failed to save audio file: %v", err)) + } + + fmt.Printf("Speech synthesis streaming complete! Audio saved to %s\n", filename) + fmt.Printf("Total chunks received: %d, final file size: %d bytes\n", totalChunks, len(audioData)) +} +``` + +## Speech-to-Text Streaming: Real-time Audio Transcription + +Stream audio transcription results as they're processed. Get immediate text output for real-time applications or long audio files. + +```go +// Read the audio file for transcription +audioFilename := "output.mp3" +audioData, err := os.ReadFile(audioFilename) +if err != nil { + panic(fmt.Sprintf("Failed to read audio file %s: %v. Please make sure the file exists.", audioFilename, err)) +} + +fmt.Printf("Loaded audio file %s (%d bytes) for transcription...\n", audioFilename, len(audioData)) + +stream, err := client.TranscriptionStreamRequest(context.Background(), &schemas.BifrostRequest{ + Provider: schemas.OpenAI, + Model: "whisper-1", // Using Whisper model for transcription + Input: schemas.RequestInput{ + TranscriptionInput: &schemas.TranscriptionInput{ + File: audioData, + Prompt: bifrost.Ptr("This is a sample audio transcription from Bifrost speech synthesis."), // Optional: provide context + }, + }, +}) + +if err != nil { + panic(err) +} + +for chunk := range stream { + if chunk.BifrostError != nil { + panic(fmt.Sprintf("Stream error: %s", chunk.BifrostError.Error.Message)) + } + + if chunk.BifrostResponse != nil && chunk.BifrostResponse.Transcribe != nil { + // Print each chunk of text as it arrives + fmt.Print(chunk.BifrostResponse.Transcribe.Text) + } +} +``` + +## Streaming Best Practices + +### Buffering for Audio + +For audio streaming, consider buffering chunks before saving: + +```go +const bufferSize = 1024 * 1024 // 1MB buffer + +var audioBuffer bytes.Buffer +var lastSave time.Time + +for chunk := range stream { + if chunk.BifrostResponse != nil && chunk.BifrostResponse.Speech != nil { + audioBuffer.Write(chunk.BifrostResponse.Speech.Audio) + + // Save every second or when buffer is full + if time.Since(lastSave) > time.Second || audioBuffer.Len() > bufferSize { + // Append to file + file, err := os.OpenFile("streaming_audio.mp3", os.O_APPEND|os.O_CREATE|os.O_WRONLY, 0644) + if err == nil { + file.Write(audioBuffer.Bytes()) + file.Close() + audioBuffer.Reset() + lastSave = time.Now() + } + } + } +} +``` + +### Context and Cancellation + +Use context to control streaming duration: + +```go +ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second) +defer cancel() + +stream, err := client.ChatCompletionStreamRequest(ctx, &schemas.BifrostRequest{ + // ... your request +}) + +// Stream will automatically stop after 30 seconds +``` + +## Voice Options + +OpenAI TTS supports these voices: + +- `alloy` - Balanced, natural voice +- `echo` - Deep, resonant voice +- `fable` - Expressive, storytelling voice +- `onyx` - Strong, confident voice +- `nova` - Bright, energetic voice +- `shimmer` - Gentle, soothing voice + +```go +// Different voice example +VoiceConfig: schemas.SpeechVoiceInput{ + Voice: bifrost.Ptr("nova"), +}, +``` + +> **Note:** Please check each model's documentation to see if it supports the corresponding streaming features. Not all providers support all streaming capabilities. + +## Next Steps + +- **[Tool Calling](./tool-calling)** - Enable AI to use external functions +- **[Multimodal AI](./multimodal)** - Process images and multimedia content +- **[Provider Configuration](./provider-configuration)** - Multiple providers for redundancy +- **[Core Features](../../features/)** - Advanced Bifrost capabilities diff --git a/docs/quickstart/go-sdk/tool-calling.mdx b/docs/quickstart/go-sdk/tool-calling.mdx new file mode 100644 index 000000000..f58549981 --- /dev/null +++ b/docs/quickstart/go-sdk/tool-calling.mdx @@ -0,0 +1,271 @@ +--- +title: "Tool Calling" +description: "Enable AI models to use external functions and services by defining tool schemas or connecting to Model Context Protocol (MCP) servers. This allows AI to interact with databases, APIs, file systems, and more." +icon: "wrench" +--- + +## Function Calling with Custom Tools + +Enable AI models to use external functions by defining tool schemas. Models can then call these functions automatically based on user requests. + +```go +// Define a tool for the calculator +calculatorTool := schemas.Tool{ + Type: "function", + Function: schemas.Function{ + Name: "calculator", + Description: "A calculator tool", + Parameters: schemas.FunctionParameters{ + Type: "object", + Properties: map[string]interface{}{ + "operation": map[string]interface{}{ + "type": "string", + "description": "The operation to perform", + "enum": []string{"add", "subtract", "multiply", "divide"}, + }, + "a": map[string]interface{}{ + "type": "number", + "description": "The first number", + }, + "b": map[string]interface{}{ + "type": "number", + "description": "The second number", + }, + }, + Required: []string{"operation", "a", "b"}, + }, + }, +} + +response, err := client.ChatCompletionRequest(context.Background(), &schemas.BifrostRequest{ + Provider: schemas.OpenAI, + Model: "gpt-4o-mini", + Input: schemas.RequestInput{ + ChatCompletionInput: &[]schemas.BifrostMessage{ + { + Role: schemas.ModelChatMessageRoleUser, + Content: schemas.MessageContent{ + ContentStr: bifrost.Ptr("What is 2+2? Use the calculator tool."), + }, + }, + }, + }, + Params: &schemas.ModelParameters{ + Tools: &[]schemas.Tool{calculatorTool}, + }, +}) + +if err != nil { + panic(err) +} + +toolCalls := response.Choices[0].Message.AssistantMessage.ToolCalls +if toolCalls != nil { + for _, toolCall := range *toolCalls { + fmt.Printf("Tool call in response - %s: %s\n", *toolCall.ID, *toolCall.Function.Name) + fmt.Printf("Tool call arguments - %s\n", toolCall.Function.Arguments) + } +} +``` + +## Connecting to MCP Servers + +Connect to Model Context Protocol (MCP) servers to give AI models access to external tools and services without manually defining each function. + +```go +client, initErr := bifrost.Init(context.Background(), schemas.BifrostConfig{ + Account: &MyAccount{}, + MCPConfig: &schemas.MCPConfig{ + ClientConfigs: []schemas.MCPClientConfig{ + // Sample youtube-mcp server + { + Name: "youtube-mcp", + ConnectionType: schemas.MCPConnectionTypeHTTP, + ConnectionString: bifrost.Ptr("http://your-youtube-mcp-url"), + }, + }, + }, +}) +if initErr != nil { + panic(initErr) +} +defer client.Cleanup() + +response, err := client.ChatCompletionRequest(context.Background(), &schemas.BifrostRequest{ + Provider: schemas.OpenAI, + Model: "gpt-4o-mini", + Input: schemas.RequestInput{ + ChatCompletionInput: &[]schemas.BifrostMessage{ + { + Role: schemas.ModelChatMessageRoleUser, + Content: schemas.MessageContent{ + ContentStr: bifrost.Ptr("What do you see when you search for 'bifrost' on youtube?"), + }, + }, + }, + }, +}) + +if err != nil { + panic(err) +} + +toolCalls := response.Choices[0].Message.AssistantMessage.ToolCalls +if toolCalls != nil { + for _, toolCall := range *toolCalls { + fmt.Printf("Tool call in response - %s: %s\n", *toolCall.ID, *toolCall.Function.Name) + fmt.Printf("Tool call arguments - %s\n", toolCall.Function.Arguments) + } +} +``` + +Read more about MCP connections and in-house tool registration via local MCP server in the [MCP Features](../../features/mcp) section. + +## Advanced Tool Examples + +### Weather API Tool + +```go +weatherTool := schemas.Tool{ + Type: "function", + Function: schemas.Function{ + Name: "get_weather", + Description: "Get the current weather for a location", + Parameters: schemas.FunctionParameters{ + Type: "object", + Properties: map[string]interface{}{ + "location": map[string]interface{}{ + "type": "string", + "description": "The city and state, e.g. San Francisco, CA", + }, + "unit": map[string]interface{}{ + "type": "string", + "description": "Temperature unit", + "enum": []string{"celsius", "fahrenheit"}, + }, + }, + Required: []string{"location"}, + }, + }, +} +``` + +### Database Query Tool + +```go +databaseTool := schemas.Tool{ + Type: "function", + Function: schemas.Function{ + Name: "query_database", + Description: "Execute a SQL query on the customer database", + Parameters: schemas.FunctionParameters{ + Type: "object", + Properties: map[string]interface{}{ + "query": map[string]interface{}{ + "type": "string", + "description": "The SQL query to execute", + }, + "table": map[string]interface{}{ + "type": "string", + "description": "The table to query", + "enum": []string{"customers", "orders", "products"}, + }, + }, + Required: []string{"query", "table"}, + }, + }, +} +``` + +### File System Tool + +```go +fileSystemTool := schemas.Tool{ + Type: "function", + Function: schemas.Function{ + Name: "read_file", + Description: "Read the contents of a file", + Parameters: schemas.FunctionParameters{ + Type: "object", + Properties: map[string]interface{}{ + "path": map[string]interface{}{ + "type": "string", + "description": "The file path to read", + }, + "encoding": map[string]interface{}{ + "type": "string", + "description": "File encoding", + "enum": []string{"utf-8", "ascii", "base64"}, + "default": "utf-8", + }, + }, + Required: []string{"path"}, + }, + }, +} +``` + +## Multiple Tool Support + +Use multiple tools in a single request: + +```go +response, err := client.ChatCompletionRequest(context.Background(), &schemas.BifrostRequest{ + Provider: schemas.OpenAI, + Model: "gpt-4o-mini", + Input: schemas.RequestInput{ + ChatCompletionInput: &[]schemas.BifrostMessage{ + { + Role: schemas.ModelChatMessageRoleUser, + Content: schemas.MessageContent{ + ContentStr: bifrost.Ptr("What's the weather in New York and calculate 15% tip for a $50 bill?"), + }, + }, + }, + }, + Params: &schemas.ModelParameters{ + Tools: &[]schemas.Tool{weatherTool, calculatorTool}, + ToolChoice: bifrost.Ptr("auto"), // Let AI decide which tools to use + }, +}) +``` + +## Tool Choice Options + +Control how the AI uses tools: + +```go +// Force use of a specific tool +Params: &schemas.ModelParameters{ + Tools: &[]schemas.Tool{calculatorTool}, + ToolChoiceStruct: &schemas.ToolChoiceStruct{ + Type: "function", + Function: schemas.ToolChoiceFunction{ + Name: "calculator", + }, + }, +} + +// Let AI decide automatically +Params: &schemas.ModelParameters{ + Tools: &[]schemas.Tool{calculatorTool, weatherTool}, + ToolChoice: &schemas.ToolChoice{ + ToolChoiceStr: bifrost.Ptr("auto"), + }, +} + +// Disable tool usage +Params: &schemas.ModelParameters{ + Tools: &[]schemas.Tool{calculatorTool}, + ToolChoice: &schemas.ToolChoice{ + ToolChoiceStr: bifrost.Ptr("none"), + }, +} +``` + +## Next Steps + +- **[Multimodal AI](./multimodal)** - Process images, audio, and multimedia content +- **[Streaming Responses](./streaming)** - Real-time response generation +- **[Provider Configuration](./provider-configuration)** - Multiple providers for redundancy +- **[MCP Features](../../features/mcp)** - Advanced MCP server management diff --git a/docs/style.css b/docs/style.css new file mode 100644 index 000000000..e63a15fe7 --- /dev/null +++ b/docs/style.css @@ -0,0 +1,3 @@ +.nav-logo { + height: 2.75rem; +} \ No newline at end of file diff --git a/framework/changelog.md b/framework/changelog.md new file mode 100644 index 000000000..57e331e11 --- /dev/null +++ b/framework/changelog.md @@ -0,0 +1,4 @@ + + + +- Pricing module now accommodates nested model names i.e. groq/openai/gpt-oss-20b was getting skipped while computing costs. \ No newline at end of file diff --git a/framework/configstore/clientconfig.go b/framework/configstore/clientconfig.go new file mode 100644 index 000000000..9918b3aac --- /dev/null +++ b/framework/configstore/clientconfig.go @@ -0,0 +1,60 @@ +package configstore + +import ( + "github.com/maximhq/bifrost/core/schemas" +) + +type EnvKeyType string + +const ( + EnvKeyTypeAPIKey EnvKeyType = "api_key" + EnvKeyTypeAzureConfig EnvKeyType = "azure_config" + EnvKeyTypeVertexConfig EnvKeyType = "vertex_config" + EnvKeyTypeBedrockConfig EnvKeyType = "bedrock_config" + EnvKeyTypeConnection EnvKeyType = "connection_string" +) + +// EnvKeyInfo stores information about a key sourced from environment +type EnvKeyInfo struct { + EnvVar string // The environment variable name (without env. prefix) + Provider schemas.ModelProvider // The provider this key belongs to (empty for core/mcp configs) + KeyType EnvKeyType // Type of key (e.g., "api_key", "azure_config", "vertex_config", "bedrock_config", "connection_string") + ConfigPath string // Path in config where this env var is used + KeyID string // The key ID this env var belongs to (empty for non-key configs like bedrock_config, connection_string) +} + +// ClientConfig represents the core configuration for Bifrost HTTP transport and the Bifrost Client. +// It includes settings for excess request handling, Prometheus metrics, and initial pool size. +type ClientConfig struct { + DropExcessRequests bool `json:"drop_excess_requests"` // Drop excess requests if the provider queue is full + InitialPoolSize int `json:"initial_pool_size"` // The initial pool size for the bifrost client + PrometheusLabels []string `json:"prometheus_labels"` // The labels to be used for prometheus metrics + EnableLogging bool `json:"enable_logging"` // Enable logging of requests and responses + EnableGovernance bool `json:"enable_governance"` // Enable governance on all requests + EnforceGovernanceHeader bool `json:"enforce_governance_header"` // Enforce governance on all requests + AllowDirectKeys bool `json:"allow_direct_keys"` // Allow direct keys to be used for requests + AllowedOrigins []string `json:"allowed_origins,omitempty"` // Additional allowed origins for CORS and WebSocket (localhost is always allowed) + MaxRequestBodySizeMB int `json:"max_request_body_size_mb"` // The maximum request body size in MB +} + +// ProviderConfig represents the configuration for a specific AI model provider. +// It includes API keys, network settings, and concurrency settings. +type ProviderConfig struct { + Keys []schemas.Key `json:"keys"` // API keys for the provider with UUIDs + NetworkConfig *schemas.NetworkConfig `json:"network_config,omitempty"` // Network-related settings + ConcurrencyAndBufferSize *schemas.ConcurrencyAndBufferSize `json:"concurrency_and_buffer_size,omitempty"` // Concurrency settings + ProxyConfig *schemas.ProxyConfig `json:"proxy_config,omitempty"` // Proxy configuration + SendBackRawResponse bool `json:"send_back_raw_response"` // Include raw response in BifrostResponse + CustomProviderConfig *schemas.CustomProviderConfig `json:"custom_provider_config,omitempty"` // Custom provider configuration +} + +// ConfigMap maps provider names to their configurations. +type ConfigMap map[schemas.ModelProvider]ProviderConfig + +type GovernanceConfig struct { + VirtualKeys []TableVirtualKey `json:"virtual_keys"` + Teams []TableTeam `json:"teams"` + Customers []TableCustomer `json:"customers"` + Budgets []TableBudget `json:"budgets"` + RateLimits []TableRateLimit `json:"rate_limits"` +} diff --git a/framework/configstore/config.go b/framework/configstore/config.go new file mode 100644 index 000000000..642d81c8a --- /dev/null +++ b/framework/configstore/config.go @@ -0,0 +1,60 @@ +package configstore + +import ( + "encoding/json" + "fmt" +) + +// ConfigStoreType represents the type of config store. +type ConfigStoreType string + +// ConfigStoreTypeSQLite is the type of config store for SQLite. +const ( + ConfigStoreTypeSQLite ConfigStoreType = "sqlite" +) + +// Config represents the configuration for the config store. +type Config struct { + Enabled bool `json:"enabled"` + Type ConfigStoreType `json:"type"` + Config any `json:"config"` +} + +// UnmarshalJSON unmarshals the config from JSON. +func (c *Config) UnmarshalJSON(data []byte) error { + // First, unmarshal into a temporary struct to get the basic fields + type TempConfig struct { + Enabled bool `json:"enabled"` + Type ConfigStoreType `json:"type"` + Config json.RawMessage `json:"config"` // Keep as raw JSON + } + + var temp TempConfig + if err := json.Unmarshal(data, &temp); err != nil { + return fmt.Errorf("failed to unmarshal config store config: %w", err) + } + + // Set basic fields + c.Enabled = temp.Enabled + c.Type = temp.Type + + if !temp.Enabled { + c.Config = nil + return nil + } + + // Parse the config field based on type + switch temp.Type { + case ConfigStoreTypeSQLite: + var sqliteConfig SQLiteConfig + if err := json.Unmarshal(temp.Config, &sqliteConfig); err != nil { + return fmt.Errorf("failed to unmarshal sqlite config: %w", err) + } + c.Config = &sqliteConfig + + default: + return fmt.Errorf("unknown config store type: %s", temp.Type) + } + + return nil +} diff --git a/framework/configstore/errors.go b/framework/configstore/errors.go new file mode 100644 index 000000000..e5b77064d --- /dev/null +++ b/framework/configstore/errors.go @@ -0,0 +1,5 @@ +package configstore + +import "errors" + +var ErrNotFound = errors.New("not found") diff --git a/framework/configstore/internal/migration/migrator.go b/framework/configstore/internal/migration/migrator.go new file mode 100644 index 000000000..95d0ea806 --- /dev/null +++ b/framework/configstore/internal/migration/migrator.go @@ -0,0 +1,512 @@ +// Portions of this file are derived from https://github.com/go-gormigrate/gormigrate +// MIT License +// Copyright (c) 2016 Andrey Nering +// Permission is hereby granted, free of charge, to any person obtaining a copy +// of this software and associated documentation files (the "Software"), to deal +// in the Software without restriction, including without limitation the rights +// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +// copies of the Software, and to permit persons to whom the Software is +// furnished to do so, subject to the following conditions: +// +// The above copyright notice and this permission notice shall be included in all +// copies or substantial portions of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + +package migration + +import ( + "context" + "errors" + "fmt" + "reflect" + + "gorm.io/gorm" +) + +const ( + initSchemaMigrationID = "SCHEMA_INIT" +) + +// MigrateFunc is the func signature for migrating. +type MigrateFunc func(*gorm.DB) error + +// RollbackFunc is the func signature for rollbacking. +type RollbackFunc func(*gorm.DB) error + +// InitSchemaFunc is the func signature for initializing the schema. +type InitSchemaFunc func(*gorm.DB) error + +// Options define options for all migrations. +type Options struct { + // TableName is the migration table. + TableName string + // IDColumnName is the name of column where the migration id will be stored. + IDColumnName string + // IDColumnSize is the length of the migration id column + IDColumnSize int + // UseTransaction makes Gormigrate execute migrations inside a single transaction. + // Keep in mind that not all databases support DDL commands inside transactions. + UseTransaction bool + // ValidateUnknownMigrations will cause migrate to fail if there's unknown migration + // IDs in the database + ValidateUnknownMigrations bool +} + +// Migration represents a database migration (a modification to be made on the database). +type Migration struct { + // ID is the migration identifier. Usually a timestamp like "201601021504". + ID string + // Migrate is a function that will br executed while running this migration. + Migrate MigrateFunc + // Rollback will be executed on rollback. Can be nil. + Rollback RollbackFunc +} + +// Gormigrate represents a collection of all migrations of a database schema. +type Gormigrate struct { + db *gorm.DB + tx *gorm.DB + options *Options + migrations []*Migration + initSchema InitSchemaFunc +} + +// ReservedIDError is returned when a migration is using a reserved ID +type ReservedIDError struct { + ID string +} + +func (e *ReservedIDError) Error() string { + return fmt.Sprintf(`gormigrate: Reserved migration ID: "%s"`, e.ID) +} + +// DuplicatedIDError is returned when more than one migration have the same ID +type DuplicatedIDError struct { + ID string +} + +func (e *DuplicatedIDError) Error() string { + return fmt.Sprintf(`gormigrate: Duplicated migration ID: "%s"`, e.ID) +} + +var ( + // DefaultOptions can be used if you don't want to think about options. + DefaultOptions = &Options{ + TableName: "migrations", + IDColumnName: "id", + IDColumnSize: 255, + UseTransaction: false, + ValidateUnknownMigrations: false, + } + + // ErrRollbackImpossible is returned when trying to rollback a migration + // that has no rollback function. + ErrRollbackImpossible = errors.New("gormigrate: It's impossible to rollback this migration") + + // ErrNoMigrationDefined is returned when no migration is defined. + ErrNoMigrationDefined = errors.New("gormigrate: No migration defined") + + // ErrMissingID is returned when the ID od migration is equal to "" + ErrMissingID = errors.New("gormigrate: Missing ID in migration") + + // ErrNoRunMigration is returned when any run migration was found while + // running RollbackLast + ErrNoRunMigration = errors.New("gormigrate: Could not find last run migration") + + // ErrMigrationIDDoesNotExist is returned when migrating or rolling back to a migration ID that + // does not exist in the list of migrations + ErrMigrationIDDoesNotExist = errors.New("gormigrate: Tried to migrate to an ID that doesn't exist") + + // ErrUnknownPastMigration is returned if a migration exists in the DB that doesn't exist in the code + ErrUnknownPastMigration = errors.New("gormigrate: Found migration in DB that does not exist in code") +) + +// New returns a new Gormigrate. +func New(db *gorm.DB, options *Options, migrations []*Migration) *Gormigrate { + if options == nil { + options = DefaultOptions + } + if options.TableName == "" { + options.TableName = DefaultOptions.TableName + } + if options.IDColumnName == "" { + options.IDColumnName = DefaultOptions.IDColumnName + } + if options.IDColumnSize == 0 { + options.IDColumnSize = DefaultOptions.IDColumnSize + } + return &Gormigrate{ + db: db, + options: options, + migrations: migrations, + } +} + +// InitSchema sets a function that is run if no migration is found. +// The idea is preventing to run all migrations when a new clean database +// is being migrating. In this function you should create all tables and +// foreign key necessary to your application. +func (g *Gormigrate) InitSchema(initSchema InitSchemaFunc) { + g.initSchema = initSchema +} + +// Migrate executes all migrations that did not run yet. +func (g *Gormigrate) Migrate() error { + if !g.hasMigrations() { + return ErrNoMigrationDefined + } + var targetMigrationID string + if len(g.migrations) > 0 { + targetMigrationID = g.migrations[len(g.migrations)-1].ID + } + return g.migrate(targetMigrationID) +} + +// MigrateTo executes all migrations that did not run yet up to the migration that matches `migrationID`. +func (g *Gormigrate) MigrateTo(migrationID string) error { + if err := g.checkIDExist(migrationID); err != nil { + return err + } + return g.migrate(migrationID) +} + +func (g *Gormigrate) migrate(migrationID string) error { + if !g.hasMigrations() { + return ErrNoMigrationDefined + } + + if err := g.checkReservedID(); err != nil { + return err + } + + if err := g.checkDuplicatedID(); err != nil { + return err + } + + g.begin() + defer g.rollback() + + if err := g.createMigrationTableIfNotExists(); err != nil { + return err + } + + if g.options.ValidateUnknownMigrations { + unknownMigrations, err := g.unknownMigrationsHaveHappened() + if err != nil { + return err + } + if unknownMigrations { + return ErrUnknownPastMigration + } + } + + if g.initSchema != nil { + canInitializeSchema, err := g.canInitializeSchema() + if err != nil { + return err + } + if canInitializeSchema { + if err := g.runInitSchema(); err != nil { + return err + } + return g.commit() + } + } + + for _, migration := range g.migrations { + if err := g.runMigration(migration); err != nil { + return err + } + if migrationID != "" && migration.ID == migrationID { + break + } + } + return g.commit() +} + +// There are migrations to apply if either there's a defined +// initSchema function or if the list of migrations is not empty. +func (g *Gormigrate) hasMigrations() bool { + return g.initSchema != nil || len(g.migrations) > 0 +} + +// Check whether any migration is using a reserved ID. +// For now there's only have one reserved ID, but there may be more in the future. +func (g *Gormigrate) checkReservedID() error { + for _, m := range g.migrations { + if m.ID == initSchemaMigrationID { + return &ReservedIDError{ID: m.ID} + } + } + return nil +} + +func (g *Gormigrate) checkDuplicatedID() error { + lookup := make(map[string]struct{}, len(g.migrations)) + for _, m := range g.migrations { + if _, ok := lookup[m.ID]; ok { + return &DuplicatedIDError{ID: m.ID} + } + lookup[m.ID] = struct{}{} + } + return nil +} + +func (g *Gormigrate) checkIDExist(migrationID string) error { + for _, migrate := range g.migrations { + if migrate.ID == migrationID { + return nil + } + } + return ErrMigrationIDDoesNotExist +} + +// RollbackLast undo the last migration +func (g *Gormigrate) RollbackLast() error { + if len(g.migrations) == 0 { + return ErrNoMigrationDefined + } + + g.begin() + defer g.rollback() + + lastRunMigration, err := g.getLastRunMigration() + if err != nil { + return err + } + + if err := g.rollbackMigration(lastRunMigration); err != nil { + return err + } + return g.commit() +} + +// RollbackTo undoes migrations up to the given migration that matches the `migrationID`. +// Migration with the matching `migrationID` is not rolled back. +func (g *Gormigrate) RollbackTo(migrationID string) error { + if len(g.migrations) == 0 { + return ErrNoMigrationDefined + } + + if err := g.checkIDExist(migrationID); err != nil { + return err + } + + g.begin() + defer g.rollback() + + for i := len(g.migrations) - 1; i >= 0; i-- { + migration := g.migrations[i] + if migration.ID == migrationID { + break + } + migrationRan, err := g.migrationRan(migration) + if err != nil { + return err + } + if migrationRan { + if err := g.rollbackMigration(migration); err != nil { + return err + } + } + } + return g.commit() +} + +func (g *Gormigrate) getLastRunMigration() (*Migration, error) { + for i := len(g.migrations) - 1; i >= 0; i-- { + migration := g.migrations[i] + + migrationRan, err := g.migrationRan(migration) + if err != nil { + return nil, err + } + + if migrationRan { + return migration, nil + } + } + return nil, ErrNoRunMigration +} + +// RollbackMigration undo a migration. +func (g *Gormigrate) RollbackMigration(m *Migration) error { + g.begin() + defer g.rollback() + + if err := g.rollbackMigration(m); err != nil { + return err + } + return g.commit() +} + +func (g *Gormigrate) rollbackMigration(m *Migration) error { + if m.Rollback == nil { + return ErrRollbackImpossible + } + + if err := m.Rollback(g.tx); err != nil { + return err + } + + cond := fmt.Sprintf("%s = ?", g.options.IDColumnName) + return g.tx.Table(g.options.TableName).Where(cond, m.ID).Delete(g.model()).Error +} + +func (g *Gormigrate) runInitSchema() error { + if err := g.initSchema(g.tx); err != nil { + return err + } + if err := g.insertMigration(initSchemaMigrationID); err != nil { + return err + } + + for _, migration := range g.migrations { + if err := g.insertMigration(migration.ID); err != nil { + return err + } + } + + return nil +} + +func (g *Gormigrate) runMigration(migration *Migration) error { + if len(migration.ID) == 0 { + return ErrMissingID + } + + migrationRan, err := g.migrationRan(migration) + if err != nil { + return err + } + if !migrationRan { + if err := migration.Migrate(g.tx); err != nil { + return err + } + + if err := g.insertMigration(migration.ID); err != nil { + return err + } + } + return nil +} + +// model returns pointer to dynamically created gorm migration model struct value +// +// struct defined as { +// ID string `gorm:"primaryKey;column:;size:"` +// } +func (g *Gormigrate) model() any { + f := reflect.StructField{ + Name: reflect.ValueOf("ID").Interface().(string), + Type: reflect.TypeOf(""), + Tag: reflect.StructTag(fmt.Sprintf( + `gorm:"primaryKey;column:%s;size:%d"`, + g.options.IDColumnName, + g.options.IDColumnSize, + )), + } + structType := reflect.StructOf([]reflect.StructField{f}) + structValue := reflect.New(structType).Elem() + return structValue.Addr().Interface() +} + +func (g *Gormigrate) createMigrationTableIfNotExists() error { + if g.tx.Migrator().HasTable(g.options.TableName) { + return nil + } + return g.tx.Table(g.options.TableName).AutoMigrate(g.model()) +} + +func (g *Gormigrate) migrationRan(m *Migration) (bool, error) { + var count int64 + err := g.tx. + Table(g.options.TableName). + Where(fmt.Sprintf("%s = ?", g.options.IDColumnName), m.ID). + Count(&count). + Error + return count > 0, err +} + +// The schema can be initialised only if it hasn't been initialised yet +// and no other migration has been applied already. +func (g *Gormigrate) canInitializeSchema() (bool, error) { + migrationRan, err := g.migrationRan(&Migration{ID: initSchemaMigrationID}) + if err != nil { + return false, err + } + if migrationRan { + return false, nil + } + + // If the ID doesn't exist, we also want the list of migrations to be empty + var count int64 + err = g.tx. + Table(g.options.TableName). + Count(&count). + Error + return count == 0, err +} + +func (g *Gormigrate) unknownMigrationsHaveHappened() (bool, error) { + rows, err := g.tx.Table(g.options.TableName).Select(g.options.IDColumnName).Rows() + if err != nil { + return false, err + } + defer func() { + if err := rows.Close(); err != nil { + g.tx.Logger.Error(context.TODO(), err.Error()) + } + }() + + validIDSet := make(map[string]struct{}, len(g.migrations)+1) + validIDSet[initSchemaMigrationID] = struct{}{} + for _, migration := range g.migrations { + validIDSet[migration.ID] = struct{}{} + } + + for rows.Next() { + var pastMigrationID string + if err := rows.Scan(&pastMigrationID); err != nil { + return false, err + } + if _, ok := validIDSet[pastMigrationID]; !ok { + return true, nil + } + } + + return false, nil +} + +func (g *Gormigrate) insertMigration(id string) error { + record := g.model() + reflect.ValueOf(record).Elem().FieldByName("ID").SetString(id) + return g.tx.Table(g.options.TableName).Create(record).Error +} + +func (g *Gormigrate) begin() { + if g.options.UseTransaction { + g.tx = g.db.Begin() + } else { + g.tx = g.db + } +} + +func (g *Gormigrate) commit() error { + if g.options.UseTransaction { + return g.tx.Commit().Error + } + return nil +} + +func (g *Gormigrate) rollback() { + if g.options.UseTransaction { + g.tx.Rollback() + } +} diff --git a/framework/configstore/migrations.go b/framework/configstore/migrations.go new file mode 100644 index 000000000..6953a6697 --- /dev/null +++ b/framework/configstore/migrations.go @@ -0,0 +1,244 @@ +package configstore + +import ( + "fmt" + + "github.com/maximhq/bifrost/framework/configstore/internal/migration" + "gorm.io/gorm" +) + +// Migrate performs the necessary database migrations. +func triggerMigrations(db *gorm.DB) error { + if err := migrationInit(db); err != nil { + return err + } + if err := migrationMany2ManyJoinTable(db); err != nil { + return err + } + if err := migrationAddCustomProviderConfigJSONColumn(db); err != nil { + return err + } + return nil +} + +// migrationInit is the first migration +func migrationInit(db *gorm.DB) error { + m := migration.New(db, migration.DefaultOptions, []*migration.Migration{{ + ID: "init", + Migrate: func(tx *gorm.DB) error { + migrator := tx.Migrator() + if !migrator.HasTable(&TableConfigHash{}) { + if err := migrator.CreateTable(&TableConfigHash{}); err != nil { + return err + } + } + if !migrator.HasTable(&TableProvider{}) { + if err := migrator.CreateTable(&TableProvider{}); err != nil { + return err + } + } + if !migrator.HasTable(&TableKey{}) { + if err := migrator.CreateTable(&TableKey{}); err != nil { + return err + } + } + if !migrator.HasTable(&TableModel{}) { + if err := migrator.CreateTable(&TableModel{}); err != nil { + return err + } + } + if !migrator.HasTable(&TableMCPClient{}) { + if err := migrator.CreateTable(&TableMCPClient{}); err != nil { + return err + } + } + if !migrator.HasTable(&TableClientConfig{}) { + if err := migrator.CreateTable(&TableClientConfig{}); err != nil { + return err + } + } else if !migrator.HasColumn(&TableClientConfig{}, "max_request_body_size_mb") { + if err := migrator.AddColumn(&TableClientConfig{}, "max_request_body_size_mb"); err != nil { + return err + } + } + if !migrator.HasTable(&TableEnvKey{}) { + if err := migrator.CreateTable(&TableEnvKey{}); err != nil { + return err + } + } + if !migrator.HasTable(&TableVectorStoreConfig{}) { + if err := migrator.CreateTable(&TableVectorStoreConfig{}); err != nil { + return err + } + } + if !migrator.HasTable(&TableLogStoreConfig{}) { + if err := migrator.CreateTable(&TableLogStoreConfig{}); err != nil { + return err + } + } + if !migrator.HasTable(&TableBudget{}) { + if err := migrator.CreateTable(&TableBudget{}); err != nil { + return err + } + } + if !migrator.HasTable(&TableRateLimit{}) { + if err := migrator.CreateTable(&TableRateLimit{}); err != nil { + return err + } + } + if !migrator.HasTable(&TableCustomer{}) { + if err := migrator.CreateTable(&TableCustomer{}); err != nil { + return err + } + } + if !migrator.HasTable(&TableTeam{}) { + if err := migrator.CreateTable(&TableTeam{}); err != nil { + return err + } + } + if !migrator.HasTable(&TableVirtualKey{}) { + if err := migrator.CreateTable(&TableVirtualKey{}); err != nil { + return err + } + } + if !migrator.HasTable(&TableConfig{}) { + if err := migrator.CreateTable(&TableConfig{}); err != nil { + return err + } + } + if !migrator.HasTable(&TableModelPricing{}) { + if err := migrator.CreateTable(&TableModelPricing{}); err != nil { + return err + } + } + if !migrator.HasTable(&TablePlugin{}) { + if err := migrator.CreateTable(&TablePlugin{}); err != nil { + return err + } + } + + return nil + }, + Rollback: func(tx *gorm.DB) error { + migrator := tx.Migrator() + // Drop children first, then parents (adjust if your actual FKs differ) + if err := migrator.DropTable(&TableVirtualKey{}); err != nil { + return err + } + if err := migrator.DropTable(&TableKey{}); err != nil { + return err + } + if err := migrator.DropTable(&TableTeam{}); err != nil { + return err + } + if err := migrator.DropTable(&TableProvider{}); err != nil { + return err + } + if err := migrator.DropTable(&TableCustomer{}); err != nil { + return err + } + if err := migrator.DropTable(&TableBudget{}); err != nil { + return err + } + if err := migrator.DropTable(&TableRateLimit{}); err != nil { + return err + } + if err := migrator.DropTable(&TableModel{}); err != nil { + return err + } + if err := migrator.DropTable(&TableMCPClient{}); err != nil { + return err + } + if err := migrator.DropTable(&TableClientConfig{}); err != nil { + return err + } + if err := migrator.DropTable(&TableEnvKey{}); err != nil { + return err + } + if err := migrator.DropTable(&TableVectorStoreConfig{}); err != nil { + return err + } + if err := migrator.DropTable(&TableLogStoreConfig{}); err != nil { + return err + } + if err := migrator.DropTable(&TableConfig{}); err != nil { + return err + } + if err := migrator.DropTable(&TableModelPricing{}); err != nil { + return err + } + if err := migrator.DropTable(&TablePlugin{}); err != nil { + return err + } + if err := migrator.DropTable(&TableConfigHash{}); err != nil { + return err + } + return nil + }, + }}) + err := m.Migrate() + if err != nil { + return fmt.Errorf("error while running db migration: %s", err.Error()) + } + return nil +} + +// createMany2ManyJoinTable creates a many-to-many join table for the given tables. +func migrationMany2ManyJoinTable(db *gorm.DB) error { + m := migration.New(db, migration.DefaultOptions, []*migration.Migration{{ + ID: "many2manyjoin", + Migrate: func(tx *gorm.DB) error { + migrator := tx.Migrator() + + // create the many-to-many join table for virtual keys and keys + if !migrator.HasTable("governance_virtual_key_keys") { + createJoinTableSQL := ` + CREATE TABLE IF NOT EXISTS governance_virtual_key_keys ( + table_virtual_key_id VARCHAR(255) NOT NULL, + table_key_id INTEGER NOT NULL, + PRIMARY KEY (table_virtual_key_id, table_key_id), + FOREIGN KEY (table_virtual_key_id) REFERENCES governance_virtual_keys(id) ON DELETE CASCADE, + FOREIGN KEY (table_key_id) REFERENCES config_keys(id) ON DELETE CASCADE + ) + ` + if err := tx.Exec(createJoinTableSQL).Error; err != nil { + return fmt.Errorf("failed to create governance_virtual_key_keys table: %w", err) + } + } + + return nil + }, + Rollback: func(tx *gorm.DB) error { + if err := tx.Exec("DROP TABLE IF EXISTS governance_virtual_key_keys").Error; err != nil { + return err + } + return nil + }, + }}) + err := m.Migrate() + if err != nil { + return fmt.Errorf("error while running db migration: %s", err.Error()) + } + return nil +} + +func migrationAddCustomProviderConfigJSONColumn(db *gorm.DB) error { + m := migration.New(db, migration.DefaultOptions, []*migration.Migration{{ + ID: "addcustomproviderconfigjsoncolumn", + Migrate: func(tx *gorm.DB) error { + migrator := tx.Migrator() + + if !migrator.HasColumn(&TableProvider{}, "custom_provider_config_json") { + if err := migrator.AddColumn(&TableProvider{}, "custom_provider_config_json"); err != nil { + return err + } + } + return nil + }, + }}) + err := m.Migrate() + if err != nil { + return fmt.Errorf("error while running db migration: %s", err.Error()) + } + return nil +} diff --git a/framework/configstore/sqlite.go b/framework/configstore/sqlite.go new file mode 100644 index 000000000..76514bbba --- /dev/null +++ b/framework/configstore/sqlite.go @@ -0,0 +1,1275 @@ +package configstore + +import ( + "encoding/json" + "errors" + "fmt" + "os" + "strings" + + "github.com/maximhq/bifrost/core/schemas" + "github.com/maximhq/bifrost/framework/logstore" + "github.com/maximhq/bifrost/framework/vectorstore" + "gorm.io/driver/sqlite" + "gorm.io/gorm" + gormLogger "gorm.io/gorm/logger" +) + +// processEnvValue processes a value that might be an environment variable reference +func processEnvValue(value string, logger schemas.Logger) (string, error) { + v := strings.TrimSpace(value) + if !strings.HasPrefix(v, "env.") { + return value, nil + } + envKey := strings.TrimSpace(strings.TrimPrefix(v, "env.")) + if envKey == "" { + logger.Warn(fmt.Sprintf("Environment variable name missing in value: %s", value)) + return "", fmt.Errorf("environment variable name missing in %q", value) + } + if envValue, ok := os.LookupEnv(envKey); ok { + return envValue, nil + } + logger.Warn(fmt.Sprintf("Environment variable not found: %s", envKey)) + return "", fmt.Errorf("environment variable %s not found", envKey) +} + +// SQLiteConfig represents the configuration for a SQLite database. +type SQLiteConfig struct { + Path string `json:"path"` +} + +// SQLiteConfigStore represents a configuration store that uses a SQLite database. +type SQLiteConfigStore struct { + db *gorm.DB + logger schemas.Logger +} + +// UpdateClientConfig updates the client configuration in the database. +func (s *SQLiteConfigStore) UpdateClientConfig(config *ClientConfig) error { + dbConfig := TableClientConfig{ + DropExcessRequests: config.DropExcessRequests, + InitialPoolSize: config.InitialPoolSize, + EnableLogging: config.EnableLogging, + EnableGovernance: config.EnableGovernance, + EnforceGovernanceHeader: config.EnforceGovernanceHeader, + AllowDirectKeys: config.AllowDirectKeys, + PrometheusLabels: config.PrometheusLabels, + AllowedOrigins: config.AllowedOrigins, + MaxRequestBodySizeMB: config.MaxRequestBodySizeMB, + } + // Delete existing client config and create new one in a transaction + return s.db.Transaction(func(tx *gorm.DB) error { + if err := tx.Session(&gorm.Session{AllowGlobalUpdate: true}).Delete(&TableClientConfig{}).Error; err != nil { + return err + } + return tx.Create(&dbConfig).Error + }) +} + +// GetClientConfig retrieves the client configuration from the database. +func (s *SQLiteConfigStore) GetClientConfig() (*ClientConfig, error) { + var dbConfig TableClientConfig + if err := s.db.First(&dbConfig).Error; err != nil { + if errors.Is(err, gorm.ErrRecordNotFound) { + return nil, nil + } + return nil, err + } + return &ClientConfig{ + DropExcessRequests: dbConfig.DropExcessRequests, + InitialPoolSize: dbConfig.InitialPoolSize, + PrometheusLabels: dbConfig.PrometheusLabels, + EnableLogging: dbConfig.EnableLogging, + EnableGovernance: dbConfig.EnableGovernance, + EnforceGovernanceHeader: dbConfig.EnforceGovernanceHeader, + AllowDirectKeys: dbConfig.AllowDirectKeys, + AllowedOrigins: dbConfig.AllowedOrigins, + MaxRequestBodySizeMB: dbConfig.MaxRequestBodySizeMB, + }, nil +} + +// UpdateProvidersConfig updates the client configuration in the database. +func (s *SQLiteConfigStore) UpdateProvidersConfig(providers map[schemas.ModelProvider]ProviderConfig) error { + return s.db.Transaction(func(tx *gorm.DB) error { + // Delete all existing providers (cascades to keys) + if err := tx.Session(&gorm.Session{AllowGlobalUpdate: true}).Delete(&TableProvider{}).Error; err != nil { + if errors.Is(err, gorm.ErrRecordNotFound) { + return ErrNotFound + } + return err + } + + for providerName, providerConfig := range providers { + dbProvider := TableProvider{ + Name: string(providerName), + NetworkConfig: providerConfig.NetworkConfig, + ConcurrencyAndBufferSize: providerConfig.ConcurrencyAndBufferSize, + ProxyConfig: providerConfig.ProxyConfig, + SendBackRawResponse: providerConfig.SendBackRawResponse, + CustomProviderConfig: providerConfig.CustomProviderConfig, + } + + // Create provider first + if err := tx.Create(&dbProvider).Error; err != nil { + return err + } + + // Create keys for this provider + dbKeys := make([]TableKey, 0, len(providerConfig.Keys)) + for _, key := range providerConfig.Keys { + dbKey := TableKey{ + Provider: dbProvider.Name, + ProviderID: dbProvider.ID, + KeyID: key.ID, + Value: key.Value, + Models: key.Models, + Weight: key.Weight, + AzureKeyConfig: key.AzureKeyConfig, + VertexKeyConfig: key.VertexKeyConfig, + BedrockKeyConfig: key.BedrockKeyConfig, + } + + // Handle Azure config + if key.AzureKeyConfig != nil { + dbKey.AzureEndpoint = &key.AzureKeyConfig.Endpoint + dbKey.AzureAPIVersion = key.AzureKeyConfig.APIVersion + } + + // Handle Vertex config + if key.VertexKeyConfig != nil { + dbKey.VertexProjectID = &key.VertexKeyConfig.ProjectID + dbKey.VertexRegion = &key.VertexKeyConfig.Region + dbKey.VertexAuthCredentials = &key.VertexKeyConfig.AuthCredentials + } + + // Handle Bedrock config + if key.BedrockKeyConfig != nil { + dbKey.BedrockAccessKey = &key.BedrockKeyConfig.AccessKey + dbKey.BedrockSecretKey = &key.BedrockKeyConfig.SecretKey + dbKey.BedrockSessionToken = key.BedrockKeyConfig.SessionToken + dbKey.BedrockRegion = key.BedrockKeyConfig.Region + dbKey.BedrockARN = key.BedrockKeyConfig.ARN + } + + dbKeys = append(dbKeys, dbKey) + } + + // Upsert keys to handle duplicates properly + for _, dbKey := range dbKeys { + // First try to find existing key by KeyID + var existingKey TableKey + result := tx.Where("key_id = ?", dbKey.KeyID).First(&existingKey) + + if result.Error == nil { + // Update existing key with new data + dbKey.ID = existingKey.ID // Keep the same database ID + if err := tx.Save(&dbKey).Error; err != nil { + return err + } + } else if errors.Is(result.Error, gorm.ErrRecordNotFound) { + // Create new key + if err := tx.Create(&dbKey).Error; err != nil { + return err + } + } else { + // Other error occurred + return result.Error + } + } + } + return nil + }) +} + +// UpdateProviderById updates a single provider configuration in the database without deleting/recreating. +func (s *SQLiteConfigStore) UpdateProvider(provider schemas.ModelProvider, config ProviderConfig, envKeys map[string][]EnvKeyInfo) error { + return s.db.Transaction(func(tx *gorm.DB) error { + // Find the existing provider + var dbProvider TableProvider + if err := tx.Where("name = ?", string(provider)).First(&dbProvider).Error; err != nil { + if errors.Is(err, gorm.ErrRecordNotFound) { + return ErrNotFound + } + return err + } + + // Create a deep copy of the config to avoid modifying the original + configCopy, err := deepCopy(config) + if err != nil { + return err + } + // Substitute environment variables back to their original form + substituteEnvVars(&configCopy, provider, envKeys) + + // Update provider fields + dbProvider.NetworkConfig = configCopy.NetworkConfig + dbProvider.ConcurrencyAndBufferSize = configCopy.ConcurrencyAndBufferSize + dbProvider.ProxyConfig = configCopy.ProxyConfig + dbProvider.SendBackRawResponse = configCopy.SendBackRawResponse + dbProvider.CustomProviderConfig = configCopy.CustomProviderConfig + + // Save the updated provider + if err := tx.Save(&dbProvider).Error; err != nil { + return err + } + + // Get existing keys for this provider + var existingKeys []TableKey + if err := tx.Where("provider_id = ?", dbProvider.ID).Find(&existingKeys).Error; err != nil { + return err + } + + // Create a map of existing keys by KeyID for quick lookup + existingKeysMap := make(map[string]TableKey) + for _, key := range existingKeys { + existingKeysMap[key.KeyID] = key + } + + // Process each key in the new config + for _, key := range configCopy.Keys { + dbKey := TableKey{ + Provider: dbProvider.Name, + ProviderID: dbProvider.ID, + KeyID: key.ID, + Value: key.Value, + Models: key.Models, + Weight: key.Weight, + AzureKeyConfig: key.AzureKeyConfig, + VertexKeyConfig: key.VertexKeyConfig, + BedrockKeyConfig: key.BedrockKeyConfig, + } + + // Handle Azure config + if key.AzureKeyConfig != nil { + dbKey.AzureEndpoint = &key.AzureKeyConfig.Endpoint + dbKey.AzureAPIVersion = key.AzureKeyConfig.APIVersion + } + + // Handle Vertex config + if key.VertexKeyConfig != nil { + dbKey.VertexProjectID = &key.VertexKeyConfig.ProjectID + dbKey.VertexRegion = &key.VertexKeyConfig.Region + dbKey.VertexAuthCredentials = &key.VertexKeyConfig.AuthCredentials + } + + // Handle Bedrock config + if key.BedrockKeyConfig != nil { + dbKey.BedrockAccessKey = &key.BedrockKeyConfig.AccessKey + dbKey.BedrockSecretKey = &key.BedrockKeyConfig.SecretKey + dbKey.BedrockSessionToken = key.BedrockKeyConfig.SessionToken + dbKey.BedrockRegion = key.BedrockKeyConfig.Region + dbKey.BedrockARN = key.BedrockKeyConfig.ARN + } + + // Check if this key already exists + if existingKey, exists := existingKeysMap[key.ID]; exists { + // Update existing key - preserve the database ID + dbKey.ID = existingKey.ID + if err := tx.Save(&dbKey).Error; err != nil { + return err + } + // Remove from map to track which keys are still in use + delete(existingKeysMap, key.ID) + } else { + // Create new key + if err := tx.Create(&dbKey).Error; err != nil { + return err + } + } + } + + // Delete keys that are no longer in the new config + for _, keyToDelete := range existingKeysMap { + if err := tx.Delete(&keyToDelete).Error; err != nil { + return err + } + } + + return nil + }) +} + +// AddProvider creates a new provider configuration in the database. +func (s *SQLiteConfigStore) AddProvider(provider schemas.ModelProvider, config ProviderConfig, envKeys map[string][]EnvKeyInfo) error { + return s.db.Transaction(func(tx *gorm.DB) error { + // Check if provider already exists + var existingProvider TableProvider + if err := tx.Where("name = ?", string(provider)).First(&existingProvider).Error; err == nil { + return fmt.Errorf("provider %s already exists", provider) + } else if !errors.Is(err, gorm.ErrRecordNotFound) { + return err + } + + // Create a deep copy of the config to avoid modifying the original + configCopy, err := deepCopy(config) + if err != nil { + return err + } + // Substitute environment variables back to their original form + substituteEnvVars(&configCopy, provider, envKeys) + + // Create new provider + dbProvider := TableProvider{ + Name: string(provider), + NetworkConfig: configCopy.NetworkConfig, + ConcurrencyAndBufferSize: configCopy.ConcurrencyAndBufferSize, + ProxyConfig: configCopy.ProxyConfig, + SendBackRawResponse: configCopy.SendBackRawResponse, + CustomProviderConfig: configCopy.CustomProviderConfig, + } + + // Create the provider + if err := tx.Create(&dbProvider).Error; err != nil { + return err + } + + // Create keys for this provider + for _, key := range configCopy.Keys { + dbKey := TableKey{ + Provider: dbProvider.Name, + ProviderID: dbProvider.ID, + KeyID: key.ID, + Value: key.Value, + Models: key.Models, + Weight: key.Weight, + AzureKeyConfig: key.AzureKeyConfig, + VertexKeyConfig: key.VertexKeyConfig, + BedrockKeyConfig: key.BedrockKeyConfig, + } + + // Handle Azure config + if key.AzureKeyConfig != nil { + dbKey.AzureEndpoint = &key.AzureKeyConfig.Endpoint + dbKey.AzureAPIVersion = key.AzureKeyConfig.APIVersion + } + + // Handle Vertex config + if key.VertexKeyConfig != nil { + dbKey.VertexProjectID = &key.VertexKeyConfig.ProjectID + dbKey.VertexRegion = &key.VertexKeyConfig.Region + dbKey.VertexAuthCredentials = &key.VertexKeyConfig.AuthCredentials + } + + // Handle Bedrock config + if key.BedrockKeyConfig != nil { + dbKey.BedrockAccessKey = &key.BedrockKeyConfig.AccessKey + dbKey.BedrockSecretKey = &key.BedrockKeyConfig.SecretKey + dbKey.BedrockSessionToken = key.BedrockKeyConfig.SessionToken + dbKey.BedrockRegion = key.BedrockKeyConfig.Region + dbKey.BedrockARN = key.BedrockKeyConfig.ARN + } + + // Create the key + if err := tx.Create(&dbKey).Error; err != nil { + return err + } + } + + return nil + }) +} + +// DeleteProvider deletes a single provider and all its associated keys from the database. +func (s *SQLiteConfigStore) DeleteProvider(provider schemas.ModelProvider) error { + return s.db.Transaction(func(tx *gorm.DB) error { + // Find the existing provider + var dbProvider TableProvider + if err := tx.Where("name = ?", string(provider)).First(&dbProvider).Error; err != nil { + if errors.Is(err, gorm.ErrRecordNotFound) { + return ErrNotFound + } + return err + } + + // Delete the provider (keys will be deleted due to CASCADE constraint) + if err := tx.Delete(&dbProvider).Error; err != nil { + return err + } + + return nil + }) +} + +// GetProvidersConfig retrieves the provider configuration from the database. +func (s *SQLiteConfigStore) GetProvidersConfig() (map[schemas.ModelProvider]ProviderConfig, error) { + var dbProviders []TableProvider + if err := s.db.Preload("Keys").Find(&dbProviders).Error; err != nil { + if errors.Is(err, gorm.ErrRecordNotFound) { + return nil, ErrNotFound + } + return nil, err + } + if len(dbProviders) == 0 { + // No providers in database, auto-detect from environment + return nil, nil + } + processedProviders := make(map[schemas.ModelProvider]ProviderConfig) + for _, dbProvider := range dbProviders { + provider := schemas.ModelProvider(dbProvider.Name) + // Convert database keys to schemas.Key + keys := make([]schemas.Key, len(dbProvider.Keys)) + for i, dbKey := range dbProvider.Keys { + // Process main key value + processedValue, err := processEnvValue(dbKey.Value, s.logger) + if err != nil { + // If env var not found, keep the original value + processedValue = dbKey.Value + } + + // Process Azure config if present + azureConfig := dbKey.AzureKeyConfig + if azureConfig != nil { + azureConfigCopy := *azureConfig + if processedEndpoint, err := processEnvValue(azureConfig.Endpoint, s.logger); err == nil { + azureConfigCopy.Endpoint = processedEndpoint + } + if azureConfig.APIVersion != nil { + if processedAPIVersion, err := processEnvValue(*azureConfig.APIVersion, s.logger); err == nil { + azureConfigCopy.APIVersion = &processedAPIVersion + } + } + azureConfig = &azureConfigCopy + } + + // Process Vertex config if present + vertexConfig := dbKey.VertexKeyConfig + if vertexConfig != nil { + vertexConfigCopy := *vertexConfig + if processedProjectID, err := processEnvValue(vertexConfig.ProjectID, s.logger); err == nil { + vertexConfigCopy.ProjectID = processedProjectID + } + if processedRegion, err := processEnvValue(vertexConfig.Region, s.logger); err == nil { + vertexConfigCopy.Region = processedRegion + } + if processedAuthCredentials, err := processEnvValue(vertexConfig.AuthCredentials, s.logger); err == nil { + vertexConfigCopy.AuthCredentials = processedAuthCredentials + } + vertexConfig = &vertexConfigCopy + } + + // Process Bedrock config if present + bedrockConfig := dbKey.BedrockKeyConfig + if bedrockConfig != nil { + bedrockConfigCopy := *bedrockConfig + if processedAccessKey, err := processEnvValue(bedrockConfig.AccessKey, s.logger); err == nil { + bedrockConfigCopy.AccessKey = processedAccessKey + } + if processedSecretKey, err := processEnvValue(bedrockConfig.SecretKey, s.logger); err == nil { + bedrockConfigCopy.SecretKey = processedSecretKey + } + if bedrockConfig.SessionToken != nil { + if processedSessionToken, err := processEnvValue(*bedrockConfig.SessionToken, s.logger); err == nil { + bedrockConfigCopy.SessionToken = &processedSessionToken + } + } + if bedrockConfig.Region != nil { + if processedRegion, err := processEnvValue(*bedrockConfig.Region, s.logger); err == nil { + bedrockConfigCopy.Region = &processedRegion + } + } + if bedrockConfig.ARN != nil { + if processedARN, err := processEnvValue(*bedrockConfig.ARN, s.logger); err == nil { + bedrockConfigCopy.ARN = &processedARN + } + } + bedrockConfig = &bedrockConfigCopy + } + + keys[i] = schemas.Key{ + ID: dbKey.KeyID, + Value: processedValue, + Models: dbKey.Models, + Weight: dbKey.Weight, + AzureKeyConfig: azureConfig, + VertexKeyConfig: vertexConfig, + BedrockKeyConfig: bedrockConfig, + } + } + providerConfig := ProviderConfig{ + Keys: keys, + NetworkConfig: dbProvider.NetworkConfig, + ConcurrencyAndBufferSize: dbProvider.ConcurrencyAndBufferSize, + ProxyConfig: dbProvider.ProxyConfig, + SendBackRawResponse: dbProvider.SendBackRawResponse, + CustomProviderConfig: dbProvider.CustomProviderConfig, + } + processedProviders[provider] = providerConfig + } + return processedProviders, nil +} + +// GetMCPConfig retrieves the MCP configuration from the database. +func (s *SQLiteConfigStore) GetMCPConfig() (*schemas.MCPConfig, error) { + var dbMCPClients []TableMCPClient + if err := s.db.Find(&dbMCPClients).Error; err != nil { + return nil, err + } + if len(dbMCPClients) == 0 { + return nil, nil + } + clientConfigs := make([]schemas.MCPClientConfig, len(dbMCPClients)) + for i, dbClient := range dbMCPClients { + // Process connection string for environment variables + var processedConnectionString *string + if dbClient.ConnectionString != nil { + processedValue, err := processEnvValue(*dbClient.ConnectionString, s.logger) + if err != nil { + // If env var not found, keep the original value + processedValue = *dbClient.ConnectionString + } + processedConnectionString = &processedValue + } + + clientConfigs[i] = schemas.MCPClientConfig{ + Name: dbClient.Name, + ConnectionType: schemas.MCPConnectionType(dbClient.ConnectionType), + ConnectionString: processedConnectionString, + StdioConfig: dbClient.StdioConfig, + ToolsToExecute: dbClient.ToolsToExecute, + ToolsToSkip: dbClient.ToolsToSkip, + } + } + return &schemas.MCPConfig{ + ClientConfigs: clientConfigs, + }, nil +} + +// UpdateMCPConfig updates the MCP configuration in the database. +func (s *SQLiteConfigStore) UpdateMCPConfig(config *schemas.MCPConfig, envKeys map[string][]EnvKeyInfo) error { + return s.db.Transaction(func(tx *gorm.DB) error { + // Removing existing MCP clients + if err := tx.Session(&gorm.Session{AllowGlobalUpdate: true}).Delete(&TableMCPClient{}).Error; err != nil { + return err + } + + if config == nil { + return nil + } + + // Create a deep copy of the config to avoid modifying the original + configCopy, err := deepCopy(config) + if err != nil { + return err + } + // Substitute environment variables back to their original form + substituteMCPEnvVars(configCopy, envKeys) + + dbClients := make([]TableMCPClient, 0, len(configCopy.ClientConfigs)) + for _, clientConfig := range configCopy.ClientConfigs { + dbClient := TableMCPClient{ + Name: clientConfig.Name, + ConnectionType: string(clientConfig.ConnectionType), + ConnectionString: clientConfig.ConnectionString, + StdioConfig: clientConfig.StdioConfig, + ToolsToExecute: clientConfig.ToolsToExecute, + ToolsToSkip: clientConfig.ToolsToSkip, + } + + dbClients = append(dbClients, dbClient) + } + + if len(dbClients) > 0 { + if err := tx.CreateInBatches(dbClients, 100).Error; err != nil { + return err + } + } + + return nil + }) +} + +// GetVectorStoreConfig retrieves the vector store configuration from the database. +func (s *SQLiteConfigStore) GetVectorStoreConfig() (*vectorstore.Config, error) { + var vectorStoreTableConfig TableVectorStoreConfig + if err := s.db.First(&vectorStoreTableConfig).Error; err != nil { + if errors.Is(err, gorm.ErrRecordNotFound) { + // Return default cache configuration + return nil, nil + } + return nil, err + } + return &vectorstore.Config{ + Enabled: vectorStoreTableConfig.Enabled, + Config: vectorStoreTableConfig.Config, + Type: vectorstore.VectorStoreType(vectorStoreTableConfig.Type), + }, nil +} + +// UpdateVectorStoreConfig updates the vector store configuration in the database. +func (s *SQLiteConfigStore) UpdateVectorStoreConfig(config *vectorstore.Config) error { + return s.db.Transaction(func(tx *gorm.DB) error { + // Delete existing cache config + if err := tx.Session(&gorm.Session{AllowGlobalUpdate: true}).Delete(&TableVectorStoreConfig{}).Error; err != nil { + return err + } + jsonConfig, err := marshalToStringPtr(config.Config) + if err != nil { + return err + } + var record = &TableVectorStoreConfig{ + Type: string(config.Type), + Enabled: config.Enabled, + Config: jsonConfig, + } + // Create new cache config + return tx.Create(record).Error + }) +} + +// GetLogsStoreConfig retrieves the logs store configuration from the database. +func (s *SQLiteConfigStore) GetLogsStoreConfig() (*logstore.Config, error) { + var dbConfig TableLogStoreConfig + if err := s.db.First(&dbConfig).Error; err != nil { + if errors.Is(err, gorm.ErrRecordNotFound) { + return nil, nil + } + return nil, err + } + if dbConfig.Config == nil || *dbConfig.Config == "" { + return &logstore.Config{Enabled: dbConfig.Enabled}, nil + } + var logStoreConfig logstore.Config + if err := json.Unmarshal([]byte(*dbConfig.Config), &logStoreConfig); err != nil { + return nil, err + } + return &logStoreConfig, nil +} + +// UpdateLogsStoreConfig updates the logs store configuration in the database. +func (s *SQLiteConfigStore) UpdateLogsStoreConfig(config *logstore.Config) error { + return s.db.Transaction(func(tx *gorm.DB) error { + if err := tx.Session(&gorm.Session{AllowGlobalUpdate: true}).Delete(&TableLogStoreConfig{}).Error; err != nil { + return err + } + jsonConfig, err := marshalToStringPtr(config) + if err != nil { + return err + } + var record = &TableLogStoreConfig{ + Enabled: config.Enabled, + Type: string(config.Type), + Config: jsonConfig, + } + return tx.Create(record).Error + }) +} + +// GetEnvKeys retrieves the environment keys from the database. +func (s *SQLiteConfigStore) GetEnvKeys() (map[string][]EnvKeyInfo, error) { + var dbEnvKeys []TableEnvKey + if err := s.db.Find(&dbEnvKeys).Error; err != nil { + if errors.Is(err, gorm.ErrRecordNotFound) { + return nil, ErrNotFound + } + return nil, err + } + envKeys := make(map[string][]EnvKeyInfo) + for _, dbEnvKey := range dbEnvKeys { + envKeys[dbEnvKey.EnvVar] = append(envKeys[dbEnvKey.EnvVar], EnvKeyInfo{ + EnvVar: dbEnvKey.EnvVar, + Provider: schemas.ModelProvider(dbEnvKey.Provider), + KeyType: EnvKeyType(dbEnvKey.KeyType), + ConfigPath: dbEnvKey.ConfigPath, + KeyID: dbEnvKey.KeyID, + }) + } + return envKeys, nil +} + +// UpdateEnvKeys updates the environment keys in the database. +func (s *SQLiteConfigStore) UpdateEnvKeys(keys map[string][]EnvKeyInfo) error { + return s.db.Transaction(func(tx *gorm.DB) error { + // Delete existing env keys + if err := tx.Session(&gorm.Session{AllowGlobalUpdate: true}).Delete(&TableEnvKey{}).Error; err != nil { + return err + } + var dbEnvKeys []TableEnvKey + for envVar, infos := range keys { + for _, info := range infos { + dbEnvKey := TableEnvKey{ + EnvVar: envVar, + Provider: string(info.Provider), + KeyType: string(info.KeyType), + ConfigPath: info.ConfigPath, + KeyID: info.KeyID, + } + dbEnvKeys = append(dbEnvKeys, dbEnvKey) + } + } + if len(dbEnvKeys) > 0 { + if err := tx.CreateInBatches(dbEnvKeys, 100).Error; err != nil { + return err + } + } + return nil + }) +} + +// GetConfig retrieves a specific config from the database. +func (s *SQLiteConfigStore) GetConfig(key string) (*TableConfig, error) { + var config TableConfig + if err := s.db.First(&config, "key = ?", key).Error; err != nil { + if errors.Is(err, gorm.ErrRecordNotFound) { + return nil, ErrNotFound + } + return nil, err + } + return &config, nil +} + +// UpdateConfig updates a specific config in the database. +func (s *SQLiteConfigStore) UpdateConfig(config *TableConfig, tx ...*gorm.DB) error { + var txDB *gorm.DB + if len(tx) > 0 { + txDB = tx[0] + } else { + txDB = s.db + } + return txDB.Save(config).Error +} + +// GetModelPrices retrieves all model pricing records from the database. +func (s *SQLiteConfigStore) GetModelPrices() ([]TableModelPricing, error) { + var modelPrices []TableModelPricing + if err := s.db.Find(&modelPrices).Error; err != nil { + return nil, err + } + return modelPrices, nil +} + +// CreateModelPrices creates a new model pricing record in the database. +func (s *SQLiteConfigStore) CreateModelPrices(pricing *TableModelPricing, tx ...*gorm.DB) error { + var txDB *gorm.DB + if len(tx) > 0 { + txDB = tx[0] + } else { + txDB = s.db + } + return txDB.Create(pricing).Error +} + +// DeleteModelPrices deletes all model pricing records from the database. +func (s *SQLiteConfigStore) DeleteModelPrices(tx ...*gorm.DB) error { + var txDB *gorm.DB + if len(tx) > 0 { + txDB = tx[0] + } else { + txDB = s.db + } + return txDB.Session(&gorm.Session{AllowGlobalUpdate: true}).Delete(&TableModelPricing{}).Error +} + +// PLUGINS METHODS + +func (s *SQLiteConfigStore) GetPlugins() ([]TablePlugin, error) { + var plugins []TablePlugin + if err := s.db.Find(&plugins).Error; err != nil { + return nil, err + } + return plugins, nil +} + +func (s *SQLiteConfigStore) GetPlugin(name string) (*TablePlugin, error) { + var plugin TablePlugin + if err := s.db.First(&plugin, "name = ?", name).Error; err != nil { + if errors.Is(err, gorm.ErrRecordNotFound) { + return nil, ErrNotFound + } + return nil, err + } + return &plugin, nil +} + +func (s *SQLiteConfigStore) CreatePlugin(plugin *TablePlugin, tx ...*gorm.DB) error { + var txDB *gorm.DB + if len(tx) > 0 { + txDB = tx[0] + } else { + txDB = s.db + } + return txDB.Create(plugin).Error +} + +func (s *SQLiteConfigStore) UpdatePlugin(plugin *TablePlugin, tx ...*gorm.DB) error { + var txDB *gorm.DB + var localTx bool + + if len(tx) > 0 { + txDB = tx[0] + localTx = false + } else { + txDB = s.db.Begin() + localTx = true + } + + if err := txDB.Delete(&TablePlugin{}, "name = ?", plugin.Name).Error; err != nil { + if localTx { + txDB.Rollback() + } + return err + } + + if err := txDB.Create(plugin).Error; err != nil { + if localTx { + txDB.Rollback() + } + return err + } + + if localTx { + return txDB.Commit().Error + } + + return nil +} + +func (s *SQLiteConfigStore) DeletePlugin(name string, tx ...*gorm.DB) error { + var txDB *gorm.DB + if len(tx) > 0 { + txDB = tx[0] + } else { + txDB = s.db + } + return txDB.Delete(&TablePlugin{}, "name = ?", name).Error +} + +// GOVERNANCE METHODS + +// GetVirtualKeys retrieves all virtual keys from the database. +func (s *SQLiteConfigStore) GetVirtualKeys() ([]TableVirtualKey, error) { + var virtualKeys []TableVirtualKey + + // Preload all relationships for complete information + if err := s.db.Preload("Team"). + Preload("Customer"). + Preload("Budget"). + Preload("RateLimit"). + Preload("Keys", func(db *gorm.DB) *gorm.DB { + return db.Select("id, key_id, models_json") + }).Find(&virtualKeys).Error; err != nil { + return nil, err + } + + return virtualKeys, nil +} + +// GetVirtualKey retrieves a virtual key from the database. +func (s *SQLiteConfigStore) GetVirtualKey(id string) (*TableVirtualKey, error) { + var virtualKey TableVirtualKey + if err := s.db.Preload("Team"). + Preload("Customer"). + Preload("Budget"). + Preload("RateLimit"). + Preload("Keys", func(db *gorm.DB) *gorm.DB { + return db.Select("id, key_id, models_json") + }).First(&virtualKey, "id = ?", id).Error; err != nil { + return nil, err + } + return &virtualKey, nil +} + +func (s *SQLiteConfigStore) CreateVirtualKey(virtualKey *TableVirtualKey, tx ...*gorm.DB) error { + var txDB *gorm.DB + if len(tx) > 0 { + txDB = tx[0] + } else { + txDB = s.db + } + + // Create virtual key first + if err := txDB.Create(virtualKey).Error; err != nil { + return err + } + + // Create key associations after the virtual key has an ID + if len(virtualKey.Keys) > 0 { + if err := txDB.Model(virtualKey).Association("Keys").Append(virtualKey.Keys); err != nil { + return err + } + } + + return nil +} + +func (s *SQLiteConfigStore) UpdateVirtualKey(virtualKey *TableVirtualKey, tx ...*gorm.DB) error { + var txDB *gorm.DB + if len(tx) > 0 { + txDB = tx[0] + } else { + txDB = s.db + } + + // Store the keys before Save() clears them + keysToAssociate := virtualKey.Keys + + // Update virtual key first (this will clear the Keys field) + if err := txDB.Save(virtualKey).Error; err != nil { + return err + } + + // Clear existing key associations + if err := txDB.Model(virtualKey).Association("Keys").Clear(); err != nil { + return err + } + + // Create new key associations using the stored keys + if len(keysToAssociate) > 0 { + if err := txDB.Model(virtualKey).Association("Keys").Append(keysToAssociate); err != nil { + return err + } + } + + return nil +} + +// GetKeysByIDs retrieves multiple keys by their IDs +func (s *SQLiteConfigStore) GetKeysByIDs(ids []string) ([]TableKey, error) { + if len(ids) == 0 { + return []TableKey{}, nil + } + + var keys []TableKey + if err := s.db.Where("key_id IN ?", ids).Find(&keys).Error; err != nil { + return nil, err + } + return keys, nil +} + +// DeleteVirtualKey deletes a virtual key from the database. +func (s *SQLiteConfigStore) DeleteVirtualKey(id string) error { + return s.db.Delete(&TableVirtualKey{}, "id = ?", id).Error +} + +// GetTeams retrieves all teams from the database. +func (s *SQLiteConfigStore) GetTeams(customerID string) ([]TableTeam, error) { + // Preload relationships for complete information + query := s.db.Preload("Customer").Preload("Budget") + + // Optional filtering by customer + if customerID != "" { + query = query.Where("customer_id = ?", customerID) + } + + var teams []TableTeam + if err := query.Find(&teams).Error; err != nil { + return nil, err + } + return teams, nil +} + +// GetTeam retrieves a specific team from the database. +func (s *SQLiteConfigStore) GetTeam(id string) (*TableTeam, error) { + var team TableTeam + if err := s.db.Preload("Customer").Preload("Budget").First(&team, "id = ?", id).Error; err != nil { + return nil, err + } + return &team, nil +} + +// CreateTeam creates a new team in the database. +func (s *SQLiteConfigStore) CreateTeam(team *TableTeam, tx ...*gorm.DB) error { + var txDB *gorm.DB + if len(tx) > 0 { + txDB = tx[0] + } else { + txDB = s.db + } + return txDB.Create(team).Error +} + +// UpdateTeam updates an existing team in the database. +func (s *SQLiteConfigStore) UpdateTeam(team *TableTeam, tx ...*gorm.DB) error { + var txDB *gorm.DB + if len(tx) > 0 { + txDB = tx[0] + } else { + txDB = s.db + } + return txDB.Save(team).Error +} + +// DeleteTeam deletes a team from the database. +func (s *SQLiteConfigStore) DeleteTeam(id string) error { + return s.db.Delete(&TableTeam{}, "id = ?", id).Error +} + +// GetCustomers retrieves all customers from the database. +func (s *SQLiteConfigStore) GetCustomers() ([]TableCustomer, error) { + var customers []TableCustomer + if err := s.db.Preload("Teams").Preload("Budget").Find(&customers).Error; err != nil { + return nil, err + } + return customers, nil +} + +// GetCustomer retrieves a specific customer from the database. +func (s *SQLiteConfigStore) GetCustomer(id string) (*TableCustomer, error) { + var customer TableCustomer + if err := s.db.Preload("Teams").Preload("Budget").First(&customer, "id = ?", id).Error; err != nil { + return nil, err + } + return &customer, nil +} + +// CreateCustomer creates a new customer in the database. +func (s *SQLiteConfigStore) CreateCustomer(customer *TableCustomer, tx ...*gorm.DB) error { + var txDB *gorm.DB + if len(tx) > 0 { + txDB = tx[0] + } else { + txDB = s.db + } + return txDB.Create(customer).Error +} + +// UpdateCustomer updates an existing customer in the database. +func (s *SQLiteConfigStore) UpdateCustomer(customer *TableCustomer, tx ...*gorm.DB) error { + var txDB *gorm.DB + if len(tx) > 0 { + txDB = tx[0] + } else { + txDB = s.db + } + return txDB.Save(customer).Error +} + +// DeleteCustomer deletes a customer from the database. +func (s *SQLiteConfigStore) DeleteCustomer(id string) error { + return s.db.Delete(&TableCustomer{}, "id = ?", id).Error +} + +// GetRateLimit retrieves a specific rate limit from the database. +func (s *SQLiteConfigStore) GetRateLimit(id string) (*TableRateLimit, error) { + var rateLimit TableRateLimit + if err := s.db.First(&rateLimit, "id = ?", id).Error; err != nil { + return nil, err + } + return &rateLimit, nil +} + +// CreateRateLimit creates a new rate limit in the database. +func (s *SQLiteConfigStore) CreateRateLimit(rateLimit *TableRateLimit, tx ...*gorm.DB) error { + var txDB *gorm.DB + if len(tx) > 0 { + txDB = tx[0] + } else { + txDB = s.db + } + return txDB.Create(rateLimit).Error +} + +// UpdateRateLimit updates a rate limit in the database. +func (s *SQLiteConfigStore) UpdateRateLimit(rateLimit *TableRateLimit, tx ...*gorm.DB) error { + var txDB *gorm.DB + if len(tx) > 0 { + txDB = tx[0] + } else { + txDB = s.db + } + return txDB.Save(rateLimit).Error +} + +// UpdateRateLimits updates multiple rate limits in the database. +func (s *SQLiteConfigStore) UpdateRateLimits(rateLimits []*TableRateLimit, tx ...*gorm.DB) error { + var txDB *gorm.DB + if len(tx) > 0 { + txDB = tx[0] + } else { + txDB = s.db + } + for _, rl := range rateLimits { + if err := txDB.Save(rl).Error; err != nil { + return err + } + } + return nil +} + +// GetBudgets retrieves all budgets from the database. +func (s *SQLiteConfigStore) GetBudgets() ([]TableBudget, error) { + var budgets []TableBudget + if err := s.db.Find(&budgets).Error; err != nil { + if errors.Is(err, gorm.ErrRecordNotFound) { + return nil, ErrNotFound + } + return nil, err + } + return budgets, nil +} + +// GetBudget retrieves a specific budget from the database. +func (s *SQLiteConfigStore) GetBudget(id string, tx ...*gorm.DB) (*TableBudget, error) { + var txDB *gorm.DB + if len(tx) > 0 { + txDB = tx[0] + } else { + txDB = s.db + } + var budget TableBudget + if err := txDB.First(&budget, "id = ?", id).Error; err != nil { + if errors.Is(err, gorm.ErrRecordNotFound) { + return nil, ErrNotFound + } + return nil, err + } + return &budget, nil +} + +// CreateBudget creates a new budget in the database. +func (s *SQLiteConfigStore) CreateBudget(budget *TableBudget, tx ...*gorm.DB) error { + var txDB *gorm.DB + if len(tx) > 0 { + txDB = tx[0] + } else { + txDB = s.db + } + return txDB.Create(budget).Error +} + +// UpdateBudgets updates multiple budgets in the database. +func (s *SQLiteConfigStore) UpdateBudgets(budgets []*TableBudget, tx ...*gorm.DB) error { + var txDB *gorm.DB + if len(tx) > 0 { + txDB = tx[0] + } else { + txDB = s.db + } + s.logger.Debug("updating budgets: %+v", budgets) + for _, b := range budgets { + if err := txDB.Save(b).Error; err != nil { + return err + } + } + return nil +} + +// UpdateBudget updates a budget in the database. +func (s *SQLiteConfigStore) UpdateBudget(budget *TableBudget, tx ...*gorm.DB) error { + var txDB *gorm.DB + if len(tx) > 0 { + txDB = tx[0] + } else { + txDB = s.db + } + return txDB.Save(budget).Error +} + +// GetGovernanceConfig retrieves the governance configuration from the database. +func (s *SQLiteConfigStore) GetGovernanceConfig() (*GovernanceConfig, error) { + var virtualKeys []TableVirtualKey + var teams []TableTeam + var customers []TableCustomer + var budgets []TableBudget + var rateLimits []TableRateLimit + + if err := s.db.Find(&virtualKeys).Error; err != nil { + return nil, err + } + if err := s.db.Find(&teams).Error; err != nil { + return nil, err + } + if err := s.db.Find(&customers).Error; err != nil { + return nil, err + } + if err := s.db.Find(&budgets).Error; err != nil { + return nil, err + } + if err := s.db.Find(&rateLimits).Error; err != nil { + return nil, err + } + + if len(virtualKeys) == 0 && len(teams) == 0 && len(customers) == 0 && len(budgets) == 0 && len(rateLimits) == 0 { + return nil, nil + } + + return &GovernanceConfig{ + VirtualKeys: virtualKeys, + Teams: teams, + Customers: customers, + Budgets: budgets, + RateLimits: rateLimits, + }, nil +} + +// ExecuteTransaction executes a transaction. +func (s *SQLiteConfigStore) ExecuteTransaction(fn func(tx *gorm.DB) error) error { + return s.db.Transaction(fn) +} + +func (s *SQLiteConfigStore) doesTableExist(tableName string) bool { + var count int64 + if err := s.db.Raw("SELECT COUNT(*) FROM sqlite_master WHERE type='table' AND name=?", tableName).Scan(&count).Error; err != nil { + return false + } + return count > 0 +} + +// removeNullKeys removes null keys from the database. +func (s *SQLiteConfigStore) removeNullKeys() error { + return s.db.Exec("DELETE FROM config_keys WHERE key_id IS NULL OR value IS NULL").Error +} + +// removeDuplicateKeysAndNullKeys removes duplicate keys based on key_id and value combination +// Keeps the record with the smallest ID (oldest record) and deletes duplicates +func (s *SQLiteConfigStore) removeDuplicateKeysAndNullKeys() error { + s.logger.Debug("removing duplicate keys and null keys from the database") + // Check if the config_keys table exists first + if !s.doesTableExist("config_keys") { + return nil + } + s.logger.Debug("removing null keys from the database") + // First, remove null keys + if err := s.removeNullKeys(); err != nil { + return fmt.Errorf("failed to remove null keys: %w", err) + } + s.logger.Debug("deleting duplicate keys from the database") + // Find and delete duplicate keys, keeping only the one with the smallest ID + // This query deletes all records except the one with the minimum ID for each (key_id, value) pair + result := s.db.Exec(` + DELETE FROM config_keys + WHERE id NOT IN ( + SELECT MIN(id) + FROM config_keys + GROUP BY key_id, value + ) + `) + + if result.Error != nil { + return fmt.Errorf("failed to remove duplicate keys: %w", result.Error) + } + s.logger.Debug("migration complete") + return nil +} + +// newSqliteConfigStore creates a new SQLite config store. +func newSqliteConfigStore(config *SQLiteConfig, logger schemas.Logger) (ConfigStore, error) { + if _, err := os.Stat(config.Path); os.IsNotExist(err) { + // Create DB file + f, err := os.Create(config.Path) + if err != nil { + return nil, err + } + _ = f.Close() + } + dsn := fmt.Sprintf("%s?_journal_mode=WAL&_synchronous=NORMAL&_cache_size=10000&_busy_timeout=60000&_wal_autocheckpoint=1000&_foreign_keys=1", config.Path) + logger.Debug("opening DB with dsn: %s", dsn) + db, err := gorm.Open(sqlite.Open(dsn), &gorm.Config{ + Logger: gormLogger.Default.LogMode(gormLogger.Silent), + }) + + if err != nil { + return nil, err + } + logger.Debug("db opened for configstore") + s := &SQLiteConfigStore{db: db, logger: logger} + logger.Debug("running migration to remove duplicate keys") + // Run migration to remove duplicate keys before AutoMigrate + if err := s.removeDuplicateKeysAndNullKeys(); err != nil { + return nil, fmt.Errorf("failed to remove duplicate keys: %w", err) + } + // Auto migrate to all new tables + if err := triggerMigrations(db); err != nil { + return nil, err + } + return s, nil +} diff --git a/framework/configstore/store.go b/framework/configstore/store.go new file mode 100644 index 000000000..7e4f9d6d4 --- /dev/null +++ b/framework/configstore/store.go @@ -0,0 +1,119 @@ +// Package configstore provides a persistent configuration store for Bifrost. +package configstore + +import ( + "fmt" + + "github.com/maximhq/bifrost/core/schemas" + "github.com/maximhq/bifrost/framework/logstore" + "github.com/maximhq/bifrost/framework/vectorstore" + "gorm.io/gorm" +) + +// ConfigStore is the interface for the config store. +type ConfigStore interface { + + // Client config CRUD + UpdateClientConfig(config *ClientConfig) error + GetClientConfig() (*ClientConfig, error) + + // Provider config CRUD + UpdateProvidersConfig(providers map[schemas.ModelProvider]ProviderConfig) error + AddProvider(provider schemas.ModelProvider, config ProviderConfig, envKeys map[string][]EnvKeyInfo) error + UpdateProvider(provider schemas.ModelProvider, config ProviderConfig, envKeys map[string][]EnvKeyInfo) error + DeleteProvider(provider schemas.ModelProvider) error + GetProvidersConfig() (map[schemas.ModelProvider]ProviderConfig, error) + + // MCP config CRUD + UpdateMCPConfig(config *schemas.MCPConfig, envKeys map[string][]EnvKeyInfo) error + GetMCPConfig() (*schemas.MCPConfig, error) + + // Vector store config CRUD + UpdateVectorStoreConfig(config *vectorstore.Config) error + GetVectorStoreConfig() (*vectorstore.Config, error) + + // Logs store config CRUD + UpdateLogsStoreConfig(config *logstore.Config) error + GetLogsStoreConfig() (*logstore.Config, error) + + // ENV keys CRUD + UpdateEnvKeys(keys map[string][]EnvKeyInfo) error + GetEnvKeys() (map[string][]EnvKeyInfo, error) + + // Config CRUD + GetConfig(key string) (*TableConfig, error) + UpdateConfig(config *TableConfig, tx ...*gorm.DB) error + + // Plugins CRUD + GetPlugins() ([]TablePlugin, error) + GetPlugin(name string) (*TablePlugin, error) + CreatePlugin(plugin *TablePlugin, tx ...*gorm.DB) error + UpdatePlugin(plugin *TablePlugin, tx ...*gorm.DB) error + DeletePlugin(name string, tx ...*gorm.DB) error + + // Governance config CRUD + GetVirtualKeys() ([]TableVirtualKey, error) + GetVirtualKey(id string) (*TableVirtualKey, error) + CreateVirtualKey(virtualKey *TableVirtualKey, tx ...*gorm.DB) error + UpdateVirtualKey(virtualKey *TableVirtualKey, tx ...*gorm.DB) error + DeleteVirtualKey(id string) error + + // Team CRUD + GetTeams(customerID string) ([]TableTeam, error) + GetTeam(id string) (*TableTeam, error) + CreateTeam(team *TableTeam, tx ...*gorm.DB) error + UpdateTeam(team *TableTeam, tx ...*gorm.DB) error + DeleteTeam(id string) error + + // Customer CRUD + GetCustomers() ([]TableCustomer, error) + GetCustomer(id string) (*TableCustomer, error) + CreateCustomer(customer *TableCustomer, tx ...*gorm.DB) error + UpdateCustomer(customer *TableCustomer, tx ...*gorm.DB) error + DeleteCustomer(id string) error + + // Rate limit CRUD + GetRateLimit(id string) (*TableRateLimit, error) + CreateRateLimit(rateLimit *TableRateLimit, tx ...*gorm.DB) error + UpdateRateLimit(rateLimit *TableRateLimit, tx ...*gorm.DB) error + UpdateRateLimits(rateLimits []*TableRateLimit, tx ...*gorm.DB) error + + // Budget CRUD + GetBudgets() ([]TableBudget, error) + GetBudget(id string, tx ...*gorm.DB) (*TableBudget, error) + CreateBudget(budget *TableBudget, tx ...*gorm.DB) error + UpdateBudget(budget *TableBudget, tx ...*gorm.DB) error + UpdateBudgets(budgets []*TableBudget, tx ...*gorm.DB) error + + GetGovernanceConfig() (*GovernanceConfig, error) + + // Model pricing CRUD + GetModelPrices() ([]TableModelPricing, error) + CreateModelPrices(pricing *TableModelPricing, tx ...*gorm.DB) error + DeleteModelPrices(tx ...*gorm.DB) error + + // Key management + GetKeysByIDs(ids []string) ([]TableKey, error) + + // Generic transaction manager + ExecuteTransaction(fn func(tx *gorm.DB) error) error +} + +// NewConfigStore creates a new config store based on the configuration +func NewConfigStore(config *Config, logger schemas.Logger) (ConfigStore, error) { + if config == nil { + return nil, fmt.Errorf("config cannot be nil") + } + + if !config.Enabled { + return nil, nil + } + switch config.Type { + case ConfigStoreTypeSQLite: + if sqliteConfig, ok := config.Config.(*SQLiteConfig); ok { + return newSqliteConfigStore(sqliteConfig, logger) + } + return nil, fmt.Errorf("invalid sqlite config: %T", config.Config) + } + return nil, fmt.Errorf("unsupported config store type: %s", config.Type) +} diff --git a/framework/configstore/tables.go b/framework/configstore/tables.go new file mode 100644 index 000000000..06d5b99fb --- /dev/null +++ b/framework/configstore/tables.go @@ -0,0 +1,830 @@ +package configstore + +import ( + "encoding/json" + "fmt" + "time" + + "github.com/maximhq/bifrost/core/schemas" + "gorm.io/gorm" +) + +// TRANSPORT OPERATION TABLES + +type TableConfigHash struct { + ID uint `gorm:"primaryKey;autoIncrement" json:"id"` + Hash string `gorm:"type:varchar(255);uniqueIndex;not null" json:"hash"` + CreatedAt time.Time `gorm:"index;not null" json:"created_at"` + UpdatedAt time.Time `gorm:"index;not null" json:"updated_at"` +} + +// TableProvider represents a provider configuration in the database +type TableProvider struct { + ID uint `gorm:"primaryKey;autoIncrement" json:"id"` + Name string `gorm:"type:varchar(50);uniqueIndex;not null" json:"name"` // ModelProvider as string + NetworkConfigJSON string `gorm:"type:text" json:"-"` // JSON serialized schemas.NetworkConfig + ConcurrencyBufferJSON string `gorm:"type:text" json:"-"` // JSON serialized schemas.ConcurrencyAndBufferSize + ProxyConfigJSON string `gorm:"type:text" json:"-"` // JSON serialized schemas.ProxyConfig + CustomProviderConfigJSON string `gorm:"type:text" json:"-"` // JSON serialized schemas.CustomProviderConfig + SendBackRawResponse bool `json:"send_back_raw_response"` + CreatedAt time.Time `gorm:"index;not null" json:"created_at"` + UpdatedAt time.Time `gorm:"index;not null" json:"updated_at"` + + // Relationships + Keys []TableKey `gorm:"foreignKey:ProviderID;constraint:OnDelete:CASCADE" json:"keys"` + + // Virtual fields for runtime use (not stored in DB) + NetworkConfig *schemas.NetworkConfig `gorm:"-" json:"network_config,omitempty"` + ConcurrencyAndBufferSize *schemas.ConcurrencyAndBufferSize `gorm:"-" json:"concurrency_and_buffer_size,omitempty"` + ProxyConfig *schemas.ProxyConfig `gorm:"-" json:"proxy_config,omitempty"` + + // Custom provider fields + CustomProviderConfig *schemas.CustomProviderConfig `gorm:"-" json:"custom_provider_config,omitempty"` + + // Foreign keys + Models []TableModel `gorm:"foreignKey:ProviderID;constraint:OnDelete:CASCADE" json:"models"` +} + +// TableModel represents a model configuration in the database +type TableModel struct { + ID string `gorm:"primaryKey" json:"id"` + ProviderID uint `gorm:"index;not null;uniqueIndex:idx_provider_name" json:"provider_id"` + Name string `gorm:"uniqueIndex:idx_provider_name" json:"name"` + CreatedAt time.Time `json:"created_at"` + UpdatedAt time.Time `json:"updated_at"` +} + +// TableKey represents an API key configuration in the database +type TableKey struct { + ID uint `gorm:"primaryKey;autoIncrement" json:"id"` + ProviderID uint `gorm:"index;not null" json:"provider_id"` + Provider string `gorm:"index;type:varchar(50)" json:"provider"` // ModelProvider as string + KeyID string `gorm:"type:varchar(255);uniqueIndex:idx_key_id;not null" json:"key_id"` // UUID from schemas.Key + Value string `gorm:"type:text;not null" json:"value"` + ModelsJSON string `gorm:"type:text" json:"-"` // JSON serialized []string + Weight float64 `gorm:"default:1.0" json:"weight"` + CreatedAt time.Time `gorm:"index;not null" json:"created_at"` + UpdatedAt time.Time `gorm:"index;not null" json:"updated_at"` + + // Azure config fields (embedded instead of separate table for simplicity) + AzureEndpoint *string `gorm:"type:text" json:"azure_endpoint,omitempty"` + AzureAPIVersion *string `gorm:"type:varchar(50)" json:"azure_api_version,omitempty"` + AzureDeploymentsJSON *string `gorm:"type:text" json:"-"` // JSON serialized map[string]string + + // Vertex config fields (embedded) + VertexProjectID *string `gorm:"type:varchar(255)" json:"vertex_project_id,omitempty"` + VertexRegion *string `gorm:"type:varchar(100)" json:"vertex_region,omitempty"` + VertexAuthCredentials *string `gorm:"type:text" json:"vertex_auth_credentials,omitempty"` + + // Bedrock config fields (embedded) + BedrockAccessKey *string `gorm:"type:varchar(255)" json:"bedrock_access_key,omitempty"` + BedrockSecretKey *string `gorm:"type:text" json:"bedrock_secret_key,omitempty"` + BedrockSessionToken *string `gorm:"type:text" json:"bedrock_session_token,omitempty"` + BedrockRegion *string `gorm:"type:varchar(100)" json:"bedrock_region,omitempty"` + BedrockARN *string `gorm:"type:text" json:"bedrock_arn,omitempty"` + BedrockDeploymentsJSON *string `gorm:"type:text" json:"-"` // JSON serialized map[string]string + + // Virtual fields for runtime use (not stored in DB) + Models []string `gorm:"-" json:"models"` + AzureKeyConfig *schemas.AzureKeyConfig `gorm:"-" json:"azure_key_config,omitempty"` + VertexKeyConfig *schemas.VertexKeyConfig `gorm:"-" json:"vertex_key_config,omitempty"` + BedrockKeyConfig *schemas.BedrockKeyConfig `gorm:"-" json:"bedrock_key_config,omitempty"` +} + +// TableMCPClient represents an MCP client configuration in the database +type TableMCPClient struct { + ID uint `gorm:"primaryKey;autoIncrement" json:"id"` + Name string `gorm:"type:varchar(255);uniqueIndex;not null" json:"name"` + ConnectionType string `gorm:"type:varchar(20);not null" json:"connection_type"` // schemas.MCPConnectionType + ConnectionString *string `gorm:"type:text" json:"connection_string,omitempty"` + StdioConfigJSON *string `gorm:"type:text" json:"-"` // JSON serialized schemas.MCPStdioConfig + ToolsToExecuteJSON string `gorm:"type:text" json:"-"` // JSON serialized []string + ToolsToSkipJSON string `gorm:"type:text" json:"-"` // JSON serialized []string + CreatedAt time.Time `gorm:"index;not null" json:"created_at"` + UpdatedAt time.Time `gorm:"index;not null" json:"updated_at"` + + // Virtual fields for runtime use (not stored in DB) + StdioConfig *schemas.MCPStdioConfig `gorm:"-" json:"stdio_config,omitempty"` + ToolsToExecute []string `gorm:"-" json:"tools_to_execute"` + ToolsToSkip []string `gorm:"-" json:"tools_to_skip"` +} + +// TableClientConfig represents global client configuration in the database +type TableClientConfig struct { + ID uint `gorm:"primaryKey;autoIncrement" json:"id"` + DropExcessRequests bool `gorm:"default:false" json:"drop_excess_requests"` + PrometheusLabelsJSON string `gorm:"type:text" json:"-"` // JSON serialized []string + AllowedOriginsJSON string `gorm:"type:text" json:"-"` // JSON serialized []string + InitialPoolSize int `gorm:"default:300" json:"initial_pool_size"` + EnableLogging bool `gorm:"" json:"enable_logging"` + EnableGovernance bool `gorm:"" json:"enable_governance"` + EnforceGovernanceHeader bool `gorm:"" json:"enforce_governance_header"` + AllowDirectKeys bool `gorm:"" json:"allow_direct_keys"` + MaxRequestBodySizeMB int `gorm:"" json:"max_request_body_size_mb"` + CreatedAt time.Time `gorm:"index;not null" json:"created_at"` + UpdatedAt time.Time `gorm:"index;not null" json:"updated_at"` + + // Virtual fields for runtime use (not stored in DB) + PrometheusLabels []string `gorm:"-" json:"prometheus_labels"` + AllowedOrigins []string `gorm:"-" json:"allowed_origins,omitempty"` +} + +// TableEnvKey represents environment variable tracking in the database +type TableEnvKey struct { + ID uint `gorm:"primaryKey;autoIncrement" json:"id"` + EnvVar string `gorm:"type:varchar(255);index;not null" json:"env_var"` + Provider string `gorm:"type:varchar(50);index" json:"provider"` // Empty for MCP/client configs + KeyType string `gorm:"type:varchar(50);not null" json:"key_type"` // "api_key", "azure_config", "vertex_config", "bedrock_config", "connection_string" + ConfigPath string `gorm:"type:varchar(500);not null" json:"config_path"` // Descriptive path of where this env var is used + KeyID string `gorm:"type:varchar(255);index" json:"key_id"` // Key UUID (empty for non-key configs) + CreatedAt time.Time `gorm:"index;not null" json:"created_at"` +} + +// TableVectorStoreConfig represents Cache plugin configuration in the database +type TableVectorStoreConfig struct { + ID uint `gorm:"primaryKey;autoIncrement" json:"id"` + Enabled bool `json:"enabled"` // Enable vector store + Type string `gorm:"type:varchar(50);not null" json:"type"` // "weaviate, elasticsearch, pinecone, etc." + TTLSeconds int `gorm:"default:300" json:"ttl_seconds"` // TTL in seconds (default: 5 minutes) + CacheByModel bool `gorm:"" json:"cache_by_model"` // Include model in cache key + CacheByProvider bool `gorm:"" json:"cache_by_provider"` // Include provider in cache key + Config *string `gorm:"type:text" json:"config"` // JSON serialized schemas.RedisVectorStoreConfig + CreatedAt time.Time `gorm:"index;not null" json:"created_at"` + UpdatedAt time.Time `gorm:"index;not null" json:"updated_at"` +} + +// TableLogStoreConfig represents the configuration for the log store in the database +type TableLogStoreConfig struct { + ID uint `gorm:"primaryKey;autoIncrement" json:"id"` + Enabled bool `json:"enabled"` + Type string `gorm:"type:varchar(50);not null" json:"type"` // "sqlite" + Config *string `gorm:"type:text" json:"config"` // JSON serialized logstore.Config + CreatedAt time.Time `gorm:"index;not null" json:"created_at"` + UpdatedAt time.Time `gorm:"index;not null" json:"updated_at"` +} + +// TablePlugin represents a plugin configuration in the database + +type TablePlugin struct { + ID uint `gorm:"primaryKey;autoIncrement" json:"id"` + Name string `gorm:"type:varchar(255);uniqueIndex;not null" json:"name"` + Enabled bool `json:"enabled"` + ConfigJSON string `gorm:"type:text" json:"-"` // JSON serialized plugin.Config + CreatedAt time.Time `gorm:"index;not null" json:"created_at"` + UpdatedAt time.Time `gorm:"index;not null" json:"updated_at"` + + // Virtual fields for runtime use (not stored in DB) + Config any `gorm:"-" json:"config,omitempty"` +} + +// TableName sets the table name for each model +func (TableConfigHash) TableName() string { return "config_hashes" } +func (TableProvider) TableName() string { return "config_providers" } +func (TableKey) TableName() string { return "config_keys" } +func (TableModel) TableName() string { return "config_models" } +func (TableMCPClient) TableName() string { return "config_mcp_clients" } +func (TableClientConfig) TableName() string { return "config_client" } +func (TableEnvKey) TableName() string { return "config_env_keys" } +func (TableVectorStoreConfig) TableName() string { return "config_vector_store" } +func (TableLogStoreConfig) TableName() string { return "config_log_store" } +func (TablePlugin) TableName() string { return "config_plugins" } + +// GORM Hooks for JSON serialization/deserialization + +// BeforeSave hooks for serialization +func (p *TableProvider) BeforeSave(tx *gorm.DB) error { + if p.NetworkConfig != nil { + data, err := json.Marshal(p.NetworkConfig) + if err != nil { + return err + } + p.NetworkConfigJSON = string(data) + } + + if p.ConcurrencyAndBufferSize != nil { + data, err := json.Marshal(p.ConcurrencyAndBufferSize) + if err != nil { + return err + } + p.ConcurrencyBufferJSON = string(data) + } + + if p.ProxyConfig != nil { + data, err := json.Marshal(p.ProxyConfig) + if err != nil { + return err + } + p.ProxyConfigJSON = string(data) + } + + if p.CustomProviderConfig != nil && p.CustomProviderConfig.BaseProviderType == "" { + return fmt.Errorf("base_provider_type is required when custom_provider_config is set") + } + + if p.CustomProviderConfig != nil { + data, err := json.Marshal(p.CustomProviderConfig) + if err != nil { + return err + } + p.CustomProviderConfigJSON = string(data) + } + + return nil +} + +func (k *TableKey) BeforeSave(tx *gorm.DB) error { + + if k.Models != nil { + data, err := json.Marshal(k.Models) + if err != nil { + return err + } + k.ModelsJSON = string(data) + } else { + k.ModelsJSON = "[]" + } + + if k.AzureKeyConfig != nil { + if k.AzureKeyConfig.Endpoint != "" { + k.AzureEndpoint = &k.AzureKeyConfig.Endpoint + } + k.AzureAPIVersion = k.AzureKeyConfig.APIVersion + if k.AzureKeyConfig.Deployments != nil { + data, err := json.Marshal(k.AzureKeyConfig.Deployments) + if err != nil { + return err + } + s := string(data) + k.AzureDeploymentsJSON = &s + } + } else { + k.AzureEndpoint = nil + k.AzureAPIVersion = nil + k.AzureDeploymentsJSON = nil + } + + if k.VertexKeyConfig != nil { + if k.VertexKeyConfig.ProjectID != "" { + k.VertexProjectID = &k.VertexKeyConfig.ProjectID + } + if k.VertexKeyConfig.Region != "" { + k.VertexRegion = &k.VertexKeyConfig.Region + } + if k.VertexKeyConfig.AuthCredentials != "" { + k.VertexAuthCredentials = &k.VertexKeyConfig.AuthCredentials + } + } else { + k.VertexProjectID = nil + k.VertexRegion = nil + k.VertexAuthCredentials = nil + } + + if k.BedrockKeyConfig != nil { + if k.BedrockKeyConfig.AccessKey != "" { + k.BedrockAccessKey = &k.BedrockKeyConfig.AccessKey + } + if k.BedrockKeyConfig.SecretKey != "" { + k.BedrockSecretKey = &k.BedrockKeyConfig.SecretKey + } + k.BedrockSessionToken = k.BedrockKeyConfig.SessionToken + k.BedrockRegion = k.BedrockKeyConfig.Region + k.BedrockARN = k.BedrockKeyConfig.ARN + if k.BedrockKeyConfig.Deployments != nil { + data, err := json.Marshal(k.BedrockKeyConfig.Deployments) + if err != nil { + return err + } + s := string(data) + k.BedrockDeploymentsJSON = &s + } + } else { + k.BedrockAccessKey = nil + k.BedrockSecretKey = nil + k.BedrockSessionToken = nil + k.BedrockRegion = nil + k.BedrockARN = nil + k.BedrockDeploymentsJSON = nil + } + return nil +} + +func (c *TableMCPClient) BeforeSave(tx *gorm.DB) error { + if c.StdioConfig != nil { + data, err := json.Marshal(c.StdioConfig) + if err != nil { + return err + } + config := string(data) + c.StdioConfigJSON = &config + } else { + c.StdioConfigJSON = nil + } + + if c.ToolsToExecute != nil { + data, err := json.Marshal(c.ToolsToExecute) + if err != nil { + return err + } + c.ToolsToExecuteJSON = string(data) + } else { + c.ToolsToExecuteJSON = "[]" + } + + if c.ToolsToSkip != nil { + data, err := json.Marshal(c.ToolsToSkip) + if err != nil { + return err + } + c.ToolsToSkipJSON = string(data) + } else { + c.ToolsToSkipJSON = "[]" + } + + return nil +} + +func (cc *TableClientConfig) BeforeSave(tx *gorm.DB) error { + if cc.PrometheusLabels != nil { + data, err := json.Marshal(cc.PrometheusLabels) + if err != nil { + return err + } + cc.PrometheusLabelsJSON = string(data) + } else { + cc.PrometheusLabelsJSON = "[]" + } + + if cc.AllowedOrigins != nil { + data, err := json.Marshal(cc.AllowedOrigins) + if err != nil { + return err + } + cc.AllowedOriginsJSON = string(data) + } else { + cc.AllowedOriginsJSON = "[]" + } + + return nil +} + +func (p *TablePlugin) BeforeSave(tx *gorm.DB) error { + if p.Config != nil { + data, err := json.Marshal(p.Config) + if err != nil { + return err + } + p.ConfigJSON = string(data) + } else { + p.ConfigJSON = "{}" + } + + return nil +} + +// AfterFind hooks for deserialization +func (p *TableProvider) AfterFind(tx *gorm.DB) error { + if p.NetworkConfigJSON != "" { + var config schemas.NetworkConfig + if err := json.Unmarshal([]byte(p.NetworkConfigJSON), &config); err != nil { + return err + } + p.NetworkConfig = &config + } + + if p.ConcurrencyBufferJSON != "" { + var config schemas.ConcurrencyAndBufferSize + if err := json.Unmarshal([]byte(p.ConcurrencyBufferJSON), &config); err != nil { + return err + } + p.ConcurrencyAndBufferSize = &config + } + + if p.ProxyConfigJSON != "" { + var proxyConfig schemas.ProxyConfig + if err := json.Unmarshal([]byte(p.ProxyConfigJSON), &proxyConfig); err != nil { + return err + } + p.ProxyConfig = &proxyConfig + } + + if p.CustomProviderConfigJSON != "" { + var customConfig schemas.CustomProviderConfig + if err := json.Unmarshal([]byte(p.CustomProviderConfigJSON), &customConfig); err != nil { + return err + } + p.CustomProviderConfig = &customConfig + } + + return nil +} + +func (k *TableKey) AfterFind(tx *gorm.DB) error { + if k.ModelsJSON != "" { + if err := json.Unmarshal([]byte(k.ModelsJSON), &k.Models); err != nil { + return err + } + } + + // Reconstruct Azure config if fields are present + if k.AzureEndpoint != nil { + azureConfig := &schemas.AzureKeyConfig{ + Endpoint: *k.AzureEndpoint, + APIVersion: k.AzureAPIVersion, + } + + if k.AzureDeploymentsJSON != nil { + var deployments map[string]string + if err := json.Unmarshal([]byte(*k.AzureDeploymentsJSON), &deployments); err != nil { + return err + } + azureConfig.Deployments = deployments + } + + k.AzureKeyConfig = azureConfig + } + + // Reconstruct Vertex config if fields are present + if k.VertexProjectID != nil || k.VertexRegion != nil || k.VertexAuthCredentials != nil { + config := &schemas.VertexKeyConfig{} + + if k.VertexProjectID != nil { + config.ProjectID = *k.VertexProjectID + } + + if k.VertexRegion != nil { + config.Region = *k.VertexRegion + } + if k.VertexAuthCredentials != nil { + config.AuthCredentials = *k.VertexAuthCredentials + } + + k.VertexKeyConfig = config + } + + // Reconstruct Bedrock config if fields are present + if k.BedrockAccessKey != nil || k.BedrockSecretKey != nil || k.BedrockSessionToken != nil || k.BedrockRegion != nil || k.BedrockARN != nil || (k.BedrockDeploymentsJSON != nil && *k.BedrockDeploymentsJSON != "") { + bedrockConfig := &schemas.BedrockKeyConfig{} + + if k.BedrockAccessKey != nil { + bedrockConfig.AccessKey = *k.BedrockAccessKey + } + + bedrockConfig.SessionToken = k.BedrockSessionToken + bedrockConfig.Region = k.BedrockRegion + bedrockConfig.ARN = k.BedrockARN + + if k.BedrockSecretKey != nil { + bedrockConfig.SecretKey = *k.BedrockSecretKey + } + + if k.BedrockDeploymentsJSON != nil { + var deployments map[string]string + if err := json.Unmarshal([]byte(*k.BedrockDeploymentsJSON), &deployments); err != nil { + return err + } + bedrockConfig.Deployments = deployments + } + + k.BedrockKeyConfig = bedrockConfig + } + + return nil +} + +func (c *TableMCPClient) AfterFind(tx *gorm.DB) error { + if c.StdioConfigJSON != nil { + var config schemas.MCPStdioConfig + if err := json.Unmarshal([]byte(*c.StdioConfigJSON), &config); err != nil { + return err + } + c.StdioConfig = &config + } + + if c.ToolsToExecuteJSON != "" { + if err := json.Unmarshal([]byte(c.ToolsToExecuteJSON), &c.ToolsToExecute); err != nil { + return err + } + } + + if c.ToolsToSkipJSON != "" { + if err := json.Unmarshal([]byte(c.ToolsToSkipJSON), &c.ToolsToSkip); err != nil { + return err + } + } + + return nil +} + +func (cc *TableClientConfig) AfterFind(tx *gorm.DB) error { + if cc.PrometheusLabelsJSON != "" { + if err := json.Unmarshal([]byte(cc.PrometheusLabelsJSON), &cc.PrometheusLabels); err != nil { + return err + } + } + + if cc.AllowedOriginsJSON != "" { + if err := json.Unmarshal([]byte(cc.AllowedOriginsJSON), &cc.AllowedOrigins); err != nil { + return err + } + } + + return nil +} + +func (p *TablePlugin) AfterFind(tx *gorm.DB) error { + if p.ConfigJSON != "" { + if err := json.Unmarshal([]byte(p.ConfigJSON), &p.Config); err != nil { + return err + } + } else { + p.Config = nil + } + + return nil +} + +// TableConfig represents generic configuration key-value pairs +type TableConfig struct { + Key string `gorm:"primaryKey;type:varchar(255)" json:"key"` + Value string `gorm:"type:text" json:"value"` +} + +// GOVERNANCE TABLES + +// TableBudget defines spending limits with configurable reset periods +type TableBudget struct { + ID string `gorm:"primaryKey;type:varchar(255)" json:"id"` + MaxLimit float64 `gorm:"not null" json:"max_limit"` // Maximum budget in dollars + ResetDuration string `gorm:"type:varchar(50);not null" json:"reset_duration"` // e.g., "30s", "5m", "1h", "1d", "1w", "1M", "1Y" + LastReset time.Time `gorm:"index" json:"last_reset"` // Last time budget was reset + CurrentUsage float64 `gorm:"default:0" json:"current_usage"` // Current usage in dollars + + CreatedAt time.Time `gorm:"index;not null" json:"created_at"` + UpdatedAt time.Time `gorm:"index;not null" json:"updated_at"` +} + +// TableRateLimit defines rate limiting rules for virtual keys using flexible max+reset approach +type TableRateLimit struct { + ID string `gorm:"primaryKey;type:varchar(255)" json:"id"` + + // Token limits with flexible duration + TokenMaxLimit *int64 `gorm:"default:null" json:"token_max_limit,omitempty"` // Maximum tokens allowed + TokenResetDuration *string `gorm:"type:varchar(50)" json:"token_reset_duration,omitempty"` // e.g., "30s", "5m", "1h", "1d", "1w", "1M", "1Y" + TokenCurrentUsage int64 `gorm:"default:0" json:"token_current_usage"` // Current token usage + TokenLastReset time.Time `gorm:"index" json:"token_last_reset"` // Last time token counter was reset + + // Request limits with flexible duration + RequestMaxLimit *int64 `gorm:"default:null" json:"request_max_limit,omitempty"` // Maximum requests allowed + RequestResetDuration *string `gorm:"type:varchar(50)" json:"request_reset_duration,omitempty"` // e.g., "30s", "5m", "1h", "1d", "1w", "1M", "1Y" + RequestCurrentUsage int64 `gorm:"default:0" json:"request_current_usage"` // Current request usage + RequestLastReset time.Time `gorm:"index" json:"request_last_reset"` // Last time request counter was reset + + CreatedAt time.Time `gorm:"index;not null" json:"created_at"` + UpdatedAt time.Time `gorm:"index;not null" json:"updated_at"` +} + +// TableCustomer represents a customer entity with budget +type TableCustomer struct { + ID string `gorm:"primaryKey;type:varchar(255)" json:"id"` + Name string `gorm:"type:varchar(255);not null" json:"name"` + BudgetID *string `gorm:"type:varchar(255);index" json:"budget_id,omitempty"` + + // Relationships + Budget *TableBudget `gorm:"foreignKey:BudgetID" json:"budget,omitempty"` + Teams []TableTeam `gorm:"foreignKey:CustomerID" json:"teams"` + VirtualKeys []TableVirtualKey `gorm:"foreignKey:CustomerID" json:"virtual_keys"` + + CreatedAt time.Time `gorm:"index;not null" json:"created_at"` + UpdatedAt time.Time `gorm:"index;not null" json:"updated_at"` +} + +// TableTeam represents a team entity with budget and customer association +type TableTeam struct { + ID string `gorm:"primaryKey;type:varchar(255)" json:"id"` + Name string `gorm:"type:varchar(255);not null" json:"name"` + CustomerID *string `gorm:"type:varchar(255);index" json:"customer_id,omitempty"` // A team can belong to a customer + BudgetID *string `gorm:"type:varchar(255);index" json:"budget_id,omitempty"` + + // Relationships + Customer *TableCustomer `gorm:"foreignKey:CustomerID" json:"customer,omitempty"` + Budget *TableBudget `gorm:"foreignKey:BudgetID" json:"budget,omitempty"` + VirtualKeys []TableVirtualKey `gorm:"foreignKey:TeamID" json:"virtual_keys"` + + CreatedAt time.Time `gorm:"index;not null" json:"created_at"` + UpdatedAt time.Time `gorm:"index;not null" json:"updated_at"` +} + +// TableVirtualKey represents a virtual key with budget, rate limits, and team/customer association +type TableVirtualKey struct { + ID string `gorm:"primaryKey;type:varchar(255)" json:"id"` + Name string `gorm:"uniqueIndex:idx_virtual_key_name;type:varchar(255);not null" json:"name"` + Description string `gorm:"type:text" json:"description,omitempty"` + Value string `gorm:"uniqueIndex:idx_virtual_key_value;type:varchar(255);not null" json:"value"` // The virtual key value + IsActive bool `gorm:"default:true" json:"is_active"` + AllowedModels []string `gorm:"type:text;serializer:json" json:"allowed_models"` // Empty means all models allowed + AllowedProviders []string `gorm:"type:text;serializer:json" json:"allowed_providers"` // Empty means all providers allowed + + // Foreign key relationships (mutually exclusive: either TeamID or CustomerID, not both) + TeamID *string `gorm:"type:varchar(255);index" json:"team_id,omitempty"` + CustomerID *string `gorm:"type:varchar(255);index" json:"customer_id,omitempty"` + BudgetID *string `gorm:"type:varchar(255);index" json:"budget_id,omitempty"` + RateLimitID *string `gorm:"type:varchar(255);index" json:"rate_limit_id,omitempty"` + Keys []TableKey `gorm:"many2many:governance_virtual_key_keys;constraint:OnDelete:CASCADE" json:"keys"` + + // Relationships + Team *TableTeam `gorm:"foreignKey:TeamID" json:"team,omitempty"` + Customer *TableCustomer `gorm:"foreignKey:CustomerID" json:"customer,omitempty"` + Budget *TableBudget `gorm:"foreignKey:BudgetID" json:"budget,omitempty"` + RateLimit *TableRateLimit `gorm:"foreignKey:RateLimitID" json:"rate_limit,omitempty"` + + CreatedAt time.Time `gorm:"index;not null" json:"created_at"` + UpdatedAt time.Time `gorm:"index;not null" json:"updated_at"` +} + +// TableModelPricing represents pricing information for AI models +type TableModelPricing struct { + ID uint `gorm:"primaryKey;autoIncrement" json:"id"` + Model string `gorm:"type:varchar(255);not null;uniqueIndex:idx_model_provider_mode" json:"model"` + Provider string `gorm:"type:varchar(50);not null;uniqueIndex:idx_model_provider_mode" json:"provider"` + InputCostPerToken float64 `gorm:"not null" json:"input_cost_per_token"` + OutputCostPerToken float64 `gorm:"not null" json:"output_cost_per_token"` + Mode string `gorm:"type:varchar(50);not null;uniqueIndex:idx_model_provider_mode" json:"mode"` + + // Additional pricing for media + InputCostPerImage *float64 `gorm:"default:null" json:"input_cost_per_image,omitempty"` + InputCostPerVideoPerSecond *float64 `gorm:"default:null" json:"input_cost_per_video_per_second,omitempty"` + InputCostPerAudioPerSecond *float64 `gorm:"default:null" json:"input_cost_per_audio_per_second,omitempty"` + + // Character-based pricing + InputCostPerCharacter *float64 `gorm:"default:null" json:"input_cost_per_character,omitempty"` + OutputCostPerCharacter *float64 `gorm:"default:null" json:"output_cost_per_character,omitempty"` + + // Pricing above 128k tokens + InputCostPerTokenAbove128kTokens *float64 `gorm:"default:null" json:"input_cost_per_token_above_128k_tokens,omitempty"` + InputCostPerCharacterAbove128kTokens *float64 `gorm:"default:null" json:"input_cost_per_character_above_128k_tokens,omitempty"` + InputCostPerImageAbove128kTokens *float64 `gorm:"default:null" json:"input_cost_per_image_above_128k_tokens,omitempty"` + InputCostPerVideoPerSecondAbove128kTokens *float64 `gorm:"default:null" json:"input_cost_per_video_per_second_above_128k_tokens,omitempty"` + InputCostPerAudioPerSecondAbove128kTokens *float64 `gorm:"default:null" json:"input_cost_per_audio_per_second_above_128k_tokens,omitempty"` + OutputCostPerTokenAbove128kTokens *float64 `gorm:"default:null" json:"output_cost_per_token_above_128k_tokens,omitempty"` + OutputCostPerCharacterAbove128kTokens *float64 `gorm:"default:null" json:"output_cost_per_character_above_128k_tokens,omitempty"` + + // Cache and batch pricing + CacheReadInputTokenCost *float64 `gorm:"default:null" json:"cache_read_input_token_cost,omitempty"` + InputCostPerTokenBatches *float64 `gorm:"default:null" json:"input_cost_per_token_batches,omitempty"` + OutputCostPerTokenBatches *float64 `gorm:"default:null" json:"output_cost_per_token_batches,omitempty"` +} + +// Table names +func (TableBudget) TableName() string { return "governance_budgets" } +func (TableRateLimit) TableName() string { return "governance_rate_limits" } +func (TableCustomer) TableName() string { return "governance_customers" } +func (TableTeam) TableName() string { return "governance_teams" } +func (TableVirtualKey) TableName() string { return "governance_virtual_keys" } +func (TableConfig) TableName() string { return "governance_config" } +func (TableModelPricing) TableName() string { return "governance_model_pricing" } + +// GORM Hooks for validation and constraints + +// BeforeSave hook for VirtualKey to enforce mutual exclusion +func (vk *TableVirtualKey) BeforeSave(tx *gorm.DB) error { + // Enforce mutual exclusion: VK can belong to either Team OR Customer, not both + if vk.TeamID != nil && vk.CustomerID != nil { + return fmt.Errorf("virtual key cannot belong to both team and customer") + } + return nil +} + +// BeforeSave hook for Budget to validate reset duration format and max limit +func (b *TableBudget) BeforeSave(tx *gorm.DB) error { + // Validate that ResetDuration is in correct format (e.g., "30s", "5m", "1h", "1d", "1w", "1M", "1Y") + if _, err := ParseDuration(b.ResetDuration); err != nil { + return fmt.Errorf("invalid reset duration format: %s", b.ResetDuration) + } + + // Validate that MaxLimit is not negative (budgets should be positive) + if b.MaxLimit < 0 { + return fmt.Errorf("budget max_limit cannot be negative: %.2f", b.MaxLimit) + } + + return nil +} + +// BeforeSave hook for RateLimit to validate reset duration formats +func (rl *TableRateLimit) BeforeSave(tx *gorm.DB) error { + // Validate token reset duration if provided + if rl.TokenResetDuration != nil { + if _, err := ParseDuration(*rl.TokenResetDuration); err != nil { + return fmt.Errorf("invalid token reset duration format: %s", *rl.TokenResetDuration) + } + } + + // Validate request reset duration if provided + if rl.RequestResetDuration != nil { + if _, err := ParseDuration(*rl.RequestResetDuration); err != nil { + return fmt.Errorf("invalid request reset duration format: %s", *rl.RequestResetDuration) + } + } + + // Validate that if a max limit is set, a reset duration is also provided + if rl.TokenMaxLimit != nil && rl.TokenResetDuration == nil { + return fmt.Errorf("token_reset_duration is required when token_max_limit is set") + } + if rl.RequestMaxLimit != nil && rl.RequestResetDuration == nil { + return fmt.Errorf("request_reset_duration is required when request_max_limit is set") + } + + return nil +} + +func (vk *TableVirtualKey) AfterFind(tx *gorm.DB) error { + if vk.Keys != nil { + // Clear sensitive data from associated keys, keeping only key IDs and non-sensitive metadata + for i := range vk.Keys { + key := &vk.Keys[i] + + // Clear the actual API key value + key.Value = "" + + // Clear all Azure-related sensitive fields + key.AzureEndpoint = nil + key.AzureAPIVersion = nil + key.AzureDeploymentsJSON = nil + key.AzureKeyConfig = nil + + // Clear all Vertex-related sensitive fields + key.VertexProjectID = nil + key.VertexRegion = nil + key.VertexAuthCredentials = nil + key.VertexKeyConfig = nil + + // Clear all Bedrock-related sensitive fields + key.BedrockAccessKey = nil + key.BedrockSecretKey = nil + key.BedrockSessionToken = nil + key.BedrockRegion = nil + key.BedrockARN = nil + key.BedrockDeploymentsJSON = nil + key.BedrockKeyConfig = nil + + vk.Keys[i] = *key + } + } + return nil +} + +// Database constraints and indexes +func (vk *TableVirtualKey) AfterAutoMigrate(tx *gorm.DB) error { + // Ensure only one of TeamID or CustomerID is set + return tx.Exec(` + CREATE OR REPLACE FUNCTION check_vk_exclusion() RETURNS TRIGGER AS $$ + BEGIN + IF NEW.team_id IS NOT NULL AND NEW.customer_id IS NOT NULL THEN + RAISE EXCEPTION 'Virtual key cannot belong to both team and customer'; + END IF; + RETURN NEW; + END; + $$ LANGUAGE plpgsql; + + DROP TRIGGER IF EXISTS vk_exclusion_trigger ON governance_virtual_keys; + CREATE TRIGGER vk_exclusion_trigger + BEFORE INSERT OR UPDATE ON governance_virtual_keys + FOR EACH ROW EXECUTE FUNCTION check_vk_exclusion(); + `).Error +} + +// Utility function to parse duration strings +func ParseDuration(duration string) (time.Duration, error) { + if duration == "" { + return 0, fmt.Errorf("duration is empty") + } + + // Handle special cases for days, weeks, months, years + switch { + case duration[len(duration)-1:] == "d": + days := duration[:len(duration)-1] + if d, err := time.ParseDuration(days + "h"); err == nil { + return d * 24, nil + } + return 0, fmt.Errorf("invalid day duration: %s", duration) + case duration[len(duration)-1:] == "w": + weeks := duration[:len(duration)-1] + if w, err := time.ParseDuration(weeks + "h"); err == nil { + return w * 24 * 7, nil + } + return 0, fmt.Errorf("invalid week duration: %s", duration) + case duration[len(duration)-1:] == "M": + months := duration[:len(duration)-1] + if m, err := time.ParseDuration(months + "h"); err == nil { + return m * 24 * 30, nil // Approximate month as 30 days + } + return 0, fmt.Errorf("invalid month duration: %s", duration) + case duration[len(duration)-1:] == "Y": + years := duration[:len(duration)-1] + if y, err := time.ParseDuration(years + "h"); err == nil { + return y * 24 * 365, nil // Approximate year as 365 days + } + return 0, fmt.Errorf("invalid year duration: %s", duration) + default: + return time.ParseDuration(duration) + } +} diff --git a/framework/configstore/utils.go b/framework/configstore/utils.go new file mode 100644 index 000000000..fe7137d76 --- /dev/null +++ b/framework/configstore/utils.go @@ -0,0 +1,161 @@ +package configstore + +import ( + "encoding/json" + "fmt" + "strings" + + "github.com/maximhq/bifrost/core/schemas" +) + +// marshalToString marshals the given value to a JSON string. +func marshalToString(v any) (string, error) { + if v == nil { + return "", nil + } + data, err := json.Marshal(v) + if err != nil { + return "", err + } + return string(data), nil +} + +// marshalToStringPtr marshals the given value to a JSON string and returns a pointer to the string. +func marshalToStringPtr(v any) (*string, error) { + if v == nil { + return nil, nil + } + data, err := marshalToString(v) + if err != nil { + return nil, err + } + return &data, nil +} + +// deepCopy creates a deep copy of a given type +func deepCopy[T any](in T) (T, error) { + var out T + b, err := json.Marshal(in) + if err != nil { + return out, err + } + err = json.Unmarshal(b, &out) + return out, err +} + +// substituteEnvVars replaces resolved environment variable values with their original env.VAR_NAME references +func substituteEnvVars(config *ProviderConfig, provider schemas.ModelProvider, envKeys map[string][]EnvKeyInfo) { + // Create a map for quick lookup of env vars by provider and key ID + envVarMap := make(map[string]string) // key: "provider.keyID.field" -> env var name + + for envVar, keyInfos := range envKeys { + for _, keyInfo := range keyInfos { + if keyInfo.Provider == provider { + // For API keys + if keyInfo.KeyType == "api_key" { + envVarMap[fmt.Sprintf("%s.%s.value", provider, keyInfo.KeyID)] = envVar + } + // For Azure config + if keyInfo.KeyType == "azure_config" { + field := strings.TrimPrefix(keyInfo.ConfigPath, fmt.Sprintf("providers.%s.keys[%s].azure_key_config.", provider, keyInfo.KeyID)) + envVarMap[fmt.Sprintf("%s.%s.azure.%s", provider, keyInfo.KeyID, field)] = envVar + } + // For Vertex config + if keyInfo.KeyType == "vertex_config" { + field := strings.TrimPrefix(keyInfo.ConfigPath, fmt.Sprintf("providers.%s.keys[%s].vertex_key_config.", provider, keyInfo.KeyID)) + envVarMap[fmt.Sprintf("%s.%s.vertex.%s", provider, keyInfo.KeyID, field)] = envVar + } + // For Bedrock config + if keyInfo.KeyType == "bedrock_config" { + field := strings.TrimPrefix(keyInfo.ConfigPath, fmt.Sprintf("providers.%s.keys[%s].bedrock_key_config.", provider, keyInfo.KeyID)) + envVarMap[fmt.Sprintf("%s.%s.bedrock.%s", provider, keyInfo.KeyID, field)] = envVar + } + } + } + } + + // Substitute values in keys + for i, key := range config.Keys { + keyPrefix := fmt.Sprintf("%s.%s", provider, key.ID) + + // Substitute API key value + if envVar, exists := envVarMap[fmt.Sprintf("%s.value", keyPrefix)]; exists { + config.Keys[i].Value = fmt.Sprintf("env.%s", envVar) + } + + // Substitute Azure config + if key.AzureKeyConfig != nil { + if envVar, exists := envVarMap[fmt.Sprintf("%s.azure.endpoint", keyPrefix)]; exists { + config.Keys[i].AzureKeyConfig.Endpoint = fmt.Sprintf("env.%s", envVar) + } + if envVar, exists := envVarMap[fmt.Sprintf("%s.azure.api_version", keyPrefix)]; exists { + apiVersion := fmt.Sprintf("env.%s", envVar) + config.Keys[i].AzureKeyConfig.APIVersion = &apiVersion + } + } + + // Substitute Vertex config + if key.VertexKeyConfig != nil { + if envVar, exists := envVarMap[fmt.Sprintf("%s.vertex.project_id", keyPrefix)]; exists { + config.Keys[i].VertexKeyConfig.ProjectID = fmt.Sprintf("env.%s", envVar) + } + if envVar, exists := envVarMap[fmt.Sprintf("%s.vertex.region", keyPrefix)]; exists { + config.Keys[i].VertexKeyConfig.Region = fmt.Sprintf("env.%s", envVar) + } + if envVar, exists := envVarMap[fmt.Sprintf("%s.vertex.auth_credentials", keyPrefix)]; exists { + config.Keys[i].VertexKeyConfig.AuthCredentials = fmt.Sprintf("env.%s", envVar) + } + } + + // Substitute Bedrock config + if key.BedrockKeyConfig != nil { + if envVar, exists := envVarMap[fmt.Sprintf("%s.bedrock.access_key", keyPrefix)]; exists { + config.Keys[i].BedrockKeyConfig.AccessKey = fmt.Sprintf("env.%s", envVar) + } + if envVar, exists := envVarMap[fmt.Sprintf("%s.bedrock.secret_key", keyPrefix)]; exists { + config.Keys[i].BedrockKeyConfig.SecretKey = fmt.Sprintf("env.%s", envVar) + } + if envVar, exists := envVarMap[fmt.Sprintf("%s.bedrock.session_token", keyPrefix)]; exists { + config.Keys[i].BedrockKeyConfig.SessionToken = &[]string{fmt.Sprintf("env.%s", envVar)}[0] + } + if envVar, exists := envVarMap[fmt.Sprintf("%s.bedrock.region", keyPrefix)]; exists { + config.Keys[i].BedrockKeyConfig.Region = &[]string{fmt.Sprintf("env.%s", envVar)}[0] + } + if envVar, exists := envVarMap[fmt.Sprintf("%s.bedrock.arn", keyPrefix)]; exists { + config.Keys[i].BedrockKeyConfig.ARN = &[]string{fmt.Sprintf("env.%s", envVar)}[0] + } + } + } +} + +// substituteMCPEnvVars replaces resolved environment variable values with their original env.VAR_NAME references for MCP config +func substituteMCPEnvVars(config *schemas.MCPConfig, envKeys map[string][]EnvKeyInfo) { + // Create a map for quick lookup of env vars by MCP client name + envVarMap := make(map[string]string) // key: "clientName.connection_string" -> env var name + + for envVar, keyInfos := range envKeys { + for _, keyInfo := range keyInfos { + // For MCP connection strings + if keyInfo.KeyType == "connection_string" { + // Extract client name from config path like "mcp.client_configs.clientName.connection_string" + pathParts := strings.Split(keyInfo.ConfigPath, ".") + if len(pathParts) >= 3 && pathParts[0] == "mcp" && pathParts[1] == "client_configs" { + clientName := pathParts[2] + envVarMap[fmt.Sprintf("%s.connection_string", clientName)] = envVar + } + } + } + } + + // Substitute values in MCP client configs + for i, clientConfig := range config.ClientConfigs { + clientPrefix := clientConfig.Name + + // Substitute connection string + if clientConfig.ConnectionString != nil { + if envVar, exists := envVarMap[fmt.Sprintf("%s.connection_string", clientPrefix)]; exists { + config.ClientConfigs[i].ConnectionString = &[]string{fmt.Sprintf("env.%s", envVar)}[0] + } + } + } +} diff --git a/framework/go.mod b/framework/go.mod new file mode 100644 index 000000000..fded729cd --- /dev/null +++ b/framework/go.mod @@ -0,0 +1,89 @@ +module github.com/maximhq/bifrost/framework + +go 1.24 + +toolchain go1.24.3 + +require ( + github.com/google/uuid v1.6.0 + github.com/maximhq/bifrost/core v1.1.37 + github.com/redis/go-redis/v9 v9.12.1 + github.com/stretchr/testify v1.10.0 + github.com/weaviate/weaviate v1.31.5 + github.com/weaviate/weaviate-go-client/v5 v5.2.0 + gorm.io/driver/sqlite v1.6.0 + gorm.io/gorm v1.30.1 +) + +require ( + cloud.google.com/go/compute/metadata v0.8.0 // indirect + github.com/andybalholm/brotli v1.2.0 // indirect + github.com/asaskevich/govalidator v0.0.0-20230301143203-a9d515a09cc2 // indirect + github.com/aws/aws-sdk-go-v2 v1.38.0 // indirect + github.com/aws/aws-sdk-go-v2/config v1.31.0 // indirect + github.com/aws/aws-sdk-go-v2/credentials v1.18.4 // indirect + github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.18.3 // indirect + github.com/aws/aws-sdk-go-v2/internal/configsources v1.4.3 // indirect + github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.7.3 // indirect + github.com/aws/aws-sdk-go-v2/internal/ini v1.8.3 // indirect + github.com/aws/aws-sdk-go-v2/service/internal/accept-encoding v1.13.0 // indirect + github.com/aws/aws-sdk-go-v2/service/internal/presigned-url v1.13.3 // indirect + github.com/aws/aws-sdk-go-v2/service/sso v1.28.0 // indirect + github.com/aws/aws-sdk-go-v2/service/ssooidc v1.33.0 // indirect + github.com/aws/aws-sdk-go-v2/service/sts v1.37.0 // indirect + github.com/aws/smithy-go v1.22.5 // indirect + github.com/bahlo/generic-list-go v0.2.0 // indirect + github.com/buger/jsonparser v1.1.1 // indirect + github.com/bytedance/sonic v1.14.0 // indirect + github.com/bytedance/sonic/loader v0.3.0 // indirect + github.com/cespare/xxhash/v2 v2.3.0 // indirect + github.com/cloudwego/base64x v0.1.6 // indirect + github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc // indirect + github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f // indirect + github.com/go-openapi/analysis v0.23.0 // indirect + github.com/go-openapi/errors v0.22.0 // indirect + github.com/go-openapi/jsonpointer v0.21.0 // indirect + github.com/go-openapi/jsonreference v0.21.0 // indirect + github.com/go-openapi/loads v0.22.0 // indirect + github.com/go-openapi/runtime v0.24.2 // indirect + github.com/go-openapi/spec v0.21.0 // indirect + github.com/go-openapi/strfmt v0.23.0 // indirect + github.com/go-openapi/swag v0.23.0 // indirect + github.com/go-openapi/validate v0.24.0 // indirect + github.com/invopop/jsonschema v0.13.0 // indirect + github.com/jinzhu/inflection v1.0.0 // indirect + github.com/jinzhu/now v1.1.5 // indirect + github.com/josharian/intern v1.0.0 // indirect + github.com/klauspost/compress v1.18.0 // indirect + github.com/klauspost/cpuid/v2 v2.3.0 // indirect + github.com/mailru/easyjson v0.9.0 // indirect + github.com/mark3labs/mcp-go v0.37.0 // indirect + github.com/mattn/go-colorable v0.1.14 // indirect + github.com/mattn/go-isatty v0.0.20 // indirect + github.com/mattn/go-sqlite3 v1.14.32 // indirect + github.com/mitchellh/mapstructure v1.5.0 // indirect + github.com/oklog/ulid v1.3.1 // indirect + github.com/opentracing/opentracing-go v1.2.0 // indirect + github.com/pkg/errors v0.9.1 // indirect + github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2 // indirect + github.com/rs/zerolog v1.34.0 // indirect + github.com/spf13/cast v1.9.2 // indirect + github.com/twitchyliquid64/golang-asm v0.15.1 // indirect + github.com/valyala/bytebufferpool v1.0.0 // indirect + github.com/valyala/fasthttp v1.65.0 // indirect + github.com/wk8/go-ordered-map/v2 v2.1.8 // indirect + github.com/yosida95/uritemplate/v3 v3.0.2 // indirect + go.mongodb.org/mongo-driver v1.14.0 // indirect + go.opentelemetry.io/otel v1.37.0 // indirect + go.opentelemetry.io/otel/sdk/metric v1.37.0 // indirect + golang.org/x/arch v0.20.0 // indirect + golang.org/x/net v0.43.0 // indirect + golang.org/x/oauth2 v0.30.0 // indirect + golang.org/x/sys v0.35.0 // indirect + golang.org/x/text v0.28.0 // indirect + google.golang.org/genproto/googleapis/rpc v0.0.0-20250811230008-5f3141c8851a // indirect + google.golang.org/grpc v1.74.2 // indirect + google.golang.org/protobuf v1.36.7 // indirect + gopkg.in/yaml.v2 v2.4.0 // indirect + gopkg.in/yaml.v3 v3.0.1 // indirect +) diff --git a/framework/go.sum b/framework/go.sum new file mode 100644 index 000000000..d5fcb73c4 --- /dev/null +++ b/framework/go.sum @@ -0,0 +1,353 @@ +cloud.google.com/go/compute/metadata v0.8.0 h1:HxMRIbao8w17ZX6wBnjhcDkW6lTFpgcaobyVfZWqRLA= +cloud.google.com/go/compute/metadata v0.8.0/go.mod h1:sYOGTp851OV9bOFJ9CH7elVvyzopvWQFNNghtDQ/Biw= +github.com/BurntSushi/toml v0.3.1/go.mod h1:xHWCNGjB5oqiDr8zfno3MHue2Ht5sIBksp03qcyfWMU= +github.com/PuerkitoBio/purell v1.1.1/go.mod h1:c11w/QuzBsJSee3cPx9rAFu61PvFxuPbtSwDGJws/X0= +github.com/PuerkitoBio/urlesc v0.0.0-20170810143723-de5bf2ad4578/go.mod h1:uGdkoq3SwY9Y+13GIhn11/XLaGBb4BfwItxLd5jeuXE= +github.com/andybalholm/brotli v1.2.0 h1:ukwgCxwYrmACq68yiUqwIWnGY0cTPox/M94sVwToPjQ= +github.com/andybalholm/brotli v1.2.0/go.mod h1:rzTDkvFWvIrjDXZHkuS16NPggd91W3kUSvPlQ1pLaKY= +github.com/asaskevich/govalidator v0.0.0-20200907205600-7a23bdc65eef/go.mod h1:WaHUgvxTVq04UNunO+XhnAqY/wQc+bxr74GqbsZ/Jqw= +github.com/asaskevich/govalidator v0.0.0-20210307081110-f21760c49a8d/go.mod h1:WaHUgvxTVq04UNunO+XhnAqY/wQc+bxr74GqbsZ/Jqw= +github.com/asaskevich/govalidator v0.0.0-20230301143203-a9d515a09cc2 h1:DklsrG3dyBCFEj5IhUbnKptjxatkF07cF2ak3yi77so= +github.com/asaskevich/govalidator v0.0.0-20230301143203-a9d515a09cc2/go.mod h1:WaHUgvxTVq04UNunO+XhnAqY/wQc+bxr74GqbsZ/Jqw= +github.com/aws/aws-sdk-go-v2 v1.38.0 h1:UCRQ5mlqcFk9HJDIqENSLR3wiG1VTWlyUfLDEvY7RxU= +github.com/aws/aws-sdk-go-v2 v1.38.0/go.mod h1:9Q0OoGQoboYIAJyslFyF1f5K1Ryddop8gqMhWx/n4Wg= +github.com/aws/aws-sdk-go-v2/config v1.31.0 h1:9yH0xiY5fUnVNLRWO0AtayqwU1ndriZdN78LlhruJR4= +github.com/aws/aws-sdk-go-v2/config v1.31.0/go.mod h1:VeV3K72nXnhbe4EuxxhzsDc/ByrCSlZwUnWH52Nde/I= +github.com/aws/aws-sdk-go-v2/credentials v1.18.4 h1:IPd0Algf1b+Qy9BcDp0sCUcIWdCQPSzDoMK3a8pcbUM= +github.com/aws/aws-sdk-go-v2/credentials v1.18.4/go.mod h1:nwg78FjH2qvsRM1EVZlX9WuGUJOL5od+0qvm0adEzHk= +github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.18.3 h1:GicIdnekoJsjq9wqnvyi2elW6CGMSYKhdozE7/Svh78= +github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.18.3/go.mod h1:R7BIi6WNC5mc1kfRM7XM/VHC3uRWkjc396sfabq4iOo= +github.com/aws/aws-sdk-go-v2/internal/configsources v1.4.3 h1:o9RnO+YZ4X+kt5Z7Nvcishlz0nksIt2PIzDglLMP0vA= +github.com/aws/aws-sdk-go-v2/internal/configsources v1.4.3/go.mod h1:+6aLJzOG1fvMOyzIySYjOFjcguGvVRL68R+uoRencN4= +github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.7.3 h1:joyyUFhiTQQmVK6ImzNU9TQSNRNeD9kOklqTzyk5v6s= +github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.7.3/go.mod h1:+vNIyZQP3b3B1tSLI0lxvrU9cfM7gpdRXMFfm67ZcPc= +github.com/aws/aws-sdk-go-v2/internal/ini v1.8.3 h1:bIqFDwgGXXN1Kpp99pDOdKMTTb5d2KyU5X/BZxjOkRo= +github.com/aws/aws-sdk-go-v2/internal/ini v1.8.3/go.mod h1:H5O/EsxDWyU+LP/V8i5sm8cxoZgc2fdNR9bxlOFrQTo= +github.com/aws/aws-sdk-go-v2/service/internal/accept-encoding v1.13.0 h1:6+lZi2JeGKtCraAj1rpoZfKqnQ9SptseRZioejfUOLM= +github.com/aws/aws-sdk-go-v2/service/internal/accept-encoding v1.13.0/go.mod h1:eb3gfbVIxIoGgJsi9pGne19dhCBpK6opTYpQqAmdy44= +github.com/aws/aws-sdk-go-v2/service/internal/presigned-url v1.13.3 h1:ieRzyHXypu5ByllM7Sp4hC5f/1Fy5wqxqY0yB85hC7s= +github.com/aws/aws-sdk-go-v2/service/internal/presigned-url v1.13.3/go.mod h1:O5ROz8jHiOAKAwx179v+7sHMhfobFVi6nZt8DEyiYoM= +github.com/aws/aws-sdk-go-v2/service/sso v1.28.0 h1:Mc/MKBf2m4VynyJkABoVEN+QzkfLqGj0aiJuEe7cMeM= +github.com/aws/aws-sdk-go-v2/service/sso v1.28.0/go.mod h1:iS5OmxEcN4QIPXARGhavH7S8kETNL11kym6jhoS7IUQ= +github.com/aws/aws-sdk-go-v2/service/ssooidc v1.33.0 h1:6csaS/aJmqZQbKhi1EyEMM7yBW653Wy/B9hnBofW+sw= +github.com/aws/aws-sdk-go-v2/service/ssooidc v1.33.0/go.mod h1:59qHWaY5B+Rs7HGTuVGaC32m0rdpQ68N8QCN3khYiqs= +github.com/aws/aws-sdk-go-v2/service/sts v1.37.0 h1:MG9VFW43M4A8BYeAfaJJZWrroinxeTi2r3+SnmLQfSA= +github.com/aws/aws-sdk-go-v2/service/sts v1.37.0/go.mod h1:JdeBDPgpJfuS6rU/hNglmOigKhyEZtBmbraLE4GK1J8= +github.com/aws/smithy-go v1.22.5 h1:P9ATCXPMb2mPjYBgueqJNCA5S9UfktsW0tTxi+a7eqw= +github.com/aws/smithy-go v1.22.5/go.mod h1:t1ufH5HMublsJYulve2RKmHDC15xu1f26kHCp/HgceI= +github.com/bahlo/generic-list-go v0.2.0 h1:5sz/EEAK+ls5wF+NeqDpk5+iNdMDXrh3z3nPnH1Wvgk= +github.com/bahlo/generic-list-go v0.2.0/go.mod h1:2KvAjgMlE5NNynlg/5iLrrCCZ2+5xWbdbCW3pNTGyYg= +github.com/bsm/ginkgo/v2 v2.12.0 h1:Ny8MWAHyOepLGlLKYmXG4IEkioBysk6GpaRTLC8zwWs= +github.com/bsm/ginkgo/v2 v2.12.0/go.mod h1:SwYbGRRDovPVboqFv0tPTcG1sN61LM1Z4ARdbAV9g4c= +github.com/bsm/gomega v1.27.10 h1:yeMWxP2pV2fG3FgAODIY8EiRE3dy0aeFYt4l7wh6yKA= +github.com/bsm/gomega v1.27.10/go.mod h1:JyEr/xRbxbtgWNi8tIEVPUYZ5Dzef52k01W3YH0H+O0= +github.com/buger/jsonparser v1.1.1 h1:2PnMjfWD7wBILjqQbt530v576A/cAbQvEW9gGIpYMUs= +github.com/buger/jsonparser v1.1.1/go.mod h1:6RYKKt7H4d4+iWqouImQ9R2FZql3VbhNgx27UK13J/0= +github.com/bytedance/sonic v1.14.0 h1:/OfKt8HFw0kh2rj8N0F6C/qPGRESq0BbaNZgcNXXzQQ= +github.com/bytedance/sonic v1.14.0/go.mod h1:WoEbx8WTcFJfzCe0hbmyTGrfjt8PzNEBdxlNUO24NhA= +github.com/bytedance/sonic/loader v0.3.0 h1:dskwH8edlzNMctoruo8FPTJDF3vLtDT0sXZwvZJyqeA= +github.com/bytedance/sonic/loader v0.3.0/go.mod h1:N8A3vUdtUebEY2/VQC0MyhYeKUFosQU6FxH2JmUe6VI= +github.com/cespare/xxhash/v2 v2.3.0 h1:UL815xU9SqsFlibzuggzjXhog7bL6oX9BbNZnL2UFvs= +github.com/cespare/xxhash/v2 v2.3.0/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs= +github.com/cloudwego/base64x v0.1.6 h1:t11wG9AECkCDk5fMSoxmufanudBtJ+/HemLstXDLI2M= +github.com/cloudwego/base64x v0.1.6/go.mod h1:OFcloc187FXDaYHvrNIjxSe8ncn0OOM8gEHfghB2IPU= +github.com/coreos/go-systemd/v22 v22.5.0/go.mod h1:Y58oyj3AT4RCenI/lSvhwexgC+NSVTIJ3seZv2GcEnc= +github.com/creack/pty v1.1.9/go.mod h1:oKZEueFk5CKHvIhNR5MUki03XCEU+Q6VDXinZuGJ33E= +github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc h1:U9qPSI2PIWSS1VwoXQT9A3Wy9MM3WgvqSxFWenqJduM= +github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f h1:lO4WD4F/rVNCu3HqELle0jiPLLBs70cWOduZpkS1E78= +github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f/go.mod h1:cuUVRXasLTGF7a8hSLbxyZXjz+1KgoB3wDUb6vlszIc= +github.com/docker/go-units v0.4.0/go.mod h1:fgPhTUdO+D/Jk86RDLlptpiXQzgHJF7gydDDbaIK4Dk= +github.com/frankban/quicktest v1.14.6 h1:7Xjx+VpznH+oBnejlPUj8oUpdxnVs4f8XU8WnHkI4W8= +github.com/frankban/quicktest v1.14.6/go.mod h1:4ptaffx2x8+WTWXmUCuVU6aPUX1/Mz7zb5vbUoiM6w0= +github.com/go-logr/logr v1.4.3 h1:CjnDlHq8ikf6E492q6eKboGOC0T8CDaOvkHCIg8idEI= +github.com/go-logr/logr v1.4.3/go.mod h1:9T104GzyrTigFIr8wt5mBrctHMim0Nb2HLGrmQ40KvY= +github.com/go-logr/stdr v1.2.2 h1:hSWxHoqTgW2S2qGc0LTAI563KZ5YKYRhT3MFKZMbjag= +github.com/go-logr/stdr v1.2.2/go.mod h1:mMo/vtBO5dYbehREoey6XUKy/eSumjCCveDpRre4VKE= +github.com/go-openapi/analysis v0.21.2/go.mod h1:HZwRk4RRisyG8vx2Oe6aqeSQcoxRp47Xkp3+K6q+LdY= +github.com/go-openapi/analysis v0.23.0 h1:aGday7OWupfMs+LbmLZG4k0MYXIANxcuBTYUC03zFCU= +github.com/go-openapi/analysis v0.23.0/go.mod h1:9mz9ZWaSlV8TvjQHLl2mUW2PbZtemkE8yA5v22ohupo= +github.com/go-openapi/errors v0.19.8/go.mod h1:cM//ZKUKyO06HSwqAelJ5NsEMMcpa6VpXe8DOa1Mi1M= +github.com/go-openapi/errors v0.19.9/go.mod h1:cM//ZKUKyO06HSwqAelJ5NsEMMcpa6VpXe8DOa1Mi1M= +github.com/go-openapi/errors v0.20.2/go.mod h1:cM//ZKUKyO06HSwqAelJ5NsEMMcpa6VpXe8DOa1Mi1M= +github.com/go-openapi/errors v0.22.0 h1:c4xY/OLxUBSTiepAg3j/MHuAv5mJhnf53LLMWFB+u/w= +github.com/go-openapi/errors v0.22.0/go.mod h1:J3DmZScxCDufmIMsdOuDHxJbdOGC0xtUynjIx092vXE= +github.com/go-openapi/jsonpointer v0.19.3/go.mod h1:Pl9vOtqEWErmShwVjC8pYs9cog34VGT37dQOVbmoatg= +github.com/go-openapi/jsonpointer v0.19.5/go.mod h1:Pl9vOtqEWErmShwVjC8pYs9cog34VGT37dQOVbmoatg= +github.com/go-openapi/jsonpointer v0.21.0 h1:YgdVicSA9vH5RiHs9TZW5oyafXZFc6+2Vc1rr/O9oNQ= +github.com/go-openapi/jsonpointer v0.21.0/go.mod h1:IUyH9l/+uyhIYQ/PXVA41Rexl+kOkAPDdXEYns6fzUY= +github.com/go-openapi/jsonreference v0.19.6/go.mod h1:diGHMEHg2IqXZGKxqyvWdfWU/aim5Dprw5bqpKkTvns= +github.com/go-openapi/jsonreference v0.21.0 h1:Rs+Y7hSXT83Jacb7kFyjn4ijOuVGSvOdF2+tg1TRrwQ= +github.com/go-openapi/jsonreference v0.21.0/go.mod h1:LmZmgsrTkVg9LG4EaHeY8cBDslNPMo06cago5JNLkm4= +github.com/go-openapi/loads v0.21.1/go.mod h1:/DtAMXXneXFjbQMGEtbamCZb+4x7eGwkvZCvBmwUG+g= +github.com/go-openapi/loads v0.22.0 h1:ECPGd4jX1U6NApCGG1We+uEozOAvXvJSF4nnwHZ8Aco= +github.com/go-openapi/loads v0.22.0/go.mod h1:yLsaTCS92mnSAZX5WWoxszLj0u+Ojl+Zs5Stn1oF+rs= +github.com/go-openapi/runtime v0.24.2 h1:yX9HMGQbz32M87ECaAhGpJjBmErO3QLcgdZj9BzGx7c= +github.com/go-openapi/runtime v0.24.2/go.mod h1:AKurw9fNre+h3ELZfk6ILsfvPN+bvvlaU/M9q/r9hpk= +github.com/go-openapi/spec v0.20.4/go.mod h1:faYFR1CvsJZ0mNsmsphTMSoRrNV3TEDoAM7FOEWeq8I= +github.com/go-openapi/spec v0.21.0 h1:LTVzPc3p/RzRnkQqLRndbAzjY0d0BCL72A6j3CdL9ZY= +github.com/go-openapi/spec v0.21.0/go.mod h1:78u6VdPw81XU44qEWGhtr982gJ5BWg2c0I5XwVMotYk= +github.com/go-openapi/strfmt v0.21.0/go.mod h1:ZRQ409bWMj+SOgXofQAGTIo2Ebu72Gs+WaRADcS5iNg= +github.com/go-openapi/strfmt v0.21.1/go.mod h1:I/XVKeLc5+MM5oPNN7P6urMOpuLXEcNrCX/rPGuWb0k= +github.com/go-openapi/strfmt v0.21.2/go.mod h1:I/XVKeLc5+MM5oPNN7P6urMOpuLXEcNrCX/rPGuWb0k= +github.com/go-openapi/strfmt v0.23.0 h1:nlUS6BCqcnAk0pyhi9Y+kdDVZdZMHfEKQiS4HaMgO/c= +github.com/go-openapi/strfmt v0.23.0/go.mod h1:NrtIpfKtWIygRkKVsxh7XQMDQW5HKQl6S5ik2elW+K4= +github.com/go-openapi/swag v0.19.5/go.mod h1:POnQmlKehdgb5mhVOsnJFsivZCEZ/vjK9gh66Z9tfKk= +github.com/go-openapi/swag v0.19.15/go.mod h1:QYRuS/SOXUCsnplDa677K7+DxSOj6IPNl/eQntq43wQ= +github.com/go-openapi/swag v0.21.1/go.mod h1:QYRuS/SOXUCsnplDa677K7+DxSOj6IPNl/eQntq43wQ= +github.com/go-openapi/swag v0.23.0 h1:vsEVJDUo2hPJ2tu0/Xc+4noaxyEffXNIs3cOULZ+GrE= +github.com/go-openapi/swag v0.23.0/go.mod h1:esZ8ITTYEsH1V2trKHjAN8Ai7xHb8RV+YSZ577vPjgQ= +github.com/go-openapi/validate v0.21.0/go.mod h1:rjnrwK57VJ7A8xqfpAOEKRH8yQSGUriMu5/zuPSQ1hg= +github.com/go-openapi/validate v0.24.0 h1:LdfDKwNbpB6Vn40xhTdNZAnfLECL81w+VX3BumrGD58= +github.com/go-openapi/validate v0.24.0/go.mod h1:iyeX1sEufmv3nPbBdX3ieNviWnOZaJ1+zquzJEf2BAQ= +github.com/go-stack/stack v1.8.0/go.mod h1:v0f6uXyyMGvRgIKkXu+yp6POWl0qKG85gN/melR3HDY= +github.com/go-stack/stack v1.8.1/go.mod h1:dcoOX6HbPZSZptuspn9bctJ+N/CnF5gGygcUP3XYfe4= +github.com/gobuffalo/attrs v0.0.0-20190224210810-a9411de4debd/go.mod h1:4duuawTqi2wkkpB4ePgWMaai6/Kc6WEz83bhFwpHzj0= +github.com/gobuffalo/depgen v0.0.0-20190329151759-d478694a28d3/go.mod h1:3STtPUQYuzV0gBVOY3vy6CfMm/ljR4pABfrTeHNLHUY= +github.com/gobuffalo/depgen v0.1.0/go.mod h1:+ifsuy7fhi15RWncXQQKjWS9JPkdah5sZvtHc2RXGlg= +github.com/gobuffalo/envy v1.6.15/go.mod h1:n7DRkBerg/aorDM8kbduw5dN3oXGswK5liaSCx4T5NI= +github.com/gobuffalo/envy v1.7.0/go.mod h1:n7DRkBerg/aorDM8kbduw5dN3oXGswK5liaSCx4T5NI= +github.com/gobuffalo/flect v0.1.0/go.mod h1:d2ehjJqGOH/Kjqcoz+F7jHTBbmDb38yXA598Hb50EGs= +github.com/gobuffalo/flect v0.1.1/go.mod h1:8JCgGVbRjJhVgD6399mQr4fx5rRfGKVzFjbj6RE/9UI= +github.com/gobuffalo/flect v0.1.3/go.mod h1:8JCgGVbRjJhVgD6399mQr4fx5rRfGKVzFjbj6RE/9UI= +github.com/gobuffalo/genny v0.0.0-20190329151137-27723ad26ef9/go.mod h1:rWs4Z12d1Zbf19rlsn0nurr75KqhYp52EAGGxTbBhNk= +github.com/gobuffalo/genny v0.0.0-20190403191548-3ca520ef0d9e/go.mod h1:80lIj3kVJWwOrXWWMRzzdhW3DsrdjILVil/SFKBzF28= +github.com/gobuffalo/genny v0.1.0/go.mod h1:XidbUqzak3lHdS//TPu2OgiFB+51Ur5f7CSnXZ/JDvo= +github.com/gobuffalo/genny v0.1.1/go.mod h1:5TExbEyY48pfunL4QSXxlDOmdsD44RRq4mVZ0Ex28Xk= +github.com/gobuffalo/gitgen v0.0.0-20190315122116-cc086187d211/go.mod h1:vEHJk/E9DmhejeLeNt7UVvlSGv3ziL+djtTr3yyzcOw= +github.com/gobuffalo/gogen v0.0.0-20190315121717-8f38393713f5/go.mod h1:V9QVDIxsgKNZs6L2IYiGR8datgMhB577vzTDqypH360= +github.com/gobuffalo/gogen v0.1.0/go.mod h1:8NTelM5qd8RZ15VjQTFkAW6qOMx5wBbW4dSCS3BY8gg= +github.com/gobuffalo/gogen v0.1.1/go.mod h1:y8iBtmHmGc4qa3urIyo1shvOD8JftTtfcKi+71xfDNE= +github.com/gobuffalo/logger v0.0.0-20190315122211-86e12af44bc2/go.mod h1:QdxcLw541hSGtBnhUc4gaNIXRjiDppFGaDqzbrBd3v8= +github.com/gobuffalo/mapi v1.0.1/go.mod h1:4VAGh89y6rVOvm5A8fKFxYG+wIW6LO1FMTG9hnKStFc= +github.com/gobuffalo/mapi v1.0.2/go.mod h1:4VAGh89y6rVOvm5A8fKFxYG+wIW6LO1FMTG9hnKStFc= +github.com/gobuffalo/packd v0.0.0-20190315124812-a385830c7fc0/go.mod h1:M2Juc+hhDXf/PnmBANFCqx4DM3wRbgDvnVWeG2RIxq4= +github.com/gobuffalo/packd v0.1.0/go.mod h1:M2Juc+hhDXf/PnmBANFCqx4DM3wRbgDvnVWeG2RIxq4= +github.com/gobuffalo/packr/v2 v2.0.9/go.mod h1:emmyGweYTm6Kdper+iywB6YK5YzuKchGtJQZ0Odn4pQ= +github.com/gobuffalo/packr/v2 v2.2.0/go.mod h1:CaAwI0GPIAv+5wKLtv8Afwl+Cm78K/I/VCm/3ptBN+0= +github.com/gobuffalo/syncx v0.0.0-20190224160051-33c29581e754/go.mod h1:HhnNqWY95UYwwW3uSASeV7vtgYkT2t16hJgV3AEPUpw= +github.com/godbus/dbus/v5 v5.0.4/go.mod h1:xhWf0FNVPg57R7Z0UbKHbJfkEywrmjJnf7w5xrFpKfA= +github.com/golang/protobuf v1.5.4 h1:i7eJL8qZTpSEXOPTxNKhASYpMn+8e5Q6AdndVa1dWek= +github.com/golang/protobuf v1.5.4/go.mod h1:lnTiLA8Wa4RWRcIUkrtSVa5nRhsEGBg48fD6rSs7xps= +github.com/golang/snappy v0.0.1/go.mod h1:/XxbfmMg8lxefKM7IXC3fBNl/7bRcc72aCRzEWrmP2Q= +github.com/google/go-cmp v0.5.2/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= +github.com/google/go-cmp v0.7.0 h1:wk8382ETsv4JYUZwIsn6YpYiWiBsYLSJiTsyBybVuN8= +github.com/google/go-cmp v0.7.0/go.mod h1:pXiqmnSA92OHEEa9HXL2W4E7lf9JzCmGVUdgjX3N/iU= +github.com/google/uuid v1.1.1/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= +github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0= +github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= +github.com/inconshreveable/mousetrap v1.0.0/go.mod h1:PxqpIevigyE2G7u3NXJIT2ANytuPF1OarO4DADm73n8= +github.com/invopop/jsonschema v0.13.0 h1:KvpoAJWEjR3uD9Kbm2HWJmqsEaHt8lBUpd0qHcIi21E= +github.com/invopop/jsonschema v0.13.0/go.mod h1:ffZ5Km5SWWRAIN6wbDXItl95euhFz2uON45H2qjYt+0= +github.com/jinzhu/inflection v1.0.0 h1:K317FqzuhWc8YvSVlFMCCUb36O/S9MCKRDI7QkRKD/E= +github.com/jinzhu/inflection v1.0.0/go.mod h1:h+uFLlag+Qp1Va5pdKtLDYj+kHp5pxUVkryuEj+Srlc= +github.com/jinzhu/now v1.1.5 h1:/o9tlHleP7gOFmsnYNz3RGnqzefHA47wQpKrrdTIwXQ= +github.com/jinzhu/now v1.1.5/go.mod h1:d3SSVoowX0Lcu0IBviAWJpolVfI5UJVZZ7cO71lE/z8= +github.com/joho/godotenv v1.3.0/go.mod h1:7hK45KPybAkOC6peb+G5yklZfMxEjkZhHbwpqxOKXbg= +github.com/josharian/intern v1.0.0 h1:vlS4z54oSdjm0bgjRigI+G1HpF+tI+9rE5LLzOg8HmY= +github.com/josharian/intern v1.0.0/go.mod h1:5DoeVV0s6jJacbCEi61lwdGj/aVlrQvzHFFd8Hwg//Y= +github.com/karrick/godirwalk v1.8.0/go.mod h1:H5KPZjojv4lE+QYImBI8xVtrBRgYrIVsaRPx4tDPEn4= +github.com/karrick/godirwalk v1.10.3/go.mod h1:RoGL9dQei4vP9ilrpETWE8CLOZ1kiN0LhBygSwrAsHA= +github.com/klauspost/compress v1.13.6/go.mod h1:/3/Vjq9QcHkK5uEr5lBEmyoZ1iFhe47etQ6QUkpK6sk= +github.com/klauspost/compress v1.18.0 h1:c/Cqfb0r+Yi+JtIEq73FWXVkRonBlf0CRNYc8Zttxdo= +github.com/klauspost/compress v1.18.0/go.mod h1:2Pp+KzxcywXVXMr50+X0Q/Lsb43OQHYWRCY2AiWywWQ= +github.com/klauspost/cpuid/v2 v2.3.0 h1:S4CRMLnYUhGeDFDqkGriYKdfoFlDnMtqTiI/sFzhA9Y= +github.com/klauspost/cpuid/v2 v2.3.0/go.mod h1:hqwkgyIinND0mEev00jJYCxPNVRVXFQeu1XKlok6oO0= +github.com/konsorten/go-windows-terminal-sequences v1.0.1/go.mod h1:T0+1ngSBFLxvqU3pZ+m/2kptfBszLMUkC4ZK/EgS/cQ= +github.com/konsorten/go-windows-terminal-sequences v1.0.2/go.mod h1:T0+1ngSBFLxvqU3pZ+m/2kptfBszLMUkC4ZK/EgS/cQ= +github.com/kr/pretty v0.1.0/go.mod h1:dAy3ld7l9f0ibDNOQOHHMYYIIbhfbHSm3C4ZsoJORNo= +github.com/kr/pretty v0.3.1 h1:flRD4NNwYAUpkphVc1HcthR4KEIFJ65n8Mw5qdRn3LE= +github.com/kr/pretty v0.3.1/go.mod h1:hoEshYVHaxMs3cyo3Yncou5ZscifuDolrwPKZanG3xk= +github.com/kr/pty v1.1.1/go.mod h1:pFQYn66WHrOpPYNljwOMqo10TkYh1fy3cYio2l3bCsQ= +github.com/kr/text v0.1.0/go.mod h1:4Jbv+DJW3UT/LiOwJeYQe1efqtUx/iVham/4vfdArNI= +github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY= +github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE= +github.com/mailru/easyjson v0.0.0-20190614124828-94de47d64c63/go.mod h1:C1wdFJiN94OJF2b5HbByQZoLdCWB1Yqtg26g4irojpc= +github.com/mailru/easyjson v0.0.0-20190626092158-b2ccc519800e/go.mod h1:C1wdFJiN94OJF2b5HbByQZoLdCWB1Yqtg26g4irojpc= +github.com/mailru/easyjson v0.7.6/go.mod h1:xzfreul335JAWq5oZzymOObrkdz5UnU4kGfJJLY9Nlc= +github.com/mailru/easyjson v0.7.7/go.mod h1:xzfreul335JAWq5oZzymOObrkdz5UnU4kGfJJLY9Nlc= +github.com/mailru/easyjson v0.9.0 h1:PrnmzHw7262yW8sTBwxi1PdJA3Iw/EKBa8psRf7d9a4= +github.com/mailru/easyjson v0.9.0/go.mod h1:1+xMtQp2MRNVL/V1bOzuP3aP8VNwRW55fQUto+XFtTU= +github.com/mark3labs/mcp-go v0.37.0 h1:BywvZLPRT6Zx6mMG/MJfxLSZQkTGIcJSEGKsvr4DsoQ= +github.com/mark3labs/mcp-go v0.37.0/go.mod h1:T7tUa2jO6MavG+3P25Oy/jR7iCeJPHImCZHRymCn39g= +github.com/markbates/oncer v0.0.0-20181203154359-bf2de49a0be2/go.mod h1:Ld9puTsIW75CHf65OeIOkyKbteujpZVXDpWK6YGZbxE= +github.com/markbates/safe v1.0.1/go.mod h1:nAqgmRi7cY2nqMc92/bSEeQA+R4OheNU2T1kNSCBdG0= +github.com/mattn/go-colorable v0.1.13/go.mod h1:7S9/ev0klgBDR4GtXTXX8a3vIGJpMovkB8vQcUbaXHg= +github.com/mattn/go-colorable v0.1.14 h1:9A9LHSqF/7dyVVX6g0U9cwm9pG3kP9gSzcuIPHPsaIE= +github.com/mattn/go-colorable v0.1.14/go.mod h1:6LmQG8QLFO4G5z1gPvYEzlUgJ2wF+stgPZH1UqBm1s8= +github.com/mattn/go-isatty v0.0.16/go.mod h1:kYGgaQfpe5nmfYZH+SKPsOc2e4SrIfOl2e/yFXSvRLM= +github.com/mattn/go-isatty v0.0.19/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y= +github.com/mattn/go-isatty v0.0.20 h1:xfD0iDuEKnDkl03q4limB+vH+GxLEtL/jb4xVJSWWEY= +github.com/mattn/go-isatty v0.0.20/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y= +github.com/mattn/go-sqlite3 v1.14.32 h1:JD12Ag3oLy1zQA+BNn74xRgaBbdhbNIDYvQUEuuErjs= +github.com/mattn/go-sqlite3 v1.14.32/go.mod h1:Uh1q+B4BYcTPb+yiD3kU8Ct7aC0hY9fxUwlHK0RXw+Y= +github.com/maximhq/bifrost/core v1.1.37 h1:jVFY1tQFY8T2r4S3RE1zN8cFp1Uw97Dec3Ud32rR8Uc= +github.com/maximhq/bifrost/core v1.1.37/go.mod h1:tf2pFTpoM53UGXXMFYxsaUjMqnCqYDOd9glFgMJvA0c= +github.com/mitchellh/mapstructure v1.3.3/go.mod h1:bFUtVrKA4DC2yAKiSyO/QUcy7e+RRV2QTWOzhPopBRo= +github.com/mitchellh/mapstructure v1.4.1/go.mod h1:bFUtVrKA4DC2yAKiSyO/QUcy7e+RRV2QTWOzhPopBRo= +github.com/mitchellh/mapstructure v1.4.3/go.mod h1:bFUtVrKA4DC2yAKiSyO/QUcy7e+RRV2QTWOzhPopBRo= +github.com/mitchellh/mapstructure v1.5.0 h1:jeMsZIYE/09sWLaz43PL7Gy6RuMjD2eJVyuac5Z2hdY= +github.com/mitchellh/mapstructure v1.5.0/go.mod h1:bFUtVrKA4DC2yAKiSyO/QUcy7e+RRV2QTWOzhPopBRo= +github.com/montanaflynn/stats v0.0.0-20171201202039-1bf9dbcd8cbe/go.mod h1:wL8QJuTMNUDYhXwkmfOly8iTdp5TEcJFWZD2D7SIkUc= +github.com/niemeyer/pretty v0.0.0-20200227124842-a10e7caefd8e/go.mod h1:zD1mROLANZcx1PVRCS0qkT7pwLkGfwJo4zjcN/Tysno= +github.com/oklog/ulid v1.3.1 h1:EGfNDEx6MqHz8B3uNV6QAib1UR2Lm97sHi3ocA6ESJ4= +github.com/oklog/ulid v1.3.1/go.mod h1:CirwcVhetQ6Lv90oh/F+FBtV6XMibvdAFo93nm5qn4U= +github.com/opentracing/opentracing-go v1.2.0 h1:uEJPy/1a5RIPAJ0Ov+OIO8OxWu77jEv+1B0VhjKrZUs= +github.com/opentracing/opentracing-go v1.2.0/go.mod h1:GxEUsuufX4nBwe+T+Wl9TAgYrxe9dPLANfrWvHYVTgc= +github.com/pelletier/go-toml v1.7.0/go.mod h1:vwGMzjaWMwyfHwgIBhI2YUM4fB6nL6lVAvS1LBMMhTE= +github.com/pkg/errors v0.8.0/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= +github.com/pkg/errors v0.8.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= +github.com/pkg/errors v0.9.1 h1:FEBLx1zS214owpjy7qsBeixbURkuhQAwrK5UwLGTwt4= +github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= +github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= +github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2 h1:Jamvg5psRIccs7FGNTlIRMkT8wgtp5eCXdBlqhYGL6U= +github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= +github.com/redis/go-redis/v9 v9.12.1 h1:k5iquqv27aBtnTm2tIkROUDp8JBXhXZIVu1InSgvovg= +github.com/redis/go-redis/v9 v9.12.1/go.mod h1:huWgSWd8mW6+m0VPhJjSSQ+d6Nh1VICQ6Q5lHuCH/Iw= +github.com/rogpeppe/go-internal v1.1.0/go.mod h1:M8bDsm7K2OlrFYOpmOWEs/qY81heoFRclV5y23lUDJ4= +github.com/rogpeppe/go-internal v1.2.2/go.mod h1:M8bDsm7K2OlrFYOpmOWEs/qY81heoFRclV5y23lUDJ4= +github.com/rogpeppe/go-internal v1.3.0/go.mod h1:M8bDsm7K2OlrFYOpmOWEs/qY81heoFRclV5y23lUDJ4= +github.com/rogpeppe/go-internal v1.13.1 h1:KvO1DLK/DRN07sQ1LQKScxyZJuNnedQ5/wKSR38lUII= +github.com/rogpeppe/go-internal v1.13.1/go.mod h1:uMEvuHeurkdAXX61udpOXGD/AzZDWNMNyH2VO9fmH0o= +github.com/rs/xid v1.6.0/go.mod h1:7XoLgs4eV+QndskICGsho+ADou8ySMSjJKDIan90Nz0= +github.com/rs/zerolog v1.34.0 h1:k43nTLIwcTVQAncfCw4KZ2VY6ukYoZaBPNOE8txlOeY= +github.com/rs/zerolog v1.34.0/go.mod h1:bJsvje4Z08ROH4Nhs5iH600c3IkWhwp44iRc54W6wYQ= +github.com/sirupsen/logrus v1.4.0/go.mod h1:LxeOpSwHxABJmUn/MG1IvRgCAasNZTLOkJPxbbu5VWo= +github.com/sirupsen/logrus v1.4.1/go.mod h1:ni0Sbl8bgC9z8RoU9G6nDWqqs/fq4eDPysMBDgk/93Q= +github.com/sirupsen/logrus v1.4.2/go.mod h1:tLMulIdttU9McNUspp0xgXVQah82FyeX6MwdIuYE2rE= +github.com/spf13/cast v1.9.2 h1:SsGfm7M8QOFtEzumm7UZrZdLLquNdzFYfIbEXntcFbE= +github.com/spf13/cast v1.9.2/go.mod h1:jNfB8QC9IA6ZuY2ZjDp0KtFO2LZZlg4S/7bzP6qqeHo= +github.com/spf13/cobra v0.0.3/go.mod h1:1l0Ry5zgKvJasoi3XT1TypsSe7PqH0Sj9dhYf7v3XqQ= +github.com/spf13/pflag v1.0.3/go.mod h1:DYY7MBk1bdzusC3SYhjObp+wFpr4gzcvqqNjLnInEg4= +github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= +github.com/stretchr/objx v0.1.1/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= +github.com/stretchr/objx v0.4.0/go.mod h1:YvHI0jy2hoMjB+UWwv71VJQ9isScKT/TqJzVSSt89Yw= +github.com/stretchr/objx v0.5.0/go.mod h1:Yh+to48EsGEfYuaHDzXPcE3xhTkx73EhmCGUpEOglKo= +github.com/stretchr/objx v0.5.2 h1:xuMeJ0Sdp5ZMRXx/aWO6RZxdr3beISkG5/G/aIRr3pY= +github.com/stretchr/objx v0.5.2/go.mod h1:FRsXN1f5AsAjCGJKqEizvkpNtU+EGNCLh3NxZ/8L+MA= +github.com/stretchr/testify v1.2.2/go.mod h1:a8OnRcib4nhh0OaRAV+Yts87kKdq0PP7pXfy6kDkUVs= +github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI= +github.com/stretchr/testify v1.6.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= +github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= +github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= +github.com/stretchr/testify v1.8.0/go.mod h1:yNjHg4UonilssWZ8iaSj1OCr/vHnekPRkoO+kdMU+MU= +github.com/stretchr/testify v1.8.1/go.mod h1:w2LPCIKwWwSfY2zedu0+kehJoqGctiVI29o6fzry7u4= +github.com/stretchr/testify v1.10.0 h1:Xv5erBjTwe/5IxqUQTdXv5kgmIvbHo3QQyRwhJsOfJA= +github.com/stretchr/testify v1.10.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY= +github.com/tidwall/pretty v1.0.0/go.mod h1:XNkn88O1ChpSDQmQeStsy+sBenx6DDtFZJxhVysOjyk= +github.com/twitchyliquid64/golang-asm v0.15.1 h1:SU5vSMR7hnwNxj24w34ZyCi/FmDZTkS4MhqMhdFk5YI= +github.com/twitchyliquid64/golang-asm v0.15.1/go.mod h1:a1lVb/DtPvCB8fslRZhAngC2+aY1QWCk3Cedj/Gdt08= +github.com/valyala/bytebufferpool v1.0.0 h1:GqA5TC/0021Y/b9FG4Oi9Mr3q7XYx6KllzawFIhcdPw= +github.com/valyala/bytebufferpool v1.0.0/go.mod h1:6bBcMArwyJ5K/AmCkWv1jt77kVWyCJ6HpOuEn7z0Csc= +github.com/valyala/fasthttp v1.65.0 h1:j/u3uzFEGFfRxw79iYzJN+TteTJwbYkru9uDp3d0Yf8= +github.com/valyala/fasthttp v1.65.0/go.mod h1:P/93/YkKPMsKSnATEeELUCkG8a7Y+k99uxNHVbKINr4= +github.com/weaviate/weaviate v1.31.5 h1:YcmU1NcY2rdegWpE/mifS/9OisjE3I30JC7k6OgRlIE= +github.com/weaviate/weaviate v1.31.5/go.mod h1:CMgFYC2WIekOrNtyCQZ+HRJzJVCtrJYAdAkZVUVy45E= +github.com/weaviate/weaviate-go-client/v5 v5.2.0 h1:/HG0vFiKBK3JoOKo0mdk2XVYZ+oM0KfvCLG2ySr/FCA= +github.com/weaviate/weaviate-go-client/v5 v5.2.0/go.mod h1:nzR0ScRmbbutI+0pAjylj9Pt6upGVotnphiLWjy/QNA= +github.com/wk8/go-ordered-map/v2 v2.1.8 h1:5h/BUHu93oj4gIdvHHHGsScSTMijfx5PeYkE/fJgbpc= +github.com/wk8/go-ordered-map/v2 v2.1.8/go.mod h1:5nJHM5DyteebpVlHnWMV0rPz6Zp7+xBAnxjb1X5vnTw= +github.com/xdg-go/pbkdf2 v1.0.0/go.mod h1:jrpuAogTd400dnrH08LKmI/xc1MbPOebTwRqcT5RDeI= +github.com/xdg-go/scram v1.0.2/go.mod h1:1WAq6h33pAW+iRreB34OORO2Nf7qel3VV3fjBj+hCSs= +github.com/xdg-go/stringprep v1.0.2/go.mod h1:8F9zXuvzgwmyT5DUm4GUfZGDdT3W+LCvS6+da4O5kxM= +github.com/xyproto/randomstring v1.0.5 h1:YtlWPoRdgMu3NZtP45drfy1GKoojuR7hmRcnhZqKjWU= +github.com/xyproto/randomstring v1.0.5/go.mod h1:rgmS5DeNXLivK7YprL0pY+lTuhNQW3iGxZ18UQApw/E= +github.com/yosida95/uritemplate/v3 v3.0.2 h1:Ed3Oyj9yrmi9087+NczuL5BwkIc4wvTb5zIM+UJPGz4= +github.com/yosida95/uritemplate/v3 v3.0.2/go.mod h1:ILOh0sOhIJR3+L/8afwt/kE++YT040gmv5BQTMR2HP4= +github.com/youmark/pkcs8 v0.0.0-20181117223130-1be2e3e5546d/go.mod h1:rHwXgn7JulP+udvsHwJoVG1YGAP6VLg4y9I5dyZdqmA= +go.mongodb.org/mongo-driver v1.7.3/go.mod h1:NqaYOwnXWr5Pm7AOpO5QFxKJ503nbMse/R79oO62zWg= +go.mongodb.org/mongo-driver v1.7.5/go.mod h1:VXEWRZ6URJIkUq2SCAyapmhH0ZLRBP+FT4xhp5Zvxng= +go.mongodb.org/mongo-driver v1.8.3/go.mod h1:0sQWfOeY63QTntERDJJ/0SuKK0T1uVSgKCuAROlKEPY= +go.mongodb.org/mongo-driver v1.14.0 h1:P98w8egYRjYe3XDjxhYJagTokP/H6HzlsnojRgZRd80= +go.mongodb.org/mongo-driver v1.14.0/go.mod h1:Vzb0Mk/pa7e6cWw85R4F/endUC3u0U9jGcNU603k65c= +go.opentelemetry.io/auto/sdk v1.1.0 h1:cH53jehLUN6UFLY71z+NDOiNJqDdPRaXzTel0sJySYA= +go.opentelemetry.io/auto/sdk v1.1.0/go.mod h1:3wSPjt5PWp2RhlCcmmOial7AvC4DQqZb7a7wCow3W8A= +go.opentelemetry.io/otel v1.37.0 h1:9zhNfelUvx0KBfu/gb+ZgeAfAgtWrfHJZcAqFC228wQ= +go.opentelemetry.io/otel v1.37.0/go.mod h1:ehE/umFRLnuLa/vSccNq9oS1ErUlkkK71gMcN34UG8I= +go.opentelemetry.io/otel/metric v1.37.0 h1:mvwbQS5m0tbmqML4NqK+e3aDiO02vsf/WgbsdpcPoZE= +go.opentelemetry.io/otel/metric v1.37.0/go.mod h1:04wGrZurHYKOc+RKeye86GwKiTb9FKm1WHtO+4EVr2E= +go.opentelemetry.io/otel/sdk v1.37.0 h1:ItB0QUqnjesGRvNcmAcU0LyvkVyGJ2xftD29bWdDvKI= +go.opentelemetry.io/otel/sdk v1.37.0/go.mod h1:VredYzxUvuo2q3WRcDnKDjbdvmO0sCzOvVAiY+yUkAg= +go.opentelemetry.io/otel/sdk/metric v1.37.0 h1:90lI228XrB9jCMuSdA0673aubgRobVZFhbjxHHspCPc= +go.opentelemetry.io/otel/sdk/metric v1.37.0/go.mod h1:cNen4ZWfiD37l5NhS+Keb5RXVWZWpRE+9WyVCpbo5ps= +go.opentelemetry.io/otel/trace v1.37.0 h1:HLdcFNbRQBE2imdSEgm/kwqmQj1Or1l/7bW6mxVK7z4= +go.opentelemetry.io/otel/trace v1.37.0/go.mod h1:TlgrlQ+PtQO5XFerSPUYG0JSgGyryXewPGyayAWSBS0= +golang.org/x/arch v0.20.0 h1:dx1zTU0MAE98U+TQ8BLl7XsJbgze2WnNKF/8tGp/Q6c= +golang.org/x/arch v0.20.0/go.mod h1:bdwinDaKcfZUGpH09BB7ZmOfhalA8lQdzl62l8gGWsk= +golang.org/x/crypto v0.0.0-20180904163835-0709b304e793/go.mod h1:6SG95UA2DQfeDnfUPMdvaQW0Q7yPrPDi9nlGo2tz2b4= +golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= +golang.org/x/crypto v0.0.0-20190422162423-af44ce270edf/go.mod h1:WFFai1msRO1wXaEeE5yQxYXgSfI8pQAWXbQop6sCtWE= +golang.org/x/crypto v0.0.0-20200302210943-78000ba7a073/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto= +golang.org/x/crypto v0.0.0-20201216223049-8b5274cf687f/go.mod h1:jdWPYTVW3xRLrWPugEBEK3UY2ZEsg3UU495nc5E+M+I= +golang.org/x/net v0.0.0-20190311183353-d8887717615a/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg= +golang.org/x/net v0.0.0-20190404232315-eb5bcb51f2a3/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg= +golang.org/x/net v0.0.0-20210421230115-4e50805a0758/go.mod h1:72T/g9IO56b78aLF+1Kcs5dz7/ng1VjMUvfKvpfy+jM= +golang.org/x/net v0.0.0-20220127200216-cd36cc0744dd/go.mod h1:CfG3xpIq0wQ8r1q4Su4UZFWDARRcnwPjda9FqA0JpMk= +golang.org/x/net v0.43.0 h1:lat02VYK2j4aLzMzecihNvTlJNQUq316m2Mr9rnM6YE= +golang.org/x/net v0.43.0/go.mod h1:vhO1fvI4dGsIjh73sWfUVjj3N7CA9WkKJNQm2svM6Jg= +golang.org/x/oauth2 v0.30.0 h1:dnDm7JmhM45NNpd8FDDeLhK6FwqbOf4MLCM9zb1BOHI= +golang.org/x/oauth2 v0.30.0/go.mod h1:B++QgG3ZKulg6sRPGD/mqlHQs5rB3Ml9erfeDY7xKlU= +golang.org/x/sync v0.0.0-20190227155943-e225da77a7e6/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= +golang.org/x/sync v0.0.0-20190412183630-56d357773e84/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= +golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= +golang.org/x/sync v0.0.0-20190911185100-cd5d95a43a6e/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= +golang.org/x/sys v0.0.0-20180905080454-ebe1bf3edb33/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= +golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= +golang.org/x/sys v0.0.0-20190403152447-81d4e9dc473e/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20190412213103-97732733099d/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20190419153524-e8e3143a4f4a/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20190422165155-953cdadca894/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20190531175056-4c3a928424d2/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20191026070338-33540a1f6037/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20210420072515-93ed5bcd2bfe/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20210615035016-665e8c7367d1/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.0.0-20211216021012-1d35b9e2eb4e/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.0.0-20220811171246-fbc7d0a398ab/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.12.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.35.0 h1:vz1N37gP5bs89s7He8XuIYXpyY0+QlsKmzipCbUtyxI= +golang.org/x/sys v0.35.0/go.mod h1:BJP2sWEmIv4KK5OTEluFJCKSidICx8ciO85XgH3Ak8k= +golang.org/x/term v0.0.0-20201117132131-f5c789dd3221/go.mod h1:Nr5EML6q2oocZ2LXRh80K7BxOlk5/8JxuGnuhpl+muw= +golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo= +golang.org/x/term v0.0.0-20210927222741-03fcf44c2211/go.mod h1:jbD1KX2456YbFQfuXm/mYQcufACuNUgVhRMnK/tPxf8= +golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= +golang.org/x/text v0.3.5/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= +golang.org/x/text v0.3.6/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= +golang.org/x/text v0.3.7/go.mod h1:u+2+/6zg+i71rQMx5EYifcz6MCKuco9NR6JIITiCfzQ= +golang.org/x/text v0.28.0 h1:rhazDwis8INMIwQ4tpjLDzUhx6RlXqZNPEM0huQojng= +golang.org/x/text v0.28.0/go.mod h1:U8nCwOR8jO/marOQ0QbDiOngZVEBB7MAiitBuMjXiNU= +golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= +golang.org/x/tools v0.0.0-20190329151228-23e29df326fe/go.mod h1:LCzVGOaR6xXOjkQ3onu1FJEFr0SW1gC7cKk1uF8kGRs= +golang.org/x/tools v0.0.0-20190416151739-9c9e1878f421/go.mod h1:LCzVGOaR6xXOjkQ3onu1FJEFr0SW1gC7cKk1uF8kGRs= +golang.org/x/tools v0.0.0-20190420181800-aa740d480789/go.mod h1:LCzVGOaR6xXOjkQ3onu1FJEFr0SW1gC7cKk1uF8kGRs= +golang.org/x/tools v0.0.0-20190531172133-b3315ee88b7d/go.mod h1:/rFqwRUd4F7ZHNgwSSTFct+R/Kf4OFW1sUzUTQQTgfc= +golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= +google.golang.org/genproto/googleapis/rpc v0.0.0-20250811230008-5f3141c8851a h1:tPE/Kp+x9dMSwUm/uM0JKK0IfdiJkwAbSMSeZBXXJXc= +google.golang.org/genproto/googleapis/rpc v0.0.0-20250811230008-5f3141c8851a/go.mod h1:gw1tLEfykwDz2ET4a12jcXt4couGAm7IwsVaTy0Sflo= +google.golang.org/grpc v1.74.2 h1:WoosgB65DlWVC9FqI82dGsZhWFNBSLjQ84bjROOpMu4= +google.golang.org/grpc v1.74.2/go.mod h1:CtQ+BGjaAIXHs/5YS3i473GqwBBa1zGQNevxdeBEXrM= +google.golang.org/protobuf v1.36.7 h1:IgrO7UwFQGJdRNXH/sQux4R1Dj1WAKcLElzeeRaXV2A= +google.golang.org/protobuf v1.36.7/go.mod h1:jduwjTPXsFjZGTmRluh+L6NjiWu7pchiJ2/5YcXBHnY= +gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= +gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= +gopkg.in/check.v1 v1.0.0-20200227125254-8fa46927fb4f/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= +gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c h1:Hei/4ADfdWqJk1ZMxUNpqntNwaWcugrBjAiHlqqRiVk= +gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c/go.mod h1:JHkPIbrfpd72SG/EVd6muEfDQjcINNoR0C8j2r3qZ4Q= +gopkg.in/errgo.v2 v2.1.0/go.mod h1:hNsd1EY+bozCKY1Ytp96fpM3vjJbqLJn88ws8XvfDNI= +gopkg.in/yaml.v2 v2.2.2/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= +gopkg.in/yaml.v2 v2.2.8/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= +gopkg.in/yaml.v2 v2.4.0 h1:D8xgwECY7CYvx+Y2n4sBz93Jn9JRvxdiyyo8CTfuKaY= +gopkg.in/yaml.v2 v2.4.0/go.mod h1:RDklbk79AGWmwhnvt/jBztapEOGDOx6ZbXqjP6csGnQ= +gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= +gopkg.in/yaml.v3 v3.0.0-20200605160147-a5ece683394c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= +gopkg.in/yaml.v3 v3.0.0-20200615113413-eeeca48fe776/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= +gopkg.in/yaml.v3 v3.0.0-20210107192922-496545a6307b/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= +gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= +gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= +gorm.io/driver/sqlite v1.6.0 h1:WHRRrIiulaPiPFmDcod6prc4l2VGVWHz80KspNsxSfQ= +gorm.io/driver/sqlite v1.6.0/go.mod h1:AO9V1qIQddBESngQUKWL9yoH93HIeA1X6V633rBwyT8= +gorm.io/gorm v1.30.1 h1:lSHg33jJTBxs2mgJRfRZeLDG+WZaHYCk3Wtfl6Ngzo4= +gorm.io/gorm v1.30.1/go.mod h1:8Z33v652h4//uMA76KjeDH8mJXPm1QNCYrMeatR0DOE= diff --git a/framework/list.go b/framework/list.go new file mode 100644 index 000000000..7e32cfdd4 --- /dev/null +++ b/framework/list.go @@ -0,0 +1,14 @@ +// Package framework provides a list of dependencies that are required for the framework to work. +package framework + +// FrameworkDependency is a type that represents a dependency of the framework. +type FrameworkDependency string + +const ( + // FrameworkDependencyVectorStore indicates the framework requires a VectorStore implementation. + FrameworkDependencyVectorStore FrameworkDependency = "vector_store" + // FrameworkDependencyConfigStore indicates the framework requires a ConfigStore implementation. + FrameworkDependencyConfigStore FrameworkDependency = "config_store" + // FrameworkDependencyLogsStore indicates the framework requires a LogsStore implementation. + FrameworkDependencyLogsStore FrameworkDependency = "logs_store" +) diff --git a/framework/logstore/config.go b/framework/logstore/config.go new file mode 100644 index 000000000..b9ff51f52 --- /dev/null +++ b/framework/logstore/config.go @@ -0,0 +1,55 @@ +// Package logstore provides a logs store for Bifrost. +package logstore + +import ( + "encoding/json" + "fmt" +) + +// Config represents the configuration for the logs store. +type Config struct { + Enabled bool `json:"enabled"` + Type LogStoreType `json:"type"` + Config any `json:"config"` +} + +// UnmarshalJSON is the custom unmarshal logic for Config +func (c *Config) UnmarshalJSON(data []byte) error { + // First, unmarshal into a temporary struct to get the basic fields + type TempConfig struct { + Enabled bool `json:"enabled"` + Type LogStoreType `json:"type"` + Config json.RawMessage `json:"config"` // Keep as raw JSON + } + + var temp TempConfig + if err := json.Unmarshal(data, &temp); err != nil { + return fmt.Errorf("failed to unmarshal logs config: %w", err) + } + + // Set basic fields + c.Enabled = temp.Enabled + c.Type = temp.Type + + if !temp.Enabled { + c.Config = nil + return nil + } + + // Parse the config field based on type + switch temp.Type { + case LogStoreTypeSQLite: + if len(temp.Config) == 0 { + return fmt.Errorf("missing sqlite config payload") + } + var sqliteConfig SQLiteConfig + if err := json.Unmarshal(temp.Config, &sqliteConfig); err != nil { + return fmt.Errorf("failed to unmarshal sqlite config: %w", err) + } + c.Config = &sqliteConfig + + default: + return fmt.Errorf("unknown log store type: %s", temp.Type) + } + return nil +} diff --git a/framework/logstore/errors.go b/framework/logstore/errors.go new file mode 100644 index 000000000..650d767d3 --- /dev/null +++ b/framework/logstore/errors.go @@ -0,0 +1,7 @@ +package logstore + +import "fmt" + +var ( + ErrNotFound = fmt.Errorf("log not found") +) diff --git a/framework/logstore/sqlite.go b/framework/logstore/sqlite.go new file mode 100644 index 000000000..e09ec7cf4 --- /dev/null +++ b/framework/logstore/sqlite.go @@ -0,0 +1,234 @@ +package logstore + +import ( + "database/sql" + "errors" + "fmt" + "time" + + "github.com/maximhq/bifrost/core/schemas" + "gorm.io/driver/sqlite" + "gorm.io/gorm" +) + +// SQLiteConfig represents the configuration for a SQLite database. +type SQLiteConfig struct { + Path string `json:"path"` +} + +// SQLiteLogStore represents a logs store that uses a SQLite database. +type SQLiteLogStore struct { + db *gorm.DB + logger schemas.Logger +} + +// Create inserts a new log entry into the database. +func (s *SQLiteLogStore) Create(entry *Log) error { + return s.db.Create(entry).Error +} + +// Update updates a log entry in the database. +func (s *SQLiteLogStore) Update(id string, entry any) error { + tx := s.db.Model(&Log{}).Where("id = ?", id).Updates(entry) + if errors.Is(tx.Error, gorm.ErrRecordNotFound) { + return ErrNotFound + } + if tx.RowsAffected == 0 { + return ErrNotFound + } + return tx.Error +} + +// SearchLogs searches for logs in the database. +func (s *SQLiteLogStore) SearchLogs(filters SearchFilters, pagination PaginationOptions) (*SearchResult, error) { + baseQuery := s.db.Model(&Log{}) + + // Apply filters efficiently + if len(filters.Providers) > 0 { + baseQuery = baseQuery.Where("provider IN ?", filters.Providers) + } + if len(filters.Models) > 0 { + baseQuery = baseQuery.Where("model IN ?", filters.Models) + } + if len(filters.Status) > 0 { + baseQuery = baseQuery.Where("status IN ?", filters.Status) + } + if len(filters.Objects) > 0 { + baseQuery = baseQuery.Where("object_type IN ?", filters.Objects) + } + if filters.StartTime != nil { + baseQuery = baseQuery.Where("timestamp >= ?", *filters.StartTime) + } + if filters.EndTime != nil { + baseQuery = baseQuery.Where("timestamp <= ?", *filters.EndTime) + } + if filters.MinLatency != nil { + baseQuery = baseQuery.Where("latency >= ?", *filters.MinLatency) + } + if filters.MaxLatency != nil { + baseQuery = baseQuery.Where("latency <= ?", *filters.MaxLatency) + } + if filters.MinTokens != nil { + baseQuery = baseQuery.Where("total_tokens >= ?", *filters.MinTokens) + } + if filters.MaxTokens != nil { + baseQuery = baseQuery.Where("total_tokens <= ?", *filters.MaxTokens) + } + if filters.MinCost != nil { + baseQuery = baseQuery.Where("cost >= ?", *filters.MinCost) + } + if filters.MaxCost != nil { + baseQuery = baseQuery.Where("cost <= ?", *filters.MaxCost) + } + if filters.ContentSearch != "" { + baseQuery = baseQuery.Where("content_summary LIKE ?", "%"+filters.ContentSearch+"%") + } + + // Get total count + var totalCount int64 + if err := baseQuery.Count(&totalCount).Error; err != nil { + return nil, err + } + + // Initialize stats + stats := SearchStats{} + + // Calculate statistics efficiently if we have data + if totalCount > 0 { + // Total requests should include all requests (processing, success, error) + stats.TotalRequests = totalCount + + // Get completed requests count (success + error, excluding processing) for success rate calculation + var completedCount int64 + completedQuery := baseQuery.Session(&gorm.Session{}) + if err := completedQuery.Where("status IN ?", []string{"success", "error"}).Count(&completedCount).Error; err != nil { + return nil, err + } + + if completedCount > 0 { + // Calculate success rate based on completed requests only + var successCount int64 + successQuery := baseQuery.Session(&gorm.Session{}) + if err := successQuery.Where("status = ?", "success").Count(&successCount).Error; err != nil { + return nil, err + } + stats.SuccessRate = float64(successCount) / float64(completedCount) * 100 + + // Calculate average latency and total tokens in a single query for better performance + var result struct { + AvgLatency sql.NullFloat64 `json:"avg_latency"` + TotalTokens sql.NullInt64 `json:"total_tokens"` + TotalCost sql.NullFloat64 `json:"total_cost"` + } + + statsQuery := baseQuery.Session(&gorm.Session{}) + if err := statsQuery.Select("AVG(latency) as avg_latency, SUM(total_tokens) as total_tokens, SUM(cost) as total_cost").Scan(&result).Error; err != nil { + return nil, err + } + + if result.AvgLatency.Valid { + stats.AverageLatency = result.AvgLatency.Float64 + } + if result.TotalTokens.Valid { + stats.TotalTokens = result.TotalTokens.Int64 + } + if result.TotalCost.Valid { + stats.TotalCost = result.TotalCost.Float64 + } + } + } + + // Build order clause + direction := "DESC" + if pagination.Order == "asc" { + direction = "ASC" + } + + var orderClause string + switch pagination.SortBy { + case "timestamp": + orderClause = "timestamp " + direction + case "latency": + orderClause = "latency " + direction + case "tokens": + orderClause = "total_tokens " + direction + case "cost": + orderClause = "cost " + direction + default: + orderClause = "timestamp " + direction + } + + // Execute main query with sorting and pagination + var logs []Log + mainQuery := baseQuery.Order(orderClause) + + if pagination.Limit > 0 { + mainQuery = mainQuery.Limit(pagination.Limit) + } + if pagination.Offset > 0 { + mainQuery = mainQuery.Offset(pagination.Offset) + } + + if err := mainQuery.Find(&logs).Error; err != nil { + if errors.Is(err, gorm.ErrRecordNotFound) { + return &SearchResult{ + Logs: logs, + Pagination: pagination, + Stats: stats, + }, nil + } + return nil, err + } + + return &SearchResult{ + Logs: logs, + Pagination: pagination, + Stats: stats, + }, nil +} + +// FindFirst gets a log entry from the database. +func (s *SQLiteLogStore) FindFirst(query any, fields ...string) (*Log, error) { + var log Log + if err := s.db.Select(fields).Where(query).First(&log).Error; err != nil { + if errors.Is(err, gorm.ErrRecordNotFound) { + return nil, ErrNotFound + } + return nil, err + } + return &log, nil +} + +// CleanupLogs deletes old log entries from the database. +func (s *SQLiteLogStore) CleanupLogs(since time.Time) error { + result := s.db.Where("status = ? AND created_at < ?", "processing", since).Delete(&Log{}) + if result.Error != nil { + return fmt.Errorf("failed to cleanup old processing logs: %w", result.Error) + } + return nil +} + +// FindAll finds all log entries from the database. +func (s *SQLiteLogStore) FindAll(query any, fields ...string) ([]*Log, error) { + var logs []*Log + if err := s.db.Select(fields).Where(query).Find(&logs).Error; err != nil { + if errors.Is(err, gorm.ErrRecordNotFound) { + return []*Log{}, nil + } + return nil, err + } + return logs, nil +} + +func newSqliteLogStore(config *SQLiteConfig, logger schemas.Logger) (*SQLiteLogStore, error) { + // Configure SQLite with proper settings to handle concurrent access + dsn := fmt.Sprintf("%s??_journal_mode=WAL&_synchronous=NORMAL&_cache_size=10000&_busy_timeout=60000&_wal_autocheckpoint=1000", config.Path) + db, err := gorm.Open(sqlite.Open(dsn), &gorm.Config{}) + if err != nil { + return nil, err + } + if err := db.AutoMigrate(&Log{}); err != nil { + return nil, err + } + return &SQLiteLogStore{db: db, logger: logger}, nil +} diff --git a/framework/logstore/store.go b/framework/logstore/store.go new file mode 100644 index 000000000..d19851a19 --- /dev/null +++ b/framework/logstore/store.go @@ -0,0 +1,39 @@ +package logstore + +import ( + "fmt" + "time" + + "github.com/maximhq/bifrost/core/schemas" +) + +// LogStoreType represents the type of log store. +type LogStoreType string + +// LogStoreTypeSQLite is the type of log store for SQLite. +const ( + LogStoreTypeSQLite LogStoreType = "sqlite" +) + +// LogStore is the interface for the log store. +type LogStore interface { + Create(entry *Log) error + FindFirst(query any, fields ...string) (*Log, error) + FindAll(query any, fields ...string) ([]*Log, error) + SearchLogs(filters SearchFilters, pagination PaginationOptions) (*SearchResult, error) + Update(id string, entry any) error + CleanupLogs(since time.Time) error +} + +// NewLogStore creates a new log store based on the configuration. +func NewLogStore(config *Config, logger schemas.Logger) (LogStore, error) { + switch config.Type { + case LogStoreTypeSQLite: + if sqliteConfig, ok := config.Config.(*SQLiteConfig); ok { + return newSqliteLogStore(sqliteConfig, logger) + } + return nil, fmt.Errorf("invalid sqlite config: %T", config.Config) + default: + return nil, fmt.Errorf("unsupported log store type: %s", config.Type) + } +} diff --git a/framework/logstore/tables.go b/framework/logstore/tables.go new file mode 100644 index 000000000..a815142ec --- /dev/null +++ b/framework/logstore/tables.go @@ -0,0 +1,405 @@ +package logstore + +import ( + "encoding/json" + "strings" + "time" + + "github.com/maximhq/bifrost/core/schemas" + "gorm.io/gorm" +) + +type SortBy string + +const ( + SortByTimestamp SortBy = "timestamp" + SortByLatency SortBy = "latency" + SortByTokens SortBy = "tokens" + SortByCost SortBy = "cost" +) + +type SortOrder string + +const ( + SortAsc SortOrder = "asc" + SortDesc SortOrder = "desc" +) + +// SearchFilters represents the available filters for log searches +type SearchFilters struct { + Providers []string `json:"providers,omitempty"` + Models []string `json:"models,omitempty"` + Status []string `json:"status,omitempty"` + Objects []string `json:"objects,omitempty"` // For filtering by request type (chat.completion, text.completion, embedding) + StartTime *time.Time `json:"start_time,omitempty"` + EndTime *time.Time `json:"end_time,omitempty"` + MinLatency *float64 `json:"min_latency,omitempty"` + MaxLatency *float64 `json:"max_latency,omitempty"` + MinTokens *int `json:"min_tokens,omitempty"` + MaxTokens *int `json:"max_tokens,omitempty"` + MinCost *float64 `json:"min_cost,omitempty"` + MaxCost *float64 `json:"max_cost,omitempty"` + ContentSearch string `json:"content_search,omitempty"` +} + +// PaginationOptions represents pagination parameters +type PaginationOptions struct { + Limit int `json:"limit"` + Offset int `json:"offset"` + SortBy string `json:"sort_by"` // "timestamp", "latency", "tokens", "cost" + Order string `json:"order"` // "asc", "desc" +} + +// SearchResult represents the result of a log search +type SearchResult struct { + Logs []Log `json:"logs"` + Pagination PaginationOptions `json:"pagination"` + Stats SearchStats `json:"stats"` +} + +type SearchStats struct { + TotalRequests int64 `json:"total_requests"` + SuccessRate float64 `json:"success_rate"` // Percentage of successful requests + AverageLatency float64 `json:"average_latency"` // Average latency in milliseconds + TotalTokens int64 `json:"total_tokens"` // Total tokens used + TotalCost float64 `json:"total_cost"` // Total cost in dollars +} + +// Log represents a complete log entry for a request/response cycle +// This is the GORM model with appropriate tags +type Log struct { + ID string `gorm:"primaryKey;type:varchar(255)" json:"id"` + Timestamp time.Time `gorm:"index;not null" json:"timestamp"` + Object string `gorm:"type:varchar(255);index;not null;column:object_type" json:"object"` // text.completion, chat.completion, or embedding + Provider string `gorm:"type:varchar(255);index;not null" json:"provider"` + Model string `gorm:"type:varchar(255);index;not null" json:"model"` + InputHistory string `gorm:"type:text" json:"-"` // JSON serialized []schemas.BifrostMessage + OutputMessage string `gorm:"type:text" json:"-"` // JSON serialized *schemas.BifrostMessage + EmbeddingOutput string `gorm:"type:text" json:"-"` // JSON serialized *[][]float32 + Params string `gorm:"type:text" json:"-"` // JSON serialized *schemas.ModelParameters + Tools string `gorm:"type:text" json:"-"` // JSON serialized *[]schemas.Tool + ToolCalls string `gorm:"type:text" json:"-"` // JSON serialized *[]schemas.ToolCall + SpeechInput string `gorm:"type:text" json:"-"` // JSON serialized *schemas.SpeechInput + TranscriptionInput string `gorm:"type:text" json:"-"` // JSON serialized *schemas.TranscriptionInput + SpeechOutput string `gorm:"type:text" json:"-"` // JSON serialized *schemas.BifrostSpeech + TranscriptionOutput string `gorm:"type:text" json:"-"` // JSON serialized *schemas.BifrostTranscribe + CacheDebug string `gorm:"type:text" json:"-"` // JSON serialized *schemas.BifrostCacheDebug + Latency *float64 `json:"latency,omitempty"` + TokenUsage string `gorm:"type:text" json:"-"` // JSON serialized *schemas.LLMUsage + Cost *float64 `gorm:"index" json:"cost,omitempty"` // Cost in dollars (total cost of the request - includes cache lookup cost) + Status string `gorm:"type:varchar(50);index;not null" json:"status"` // "processing", "success", or "error" + ErrorDetails string `gorm:"type:text" json:"-"` // JSON serialized *schemas.BifrostError + Stream bool `gorm:"default:false" json:"stream"` // true if this was a streaming response + ContentSummary string `gorm:"type:text" json:"-"` // For content search + + // Denormalized token fields for easier querying + PromptTokens int `gorm:"default:0" json:"-"` + CompletionTokens int `gorm:"default:0" json:"-"` + TotalTokens int `gorm:"default:0" json:"-"` + + CreatedAt time.Time `gorm:"index;not null" json:"created_at"` + + // Virtual fields for JSON output - these will be populated when needed + InputHistoryParsed []schemas.BifrostMessage `gorm:"-" json:"input_history,omitempty"` + OutputMessageParsed *schemas.BifrostMessage `gorm:"-" json:"output_message,omitempty"` + EmbeddingOutputParsed *[]schemas.BifrostEmbedding `gorm:"-" json:"embedding_output,omitempty"` + ParamsParsed *schemas.ModelParameters `gorm:"-" json:"params,omitempty"` + ToolsParsed *[]schemas.Tool `gorm:"-" json:"tools,omitempty"` + ToolCallsParsed *[]schemas.ToolCall `gorm:"-" json:"tool_calls,omitempty"` + TokenUsageParsed *schemas.LLMUsage `gorm:"-" json:"token_usage,omitempty"` + ErrorDetailsParsed *schemas.BifrostError `gorm:"-" json:"error_details,omitempty"` + SpeechInputParsed *schemas.SpeechInput `gorm:"-" json:"speech_input,omitempty"` + TranscriptionInputParsed *schemas.TranscriptionInput `gorm:"-" json:"transcription_input,omitempty"` + SpeechOutputParsed *schemas.BifrostSpeech `gorm:"-" json:"speech_output,omitempty"` + TranscriptionOutputParsed *schemas.BifrostTranscribe `gorm:"-" json:"transcription_output,omitempty"` + CacheDebugParsed *schemas.BifrostCacheDebug `gorm:"-" json:"cache_debug,omitempty"` +} + +// TableName sets the table name for GORM +func (Log) TableName() string { + return "logs" +} + +// BeforeCreate GORM hook to set created_at and serialize JSON fields +func (l *Log) BeforeCreate(tx *gorm.DB) error { + if l.CreatedAt.IsZero() { + l.CreatedAt = time.Now().UTC() + } + return l.SerializeFields() +} + +// BeforeSave GORM hook to serialize JSON fields +func (l *Log) BeforeSave(tx *gorm.DB) error { + return l.SerializeFields() +} + +// AfterFind GORM hook to deserialize JSON fields +func (l *Log) AfterFind(tx *gorm.DB) error { + return l.DeserializeFields() +} + +// SerializeFields converts Go structs to JSON strings for storage +func (l *Log) SerializeFields() error { + if l.InputHistoryParsed != nil { + if data, err := json.Marshal(l.InputHistoryParsed); err != nil { + return err + } else { + l.InputHistory = string(data) + } + } + + if l.OutputMessageParsed != nil { + if data, err := json.Marshal(l.OutputMessageParsed); err != nil { + return err + } else { + l.OutputMessage = string(data) + } + } + + if l.EmbeddingOutputParsed != nil { + if data, err := json.Marshal(l.EmbeddingOutputParsed); err != nil { + return err + } else { + l.EmbeddingOutput = string(data) + } + } + + if l.SpeechInputParsed != nil { + if data, err := json.Marshal(l.SpeechInputParsed); err != nil { + return err + } else { + l.SpeechInput = string(data) + } + } + + if l.TranscriptionInputParsed != nil { + if data, err := json.Marshal(l.TranscriptionInputParsed); err != nil { + return err + } else { + l.TranscriptionInput = string(data) + } + } + + if l.SpeechOutputParsed != nil { + if data, err := json.Marshal(l.SpeechOutputParsed); err != nil { + return err + } else { + l.SpeechOutput = string(data) + } + } + + if l.TranscriptionOutputParsed != nil { + if data, err := json.Marshal(l.TranscriptionOutputParsed); err != nil { + return err + } else { + l.TranscriptionOutput = string(data) + } + } + + if l.ParamsParsed != nil { + if data, err := json.Marshal(l.ParamsParsed); err != nil { + return err + } else { + l.Params = string(data) + } + } + + if l.ToolsParsed != nil { + if data, err := json.Marshal(l.ToolsParsed); err != nil { + return err + } else { + l.Tools = string(data) + } + } + + if l.ToolCallsParsed != nil { + if data, err := json.Marshal(l.ToolCallsParsed); err != nil { + return err + } else { + l.ToolCalls = string(data) + } + } + + if l.TokenUsageParsed != nil { + if data, err := json.Marshal(l.TokenUsageParsed); err != nil { + return err + } else { + l.TokenUsage = string(data) + } + // Update denormalized fields for easier querying + l.PromptTokens = l.TokenUsageParsed.PromptTokens + l.CompletionTokens = l.TokenUsageParsed.CompletionTokens + l.TotalTokens = l.TokenUsageParsed.TotalTokens + } + + if l.ErrorDetailsParsed != nil { + if data, err := json.Marshal(l.ErrorDetailsParsed); err != nil { + return err + } else { + l.ErrorDetails = string(data) + } + } + + if l.CacheDebugParsed != nil { + if data, err := json.Marshal(l.CacheDebugParsed); err != nil { + return err + } else { + l.CacheDebug = string(data) + } + } + + // Build content summary for search + l.ContentSummary = l.BuildContentSummary() + + return nil +} + +// DeserializeFields converts JSON strings back to Go structs +func (l *Log) DeserializeFields() error { + if l.InputHistory != "" { + if err := json.Unmarshal([]byte(l.InputHistory), &l.InputHistoryParsed); err != nil { + // Log error but don't fail the operation - initialize as empty slice + l.InputHistoryParsed = []schemas.BifrostMessage{} + } + } + + if l.OutputMessage != "" { + if err := json.Unmarshal([]byte(l.OutputMessage), &l.OutputMessageParsed); err != nil { + // Log error but don't fail the operation - initialize as nil + l.OutputMessageParsed = nil + } + } + + if l.EmbeddingOutput != "" { + if err := json.Unmarshal([]byte(l.EmbeddingOutput), &l.EmbeddingOutputParsed); err != nil { + // Log error but don't fail the operation - initialize as nil + l.EmbeddingOutputParsed = nil + } + } + + if l.Params != "" { + if err := json.Unmarshal([]byte(l.Params), &l.ParamsParsed); err != nil { + // Log error but don't fail the operation - initialize as nil + l.ParamsParsed = nil + } + } + + if l.Tools != "" { + if err := json.Unmarshal([]byte(l.Tools), &l.ToolsParsed); err != nil { + // Log error but don't fail the operation - initialize as nil + l.ToolsParsed = nil + } + } + + if l.ToolCalls != "" { + if err := json.Unmarshal([]byte(l.ToolCalls), &l.ToolCallsParsed); err != nil { + // Log error but don't fail the operation - initialize as nil + l.ToolCallsParsed = nil + } + } + + if l.TokenUsage != "" { + if err := json.Unmarshal([]byte(l.TokenUsage), &l.TokenUsageParsed); err != nil { + // Log error but don't fail the operation - initialize as nil + l.TokenUsageParsed = nil + } + } + + if l.ErrorDetails != "" { + if err := json.Unmarshal([]byte(l.ErrorDetails), &l.ErrorDetailsParsed); err != nil { + // Log error but don't fail the operation - initialize as nil + l.ErrorDetailsParsed = nil + } + } + + // Deserialize speech and transcription fields + if l.SpeechInput != "" { + if err := json.Unmarshal([]byte(l.SpeechInput), &l.SpeechInputParsed); err != nil { + // Log error but don't fail the operation - initialize as nil + l.SpeechInputParsed = nil + } + } + + if l.TranscriptionInput != "" { + if err := json.Unmarshal([]byte(l.TranscriptionInput), &l.TranscriptionInputParsed); err != nil { + // Log error but don't fail the operation - initialize as nil + l.TranscriptionInputParsed = nil + } + } + + if l.SpeechOutput != "" { + if err := json.Unmarshal([]byte(l.SpeechOutput), &l.SpeechOutputParsed); err != nil { + // Log error but don't fail the operation - initialize as nil + l.SpeechOutputParsed = nil + } + } + + if l.TranscriptionOutput != "" { + if err := json.Unmarshal([]byte(l.TranscriptionOutput), &l.TranscriptionOutputParsed); err != nil { + // Log error but don't fail the operation - initialize as nil + l.TranscriptionOutputParsed = nil + } + } + + if l.CacheDebug != "" { + if err := json.Unmarshal([]byte(l.CacheDebug), &l.CacheDebugParsed); err != nil { + // Log error but don't fail the operation - initialize as nil + l.CacheDebugParsed = nil + } + } + + return nil +} + +// BuildContentSummary creates a searchable text summary +func (l *Log) BuildContentSummary() string { + var parts []string + + // Add input messages + for _, msg := range l.InputHistoryParsed { + // Access content through the Content field + if msg.Content.ContentStr != nil && *msg.Content.ContentStr != "" { + parts = append(parts, *msg.Content.ContentStr) + } + // If content blocks exist, extract text from them + if msg.Content.ContentBlocks != nil { + for _, block := range *msg.Content.ContentBlocks { + if block.Text != nil && *block.Text != "" { + parts = append(parts, *block.Text) + } + } + } + } + + // Add output message + if l.OutputMessageParsed != nil { + if l.OutputMessageParsed.Content.ContentStr != nil && *l.OutputMessageParsed.Content.ContentStr != "" { + parts = append(parts, *l.OutputMessageParsed.Content.ContentStr) + } + // If content blocks exist, extract text from them + if l.OutputMessageParsed.Content.ContentBlocks != nil { + for _, block := range *l.OutputMessageParsed.Content.ContentBlocks { + if block.Text != nil && *block.Text != "" { + parts = append(parts, *block.Text) + } + } + } + } + + // Add speech input content + if l.SpeechInputParsed != nil && l.SpeechInputParsed.Input != "" { + parts = append(parts, l.SpeechInputParsed.Input) + } + + // Add transcription output content + if l.TranscriptionOutputParsed != nil && l.TranscriptionOutputParsed.Text != "" { + parts = append(parts, l.TranscriptionOutputParsed.Text) + } + + // Add error details + if l.ErrorDetailsParsed != nil && l.ErrorDetailsParsed.Error.Message != "" { + parts = append(parts, l.ErrorDetailsParsed.Error.Message) + } + + return strings.Join(parts, " ") +} diff --git a/framework/pricing/main.go b/framework/pricing/main.go new file mode 100644 index 000000000..45880cbd6 --- /dev/null +++ b/framework/pricing/main.go @@ -0,0 +1,344 @@ +package pricing + +import ( + "fmt" + "sync" + "time" + + "github.com/maximhq/bifrost/core/schemas" + "github.com/maximhq/bifrost/framework/configstore" +) + +// Default sync interval and config key +const ( + DefaultPricingSyncInterval = 24 * time.Hour + LastPricingSyncKey = "LastModelPricingSync" + PricingFileURL = "https://getbifrost.ai/datasheet" + TokenTierAbove128K = 128000 +) + +type PricingManager struct { + configStore configstore.ConfigStore + logger schemas.Logger + + // In-memory cache for fast access - direct map for O(1) lookups + pricingData map[string]configstore.TableModelPricing + mu sync.RWMutex + + // Background sync worker + syncTicker *time.Ticker + done chan struct{} + wg sync.WaitGroup +} + +// PricingData represents the structure of the pricing.json file +type PricingData map[string]PricingEntry + +// PricingEntry represents a single model's pricing information +type PricingEntry struct { + // Basic pricing + InputCostPerToken float64 `json:"input_cost_per_token"` + OutputCostPerToken float64 `json:"output_cost_per_token"` + Provider string `json:"provider"` + Mode string `json:"mode"` + + // Additional pricing for media + InputCostPerImage *float64 `json:"input_cost_per_image,omitempty"` + InputCostPerVideoPerSecond *float64 `json:"input_cost_per_video_per_second,omitempty"` + InputCostPerAudioPerSecond *float64 `json:"input_cost_per_audio_per_second,omitempty"` + + // Character-based pricing + InputCostPerCharacter *float64 `json:"input_cost_per_character,omitempty"` + OutputCostPerCharacter *float64 `json:"output_cost_per_character,omitempty"` + + // Pricing above 128k tokens + InputCostPerTokenAbove128kTokens *float64 `json:"input_cost_per_token_above_128k_tokens,omitempty"` + InputCostPerCharacterAbove128kTokens *float64 `json:"input_cost_per_character_above_128k_tokens,omitempty"` + InputCostPerImageAbove128kTokens *float64 `json:"input_cost_per_image_above_128k_tokens,omitempty"` + InputCostPerVideoPerSecondAbove128kTokens *float64 `json:"input_cost_per_video_per_second_above_128k_tokens,omitempty"` + InputCostPerAudioPerSecondAbove128kTokens *float64 `json:"input_cost_per_audio_per_second_above_128k_tokens,omitempty"` + OutputCostPerTokenAbove128kTokens *float64 `json:"output_cost_per_token_above_128k_tokens,omitempty"` + OutputCostPerCharacterAbove128kTokens *float64 `json:"output_cost_per_character_above_128k_tokens,omitempty"` + + // Cache and batch pricing + CacheReadInputTokenCost *float64 `json:"cache_read_input_token_cost,omitempty"` + InputCostPerTokenBatches *float64 `json:"input_cost_per_token_batches,omitempty"` + OutputCostPerTokenBatches *float64 `json:"output_cost_per_token_batches,omitempty"` +} + +func Init(configStore configstore.ConfigStore, logger schemas.Logger) (*PricingManager, error) { + pm := &PricingManager{ + configStore: configStore, + logger: logger, + pricingData: make(map[string]configstore.TableModelPricing), + done: make(chan struct{}), + } + + if configStore != nil { + // Load initial pricing data + if err := pm.loadPricingFromDatabase(); err != nil { + return nil, fmt.Errorf("failed to load initial pricing data: %w", err) + } + + // For the bootup we sync pricing data from file to database + if err := pm.syncPricing(); err != nil { + return nil, fmt.Errorf("failed to sync pricing data: %w", err) + } + + } else { + // Load pricing data from config memory + if err := pm.loadPricingIntoMemory(); err != nil { + return nil, fmt.Errorf("failed to load pricing data from config memory: %w", err) + } + } + + // Start background sync worker + pm.startSyncWorker() + pm.configStore = configStore + pm.logger = logger + + return pm, nil +} + +func (pm *PricingManager) CalculateCost(result *schemas.BifrostResponse, provider schemas.ModelProvider, model string, requestType schemas.RequestType) float64 { + if result == nil || provider == "" || model == "" || requestType == "" { + return 0.0 + } + + var usage *schemas.LLMUsage + var audioSeconds *int + var audioTokenDetails *schemas.AudioTokenDetails + + //TODO: Detect cache and batch operations + isCacheRead := false + isBatch := false + + // Check main usage field + if result.Usage != nil { + usage = result.Usage + } else if result.Speech != nil && result.Speech.Usage != nil { + // For speech synthesis, create LLMUsage from AudioLLMUsage + usage = &schemas.LLMUsage{ + PromptTokens: result.Speech.Usage.InputTokens, + CompletionTokens: 0, // Speech doesn't have completion tokens + TotalTokens: result.Speech.Usage.TotalTokens, + } + + // Extract audio token details if available + if result.Speech.Usage.InputTokensDetails != nil { + audioTokenDetails = result.Speech.Usage.InputTokensDetails + } + } else if result.Transcribe != nil && result.Transcribe.Usage != nil && result.Transcribe.Usage.TotalTokens != nil { + // For transcription, create LLMUsage from TranscriptionUsage + inputTokens := 0 + outputTokens := 0 + if result.Transcribe.Usage.InputTokens != nil { + inputTokens = *result.Transcribe.Usage.InputTokens + } + if result.Transcribe.Usage.OutputTokens != nil { + outputTokens = *result.Transcribe.Usage.OutputTokens + } + usage = &schemas.LLMUsage{ + PromptTokens: inputTokens, + CompletionTokens: outputTokens, + TotalTokens: int(*result.Transcribe.Usage.TotalTokens), + } + + // Extract audio duration if available (for duration-based pricing) + if result.Transcribe.Usage.Seconds != nil { + audioSeconds = result.Transcribe.Usage.Seconds + } + + // Extract audio token details if available + if result.Transcribe.Usage.InputTokenDetails != nil { + audioTokenDetails = result.Transcribe.Usage.InputTokenDetails + } + } + + cost := 0.0 + if usage != nil || audioSeconds != nil || audioTokenDetails != nil { + cost = pm.CalculateCostFromUsage(string(provider), model, usage, requestType, isCacheRead, isBatch, audioSeconds, audioTokenDetails) + } + + return cost +} + +func (pm *PricingManager) CalculateCostWithCacheDebug(result *schemas.BifrostResponse, provider schemas.ModelProvider, model string, requestType schemas.RequestType) float64 { + if result == nil || provider == "" || model == "" || requestType == "" { + return 0.0 + } + cacheDebug := result.ExtraFields.CacheDebug + if cacheDebug != nil { + if cacheDebug.CacheHit { + if cacheDebug.HitType != nil && *cacheDebug.HitType == "direct" { + return 0 + } else if cacheDebug.ProviderUsed != nil && cacheDebug.ModelUsed != nil && cacheDebug.InputTokens != nil { + return pm.CalculateCostFromUsage(*cacheDebug.ProviderUsed, *cacheDebug.ModelUsed, &schemas.LLMUsage{ + PromptTokens: *cacheDebug.InputTokens, + CompletionTokens: 0, + TotalTokens: *cacheDebug.InputTokens, + }, schemas.EmbeddingRequest, false, false, nil, nil) + } + + // Don't over-bill cache hits if fields are missing. + return 0 + } else { + baseCost := pm.CalculateCost(result, provider, model, requestType) + var semanticCacheCost float64 + if cacheDebug.ProviderUsed != nil && cacheDebug.ModelUsed != nil && cacheDebug.InputTokens != nil { + semanticCacheCost = pm.CalculateCostFromUsage(*cacheDebug.ProviderUsed, *cacheDebug.ModelUsed, &schemas.LLMUsage{ + PromptTokens: *cacheDebug.InputTokens, + CompletionTokens: 0, + TotalTokens: *cacheDebug.InputTokens, + }, schemas.EmbeddingRequest, false, false, nil, nil) + } + + return baseCost + semanticCacheCost + } + } + + return pm.CalculateCost(result, provider, model, requestType) +} + +func (pm *PricingManager) Cleanup() error { + if pm.syncTicker != nil { + pm.syncTicker.Stop() + } + + close(pm.done) + pm.wg.Wait() + + return nil +} + +// CalculateCostFromUsage calculates cost in dollars using pricing manager and usage data with conditional pricing +func (pm *PricingManager) CalculateCostFromUsage(provider string, model string, usage *schemas.LLMUsage, requestType schemas.RequestType, isCacheRead bool, isBatch bool, audioSeconds *int, audioTokenDetails *schemas.AudioTokenDetails) float64 { + // Allow audio-only flows by only returning early if we have no usage data at all + if usage == nil && audioSeconds == nil && audioTokenDetails == nil { + return 0.0 + } + + // Get pricing for the model + pricing, exists := pm.getPricing(model, provider, requestType) + if !exists { + pm.logger.Warn("pricing not found for model %s and provider %s of request type %s, skipping cost calculation", model, provider, normalizeRequestType(requestType)) + return 0.0 + } + + var inputCost, outputCost float64 + + // Helper function to safely get token counts with zero defaults + safeTokenCount := func(usage *schemas.LLMUsage, getter func(*schemas.LLMUsage) int) int { + if usage == nil { + return 0 + } + return getter(usage) + } + + totalTokens := safeTokenCount(usage, func(u *schemas.LLMUsage) int { return u.TotalTokens }) + promptTokens := safeTokenCount(usage, func(u *schemas.LLMUsage) int { return u.PromptTokens }) + completionTokens := safeTokenCount(usage, func(u *schemas.LLMUsage) int { return u.CompletionTokens }) + + // Special handling for audio operations with duration-based pricing + if (requestType == schemas.SpeechRequest || requestType == schemas.TranscriptionRequest) && audioSeconds != nil && *audioSeconds > 0 { + // Determine if this is above TokenTierAbove128K for pricing tier selection + isAbove128k := totalTokens > TokenTierAbove128K + + // Use duration-based pricing for audio when available + var audioPerSecondRate *float64 + if isAbove128k && pricing.InputCostPerAudioPerSecondAbove128kTokens != nil { + audioPerSecondRate = pricing.InputCostPerAudioPerSecondAbove128kTokens + } else if pricing.InputCostPerAudioPerSecond != nil { + audioPerSecondRate = pricing.InputCostPerAudioPerSecond + } + + if audioPerSecondRate != nil { + inputCost = float64(*audioSeconds) * *audioPerSecondRate + } else { + // Fall back to token-based pricing + inputCost = float64(promptTokens) * pricing.InputCostPerToken + } + + // For audio operations, output cost is typically based on tokens (if any) + outputCost = float64(completionTokens) * pricing.OutputCostPerToken + + return inputCost + outputCost + } + + // Handle audio token details if available (for token-based audio pricing) + if audioTokenDetails != nil && (requestType == schemas.SpeechRequest || requestType == schemas.TranscriptionRequest) { + // Use audio-specific token pricing if available + audioTokens := float64(audioTokenDetails.AudioTokens) + textTokens := float64(audioTokenDetails.TextTokens) + isAbove128k := totalTokens > TokenTierAbove128K + + // Determine the appropriate token pricing rates + var inputTokenRate, outputTokenRate float64 + + if isAbove128k { + inputTokenRate = getSafeFloat64(pricing.InputCostPerTokenAbove128kTokens, pricing.InputCostPerToken) + outputTokenRate = getSafeFloat64(pricing.OutputCostPerTokenAbove128kTokens, pricing.OutputCostPerToken) + } else { + inputTokenRate = pricing.InputCostPerToken + outputTokenRate = pricing.OutputCostPerToken + } + + // Calculate costs using token-based pricing with audio/text breakdown + inputCost = audioTokens*inputTokenRate + textTokens*inputTokenRate + outputCost = float64(completionTokens) * outputTokenRate + + return inputCost + outputCost + } + + // Use conditional pricing based on request characteristics + if isBatch { + // Use batch pricing if available, otherwise fall back to regular pricing + if pricing.InputCostPerTokenBatches != nil { + inputCost = float64(promptTokens) * *pricing.InputCostPerTokenBatches + } else { + inputCost = float64(promptTokens) * pricing.InputCostPerToken + } + + if pricing.OutputCostPerTokenBatches != nil { + outputCost = float64(completionTokens) * *pricing.OutputCostPerTokenBatches + } else { + outputCost = float64(completionTokens) * pricing.OutputCostPerToken + } + } else if isCacheRead { + // Use cache read pricing for input tokens if available, regular pricing for output + if pricing.CacheReadInputTokenCost != nil { + inputCost = float64(promptTokens) * *pricing.CacheReadInputTokenCost + } else { + inputCost = float64(promptTokens) * pricing.InputCostPerToken + } + + // Output tokens always use regular pricing for cache reads + outputCost = float64(completionTokens) * pricing.OutputCostPerToken + } else { + // Use regular pricing + inputCost = float64(promptTokens) * pricing.InputCostPerToken + outputCost = float64(completionTokens) * pricing.OutputCostPerToken + } + + totalCost := inputCost + outputCost + + return totalCost +} + +// getPricing returns pricing information for a model (thread-safe) +func (pm *PricingManager) getPricing(model, provider string, requestType schemas.RequestType) (*configstore.TableModelPricing, bool) { + pm.mu.RLock() + defer pm.mu.RUnlock() + + pricing, ok := pm.pricingData[makeKey(model, provider, normalizeRequestType(requestType))] + if !ok { + if provider == string(schemas.Gemini) { + pricing, ok = pm.pricingData[makeKey(model, "vertex", normalizeRequestType(requestType))] + if ok { + return &pricing, true + } + } + return nil, false + } + return &pricing, true +} diff --git a/framework/pricing/sync.go b/framework/pricing/sync.go new file mode 100644 index 000000000..607a4e90c --- /dev/null +++ b/framework/pricing/sync.go @@ -0,0 +1,239 @@ +package pricing + +import ( + "encoding/json" + "fmt" + "io" + "net/http" + "time" + + "github.com/maximhq/bifrost/framework/configstore" + "gorm.io/gorm" +) + +// checkAndSyncPricing determines if pricing data needs to be synced and performs the sync if needed. +// It syncs pricing data in the following scenarios: +// - No config store available (returns early with no error) +// - No previous sync record exists +// - Previous sync timestamp is invalid/corrupted +// - Sync interval has elapsed since last successful sync +func (pm *PricingManager) checkAndSyncPricing() error { + // Skip sync if no config store is available + if pm.configStore == nil { + return nil + } + + // Determine if sync is needed and perform it + needsSync, reason := pm.shouldSyncPricing() + if needsSync { + pm.logger.Debug("pricing sync needed: %s", reason) + return pm.syncPricing() + } + + return nil +} + +// shouldSyncPricing determines if pricing data should be synced and returns the reason +func (pm *PricingManager) shouldSyncPricing() (bool, string) { + config, err := pm.configStore.GetConfig(LastPricingSyncKey) + if err != nil { + return true, "no previous sync record found" + } + + lastSync, err := time.Parse(time.RFC3339, config.Value) + if err != nil { + pm.logger.Warn("invalid last sync timestamp: %v", err) + return true, "corrupted sync timestamp" + } + + if time.Since(lastSync) >= DefaultPricingSyncInterval { + return true, "sync interval elapsed" + } + + return false, "sync not needed" +} + +// syncPricing syncs pricing data from URL to database and updates cache +func (pm *PricingManager) syncPricing() error { + pm.logger.Debug("Starting pricing data synchronization for governance") + + // Load pricing data from URL + pricingData, err := pm.loadPricingFromURL() + if err != nil { + // Check if we have existing data in database + pricingRecords, err := pm.configStore.GetModelPrices() + if err != nil { + return fmt.Errorf("failed to get pricing records: %w", err) + } + if len(pricingRecords) > 0 { + pm.logger.Error("failed to load pricing data from URL, but existing data found in database: %v", err) + return nil + } else { + return fmt.Errorf("failed to load pricing data from URL and no existing data in database: %w", err) + } + } + + // Update database in transaction + err = pm.configStore.ExecuteTransaction(func(tx *gorm.DB) error { + // Clear existing pricing data + if err := pm.configStore.DeleteModelPrices(tx); err != nil { + return fmt.Errorf("failed to clear existing pricing data: %v", err) + } + + // Deduplicate and insert new pricing data + seen := make(map[string]bool) + for modelKey, entry := range pricingData { + pricing := convertPricingDataToTableModelPricing(modelKey, entry) + + // Create composite key for deduplication + key := makeKey(pricing.Model, pricing.Provider, pricing.Mode) + + // Skip if already seen + if exists, ok := seen[key]; ok && exists { + continue + } + + // Mark as seen + seen[key] = true + + if err := pm.configStore.CreateModelPrices(&pricing, tx); err != nil { + return fmt.Errorf("failed to create pricing record for model %s: %w", pricing.Model, err) + } + } + + // Clear seen map + seen = nil + + return nil + }) + + if err != nil { + return fmt.Errorf("failed to sync pricing data to database: %w", err) + } + + config := &configstore.TableConfig{ + Key: LastPricingSyncKey, + Value: time.Now().Format(time.RFC3339), + } + + // Update last sync time + if err := pm.configStore.UpdateConfig(config); err != nil { + pm.logger.Warn("Failed to update last sync time: %v", err) + } + + // Reload cache from database + if err := pm.loadPricingFromDatabase(); err != nil { + return fmt.Errorf("failed to reload pricing cache: %w", err) + } + + pm.logger.Info("successfully synced %d pricing records", len(pricingData)) + return nil +} + +// loadPricingFromURL loads pricing data from the remote URL +func (pm *PricingManager) loadPricingFromURL() (PricingData, error) { + // Create HTTP client with timeout + client := &http.Client{ + Timeout: 30 * time.Second, + } + + // Make HTTP request + resp, err := client.Get(PricingFileURL) + if err != nil { + return nil, fmt.Errorf("failed to download pricing data: %w", err) + } + defer resp.Body.Close() + + // Check HTTP status + if resp.StatusCode != http.StatusOK { + return nil, fmt.Errorf("failed to download pricing data: HTTP %d", resp.StatusCode) + } + + // Read response body + data, err := io.ReadAll(resp.Body) + if err != nil { + return nil, fmt.Errorf("failed to read pricing data response: %w", err) + } + + // Unmarshal JSON data + var pricingData PricingData + if err := json.Unmarshal(data, &pricingData); err != nil { + return nil, fmt.Errorf("failed to unmarshal pricing data: %w", err) + } + + pm.logger.Debug("successfully downloaded and parsed %d pricing records", len(pricingData)) + return pricingData, nil +} + +// loadPricingIntoMemory loads pricing data from URL into memory cache +func (pm *PricingManager) loadPricingIntoMemory() error { + pricingData, err := pm.loadPricingFromURL() + if err != nil { + return fmt.Errorf("failed to load pricing data from URL: %w", err) + } + + pm.mu.Lock() + defer pm.mu.Unlock() + + // Clear and rebuild the pricing map + pm.pricingData = make(map[string]configstore.TableModelPricing, len(pricingData)) + for modelKey, entry := range pricingData { + pricing := convertPricingDataToTableModelPricing(modelKey, entry) + key := makeKey(pricing.Model, pricing.Provider, pricing.Mode) + pm.pricingData[key] = pricing + } + + return nil +} + +// loadPricingFromDatabase loads pricing data from database into memory cache +func (pm *PricingManager) loadPricingFromDatabase() error { + if pm.configStore == nil { + return nil + } + + pricingRecords, err := pm.configStore.GetModelPrices() + if err != nil { + return fmt.Errorf("failed to load pricing from database: %w", err) + } + + pm.mu.Lock() + defer pm.mu.Unlock() + + // Clear and rebuild the pricing map + pm.pricingData = make(map[string]configstore.TableModelPricing, len(pricingRecords)) + for _, pricing := range pricingRecords { + key := makeKey(pricing.Model, pricing.Provider, pricing.Mode) + pm.pricingData[key] = pricing + } + + pm.logger.Debug("loaded %d pricing records into cache", len(pricingRecords)) + return nil +} + +// startSyncWorker starts the background sync worker +func (pm *PricingManager) startSyncWorker() { + // Use a ticker that checks every hour, but only sync when needed + pm.syncTicker = time.NewTicker(1 * time.Hour) + pm.wg.Add(1) + go pm.syncWorker() +} + +// syncWorker runs the background sync check +func (pm *PricingManager) syncWorker() { + defer pm.wg.Done() + defer pm.syncTicker.Stop() + + for { + select { + case <-pm.syncTicker.C: + // Check and sync pricing data - this handles the sync internally + if err := pm.checkAndSyncPricing(); err != nil { + pm.logger.Error("background pricing sync failed: %v", err) + } + + case <-pm.done: + return + } + } +} diff --git a/framework/pricing/utils.go b/framework/pricing/utils.go new file mode 100644 index 000000000..b841e3b20 --- /dev/null +++ b/framework/pricing/utils.go @@ -0,0 +1,126 @@ +package pricing + +import ( + "strings" + + "github.com/maximhq/bifrost/core/schemas" + "github.com/maximhq/bifrost/framework/configstore" +) + +// makeKey creates a unique key for a model, provider, and mode for pricingData map +func makeKey(model, provider, mode string) string { return model + "|" + provider + "|" + mode } + +// isBatchRequest checks if the request is for batch processing +func isBatchRequest(req *schemas.BifrostRequest) bool { + // Check for batch endpoints or batch-specific headers + // This could be detected via specific endpoint patterns or headers + // For now, return false + return false +} + +// isCacheReadRequest checks if the request involves cache reading +func isCacheReadRequest(req *schemas.BifrostRequest, headers map[string]string) bool { + // Check for cache-related headers or request parameters + if cacheHeader := headers["x-cache-read"]; cacheHeader == "true" { + return true + } + + // Check for anthropic cache headers + if cacheControl := headers["anthropic-beta"]; cacheControl != "" { + return true + } + + // TODO: Add message-level cache control detection when BifrostMessage schema supports it + // For now, cache detection relies on headers only + + return false +} + +// normalizeProvider normalizes the provider name to a consistent format +func normalizeProvider(p string) string { + if strings.Contains(p, "vertex_ai") || p == "google-vertex" { + return string(schemas.Vertex) + } else { + return p + } +} + +// normalizeRequestType normalizes the request type to a consistent format +func normalizeRequestType(reqType schemas.RequestType) string { + baseType := "unknown" + + switch reqType { + case schemas.TextCompletionRequest: + baseType = "completion" + case schemas.ChatCompletionRequest, schemas.ChatCompletionStreamRequest: + baseType = "chat" + case schemas.EmbeddingRequest: + baseType = "embedding" + case schemas.SpeechRequest, schemas.SpeechStreamRequest: + baseType = "audio_speech" + case schemas.TranscriptionRequest, schemas.TranscriptionStreamRequest: + baseType = "audio_transcription" + } + + // TODO: Check for batch processing indicators + // if isBatchRequest(reqType) { + // return baseType + "_batch" + // } + + return baseType +} + +// convertPricingDataToTableModelPricing converts the pricing data to a TableModelPricing struct +func convertPricingDataToTableModelPricing(modelKey string, entry PricingEntry) configstore.TableModelPricing { + provider := normalizeProvider(entry.Provider) + + // Handle provider/model format - extract just the model name + modelName := modelKey + if strings.Contains(modelKey, "/") { + parts := strings.Split(modelKey, "/") + if len(parts) > 1 { + modelName = strings.Join(parts[1:], "/") + } + } + + pricing := configstore.TableModelPricing{ + Model: modelName, + Provider: provider, + InputCostPerToken: entry.InputCostPerToken, + OutputCostPerToken: entry.OutputCostPerToken, + Mode: entry.Mode, + + // Additional pricing for media + InputCostPerImage: entry.InputCostPerImage, + InputCostPerVideoPerSecond: entry.InputCostPerVideoPerSecond, + InputCostPerAudioPerSecond: entry.InputCostPerAudioPerSecond, + + // Character-based pricing + InputCostPerCharacter: entry.InputCostPerCharacter, + OutputCostPerCharacter: entry.OutputCostPerCharacter, + + // Pricing above 128k tokens + InputCostPerTokenAbove128kTokens: entry.InputCostPerTokenAbove128kTokens, + InputCostPerCharacterAbove128kTokens: entry.InputCostPerCharacterAbove128kTokens, + InputCostPerImageAbove128kTokens: entry.InputCostPerImageAbove128kTokens, + InputCostPerVideoPerSecondAbove128kTokens: entry.InputCostPerVideoPerSecondAbove128kTokens, + InputCostPerAudioPerSecondAbove128kTokens: entry.InputCostPerAudioPerSecondAbove128kTokens, + OutputCostPerTokenAbove128kTokens: entry.OutputCostPerTokenAbove128kTokens, + OutputCostPerCharacterAbove128kTokens: entry.OutputCostPerCharacterAbove128kTokens, + + // Cache and batch pricing + CacheReadInputTokenCost: entry.CacheReadInputTokenCost, + InputCostPerTokenBatches: entry.InputCostPerTokenBatches, + OutputCostPerTokenBatches: entry.OutputCostPerTokenBatches, + } + + return pricing +} + +// getSafeFloat64 returns the value of a float64 pointer or fallback if nil +func getSafeFloat64(ptr *float64, fallback float64) float64 { + if ptr != nil { + return *ptr + } + return fallback +} diff --git a/framework/vectorstore/errors.go b/framework/vectorstore/errors.go new file mode 100644 index 000000000..ffcd9cf41 --- /dev/null +++ b/framework/vectorstore/errors.go @@ -0,0 +1,8 @@ +package vectorstore + +import "errors" + +var ( + ErrNotFound = errors.New("vectorstore: not found") + ErrNotSupported = errors.New("vectorstore: operation not supported on this store") +) diff --git a/framework/vectorstore/redis.go b/framework/vectorstore/redis.go new file mode 100644 index 000000000..a5c07574c --- /dev/null +++ b/framework/vectorstore/redis.go @@ -0,0 +1,842 @@ +package vectorstore + +import ( + "context" + "encoding/binary" + "encoding/json" + "fmt" + "math" + "strconv" + "strings" + "time" + + "github.com/maximhq/bifrost/core/schemas" + "github.com/redis/go-redis/v9" +) + +const ( + // defaultLimit is the default limit used for pagination and batch operations + BatchLimit = 100 +) + +type RedisConfig struct { + // Connection settings + Addr string `json:"addr"` // Redis server address (host:port) - REQUIRED + Username string `json:"username,omitempty"` // Username for Redis AUTH (optional) + Password string `json:"password,omitempty"` // Password for Redis AUTH (optional) + DB int `json:"db,omitempty"` // Redis database number (default: 0) + + // Connection pool and timeout settings (passed directly to Redis client) + PoolSize int `json:"pool_size,omitempty"` // Maximum number of socket connections (optional) + MaxActiveConns int `json:"max_active_conns,omitempty"` // Maximum number of active connections (optional) + MinIdleConns int `json:"min_idle_conns,omitempty"` // Minimum number of idle connections (optional) + MaxIdleConns int `json:"max_idle_conns,omitempty"` // Maximum number of idle connections (optional) + ConnMaxLifetime time.Duration `json:"conn_max_lifetime,omitempty"` // Connection maximum lifetime (optional) + ConnMaxIdleTime time.Duration `json:"conn_max_idle_time,omitempty"` // Connection maximum idle time (optional) + DialTimeout time.Duration `json:"dial_timeout,omitempty"` // Timeout for socket connection (optional) + ReadTimeout time.Duration `json:"read_timeout,omitempty"` // Timeout for socket reads (optional) + WriteTimeout time.Duration `json:"write_timeout,omitempty"` // Timeout for socket writes (optional) + ContextTimeout time.Duration `json:"context_timeout,omitempty"` // Timeout for Redis operations (optional) +} + +// RedisStore represents the Redis vector store. +type RedisStore struct { + client *redis.Client + config RedisConfig + logger schemas.Logger +} + +func (s *RedisStore) CreateNamespace(ctx context.Context, namespace string, dimension int, properties map[string]VectorStoreProperties) error { + ctx, cancel := withTimeout(ctx, s.config.ContextTimeout) + defer cancel() + + // Check if index already exists + infoResult := s.client.Do(ctx, "FT.INFO", namespace) + if infoResult.Err() == nil { + return nil // Index already exists + } + if err := infoResult.Err(); err != nil && strings.Contains(strings.ToLower(err.Error()), "unknown command") { + return fmt.Errorf("RediSearch module not available: please use Redis Stack or enable RediSearch (FT.*). Original error: %w", err) + } + + // Extract metadata field names from properties + var metadataFields []string + for fieldName := range properties { + metadataFields = append(metadataFields, fieldName) + } + + // Create index with VECTOR field + metadata fields + keyPrefix := fmt.Sprintf("%s:", namespace) + + if dimension <= 0 { + return fmt.Errorf("redis vector index %q: dimension must be > 0 (got %d)", namespace, dimension) + } + + args := []interface{}{ + "FT.CREATE", namespace, + "ON", "HASH", + "PREFIX", "1", keyPrefix, + "SCHEMA", + // Native vector field with HNSW algorithm + "embedding", "VECTOR", "HNSW", "6", + "TYPE", "FLOAT32", + "DIM", dimension, + "DISTANCE_METRIC", "COSINE", + } + + // Add all metadata fields as TEXT with exact matching + // All values are converted to strings for consistent searching + for _, field := range metadataFields { + // Detect field type from VectorStoreProperties + prop := properties[field] + switch prop.DataType { + case VectorStorePropertyTypeInteger: + args = append(args, field, "NUMERIC") + default: + args = append(args, field, "TAG") + } + } + + // Create the index + if err := s.client.Do(ctx, args...).Err(); err != nil { + return fmt.Errorf("failed to create semantic vector index %s: %w", namespace, err) + } + + return nil +} + +func (s *RedisStore) GetChunk(ctx context.Context, namespace string, id string) (SearchResult, error) { + ctx, cancel := withTimeout(ctx, s.config.ContextTimeout) + defer cancel() + + if strings.TrimSpace(id) == "" { + return SearchResult{}, fmt.Errorf("id is required") + } + + // Create key with namespace + key := buildKey(namespace, id) + + // Get all fields from the hash + result := s.client.HGetAll(ctx, key) + if result.Err() != nil { + return SearchResult{}, fmt.Errorf("failed to get chunk: %w", result.Err()) + } + + fields := result.Val() + if len(fields) == 0 { + return SearchResult{}, fmt.Errorf("chunk not found: %s", id) + } + + // Build SearchResult + searchResult := SearchResult{ + ID: id, + Properties: make(map[string]interface{}), + } + + // Parse fields + for k, v := range fields { + searchResult.Properties[k] = v + } + + return searchResult, nil +} + +func (s *RedisStore) GetChunks(ctx context.Context, namespace string, ids []string) ([]SearchResult, error) { + ctx, cancel := withTimeout(ctx, s.config.ContextTimeout) + defer cancel() + + if len(ids) == 0 { + return []SearchResult{}, nil + } + + // Create keys with namespace + keys := make([]string, len(ids)) + for i, id := range ids { + if strings.TrimSpace(id) == "" { + return nil, fmt.Errorf("id cannot be empty at index %d", i) + } + keys[i] = buildKey(namespace, id) + } + + // Use pipeline for efficient batch retrieval + pipe := s.client.Pipeline() + cmds := make([]*redis.MapStringStringCmd, len(keys)) + + for i, key := range keys { + cmds[i] = pipe.HGetAll(ctx, key) + } + + // Execute pipeline + _, err := pipe.Exec(ctx) + if err != nil { + return nil, fmt.Errorf("failed to execute pipeline: %w", err) + } + + // Process results + var results []SearchResult + for i, cmd := range cmds { + if cmd.Err() != nil { + // Log error but continue with other results + s.logger.Debug(fmt.Sprintf("failed to get chunk %s: %v", ids[i], cmd.Err())) + continue + } + + fields := cmd.Val() + if len(fields) == 0 { + // Chunk not found, skip + continue + } + + // Build SearchResult + searchResult := SearchResult{ + ID: ids[i], + Properties: make(map[string]interface{}), + } + + // Parse fields + for k, v := range fields { + searchResult.Properties[k] = v + } + + results = append(results, searchResult) + } + + return results, nil +} + +func (s *RedisStore) GetAll(ctx context.Context, namespace string, queries []Query, selectFields []string, cursor *string, limit int64) ([]SearchResult, *string, error) { + ctx, cancel := withTimeout(ctx, s.config.ContextTimeout) + defer cancel() + + // Set default limit if not provided + if limit < 0 { + limit = BatchLimit + } + + // Build Redis query from the provided queries + redisQuery := buildRedisQuery(queries) + + // Build FT.SEARCH command + args := []interface{}{ + "FT.SEARCH", namespace, + redisQuery, + } + + // Add RETURN only if specific fields were requested + if len(selectFields) > 0 { + args = append(args, "RETURN", len(selectFields)) + for _, field := range selectFields { + args = append(args, field) + } + } + + // Add LIMIT clause - use large limit for "all" (limit=0) + searchLimit := limit + if limit == 0 { + searchLimit = math.MaxInt32 // Use large limit to get all results + } + + // Add OFFSET for pagination if cursor is provided + offset := 0 + if cursor != nil && *cursor != "" { + if parsedOffset, err := strconv.ParseInt(*cursor, 10, 64); err == nil { + offset = int(parsedOffset) + } + } + + args = append(args, "LIMIT", offset, int(searchLimit), "DIALECT", "2") + + // Execute search + result := s.client.Do(ctx, args...) + if result.Err() != nil { + return nil, nil, fmt.Errorf("failed to search: %w", result.Err()) + } + + // Parse search results + results, err := s.parseSearchResults(result.Val(), selectFields) + if err != nil { + return nil, nil, fmt.Errorf("failed to parse search results: %w", err) + } + + // Implement cursor-based pagination using OFFSET + var nextCursor *string = nil + if cursor != nil && *cursor != "" { + // If we have a cursor, we've already applied pagination + // Check if there might be more results + if len(results) == int(limit) && limit > 0 { + // There might be more results, create next cursor + offset, err := strconv.ParseInt(*cursor, 10, 64) + if err == nil { + nextOffset := offset + limit + nextCursorStr := strconv.FormatInt(nextOffset, 10) + nextCursor = &nextCursorStr + } + } + } else if len(results) == int(limit) && limit > 0 { + // First page and we got exactly the limit, there might be more + nextCursorStr := strconv.FormatInt(limit, 10) + nextCursor = &nextCursorStr + } + + return results, nextCursor, nil +} + +// parseSearchResults parses FT.SEARCH results into SearchResult slice +func (s *RedisStore) parseSearchResults(result interface{}, selectFields []string) ([]SearchResult, error) { + // FT.SEARCH returns a map with results array + resultMap, ok := result.(map[interface{}]interface{}) + if !ok { + return []SearchResult{}, nil + } + + resultsArray, ok := resultMap["results"].([]interface{}) + if !ok { + return []SearchResult{}, nil + } + + results := []SearchResult{} + + for _, resultItem := range resultsArray { + resultMap, ok := resultItem.(map[interface{}]interface{}) + if !ok { + continue + } + + // Get the document ID + id, ok := resultMap["id"].(string) + if !ok { + continue + } + + // Extract ID from key (remove namespace prefix) + keyParts := strings.Split(id, ":") + if len(keyParts) < 2 { + continue + } + documentID := strings.Join(keyParts[1:], ":") // Handle IDs that might contain colons + + // Get the extra_attributes (metadata) + extraAttributes, ok := resultMap["extra_attributes"].(map[interface{}]interface{}) + if !ok { + continue + } + + // Build SearchResult + searchResult := SearchResult{ + ID: documentID, + Properties: make(map[string]interface{}), + } + + // Parse extra_attributes + for fieldNameInterface, fieldValue := range extraAttributes { + fieldName, ok := fieldNameInterface.(string) + if !ok { + continue + } + + // Always include score field for vector searches + if fieldName == "score" { + searchResult.Properties[fieldName] = fieldValue + // Also set the Score field for proper access + if scoreFloat, ok := fieldValue.(float64); ok { + searchResult.Score = &scoreFloat + } + continue + } + + // Apply field selection if specified + if len(selectFields) > 0 { + // Check if this field should be included + include := false + for _, selectField := range selectFields { + if fieldName == selectField { + include = true + break + } + } + if !include { + continue + } + } + + searchResult.Properties[fieldName] = fieldValue + } + + results = append(results, searchResult) + } + + return results, nil +} + +// buildRedisQuery converts []Query to Redis query syntax +func buildRedisQuery(queries []Query) string { + if len(queries) == 0 { + return "*" + } + + var conditions []string + for _, query := range queries { + condition := buildRedisQueryCondition(query) + if condition != "" { + conditions = append(conditions, condition) + } + } + + if len(conditions) == 0 { + return "*" + } + + // Join conditions with space (AND operation in Redis) + return strings.Join(conditions, " ") +} + +// buildRedisQueryCondition builds a single Redis query condition +func buildRedisQueryCondition(query Query) string { + field := query.Field + operator := query.Operator + value := query.Value + + // Convert value to string + var stringValue string + switch val := value.(type) { + case string: + stringValue = val + case int, int64, float64, bool: + stringValue = fmt.Sprintf("%v", val) + default: + jsonData, _ := json.Marshal(val) + stringValue = string(jsonData) + } + + // Escape special characters for TAG fields + escapedValue := escapeSearchValue(stringValue) // new function for TAG escaping + + switch operator { + case QueryOperatorEqual: + // TAG exact match + return fmt.Sprintf("@%s:{%s}", field, escapedValue) + case QueryOperatorNotEqual: + // TAG negation + return fmt.Sprintf("-@%s:{%s}", field, escapedValue) + case QueryOperatorLike: + // Cannot do LIKE with TAGs directly; fallback to exact match + return fmt.Sprintf("@%s:{%s}", field, escapedValue) + case QueryOperatorGreaterThan: + return fmt.Sprintf("@%s:[(%s +inf]", field, escapedValue) + case QueryOperatorGreaterThanOrEqual: + return fmt.Sprintf("@%s:[%s +inf]", field, escapedValue) + case QueryOperatorLessThan: + return fmt.Sprintf("@%s:[-inf (%s]", field, escapedValue) + case QueryOperatorLessThanOrEqual: + return fmt.Sprintf("@%s:[-inf %s]", field, escapedValue) + case QueryOperatorIsNull: + // Field not present + return fmt.Sprintf("-@%s:*", field) + case QueryOperatorIsNotNull: + // Field exists + return fmt.Sprintf("@%s:*", field) + case QueryOperatorContainsAny: + if values, ok := value.([]interface{}); ok { + var orConditions []string + for _, v := range values { + vStr := fmt.Sprintf("%v", v) + orConditions = append(orConditions, fmt.Sprintf("@%s:{%s}", field, escapeSearchValue(vStr))) + } + return fmt.Sprintf("(%s)", strings.Join(orConditions, " | ")) + } + return fmt.Sprintf("@%s:{%s}", field, escapedValue) + case QueryOperatorContainsAll: + if values, ok := value.([]interface{}); ok { + var andConditions []string + for _, v := range values { + vStr := fmt.Sprintf("%v", v) + andConditions = append(andConditions, fmt.Sprintf("@%s:{%s}", field, escapeSearchValue(vStr))) + } + return strings.Join(andConditions, " ") + } + return fmt.Sprintf("@%s:{%s}", field, escapedValue) + default: + return fmt.Sprintf("@%s:{%s}", field, escapedValue) + } +} + +func (s *RedisStore) GetNearest(ctx context.Context, namespace string, vector []float32, queries []Query, selectFields []string, threshold float64, limit int64) ([]SearchResult, error) { + ctx, cancel := withTimeout(ctx, s.config.ContextTimeout) + defer cancel() + + // Build Redis query from the provided queries + redisQuery := buildRedisQuery(queries) + + // Convert query embedding to binary format + queryBytes := float32SliceToBytes(vector) + + // Build hybrid FT.SEARCH query: metadata filters + KNN vector search + // The correct syntax is: (metadata_filter)=>[KNN k @embedding $vec AS score] + var hybridQuery string + if len(queries) > 0 { + // Wrap metadata query in parentheses for hybrid syntax + hybridQuery = fmt.Sprintf("(%s)", redisQuery) + } else { + // Wildcard for pure vector search + hybridQuery = "*" + } + + // Execute FT.SEARCH with KNN + // Use large limit for "all" (limit=0) in KNN query + knnLimit := limit + if limit == 0 { + knnLimit = math.MaxInt32 + } + + args := []interface{}{ + "FT.SEARCH", namespace, + fmt.Sprintf("%s=>[KNN %d @embedding $vec AS score]", hybridQuery, knnLimit), + "PARAMS", "2", "vec", queryBytes, + "SORTBY", "score", + } + + // Add RETURN clause - always include score for vector search + // For vector search, we need to include the score field generated by KNN + returnFields := []string{"score"} + if len(selectFields) > 0 { + returnFields = append(returnFields, selectFields...) + } + + args = append(args, "RETURN", len(returnFields)) + for _, field := range returnFields { + args = append(args, field) + } + + // Add LIMIT clause and DIALECT 2 for better query parsing + searchLimit := limit + if limit == 0 { + searchLimit = math.MaxInt32 + } + args = append(args, "LIMIT", 0, int(searchLimit), "DIALECT", "2") + + result := s.client.Do(ctx, args...) + if result.Err() != nil { + return nil, fmt.Errorf("native vector search failed: %w", result.Err()) + } + + // Parse search results + results, err := s.parseSearchResults(result.Val(), selectFields) + if err != nil { + return nil, err + } + + // Apply threshold filter and extract scores + var filteredResults []SearchResult + for _, result := range results { + // Extract score from the result + if scoreValue, exists := result.Properties["score"]; exists { + var score float64 + switch v := scoreValue.(type) { + case float64: + score = v + case float32: + score = float64(v) + case int: + score = float64(v) + case int64: + score = float64(v) + case string: + if parsedScore, err := strconv.ParseFloat(v, 64); err == nil { + score = parsedScore + } + } + + // Convert cosine distance to similarity: similarity = 1 - distance + similarity := 1.0 - score + result.Score = &similarity + + // Apply threshold filter + if similarity >= threshold { + filteredResults = append(filteredResults, result) + } + } else { + // If no score, include the result (shouldn't happen with KNN queries) + filteredResults = append(filteredResults, result) + } + } + + results = filteredResults + + return results, nil +} + +func (s *RedisStore) Add(ctx context.Context, namespace string, id string, embedding []float32, metadata map[string]interface{}) error { + ctx, cancel := withTimeout(ctx, s.config.ContextTimeout) + defer cancel() + + if strings.TrimSpace(id) == "" { + return fmt.Errorf("id is required") + } + + // Create key with namespace + key := buildKey(namespace, id) + + // Prepare hash fields: binary embedding + metadata + fields := make(map[string]interface{}) + + // Only add embedding if it's not empty + if len(embedding) > 0 { + // Convert float32 slice to bytes for Redis storage + embeddingBytes := float32SliceToBytes(embedding) + fields["embedding"] = embeddingBytes + } + + // Add metadata fields directly (no prefix needed with proper indexing) + for k, v := range metadata { + switch val := v.(type) { + case string: + fields[k] = val + case int, int64, float64, bool: + fields[k] = fmt.Sprintf("%v", val) + case []interface{}: + // Preserve arrays as JSON to support round-trips (e.g., stream_chunks) + b, err := json.Marshal(val) + if err != nil { + return fmt.Errorf("failed to marshal array metadata %s: %w", k, err) + } + fields[k] = string(b) + default: + // JSON encode complex types + jsonData, err := json.Marshal(val) + if err != nil { + return fmt.Errorf("failed to marshal metadata field %s: %w", k, err) + } + fields[k] = string(jsonData) + } + } + + // Store as hash for efficient native vector search + if err := s.client.HSet(ctx, key, fields).Err(); err != nil { + return fmt.Errorf("failed to store semantic cache entry: %w", err) + } + + return nil +} + +func (s *RedisStore) Delete(ctx context.Context, namespace string, id string) error { + ctx, cancel := withTimeout(ctx, s.config.ContextTimeout) + defer cancel() + + if strings.TrimSpace(id) == "" { + return fmt.Errorf("id is required") + } + + // Create key with namespace + key := buildKey(namespace, id) + + // Delete the hash key + result := s.client.Del(ctx, key) + if result.Err() != nil { + return fmt.Errorf("failed to delete chunk %s: %w", id, result.Err()) + } + + // Check if the key actually existed + if result.Val() == 0 { + return fmt.Errorf("chunk not found: %s", id) + } + + return nil +} + +func (s *RedisStore) DeleteAll(ctx context.Context, namespace string, queries []Query) ([]DeleteResult, error) { + ctx, cancel := withTimeout(ctx, s.config.ContextTimeout) + defer cancel() + + // Use cursor-based deletion to handle large datasets efficiently + return s.deleteAllWithCursor(ctx, namespace, queries, nil) +} + +// deleteAllWithCursor performs cursor-based deletion for large datasets +func (s *RedisStore) deleteAllWithCursor(ctx context.Context, namespace string, queries []Query, cursor *string) ([]DeleteResult, error) { + // Get a batch of documents to delete (using pagination) + results, nextCursor, err := s.GetAll(ctx, namespace, queries, []string{}, cursor, BatchLimit) + if err != nil { + return nil, fmt.Errorf("failed to find documents to delete: %w", err) + } + + if len(results) == 0 { + return []DeleteResult{}, nil + } + + // Extract IDs from results + ids := make([]string, len(results)) + for i, result := range results { + ids[i] = result.ID + } + + // Delete this batch of documents + var deleteResults []DeleteResult + batchSize := BatchLimit // Process in batches to avoid overwhelming Redis + + for i := 0; i < len(ids); i += batchSize { + end := i + batchSize + if end > len(ids) { + end = len(ids) + } + batch := ids[i:end] + + // Create pipeline for batch deletion + pipe := s.client.Pipeline() + cmds := make([]*redis.IntCmd, len(batch)) + + for j, id := range batch { + key := buildKey(namespace, id) + cmds[j] = pipe.Del(ctx, key) + } + + // Execute pipeline + _, err := pipe.Exec(ctx) + if err != nil { + // If pipeline fails, mark all in this batch as failed + for _, id := range batch { + deleteResults = append(deleteResults, DeleteResult{ + ID: id, + Status: DeleteStatusError, + Error: fmt.Sprintf("pipeline execution failed: %v", err), + }) + } + continue + } + + // Process results for this batch + for j, cmd := range cmds { + id := batch[j] + if cmd.Err() != nil { + deleteResults = append(deleteResults, DeleteResult{ + ID: id, + Status: DeleteStatusError, + Error: cmd.Err().Error(), + }) + } else if cmd.Val() > 0 { + // Key existed and was deleted + deleteResults = append(deleteResults, DeleteResult{ + ID: id, + Status: DeleteStatusSuccess, + }) + } else { + // Key didn't exist + deleteResults = append(deleteResults, DeleteResult{ + ID: id, + Status: DeleteStatusError, + Error: "document not found", + }) + } + } + } + + // If there are more results, continue with next cursor + if nextCursor != nil { + nextResults, err := s.deleteAllWithCursor(ctx, namespace, queries, nextCursor) + if err != nil { + return nil, fmt.Errorf("failed to delete remaining documents: %w", err) + } + // Combine results from this batch and subsequent batches + deleteResults = append(deleteResults, nextResults...) + } + + return deleteResults, nil +} + +func (s *RedisStore) DeleteNamespace(ctx context.Context, namespace string) error { + ctx, cancel := withTimeout(ctx, s.config.ContextTimeout) + defer cancel() + + // Drop the index using FT.DROPINDEX + if err := s.client.Do(ctx, "FT.DROPINDEX", namespace).Err(); err != nil { + // Check if error is "Unknown Index name" - that's OK, index doesn't exist + if strings.Contains(err.Error(), "Unknown Index name") { + return nil // Index doesn't exist, nothing to drop + } + return fmt.Errorf("failed to drop semantic index %s: %w", namespace, err) + } + + return nil +} + +func (s *RedisStore) Close(ctx context.Context, namespace string) error { + // Close the Redis client connection + return s.client.Close() +} + +// escapeSearchValue escapes special characters in search values. +func escapeSearchValue(value string) string { + // Escape special RediSearch characters + replacer := strings.NewReplacer( + "(", "\\(", + ")", "\\)", + "[", "\\[", + "]", "\\]", + "{", "\\{", + "}", "\\}", + "*", "\\*", + "?", "\\?", + "|", "\\|", + "&", "\\&", + "!", "\\!", + "@", "\\@", + "#", "\\#", + "$", "\\$", + "%", "\\%", + "^", "\\^", + "~", "\\~", + "`", "\\`", + "\"", "\\\"", + "'", "\\'", + " ", "\\ ", + "-", "\\-", + ",", "|", + ) + return replacer.Replace(value) +} + +// Binary embedding conversion helpers +func float32SliceToBytes(floats []float32) []byte { + bytes := make([]byte, len(floats)*4) + for i, f := range floats { + binary.LittleEndian.PutUint32(bytes[i*4:], math.Float32bits(f)) + } + return bytes +} + +// buildKey creates a Redis key by combining namespace and id. +func buildKey(namespace, id string) string { + return fmt.Sprintf("%s:%s", namespace, id) +} + +// newRedisStore creates a new Redis vector store. +func newRedisStore(ctx context.Context, config RedisConfig, logger schemas.Logger) (*RedisStore, error) { + // Validate required fields + if config.Addr == "" { + return nil, fmt.Errorf("redis addr is required") + } + + client := redis.NewClient(&redis.Options{ + Addr: config.Addr, + Username: config.Username, + Password: config.Password, + DB: config.DB, + Protocol: 3, // Explicitly use RESP3 protocol + PoolSize: config.PoolSize, + MaxActiveConns: config.MaxActiveConns, + MinIdleConns: config.MinIdleConns, + MaxIdleConns: config.MaxIdleConns, + ConnMaxLifetime: config.ConnMaxLifetime, + ConnMaxIdleTime: config.ConnMaxIdleTime, + DialTimeout: config.DialTimeout, + ReadTimeout: config.ReadTimeout, + WriteTimeout: config.WriteTimeout, + }) + + store := &RedisStore{ + client: client, + config: config, + logger: logger, + } + + return store, nil +} diff --git a/framework/vectorstore/redis_test.go b/framework/vectorstore/redis_test.go new file mode 100644 index 000000000..94052346f --- /dev/null +++ b/framework/vectorstore/redis_test.go @@ -0,0 +1,889 @@ +package vectorstore + +import ( + "context" + "os" + "testing" + "time" + + bifrost "github.com/maximhq/bifrost/core" + "github.com/maximhq/bifrost/core/schemas" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +// Test constants +const ( + RedisTestTimeout = 30 * time.Second + TestNamespace = "TestRedis" + DefaultTestAddr = "localhost:6379" + DefaultRedisTestTimeout = 10 * time.Second + RedisTestDimension = 1536 +) + +// TestSetup provides common test infrastructure +type RedisTestSetup struct { + Store *RedisStore + Logger schemas.Logger + Config RedisConfig + ctx context.Context + cancel context.CancelFunc +} + +// NewRedisTestSetup creates a test setup with environment-driven configuration +func NewRedisTestSetup(t *testing.T) *RedisTestSetup { + // Get configuration from environment variables + addr := getEnvWithDefault("REDIS_ADDR", DefaultTestAddr) + username := os.Getenv("REDIS_USERNAME") + password := os.Getenv("REDIS_PASSWORD") + db, err := getEnvWithDefaultInt("REDIS_DB", 0) + if err != nil { + t.Fatalf("Failed to get REDIS_DB: %v", err) + } + + timeoutStr := getEnvWithDefault("REDIS_TIMEOUT", "10s") + timeout, err := time.ParseDuration(timeoutStr) + if err != nil { + timeout = DefaultRedisTestTimeout + } + + config := RedisConfig{ + Addr: addr, + Username: username, + Password: password, + DB: db, + ContextTimeout: timeout, + } + + logger := bifrost.NewDefaultLogger(schemas.LogLevelInfo) + ctx, cancel := context.WithTimeout(context.Background(), RedisTestTimeout) + + store, err := newRedisStore(ctx, config, logger) + if err != nil { + cancel() + t.Fatalf("Failed to create Redis store: %v", err) + } + + setup := &RedisTestSetup{ + Store: store, + Logger: logger, + Config: config, + ctx: ctx, + cancel: cancel, + } + + // Ensure namespace exists for integration tests + if !testing.Short() { + setup.ensureNamespaceExists(t) + } + + return setup +} + +// Cleanup cleans up test resources +func (ts *RedisTestSetup) Cleanup(t *testing.T) { + defer ts.cancel() + + if !testing.Short() { + // Clean up test data + ts.cleanupTestData(t) + } + + if err := ts.Store.Close(ts.ctx, TestNamespace); err != nil { + t.Logf("Warning: Failed to close store: %v", err) + } +} + +// ensureNamespaceExists creates the test namespace in Redis +func (ts *RedisTestSetup) ensureNamespaceExists(t *testing.T) { + // Create namespace with test properties + properties := map[string]VectorStoreProperties{ + "key": { + DataType: VectorStorePropertyTypeString, + }, + "type": { + DataType: VectorStorePropertyTypeString, + }, + "test_type": { + DataType: VectorStorePropertyTypeString, + }, + "size": { + DataType: VectorStorePropertyTypeInteger, + }, + "public": { + DataType: VectorStorePropertyTypeBoolean, + }, + "author": { + DataType: VectorStorePropertyTypeString, + }, + "request_hash": { + DataType: VectorStorePropertyTypeString, + }, + "user": { + DataType: VectorStorePropertyTypeString, + }, + "lang": { + DataType: VectorStorePropertyTypeString, + }, + "category": { + DataType: VectorStorePropertyTypeString, + }, + "content": { + DataType: VectorStorePropertyTypeString, + }, + "response": { + DataType: VectorStorePropertyTypeString, + }, + "from_bifrost_semantic_cache_plugin": { + DataType: VectorStorePropertyTypeBoolean, + }, + } + + err := ts.Store.CreateNamespace(ts.ctx, TestNamespace, RedisTestDimension, properties) + if err != nil { + t.Fatalf("Failed to create namespace %q: %v", TestNamespace, err) + } + t.Logf("Created test namespace: %s", TestNamespace) +} + +// cleanupTestData removes all test objects from the namespace +func (ts *RedisTestSetup) cleanupTestData(t *testing.T) { + // Delete all objects in the test namespace + allTestKeys, _, err := ts.Store.GetAll(ts.ctx, TestNamespace, []Query{}, []string{}, nil, 1000) + if err != nil { + t.Logf("Warning: Failed to get all test keys: %v", err) + return + } + + for _, key := range allTestKeys { + err := ts.Store.Delete(ts.ctx, TestNamespace, key.ID) + if err != nil { + t.Logf("Warning: Failed to delete test key %s: %v", key.ID, err) + } + } + + t.Logf("Cleaned up test namespace: %s", TestNamespace) +} + +// ============================================================================ +// UNIT TESTS +// ============================================================================ + +func TestRedisConfig_Validation(t *testing.T) { + logger := bifrost.NewDefaultLogger(schemas.LogLevelInfo) + ctx := context.Background() + + tests := []struct { + name string + config RedisConfig + expectError bool + errorMsg string + }{ + { + name: "valid config", + config: RedisConfig{ + Addr: "localhost:6379", + }, + expectError: false, + }, + { + name: "missing addr", + config: RedisConfig{ + Username: "user", + }, + expectError: true, + errorMsg: "redis addr is required", + }, + { + name: "with credentials", + config: RedisConfig{ + Addr: "localhost:6379", + Username: "default", + Password: "", + }, + expectError: false, + }, + { + name: "with custom db", + config: RedisConfig{ + Addr: "localhost:6379", + DB: 1, + }, + expectError: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + store, err := newRedisStore(ctx, tt.config, logger) + + if tt.expectError { + assert.Error(t, err) + assert.Nil(t, store) + if tt.errorMsg != "" { + assert.Contains(t, err.Error(), tt.errorMsg) + } + } else { + // For valid configs, store creation should succeed + // (connection will fail later when actually using Redis) + assert.NoError(t, err) + assert.NotNil(t, store) + } + }) + } +} + +// ============================================================================ +// INTEGRATION TESTS (require real Redis instance with RediSearch) +// ============================================================================ + +func TestRedisStore_Integration(t *testing.T) { + if testing.Short() { + t.Skip("Skipping integration tests in short mode") + } + + setup := NewRedisTestSetup(t) + defer setup.Cleanup(t) + + t.Run("Add and GetChunk", func(t *testing.T) { + testKey := generateUUID() + embedding := generateTestEmbedding(RedisTestDimension) + metadata := map[string]interface{}{ + "type": "document", + "size": 1024, + "public": true, + } + + // Add object + err := setup.Store.Add(setup.ctx, TestNamespace, testKey, embedding, metadata) + require.NoError(t, err) + + // Small delay to ensure consistency + time.Sleep(100 * time.Millisecond) + + // Get single chunk + result, err := setup.Store.GetChunk(setup.ctx, TestNamespace, testKey) + require.NoError(t, err) + assert.NotEmpty(t, result) + assert.Equal(t, "document", result.Properties["type"]) // Should contain metadata + }) + + t.Run("Add without embedding", func(t *testing.T) { + testKey := generateUUID() + metadata := map[string]interface{}{ + "type": "metadata-only", + } + + // Add object without embedding + err := setup.Store.Add(setup.ctx, TestNamespace, testKey, nil, metadata) + require.NoError(t, err) + + time.Sleep(100 * time.Millisecond) + + // Retrieve it + result, err := setup.Store.GetChunk(setup.ctx, TestNamespace, testKey) + require.NoError(t, err) + assert.Equal(t, "metadata-only", result.Properties["type"]) + }) + + t.Run("GetChunks batch retrieval", func(t *testing.T) { + // Add multiple objects + keys := []string{generateUUID(), generateUUID(), generateUUID()} + embeddings := [][]float32{ + generateTestEmbedding(RedisTestDimension), + generateTestEmbedding(RedisTestDimension), + nil, + } + metadata := []map[string]interface{}{ + {"type": "doc1", "size": 100}, + {"type": "doc2", "size": 200}, + {"type": "doc3", "size": 300}, + } + + for i, key := range keys { + emb := embeddings[i] + err := setup.Store.Add(setup.ctx, TestNamespace, key, emb, metadata[i]) + require.NoError(t, err) + } + + time.Sleep(100 * time.Millisecond) + + // Get all chunks + results, err := setup.Store.GetChunks(setup.ctx, TestNamespace, keys) + require.NoError(t, err) + assert.Len(t, results, 3) + + // Verify each result + for i, result := range results { + assert.Equal(t, keys[i], result.ID) + assert.Equal(t, metadata[i]["type"], result.Properties["type"]) + } + }) +} + +func TestRedisStore_FilteringScenarios(t *testing.T) { + if testing.Short() { + t.Skip("Skipping integration tests in short mode") + } + + setup := NewRedisTestSetup(t) + defer setup.Cleanup(t) + + // Setup test data for filtering scenarios + testData := []struct { + key string + metadata map[string]interface{} + }{ + { + generateUUID(), + map[string]interface{}{ + "type": "pdf", + "size": 1024, + "public": true, + "author": "alice", + }, + }, + { + generateUUID(), + map[string]interface{}{ + "type": "docx", + "size": 2048, + "public": false, + "author": "bob", + }, + }, + { + generateUUID(), + map[string]interface{}{ + "type": "pdf", + "size": 512, + "public": true, + "author": "alice", + }, + }, + { + generateUUID(), + map[string]interface{}{ + "type": "txt", + "size": 256, + "public": true, + "author": "charlie", + }, + }, + } + + filterFields := []string{"type", "size", "public", "author"} + + // Add all test data + for _, item := range testData { + embedding := generateTestEmbedding(RedisTestDimension) + err := setup.Store.Add(setup.ctx, TestNamespace, item.key, embedding, item.metadata) + require.NoError(t, err) + } + + time.Sleep(500 * time.Millisecond) // Wait for consistency + + t.Run("Filter by numeric comparison", func(t *testing.T) { + queries := []Query{ + {Field: "size", Operator: QueryOperatorGreaterThan, Value: 1000}, + } + + results, _, err := setup.Store.GetAll(setup.ctx, TestNamespace, queries, filterFields, nil, 10) + require.NoError(t, err) + assert.Len(t, results, 2) // doc1 (1024) and doc2 (2048) + }) + + t.Run("Filter by boolean", func(t *testing.T) { + queries := []Query{ + {Field: "public", Operator: QueryOperatorEqual, Value: true}, + } + + results, _, err := setup.Store.GetAll(setup.ctx, TestNamespace, queries, filterFields, nil, 10) + require.NoError(t, err) + assert.Len(t, results, 3) // doc1, doc3, doc4 + }) + + t.Run("Multiple filters (AND)", func(t *testing.T) { + queries := []Query{ + {Field: "type", Operator: QueryOperatorEqual, Value: "pdf"}, + {Field: "public", Operator: QueryOperatorEqual, Value: true}, + } + + results, _, err := setup.Store.GetAll(setup.ctx, TestNamespace, queries, filterFields, nil, 10) + require.NoError(t, err) + assert.Len(t, results, 2) // doc1 and doc3 + }) + + t.Run("Complex multi-condition filter", func(t *testing.T) { + queries := []Query{ + {Field: "author", Operator: QueryOperatorEqual, Value: "alice"}, + {Field: "size", Operator: QueryOperatorLessThan, Value: 2000}, + {Field: "public", Operator: QueryOperatorEqual, Value: true}, + } + + results, _, err := setup.Store.GetAll(setup.ctx, TestNamespace, queries, filterFields, nil, 10) + require.NoError(t, err) + assert.Len(t, results, 2) // doc1 and doc3 (both by alice, < 2000 size, public) + }) + + t.Run("Pagination test", func(t *testing.T) { + // Test with limit of 2 + results, cursor, err := setup.Store.GetAll(setup.ctx, TestNamespace, nil, filterFields, nil, 2) + require.NoError(t, err) + assert.Len(t, results, 2) + + if cursor != nil { + // Get next page + nextResults, _, err := setup.Store.GetAll(setup.ctx, TestNamespace, nil, filterFields, cursor, 2) + require.NoError(t, err) + assert.LessOrEqual(t, len(nextResults), 2) + t.Logf("First page: %d results, Next page: %d results", len(results), len(nextResults)) + } + }) +} + +func TestRedisStore_VectorSearch(t *testing.T) { + if testing.Short() { + t.Skip("Skipping integration tests in short mode") + } + + setup := NewRedisTestSetup(t) + defer setup.Cleanup(t) + + // Add test documents with embeddings + testDocs := []struct { + key string + embedding []float32 + metadata map[string]interface{} + }{ + { + generateUUID(), + generateTestEmbedding(RedisTestDimension), + map[string]interface{}{ + "type": "tech", + "category": "programming", + "content": "Go programming language", + }, + }, + { + generateUUID(), + generateTestEmbedding(RedisTestDimension), + map[string]interface{}{ + "type": "tech", + "category": "programming", + "content": "Python programming language", + }, + }, + { + generateUUID(), + generateTestEmbedding(RedisTestDimension), + map[string]interface{}{ + "type": "sports", + "category": "football", + "content": "Football match results", + }, + }, + } + + for _, doc := range testDocs { + err := setup.Store.Add(setup.ctx, TestNamespace, doc.key, doc.embedding, doc.metadata) + require.NoError(t, err) + } + + time.Sleep(500 * time.Millisecond) + + t.Run("Vector similarity search", func(t *testing.T) { + // Search for similar content to the first document + queryEmbedding := testDocs[0].embedding + results, err := setup.Store.GetNearest(setup.ctx, TestNamespace, queryEmbedding, nil, []string{"type", "category", "content"}, 0.1, 10) + require.NoError(t, err) + assert.GreaterOrEqual(t, len(results), 1) + + // Check that results have scores and are not nil + require.NotEmpty(t, results) + require.NotNil(t, results[0].Score) + assert.InDelta(t, 1.0, *results[0].Score, 1e-4) + }) + + t.Run("Vector search with metadata filters", func(t *testing.T) { + // Search for tech content only + queries := []Query{ + {Field: "type", Operator: QueryOperatorEqual, Value: "tech"}, + } + + queryEmbedding := testDocs[0].embedding + results, err := setup.Store.GetNearest(setup.ctx, TestNamespace, queryEmbedding, queries, []string{"type", "category", "content"}, 0.1, 10) + require.NoError(t, err) + assert.GreaterOrEqual(t, len(results), 1) + + // All results should be tech type + for _, result := range results { + assert.Equal(t, "tech", result.Properties["type"]) + } + }) + + t.Run("Vector search with threshold", func(t *testing.T) { + // Use a very high threshold to get only very similar results + queryEmbedding := testDocs[0].embedding + results, err := setup.Store.GetNearest(setup.ctx, TestNamespace, queryEmbedding, nil, []string{"type", "category", "content"}, 0.99, 10) + require.NoError(t, err) + // Should return fewer results due to high threshold + t.Logf("High threshold search returned %d results", len(results)) + }) +} + +func TestRedisStore_CompleteUseCases(t *testing.T) { + if testing.Short() { + t.Skip("Skipping integration tests in short mode") + } + + setup := NewRedisTestSetup(t) + defer setup.Cleanup(t) + + t.Run("Document Storage & Retrieval Scenario", func(t *testing.T) { + // Add documents with different types + documents := []struct { + key string + embedding []float32 + metadata map[string]interface{} + }{ + { + generateUUID(), + generateTestEmbedding(RedisTestDimension), + map[string]interface{}{"type": "pdf", "size": 1024, "public": true}, + }, + { + generateUUID(), + generateTestEmbedding(RedisTestDimension), + map[string]interface{}{"type": "docx", "size": 2048, "public": false}, + }, + { + generateUUID(), + generateTestEmbedding(RedisTestDimension), + map[string]interface{}{"type": "pdf", "size": 512, "public": true}, + }, + } + + filterFields := []string{"type", "size", "public"} + + for _, doc := range documents { + err := setup.Store.Add(setup.ctx, TestNamespace, doc.key, doc.embedding, doc.metadata) + require.NoError(t, err) + } + + time.Sleep(300 * time.Millisecond) + + // Test various retrieval patterns + + // Get PDF documents + pdfQuery := []Query{{Field: "type", Operator: QueryOperatorEqual, Value: "pdf"}} + results, _, err := setup.Store.GetAll(setup.ctx, TestNamespace, pdfQuery, filterFields, nil, 10) + require.NoError(t, err) + assert.Len(t, results, 2) // doc1, doc3 + + // Get large documents (size > 1000) + sizeQuery := []Query{{Field: "size", Operator: QueryOperatorGreaterThan, Value: 1000}} + results, _, err = setup.Store.GetAll(setup.ctx, TestNamespace, sizeQuery, filterFields, nil, 10) + require.NoError(t, err) + assert.Len(t, results, 2) // doc1, doc2 + + // Get public PDFs + combinedQuery := []Query{ + {Field: "public", Operator: QueryOperatorEqual, Value: true}, + {Field: "type", Operator: QueryOperatorEqual, Value: "pdf"}, + } + results, _, err = setup.Store.GetAll(setup.ctx, TestNamespace, combinedQuery, filterFields, nil, 10) + require.NoError(t, err) + assert.Len(t, results, 2) // doc1, doc3 + + // Vector similarity search + queryEmbedding := documents[0].embedding // Similar to doc1 + vectorResults, err := setup.Store.GetNearest(setup.ctx, TestNamespace, queryEmbedding, nil, filterFields, 0.8, 10) + require.NoError(t, err) + assert.GreaterOrEqual(t, len(vectorResults), 1) + }) + + t.Run("Semantic Cache-like Workflow", func(t *testing.T) { + // Add request-response pairs with parameters + cacheEntries := []struct { + key string + embedding []float32 + metadata map[string]interface{} + }{ + { + generateUUID(), + generateTestEmbedding(RedisTestDimension), + map[string]interface{}{ + "request_hash": "abc123", + "user": "u1", + "lang": "en", + "response": "answer1", + "from_bifrost_semantic_cache_plugin": true, + }, + }, + { + generateUUID(), + generateTestEmbedding(RedisTestDimension), + map[string]interface{}{ + "request_hash": "def456", + "user": "u1", + "lang": "es", + "response": "answer2", + "from_bifrost_semantic_cache_plugin": true, + }, + }, + } + + filterFields := []string{"request_hash", "user", "lang", "response", "from_bifrost_semantic_cache_plugin"} + + for _, entry := range cacheEntries { + err := setup.Store.Add(setup.ctx, TestNamespace, entry.key, entry.embedding, entry.metadata) + require.NoError(t, err) + } + + time.Sleep(300 * time.Millisecond) + + // Test hash-based direct retrieval (exact match) + hashQuery := []Query{{Field: "request_hash", Operator: QueryOperatorEqual, Value: "abc123"}} + results, _, err := setup.Store.GetAll(setup.ctx, TestNamespace, hashQuery, filterFields, nil, 10) + require.NoError(t, err) + assert.Len(t, results, 1) + + // Test semantic search with user and language filters + userLangFilter := []Query{ + {Field: "user", Operator: QueryOperatorEqual, Value: "u1"}, + {Field: "lang", Operator: QueryOperatorEqual, Value: "en"}, + } + similarEmbedding := generateSimilarEmbedding(cacheEntries[0].embedding, 0.9) + vectorResults, err := setup.Store.GetNearest(setup.ctx, TestNamespace, similarEmbedding, userLangFilter, filterFields, 0.7, 10) + require.NoError(t, err) + assert.Len(t, vectorResults, 1) // Should find English content for u1 + }) +} + +func TestRedisStore_DeleteOperations(t *testing.T) { + if testing.Short() { + t.Skip("Skipping integration tests in short mode") + } + + setup := NewRedisTestSetup(t) + defer setup.Cleanup(t) + + t.Run("Delete single item", func(t *testing.T) { + // Add an item + key := generateUUID() + embedding := generateTestEmbedding(RedisTestDimension) + metadata := map[string]interface{}{"type": "test", "value": "delete_me"} + + err := setup.Store.Add(setup.ctx, TestNamespace, key, embedding, metadata) + require.NoError(t, err) + + time.Sleep(100 * time.Millisecond) + + // Verify it exists + result, err := setup.Store.GetChunk(setup.ctx, TestNamespace, key) + require.NoError(t, err) + assert.Equal(t, "test", result.Properties["type"]) + + // Delete it + err = setup.Store.Delete(setup.ctx, TestNamespace, key) + require.NoError(t, err) + + // Verify it's gone + _, err = setup.Store.GetChunk(setup.ctx, TestNamespace, key) + assert.Error(t, err) + }) + + t.Run("DeleteAll with filters", func(t *testing.T) { + // Add multiple items with different types + testItems := []struct { + key string + embedding []float32 + metadata map[string]interface{} + }{ + { + generateUUID(), + generateTestEmbedding(RedisTestDimension), + map[string]interface{}{"type": "delete_me", "category": "test"}, + }, + { + generateUUID(), + generateTestEmbedding(RedisTestDimension), + map[string]interface{}{"type": "delete_me", "category": "test"}, + }, + { + generateUUID(), + generateTestEmbedding(RedisTestDimension), + map[string]interface{}{"type": "keep_me", "category": "test"}, + }, + } + + for _, item := range testItems { + err := setup.Store.Add(setup.ctx, TestNamespace, item.key, item.embedding, item.metadata) + require.NoError(t, err) + } + + time.Sleep(300 * time.Millisecond) + + // Delete all items with type "delete_me" + queries := []Query{ + {Field: "type", Operator: QueryOperatorEqual, Value: "delete_me"}, + } + + deleteResults, err := setup.Store.DeleteAll(setup.ctx, TestNamespace, queries) + require.NoError(t, err) + assert.Len(t, deleteResults, 2) // Should delete 2 items + + // Verify only "keep_me" items remain + allResults, _, err := setup.Store.GetAll(setup.ctx, TestNamespace, nil, []string{"type"}, nil, 10) + require.NoError(t, err) + assert.Len(t, allResults, 1) // Only the "keep_me" item should remain + assert.Equal(t, "keep_me", allResults[0].Properties["type"]) + }) +} + +// ============================================================================ +// INTERFACE COMPLIANCE TESTS +// ============================================================================ + +func TestRedisStore_InterfaceCompliance(t *testing.T) { + // Verify that RedisStore implements VectorStore interface + var _ VectorStore = (*RedisStore)(nil) +} + +func TestVectorStoreFactory_Redis(t *testing.T) { + if testing.Short() { + t.Skip("Skipping integration tests in short mode") + } + + logger := bifrost.NewDefaultLogger(schemas.LogLevelInfo) + config := &Config{ + Enabled: true, + Type: VectorStoreTypeRedis, + Config: RedisConfig{ + Addr: getEnvWithDefault("REDIS_ADDR", DefaultTestAddr), + Username: os.Getenv("REDIS_USERNAME"), + Password: os.Getenv("REDIS_PASSWORD"), + }, + } + + store, err := NewVectorStore(context.Background(), config, logger) + if err != nil { + t.Skipf("Could not create Redis store: %v", err) + } + defer store.Close(context.Background(), TestNamespace) + + // Verify it's actually a RedisStore + redisStore, ok := store.(*RedisStore) + assert.True(t, ok) + assert.NotNil(t, redisStore) +} + +// ============================================================================ +// ERROR HANDLING TESTS +// ============================================================================ + +func TestRedisStore_ErrorHandling(t *testing.T) { + if testing.Short() { + t.Skip("Skipping integration tests in short mode") + } + + setup := NewRedisTestSetup(t) + defer setup.Cleanup(t) + + t.Run("GetChunk with non-existent key", func(t *testing.T) { + _, err := setup.Store.GetChunk(setup.ctx, TestNamespace, "non-existent-key") + assert.Error(t, err) + }) + + t.Run("Delete non-existent key", func(t *testing.T) { + err := setup.Store.Delete(setup.ctx, TestNamespace, "non-existent-key") + assert.Error(t, err) + }) + + t.Run("Add with empty ID", func(t *testing.T) { + embedding := generateTestEmbedding(RedisTestDimension) + metadata := map[string]interface{}{"type": "test"} + + err := setup.Store.Add(setup.ctx, TestNamespace, "", embedding, metadata) + assert.Error(t, err) + }) + + t.Run("GetNearest with empty namespace", func(t *testing.T) { + embedding := generateTestEmbedding(RedisTestDimension) + _, err := setup.Store.GetNearest(setup.ctx, "", embedding, nil, []string{}, 0.8, 10) + assert.Error(t, err) + }) +} + +func TestRedisStore_NamespaceDimensionHandling(t *testing.T) { + if testing.Short() { + t.Skip("Skipping integration tests in short mode") + } + + setup := NewRedisTestSetup(t) + defer setup.Cleanup(t) + + testNamespace := "TestDimensionHandling" + + t.Run("Recreate namespace with different dimension should not crash", func(t *testing.T) { + properties := map[string]VectorStoreProperties{ + "type": {DataType: VectorStorePropertyTypeString}, + "test": {DataType: VectorStorePropertyTypeString}, + } + + // Step 1: Create namespace with dimension 512 + err := setup.Store.CreateNamespace(setup.ctx, testNamespace, 512, properties) + require.NoError(t, err) + + // Add a document with 512-dimensional embedding + embedding512 := generateTestEmbedding(512) + metadata := map[string]interface{}{ + "type": "test_doc", + "test": "dimension_512", + } + + err = setup.Store.Add(setup.ctx, testNamespace, "test-key-512", embedding512, metadata) + require.NoError(t, err) + + // Verify it was added + result, err := setup.Store.GetChunk(setup.ctx, testNamespace, "test-key-512") + require.NoError(t, err) + assert.Equal(t, "dimension_512", result.Properties["test"]) + + // Step 2: Delete the namespace + err = setup.Store.DeleteNamespace(setup.ctx, testNamespace) + require.NoError(t, err) + + // Step 3: Create namespace with same name but different dimension - should not crash + err = setup.Store.CreateNamespace(setup.ctx, testNamespace, 1024, properties) + require.NoError(t, err) + + // Add a document with 1024-dimensional embedding + embedding1024 := generateTestEmbedding(1024) + metadata1024 := map[string]interface{}{ + "type": "test_doc", + "test": "dimension_1024", + } + + err = setup.Store.Add(setup.ctx, testNamespace, "test-key-1024", embedding1024, metadata1024) + require.NoError(t, err) + + // Verify new document exists + result, err = setup.Store.GetChunk(setup.ctx, testNamespace, "test-key-1024") + require.NoError(t, err) + assert.Equal(t, "dimension_1024", result.Properties["test"]) + + // Verify vector search works with new dimension + vectorResults, err := setup.Store.GetNearest(setup.ctx, testNamespace, embedding1024, nil, []string{"type", "test"}, 0.8, 10) + require.NoError(t, err) + assert.GreaterOrEqual(t, len(vectorResults), 1) + assert.NotNil(t, vectorResults[0].Score) + + // Cleanup + err = setup.Store.DeleteNamespace(setup.ctx, testNamespace) + if err != nil { + t.Logf("Warning: Failed to cleanup namespace: %v", err) + } + }) +} diff --git a/framework/vectorstore/store.go b/framework/vectorstore/store.go new file mode 100644 index 000000000..309c4cea5 --- /dev/null +++ b/framework/vectorstore/store.go @@ -0,0 +1,168 @@ +// Package vectorstore provides a generic interface for vector stores. +package vectorstore + +import ( + "context" + "encoding/json" + "fmt" + + "github.com/maximhq/bifrost/core/schemas" +) + +type VectorStoreType string + +const ( + VectorStoreTypeWeaviate VectorStoreType = "weaviate" + VectorStoreTypeRedis VectorStoreType = "redis" +) + +// Query represents a query to the vector store. +type Query struct { + Field string + Operator QueryOperator + Value interface{} +} + +type QueryOperator string + +const ( + QueryOperatorEqual QueryOperator = "Equal" + QueryOperatorNotEqual QueryOperator = "NotEqual" + QueryOperatorGreaterThan QueryOperator = "GreaterThan" + QueryOperatorLessThan QueryOperator = "LessThan" + QueryOperatorGreaterThanOrEqual QueryOperator = "GreaterThanOrEqual" + QueryOperatorLessThanOrEqual QueryOperator = "LessThanOrEqual" + QueryOperatorLike QueryOperator = "Like" + QueryOperatorContainsAny QueryOperator = "ContainsAny" + QueryOperatorContainsAll QueryOperator = "ContainsAll" + QueryOperatorIsNull QueryOperator = "IsNull" + QueryOperatorIsNotNull QueryOperator = "IsNotNull" +) + +// SearchResult represents a search result with metadata. +type SearchResult struct { + ID string + Score *float64 + Properties map[string]interface{} +} + +// DeleteResult represents the result of a delete operation. +type DeleteResult struct { + ID string + Status DeleteStatus + Error string +} + +type DeleteStatus string + +const ( + DeleteStatusSuccess DeleteStatus = "success" + DeleteStatusError DeleteStatus = "error" +) + +type VectorStoreProperties struct { + DataType VectorStorePropertyType `json:"data_type"` + Description string `json:"description"` +} + +type VectorStorePropertyType string + +const ( + VectorStorePropertyTypeString VectorStorePropertyType = "string" + VectorStorePropertyTypeInteger VectorStorePropertyType = "integer" + VectorStorePropertyTypeBoolean VectorStorePropertyType = "boolean" + VectorStorePropertyTypeStringArray VectorStorePropertyType = "string[]" +) + +// VectorStore represents the interface for the vector store. +type VectorStore interface { + CreateNamespace(ctx context.Context, namespace string, dimension int, properties map[string]VectorStoreProperties) error + DeleteNamespace(ctx context.Context, namespace string) error + GetChunk(ctx context.Context, namespace string, id string) (SearchResult, error) + GetChunks(ctx context.Context, namespace string, ids []string) ([]SearchResult, error) + GetAll(ctx context.Context, namespace string, queries []Query, selectFields []string, cursor *string, limit int64) ([]SearchResult, *string, error) + GetNearest(ctx context.Context, namespace string, vector []float32, queries []Query, selectFields []string, threshold float64, limit int64) ([]SearchResult, error) + Add(ctx context.Context, namespace string, id string, embedding []float32, metadata map[string]interface{}) error + Delete(ctx context.Context, namespace string, id string) error + DeleteAll(ctx context.Context, namespace string, queries []Query) ([]DeleteResult, error) + Close(ctx context.Context, namespace string) error +} + +// Config represents the configuration for the vector store. +type Config struct { + Enabled bool `json:"enabled"` + Type VectorStoreType `json:"type"` + Config any `json:"config"` +} + +// UnmarshalJSON unmarshals the config from JSON. +func (c *Config) UnmarshalJSON(data []byte) error { + // First, unmarshal into a temporary struct to get the basic fields + type TempConfig struct { + Enabled bool `json:"enabled"` + Type string `json:"type"` + Config json.RawMessage `json:"config"` // Keep as raw JSON + } + + var temp TempConfig + if err := json.Unmarshal(data, &temp); err != nil { + return fmt.Errorf("failed to unmarshal config: %w", err) + } + + // Set basic fields + c.Enabled = temp.Enabled + c.Type = VectorStoreType(temp.Type) + + // Parse the config field based on type + switch c.Type { + case VectorStoreTypeWeaviate: + var weaviateConfig WeaviateConfig + if err := json.Unmarshal(temp.Config, &weaviateConfig); err != nil { + return fmt.Errorf("failed to unmarshal weaviate config: %w", err) + } + c.Config = weaviateConfig + case VectorStoreTypeRedis: + var redisConfig RedisConfig + if err := json.Unmarshal(temp.Config, &redisConfig); err != nil { + return fmt.Errorf("failed to unmarshal redis config: %w", err) + } + c.Config = redisConfig + default: + return fmt.Errorf("unknown vector store type: %s", temp.Type) + } + + return nil +} + +// NewVectorStore returns a new vector store based on the configuration. +func NewVectorStore(ctx context.Context, config *Config, logger schemas.Logger) (VectorStore, error) { + if config == nil { + return nil, fmt.Errorf("config cannot be nil") + } + + if !config.Enabled { + return nil, fmt.Errorf("vector store is disabled") + } + + switch config.Type { + case VectorStoreTypeWeaviate: + if config.Config == nil { + return nil, fmt.Errorf("weaviate config is required") + } + weaviateConfig, ok := config.Config.(WeaviateConfig) + if !ok { + return nil, fmt.Errorf("invalid weaviate config") + } + return newWeaviateStore(ctx, &weaviateConfig, logger) + case VectorStoreTypeRedis: + if config.Config == nil { + return nil, fmt.Errorf("redis config is required") + } + redisConfig, ok := config.Config.(RedisConfig) + if !ok { + return nil, fmt.Errorf("invalid redis config") + } + return newRedisStore(ctx, redisConfig, logger) + } + return nil, fmt.Errorf("invalid vector store type: %s", config.Type) +} diff --git a/framework/vectorstore/test_utils.go b/framework/vectorstore/test_utils.go new file mode 100644 index 000000000..54eaf9450 --- /dev/null +++ b/framework/vectorstore/test_utils.go @@ -0,0 +1,47 @@ +package vectorstore + +import ( + "math/rand" + "os" + "strconv" + + "github.com/google/uuid" +) + +// Helper functions +func getEnvWithDefault(key, defaultValue string) string { + if value := os.Getenv(key); value != "" { + return value + } + return defaultValue +} + +func getEnvWithDefaultInt(key string, defaultValue int) (int, error) { + if value := os.Getenv(key); value != "" { + return strconv.Atoi(value) + } + return defaultValue, nil +} + +func generateUUID() string { + return uuid.New().String() +} + +func generateTestEmbedding(dim int) []float32 { + embedding := make([]float32, dim) + for i := range embedding { + embedding[i] = rand.Float32()*2 - 1 // Random values between -1 and 1 + } + return embedding +} + +func generateSimilarEmbedding(original []float32, similarity float32) []float32 { + similar := make([]float32, len(original)) + for i := range similar { + // Add small random noise to create similar but not identical embedding + noise := (rand.Float32()*2 - 1) * (1 - similarity) * 0.1 + similar[i] = original[i] + noise + } + return similar +} + diff --git a/framework/vectorstore/utils.go b/framework/vectorstore/utils.go new file mode 100644 index 000000000..82c8ddace --- /dev/null +++ b/framework/vectorstore/utils.go @@ -0,0 +1,15 @@ +package vectorstore + +import ( + "context" + "time" +) + +// withTimeout adds a timeout to the context if it is set. +func withTimeout(ctx context.Context, timeout time.Duration) (context.Context, context.CancelFunc) { + if timeout > 0 { + return context.WithTimeout(ctx, timeout) + } + // No-op cancel to simplify call sites. + return ctx, func() {} +} diff --git a/framework/vectorstore/weaviate.go b/framework/vectorstore/weaviate.go new file mode 100644 index 000000000..34fbc3c0b --- /dev/null +++ b/framework/vectorstore/weaviate.go @@ -0,0 +1,612 @@ +package vectorstore + +import ( + "context" + "fmt" + "strings" + "time" + + "github.com/maximhq/bifrost/core/schemas" + "github.com/weaviate/weaviate-go-client/v5/weaviate" + "github.com/weaviate/weaviate-go-client/v5/weaviate/auth" + "github.com/weaviate/weaviate-go-client/v5/weaviate/filters" + "github.com/weaviate/weaviate-go-client/v5/weaviate/graphql" + "github.com/weaviate/weaviate-go-client/v5/weaviate/grpc" + "github.com/weaviate/weaviate/entities/models" +) + +// Default values for Weaviate vector index configuration +const ( + // Default class names (Weaviate prefers PascalCase) + DefaultClassName = "BifrostStore" +) + +// WeaviateConfig represents the configuration for the Weaviate vector store. +type WeaviateConfig struct { + // Connection settings + Scheme string `json:"scheme"` // "http" or "https" - REQUIRED + Host string `json:"host"` // "localhost:8080" - REQUIRED + GrpcConfig *WeaviateGrpcConfig `json:"grpc_config,omitempty"` // grpc config for weaviate (optional) + + // Authentication settings (optional) + ApiKey string `json:"api_key,omitempty"` // API key for authentication + Headers map[string]string `json:"headers,omitempty"` // Additional headers + + // Connection settings + Timeout time.Duration `json:"timeout,omitempty"` // Request timeout (optional) +} + +type WeaviateGrpcConfig struct { + // Host is the host of the weaviate server (host:port). + // If host is without a port number then the 80 port for insecured and 443 port for secured connections will be used. + Host string `json:"host"` + // Secured is a boolean flag indicating if the connection is secured + Secured bool `json:"secured"` +} + +// WeaviateStore represents the Weaviate vector store. +type WeaviateStore struct { + client *weaviate.Client + config *WeaviateConfig + logger schemas.Logger +} + +// Add stores a new object (with or without embedding) +func (s *WeaviateStore) Add(ctx context.Context, className string, id string, embedding []float32, metadata map[string]interface{}) error { + if strings.TrimSpace(id) == "" { + return fmt.Errorf("id is required") + } + + obj := &models.Object{ + Class: className, + Properties: metadata, + } + + var err error + if len(embedding) > 0 { + _, err = s.client.Data().Creator(). + WithClassName(className). + WithID(id). + WithProperties(obj.Properties). + WithVector(embedding). + Do(ctx) + } else { + _, err = s.client.Data().Creator(). + WithClassName(className). + WithID(id). + WithProperties(obj.Properties). + Do(ctx) + } + + return err +} + +// GetChunk returns the "metadata" for a single key +func (s *WeaviateStore) GetChunk(ctx context.Context, className string, id string) (SearchResult, error) { + obj, err := s.client.Data().ObjectsGetter(). + WithClassName(className). + WithID(id). + Do(ctx) + if err != nil { + return SearchResult{}, err + } + if len(obj) == 0 { + return SearchResult{}, fmt.Errorf("not found: %s", id) + } + + props, ok := obj[0].Properties.(map[string]interface{}) + if !ok { + return SearchResult{}, fmt.Errorf("invalid properties") + } + + return SearchResult{ + ID: id, + Score: nil, + Properties: props, + }, nil +} + +// GetChunks returns multiple objects by ID +func (s *WeaviateStore) GetChunks(ctx context.Context, className string, ids []string) ([]SearchResult, error) { + out := make([]SearchResult, 0, len(ids)) + for _, id := range ids { + obj, err := s.client.Data().ObjectsGetter(). + WithClassName(className). + WithID(id). + Do(ctx) + if err != nil { + return nil, err + } + if len(obj) > 0 { + props, ok := obj[0].Properties.(map[string]interface{}) + if !ok { + return nil, fmt.Errorf("invalid properties") + } + out = append(out, SearchResult{ + ID: id, + Score: nil, + Properties: props, + }) + } + } + return out, nil +} + +// GetAll with filtering + pagination +func (s *WeaviateStore) GetAll(ctx context.Context, className string, queries []Query, selectFields []string, cursor *string, limit int64) ([]SearchResult, *string, error) { + where := buildWeaviateFilter(queries) + + fields := []graphql.Field{ + {Name: "_additional", Fields: []graphql.Field{ + {Name: "id"}, + }}, + } + for _, field := range selectFields { + fields = append(fields, graphql.Field{Name: field}) + } + + search := s.client.GraphQL().Get(). + WithClassName(className). + WithLimit(int(limit)). + WithFields(fields...) + + if where != nil { + search = search.WithWhere(where) + } + if cursor != nil { + search = search.WithAfter(*cursor) + } + + resp, err := search.Do(ctx) + if err != nil { + return nil, nil, err + } + + // Check for GraphQL errors + if len(resp.Errors) > 0 { + var errorMsgs []string + for _, err := range resp.Errors { + errorMsgs = append(errorMsgs, err.Message) + } + return nil, nil, fmt.Errorf("graphql errors: %v", errorMsgs) + } + + data, ok := resp.Data["Get"].(map[string]interface{}) + if !ok { + return nil, nil, fmt.Errorf("invalid graphql response: missing 'Get' key, got: %+v", resp.Data) + } + + objsRaw, exists := data[className] + if !exists { + // No results for this class - this is normal, not an error + s.logger.Debug(fmt.Sprintf("No results found for class '%s', available classes: %+v", className, data)) + return nil, nil, nil + } + + objs, ok := objsRaw.([]interface{}) + if !ok { + s.logger.Debug(fmt.Sprintf("Class '%s' exists but data is not an array: %+v", className, objsRaw)) + return nil, nil, nil + } + + results := make([]SearchResult, 0, len(objs)) + var nextCursor *string + for _, o := range objs { + obj, ok := o.(map[string]interface{}) + if !ok { + continue + } + + // Convert to SearchResult format for consistency + searchResult := SearchResult{ + Properties: obj, + } + + if additional, ok := obj["_additional"].(map[string]interface{}); ok { + if id, ok := additional["id"].(string); ok { + searchResult.ID = id + nextCursor = &id + } + } + + results = append(results, searchResult) + } + + return results, nextCursor, nil +} + +// GetNearest with explicit filters only +func (s *WeaviateStore) GetNearest( + ctx context.Context, + className string, + vector []float32, + queries []Query, + selectFields []string, + threshold float64, + limit int64, +) ([]SearchResult, error) { + where := buildWeaviateFilter(queries) + + fields := []graphql.Field{ + {Name: "_additional", Fields: []graphql.Field{ + {Name: "id"}, + {Name: "certainty"}, + }}, + } + + for _, field := range selectFields { + fields = append(fields, graphql.Field{Name: field}) + } + + nearVector := s.client.GraphQL().NearVectorArgBuilder(). + WithVector(vector). + WithCertainty(float32(threshold)) + + search := s.client.GraphQL().Get(). + WithClassName(className). + WithNearVector(nearVector). + WithLimit(int(limit)). + WithFields(fields...) + + if where != nil { + search = search.WithWhere(where) + } + + resp, err := search.Do(ctx) + if err != nil { + return nil, err + } + + // Check for GraphQL errors + if len(resp.Errors) > 0 { + var errorMsgs []string + for _, err := range resp.Errors { + errorMsgs = append(errorMsgs, err.Message) + } + return nil, fmt.Errorf("graphql errors: %v", errorMsgs) + } + + data, ok := resp.Data["Get"].(map[string]interface{}) + if !ok { + return nil, fmt.Errorf("invalid graphql response: missing 'Get' key, got: %+v", resp.Data) + } + + objsRaw, exists := data[className] + if !exists { + // No results for this class - this is normal, not an error + s.logger.Debug(fmt.Sprintf("No results found for class '%s', available classes: %+v", className, data)) + return nil, nil + } + + objs, ok := objsRaw.([]interface{}) + if !ok { + s.logger.Debug(fmt.Sprintf("Class '%s' exists but data is not an array: %+v", className, objsRaw)) + return nil, nil + } + + results := make([]SearchResult, 0, len(objs)) + for _, o := range objs { + obj, ok := o.(map[string]interface{}) + if !ok { + continue + } + + additional, ok := obj["_additional"].(map[string]interface{}) + if !ok { + continue + } + + // Safely extract ID + idRaw, exists := additional["id"] + if !exists || idRaw == nil { + continue + } + id, ok := idRaw.(string) + if !ok { + continue + } + + // Safely extract certainty/score with default value + var score float64 + if certaintyRaw, exists := additional["certainty"]; exists && certaintyRaw != nil { + switch v := certaintyRaw.(type) { + case float64: + score = v + case float32: + score = float64(v) + case int: + score = float64(v) + case int64: + score = float64(v) + default: + score = 0.0 // Default score if type conversion fails + } + } + + results = append(results, SearchResult{ + ID: id, + Score: &score, + Properties: obj, + }) + } + + return results, nil +} + +// Delete removes multiple objects by ID +func (s *WeaviateStore) Delete(ctx context.Context, className string, id string) error { + return s.client.Data().Deleter(). + WithClassName(className). + WithID(id). + Do(ctx) +} + +func (s *WeaviateStore) DeleteAll(ctx context.Context, className string, queries []Query) ([]DeleteResult, error) { + where := buildWeaviateFilter(queries) + + res, err := s.client.Batch().ObjectsBatchDeleter(). + WithClassName(className). + WithWhere(where). + Do(ctx) + if err != nil { + return nil, err + } + + // NOTE: Weaviate is returning an empty array for Results.Objects, even on successful deletes. + results := make([]DeleteResult, 0, len(res.Results.Objects)) + + for _, obj := range res.Results.Objects { + result := DeleteResult{ + ID: obj.ID.String(), + } + + if obj.Status != nil { + switch *obj.Status { + case "SUCCESS": + result.Status = DeleteStatusSuccess + case "FAILED": + result.Status = DeleteStatusError + + if obj.Errors != nil { + var errorMsgs []string + for _, err := range obj.Errors.Error { + errorMsgs = append(errorMsgs, err.Message) + } + + result.Error = strings.Join(errorMsgs, ", ") + } + } + } + + results = append(results, result) + } + + return results, nil +} + +func (s *WeaviateStore) Close(ctx context.Context, className string) error { + // nothing to close + return nil +} + +// newWeaviateStore creates a new Weaviate vector store. +func newWeaviateStore(ctx context.Context, config *WeaviateConfig, logger schemas.Logger) (*WeaviateStore, error) { + // Validate required config + if config.Scheme == "" || config.Host == "" { + return nil, fmt.Errorf("weaviate scheme and host are required") + } + + // Build client configuration + cfg := weaviate.Config{ + Scheme: config.Scheme, + Host: config.Host, + } + + // Add authentication if provided + if config.ApiKey != "" { + cfg.AuthConfig = auth.ApiKey{Value: config.ApiKey} + } + + // Add grpc config if provided + if config.GrpcConfig != nil { + cfg.GrpcConfig = &grpc.Config{ + Host: config.GrpcConfig.Host, + Secured: config.GrpcConfig.Secured, + } + } + + // Add custom headers if provided + if len(config.Headers) > 0 { + cfg.Headers = config.Headers + } + + // Create client + client, err := weaviate.NewClient(cfg) + if err != nil { + return nil, fmt.Errorf("failed to create weaviate client: %w", err) + } + + // Test connection with meta endpoint + testCtx := ctx + if config.Timeout > 0 { + var cancel context.CancelFunc + testCtx, cancel = context.WithTimeout(ctx, config.Timeout) + defer cancel() + } + + _, err = client.Misc().MetaGetter().Do(testCtx) + if err != nil { + return nil, fmt.Errorf("failed to connect to weaviate: %w", err) + } + + store := &WeaviateStore{ + client: client, + config: config, + logger: logger, + } + + return store, nil +} + +func (s *WeaviateStore) CreateNamespace(ctx context.Context, className string, dimension int, properties map[string]VectorStoreProperties) error { + // Check if class exists + exists, err := s.client.Schema().ClassExistenceChecker(). + WithClassName(className). + Do(ctx) + if err != nil { + return fmt.Errorf("failed to check class existence: %w", err) + } + + if exists { + return nil // Schema already exists + } + + // Create properties + weaviateProperties := []*models.Property{} + for name, prop := range properties { + var dataType []string + switch prop.DataType { + case VectorStorePropertyTypeString: + dataType = []string{"string"} + case VectorStorePropertyTypeInteger: + dataType = []string{"int"} + case VectorStorePropertyTypeBoolean: + dataType = []string{"boolean"} + case VectorStorePropertyTypeStringArray: + dataType = []string{"string[]"} + } + + weaviateProperties = append(weaviateProperties, &models.Property{ + Name: name, + DataType: dataType, + Description: prop.Description, + }) + } + + // Create class schema with all fields we need + classSchema := &models.Class{ + Class: className, + Properties: weaviateProperties, + VectorIndexType: "hnsw", + Vectorizer: "none", // We provide our own vectors + } + + if dimension > 0 { + classSchema.VectorIndexConfig = map[string]interface{}{ + "vectorDimensions": dimension, + } + } + + err = s.client.Schema().ClassCreator(). + WithClass(classSchema). + Do(ctx) + if err != nil { + return fmt.Errorf("failed to create class schema: %w", err) + } + + return nil +} + +func (s *WeaviateStore) DeleteNamespace(ctx context.Context, className string) error { + exists, err := s.client.Schema().ClassExistenceChecker(). + WithClassName(className). + Do(ctx) + if err != nil { + return fmt.Errorf("failed to check class existence: %w", err) + } + if !exists { + return nil // Schema already does not exist + } else { + return s.client.Schema().ClassDeleter(). + WithClassName(className). + Do(ctx) + } +} + +// buildWeaviateFilter converts []Query β†’ Weaviate WhereFilter +func buildWeaviateFilter(queries []Query) *filters.WhereBuilder { + if len(queries) == 0 { + return nil + } + + var operands []*filters.WhereBuilder + for _, q := range queries { + // Convert string operator to filters operator + operator := convertOperator(q.Operator) + + fieldPath := strings.Split(q.Field, ".") + + whereClause := filters.Where(). + WithPath(fieldPath). + WithOperator(operator) + + // Special handling for IsNull and IsNotNull + switch q.Operator { + case QueryOperatorIsNull: + whereClause = whereClause.WithValueBoolean(true) + case QueryOperatorIsNotNull: + whereClause = whereClause.WithValueBoolean(false) + default: + // Set value based on type + switch v := q.Value.(type) { + case string: + whereClause = whereClause.WithValueString(v) + case int: + whereClause = whereClause.WithValueInt(int64(v)) + case int64: + whereClause = whereClause.WithValueInt(v) + case float32: + whereClause = whereClause.WithValueNumber(float64(v)) + case float64: + whereClause = whereClause.WithValueNumber(v) + case bool: + whereClause = whereClause.WithValueBoolean(v) + default: + // Fallback to string conversion + whereClause = whereClause.WithValueString(fmt.Sprintf("%v", v)) + } + } + + operands = append(operands, whereClause) + } + + if len(operands) == 1 { + return operands[0] + } + + // Create AND filter for multiple operands + return filters.Where(). + WithOperator(filters.And). + WithOperands(operands) +} + +// convertOperator converts string operator to filters operator +func convertOperator(op QueryOperator) filters.WhereOperator { + switch op { + case QueryOperatorEqual: + return filters.Equal + case QueryOperatorNotEqual: + return filters.NotEqual + case QueryOperatorLessThan: + return filters.LessThan + case QueryOperatorLessThanOrEqual: + return filters.LessThanEqual + case QueryOperatorGreaterThan: + return filters.GreaterThan + case QueryOperatorGreaterThanOrEqual: + return filters.GreaterThanEqual + case QueryOperatorLike: + return filters.Like + case QueryOperatorContainsAny: + return filters.ContainsAny + case QueryOperatorContainsAll: + return filters.ContainsAll + case QueryOperatorIsNull: + return filters.IsNull + case QueryOperatorIsNotNull: // IsNotNull is not supported by Weaviate, so we use IsNull and negate it. + return filters.IsNull + default: + // Default to Equal if unknown + return filters.Equal + } +} diff --git a/framework/vectorstore/weaviate_test.go b/framework/vectorstore/weaviate_test.go new file mode 100644 index 000000000..aa45076cf --- /dev/null +++ b/framework/vectorstore/weaviate_test.go @@ -0,0 +1,814 @@ +package vectorstore + +import ( + "context" + "os" + "testing" + "time" + + bifrost "github.com/maximhq/bifrost/core" + "github.com/maximhq/bifrost/core/schemas" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "github.com/weaviate/weaviate-go-client/v5/weaviate/filters" + "github.com/weaviate/weaviate/entities/models" +) + +// Test constants +const ( + TestTimeout = 30 * time.Second + TestClassName = "TestWeaviate" + TestEmbeddingDim = 384 + DefaultTestScheme = "http" + DefaultTestHost = "localhost:9000" + DefaultTestTimeout = 10 * time.Second +) + +// TestSetup provides common test infrastructure +type TestSetup struct { + Store *WeaviateStore + Logger schemas.Logger + Config WeaviateConfig + ctx context.Context + cancel context.CancelFunc +} + +// NewTestSetup creates a test setup with environment-driven configuration +func NewTestSetup(t *testing.T) *TestSetup { + // Get configuration from environment variables + scheme := getEnvWithDefault("WEAVIATE_SCHEME", DefaultTestScheme) + host := getEnvWithDefault("WEAVIATE_HOST", DefaultTestHost) + apiKey := os.Getenv("WEAVIATE_API_KEY") + + timeoutStr := getEnvWithDefault("WEAVIATE_TIMEOUT", "10s") + timeout, err := time.ParseDuration(timeoutStr) + if err != nil { + timeout = DefaultTestTimeout + } + + config := WeaviateConfig{ + Scheme: scheme, + Host: host, + ApiKey: apiKey, + Timeout: timeout, + } + + logger := bifrost.NewDefaultLogger(schemas.LogLevelInfo) + ctx, cancel := context.WithTimeout(context.Background(), TestTimeout) + + store, err := newWeaviateStore(ctx, &config, logger) + if err != nil { + cancel() + t.Fatalf("Failed to create Weaviate store: %v", err) + } + + setup := &TestSetup{ + Store: store, + Logger: logger, + Config: config, + ctx: ctx, + cancel: cancel, + } + + // Ensure class exists for integration tests + if !testing.Short() { + setup.ensureClassExists(t) + } + + return setup +} + +// Cleanup cleans up test resources +func (ts *TestSetup) Cleanup(t *testing.T) { + defer ts.cancel() + + if !testing.Short() { + // Clean up test data + ts.cleanupTestData(t) + } + + if err := ts.Store.Close(ts.ctx, TestClassName); err != nil { + t.Logf("Warning: Failed to close store: %v", err) + } +} + +// ensureClassExists creates the test class in Weaviate +func (ts *TestSetup) ensureClassExists(t *testing.T) { + // Try to get class schema first + exists, err := ts.Store.client.Schema().ClassGetter(). + WithClassName(TestClassName). + Do(ts.ctx) + + if err == nil && exists != nil { + t.Logf("Class %s already exists", TestClassName) + return + } + + // Create class with minimal schema - let Weaviate auto-create properties + class := &models.Class{ + Class: TestClassName, + Properties: []*models.Property{ + { + Name: "key", + DataType: []string{"text"}, + }, + { + Name: "test_type", + DataType: []string{"text"}, + }, + { + Name: "size", + DataType: []string{"int"}, + }, + { + Name: "public", + DataType: []string{"boolean"}, + }, + }, + VectorIndexConfig: map[string]interface{}{ + "distance": "cosine", + }, + } + + err = ts.Store.client.Schema().ClassCreator(). + WithClass(class). + Do(ts.ctx) + + if err != nil { + t.Logf("Warning: Failed to create test class %s: %v", TestClassName, err) + t.Logf("This might be due to auto-schema creation. Continuing...") + } else { + t.Logf("Created test class: %s", TestClassName) + } +} + +// cleanupTestData removes all test objects from the class +func (ts *TestSetup) cleanupTestData(t *testing.T) { + // Delete all objects in the test class + allTestKeys, _, err := ts.Store.GetAll(ts.ctx, TestClassName, []Query{}, []string{}, nil, 1000) + if err != nil { + t.Logf("Warning: Failed to get all test keys: %v", err) + return + } + + for _, key := range allTestKeys { + err := ts.Store.Delete(ts.ctx, TestClassName, key.ID) + if err != nil { + t.Logf("Warning: Failed to delete test key %s: %v", key.ID, err) + } + } + + t.Logf("Cleaned up test class: %s", TestClassName) +} + +// ============================================================================ +// UNIT TESTS +// ============================================================================ + +func TestWeaviateConfig_Validation(t *testing.T) { + logger := bifrost.NewDefaultLogger(schemas.LogLevelInfo) + ctx := context.Background() + + tests := []struct { + name string + config WeaviateConfig + expectError bool + errorMsg string + }{ + { + name: "valid config", + config: WeaviateConfig{ + Scheme: "http", + Host: "localhost:8080", + }, + expectError: false, + }, + { + name: "missing scheme", + config: WeaviateConfig{ + Host: "localhost:8080", + }, + expectError: true, + errorMsg: "scheme and host are required", + }, + { + name: "missing host", + config: WeaviateConfig{ + Scheme: "http", + }, + expectError: true, + errorMsg: "scheme and host are required", + }, + { + name: "with api key", + config: WeaviateConfig{ + Scheme: "https", + Host: "cluster.weaviate.network", + ApiKey: "test-key", + }, + expectError: false, + }, + { + name: "with custom headers", + config: WeaviateConfig{ + Scheme: "http", + Host: "localhost:8080", + Headers: map[string]string{ + "Custom-Header": "value", + }, + }, + expectError: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + store, err := newWeaviateStore(ctx, &tt.config, logger) + + if tt.expectError { + assert.Error(t, err) + assert.Nil(t, store) + if tt.errorMsg != "" { + assert.Contains(t, err.Error(), tt.errorMsg) + } + } else { + // Note: This will fail with connection error in unit tests + // but should pass config validation + assert.Nil(t, store) // Expected due to no real Weaviate instance + assert.Error(t, err) // Connection error expected + } + }) + } +} + +func TestDefaultClassName(t *testing.T) { + config := WeaviateConfig{ + Scheme: "http", + Host: "localhost:8080", + } + + // This will fail to connect but should set default class name + logger := bifrost.NewDefaultLogger(schemas.LogLevelInfo) + _, err := newWeaviateStore(context.Background(), &config, logger) + + // Should fail with connection error, but we can't test the default class name + // without mocking the client, which would be more complex + assert.Error(t, err) +} + +func TestBuildWeaviateFilter(t *testing.T) { + tests := []struct { + name string + queries []Query + expected *filters.WhereBuilder // We'll test the structure, not exact equality + isNil bool + }{ + { + name: "empty queries", + queries: []Query{}, + expected: nil, + isNil: true, + }, + { + name: "single string query", + queries: []Query{ + {Field: "category", Operator: QueryOperatorEqual, Value: "tech"}, + }, + isNil: false, + }, + { + name: "single numeric query", + queries: []Query{ + {Field: "size", Operator: QueryOperatorGreaterThan, Value: 1000}, + }, + isNil: false, + }, + { + name: "multiple queries (AND)", + queries: []Query{ + {Field: "category", Operator: QueryOperatorEqual, Value: "tech"}, + {Field: "public", Operator: QueryOperatorEqual, Value: true}, + }, + isNil: false, + }, + { + name: "mixed types", + queries: []Query{ + {Field: "name", Operator: QueryOperatorLike, Value: "test%"}, + {Field: "count", Operator: QueryOperatorLessThan, Value: int64(100)}, + {Field: "active", Operator: QueryOperatorEqual, Value: true}, + {Field: "score", Operator: QueryOperatorGreaterThanOrEqual, Value: 95.5}, + }, + isNil: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := buildWeaviateFilter(tt.queries) + + if tt.isNil { + assert.Nil(t, result) + } else { + assert.NotNil(t, result) + // We can't easily test the internal structure without reflection + // or implementing String() methods, but we verify it's not nil + } + }) + } +} + +func TestConvertOperator(t *testing.T) { + tests := []struct { + input QueryOperator + expected filters.WhereOperator + }{ + {QueryOperatorEqual, filters.Equal}, + {QueryOperatorNotEqual, filters.NotEqual}, + {QueryOperatorLessThan, filters.LessThan}, + {QueryOperatorLessThanOrEqual, filters.LessThanEqual}, + {QueryOperatorGreaterThan, filters.GreaterThan}, + {QueryOperatorGreaterThanOrEqual, filters.GreaterThanEqual}, + {QueryOperatorLike, filters.Like}, + {QueryOperatorContainsAny, filters.ContainsAny}, + {QueryOperatorContainsAll, filters.ContainsAll}, + {QueryOperatorIsNull, filters.IsNull}, + {QueryOperatorIsNotNull, filters.IsNull}, + } + + for _, tt := range tests { + t.Run(string(tt.input), func(t *testing.T) { + result := convertOperator(tt.input) + assert.Equal(t, tt.expected, result) + }) + } +} + +// ============================================================================ +// INTEGRATION TESTS (require real Weaviate instance) +// ============================================================================ + +func TestWeaviateStore_Integration(t *testing.T) { + if testing.Short() { + t.Skip("Skipping integration tests in short mode") + } + + setup := NewTestSetup(t) + defer setup.Cleanup(t) + + t.Run("Add and GetChunk", func(t *testing.T) { + testKey := generateUUID() + embedding := generateTestEmbedding(TestEmbeddingDim) + metadata := map[string]interface{}{ + "type": "document", + "size": 1024, + "public": true, + } + + // Add object + err := setup.Store.Add(setup.ctx, TestClassName, testKey, embedding, metadata) + require.NoError(t, err) + + // Small delay to ensure consistency + time.Sleep(100 * time.Millisecond) + + // Get single chunk + result, err := setup.Store.GetChunk(setup.ctx, TestClassName, testKey) + require.NoError(t, err) + assert.NotEmpty(t, result) + assert.Equal(t, "document", result.Properties["type"]) // Should contain metadata + }) + + t.Run("Add without embedding", func(t *testing.T) { + testKey := generateUUID() + metadata := map[string]interface{}{ + "type": "metadata-only", + } + + // Add object without embedding + err := setup.Store.Add(setup.ctx, TestClassName, testKey, nil, metadata) + require.NoError(t, err) + + time.Sleep(100 * time.Millisecond) + + // Retrieve it + result, err := setup.Store.GetChunk(setup.ctx, TestClassName, testKey) + require.NoError(t, err) + assert.Equal(t, "metadata-only", result.Properties["type"]) + }) +} + +func TestWeaviateStore_FilteringScenarios(t *testing.T) { + if testing.Short() { + t.Skip("Skipping integration tests in short mode") + } + + setup := NewTestSetup(t) + defer setup.Cleanup(t) + + // Setup test data for filtering scenarios + testData := []struct { + key string + metadata map[string]interface{} + }{ + { + generateUUID(), + map[string]interface{}{ + "type": "pdf", + "size": 1024, + "public": true, + "author": "alice", + }, + }, + { + generateUUID(), + map[string]interface{}{ + "type": "docx", + "size": 2048, + "public": false, + "author": "bob", + }, + }, + { + generateUUID(), + map[string]interface{}{ + "type": "pdf", + "size": 512, + "public": true, + "author": "alice", + }, + }, + { + generateUUID(), + map[string]interface{}{ + "type": "txt", + "size": 256, + "public": true, + "author": "charlie", + }, + }, + } + + filterFields := []string{"type", "size", "public", "author"} + + // Add all test data + for _, item := range testData { + embedding := generateTestEmbedding(TestEmbeddingDim) + err := setup.Store.Add(setup.ctx, TestClassName, item.key, embedding, item.metadata) + require.NoError(t, err) + } + + time.Sleep(500 * time.Millisecond) // Wait for consistency + + t.Run("Filter by numeric comparison", func(t *testing.T) { + queries := []Query{ + {Field: "size", Operator: "GreaterThan", Value: 1000}, + } + + results, _, err := setup.Store.GetAll(setup.ctx, TestClassName, queries, filterFields, nil, 10) + require.NoError(t, err) + assert.Len(t, results, 2) // doc1 (1024) and doc2 (2048) + }) + + t.Run("Filter by boolean", func(t *testing.T) { + queries := []Query{ + {Field: "public", Operator: "Equal", Value: true}, + } + + results, _, err := setup.Store.GetAll(setup.ctx, TestClassName, queries, filterFields, nil, 10) + require.NoError(t, err) + assert.Len(t, results, 3) // doc1, doc3, doc4 + }) + + t.Run("Multiple filters (AND)", func(t *testing.T) { + queries := []Query{ + {Field: "type", Operator: "Equal", Value: "pdf"}, + {Field: "public", Operator: "Equal", Value: true}, + } + + results, _, err := setup.Store.GetAll(setup.ctx, TestClassName, queries, filterFields, nil, 10) + require.NoError(t, err) + assert.Len(t, results, 2) // doc1 and doc3 + }) + + t.Run("Complex multi-condition filter", func(t *testing.T) { + queries := []Query{ + {Field: "author", Operator: "Equal", Value: "alice"}, + {Field: "size", Operator: "LessThan", Value: 2000}, + {Field: "public", Operator: "Equal", Value: true}, + } + + results, _, err := setup.Store.GetAll(setup.ctx, TestClassName, queries, filterFields, nil, 10) + require.NoError(t, err) + assert.Len(t, results, 2) // doc1 and doc3 (both by alice, < 2000 size, public) + }) + + t.Run("Pagination test", func(t *testing.T) { + // Test with limit of 2 + results, cursor, err := setup.Store.GetAll(setup.ctx, TestClassName, nil, filterFields, nil, 2) + require.NoError(t, err) + assert.Len(t, results, 2) + + if cursor != nil { + // Get next page + nextResults, _, err := setup.Store.GetAll(setup.ctx, TestClassName, nil, filterFields, cursor, 2) + require.NoError(t, err) + assert.LessOrEqual(t, len(nextResults), 2) + t.Logf("First page: %d results, Next page: %d results", len(results), len(nextResults)) + } + }) +} + +func TestWeaviateStore_CompleteUseCases(t *testing.T) { + if testing.Short() { + t.Skip("Skipping integration tests in short mode") + } + + setup := NewTestSetup(t) + defer setup.Cleanup(t) + + t.Run("Document Storage & Retrieval Scenario", func(t *testing.T) { + // Add documents with different types + documents := []struct { + key string + embedding []float32 + metadata map[string]interface{} + }{ + { + generateUUID(), + generateTestEmbedding(TestEmbeddingDim), + map[string]interface{}{"type": "pdf", "size": 1024, "public": true}, + }, + { + generateUUID(), + generateTestEmbedding(TestEmbeddingDim), + map[string]interface{}{"type": "docx", "size": 2048, "public": false}, + }, + { + generateUUID(), + generateTestEmbedding(TestEmbeddingDim), + map[string]interface{}{"type": "pdf", "size": 512, "public": true}, + }, + } + + filterFields := []string{"type", "size", "public", "author"} + + for _, doc := range documents { + err := setup.Store.Add(setup.ctx, TestClassName, doc.key, doc.embedding, doc.metadata) + require.NoError(t, err) + } + + time.Sleep(300 * time.Millisecond) + + // Test various retrieval patterns + + // Get PDF documents + pdfQuery := []Query{{Field: "type", Operator: "Equal", Value: "pdf"}} + results, _, err := setup.Store.GetAll(setup.ctx, TestClassName, pdfQuery, filterFields, nil, 10) + require.NoError(t, err) + assert.Len(t, results, 2) // doc1, doc3 + + // Get large documents (size > 1000) + sizeQuery := []Query{{Field: "size", Operator: "GreaterThan", Value: 1000}} + results, _, err = setup.Store.GetAll(setup.ctx, TestClassName, sizeQuery, filterFields, nil, 10) + require.NoError(t, err) + assert.Len(t, results, 2) // doc1, doc2 + + // Get public PDFs + combinedQuery := []Query{ + {Field: "public", Operator: "Equal", Value: true}, + {Field: "type", Operator: "Equal", Value: "pdf"}, + } + results, _, err = setup.Store.GetAll(setup.ctx, TestClassName, combinedQuery, filterFields, nil, 10) + require.NoError(t, err) + assert.Len(t, results, 2) // doc1, doc3 + + // Vector similarity search + queryEmbedding := documents[0].embedding // Similar to doc1 + vectorResults, err := setup.Store.GetNearest(setup.ctx, TestClassName, queryEmbedding, nil, filterFields, 0.8, 10) + require.NoError(t, err) + assert.GreaterOrEqual(t, len(vectorResults), 1) + }) + + t.Run("User Content Management Scenario", func(t *testing.T) { + // Add user content with metadata + userContent := []struct { + key string + embedding []float32 + metadata map[string]interface{} + }{ + { + generateUUID(), + generateTestEmbedding(TestEmbeddingDim), + map[string]interface{}{"user": "alice", "lang": "en", "category": "tech"}, + }, + { + generateUUID(), + generateTestEmbedding(TestEmbeddingDim), + map[string]interface{}{"user": "bob", "lang": "es", "category": "tech"}, + }, + { + generateUUID(), + generateTestEmbedding(TestEmbeddingDim), + map[string]interface{}{"user": "alice", "lang": "en", "category": "sports"}, + }, + } + + filterFields := []string{"user", "lang", "category"} + + for _, content := range userContent { + err := setup.Store.Add(setup.ctx, TestClassName, content.key, content.embedding, content.metadata) + require.NoError(t, err) + } + + time.Sleep(300 * time.Millisecond) + + // Test user-specific filtering + aliceQuery := []Query{{Field: "user", Operator: "Equal", Value: "alice"}} + results, _, err := setup.Store.GetAll(setup.ctx, TestClassName, aliceQuery, filterFields, nil, 10) + require.NoError(t, err) + assert.Len(t, results, 2) // Alice's content + + // English tech content + techEnQuery := []Query{ + {Field: "lang", Operator: "Equal", Value: "en"}, + {Field: "category", Operator: "Equal", Value: "tech"}, + } + results, _, err = setup.Store.GetAll(setup.ctx, TestClassName, techEnQuery, filterFields, nil, 10) + require.NoError(t, err) + assert.Len(t, results, 1) // user1_content + + // Alice's similar content (semantic search with user filter) + aliceFilter := []Query{{Field: "user", Operator: "Equal", Value: "alice"}} + queryEmbedding := userContent[0].embedding + vectorResults, err := setup.Store.GetNearest(setup.ctx, TestClassName, queryEmbedding, aliceFilter, filterFields, 0.1, 10) + require.NoError(t, err) + assert.Len(t, vectorResults, 2) // Both of Alice's content + }) + + t.Run("Semantic Cache-like Workflow", func(t *testing.T) { + // Add request-response pairs with parameters + cacheEntries := []struct { + key string + embedding []float32 + metadata map[string]interface{} + }{ + { + generateUUID(), + generateTestEmbedding(TestEmbeddingDim), + map[string]interface{}{ + "request_hash": "abc123", + "user": "u1", + "lang": "en", + "response": "answer1", + }, + }, + { + generateUUID(), + generateTestEmbedding(TestEmbeddingDim), + map[string]interface{}{ + "request_hash": "def456", + "user": "u1", + "lang": "es", + "response": "answer2", + }, + }, + } + + filterFields := []string{"request_hash", "user", "lang", "response"} + + for _, entry := range cacheEntries { + err := setup.Store.Add(setup.ctx, TestClassName, entry.key, entry.embedding, entry.metadata) + require.NoError(t, err) + } + + time.Sleep(300 * time.Millisecond) + + // Test hash-based direct retrieval (exact match) + hashQuery := []Query{{Field: "request_hash", Operator: "Equal", Value: "abc123"}} + results, _, err := setup.Store.GetAll(setup.ctx, TestClassName, hashQuery, filterFields, nil, 10) + require.NoError(t, err) + assert.Len(t, results, 1) + + // Test semantic search with user and language filters + userLangFilter := []Query{ + {Field: "user", Operator: "Equal", Value: "u1"}, + {Field: "lang", Operator: "Equal", Value: "en"}, + } + similarEmbedding := generateSimilarEmbedding(cacheEntries[0].embedding, 0.9) + vectorResults, err := setup.Store.GetNearest(setup.ctx, TestClassName, similarEmbedding, userLangFilter, filterFields, 0.7, 10) + require.NoError(t, err) + assert.Len(t, vectorResults, 1) // Should find English content for u1 + }) +} + +// ============================================================================ +// INTERFACE COMPLIANCE TESTS +// ============================================================================ + +func TestWeaviateStore_InterfaceCompliance(t *testing.T) { + // Verify that WeaviateStore implements VectorStore interface + var _ VectorStore = (*WeaviateStore)(nil) +} + +func TestVectorStoreFactory_Weaviate(t *testing.T) { + if testing.Short() { + t.Skip("Skipping integration tests in short mode") + } + + logger := bifrost.NewDefaultLogger(schemas.LogLevelInfo) + config := &Config{ + Enabled: true, + Type: VectorStoreTypeWeaviate, + Config: WeaviateConfig{ + Scheme: getEnvWithDefault("WEAVIATE_SCHEME", DefaultTestScheme), + Host: getEnvWithDefault("WEAVIATE_HOST", DefaultTestHost), + ApiKey: os.Getenv("WEAVIATE_API_KEY"), + }, + } + + store, err := NewVectorStore(context.Background(), config, logger) + if err != nil { + t.Skipf("Could not create Weaviate store: %v", err) + } + defer store.Close(context.Background(), TestClassName) + + // Verify it's actually a WeaviateStore + weaviateStore, ok := store.(*WeaviateStore) + assert.True(t, ok) + assert.NotNil(t, weaviateStore) +} + +func TestWeaviateStore_NamespaceDimensionHandling(t *testing.T) { + if testing.Short() { + t.Skip("Skipping integration tests in short mode") + } + + setup := NewTestSetup(t) + defer setup.Cleanup(t) + + testClassName := "TestDimensionHandling" + + t.Run("Recreate class with different dimension should not crash", func(t *testing.T) { + properties := map[string]VectorStoreProperties{ + "type": {DataType: VectorStorePropertyTypeString}, + "test": {DataType: VectorStorePropertyTypeString}, + } + + // Step 1: Create class with dimension 512 + err := setup.Store.CreateNamespace(setup.ctx, testClassName, 512, properties) + require.NoError(t, err) + + // Add a document with 512-dimensional embedding + testKey512 := generateUUID() + embedding512 := generateTestEmbedding(512) + metadata := map[string]interface{}{ + "type": "test_doc", + "test": "dimension_512", + } + + err = setup.Store.Add(setup.ctx, testClassName, testKey512, embedding512, metadata) + require.NoError(t, err) + + // Verify it was added + result, err := setup.Store.GetChunk(setup.ctx, testClassName, testKey512) + require.NoError(t, err) + assert.Equal(t, "dimension_512", result.Properties["test"]) + + // Step 2: Delete the class/namespace + err = setup.Store.DeleteNamespace(setup.ctx, testClassName) + require.NoError(t, err) + + // Step 3: Create class with same name but different dimension - should not crash + err = setup.Store.CreateNamespace(setup.ctx, testClassName, 1024, properties) + require.NoError(t, err) + + // Add a document with 1024-dimensional embedding + testKey1024 := generateUUID() + embedding1024 := generateTestEmbedding(1024) + metadata1024 := map[string]interface{}{ + "type": "test_doc", + "test": "dimension_1024", + } + + err = setup.Store.Add(setup.ctx, testClassName, testKey1024, embedding1024, metadata1024) + require.NoError(t, err) + + // Verify new document exists + result, err = setup.Store.GetChunk(setup.ctx, testClassName, testKey1024) + require.NoError(t, err) + assert.Equal(t, "dimension_1024", result.Properties["test"]) + + // Verify vector search works with new dimension + vectorResults, err := setup.Store.GetNearest(setup.ctx, testClassName, embedding1024, nil, []string{"type", "test"}, 0.8, 10) + require.NoError(t, err) + assert.GreaterOrEqual(t, len(vectorResults), 1) + assert.NotNil(t, vectorResults[0].Score) + + // Cleanup + err = setup.Store.DeleteNamespace(setup.ctx, testClassName) + if err != nil { + t.Logf("Warning: Failed to cleanup class: %v", err) + } + }) +} diff --git a/framework/version b/framework/version new file mode 100644 index 000000000..154b9fce5 --- /dev/null +++ b/framework/version @@ -0,0 +1 @@ +1.0.23 diff --git a/npx/bin.js b/npx/bin.js new file mode 100644 index 000000000..1e2ee9456 --- /dev/null +++ b/npx/bin.js @@ -0,0 +1,221 @@ +#!/usr/bin/env node + +import { execFileSync } from "child_process"; +import { chmodSync, createWriteStream, existsSync, fsyncSync } from "fs"; +import { tmpdir } from "os"; +import { join } from "path"; +import { Readable } from "stream"; + +const BASE_URL = "https://downloads.getmaxim.ai"; + +// Parse transport version from command line arguments +function parseTransportVersion() { + const args = process.argv.slice(2); + let transportVersion = "latest"; // Default to latest + + // Find --transport-version argument + const versionArgIndex = args.findIndex(arg => arg.startsWith("--transport-version")); + + if (versionArgIndex !== -1) { + const versionArg = args[versionArgIndex]; + + if (versionArg.includes("=")) { + // Format: --transport-version=v1.2.3 + transportVersion = versionArg.split("=")[1]; + } else if (versionArgIndex + 1 < args.length) { + // Format: --transport-version v1.2.3 + transportVersion = args[versionArgIndex + 1]; + } + + // Remove the transport-version arguments from args array so they don't get passed to the binary + if (versionArg.includes("=")) { + args.splice(versionArgIndex, 1); + } else { + args.splice(versionArgIndex, 2); + } + } + + return { version: validateTransportVersion(transportVersion), remainingArgs: args }; +} + +// Validate transport version format +function validateTransportVersion(version) { + if (version === "latest") { + return version; + } + + // Check if version matches v{x.x.x} format + const versionRegex = /^v\d+\.\d+\.\d+(?:-[0-9A-Za-z.-]+)?$/; + if (versionRegex.test(version)) { + return version; + } + + console.error(`Invalid transport version format: ${version}`); + console.error(`Transport version must be either "latest", "v1.2.3", or "v1.2.3-prerelease1"`); + process.exit(1); +} + +const { version: VERSION, remainingArgs } = parseTransportVersion(); + +async function getPlatformArchAndBinary() { + const platform = process.platform; + const arch = process.arch; + + let platformDir; + let archDir; + let binaryName; + + if (platform === "darwin") { + platformDir = "darwin"; + if (arch === "arm64") archDir = "arm64"; + else archDir = "amd64"; + binaryName = "bifrost-http"; + } else if (platform === "linux") { + platformDir = "linux"; + if (arch === "x64") archDir = "amd64"; + else if (arch === "ia32") archDir = "386"; + else archDir = arch; // fallback + binaryName = "bifrost-http"; + } else if (platform === "win32") { + platformDir = "windows"; + if (arch === "x64") archDir = "amd64"; + else if (arch === "ia32") archDir = "386"; + else archDir = arch; // fallback + binaryName = "bifrost-http.exe"; + } else { + console.error(`Unsupported platform/arch: ${platform}/${arch}`); + process.exit(1); + } + + return { platformDir, archDir, binaryName }; +} + +async function downloadBinary(url, dest) { + // console.log(`πŸ”„ Downloading binary from ${url}...`); + + const res = await fetch(url); + + if (!res.ok) { + console.error(`❌ Download failed: ${res.status} ${res.statusText}`); + process.exit(1); + } + + const contentLength = res.headers.get('content-length'); + const totalSize = contentLength ? parseInt(contentLength, 10) : null; + let downloadedSize = 0; + + const fileStream = createWriteStream(dest, { flags: "w" }); + await new Promise((resolve, reject) => { + try { + // Convert the fetch response body to a Node.js readable stream + const nodeStream = Readable.fromWeb(res.body); + + // Add progress tracking + nodeStream.on('data', (chunk) => { + downloadedSize += chunk.length; + if (totalSize) { + const progress = ((downloadedSize / totalSize) * 100).toFixed(1); + process.stdout.write(`\r⏱️ Downloading Binary: ${progress}% (${formatBytes(downloadedSize)}/${formatBytes(totalSize)})`); + } else { + process.stdout.write(`\r⏱️ Downloaded: ${formatBytes(downloadedSize)}`); + } + }); + + nodeStream.pipe(fileStream); + fileStream.on("finish", () => { + process.stdout.write('\n'); + + // Ensure file is fully written to disk + try { + fsyncSync(fileStream.fd); + } catch (syncError) { + // fsync might fail on some systems, ignore + } + + resolve(); + }); + fileStream.on("error", reject); + nodeStream.on("error", reject); + } catch (error) { + reject(error); + } + }); + + chmodSync(dest, 0o755); +} + +function formatBytes(bytes) { + if (bytes === 0) return '0 B'; + const k = 1024; + const sizes = ['B', 'KB', 'MB', 'GB']; + const i = Math.floor(Math.log(bytes) / Math.log(k)); + return parseFloat((bytes / Math.pow(k, i)).toFixed(1)) + ' ' + sizes[i]; +} + +(async () => { + const platformInfo = await getPlatformArchAndBinary(); + const { platformDir, archDir, binaryName } = platformInfo; + + // For future use when we want to add multiple fallback binaries + const downloadUrls = []; + + downloadUrls.push(`${BASE_URL}/bifrost/${VERSION}/${platformDir}/${archDir}/${binaryName}`); + + let lastError = null; + let binaryWorking = false; + + for (let i = 0; i < downloadUrls.length; i++) { + const downloadUrl = downloadUrls[i]; + // Use unique file path for each attempt to avoid ETXTBSY + const binaryPath = join(tmpdir(), `${binaryName}-${i}`); + + try { + await downloadBinary(downloadUrl, binaryPath); + + // Verify the binary is executable before trying to run it + if (!existsSync(binaryPath)) { + throw new Error(`Binary not found at: ${binaryPath}`); + } + + // Add a small delay to ensure file is fully written and not busy + await new Promise(resolve => setTimeout(resolve, 100)); + + // Test if the binary can execute + try { + execFileSync(binaryPath, remainingArgs, { stdio: "inherit" }); + binaryWorking = true; + break; + } catch (execError) { + // If execution fails (ENOENT, ETXTBSY, etc.), try next binary + lastError = execError; + continue; + } + } catch (downloadError) { + lastError = downloadError; + // Continue to next URL silently + } + } + + if (!binaryWorking) { + console.error(`❌ Failed to start Bifrost. Error:`, lastError.message); + + // Show critical error details for troubleshooting + if (lastError.code) { + console.error(`Error code: ${lastError.code}`); + } + if (lastError.errno) { + console.error(`System error: ${lastError.errno}`); + } + if (lastError.signal) { + console.error(`Signal: ${lastError.signal}`); + } + + // For specific Linux issues, show diagnostic info + if (process.platform === 'linux' && (lastError.code === 'ENOENT' || lastError.code === 'ETXTBSY')) { + console.error(`\nπŸ’‘ This appears to be a Linux compatibility issue.`); + console.error(` The binary may be incompatible with your Linux distribution.`); + } + + process.exit(lastError.status || 1); + } +})(); diff --git a/npx/package-lock.json b/npx/package-lock.json new file mode 100644 index 000000000..0dfb91807 --- /dev/null +++ b/npx/package-lock.json @@ -0,0 +1,19 @@ +{ + "name": "@maximhq/bifrost", + "version": "1.0.4", + "lockfileVersion": 3, + "requires": true, + "packages": { + "": { + "name": "@maximhq/bifrost", + "version": "1.0.4", + "license": "Apache-2.0", + "bin": { + "bifrost": "bin.js" + }, + "engines": { + "node": ">=18.0.0" + } + } + } +} diff --git a/npx/package.json b/npx/package.json new file mode 100644 index 000000000..8c8a7c558 --- /dev/null +++ b/npx/package.json @@ -0,0 +1,24 @@ +{ + "name": "@maximhq/bifrost", + "version": "1.0.5", + "description": "High-performance AI gateway CLI - connect to 12+ providers through a single API", + "keywords": ["ai", "gateway", "openai", "anthropic", "cli", "bifrost"], + "homepage": "https://github.com/maximhq/bifrost", + "repository": { + "type": "git", + "url": "https://github.com/maximhq/bifrost.git" + }, + "license": "Apache-2.0", + "author": "Maxim HQ", + "engines": { + "node": ">=18.0.0" + }, + "publishConfig": { + "access": "public" + }, + "bin": { + "bifrost": "bin.js" + }, + "type": "module", + "dependencies": {} +} \ No newline at end of file diff --git a/plugins/go.mod b/plugins/go.mod deleted file mode 100644 index 82e50b301..000000000 --- a/plugins/go.mod +++ /dev/null @@ -1,8 +0,0 @@ -module github.com/maximhq/bifrost/plugins - -go 1.24.1 - -require ( - github.com/maximhq/bifrost/core v1.0.1 - github.com/maximhq/maxim-go v0.1.1 -) diff --git a/plugins/go.sum b/plugins/go.sum deleted file mode 100644 index b8cb7b66e..000000000 --- a/plugins/go.sum +++ /dev/null @@ -1,4 +0,0 @@ -github.com/maximhq/bifrost/core v1.0.1 h1:B0u6o13faUexA+V0EUU0bsLW2dHg9+R2TZKQzPzCxlY= -github.com/maximhq/bifrost/core v1.0.1/go.mod h1:4+Ept2EnX1EEjH/mBuSwK7eE56znI/BCoCbIrx25/x8= -github.com/maximhq/maxim-go v0.1.1 h1:69uUQjjDPmUGcKg/M4/3AO0fbD+70Agt66pH/UCsI5M= -github.com/maximhq/maxim-go v0.1.1/go.mod h1:0+UTWM7UZwNNE5VnljLtr/vpRGtYP8r/2q9WDwlLWFw= diff --git a/plugins/governance/changelog.md b/plugins/governance/changelog.md new file mode 100644 index 000000000..6dcfe4edd --- /dev/null +++ b/plugins/governance/changelog.md @@ -0,0 +1,4 @@ + + + +- Upgrades framework to 1.0.23 \ No newline at end of file diff --git a/plugins/governance/go.mod b/plugins/governance/go.mod new file mode 100644 index 000000000..9501860d4 --- /dev/null +++ b/plugins/governance/go.mod @@ -0,0 +1,89 @@ +module github.com/maximhq/bifrost/plugins/governance + +go 1.24.1 + +toolchain go1.24.3 + +require gorm.io/gorm v1.30.1 + +require ( + github.com/maximhq/bifrost/core v1.1.37 + github.com/maximhq/bifrost/framework v1.0.23 +) + +require ( + cloud.google.com/go/compute/metadata v0.8.0 // indirect + github.com/andybalholm/brotli v1.2.0 // indirect + github.com/asaskevich/govalidator v0.0.0-20230301143203-a9d515a09cc2 // indirect + github.com/aws/aws-sdk-go-v2 v1.38.0 // indirect + github.com/aws/aws-sdk-go-v2/config v1.31.0 // indirect + github.com/aws/aws-sdk-go-v2/credentials v1.18.4 // indirect + github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.18.3 // indirect + github.com/aws/aws-sdk-go-v2/internal/configsources v1.4.3 // indirect + github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.7.3 // indirect + github.com/aws/aws-sdk-go-v2/internal/ini v1.8.3 // indirect + github.com/aws/aws-sdk-go-v2/service/internal/accept-encoding v1.13.0 // indirect + github.com/aws/aws-sdk-go-v2/service/internal/presigned-url v1.13.3 // indirect + github.com/aws/aws-sdk-go-v2/service/sso v1.28.0 // indirect + github.com/aws/aws-sdk-go-v2/service/ssooidc v1.33.0 // indirect + github.com/aws/aws-sdk-go-v2/service/sts v1.37.0 // indirect + github.com/aws/smithy-go v1.22.5 // indirect + github.com/bahlo/generic-list-go v0.2.0 // indirect + github.com/buger/jsonparser v1.1.1 // indirect + github.com/bytedance/sonic v1.14.0 // indirect + github.com/bytedance/sonic/loader v0.3.0 // indirect + github.com/cespare/xxhash/v2 v2.3.0 // indirect + github.com/cloudwego/base64x v0.1.6 // indirect + github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f // indirect + github.com/go-openapi/analysis v0.23.0 // indirect + github.com/go-openapi/errors v0.22.0 // indirect + github.com/go-openapi/jsonpointer v0.21.0 // indirect + github.com/go-openapi/jsonreference v0.21.0 // indirect + github.com/go-openapi/loads v0.22.0 // indirect + github.com/go-openapi/runtime v0.24.2 // indirect + github.com/go-openapi/spec v0.21.0 // indirect + github.com/go-openapi/strfmt v0.23.0 // indirect + github.com/go-openapi/swag v0.23.0 // indirect + github.com/go-openapi/validate v0.24.0 // indirect + github.com/google/uuid v1.6.0 // indirect + github.com/invopop/jsonschema v0.13.0 // indirect + github.com/jinzhu/inflection v1.0.0 // indirect + github.com/jinzhu/now v1.1.5 // indirect + github.com/josharian/intern v1.0.0 // indirect + github.com/klauspost/compress v1.18.0 // indirect + github.com/klauspost/cpuid/v2 v2.3.0 // indirect + github.com/mailru/easyjson v0.9.0 // indirect + github.com/mark3labs/mcp-go v0.37.0 // indirect + github.com/mattn/go-colorable v0.1.14 // indirect + github.com/mattn/go-isatty v0.0.20 // indirect + github.com/mattn/go-sqlite3 v1.14.32 // indirect + github.com/mitchellh/mapstructure v1.5.0 // indirect + github.com/oklog/ulid v1.3.1 // indirect + github.com/opentracing/opentracing-go v1.2.0 // indirect + github.com/pkg/errors v0.9.1 // indirect + github.com/redis/go-redis/v9 v9.12.1 // indirect + github.com/rs/zerolog v1.34.0 // indirect + github.com/spf13/cast v1.9.2 // indirect + github.com/twitchyliquid64/golang-asm v0.15.1 // indirect + github.com/valyala/bytebufferpool v1.0.0 // indirect + github.com/valyala/fasthttp v1.65.0 // indirect + github.com/weaviate/weaviate v1.31.5 // indirect + github.com/weaviate/weaviate-go-client/v5 v5.2.0 // indirect + github.com/wk8/go-ordered-map/v2 v2.1.8 // indirect + github.com/yosida95/uritemplate/v3 v3.0.2 // indirect + go.mongodb.org/mongo-driver v1.14.0 // indirect + go.opentelemetry.io/otel/metric v1.37.0 // indirect + go.opentelemetry.io/otel/sdk v1.37.0 // indirect + go.opentelemetry.io/otel/trace v1.37.0 // indirect + golang.org/x/arch v0.20.0 // indirect + golang.org/x/net v0.43.0 // indirect + golang.org/x/oauth2 v0.30.0 // indirect + golang.org/x/sys v0.35.0 // indirect + golang.org/x/text v0.28.0 // indirect + google.golang.org/genproto/googleapis/rpc v0.0.0-20250811230008-5f3141c8851a // indirect + google.golang.org/grpc v1.74.2 // indirect + google.golang.org/protobuf v1.36.7 // indirect + gopkg.in/yaml.v2 v2.4.0 // indirect + gopkg.in/yaml.v3 v3.0.1 // indirect + gorm.io/driver/sqlite v1.6.0 // indirect +) diff --git a/plugins/governance/go.sum b/plugins/governance/go.sum new file mode 100644 index 000000000..a8bac98bd --- /dev/null +++ b/plugins/governance/go.sum @@ -0,0 +1,355 @@ +cloud.google.com/go/compute/metadata v0.8.0 h1:HxMRIbao8w17ZX6wBnjhcDkW6lTFpgcaobyVfZWqRLA= +cloud.google.com/go/compute/metadata v0.8.0/go.mod h1:sYOGTp851OV9bOFJ9CH7elVvyzopvWQFNNghtDQ/Biw= +github.com/BurntSushi/toml v0.3.1/go.mod h1:xHWCNGjB5oqiDr8zfno3MHue2Ht5sIBksp03qcyfWMU= +github.com/PuerkitoBio/purell v1.1.1/go.mod h1:c11w/QuzBsJSee3cPx9rAFu61PvFxuPbtSwDGJws/X0= +github.com/PuerkitoBio/urlesc v0.0.0-20170810143723-de5bf2ad4578/go.mod h1:uGdkoq3SwY9Y+13GIhn11/XLaGBb4BfwItxLd5jeuXE= +github.com/andybalholm/brotli v1.2.0 h1:ukwgCxwYrmACq68yiUqwIWnGY0cTPox/M94sVwToPjQ= +github.com/andybalholm/brotli v1.2.0/go.mod h1:rzTDkvFWvIrjDXZHkuS16NPggd91W3kUSvPlQ1pLaKY= +github.com/asaskevich/govalidator v0.0.0-20200907205600-7a23bdc65eef/go.mod h1:WaHUgvxTVq04UNunO+XhnAqY/wQc+bxr74GqbsZ/Jqw= +github.com/asaskevich/govalidator v0.0.0-20210307081110-f21760c49a8d/go.mod h1:WaHUgvxTVq04UNunO+XhnAqY/wQc+bxr74GqbsZ/Jqw= +github.com/asaskevich/govalidator v0.0.0-20230301143203-a9d515a09cc2 h1:DklsrG3dyBCFEj5IhUbnKptjxatkF07cF2ak3yi77so= +github.com/asaskevich/govalidator v0.0.0-20230301143203-a9d515a09cc2/go.mod h1:WaHUgvxTVq04UNunO+XhnAqY/wQc+bxr74GqbsZ/Jqw= +github.com/aws/aws-sdk-go-v2 v1.38.0 h1:UCRQ5mlqcFk9HJDIqENSLR3wiG1VTWlyUfLDEvY7RxU= +github.com/aws/aws-sdk-go-v2 v1.38.0/go.mod h1:9Q0OoGQoboYIAJyslFyF1f5K1Ryddop8gqMhWx/n4Wg= +github.com/aws/aws-sdk-go-v2/config v1.31.0 h1:9yH0xiY5fUnVNLRWO0AtayqwU1ndriZdN78LlhruJR4= +github.com/aws/aws-sdk-go-v2/config v1.31.0/go.mod h1:VeV3K72nXnhbe4EuxxhzsDc/ByrCSlZwUnWH52Nde/I= +github.com/aws/aws-sdk-go-v2/credentials v1.18.4 h1:IPd0Algf1b+Qy9BcDp0sCUcIWdCQPSzDoMK3a8pcbUM= +github.com/aws/aws-sdk-go-v2/credentials v1.18.4/go.mod h1:nwg78FjH2qvsRM1EVZlX9WuGUJOL5od+0qvm0adEzHk= +github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.18.3 h1:GicIdnekoJsjq9wqnvyi2elW6CGMSYKhdozE7/Svh78= +github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.18.3/go.mod h1:R7BIi6WNC5mc1kfRM7XM/VHC3uRWkjc396sfabq4iOo= +github.com/aws/aws-sdk-go-v2/internal/configsources v1.4.3 h1:o9RnO+YZ4X+kt5Z7Nvcishlz0nksIt2PIzDglLMP0vA= +github.com/aws/aws-sdk-go-v2/internal/configsources v1.4.3/go.mod h1:+6aLJzOG1fvMOyzIySYjOFjcguGvVRL68R+uoRencN4= +github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.7.3 h1:joyyUFhiTQQmVK6ImzNU9TQSNRNeD9kOklqTzyk5v6s= +github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.7.3/go.mod h1:+vNIyZQP3b3B1tSLI0lxvrU9cfM7gpdRXMFfm67ZcPc= +github.com/aws/aws-sdk-go-v2/internal/ini v1.8.3 h1:bIqFDwgGXXN1Kpp99pDOdKMTTb5d2KyU5X/BZxjOkRo= +github.com/aws/aws-sdk-go-v2/internal/ini v1.8.3/go.mod h1:H5O/EsxDWyU+LP/V8i5sm8cxoZgc2fdNR9bxlOFrQTo= +github.com/aws/aws-sdk-go-v2/service/internal/accept-encoding v1.13.0 h1:6+lZi2JeGKtCraAj1rpoZfKqnQ9SptseRZioejfUOLM= +github.com/aws/aws-sdk-go-v2/service/internal/accept-encoding v1.13.0/go.mod h1:eb3gfbVIxIoGgJsi9pGne19dhCBpK6opTYpQqAmdy44= +github.com/aws/aws-sdk-go-v2/service/internal/presigned-url v1.13.3 h1:ieRzyHXypu5ByllM7Sp4hC5f/1Fy5wqxqY0yB85hC7s= +github.com/aws/aws-sdk-go-v2/service/internal/presigned-url v1.13.3/go.mod h1:O5ROz8jHiOAKAwx179v+7sHMhfobFVi6nZt8DEyiYoM= +github.com/aws/aws-sdk-go-v2/service/sso v1.28.0 h1:Mc/MKBf2m4VynyJkABoVEN+QzkfLqGj0aiJuEe7cMeM= +github.com/aws/aws-sdk-go-v2/service/sso v1.28.0/go.mod h1:iS5OmxEcN4QIPXARGhavH7S8kETNL11kym6jhoS7IUQ= +github.com/aws/aws-sdk-go-v2/service/ssooidc v1.33.0 h1:6csaS/aJmqZQbKhi1EyEMM7yBW653Wy/B9hnBofW+sw= +github.com/aws/aws-sdk-go-v2/service/ssooidc v1.33.0/go.mod h1:59qHWaY5B+Rs7HGTuVGaC32m0rdpQ68N8QCN3khYiqs= +github.com/aws/aws-sdk-go-v2/service/sts v1.37.0 h1:MG9VFW43M4A8BYeAfaJJZWrroinxeTi2r3+SnmLQfSA= +github.com/aws/aws-sdk-go-v2/service/sts v1.37.0/go.mod h1:JdeBDPgpJfuS6rU/hNglmOigKhyEZtBmbraLE4GK1J8= +github.com/aws/smithy-go v1.22.5 h1:P9ATCXPMb2mPjYBgueqJNCA5S9UfktsW0tTxi+a7eqw= +github.com/aws/smithy-go v1.22.5/go.mod h1:t1ufH5HMublsJYulve2RKmHDC15xu1f26kHCp/HgceI= +github.com/bahlo/generic-list-go v0.2.0 h1:5sz/EEAK+ls5wF+NeqDpk5+iNdMDXrh3z3nPnH1Wvgk= +github.com/bahlo/generic-list-go v0.2.0/go.mod h1:2KvAjgMlE5NNynlg/5iLrrCCZ2+5xWbdbCW3pNTGyYg= +github.com/bsm/ginkgo/v2 v2.12.0 h1:Ny8MWAHyOepLGlLKYmXG4IEkioBysk6GpaRTLC8zwWs= +github.com/bsm/ginkgo/v2 v2.12.0/go.mod h1:SwYbGRRDovPVboqFv0tPTcG1sN61LM1Z4ARdbAV9g4c= +github.com/bsm/gomega v1.27.10 h1:yeMWxP2pV2fG3FgAODIY8EiRE3dy0aeFYt4l7wh6yKA= +github.com/bsm/gomega v1.27.10/go.mod h1:JyEr/xRbxbtgWNi8tIEVPUYZ5Dzef52k01W3YH0H+O0= +github.com/buger/jsonparser v1.1.1 h1:2PnMjfWD7wBILjqQbt530v576A/cAbQvEW9gGIpYMUs= +github.com/buger/jsonparser v1.1.1/go.mod h1:6RYKKt7H4d4+iWqouImQ9R2FZql3VbhNgx27UK13J/0= +github.com/bytedance/sonic v1.14.0 h1:/OfKt8HFw0kh2rj8N0F6C/qPGRESq0BbaNZgcNXXzQQ= +github.com/bytedance/sonic v1.14.0/go.mod h1:WoEbx8WTcFJfzCe0hbmyTGrfjt8PzNEBdxlNUO24NhA= +github.com/bytedance/sonic/loader v0.3.0 h1:dskwH8edlzNMctoruo8FPTJDF3vLtDT0sXZwvZJyqeA= +github.com/bytedance/sonic/loader v0.3.0/go.mod h1:N8A3vUdtUebEY2/VQC0MyhYeKUFosQU6FxH2JmUe6VI= +github.com/cespare/xxhash/v2 v2.3.0 h1:UL815xU9SqsFlibzuggzjXhog7bL6oX9BbNZnL2UFvs= +github.com/cespare/xxhash/v2 v2.3.0/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs= +github.com/cloudwego/base64x v0.1.6 h1:t11wG9AECkCDk5fMSoxmufanudBtJ+/HemLstXDLI2M= +github.com/cloudwego/base64x v0.1.6/go.mod h1:OFcloc187FXDaYHvrNIjxSe8ncn0OOM8gEHfghB2IPU= +github.com/coreos/go-systemd/v22 v22.5.0/go.mod h1:Y58oyj3AT4RCenI/lSvhwexgC+NSVTIJ3seZv2GcEnc= +github.com/creack/pty v1.1.9/go.mod h1:oKZEueFk5CKHvIhNR5MUki03XCEU+Q6VDXinZuGJ33E= +github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc h1:U9qPSI2PIWSS1VwoXQT9A3Wy9MM3WgvqSxFWenqJduM= +github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f h1:lO4WD4F/rVNCu3HqELle0jiPLLBs70cWOduZpkS1E78= +github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f/go.mod h1:cuUVRXasLTGF7a8hSLbxyZXjz+1KgoB3wDUb6vlszIc= +github.com/docker/go-units v0.4.0/go.mod h1:fgPhTUdO+D/Jk86RDLlptpiXQzgHJF7gydDDbaIK4Dk= +github.com/frankban/quicktest v1.14.6 h1:7Xjx+VpznH+oBnejlPUj8oUpdxnVs4f8XU8WnHkI4W8= +github.com/frankban/quicktest v1.14.6/go.mod h1:4ptaffx2x8+WTWXmUCuVU6aPUX1/Mz7zb5vbUoiM6w0= +github.com/go-logr/logr v1.4.3 h1:CjnDlHq8ikf6E492q6eKboGOC0T8CDaOvkHCIg8idEI= +github.com/go-logr/logr v1.4.3/go.mod h1:9T104GzyrTigFIr8wt5mBrctHMim0Nb2HLGrmQ40KvY= +github.com/go-logr/stdr v1.2.2 h1:hSWxHoqTgW2S2qGc0LTAI563KZ5YKYRhT3MFKZMbjag= +github.com/go-logr/stdr v1.2.2/go.mod h1:mMo/vtBO5dYbehREoey6XUKy/eSumjCCveDpRre4VKE= +github.com/go-openapi/analysis v0.21.2/go.mod h1:HZwRk4RRisyG8vx2Oe6aqeSQcoxRp47Xkp3+K6q+LdY= +github.com/go-openapi/analysis v0.23.0 h1:aGday7OWupfMs+LbmLZG4k0MYXIANxcuBTYUC03zFCU= +github.com/go-openapi/analysis v0.23.0/go.mod h1:9mz9ZWaSlV8TvjQHLl2mUW2PbZtemkE8yA5v22ohupo= +github.com/go-openapi/errors v0.19.8/go.mod h1:cM//ZKUKyO06HSwqAelJ5NsEMMcpa6VpXe8DOa1Mi1M= +github.com/go-openapi/errors v0.19.9/go.mod h1:cM//ZKUKyO06HSwqAelJ5NsEMMcpa6VpXe8DOa1Mi1M= +github.com/go-openapi/errors v0.20.2/go.mod h1:cM//ZKUKyO06HSwqAelJ5NsEMMcpa6VpXe8DOa1Mi1M= +github.com/go-openapi/errors v0.22.0 h1:c4xY/OLxUBSTiepAg3j/MHuAv5mJhnf53LLMWFB+u/w= +github.com/go-openapi/errors v0.22.0/go.mod h1:J3DmZScxCDufmIMsdOuDHxJbdOGC0xtUynjIx092vXE= +github.com/go-openapi/jsonpointer v0.19.3/go.mod h1:Pl9vOtqEWErmShwVjC8pYs9cog34VGT37dQOVbmoatg= +github.com/go-openapi/jsonpointer v0.19.5/go.mod h1:Pl9vOtqEWErmShwVjC8pYs9cog34VGT37dQOVbmoatg= +github.com/go-openapi/jsonpointer v0.21.0 h1:YgdVicSA9vH5RiHs9TZW5oyafXZFc6+2Vc1rr/O9oNQ= +github.com/go-openapi/jsonpointer v0.21.0/go.mod h1:IUyH9l/+uyhIYQ/PXVA41Rexl+kOkAPDdXEYns6fzUY= +github.com/go-openapi/jsonreference v0.19.6/go.mod h1:diGHMEHg2IqXZGKxqyvWdfWU/aim5Dprw5bqpKkTvns= +github.com/go-openapi/jsonreference v0.21.0 h1:Rs+Y7hSXT83Jacb7kFyjn4ijOuVGSvOdF2+tg1TRrwQ= +github.com/go-openapi/jsonreference v0.21.0/go.mod h1:LmZmgsrTkVg9LG4EaHeY8cBDslNPMo06cago5JNLkm4= +github.com/go-openapi/loads v0.21.1/go.mod h1:/DtAMXXneXFjbQMGEtbamCZb+4x7eGwkvZCvBmwUG+g= +github.com/go-openapi/loads v0.22.0 h1:ECPGd4jX1U6NApCGG1We+uEozOAvXvJSF4nnwHZ8Aco= +github.com/go-openapi/loads v0.22.0/go.mod h1:yLsaTCS92mnSAZX5WWoxszLj0u+Ojl+Zs5Stn1oF+rs= +github.com/go-openapi/runtime v0.24.2 h1:yX9HMGQbz32M87ECaAhGpJjBmErO3QLcgdZj9BzGx7c= +github.com/go-openapi/runtime v0.24.2/go.mod h1:AKurw9fNre+h3ELZfk6ILsfvPN+bvvlaU/M9q/r9hpk= +github.com/go-openapi/spec v0.20.4/go.mod h1:faYFR1CvsJZ0mNsmsphTMSoRrNV3TEDoAM7FOEWeq8I= +github.com/go-openapi/spec v0.21.0 h1:LTVzPc3p/RzRnkQqLRndbAzjY0d0BCL72A6j3CdL9ZY= +github.com/go-openapi/spec v0.21.0/go.mod h1:78u6VdPw81XU44qEWGhtr982gJ5BWg2c0I5XwVMotYk= +github.com/go-openapi/strfmt v0.21.0/go.mod h1:ZRQ409bWMj+SOgXofQAGTIo2Ebu72Gs+WaRADcS5iNg= +github.com/go-openapi/strfmt v0.21.1/go.mod h1:I/XVKeLc5+MM5oPNN7P6urMOpuLXEcNrCX/rPGuWb0k= +github.com/go-openapi/strfmt v0.21.2/go.mod h1:I/XVKeLc5+MM5oPNN7P6urMOpuLXEcNrCX/rPGuWb0k= +github.com/go-openapi/strfmt v0.23.0 h1:nlUS6BCqcnAk0pyhi9Y+kdDVZdZMHfEKQiS4HaMgO/c= +github.com/go-openapi/strfmt v0.23.0/go.mod h1:NrtIpfKtWIygRkKVsxh7XQMDQW5HKQl6S5ik2elW+K4= +github.com/go-openapi/swag v0.19.5/go.mod h1:POnQmlKehdgb5mhVOsnJFsivZCEZ/vjK9gh66Z9tfKk= +github.com/go-openapi/swag v0.19.15/go.mod h1:QYRuS/SOXUCsnplDa677K7+DxSOj6IPNl/eQntq43wQ= +github.com/go-openapi/swag v0.21.1/go.mod h1:QYRuS/SOXUCsnplDa677K7+DxSOj6IPNl/eQntq43wQ= +github.com/go-openapi/swag v0.23.0 h1:vsEVJDUo2hPJ2tu0/Xc+4noaxyEffXNIs3cOULZ+GrE= +github.com/go-openapi/swag v0.23.0/go.mod h1:esZ8ITTYEsH1V2trKHjAN8Ai7xHb8RV+YSZ577vPjgQ= +github.com/go-openapi/validate v0.21.0/go.mod h1:rjnrwK57VJ7A8xqfpAOEKRH8yQSGUriMu5/zuPSQ1hg= +github.com/go-openapi/validate v0.24.0 h1:LdfDKwNbpB6Vn40xhTdNZAnfLECL81w+VX3BumrGD58= +github.com/go-openapi/validate v0.24.0/go.mod h1:iyeX1sEufmv3nPbBdX3ieNviWnOZaJ1+zquzJEf2BAQ= +github.com/go-stack/stack v1.8.0/go.mod h1:v0f6uXyyMGvRgIKkXu+yp6POWl0qKG85gN/melR3HDY= +github.com/go-stack/stack v1.8.1/go.mod h1:dcoOX6HbPZSZptuspn9bctJ+N/CnF5gGygcUP3XYfe4= +github.com/gobuffalo/attrs v0.0.0-20190224210810-a9411de4debd/go.mod h1:4duuawTqi2wkkpB4ePgWMaai6/Kc6WEz83bhFwpHzj0= +github.com/gobuffalo/depgen v0.0.0-20190329151759-d478694a28d3/go.mod h1:3STtPUQYuzV0gBVOY3vy6CfMm/ljR4pABfrTeHNLHUY= +github.com/gobuffalo/depgen v0.1.0/go.mod h1:+ifsuy7fhi15RWncXQQKjWS9JPkdah5sZvtHc2RXGlg= +github.com/gobuffalo/envy v1.6.15/go.mod h1:n7DRkBerg/aorDM8kbduw5dN3oXGswK5liaSCx4T5NI= +github.com/gobuffalo/envy v1.7.0/go.mod h1:n7DRkBerg/aorDM8kbduw5dN3oXGswK5liaSCx4T5NI= +github.com/gobuffalo/flect v0.1.0/go.mod h1:d2ehjJqGOH/Kjqcoz+F7jHTBbmDb38yXA598Hb50EGs= +github.com/gobuffalo/flect v0.1.1/go.mod h1:8JCgGVbRjJhVgD6399mQr4fx5rRfGKVzFjbj6RE/9UI= +github.com/gobuffalo/flect v0.1.3/go.mod h1:8JCgGVbRjJhVgD6399mQr4fx5rRfGKVzFjbj6RE/9UI= +github.com/gobuffalo/genny v0.0.0-20190329151137-27723ad26ef9/go.mod h1:rWs4Z12d1Zbf19rlsn0nurr75KqhYp52EAGGxTbBhNk= +github.com/gobuffalo/genny v0.0.0-20190403191548-3ca520ef0d9e/go.mod h1:80lIj3kVJWwOrXWWMRzzdhW3DsrdjILVil/SFKBzF28= +github.com/gobuffalo/genny v0.1.0/go.mod h1:XidbUqzak3lHdS//TPu2OgiFB+51Ur5f7CSnXZ/JDvo= +github.com/gobuffalo/genny v0.1.1/go.mod h1:5TExbEyY48pfunL4QSXxlDOmdsD44RRq4mVZ0Ex28Xk= +github.com/gobuffalo/gitgen v0.0.0-20190315122116-cc086187d211/go.mod h1:vEHJk/E9DmhejeLeNt7UVvlSGv3ziL+djtTr3yyzcOw= +github.com/gobuffalo/gogen v0.0.0-20190315121717-8f38393713f5/go.mod h1:V9QVDIxsgKNZs6L2IYiGR8datgMhB577vzTDqypH360= +github.com/gobuffalo/gogen v0.1.0/go.mod h1:8NTelM5qd8RZ15VjQTFkAW6qOMx5wBbW4dSCS3BY8gg= +github.com/gobuffalo/gogen v0.1.1/go.mod h1:y8iBtmHmGc4qa3urIyo1shvOD8JftTtfcKi+71xfDNE= +github.com/gobuffalo/logger v0.0.0-20190315122211-86e12af44bc2/go.mod h1:QdxcLw541hSGtBnhUc4gaNIXRjiDppFGaDqzbrBd3v8= +github.com/gobuffalo/mapi v1.0.1/go.mod h1:4VAGh89y6rVOvm5A8fKFxYG+wIW6LO1FMTG9hnKStFc= +github.com/gobuffalo/mapi v1.0.2/go.mod h1:4VAGh89y6rVOvm5A8fKFxYG+wIW6LO1FMTG9hnKStFc= +github.com/gobuffalo/packd v0.0.0-20190315124812-a385830c7fc0/go.mod h1:M2Juc+hhDXf/PnmBANFCqx4DM3wRbgDvnVWeG2RIxq4= +github.com/gobuffalo/packd v0.1.0/go.mod h1:M2Juc+hhDXf/PnmBANFCqx4DM3wRbgDvnVWeG2RIxq4= +github.com/gobuffalo/packr/v2 v2.0.9/go.mod h1:emmyGweYTm6Kdper+iywB6YK5YzuKchGtJQZ0Odn4pQ= +github.com/gobuffalo/packr/v2 v2.2.0/go.mod h1:CaAwI0GPIAv+5wKLtv8Afwl+Cm78K/I/VCm/3ptBN+0= +github.com/gobuffalo/syncx v0.0.0-20190224160051-33c29581e754/go.mod h1:HhnNqWY95UYwwW3uSASeV7vtgYkT2t16hJgV3AEPUpw= +github.com/godbus/dbus/v5 v5.0.4/go.mod h1:xhWf0FNVPg57R7Z0UbKHbJfkEywrmjJnf7w5xrFpKfA= +github.com/golang/protobuf v1.5.4 h1:i7eJL8qZTpSEXOPTxNKhASYpMn+8e5Q6AdndVa1dWek= +github.com/golang/protobuf v1.5.4/go.mod h1:lnTiLA8Wa4RWRcIUkrtSVa5nRhsEGBg48fD6rSs7xps= +github.com/golang/snappy v0.0.1/go.mod h1:/XxbfmMg8lxefKM7IXC3fBNl/7bRcc72aCRzEWrmP2Q= +github.com/google/go-cmp v0.5.2/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= +github.com/google/go-cmp v0.7.0 h1:wk8382ETsv4JYUZwIsn6YpYiWiBsYLSJiTsyBybVuN8= +github.com/google/go-cmp v0.7.0/go.mod h1:pXiqmnSA92OHEEa9HXL2W4E7lf9JzCmGVUdgjX3N/iU= +github.com/google/uuid v1.1.1/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= +github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0= +github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= +github.com/inconshreveable/mousetrap v1.0.0/go.mod h1:PxqpIevigyE2G7u3NXJIT2ANytuPF1OarO4DADm73n8= +github.com/invopop/jsonschema v0.13.0 h1:KvpoAJWEjR3uD9Kbm2HWJmqsEaHt8lBUpd0qHcIi21E= +github.com/invopop/jsonschema v0.13.0/go.mod h1:ffZ5Km5SWWRAIN6wbDXItl95euhFz2uON45H2qjYt+0= +github.com/jinzhu/inflection v1.0.0 h1:K317FqzuhWc8YvSVlFMCCUb36O/S9MCKRDI7QkRKD/E= +github.com/jinzhu/inflection v1.0.0/go.mod h1:h+uFLlag+Qp1Va5pdKtLDYj+kHp5pxUVkryuEj+Srlc= +github.com/jinzhu/now v1.1.5 h1:/o9tlHleP7gOFmsnYNz3RGnqzefHA47wQpKrrdTIwXQ= +github.com/jinzhu/now v1.1.5/go.mod h1:d3SSVoowX0Lcu0IBviAWJpolVfI5UJVZZ7cO71lE/z8= +github.com/joho/godotenv v1.3.0/go.mod h1:7hK45KPybAkOC6peb+G5yklZfMxEjkZhHbwpqxOKXbg= +github.com/josharian/intern v1.0.0 h1:vlS4z54oSdjm0bgjRigI+G1HpF+tI+9rE5LLzOg8HmY= +github.com/josharian/intern v1.0.0/go.mod h1:5DoeVV0s6jJacbCEi61lwdGj/aVlrQvzHFFd8Hwg//Y= +github.com/karrick/godirwalk v1.8.0/go.mod h1:H5KPZjojv4lE+QYImBI8xVtrBRgYrIVsaRPx4tDPEn4= +github.com/karrick/godirwalk v1.10.3/go.mod h1:RoGL9dQei4vP9ilrpETWE8CLOZ1kiN0LhBygSwrAsHA= +github.com/klauspost/compress v1.13.6/go.mod h1:/3/Vjq9QcHkK5uEr5lBEmyoZ1iFhe47etQ6QUkpK6sk= +github.com/klauspost/compress v1.18.0 h1:c/Cqfb0r+Yi+JtIEq73FWXVkRonBlf0CRNYc8Zttxdo= +github.com/klauspost/compress v1.18.0/go.mod h1:2Pp+KzxcywXVXMr50+X0Q/Lsb43OQHYWRCY2AiWywWQ= +github.com/klauspost/cpuid/v2 v2.3.0 h1:S4CRMLnYUhGeDFDqkGriYKdfoFlDnMtqTiI/sFzhA9Y= +github.com/klauspost/cpuid/v2 v2.3.0/go.mod h1:hqwkgyIinND0mEev00jJYCxPNVRVXFQeu1XKlok6oO0= +github.com/konsorten/go-windows-terminal-sequences v1.0.1/go.mod h1:T0+1ngSBFLxvqU3pZ+m/2kptfBszLMUkC4ZK/EgS/cQ= +github.com/konsorten/go-windows-terminal-sequences v1.0.2/go.mod h1:T0+1ngSBFLxvqU3pZ+m/2kptfBszLMUkC4ZK/EgS/cQ= +github.com/kr/pretty v0.1.0/go.mod h1:dAy3ld7l9f0ibDNOQOHHMYYIIbhfbHSm3C4ZsoJORNo= +github.com/kr/pretty v0.3.1 h1:flRD4NNwYAUpkphVc1HcthR4KEIFJ65n8Mw5qdRn3LE= +github.com/kr/pretty v0.3.1/go.mod h1:hoEshYVHaxMs3cyo3Yncou5ZscifuDolrwPKZanG3xk= +github.com/kr/pty v1.1.1/go.mod h1:pFQYn66WHrOpPYNljwOMqo10TkYh1fy3cYio2l3bCsQ= +github.com/kr/text v0.1.0/go.mod h1:4Jbv+DJW3UT/LiOwJeYQe1efqtUx/iVham/4vfdArNI= +github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY= +github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE= +github.com/mailru/easyjson v0.0.0-20190614124828-94de47d64c63/go.mod h1:C1wdFJiN94OJF2b5HbByQZoLdCWB1Yqtg26g4irojpc= +github.com/mailru/easyjson v0.0.0-20190626092158-b2ccc519800e/go.mod h1:C1wdFJiN94OJF2b5HbByQZoLdCWB1Yqtg26g4irojpc= +github.com/mailru/easyjson v0.7.6/go.mod h1:xzfreul335JAWq5oZzymOObrkdz5UnU4kGfJJLY9Nlc= +github.com/mailru/easyjson v0.7.7/go.mod h1:xzfreul335JAWq5oZzymOObrkdz5UnU4kGfJJLY9Nlc= +github.com/mailru/easyjson v0.9.0 h1:PrnmzHw7262yW8sTBwxi1PdJA3Iw/EKBa8psRf7d9a4= +github.com/mailru/easyjson v0.9.0/go.mod h1:1+xMtQp2MRNVL/V1bOzuP3aP8VNwRW55fQUto+XFtTU= +github.com/mark3labs/mcp-go v0.37.0 h1:BywvZLPRT6Zx6mMG/MJfxLSZQkTGIcJSEGKsvr4DsoQ= +github.com/mark3labs/mcp-go v0.37.0/go.mod h1:T7tUa2jO6MavG+3P25Oy/jR7iCeJPHImCZHRymCn39g= +github.com/markbates/oncer v0.0.0-20181203154359-bf2de49a0be2/go.mod h1:Ld9puTsIW75CHf65OeIOkyKbteujpZVXDpWK6YGZbxE= +github.com/markbates/safe v1.0.1/go.mod h1:nAqgmRi7cY2nqMc92/bSEeQA+R4OheNU2T1kNSCBdG0= +github.com/mattn/go-colorable v0.1.13/go.mod h1:7S9/ev0klgBDR4GtXTXX8a3vIGJpMovkB8vQcUbaXHg= +github.com/mattn/go-colorable v0.1.14 h1:9A9LHSqF/7dyVVX6g0U9cwm9pG3kP9gSzcuIPHPsaIE= +github.com/mattn/go-colorable v0.1.14/go.mod h1:6LmQG8QLFO4G5z1gPvYEzlUgJ2wF+stgPZH1UqBm1s8= +github.com/mattn/go-isatty v0.0.16/go.mod h1:kYGgaQfpe5nmfYZH+SKPsOc2e4SrIfOl2e/yFXSvRLM= +github.com/mattn/go-isatty v0.0.19/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y= +github.com/mattn/go-isatty v0.0.20 h1:xfD0iDuEKnDkl03q4limB+vH+GxLEtL/jb4xVJSWWEY= +github.com/mattn/go-isatty v0.0.20/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y= +github.com/mattn/go-sqlite3 v1.14.32 h1:JD12Ag3oLy1zQA+BNn74xRgaBbdhbNIDYvQUEuuErjs= +github.com/mattn/go-sqlite3 v1.14.32/go.mod h1:Uh1q+B4BYcTPb+yiD3kU8Ct7aC0hY9fxUwlHK0RXw+Y= +github.com/maximhq/bifrost/core v1.1.37 h1:jVFY1tQFY8T2r4S3RE1zN8cFp1Uw97Dec3Ud32rR8Uc= +github.com/maximhq/bifrost/core v1.1.37/go.mod h1:tf2pFTpoM53UGXXMFYxsaUjMqnCqYDOd9glFgMJvA0c= +github.com/maximhq/bifrost/framework v1.0.23 h1:erRPP9Q0WIaUgxuLBN8urd77SObEF9irPvpV9Wbegyk= +github.com/maximhq/bifrost/framework v1.0.23/go.mod h1:uEB0iuQtFfuFuMrhccMsb+51mf8m8X2tB8ZlDVoJUbM= +github.com/mitchellh/mapstructure v1.3.3/go.mod h1:bFUtVrKA4DC2yAKiSyO/QUcy7e+RRV2QTWOzhPopBRo= +github.com/mitchellh/mapstructure v1.4.1/go.mod h1:bFUtVrKA4DC2yAKiSyO/QUcy7e+RRV2QTWOzhPopBRo= +github.com/mitchellh/mapstructure v1.4.3/go.mod h1:bFUtVrKA4DC2yAKiSyO/QUcy7e+RRV2QTWOzhPopBRo= +github.com/mitchellh/mapstructure v1.5.0 h1:jeMsZIYE/09sWLaz43PL7Gy6RuMjD2eJVyuac5Z2hdY= +github.com/mitchellh/mapstructure v1.5.0/go.mod h1:bFUtVrKA4DC2yAKiSyO/QUcy7e+RRV2QTWOzhPopBRo= +github.com/montanaflynn/stats v0.0.0-20171201202039-1bf9dbcd8cbe/go.mod h1:wL8QJuTMNUDYhXwkmfOly8iTdp5TEcJFWZD2D7SIkUc= +github.com/niemeyer/pretty v0.0.0-20200227124842-a10e7caefd8e/go.mod h1:zD1mROLANZcx1PVRCS0qkT7pwLkGfwJo4zjcN/Tysno= +github.com/oklog/ulid v1.3.1 h1:EGfNDEx6MqHz8B3uNV6QAib1UR2Lm97sHi3ocA6ESJ4= +github.com/oklog/ulid v1.3.1/go.mod h1:CirwcVhetQ6Lv90oh/F+FBtV6XMibvdAFo93nm5qn4U= +github.com/opentracing/opentracing-go v1.2.0 h1:uEJPy/1a5RIPAJ0Ov+OIO8OxWu77jEv+1B0VhjKrZUs= +github.com/opentracing/opentracing-go v1.2.0/go.mod h1:GxEUsuufX4nBwe+T+Wl9TAgYrxe9dPLANfrWvHYVTgc= +github.com/pelletier/go-toml v1.7.0/go.mod h1:vwGMzjaWMwyfHwgIBhI2YUM4fB6nL6lVAvS1LBMMhTE= +github.com/pkg/errors v0.8.0/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= +github.com/pkg/errors v0.8.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= +github.com/pkg/errors v0.9.1 h1:FEBLx1zS214owpjy7qsBeixbURkuhQAwrK5UwLGTwt4= +github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= +github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= +github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2 h1:Jamvg5psRIccs7FGNTlIRMkT8wgtp5eCXdBlqhYGL6U= +github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= +github.com/redis/go-redis/v9 v9.12.1 h1:k5iquqv27aBtnTm2tIkROUDp8JBXhXZIVu1InSgvovg= +github.com/redis/go-redis/v9 v9.12.1/go.mod h1:huWgSWd8mW6+m0VPhJjSSQ+d6Nh1VICQ6Q5lHuCH/Iw= +github.com/rogpeppe/go-internal v1.1.0/go.mod h1:M8bDsm7K2OlrFYOpmOWEs/qY81heoFRclV5y23lUDJ4= +github.com/rogpeppe/go-internal v1.2.2/go.mod h1:M8bDsm7K2OlrFYOpmOWEs/qY81heoFRclV5y23lUDJ4= +github.com/rogpeppe/go-internal v1.3.0/go.mod h1:M8bDsm7K2OlrFYOpmOWEs/qY81heoFRclV5y23lUDJ4= +github.com/rogpeppe/go-internal v1.13.1 h1:KvO1DLK/DRN07sQ1LQKScxyZJuNnedQ5/wKSR38lUII= +github.com/rogpeppe/go-internal v1.13.1/go.mod h1:uMEvuHeurkdAXX61udpOXGD/AzZDWNMNyH2VO9fmH0o= +github.com/rs/xid v1.6.0/go.mod h1:7XoLgs4eV+QndskICGsho+ADou8ySMSjJKDIan90Nz0= +github.com/rs/zerolog v1.34.0 h1:k43nTLIwcTVQAncfCw4KZ2VY6ukYoZaBPNOE8txlOeY= +github.com/rs/zerolog v1.34.0/go.mod h1:bJsvje4Z08ROH4Nhs5iH600c3IkWhwp44iRc54W6wYQ= +github.com/sirupsen/logrus v1.4.0/go.mod h1:LxeOpSwHxABJmUn/MG1IvRgCAasNZTLOkJPxbbu5VWo= +github.com/sirupsen/logrus v1.4.1/go.mod h1:ni0Sbl8bgC9z8RoU9G6nDWqqs/fq4eDPysMBDgk/93Q= +github.com/sirupsen/logrus v1.4.2/go.mod h1:tLMulIdttU9McNUspp0xgXVQah82FyeX6MwdIuYE2rE= +github.com/spf13/cast v1.9.2 h1:SsGfm7M8QOFtEzumm7UZrZdLLquNdzFYfIbEXntcFbE= +github.com/spf13/cast v1.9.2/go.mod h1:jNfB8QC9IA6ZuY2ZjDp0KtFO2LZZlg4S/7bzP6qqeHo= +github.com/spf13/cobra v0.0.3/go.mod h1:1l0Ry5zgKvJasoi3XT1TypsSe7PqH0Sj9dhYf7v3XqQ= +github.com/spf13/pflag v1.0.3/go.mod h1:DYY7MBk1bdzusC3SYhjObp+wFpr4gzcvqqNjLnInEg4= +github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= +github.com/stretchr/objx v0.1.1/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= +github.com/stretchr/objx v0.4.0/go.mod h1:YvHI0jy2hoMjB+UWwv71VJQ9isScKT/TqJzVSSt89Yw= +github.com/stretchr/objx v0.5.0/go.mod h1:Yh+to48EsGEfYuaHDzXPcE3xhTkx73EhmCGUpEOglKo= +github.com/stretchr/objx v0.5.2 h1:xuMeJ0Sdp5ZMRXx/aWO6RZxdr3beISkG5/G/aIRr3pY= +github.com/stretchr/objx v0.5.2/go.mod h1:FRsXN1f5AsAjCGJKqEizvkpNtU+EGNCLh3NxZ/8L+MA= +github.com/stretchr/testify v1.2.2/go.mod h1:a8OnRcib4nhh0OaRAV+Yts87kKdq0PP7pXfy6kDkUVs= +github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI= +github.com/stretchr/testify v1.6.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= +github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= +github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= +github.com/stretchr/testify v1.8.0/go.mod h1:yNjHg4UonilssWZ8iaSj1OCr/vHnekPRkoO+kdMU+MU= +github.com/stretchr/testify v1.8.1/go.mod h1:w2LPCIKwWwSfY2zedu0+kehJoqGctiVI29o6fzry7u4= +github.com/stretchr/testify v1.10.0 h1:Xv5erBjTwe/5IxqUQTdXv5kgmIvbHo3QQyRwhJsOfJA= +github.com/stretchr/testify v1.10.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY= +github.com/tidwall/pretty v1.0.0/go.mod h1:XNkn88O1ChpSDQmQeStsy+sBenx6DDtFZJxhVysOjyk= +github.com/twitchyliquid64/golang-asm v0.15.1 h1:SU5vSMR7hnwNxj24w34ZyCi/FmDZTkS4MhqMhdFk5YI= +github.com/twitchyliquid64/golang-asm v0.15.1/go.mod h1:a1lVb/DtPvCB8fslRZhAngC2+aY1QWCk3Cedj/Gdt08= +github.com/valyala/bytebufferpool v1.0.0 h1:GqA5TC/0021Y/b9FG4Oi9Mr3q7XYx6KllzawFIhcdPw= +github.com/valyala/bytebufferpool v1.0.0/go.mod h1:6bBcMArwyJ5K/AmCkWv1jt77kVWyCJ6HpOuEn7z0Csc= +github.com/valyala/fasthttp v1.65.0 h1:j/u3uzFEGFfRxw79iYzJN+TteTJwbYkru9uDp3d0Yf8= +github.com/valyala/fasthttp v1.65.0/go.mod h1:P/93/YkKPMsKSnATEeELUCkG8a7Y+k99uxNHVbKINr4= +github.com/weaviate/weaviate v1.31.5 h1:YcmU1NcY2rdegWpE/mifS/9OisjE3I30JC7k6OgRlIE= +github.com/weaviate/weaviate v1.31.5/go.mod h1:CMgFYC2WIekOrNtyCQZ+HRJzJVCtrJYAdAkZVUVy45E= +github.com/weaviate/weaviate-go-client/v5 v5.2.0 h1:/HG0vFiKBK3JoOKo0mdk2XVYZ+oM0KfvCLG2ySr/FCA= +github.com/weaviate/weaviate-go-client/v5 v5.2.0/go.mod h1:nzR0ScRmbbutI+0pAjylj9Pt6upGVotnphiLWjy/QNA= +github.com/wk8/go-ordered-map/v2 v2.1.8 h1:5h/BUHu93oj4gIdvHHHGsScSTMijfx5PeYkE/fJgbpc= +github.com/wk8/go-ordered-map/v2 v2.1.8/go.mod h1:5nJHM5DyteebpVlHnWMV0rPz6Zp7+xBAnxjb1X5vnTw= +github.com/xdg-go/pbkdf2 v1.0.0/go.mod h1:jrpuAogTd400dnrH08LKmI/xc1MbPOebTwRqcT5RDeI= +github.com/xdg-go/scram v1.0.2/go.mod h1:1WAq6h33pAW+iRreB34OORO2Nf7qel3VV3fjBj+hCSs= +github.com/xdg-go/stringprep v1.0.2/go.mod h1:8F9zXuvzgwmyT5DUm4GUfZGDdT3W+LCvS6+da4O5kxM= +github.com/xyproto/randomstring v1.0.5 h1:YtlWPoRdgMu3NZtP45drfy1GKoojuR7hmRcnhZqKjWU= +github.com/xyproto/randomstring v1.0.5/go.mod h1:rgmS5DeNXLivK7YprL0pY+lTuhNQW3iGxZ18UQApw/E= +github.com/yosida95/uritemplate/v3 v3.0.2 h1:Ed3Oyj9yrmi9087+NczuL5BwkIc4wvTb5zIM+UJPGz4= +github.com/yosida95/uritemplate/v3 v3.0.2/go.mod h1:ILOh0sOhIJR3+L/8afwt/kE++YT040gmv5BQTMR2HP4= +github.com/youmark/pkcs8 v0.0.0-20181117223130-1be2e3e5546d/go.mod h1:rHwXgn7JulP+udvsHwJoVG1YGAP6VLg4y9I5dyZdqmA= +go.mongodb.org/mongo-driver v1.7.3/go.mod h1:NqaYOwnXWr5Pm7AOpO5QFxKJ503nbMse/R79oO62zWg= +go.mongodb.org/mongo-driver v1.7.5/go.mod h1:VXEWRZ6URJIkUq2SCAyapmhH0ZLRBP+FT4xhp5Zvxng= +go.mongodb.org/mongo-driver v1.8.3/go.mod h1:0sQWfOeY63QTntERDJJ/0SuKK0T1uVSgKCuAROlKEPY= +go.mongodb.org/mongo-driver v1.14.0 h1:P98w8egYRjYe3XDjxhYJagTokP/H6HzlsnojRgZRd80= +go.mongodb.org/mongo-driver v1.14.0/go.mod h1:Vzb0Mk/pa7e6cWw85R4F/endUC3u0U9jGcNU603k65c= +go.opentelemetry.io/auto/sdk v1.1.0 h1:cH53jehLUN6UFLY71z+NDOiNJqDdPRaXzTel0sJySYA= +go.opentelemetry.io/auto/sdk v1.1.0/go.mod h1:3wSPjt5PWp2RhlCcmmOial7AvC4DQqZb7a7wCow3W8A= +go.opentelemetry.io/otel v1.37.0 h1:9zhNfelUvx0KBfu/gb+ZgeAfAgtWrfHJZcAqFC228wQ= +go.opentelemetry.io/otel v1.37.0/go.mod h1:ehE/umFRLnuLa/vSccNq9oS1ErUlkkK71gMcN34UG8I= +go.opentelemetry.io/otel/metric v1.37.0 h1:mvwbQS5m0tbmqML4NqK+e3aDiO02vsf/WgbsdpcPoZE= +go.opentelemetry.io/otel/metric v1.37.0/go.mod h1:04wGrZurHYKOc+RKeye86GwKiTb9FKm1WHtO+4EVr2E= +go.opentelemetry.io/otel/sdk v1.37.0 h1:ItB0QUqnjesGRvNcmAcU0LyvkVyGJ2xftD29bWdDvKI= +go.opentelemetry.io/otel/sdk v1.37.0/go.mod h1:VredYzxUvuo2q3WRcDnKDjbdvmO0sCzOvVAiY+yUkAg= +go.opentelemetry.io/otel/sdk/metric v1.37.0 h1:90lI228XrB9jCMuSdA0673aubgRobVZFhbjxHHspCPc= +go.opentelemetry.io/otel/sdk/metric v1.37.0/go.mod h1:cNen4ZWfiD37l5NhS+Keb5RXVWZWpRE+9WyVCpbo5ps= +go.opentelemetry.io/otel/trace v1.37.0 h1:HLdcFNbRQBE2imdSEgm/kwqmQj1Or1l/7bW6mxVK7z4= +go.opentelemetry.io/otel/trace v1.37.0/go.mod h1:TlgrlQ+PtQO5XFerSPUYG0JSgGyryXewPGyayAWSBS0= +golang.org/x/arch v0.20.0 h1:dx1zTU0MAE98U+TQ8BLl7XsJbgze2WnNKF/8tGp/Q6c= +golang.org/x/arch v0.20.0/go.mod h1:bdwinDaKcfZUGpH09BB7ZmOfhalA8lQdzl62l8gGWsk= +golang.org/x/crypto v0.0.0-20180904163835-0709b304e793/go.mod h1:6SG95UA2DQfeDnfUPMdvaQW0Q7yPrPDi9nlGo2tz2b4= +golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= +golang.org/x/crypto v0.0.0-20190422162423-af44ce270edf/go.mod h1:WFFai1msRO1wXaEeE5yQxYXgSfI8pQAWXbQop6sCtWE= +golang.org/x/crypto v0.0.0-20200302210943-78000ba7a073/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto= +golang.org/x/crypto v0.0.0-20201216223049-8b5274cf687f/go.mod h1:jdWPYTVW3xRLrWPugEBEK3UY2ZEsg3UU495nc5E+M+I= +golang.org/x/net v0.0.0-20190311183353-d8887717615a/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg= +golang.org/x/net v0.0.0-20190404232315-eb5bcb51f2a3/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg= +golang.org/x/net v0.0.0-20210421230115-4e50805a0758/go.mod h1:72T/g9IO56b78aLF+1Kcs5dz7/ng1VjMUvfKvpfy+jM= +golang.org/x/net v0.0.0-20220127200216-cd36cc0744dd/go.mod h1:CfG3xpIq0wQ8r1q4Su4UZFWDARRcnwPjda9FqA0JpMk= +golang.org/x/net v0.43.0 h1:lat02VYK2j4aLzMzecihNvTlJNQUq316m2Mr9rnM6YE= +golang.org/x/net v0.43.0/go.mod h1:vhO1fvI4dGsIjh73sWfUVjj3N7CA9WkKJNQm2svM6Jg= +golang.org/x/oauth2 v0.30.0 h1:dnDm7JmhM45NNpd8FDDeLhK6FwqbOf4MLCM9zb1BOHI= +golang.org/x/oauth2 v0.30.0/go.mod h1:B++QgG3ZKulg6sRPGD/mqlHQs5rB3Ml9erfeDY7xKlU= +golang.org/x/sync v0.0.0-20190227155943-e225da77a7e6/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= +golang.org/x/sync v0.0.0-20190412183630-56d357773e84/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= +golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= +golang.org/x/sync v0.0.0-20190911185100-cd5d95a43a6e/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= +golang.org/x/sys v0.0.0-20180905080454-ebe1bf3edb33/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= +golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= +golang.org/x/sys v0.0.0-20190403152447-81d4e9dc473e/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20190412213103-97732733099d/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20190419153524-e8e3143a4f4a/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20190422165155-953cdadca894/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20190531175056-4c3a928424d2/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20191026070338-33540a1f6037/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20210420072515-93ed5bcd2bfe/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20210615035016-665e8c7367d1/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.0.0-20211216021012-1d35b9e2eb4e/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.0.0-20220811171246-fbc7d0a398ab/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.12.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.35.0 h1:vz1N37gP5bs89s7He8XuIYXpyY0+QlsKmzipCbUtyxI= +golang.org/x/sys v0.35.0/go.mod h1:BJP2sWEmIv4KK5OTEluFJCKSidICx8ciO85XgH3Ak8k= +golang.org/x/term v0.0.0-20201117132131-f5c789dd3221/go.mod h1:Nr5EML6q2oocZ2LXRh80K7BxOlk5/8JxuGnuhpl+muw= +golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo= +golang.org/x/term v0.0.0-20210927222741-03fcf44c2211/go.mod h1:jbD1KX2456YbFQfuXm/mYQcufACuNUgVhRMnK/tPxf8= +golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= +golang.org/x/text v0.3.5/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= +golang.org/x/text v0.3.6/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= +golang.org/x/text v0.3.7/go.mod h1:u+2+/6zg+i71rQMx5EYifcz6MCKuco9NR6JIITiCfzQ= +golang.org/x/text v0.28.0 h1:rhazDwis8INMIwQ4tpjLDzUhx6RlXqZNPEM0huQojng= +golang.org/x/text v0.28.0/go.mod h1:U8nCwOR8jO/marOQ0QbDiOngZVEBB7MAiitBuMjXiNU= +golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= +golang.org/x/tools v0.0.0-20190329151228-23e29df326fe/go.mod h1:LCzVGOaR6xXOjkQ3onu1FJEFr0SW1gC7cKk1uF8kGRs= +golang.org/x/tools v0.0.0-20190416151739-9c9e1878f421/go.mod h1:LCzVGOaR6xXOjkQ3onu1FJEFr0SW1gC7cKk1uF8kGRs= +golang.org/x/tools v0.0.0-20190420181800-aa740d480789/go.mod h1:LCzVGOaR6xXOjkQ3onu1FJEFr0SW1gC7cKk1uF8kGRs= +golang.org/x/tools v0.0.0-20190531172133-b3315ee88b7d/go.mod h1:/rFqwRUd4F7ZHNgwSSTFct+R/Kf4OFW1sUzUTQQTgfc= +golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= +google.golang.org/genproto/googleapis/rpc v0.0.0-20250811230008-5f3141c8851a h1:tPE/Kp+x9dMSwUm/uM0JKK0IfdiJkwAbSMSeZBXXJXc= +google.golang.org/genproto/googleapis/rpc v0.0.0-20250811230008-5f3141c8851a/go.mod h1:gw1tLEfykwDz2ET4a12jcXt4couGAm7IwsVaTy0Sflo= +google.golang.org/grpc v1.74.2 h1:WoosgB65DlWVC9FqI82dGsZhWFNBSLjQ84bjROOpMu4= +google.golang.org/grpc v1.74.2/go.mod h1:CtQ+BGjaAIXHs/5YS3i473GqwBBa1zGQNevxdeBEXrM= +google.golang.org/protobuf v1.36.7 h1:IgrO7UwFQGJdRNXH/sQux4R1Dj1WAKcLElzeeRaXV2A= +google.golang.org/protobuf v1.36.7/go.mod h1:jduwjTPXsFjZGTmRluh+L6NjiWu7pchiJ2/5YcXBHnY= +gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= +gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= +gopkg.in/check.v1 v1.0.0-20200227125254-8fa46927fb4f/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= +gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c h1:Hei/4ADfdWqJk1ZMxUNpqntNwaWcugrBjAiHlqqRiVk= +gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c/go.mod h1:JHkPIbrfpd72SG/EVd6muEfDQjcINNoR0C8j2r3qZ4Q= +gopkg.in/errgo.v2 v2.1.0/go.mod h1:hNsd1EY+bozCKY1Ytp96fpM3vjJbqLJn88ws8XvfDNI= +gopkg.in/yaml.v2 v2.2.2/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= +gopkg.in/yaml.v2 v2.2.8/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= +gopkg.in/yaml.v2 v2.4.0 h1:D8xgwECY7CYvx+Y2n4sBz93Jn9JRvxdiyyo8CTfuKaY= +gopkg.in/yaml.v2 v2.4.0/go.mod h1:RDklbk79AGWmwhnvt/jBztapEOGDOx6ZbXqjP6csGnQ= +gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= +gopkg.in/yaml.v3 v3.0.0-20200605160147-a5ece683394c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= +gopkg.in/yaml.v3 v3.0.0-20200615113413-eeeca48fe776/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= +gopkg.in/yaml.v3 v3.0.0-20210107192922-496545a6307b/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= +gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= +gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= +gorm.io/driver/sqlite v1.6.0 h1:WHRRrIiulaPiPFmDcod6prc4l2VGVWHz80KspNsxSfQ= +gorm.io/driver/sqlite v1.6.0/go.mod h1:AO9V1qIQddBESngQUKWL9yoH93HIeA1X6V633rBwyT8= +gorm.io/gorm v1.30.1 h1:lSHg33jJTBxs2mgJRfRZeLDG+WZaHYCk3Wtfl6Ngzo4= +gorm.io/gorm v1.30.1/go.mod h1:8Z33v652h4//uMA76KjeDH8mJXPm1QNCYrMeatR0DOE= diff --git a/plugins/governance/main.go b/plugins/governance/main.go new file mode 100644 index 000000000..d3b51ea46 --- /dev/null +++ b/plugins/governance/main.go @@ -0,0 +1,326 @@ +// Package governance provides comprehensive governance plugin for Bifrost +package governance + +import ( + "context" + "fmt" + + bifrost "github.com/maximhq/bifrost/core" + "github.com/maximhq/bifrost/core/schemas" + "github.com/maximhq/bifrost/framework/configstore" + "github.com/maximhq/bifrost/framework/pricing" +) + +// PluginName is the name of the governance plugin +const PluginName = "governance" + +// contextKey is a custom type for context keys to avoid collisions +type contextKey string + +const ( + governanceRejectedContextKey contextKey = "bf-governance-rejected" + governanceIsCacheReadContextKey contextKey = "bf-governance-is-cache-read" + governanceIsBatchContextKey contextKey = "bf-governance-is-batch" +) + +// Config is the configuration for the governance plugin +type Config struct { + IsVkMandatory *bool `json:"is_vk_mandatory"` +} + +// GovernancePlugin implements the main governance plugin with hierarchical budget system +type GovernancePlugin struct { + // Core components with clear separation of concerns + store *GovernanceStore // Pure data access layer + resolver *BudgetResolver // Pure decision engine for hierarchical governance + tracker *UsageTracker // Business logic owner (updates, resets, persistence) + + // Dependencies + configStore configstore.ConfigStore + pricingManager *pricing.PricingManager + logger schemas.Logger + + isVkMandatory *bool +} + +// Init creates a new governance plugin with cleanly segregated components +// All governance features are enabled by default with optimized settings +func Init(ctx context.Context, config *Config, logger schemas.Logger, store configstore.ConfigStore, governanceConfig *configstore.GovernanceConfig, pricingManager *pricing.PricingManager) (*GovernancePlugin, error) { + if store == nil { + logger.Warn("governance plugin requires config store to persist data, running in memory only mode") + } + if pricingManager == nil { + logger.Warn("governance plugin requires pricing manager to calculate cost, all cost calculations will be skipped.") + } + + governanceStore, err := NewGovernanceStore(logger, store, governanceConfig) + if err != nil { + return nil, fmt.Errorf("failed to initialize governance store: %w", err) + } + // Initialize components in dependency order with fixed, optimal settings + // Resolver (pure decision engine for hierarchical governance, depends only on store) + resolver := NewBudgetResolver(governanceStore, logger) + + // 3. Tracker (business logic owner, depends on store and resolver) + tracker := NewUsageTracker(governanceStore, resolver, store, logger) + + // 4. Perform startup reset check for any expired limits from downtime + if store != nil { + if err := tracker.PerformStartupResets(); err != nil { + logger.Warn("startup reset failed: %v", err) + // Continue initialization even if startup reset fails (non-critical) + } + } + + plugin := &GovernancePlugin{ + store: governanceStore, + resolver: resolver, + tracker: tracker, + configStore: store, + pricingManager: pricingManager, + logger: logger, + isVkMandatory: config.IsVkMandatory, + } + + return plugin, nil +} + +// GetName returns the name of the plugin +func (p *GovernancePlugin) GetName() string { + return PluginName +} + +// PreHook intercepts requests before they are processed (governance decision point) +func (p *GovernancePlugin) PreHook(ctx *context.Context, req *schemas.BifrostRequest) (*schemas.BifrostRequest, *schemas.PluginShortCircuit, error) { + // Extract governance headers and virtual key using utility functions + headers := extractHeadersFromContext(*ctx) + virtualKey := getStringFromContext(*ctx, ContextKey("x-bf-vk")) + requestID := getStringFromContext(*ctx, schemas.BifrostContextKey("request-id")) + + if virtualKey == "" { + if p.isVkMandatory != nil && *p.isVkMandatory { + return req, &schemas.PluginShortCircuit{ + Error: &schemas.BifrostError{ + Type: bifrost.Ptr("virtual_key_required"), + StatusCode: bifrost.Ptr(400), + Error: schemas.ErrorField{ + Message: "x-bf-vk header is missing", + }, + }, + }, nil + } else { + return req, nil, nil + } + } + + // Extract provider and model from request + provider := req.Provider + model := req.Model + + // Store original request provider/model and operation flags in context for PostHook + *ctx = context.WithValue(*ctx, schemas.BifrostContextKeyRequestProvider, provider) + *ctx = context.WithValue(*ctx, schemas.BifrostContextKeyRequestModel, model) + + // Create request context for evaluation + evaluationRequest := &EvaluationRequest{ + VirtualKey: virtualKey, + Provider: provider, + Model: model, + Headers: headers, + RequestID: requestID, + } + + // Use resolver to make governance decision (pure decision engine) + result := p.resolver.EvaluateRequest(ctx, evaluationRequest) + + if result.Decision != DecisionAllow { + if ctx != nil { + if _, ok := (*ctx).Value(governanceRejectedContextKey).(bool); !ok { + *ctx = context.WithValue(*ctx, governanceRejectedContextKey, true) + } + } + } + + // Handle decision + switch result.Decision { + case DecisionAllow: + return req, nil, nil + + case DecisionVirtualKeyNotFound, DecisionVirtualKeyBlocked, DecisionModelBlocked, DecisionProviderBlocked: + return req, &schemas.PluginShortCircuit{ + Error: &schemas.BifrostError{ + Type: bifrost.Ptr(string(result.Decision)), + StatusCode: bifrost.Ptr(403), + Error: schemas.ErrorField{ + Message: result.Reason, + }, + }, + }, nil + + case DecisionRateLimited, DecisionTokenLimited, DecisionRequestLimited: + return req, &schemas.PluginShortCircuit{ + Error: &schemas.BifrostError{ + Type: bifrost.Ptr(string(result.Decision)), + StatusCode: bifrost.Ptr(429), + Error: schemas.ErrorField{ + Message: result.Reason, + }, + }, + }, nil + + case DecisionBudgetExceeded: + return req, &schemas.PluginShortCircuit{ + Error: &schemas.BifrostError{ + Type: bifrost.Ptr(string(result.Decision)), + StatusCode: bifrost.Ptr(402), + Error: schemas.ErrorField{ + Message: result.Reason, + }, + }, + }, nil + + default: + // Fallback to deny for unknown decisions + return req, &schemas.PluginShortCircuit{ + Error: &schemas.BifrostError{ + Type: bifrost.Ptr(string(result.Decision)), + Error: schemas.ErrorField{ + Message: "Governance decision error", + }, + }, + }, nil + } +} + +// PostHook processes the response and updates usage tracking (business logic execution) +func (p *GovernancePlugin) PostHook(ctx *context.Context, result *schemas.BifrostResponse, err *schemas.BifrostError) (*schemas.BifrostResponse, *schemas.BifrostError, error) { + if _, ok := (*ctx).Value(governanceRejectedContextKey).(bool); ok { + return result, err, nil + } + + // Extract governance information + headers := extractHeadersFromContext(*ctx) + virtualKey := getStringFromContext(*ctx, ContextKey("x-bf-vk")) + requestID := getStringFromContext(*ctx, schemas.BifrostContextKey("request-id")) + + // Skip if no virtual key + if virtualKey == "" { + return result, err, nil + } + + // Extract provider and model from stored context values (set in PreHook) + var provider schemas.ModelProvider + var model string + var requestType schemas.RequestType + + if providerValue := (*ctx).Value(schemas.BifrostContextKeyRequestProvider); providerValue != nil { + if p, ok := providerValue.(schemas.ModelProvider); ok { + provider = p + } + } + if modelValue := (*ctx).Value(schemas.BifrostContextKeyRequestModel); modelValue != nil { + if m, ok := modelValue.(string); ok { + model = m + } + } + if requestTypeValue := (*ctx).Value(schemas.BifrostContextKeyRequestType); requestTypeValue != nil { + if r, ok := requestTypeValue.(schemas.RequestType); ok { + requestType = r + } + } + + // If we couldn't get provider/model from context, skip usage tracking + if provider == "" || model == "" { + p.logger.Debug("Could not extract provider/model from context, skipping usage tracking") + return result, err, nil + } + + // Extract cache and batch flags from context + isCacheRead := false + isBatch := false + if val := (*ctx).Value(governanceIsCacheReadContextKey); val != nil { + if b, ok := val.(bool); ok { + isCacheRead = b + } + } + if val := (*ctx).Value(governanceIsBatchContextKey); val != nil { + if b, ok := val.(bool); ok { + isBatch = b + } + } + + // Extract team/customer info for audit trail + var teamID, customerID *string + if teamIDValue := headers["x-bf-team"]; teamIDValue != "" { + teamID = &teamIDValue + } + if customerIDValue := headers["x-bf-customer"]; customerIDValue != "" { + customerID = &customerIDValue + } + + go p.postHookWorker(result, provider, model, requestType, virtualKey, requestID, teamID, customerID, isCacheRead, isBatch, bifrost.IsFinalChunk(ctx)) + + return result, err, nil +} + +// Cleanup shuts down all components gracefully +func (p *GovernancePlugin) Cleanup() error { + if err := p.tracker.Cleanup(); err != nil { + return err + } + + return nil +} + +func (p *GovernancePlugin) postHookWorker(result *schemas.BifrostResponse, provider schemas.ModelProvider, model string, requestType schemas.RequestType, virtualKey, requestID string, teamID, customerID *string, isCacheRead, isBatch bool, isFinalChunk bool) { + // Determine if request was successful + success := (result != nil) + + // Streaming detection + isStreaming := bifrost.IsStreamRequestType(requestType) + hasUsageData := hasUsageData(result) + + // Extract usage information from response (including speech and transcribe) + var tokensUsed int64 + + if result != nil { + if result.Usage != nil { + tokensUsed = int64(result.Usage.TotalTokens) + } else if result.Speech != nil && result.Speech.Usage != nil { + tokensUsed = int64(result.Speech.Usage.TotalTokens) + } else if result.Transcribe != nil && result.Transcribe.Usage != nil && result.Transcribe.Usage.TotalTokens != nil { + tokensUsed = int64(*result.Transcribe.Usage.TotalTokens) + } + } + + cost := 0.0 + if !isStreaming || (isStreaming && isFinalChunk) { + if p.pricingManager != nil { + cost = p.pricingManager.CalculateCost(result, provider, model, requestType) + } + } + + // Create usage update for tracker (business logic) + usageUpdate := &UsageUpdate{ + VirtualKey: virtualKey, + Provider: provider, + Model: model, + Success: success, + TokensUsed: tokensUsed, + Cost: cost, + RequestID: requestID, + TeamID: teamID, + CustomerID: customerID, + IsStreaming: isStreaming, + IsFinalChunk: isFinalChunk, + HasUsageData: hasUsageData, + } + + // Queue usage update asynchronously using tracker + p.tracker.UpdateUsage(usageUpdate) +} + +// GetGovernanceStore returns the governance store +func (p *GovernancePlugin) GetGovernanceStore() *GovernanceStore { + return p.store +} diff --git a/plugins/governance/resolver.go b/plugins/governance/resolver.go new file mode 100644 index 000000000..09672123b --- /dev/null +++ b/plugins/governance/resolver.go @@ -0,0 +1,231 @@ +// Package governance provides the budget evaluation and decision engine +package governance + +import ( + "context" + "fmt" + "slices" + "strings" + + "github.com/maximhq/bifrost/core/schemas" + "github.com/maximhq/bifrost/framework/configstore" +) + +// Decision represents the result of governance evaluation +type Decision string + +const ( + DecisionAllow Decision = "allow" + DecisionVirtualKeyNotFound Decision = "virtual_key_not_found" + DecisionVirtualKeyBlocked Decision = "virtual_key_blocked" + DecisionRateLimited Decision = "rate_limited" + DecisionBudgetExceeded Decision = "budget_exceeded" + DecisionTokenLimited Decision = "token_limited" + DecisionRequestLimited Decision = "request_limited" + DecisionModelBlocked Decision = "model_blocked" + DecisionProviderBlocked Decision = "provider_blocked" +) + +// EvaluationRequest contains the context for evaluating a request +type EvaluationRequest struct { + VirtualKey string `json:"virtual_key"` + Provider schemas.ModelProvider `json:"provider"` + Model string `json:"model"` + Headers map[string]string `json:"headers"` + RequestID string `json:"request_id"` +} + +// EvaluationResult contains the complete result of governance evaluation +type EvaluationResult struct { + Decision Decision `json:"decision"` + Reason string `json:"reason"` + VirtualKey *configstore.TableVirtualKey `json:"virtual_key,omitempty"` + RateLimitInfo *configstore.TableRateLimit `json:"rate_limit_info,omitempty"` + BudgetInfo []*configstore.TableBudget `json:"budget_info,omitempty"` // All budgets in hierarchy + UsageInfo *UsageInfo `json:"usage_info,omitempty"` +} + +// UsageInfo represents current usage levels for rate limits and budgets +type UsageInfo struct { + // Rate limit usage + TokensUsedMinute int64 `json:"tokens_used_minute"` + TokensUsedHour int64 `json:"tokens_used_hour"` + TokensUsedDay int64 `json:"tokens_used_day"` + RequestsUsedMinute int64 `json:"requests_used_minute"` + RequestsUsedHour int64 `json:"requests_used_hour"` + RequestsUsedDay int64 `json:"requests_used_day"` + + // Budget usage + VKBudgetUsage int64 `json:"vk_budget_usage"` + TeamBudgetUsage int64 `json:"team_budget_usage"` + CustomerBudgetUsage int64 `json:"customer_budget_usage"` +} + +// BudgetResolver provides decision logic for the new hierarchical governance system +type BudgetResolver struct { + store *GovernanceStore + logger schemas.Logger +} + +// NewBudgetResolver creates a new budget-based governance resolver +func NewBudgetResolver(store *GovernanceStore, logger schemas.Logger) *BudgetResolver { + return &BudgetResolver{ + store: store, + logger: logger, + } +} + +// EvaluateRequest evaluates a request against the new hierarchical governance system +func (r *BudgetResolver) EvaluateRequest(ctx *context.Context, evaluationRequest *EvaluationRequest) *EvaluationResult { + // 1. Validate virtual key exists and is active + vk, exists := r.store.GetVirtualKey(evaluationRequest.VirtualKey) + if !exists { + return &EvaluationResult{ + Decision: DecisionVirtualKeyNotFound, + Reason: "Virtual key not found", + } + } + + if !vk.IsActive { + return &EvaluationResult{ + Decision: DecisionVirtualKeyBlocked, + Reason: "Virtual key is inactive", + } + } + + // 2. Check model filtering + if !r.isModelAllowed(vk, evaluationRequest.Model) { + return &EvaluationResult{ + Decision: DecisionModelBlocked, + Reason: fmt.Sprintf("Model '%s' is not allowed for this virtual key", evaluationRequest.Model), + VirtualKey: vk, + } + } + + // 3. Check provider filtering + if !r.isProviderAllowed(vk, evaluationRequest.Provider) { + return &EvaluationResult{ + Decision: DecisionProviderBlocked, + Reason: fmt.Sprintf("Provider '%s' is not allowed for this virtual key", evaluationRequest.Provider), + VirtualKey: vk, + } + } + + // 4. Check rate limits (VK level only) + if rateLimitResult := r.checkRateLimits(vk); rateLimitResult != nil { + return rateLimitResult + } + + // 5. Check budget hierarchy (VK β†’ Team β†’ Customer) + if budgetResult := r.checkBudgetHierarchy(vk); budgetResult != nil { + return budgetResult + } + + if vk.Keys != nil { + includeOnlyKeys := make([]string, 0, len(vk.Keys)) + for _, dbKey := range vk.Keys { + includeOnlyKeys = append(includeOnlyKeys, dbKey.KeyID) + } + + if len(includeOnlyKeys) > 0 { + *ctx = context.WithValue(*ctx, schemas.BifrostContextKey("bf-governance-include-only-keys"), includeOnlyKeys) + } + } + + // All checks passed + return &EvaluationResult{ + Decision: DecisionAllow, + Reason: "Request allowed by governance policy", + VirtualKey: vk, + } +} + +// isModelAllowed checks if the requested model is allowed for this VK +func (r *BudgetResolver) isModelAllowed(vk *configstore.TableVirtualKey, model string) bool { + // Empty AllowedModels means all models are allowed + if len(vk.AllowedModels) == 0 { + return true + } + + return slices.Contains(vk.AllowedModels, model) +} + +// isProviderAllowed checks if the requested provider is allowed for this VK +func (r *BudgetResolver) isProviderAllowed(vk *configstore.TableVirtualKey, provider schemas.ModelProvider) bool { + // Empty AllowedProviders means all providers are allowed + if len(vk.AllowedProviders) == 0 { + return true + } + + return slices.Contains(vk.AllowedProviders, string(provider)) +} + +// checkRateLimits checks the VK's rate limits using flexible approach +func (r *BudgetResolver) checkRateLimits(vk *configstore.TableVirtualKey) *EvaluationResult { + // No rate limits defined + if vk.RateLimit == nil { + return nil + } + + rateLimit := vk.RateLimit + + // Check if any rate limits are exceeded + var violations []string + + // Token limits + if rateLimit.TokenMaxLimit != nil && rateLimit.TokenCurrentUsage >= *rateLimit.TokenMaxLimit { + duration := "unknown" + if rateLimit.TokenResetDuration != nil { + duration = *rateLimit.TokenResetDuration + } + violations = append(violations, fmt.Sprintf("token limit exceeded (%d/%d, resets every %s)", + rateLimit.TokenCurrentUsage, *rateLimit.TokenMaxLimit, duration)) + } + + // Request limits + if rateLimit.RequestMaxLimit != nil && rateLimit.RequestCurrentUsage >= *rateLimit.RequestMaxLimit { + duration := "unknown" + if rateLimit.RequestResetDuration != nil { + duration = *rateLimit.RequestResetDuration + } + violations = append(violations, fmt.Sprintf("request limit exceeded (%d/%d, resets every %s)", + rateLimit.RequestCurrentUsage, *rateLimit.RequestMaxLimit, duration)) + } + + if len(violations) > 0 { + // Determine specific violation type + decision := DecisionRateLimited + if len(violations) == 1 { + if strings.Contains(violations[0], "token") { + decision = DecisionTokenLimited + } else if strings.Contains(violations[0], "request") { + decision = DecisionRequestLimited + } + } + + return &EvaluationResult{ + Decision: decision, + Reason: fmt.Sprintf("Rate limits exceeded: %v", violations), + VirtualKey: vk, + RateLimitInfo: rateLimit, + } + } + + return nil // No rate limit violations +} + +// checkBudgetHierarchy checks the budget hierarchy atomically (VK β†’ Team β†’ Customer) +func (r *BudgetResolver) checkBudgetHierarchy(vk *configstore.TableVirtualKey) *EvaluationResult { + // Use atomic budget checking to prevent race conditions + if err := r.store.CheckBudget(vk); err != nil { + r.logger.Debug(fmt.Sprintf("Atomic budget check failed for VK %s: %s", vk.ID, err.Error())) + + return &EvaluationResult{ + Decision: DecisionBudgetExceeded, + Reason: fmt.Sprintf("Budget check failed: %s", err.Error()), + VirtualKey: vk, + } + } + + return nil // No budget violations +} diff --git a/plugins/governance/store.go b/plugins/governance/store.go new file mode 100644 index 000000000..d12f81e1a --- /dev/null +++ b/plugins/governance/store.go @@ -0,0 +1,681 @@ +// Package governance provides the in-memory cache store for fast governance data access +package governance + +import ( + "fmt" + "sync" + "time" + + "github.com/maximhq/bifrost/core/schemas" + "github.com/maximhq/bifrost/framework/configstore" + "gorm.io/gorm" + "gorm.io/gorm/clause" +) + +// GovernanceStore provides in-memory cache for governance data with fast, non-blocking access +type GovernanceStore struct { + // Core data maps using sync.Map for lock-free reads + virtualKeys sync.Map // string -> *VirtualKey (VK value -> VirtualKey with preloaded relationships) + teams sync.Map // string -> *Team (Team ID -> Team) + customers sync.Map // string -> *Customer (Customer ID -> Customer) + budgets sync.Map // string -> *Budget (Budget ID -> Budget) + + // Config store for refresh operations + configStore configstore.ConfigStore + + // Logger + logger schemas.Logger +} + +// NewGovernanceStore creates a new in-memory governance store +func NewGovernanceStore(logger schemas.Logger, configStore configstore.ConfigStore, governanceConfig *configstore.GovernanceConfig) (*GovernanceStore, error) { + store := &GovernanceStore{ + configStore: configStore, + logger: logger, + } + + if configStore != nil { + // Load initial data from database + if err := store.loadFromDatabase(); err != nil { + return nil, fmt.Errorf("failed to load initial data: %w", err) + } + } else { + if err := store.loadFromConfigMemory(governanceConfig); err != nil { + return nil, fmt.Errorf("failed to load governance data from config memory: %w", err) + } + } + + store.logger.Info("governance store initialized successfully") + return store, nil +} + +// GetVirtualKey retrieves a virtual key by its value (lock-free) with all relationships preloaded +func (gs *GovernanceStore) GetVirtualKey(vkValue string) (*configstore.TableVirtualKey, bool) { + value, exists := gs.virtualKeys.Load(vkValue) + if !exists || value == nil { + return nil, false + } + + vk, ok := value.(*configstore.TableVirtualKey) + if !ok || vk == nil { + return nil, false + } + return vk, true +} + +// GetAllBudgets returns all budgets (for background reset operations) +func (gs *GovernanceStore) GetAllBudgets() map[string]*configstore.TableBudget { + result := make(map[string]*configstore.TableBudget) + gs.budgets.Range(func(key, value interface{}) bool { + // Type-safe conversion + keyStr, keyOk := key.(string) + budget, budgetOk := value.(*configstore.TableBudget) + + if keyOk && budgetOk && budget != nil { + result[keyStr] = budget + } + return true // continue iteration + }) + return result +} + +// CheckBudget performs budget checking using in-memory store data (lock-free for high performance) +func (gs *GovernanceStore) CheckBudget(vk *configstore.TableVirtualKey) error { + if vk == nil { + return fmt.Errorf("virtual key cannot be nil") + } + + // Use helper to collect budgets and their names (lock-free) + budgetsToCheck, budgetNames := gs.collectBudgetsFromHierarchy(vk) + + // Check each budget in hierarchy order using in-memory data + for i, budget := range budgetsToCheck { + // Check if budget needs reset (in-memory check) + if budget.ResetDuration != "" { + if duration, err := configstore.ParseDuration(budget.ResetDuration); err == nil { + if time.Since(budget.LastReset).Round(time.Millisecond) >= duration { + // Budget expired but hasn't been reset yet - treat as reset + // Note: actual reset will happen in post-hook via AtomicBudgetUpdate + continue // Skip budget check for expired budgets + } + } + } + + // Check if current usage exceeds budget limit + if budget.CurrentUsage > budget.MaxLimit { + return fmt.Errorf("%s budget exceeded: %.4f > %.4f dollars", + budgetNames[i], budget.CurrentUsage, budget.MaxLimit) + } + } + + return nil +} + +// UpdateBudget performs atomic budget updates across the hierarchy (both in memory and in database) +func (gs *GovernanceStore) UpdateBudget(vk *configstore.TableVirtualKey, cost float64) error { + if vk == nil { + return fmt.Errorf("virtual key cannot be nil") + } + + // Collect budget IDs using fast in-memory lookup instead of DB queries + budgetIDs := gs.collectBudgetIDsFromMemory(vk) + + if gs.configStore == nil { + for _, budgetID := range budgetIDs { + // Update in-memory cache for next read (lock-free) + if cachedBudgetValue, exists := gs.budgets.Load(budgetID); exists && cachedBudgetValue != nil { + if cachedBudget, ok := cachedBudgetValue.(*configstore.TableBudget); ok && cachedBudget != nil { + clone := *cachedBudget + clone.CurrentUsage += cost + gs.budgets.Store(budgetID, &clone) + } + } + } + + return nil + } + + return gs.configStore.ExecuteTransaction(func(tx *gorm.DB) error { + // budgetIDs already collected from in-memory data - no need to duplicate + + // Update each budget atomically + for _, budgetID := range budgetIDs { + var budget configstore.TableBudget + if err := tx.Clauses(clause.Locking{Strength: "UPDATE"}).First(&budget, "id = ?", budgetID).Error; err != nil { + return fmt.Errorf("failed to lock budget %s: %w", budgetID, err) + } + + // Check if budget needs reset + if err := gs.resetBudgetIfNeeded(tx, &budget); err != nil { + return fmt.Errorf("failed to reset budget: %w", err) + } + + // Update usage + budget.CurrentUsage += cost + if err := gs.configStore.UpdateBudget(&budget, tx); err != nil { + return fmt.Errorf("failed to save budget %s: %w", budgetID, err) + } + + // Update in-memory cache for next read (lock-free) + if cachedBudgetValue, exists := gs.budgets.Load(budgetID); exists && cachedBudgetValue != nil { + if cachedBudget, ok := cachedBudgetValue.(*configstore.TableBudget); ok && cachedBudget != nil { + clone := *cachedBudget + clone.CurrentUsage += cost + clone.LastReset = budget.LastReset + gs.budgets.Store(budgetID, &clone) + } + } + } + + return nil + }) +} + +// UpdateRateLimitUsage updates rate limit counters (lock-free) +func (gs *GovernanceStore) UpdateRateLimitUsage(vkValue string, tokensUsed int64, shouldUpdateTokens bool, shouldUpdateRequests bool) error { + if vkValue == "" { + return fmt.Errorf("virtual key value cannot be empty") + } + + vkValue_, exists := gs.virtualKeys.Load(vkValue) + if !exists || vkValue_ == nil { + return fmt.Errorf("virtual key not found: %s", vkValue) + } + + vk, ok := vkValue_.(*configstore.TableVirtualKey) + if !ok || vk == nil { + return fmt.Errorf("invalid virtual key type for: %s", vkValue) + } + if vk.RateLimit == nil { + return nil // No rate limit configured, nothing to update + } + + rateLimit := vk.RateLimit + now := time.Now() + updated := false + + // Check and reset token counter if needed + if rateLimit.TokenResetDuration != nil { + if duration, err := configstore.ParseDuration(*rateLimit.TokenResetDuration); err == nil { + if now.Sub(rateLimit.TokenLastReset) >= duration { + rateLimit.TokenCurrentUsage = 0 + rateLimit.TokenLastReset = now + updated = true + } + } + } + + // Check and reset request counter if needed + if rateLimit.RequestResetDuration != nil { + if duration, err := configstore.ParseDuration(*rateLimit.RequestResetDuration); err == nil { + if now.Sub(rateLimit.RequestLastReset) >= duration { + rateLimit.RequestCurrentUsage = 0 + rateLimit.RequestLastReset = now + updated = true + } + } + } + + // Update usage counters based on flags + if shouldUpdateTokens && tokensUsed > 0 { + rateLimit.TokenCurrentUsage += tokensUsed + updated = true + } + + if shouldUpdateRequests { + rateLimit.RequestCurrentUsage += 1 + updated = true + } + + // Save to database only if something changed + if updated && gs.configStore != nil { + if err := gs.configStore.UpdateRateLimit(rateLimit); err != nil { + return fmt.Errorf("failed to update rate limit usage: %w", err) + } + } + + return nil +} + +// checkAndResetSingleRateLimit checks and resets a single rate limit's counters if expired +func (gs *GovernanceStore) checkAndResetSingleRateLimit(rateLimit *configstore.TableRateLimit, now time.Time) bool { + updated := false + + // Check and reset token counter if needed + if rateLimit.TokenResetDuration != nil { + if duration, err := configstore.ParseDuration(*rateLimit.TokenResetDuration); err == nil { + if now.Sub(rateLimit.TokenLastReset).Round(time.Millisecond) >= duration { + rateLimit.TokenCurrentUsage = 0 + rateLimit.TokenLastReset = now + updated = true + } + } + } + + // Check and reset request counter if needed + if rateLimit.RequestResetDuration != nil { + if duration, err := configstore.ParseDuration(*rateLimit.RequestResetDuration); err == nil { + if now.Sub(rateLimit.RequestLastReset).Round(time.Millisecond) >= duration { + rateLimit.RequestCurrentUsage = 0 + rateLimit.RequestLastReset = now + updated = true + } + } + } + + return updated +} + +// ResetExpiredRateLimits performs background reset of expired rate limits (lock-free) +func (gs *GovernanceStore) ResetExpiredRateLimits() error { + now := time.Now() + var resetRateLimits []*configstore.TableRateLimit + + gs.virtualKeys.Range(func(key, value interface{}) bool { + // Type-safe conversion + vk, ok := value.(*configstore.TableVirtualKey) + if !ok || vk == nil || vk.RateLimit == nil { + return true // continue + } + + rateLimit := vk.RateLimit + + // Use helper method to check and reset rate limit + if gs.checkAndResetSingleRateLimit(rateLimit, now) { + resetRateLimits = append(resetRateLimits, rateLimit) + } + return true // continue + }) + + // Persist reset rate limits to database + if len(resetRateLimits) > 0 && gs.configStore != nil { + if err := gs.configStore.UpdateRateLimits(resetRateLimits); err != nil { + return fmt.Errorf("failed to persist rate limit resets to database: %w", err) + } + } + + return nil +} + +// ResetExpiredBudgets checks and resets budgets that have exceeded their reset duration (lock-free) +func (gs *GovernanceStore) ResetExpiredBudgets() error { + now := time.Now() + var resetBudgets []*configstore.TableBudget + + gs.budgets.Range(func(key, value interface{}) bool { + // Type-safe conversion + budget, ok := value.(*configstore.TableBudget) + if !ok || budget == nil { + return true // continue + } + + duration, err := configstore.ParseDuration(budget.ResetDuration) + if err != nil { + gs.logger.Error("invalid budget reset duration %s: %w", budget.ResetDuration, err) + return true // continue + } + + if now.Sub(budget.LastReset) >= duration { + oldUsage := budget.CurrentUsage + budget.CurrentUsage = 0 + budget.LastReset = now + resetBudgets = append(resetBudgets, budget) + + gs.logger.Debug(fmt.Sprintf("Reset budget %s (was %.2f, reset to 0)", + budget.ID, oldUsage)) + } + return true // continue + }) + + // Persist to database if any resets occurred + if len(resetBudgets) > 0 && gs.configStore != nil { + if err := gs.configStore.UpdateBudgets(resetBudgets); err != nil { + return fmt.Errorf("failed to persist budget resets to database: %w", err) + } + } + + return nil +} + +// DATABASE METHODS + +// loadFromDatabase loads all governance data from the database into memory +func (gs *GovernanceStore) loadFromDatabase() error { + // Load customers with their budgets + customers, err := gs.configStore.GetCustomers() + if err != nil { + return fmt.Errorf("failed to load customers: %w", err) + } + + // Load teams with their budgets + teams, err := gs.configStore.GetTeams("") + if err != nil { + return fmt.Errorf("failed to load teams: %w", err) + } + + // Load virtual keys with all relationships + virtualKeys, err := gs.configStore.GetVirtualKeys() + if err != nil { + return fmt.Errorf("failed to load virtual keys: %w", err) + } + + // Load budgets + budgets, err := gs.configStore.GetBudgets() + if err != nil { + return fmt.Errorf("failed to load budgets: %w", err) + } + + // Rebuild in-memory structures (lock-free) + gs.rebuildInMemoryStructures(customers, teams, virtualKeys, budgets) + + return nil +} + +// loadFromConfigMemory loads all governance data from the config's memory into store's memory +func (gs *GovernanceStore) loadFromConfigMemory(config *configstore.GovernanceConfig) error { + if config == nil { + return fmt.Errorf("governance config is nil") + } + + // Load customers with their budgets + customers := config.Customers + + // Load teams with their budgets + teams := config.Teams + + // Load budgets + budgets := config.Budgets + + // Load virtual keys with all relationships + virtualKeys := config.VirtualKeys + + // Load rate limits + rateLimits := config.RateLimits + + // Populate virtual keys with their relationships + for i := range virtualKeys { + vk := &virtualKeys[i] + + for i := range teams { + if vk.TeamID != nil && teams[i].ID == *vk.TeamID { + vk.Team = &teams[i] + } + } + + for i := range customers { + if vk.CustomerID != nil && customers[i].ID == *vk.CustomerID { + vk.Customer = &customers[i] + } + } + + for i := range budgets { + if vk.BudgetID != nil && budgets[i].ID == *vk.BudgetID { + vk.Budget = &budgets[i] + } + } + + for i := range rateLimits { + if vk.RateLimitID != nil && rateLimits[i].ID == *vk.RateLimitID { + vk.RateLimit = &rateLimits[i] + } + } + + virtualKeys[i] = *vk + } + + // Rebuild in-memory structures (lock-free) + gs.rebuildInMemoryStructures(customers, teams, virtualKeys, budgets) + + return nil +} + +// rebuildInMemoryStructures rebuilds all in-memory data structures (lock-free) +func (gs *GovernanceStore) rebuildInMemoryStructures(customers []configstore.TableCustomer, teams []configstore.TableTeam, virtualKeys []configstore.TableVirtualKey, budgets []configstore.TableBudget) { + // Clear existing data by creating new sync.Maps + gs.virtualKeys = sync.Map{} + gs.teams = sync.Map{} + gs.customers = sync.Map{} + gs.budgets = sync.Map{} + + // Build customers map + for i := range customers { + customer := &customers[i] + gs.customers.Store(customer.ID, customer) + } + + // Build teams map + for i := range teams { + team := &teams[i] + gs.teams.Store(team.ID, team) + } + + // Build budgets map + for i := range budgets { + budget := &budgets[i] + gs.budgets.Store(budget.ID, budget) + } + + // Build virtual keys map and track active VKs + for i := range virtualKeys { + vk := &virtualKeys[i] + gs.virtualKeys.Store(vk.Value, vk) + } +} + +// UTILITY FUNCTIONS + +// collectBudgetsFromHierarchy collects budgets and their metadata from the hierarchy (VK β†’ Team β†’ Customer) +func (gs *GovernanceStore) collectBudgetsFromHierarchy(vk *configstore.TableVirtualKey) ([]*configstore.TableBudget, []string) { + if vk == nil { + return nil, nil + } + + var budgets []*configstore.TableBudget + var budgetNames []string + + // Collect all budgets in hierarchy order using lock-free sync.Map access (VK β†’ Team β†’ Customer) + if vk.BudgetID != nil { + if budgetValue, exists := gs.budgets.Load(*vk.BudgetID); exists && budgetValue != nil { + if budget, ok := budgetValue.(*configstore.TableBudget); ok && budget != nil { + budgets = append(budgets, budget) + budgetNames = append(budgetNames, "VK") + } + } + } + + if vk.TeamID != nil { + if teamValue, exists := gs.teams.Load(*vk.TeamID); exists && teamValue != nil { + if team, ok := teamValue.(*configstore.TableTeam); ok && team != nil { + if team.BudgetID != nil { + if budgetValue, exists := gs.budgets.Load(*team.BudgetID); exists && budgetValue != nil { + if budget, ok := budgetValue.(*configstore.TableBudget); ok && budget != nil { + budgets = append(budgets, budget) + budgetNames = append(budgetNames, "Team") + } + } + } + + // Check if team belongs to a customer + if team.CustomerID != nil { + if customerValue, exists := gs.customers.Load(*team.CustomerID); exists && customerValue != nil { + if customer, ok := customerValue.(*configstore.TableCustomer); ok && customer != nil { + if customer.BudgetID != nil { + if budgetValue, exists := gs.budgets.Load(*customer.BudgetID); exists && budgetValue != nil { + if budget, ok := budgetValue.(*configstore.TableBudget); ok && budget != nil { + budgets = append(budgets, budget) + budgetNames = append(budgetNames, "Customer") + } + } + } + } + } + } + } + } + } + + if vk.CustomerID != nil { + if customerValue, exists := gs.customers.Load(*vk.CustomerID); exists && customerValue != nil { + if customer, ok := customerValue.(*configstore.TableCustomer); ok && customer != nil { + if customer.BudgetID != nil { + if budgetValue, exists := gs.budgets.Load(*customer.BudgetID); exists && budgetValue != nil { + if budget, ok := budgetValue.(*configstore.TableBudget); ok && budget != nil { + budgets = append(budgets, budget) + budgetNames = append(budgetNames, "Customer") + } + } + } + } + } + } + + return budgets, budgetNames +} + +// collectBudgetIDsFromMemory collects budget IDs from in-memory store data (lock-free) +func (gs *GovernanceStore) collectBudgetIDsFromMemory(vk *configstore.TableVirtualKey) []string { + budgets, _ := gs.collectBudgetsFromHierarchy(vk) + + budgetIDs := make([]string, len(budgets)) + for i, budget := range budgets { + budgetIDs[i] = budget.ID + } + + return budgetIDs +} + +// resetBudgetIfNeeded checks and resets budget within a transaction +func (gs *GovernanceStore) resetBudgetIfNeeded(tx *gorm.DB, budget *configstore.TableBudget) error { + duration, err := configstore.ParseDuration(budget.ResetDuration) + if err != nil { + return fmt.Errorf("invalid reset duration %s: %w", budget.ResetDuration, err) + } + + now := time.Now() + if now.Sub(budget.LastReset) >= duration { + budget.CurrentUsage = 0 + budget.LastReset = now + + if gs.configStore != nil { + // Save reset to database + if err := gs.configStore.UpdateBudget(budget, tx); err != nil { + return fmt.Errorf("failed to save budget reset: %w", err) + } + } + } + + return nil +} + +// PUBLIC API METHODS + +// CreateVirtualKeyInMemory adds a new virtual key to the in-memory store (lock-free) +func (gs *GovernanceStore) CreateVirtualKeyInMemory(vk *configstore.TableVirtualKey) { // with rateLimit preloaded + if vk == nil { + return // Nothing to create + } + gs.virtualKeys.Store(vk.Value, vk) +} + +// UpdateVirtualKeyInMemory updates an existing virtual key in the in-memory store (lock-free) +func (gs *GovernanceStore) UpdateVirtualKeyInMemory(vk *configstore.TableVirtualKey) { // with rateLimit preloaded + if vk == nil { + return // Nothing to update + } + gs.virtualKeys.Store(vk.Value, vk) +} + +// DeleteVirtualKeyInMemory removes a virtual key from the in-memory store +func (gs *GovernanceStore) DeleteVirtualKeyInMemory(vkID string) { + if vkID == "" { + return // Nothing to delete + } + + // Find and delete the VK by ID (lock-free) + gs.virtualKeys.Range(func(key, value interface{}) bool { + // Type-safe conversion + vk, ok := value.(*configstore.TableVirtualKey) + if !ok || vk == nil { + return true // continue iteration + } + + if vk.ID == vkID { + gs.virtualKeys.Delete(key) + return false // stop iteration + } + return true // continue iteration + }) +} + +// CreateTeamInMemory adds a new team to the in-memory store (lock-free) +func (gs *GovernanceStore) CreateTeamInMemory(team *configstore.TableTeam) { + if team == nil { + return // Nothing to create + } + gs.teams.Store(team.ID, team) +} + +// UpdateTeamInMemory updates an existing team in the in-memory store (lock-free) +func (gs *GovernanceStore) UpdateTeamInMemory(team *configstore.TableTeam) { + if team == nil { + return // Nothing to update + } + gs.teams.Store(team.ID, team) +} + +// DeleteTeamInMemory removes a team from the in-memory store (lock-free) +func (gs *GovernanceStore) DeleteTeamInMemory(teamID string) { + if teamID == "" { + return // Nothing to delete + } + gs.teams.Delete(teamID) +} + +// CreateCustomerInMemory adds a new customer to the in-memory store (lock-free) +func (gs *GovernanceStore) CreateCustomerInMemory(customer *configstore.TableCustomer) { + if customer == nil { + return // Nothing to create + } + gs.customers.Store(customer.ID, customer) +} + +// UpdateCustomerInMemory updates an existing customer in the in-memory store (lock-free) +func (gs *GovernanceStore) UpdateCustomerInMemory(customer *configstore.TableCustomer) { + if customer == nil { + return // Nothing to update + } + gs.customers.Store(customer.ID, customer) +} + +// DeleteCustomerInMemory removes a customer from the in-memory store (lock-free) +func (gs *GovernanceStore) DeleteCustomerInMemory(customerID string) { + if customerID == "" { + return // Nothing to delete + } + gs.customers.Delete(customerID) +} + +// CreateBudgetInMemory adds a new budget to the in-memory store (lock-free) +func (gs *GovernanceStore) CreateBudgetInMemory(budget *configstore.TableBudget) { + if budget == nil { + return // Nothing to create + } + gs.budgets.Store(budget.ID, budget) +} + +// UpdateBudgetInMemory updates a specific budget in the in-memory cache (lock-free) +func (gs *GovernanceStore) UpdateBudgetInMemory(budget *configstore.TableBudget) error { + if budget == nil { + return fmt.Errorf("budget cannot be nil") + } + gs.budgets.Store(budget.ID, budget) + return nil +} + +// DeleteBudgetInMemory removes a budget from the in-memory store (lock-free) +func (gs *GovernanceStore) DeleteBudgetInMemory(budgetID string) { + if budgetID == "" { + return // Nothing to delete + } + gs.budgets.Delete(budgetID) +} diff --git a/plugins/governance/tracker.go b/plugins/governance/tracker.go new file mode 100644 index 000000000..bf5b4fdc5 --- /dev/null +++ b/plugins/governance/tracker.go @@ -0,0 +1,245 @@ +// Package governance provides simplified usage tracking for the new hierarchical system +package governance + +import ( + "fmt" + "sync" + "time" + + "github.com/maximhq/bifrost/core/schemas" + "github.com/maximhq/bifrost/framework/configstore" +) + +// UsageUpdate contains data for VK-level usage tracking +type UsageUpdate struct { + VirtualKey string `json:"virtual_key"` + Provider schemas.ModelProvider `json:"provider"` + Model string `json:"model"` + Success bool `json:"success"` + TokensUsed int64 `json:"tokens_used"` + Cost float64 `json:"cost"` // Cost in dollars + RequestID string `json:"request_id"` + TeamID *string `json:"team_id,omitempty"` // For audit trail + CustomerID *string `json:"customer_id,omitempty"` // For audit trail + + // Streaming optimization fields + IsStreaming bool `json:"is_streaming"` // Whether this is a streaming response + IsFinalChunk bool `json:"is_final_chunk"` // Whether this is the final chunk + HasUsageData bool `json:"has_usage_data"` // Whether this chunk contains usage data +} + +// UsageTracker manages VK-level usage tracking and budget management +type UsageTracker struct { + store *GovernanceStore + resolver *BudgetResolver + configStore configstore.ConfigStore + logger schemas.Logger + + // Background workers + resetTicker *time.Ticker + done chan struct{} + wg sync.WaitGroup +} + +// NewUsageTracker creates a new usage tracker for the hierarchical budget system +func NewUsageTracker(store *GovernanceStore, resolver *BudgetResolver, configStore configstore.ConfigStore, logger schemas.Logger) *UsageTracker { + tracker := &UsageTracker{ + store: store, + resolver: resolver, + configStore: configStore, + logger: logger, + done: make(chan struct{}), + } + + // Start background workers for business logic + tracker.startWorkers() + + tracker.logger.Info("usage tracker initialized for hierarchical budget system") + return tracker +} + +// UpdateUsage queues a usage update for async processing (main business entry point) +func (t *UsageTracker) UpdateUsage(update *UsageUpdate) { + // Get virtual key + vk, exists := t.store.GetVirtualKey(update.VirtualKey) + if !exists { + t.logger.Debug(fmt.Sprintf("Virtual key not found: %s", update.VirtualKey)) + return + } + + // Only process successful requests for usage tracking + if !update.Success { + t.logger.Debug(fmt.Sprintf("Request was not successful, skipping usage update for VK: %s", vk.ID)) + return + } + + // Streaming optimization: only process certain updates based on streaming status + shouldUpdateTokens := !update.IsStreaming || (update.IsStreaming && update.HasUsageData) + shouldUpdateRequests := !update.IsStreaming || (update.IsStreaming && update.IsFinalChunk) + shouldUpdateBudget := !update.IsStreaming || (update.IsStreaming && update.HasUsageData) + + // Update VK rate limit usage if applicable + if vk.RateLimit != nil { + if err := t.store.UpdateRateLimitUsage(update.VirtualKey, update.TokensUsed, shouldUpdateTokens, shouldUpdateRequests); err != nil { + t.logger.Error("failed to update rate limit usage for VK %s: %v", vk.ID, err) + } + } + + // Update budget usage in hierarchy (VK β†’ Team β†’ Customer) only if we have usage data + if shouldUpdateBudget && update.Cost > 0 { + t.updateBudgetHierarchy(vk, update) + } +} + +// updateBudgetHierarchy updates budget usage atomically in the VK β†’ Team β†’ Customer hierarchy +func (t *UsageTracker) updateBudgetHierarchy(vk *configstore.TableVirtualKey, update *UsageUpdate) { + // Use atomic budget update to prevent race conditions and ensure consistency + if err := t.store.UpdateBudget(vk, update.Cost); err != nil { + t.logger.Error("failed to update budget hierarchy atomically for VK %s: %v", vk.ID, err) + } +} + +// startWorkers starts all background workers for business logic +func (t *UsageTracker) startWorkers() { + // Counter reset manager (business logic) + t.resetTicker = time.NewTicker(1 * time.Minute) + t.wg.Add(1) + go t.resetWorker() +} + +// resetWorker manages periodic resets of rate limit and usage counters +func (t *UsageTracker) resetWorker() { + defer t.wg.Done() + + for { + select { + case <-t.resetTicker.C: + t.resetExpiredCounters() + + case <-t.done: + return + } + } +} + +// resetExpiredCounters manages periodic resets of usage counters AND budgets using flexible durations +func (t *UsageTracker) resetExpiredCounters() { + // ==== PART 1: Reset Rate Limits ==== + if err := t.store.ResetExpiredRateLimits(); err != nil { + t.logger.Error("failed to reset expired rate limits: %v", err) + } + + // ==== PART 2: Reset Budgets ==== + if err := t.store.ResetExpiredBudgets(); err != nil { + t.logger.Error("failed to reset expired budgets: %v", err) + } +} + +// Public methods for monitoring and admin operations + +// PerformStartupResets checks and resets any expired rate limits and budgets on startup +func (t *UsageTracker) PerformStartupResets() error { + if t.configStore == nil { + t.logger.Warn("config store is not available, skipping initialization of usage tracker") + return nil + } + + t.logger.Info("performing startup reset check for expired rate limits and budgets") + now := time.Now() + + var resetRateLimits []*configstore.TableRateLimit + var errs []string + var vksWithRateLimits int + var vksWithoutRateLimits int + + // ==== RESET EXPIRED RATE LIMITS ==== + // Check ALL virtual keys (both active and inactive) for expired rate limits + allVKs, err := t.configStore.GetVirtualKeys() + if err != nil { + errs = append(errs, fmt.Sprintf("failed to load virtual keys for reset: %s", err.Error())) + } else { + t.logger.Debug(fmt.Sprintf("startup reset: checking %d virtual keys (active + inactive) for expired rate limits", len(allVKs))) + } + + for i := range allVKs { + vk := &allVKs[i] // Get pointer to VK for modifications + if vk.RateLimit == nil { + vksWithoutRateLimits++ + continue + } + + vksWithRateLimits++ + + rateLimit := vk.RateLimit + rateLimitUpdated := false + + // Check token limits + if rateLimit.TokenResetDuration != nil { + if duration, err := configstore.ParseDuration(*rateLimit.TokenResetDuration); err == nil { + timeSinceReset := now.Sub(rateLimit.TokenLastReset) + if timeSinceReset >= duration { + rateLimit.TokenCurrentUsage = 0 + rateLimit.TokenLastReset = now + rateLimitUpdated = true + } + } else { + errs = append(errs, fmt.Sprintf("invalid token reset duration for VK %s: %s", vk.ID, *rateLimit.TokenResetDuration)) + } + } + + // Check request limits + if rateLimit.RequestResetDuration != nil { + if duration, err := configstore.ParseDuration(*rateLimit.RequestResetDuration); err == nil { + timeSinceReset := now.Sub(rateLimit.RequestLastReset) + if timeSinceReset >= duration { + rateLimit.RequestCurrentUsage = 0 + rateLimit.RequestLastReset = now + rateLimitUpdated = true + } + } else { + errs = append(errs, fmt.Sprintf("invalid request reset duration for VK %s: %s", vk.ID, *rateLimit.RequestResetDuration)) + } + } + + if rateLimitUpdated { + resetRateLimits = append(resetRateLimits, rateLimit) + } + } + + // DB reset is also handled by this function + if err := t.store.ResetExpiredBudgets(); err != nil { + errs = append(errs, fmt.Sprintf("failed to reset expired budgets: %s", err.Error())) + } + + // ==== PERSIST RESETS TO DATABASE ==== + if t.configStore != nil { + if len(resetRateLimits) > 0 { + if err := t.configStore.UpdateRateLimits(resetRateLimits); err != nil { + errs = append(errs, fmt.Sprintf("failed to persist rate limit resets: %s", err.Error())) + } + } + } + t.logger.Info("startup reset summary: VKs with RL=%d, without RL=%d, RL resets=%d", vksWithRateLimits, vksWithoutRateLimits, len(resetRateLimits)) + if len(errs) > 0 { + t.logger.Error("startup reset encountered %d errors: %v", len(errs), errs) + return fmt.Errorf("startup reset completed with %d errors", len(errs)) + } + + return nil +} + +// Cleanup stops all background workers and flushes pending operations +func (t *UsageTracker) Cleanup() error { + // Stop background workers + close(t.done) + + if t.resetTicker != nil { + t.resetTicker.Stop() + } + + // Wait for workers to finish + t.wg.Wait() + + t.logger.Debug("usage tracker cleanup completed") + return nil +} diff --git a/plugins/governance/utils.go b/plugins/governance/utils.go new file mode 100644 index 000000000..ba5e6cba2 --- /dev/null +++ b/plugins/governance/utils.go @@ -0,0 +1,62 @@ +// Package governance provides utility functions for the governance plugin +package governance + +import ( + "context" + + "github.com/maximhq/bifrost/core/schemas" +) + +type ContextKey string + +// extractHeadersFromContext extracts governance headers from context (standalone version) +func extractHeadersFromContext(ctx context.Context) map[string]string { + headers := make(map[string]string) + + // Extract governance headers using lib.ContextKey + if teamID := getStringFromContext(ctx, ContextKey("x-bf-team")); teamID != "" { + headers["x-bf-team"] = teamID + } + if userID := getStringFromContext(ctx, ContextKey("x-bf-user")); userID != "" { + headers["x-bf-user"] = userID + } + if customerID := getStringFromContext(ctx, ContextKey("x-bf-customer")); customerID != "" { + headers["x-bf-customer"] = customerID + } + + return headers +} + +// getStringFromContext safely extracts a string value from context +func getStringFromContext(ctx context.Context, key any) string { + if value := ctx.Value(key); value != nil { + if str, ok := value.(string); ok { + return str + } + } + return "" +} + +// hasUsageData checks if the response contains actual usage information +func hasUsageData(result *schemas.BifrostResponse) bool { + if result == nil { + return false + } + + // Check main usage field + if result.Usage != nil { + return true + } + + // Check speech usage + if result.Speech != nil && result.Speech.Usage != nil { + return true + } + + // Check transcribe usage + if result.Transcribe != nil && result.Transcribe.Usage != nil { + return true + } + + return false +} diff --git a/plugins/governance/version b/plugins/governance/version new file mode 100644 index 000000000..f69752ab1 --- /dev/null +++ b/plugins/governance/version @@ -0,0 +1 @@ +1.2.16 diff --git a/plugins/jsonparser/changelog.md b/plugins/jsonparser/changelog.md new file mode 100644 index 000000000..6dcfe4edd --- /dev/null +++ b/plugins/jsonparser/changelog.md @@ -0,0 +1,4 @@ + + + +- Upgrades framework to 1.0.23 \ No newline at end of file diff --git a/plugins/jsonparser/go.mod b/plugins/jsonparser/go.mod new file mode 100644 index 000000000..6f3411b71 --- /dev/null +++ b/plugins/jsonparser/go.mod @@ -0,0 +1,51 @@ +module github.com/maximhq/bifrost/plugins/jsonparser + +go 1.24 + +toolchain go1.24.3 + +require github.com/maximhq/bifrost/core v1.1.37 + +require ( + cloud.google.com/go/compute/metadata v0.8.0 // indirect + github.com/andybalholm/brotli v1.2.0 // indirect + github.com/aws/aws-sdk-go-v2 v1.38.0 // indirect + github.com/aws/aws-sdk-go-v2/config v1.31.0 // indirect + github.com/aws/aws-sdk-go-v2/credentials v1.18.4 // indirect + github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.18.3 // indirect + github.com/aws/aws-sdk-go-v2/internal/configsources v1.4.3 // indirect + github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.7.3 // indirect + github.com/aws/aws-sdk-go-v2/internal/ini v1.8.3 // indirect + github.com/aws/aws-sdk-go-v2/service/internal/accept-encoding v1.13.0 // indirect + github.com/aws/aws-sdk-go-v2/service/internal/presigned-url v1.13.3 // indirect + github.com/aws/aws-sdk-go-v2/service/sso v1.28.0 // indirect + github.com/aws/aws-sdk-go-v2/service/ssooidc v1.33.0 // indirect + github.com/aws/aws-sdk-go-v2/service/sts v1.37.0 // indirect + github.com/aws/smithy-go v1.22.5 // indirect + github.com/bahlo/generic-list-go v0.2.0 // indirect + github.com/buger/jsonparser v1.1.1 // indirect + github.com/bytedance/sonic v1.14.0 // indirect + github.com/bytedance/sonic/loader v0.3.0 // indirect + github.com/cloudwego/base64x v0.1.6 // indirect + github.com/google/uuid v1.6.0 // indirect + github.com/invopop/jsonschema v0.13.0 // indirect + github.com/klauspost/compress v1.18.0 // indirect + github.com/klauspost/cpuid/v2 v2.3.0 // indirect + github.com/mailru/easyjson v0.9.0 // indirect + github.com/mark3labs/mcp-go v0.37.0 // indirect + github.com/mattn/go-colorable v0.1.14 // indirect + github.com/mattn/go-isatty v0.0.20 // indirect + github.com/rs/zerolog v1.34.0 // indirect + github.com/spf13/cast v1.9.2 // indirect + github.com/twitchyliquid64/golang-asm v0.15.1 // indirect + github.com/valyala/bytebufferpool v1.0.0 // indirect + github.com/valyala/fasthttp v1.65.0 // indirect + github.com/wk8/go-ordered-map/v2 v2.1.8 // indirect + github.com/yosida95/uritemplate/v3 v3.0.2 // indirect + golang.org/x/arch v0.20.0 // indirect + golang.org/x/net v0.43.0 // indirect + golang.org/x/oauth2 v0.30.0 // indirect + golang.org/x/sys v0.35.0 // indirect + golang.org/x/text v0.28.0 // indirect + gopkg.in/yaml.v3 v3.0.1 // indirect +) diff --git a/plugins/jsonparser/go.sum b/plugins/jsonparser/go.sum new file mode 100644 index 000000000..7bf3f8a9f --- /dev/null +++ b/plugins/jsonparser/go.sum @@ -0,0 +1,125 @@ +cloud.google.com/go/compute/metadata v0.8.0 h1:HxMRIbao8w17ZX6wBnjhcDkW6lTFpgcaobyVfZWqRLA= +cloud.google.com/go/compute/metadata v0.8.0/go.mod h1:sYOGTp851OV9bOFJ9CH7elVvyzopvWQFNNghtDQ/Biw= +github.com/andybalholm/brotli v1.2.0 h1:ukwgCxwYrmACq68yiUqwIWnGY0cTPox/M94sVwToPjQ= +github.com/andybalholm/brotli v1.2.0/go.mod h1:rzTDkvFWvIrjDXZHkuS16NPggd91W3kUSvPlQ1pLaKY= +github.com/aws/aws-sdk-go-v2 v1.38.0 h1:UCRQ5mlqcFk9HJDIqENSLR3wiG1VTWlyUfLDEvY7RxU= +github.com/aws/aws-sdk-go-v2 v1.38.0/go.mod h1:9Q0OoGQoboYIAJyslFyF1f5K1Ryddop8gqMhWx/n4Wg= +github.com/aws/aws-sdk-go-v2/config v1.31.0 h1:9yH0xiY5fUnVNLRWO0AtayqwU1ndriZdN78LlhruJR4= +github.com/aws/aws-sdk-go-v2/config v1.31.0/go.mod h1:VeV3K72nXnhbe4EuxxhzsDc/ByrCSlZwUnWH52Nde/I= +github.com/aws/aws-sdk-go-v2/credentials v1.18.4 h1:IPd0Algf1b+Qy9BcDp0sCUcIWdCQPSzDoMK3a8pcbUM= +github.com/aws/aws-sdk-go-v2/credentials v1.18.4/go.mod h1:nwg78FjH2qvsRM1EVZlX9WuGUJOL5od+0qvm0adEzHk= +github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.18.3 h1:GicIdnekoJsjq9wqnvyi2elW6CGMSYKhdozE7/Svh78= +github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.18.3/go.mod h1:R7BIi6WNC5mc1kfRM7XM/VHC3uRWkjc396sfabq4iOo= +github.com/aws/aws-sdk-go-v2/internal/configsources v1.4.3 h1:o9RnO+YZ4X+kt5Z7Nvcishlz0nksIt2PIzDglLMP0vA= +github.com/aws/aws-sdk-go-v2/internal/configsources v1.4.3/go.mod h1:+6aLJzOG1fvMOyzIySYjOFjcguGvVRL68R+uoRencN4= +github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.7.3 h1:joyyUFhiTQQmVK6ImzNU9TQSNRNeD9kOklqTzyk5v6s= +github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.7.3/go.mod h1:+vNIyZQP3b3B1tSLI0lxvrU9cfM7gpdRXMFfm67ZcPc= +github.com/aws/aws-sdk-go-v2/internal/ini v1.8.3 h1:bIqFDwgGXXN1Kpp99pDOdKMTTb5d2KyU5X/BZxjOkRo= +github.com/aws/aws-sdk-go-v2/internal/ini v1.8.3/go.mod h1:H5O/EsxDWyU+LP/V8i5sm8cxoZgc2fdNR9bxlOFrQTo= +github.com/aws/aws-sdk-go-v2/service/internal/accept-encoding v1.13.0 h1:6+lZi2JeGKtCraAj1rpoZfKqnQ9SptseRZioejfUOLM= +github.com/aws/aws-sdk-go-v2/service/internal/accept-encoding v1.13.0/go.mod h1:eb3gfbVIxIoGgJsi9pGne19dhCBpK6opTYpQqAmdy44= +github.com/aws/aws-sdk-go-v2/service/internal/presigned-url v1.13.3 h1:ieRzyHXypu5ByllM7Sp4hC5f/1Fy5wqxqY0yB85hC7s= +github.com/aws/aws-sdk-go-v2/service/internal/presigned-url v1.13.3/go.mod h1:O5ROz8jHiOAKAwx179v+7sHMhfobFVi6nZt8DEyiYoM= +github.com/aws/aws-sdk-go-v2/service/sso v1.28.0 h1:Mc/MKBf2m4VynyJkABoVEN+QzkfLqGj0aiJuEe7cMeM= +github.com/aws/aws-sdk-go-v2/service/sso v1.28.0/go.mod h1:iS5OmxEcN4QIPXARGhavH7S8kETNL11kym6jhoS7IUQ= +github.com/aws/aws-sdk-go-v2/service/ssooidc v1.33.0 h1:6csaS/aJmqZQbKhi1EyEMM7yBW653Wy/B9hnBofW+sw= +github.com/aws/aws-sdk-go-v2/service/ssooidc v1.33.0/go.mod h1:59qHWaY5B+Rs7HGTuVGaC32m0rdpQ68N8QCN3khYiqs= +github.com/aws/aws-sdk-go-v2/service/sts v1.37.0 h1:MG9VFW43M4A8BYeAfaJJZWrroinxeTi2r3+SnmLQfSA= +github.com/aws/aws-sdk-go-v2/service/sts v1.37.0/go.mod h1:JdeBDPgpJfuS6rU/hNglmOigKhyEZtBmbraLE4GK1J8= +github.com/aws/smithy-go v1.22.5 h1:P9ATCXPMb2mPjYBgueqJNCA5S9UfktsW0tTxi+a7eqw= +github.com/aws/smithy-go v1.22.5/go.mod h1:t1ufH5HMublsJYulve2RKmHDC15xu1f26kHCp/HgceI= +github.com/bahlo/generic-list-go v0.2.0 h1:5sz/EEAK+ls5wF+NeqDpk5+iNdMDXrh3z3nPnH1Wvgk= +github.com/bahlo/generic-list-go v0.2.0/go.mod h1:2KvAjgMlE5NNynlg/5iLrrCCZ2+5xWbdbCW3pNTGyYg= +github.com/buger/jsonparser v1.1.1 h1:2PnMjfWD7wBILjqQbt530v576A/cAbQvEW9gGIpYMUs= +github.com/buger/jsonparser v1.1.1/go.mod h1:6RYKKt7H4d4+iWqouImQ9R2FZql3VbhNgx27UK13J/0= +github.com/bytedance/sonic v1.14.0 h1:/OfKt8HFw0kh2rj8N0F6C/qPGRESq0BbaNZgcNXXzQQ= +github.com/bytedance/sonic v1.14.0/go.mod h1:WoEbx8WTcFJfzCe0hbmyTGrfjt8PzNEBdxlNUO24NhA= +github.com/bytedance/sonic/loader v0.3.0 h1:dskwH8edlzNMctoruo8FPTJDF3vLtDT0sXZwvZJyqeA= +github.com/bytedance/sonic/loader v0.3.0/go.mod h1:N8A3vUdtUebEY2/VQC0MyhYeKUFosQU6FxH2JmUe6VI= +github.com/cloudwego/base64x v0.1.6 h1:t11wG9AECkCDk5fMSoxmufanudBtJ+/HemLstXDLI2M= +github.com/cloudwego/base64x v0.1.6/go.mod h1:OFcloc187FXDaYHvrNIjxSe8ncn0OOM8gEHfghB2IPU= +github.com/coreos/go-systemd/v22 v22.5.0/go.mod h1:Y58oyj3AT4RCenI/lSvhwexgC+NSVTIJ3seZv2GcEnc= +github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc h1:U9qPSI2PIWSS1VwoXQT9A3Wy9MM3WgvqSxFWenqJduM= +github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/frankban/quicktest v1.14.6 h1:7Xjx+VpznH+oBnejlPUj8oUpdxnVs4f8XU8WnHkI4W8= +github.com/frankban/quicktest v1.14.6/go.mod h1:4ptaffx2x8+WTWXmUCuVU6aPUX1/Mz7zb5vbUoiM6w0= +github.com/godbus/dbus/v5 v5.0.4/go.mod h1:xhWf0FNVPg57R7Z0UbKHbJfkEywrmjJnf7w5xrFpKfA= +github.com/google/go-cmp v0.7.0 h1:wk8382ETsv4JYUZwIsn6YpYiWiBsYLSJiTsyBybVuN8= +github.com/google/go-cmp v0.7.0/go.mod h1:pXiqmnSA92OHEEa9HXL2W4E7lf9JzCmGVUdgjX3N/iU= +github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0= +github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= +github.com/invopop/jsonschema v0.13.0 h1:KvpoAJWEjR3uD9Kbm2HWJmqsEaHt8lBUpd0qHcIi21E= +github.com/invopop/jsonschema v0.13.0/go.mod h1:ffZ5Km5SWWRAIN6wbDXItl95euhFz2uON45H2qjYt+0= +github.com/klauspost/compress v1.18.0 h1:c/Cqfb0r+Yi+JtIEq73FWXVkRonBlf0CRNYc8Zttxdo= +github.com/klauspost/compress v1.18.0/go.mod h1:2Pp+KzxcywXVXMr50+X0Q/Lsb43OQHYWRCY2AiWywWQ= +github.com/klauspost/cpuid/v2 v2.3.0 h1:S4CRMLnYUhGeDFDqkGriYKdfoFlDnMtqTiI/sFzhA9Y= +github.com/klauspost/cpuid/v2 v2.3.0/go.mod h1:hqwkgyIinND0mEev00jJYCxPNVRVXFQeu1XKlok6oO0= +github.com/kr/pretty v0.3.1 h1:flRD4NNwYAUpkphVc1HcthR4KEIFJ65n8Mw5qdRn3LE= +github.com/kr/pretty v0.3.1/go.mod h1:hoEshYVHaxMs3cyo3Yncou5ZscifuDolrwPKZanG3xk= +github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY= +github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE= +github.com/mailru/easyjson v0.9.0 h1:PrnmzHw7262yW8sTBwxi1PdJA3Iw/EKBa8psRf7d9a4= +github.com/mailru/easyjson v0.9.0/go.mod h1:1+xMtQp2MRNVL/V1bOzuP3aP8VNwRW55fQUto+XFtTU= +github.com/mark3labs/mcp-go v0.37.0 h1:BywvZLPRT6Zx6mMG/MJfxLSZQkTGIcJSEGKsvr4DsoQ= +github.com/mark3labs/mcp-go v0.37.0/go.mod h1:T7tUa2jO6MavG+3P25Oy/jR7iCeJPHImCZHRymCn39g= +github.com/mattn/go-colorable v0.1.13/go.mod h1:7S9/ev0klgBDR4GtXTXX8a3vIGJpMovkB8vQcUbaXHg= +github.com/mattn/go-colorable v0.1.14 h1:9A9LHSqF/7dyVVX6g0U9cwm9pG3kP9gSzcuIPHPsaIE= +github.com/mattn/go-colorable v0.1.14/go.mod h1:6LmQG8QLFO4G5z1gPvYEzlUgJ2wF+stgPZH1UqBm1s8= +github.com/mattn/go-isatty v0.0.16/go.mod h1:kYGgaQfpe5nmfYZH+SKPsOc2e4SrIfOl2e/yFXSvRLM= +github.com/mattn/go-isatty v0.0.19/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y= +github.com/mattn/go-isatty v0.0.20 h1:xfD0iDuEKnDkl03q4limB+vH+GxLEtL/jb4xVJSWWEY= +github.com/mattn/go-isatty v0.0.20/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y= +github.com/maximhq/bifrost/core v1.1.37 h1:jVFY1tQFY8T2r4S3RE1zN8cFp1Uw97Dec3Ud32rR8Uc= +github.com/maximhq/bifrost/core v1.1.37/go.mod h1:tf2pFTpoM53UGXXMFYxsaUjMqnCqYDOd9glFgMJvA0c= +github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= +github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= +github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2 h1:Jamvg5psRIccs7FGNTlIRMkT8wgtp5eCXdBlqhYGL6U= +github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= +github.com/rogpeppe/go-internal v1.13.1 h1:KvO1DLK/DRN07sQ1LQKScxyZJuNnedQ5/wKSR38lUII= +github.com/rogpeppe/go-internal v1.13.1/go.mod h1:uMEvuHeurkdAXX61udpOXGD/AzZDWNMNyH2VO9fmH0o= +github.com/rs/xid v1.6.0/go.mod h1:7XoLgs4eV+QndskICGsho+ADou8ySMSjJKDIan90Nz0= +github.com/rs/zerolog v1.34.0 h1:k43nTLIwcTVQAncfCw4KZ2VY6ukYoZaBPNOE8txlOeY= +github.com/rs/zerolog v1.34.0/go.mod h1:bJsvje4Z08ROH4Nhs5iH600c3IkWhwp44iRc54W6wYQ= +github.com/spf13/cast v1.9.2 h1:SsGfm7M8QOFtEzumm7UZrZdLLquNdzFYfIbEXntcFbE= +github.com/spf13/cast v1.9.2/go.mod h1:jNfB8QC9IA6ZuY2ZjDp0KtFO2LZZlg4S/7bzP6qqeHo= +github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= +github.com/stretchr/objx v0.4.0/go.mod h1:YvHI0jy2hoMjB+UWwv71VJQ9isScKT/TqJzVSSt89Yw= +github.com/stretchr/objx v0.5.0/go.mod h1:Yh+to48EsGEfYuaHDzXPcE3xhTkx73EhmCGUpEOglKo= +github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= +github.com/stretchr/testify v1.8.0/go.mod h1:yNjHg4UonilssWZ8iaSj1OCr/vHnekPRkoO+kdMU+MU= +github.com/stretchr/testify v1.8.1/go.mod h1:w2LPCIKwWwSfY2zedu0+kehJoqGctiVI29o6fzry7u4= +github.com/stretchr/testify v1.10.0 h1:Xv5erBjTwe/5IxqUQTdXv5kgmIvbHo3QQyRwhJsOfJA= +github.com/stretchr/testify v1.10.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY= +github.com/twitchyliquid64/golang-asm v0.15.1 h1:SU5vSMR7hnwNxj24w34ZyCi/FmDZTkS4MhqMhdFk5YI= +github.com/twitchyliquid64/golang-asm v0.15.1/go.mod h1:a1lVb/DtPvCB8fslRZhAngC2+aY1QWCk3Cedj/Gdt08= +github.com/valyala/bytebufferpool v1.0.0 h1:GqA5TC/0021Y/b9FG4Oi9Mr3q7XYx6KllzawFIhcdPw= +github.com/valyala/bytebufferpool v1.0.0/go.mod h1:6bBcMArwyJ5K/AmCkWv1jt77kVWyCJ6HpOuEn7z0Csc= +github.com/valyala/fasthttp v1.65.0 h1:j/u3uzFEGFfRxw79iYzJN+TteTJwbYkru9uDp3d0Yf8= +github.com/valyala/fasthttp v1.65.0/go.mod h1:P/93/YkKPMsKSnATEeELUCkG8a7Y+k99uxNHVbKINr4= +github.com/wk8/go-ordered-map/v2 v2.1.8 h1:5h/BUHu93oj4gIdvHHHGsScSTMijfx5PeYkE/fJgbpc= +github.com/wk8/go-ordered-map/v2 v2.1.8/go.mod h1:5nJHM5DyteebpVlHnWMV0rPz6Zp7+xBAnxjb1X5vnTw= +github.com/xyproto/randomstring v1.0.5 h1:YtlWPoRdgMu3NZtP45drfy1GKoojuR7hmRcnhZqKjWU= +github.com/xyproto/randomstring v1.0.5/go.mod h1:rgmS5DeNXLivK7YprL0pY+lTuhNQW3iGxZ18UQApw/E= +github.com/yosida95/uritemplate/v3 v3.0.2 h1:Ed3Oyj9yrmi9087+NczuL5BwkIc4wvTb5zIM+UJPGz4= +github.com/yosida95/uritemplate/v3 v3.0.2/go.mod h1:ILOh0sOhIJR3+L/8afwt/kE++YT040gmv5BQTMR2HP4= +golang.org/x/arch v0.20.0 h1:dx1zTU0MAE98U+TQ8BLl7XsJbgze2WnNKF/8tGp/Q6c= +golang.org/x/arch v0.20.0/go.mod h1:bdwinDaKcfZUGpH09BB7ZmOfhalA8lQdzl62l8gGWsk= +golang.org/x/net v0.43.0 h1:lat02VYK2j4aLzMzecihNvTlJNQUq316m2Mr9rnM6YE= +golang.org/x/net v0.43.0/go.mod h1:vhO1fvI4dGsIjh73sWfUVjj3N7CA9WkKJNQm2svM6Jg= +golang.org/x/oauth2 v0.30.0 h1:dnDm7JmhM45NNpd8FDDeLhK6FwqbOf4MLCM9zb1BOHI= +golang.org/x/oauth2 v0.30.0/go.mod h1:B++QgG3ZKulg6sRPGD/mqlHQs5rB3Ml9erfeDY7xKlU= +golang.org/x/sys v0.0.0-20220811171246-fbc7d0a398ab/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.12.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.35.0 h1:vz1N37gP5bs89s7He8XuIYXpyY0+QlsKmzipCbUtyxI= +golang.org/x/sys v0.35.0/go.mod h1:BJP2sWEmIv4KK5OTEluFJCKSidICx8ciO85XgH3Ak8k= +golang.org/x/text v0.28.0 h1:rhazDwis8INMIwQ4tpjLDzUhx6RlXqZNPEM0huQojng= +golang.org/x/text v0.28.0/go.mod h1:U8nCwOR8jO/marOQ0QbDiOngZVEBB7MAiitBuMjXiNU= +gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= +gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c h1:Hei/4ADfdWqJk1ZMxUNpqntNwaWcugrBjAiHlqqRiVk= +gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c/go.mod h1:JHkPIbrfpd72SG/EVd6muEfDQjcINNoR0C8j2r3qZ4Q= +gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= +gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= +gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= diff --git a/plugins/jsonparser/main.go b/plugins/jsonparser/main.go new file mode 100644 index 000000000..10289a6a0 --- /dev/null +++ b/plugins/jsonparser/main.go @@ -0,0 +1,417 @@ +package jsonparser + +import ( + "context" + "encoding/json" + "strings" + "sync" + "time" + + bifrost "github.com/maximhq/bifrost/core" + "github.com/maximhq/bifrost/core/schemas" +) + +const ( + PluginName = "streaming-json-parser" + EnableStreamingJSONParser = "enable-streaming-json-parser" +) + +type Usage string + +const ( + AllRequests Usage = "all_requests" + PerRequest Usage = "per_request" +) + +// AccumulatedContent holds both the content and timestamp for a request +type AccumulatedContent struct { + Content *strings.Builder + Timestamp time.Time +} + +// JsonParserPlugin provides JSON parsing capabilities for streaming responses +// It handles partial JSON chunks by accumulating them and making the accumulated content valid JSON +type JsonParserPlugin struct { + usage Usage + // State management for accumulating chunks + accumulatedContent map[string]*AccumulatedContent // requestID -> accumulated content with timestamp + mutex sync.RWMutex + // Cleanup configuration + cleanupInterval time.Duration + maxAge time.Duration + stopCleanup chan struct{} + stopOnce sync.Once +} + +// PluginConfig holds configuration options for the JSON parser plugin +type PluginConfig struct { + Usage Usage + CleanupInterval time.Duration + MaxAge time.Duration +} + +// Init creates a new JSON parser plugin instance with custom configuration +func Init(config PluginConfig) (*JsonParserPlugin, error) { + // Set defaults if not provided + if config.CleanupInterval <= 0 { + config.CleanupInterval = 5 * time.Minute + } + if config.MaxAge <= 0 { + config.MaxAge = 30 * time.Minute + } + if config.Usage == "" { + config.Usage = PerRequest + } + + plugin := &JsonParserPlugin{ + usage: config.Usage, + accumulatedContent: make(map[string]*AccumulatedContent), + cleanupInterval: config.CleanupInterval, + maxAge: config.MaxAge, + stopCleanup: make(chan struct{}), + } + + // Start the cleanup goroutine + go plugin.startCleanupGoroutine() + + return plugin, nil +} + +// GetName returns the plugin name +func (p *JsonParserPlugin) GetName() string { + return PluginName +} + +// PreHook is not used for this plugin as we only process responses +func (p *JsonParserPlugin) PreHook(ctx *context.Context, req *schemas.BifrostRequest) (*schemas.BifrostRequest, *schemas.PluginShortCircuit, error) { + return req, nil, nil +} + +// PostHook processes streaming responses by accumulating chunks and making accumulated content valid JSON +func (p *JsonParserPlugin) PostHook(ctx *context.Context, result *schemas.BifrostResponse, err *schemas.BifrostError) (*schemas.BifrostResponse, *schemas.BifrostError, error) { + // Check if plugin should run based on usage type + if !p.shouldRun(ctx) { + return result, err, nil + } + + // If there's an error, don't process + if err != nil { + return result, err, nil + } + + // If no result, return as is + if result == nil { + return result, err, nil + } + + // Get request ID for state management, if it's not set, return as is + requestID := p.getRequestID(ctx, result) + if requestID == "" { + return result, err, nil + } + + // Process only streaming choices to accumulate and fix partial JSON + if len(result.Choices) > 0 { + for i := range result.Choices { + choice := &result.Choices[i] + + // Handle only streaming response + if choice.BifrostStreamResponseChoice != nil { + if choice.BifrostStreamResponseChoice.Delta.Content != nil { + content := *choice.BifrostStreamResponseChoice.Delta.Content + if content != "" { + // Accumulate the content + accumulated := p.accumulateContent(requestID, content) + + // Process the accumulated content to make it valid JSON + fixedContent := p.parsePartialJSON(accumulated) + + if !p.isValidJSON(fixedContent) { + err = &schemas.BifrostError{ + Error: schemas.ErrorField{ + Message: "Invalid JSON in streaming response", + }, + StreamControl: &schemas.StreamControl{ + SkipStream: bifrost.Ptr(true), + }, + } + + return nil, err, nil + } + + // Replace the delta content with the complete valid JSON + choice.BifrostStreamResponseChoice.Delta.Content = &fixedContent + } + } + } + } + } + + // If this is the final chunk, cleanup the accumulated content for this request + if streamEndIndicatorValue := (*ctx).Value(schemas.BifrostContextKeyStreamEndIndicator); streamEndIndicatorValue != nil { + isFinalChunk, ok := streamEndIndicatorValue.(bool) + if ok && isFinalChunk { + p.ClearRequestState(requestID) + } + } + + return result, err, nil +} + +// getRequestID extracts a unique identifier for the request to maintain state +func (p *JsonParserPlugin) getRequestID(ctx *context.Context, result *schemas.BifrostResponse) string { + + // Try to get from result + if result != nil && result.ID != "" { + return result.ID + } + + // Try to get from context if not available in result + if ctx != nil { + if requestID, ok := (*ctx).Value(schemas.BifrostContextKey("request-id")).(string); ok && requestID != "" { + return requestID + } + } + + return "" +} + +// accumulateContent adds new content to the accumulated content for a specific request +func (p *JsonParserPlugin) accumulateContent(requestID, newContent string) string { + p.mutex.Lock() + defer p.mutex.Unlock() + + // Get existing accumulated content + existing := p.accumulatedContent[requestID] + + if existing != nil { + // Append to existing builder + existing.Content.WriteString(newContent) + return existing.Content.String() + } else { + // Create new builder + builder := &strings.Builder{} + builder.WriteString(newContent) + p.accumulatedContent[requestID] = &AccumulatedContent{ + Content: builder, + Timestamp: time.Now(), + } + return builder.String() + } +} + +// shouldRun determines if the plugin should process the request based on usage type +func (p *JsonParserPlugin) shouldRun(ctx *context.Context) bool { + // Run only for chat completion stream requests + requestType, ok := (*ctx).Value(schemas.BifrostContextKeyRequestType).(schemas.RequestType) + if !ok || requestType != schemas.ChatCompletionStreamRequest { + return false + } + + switch p.usage { + case AllRequests: + return true + case PerRequest: + // Check if the context contains the plugin-specific key + if ctx != nil { + if value, ok := (*ctx).Value(EnableStreamingJSONParser).(bool); ok { + return value + } + } + return false + default: + return false + } +} + +// Cleanup performs plugin cleanup and clears accumulated content +func (p *JsonParserPlugin) Cleanup() error { + // Stop the cleanup goroutine + p.StopCleanup() + + p.mutex.Lock() + defer p.mutex.Unlock() + + // Clear accumulated content + p.accumulatedContent = make(map[string]*AccumulatedContent) + return nil +} + +// ClearRequestState clears the accumulated content for a specific request +func (p *JsonParserPlugin) ClearRequestState(requestID string) { + p.mutex.Lock() + defer p.mutex.Unlock() + + delete(p.accumulatedContent, requestID) +} + +// parsePartialJSON parses a JSON string that may be missing closing braces +func (p *JsonParserPlugin) parsePartialJSON(s string) string { + // Trim whitespace + s = strings.TrimSpace(s) + if s == "" { + return "{}" + } + + // Quick check: if it starts with { or [, it might be JSON + if s[0] != '{' && s[0] != '[' { + return s + } + + // First, try to parse the string as-is (fast path) + if p.isValidJSON(s) { + return s + } + + // Use a more efficient approach: build the completion directly + return p.completeJSON(s) +} + +// isValidJSON checks if a string is valid JSON +func (p *JsonParserPlugin) isValidJSON(s string) bool { + // Trim whitespace + s = strings.TrimSpace(s) + + // Empty string after trimming is not valid JSON + if s == "" { + return false + } + + return json.Valid([]byte(s)) +} + +// completeJSON completes partial JSON with O(n) time complexity +func (p *JsonParserPlugin) completeJSON(s string) string { + // Pre-allocate buffer with estimated capacity + capacity := len(s) + 10 // Estimate max 10 closing characters needed + result := make([]byte, 0, capacity) + + var stack []byte + inString := false + escaped := false + + // Process the string once + for i := 0; i < len(s); i++ { + char := s[i] + result = append(result, char) + + if escaped { + escaped = false + continue + } + + if char == '\\' { + escaped = true + continue + } + + if char == '"' { + inString = !inString + continue + } + + if inString { + continue + } + + switch char { + case '{', '[': + if char == '{' { + stack = append(stack, '}') + } else { + stack = append(stack, ']') + } + case '}', ']': + if len(stack) > 0 && stack[len(stack)-1] == char { + stack = stack[:len(stack)-1] + } + } + } + + // Close any unclosed strings + if inString { + if escaped { + // Remove the trailing backslash + if len(result) > 0 { + result = result[:len(result)-1] + } + } + result = append(result, '"') + } + + // Add closing characters in reverse order + for i := len(stack) - 1; i >= 0; i-- { + result = append(result, stack[i]) + } + + // Validate the result + if p.isValidJSON(string(result)) { + return string(result) + } + + // If still invalid, try progressive truncation (but more efficiently) + return p.progressiveTruncation(s, result) +} + +// progressiveTruncation efficiently tries different truncation points +func (p *JsonParserPlugin) progressiveTruncation(original string, completed []byte) string { + // Try removing characters from the end until we get valid JSON + // Use binary search for better performance + left, right := 0, len(completed) + + for left < right { + mid := (left + right) / 2 + candidate := completed[:mid] + + if p.isValidJSON(string(candidate)) { + left = mid + 1 + } else { + right = mid + } + } + + // Try the best candidate + if left > 0 && p.isValidJSON(string(completed[:left-1])) { + return string(completed[:left-1]) + } + + // Fallback to original + return original +} + +// startCleanupGoroutine starts a goroutine that periodically cleans up old accumulated content +func (p *JsonParserPlugin) startCleanupGoroutine() { + ticker := time.NewTicker(p.cleanupInterval) + defer ticker.Stop() + + for { + select { + case <-ticker.C: + p.cleanupOldEntries() + case <-p.stopCleanup: + return + } + } +} + +// cleanupOldEntries removes accumulated content entries that are older than maxAge +func (p *JsonParserPlugin) cleanupOldEntries() { + p.mutex.Lock() + defer p.mutex.Unlock() + + now := time.Now() + cutoff := now.Add(-p.maxAge) + + for requestID, content := range p.accumulatedContent { + if content.Timestamp.Before(cutoff) { + delete(p.accumulatedContent, requestID) + } + } +} + +// StopCleanup stops the cleanup goroutine +func (p *JsonParserPlugin) StopCleanup() { + p.stopOnce.Do(func() { + close(p.stopCleanup) + }) +} diff --git a/plugins/jsonparser/plugin_test.go b/plugins/jsonparser/plugin_test.go new file mode 100644 index 000000000..a4b3c9a49 --- /dev/null +++ b/plugins/jsonparser/plugin_test.go @@ -0,0 +1,328 @@ +package jsonparser + +import ( + "context" + "os" + "testing" + "time" + + bifrost "github.com/maximhq/bifrost/core" + "github.com/maximhq/bifrost/core/schemas" +) + +// BaseAccount implements the schemas.Account interface for testing purposes. +// It provides mock implementations of the required methods to test the JSON parser plugin +// with a basic OpenAI configuration. +type BaseAccount struct{} + +// GetConfiguredProviders returns a list of supported providers for testing. +// Currently only supports OpenAI for simplicity in testing. +func (baseAccount *BaseAccount) GetConfiguredProviders() ([]schemas.ModelProvider, error) { + return []schemas.ModelProvider{schemas.OpenAI}, nil +} + +// GetKeysForProvider returns a mock API key configuration for testing. +// Uses the OPENAI_API_KEY environment variable for authentication. +func (baseAccount *BaseAccount) GetKeysForProvider(ctx *context.Context, providerKey schemas.ModelProvider) ([]schemas.Key, error) { + return []schemas.Key{ + { + Value: os.Getenv("OPENAI_API_KEY"), + Models: []string{"gpt-4o-mini", "gpt-4-turbo"}, + Weight: 1.0, + }, + }, nil +} + +// GetConfigForProvider returns default provider configuration for testing. +// Uses standard network and concurrency settings. +func (baseAccount *BaseAccount) GetConfigForProvider(providerKey schemas.ModelProvider) (*schemas.ProviderConfig, error) { + return &schemas.ProviderConfig{ + NetworkConfig: schemas.DefaultNetworkConfig, + ConcurrencyAndBufferSize: schemas.DefaultConcurrencyAndBufferSize, + }, nil +} + +// TestJsonParserPluginEndToEnd tests the integration of the JSON parser plugin with Bifrost. +// It performs the following steps: +// 1. Initializes the JSON parser plugin with AllRequests usage +// 2. Sets up a test Bifrost instance with the plugin +// 3. Makes a test chat completion request with streaming enabled +// 4. Verifies that the plugin processes the streaming response correctly +// +// Required environment variables: +// - OPENAI_API_KEY: Your OpenAI API key for the test request +func TestJsonParserPluginEndToEnd(t *testing.T) { + ctx := context.Background() + // Check if OpenAI API key is set + if os.Getenv("OPENAI_API_KEY") == "" { + t.Skip("OPENAI_API_KEY is not set, skipping end-to-end test") + } + + // Initialize the JSON parser plugin for all requests + plugin, err := Init(PluginConfig{ + Usage: AllRequests, + CleanupInterval: 5 * time.Minute, + MaxAge: 30 * time.Minute, + }) + if err != nil { + t.Fatalf("Error initializing JSON parser plugin: %v", err) + } + + account := BaseAccount{} + + // Initialize Bifrost with the plugin + client, err := bifrost.Init(ctx, schemas.BifrostConfig{ + Account: &account, + Plugins: []schemas.Plugin{plugin}, + Logger: bifrost.NewDefaultLogger(schemas.LogLevelDebug), + }) + if err != nil { + t.Fatalf("Error initializing Bifrost: %v", err) + } + defer client.Shutdown() + + // Make a test chat completion request with streaming enabled + // Request JSON output to test the parser + request := &schemas.BifrostRequest{ + Provider: schemas.OpenAI, + Model: "gpt-4o-mini", + Input: schemas.RequestInput{ + ChatCompletionInput: &[]schemas.BifrostMessage{ + { + Role: "user", + Content: schemas.MessageContent{ + ContentStr: bifrost.Ptr("Return a JSON object with name, age, and city fields. Example: {\"name\": \"John\", \"age\": 30, \"city\": \"New York\"}"), + }, + }, + }, + }, + Params: &schemas.ModelParameters{ + + ExtraParams: map[string]any{ + "stream": true, + "response_format": map[string]any{ + "type": "json_object", + }, + }, + }, + } + // Make the streaming request + responseChan, bifrostErr := client.ChatCompletionStreamRequest(ctx, request) + + if bifrostErr != nil { + t.Fatalf("Error in Bifrost request: %v", bifrostErr) + } + + // Process streaming responses + if responseChan != nil { + t.Logf("Streaming response channel received") + + // Read from the channel to see the streaming responses + responseCount := 0 + + for streamResponse := range responseChan { + responseCount++ + + if streamResponse.BifrostError != nil { + t.Logf("Streaming response error: %v", streamResponse.BifrostError) + } + + if streamResponse.BifrostResponse != nil { + for _, choice := range streamResponse.BifrostResponse.Choices { + if choice.BifrostStreamResponseChoice != nil && choice.BifrostStreamResponseChoice.Delta.Content != nil { + content := *choice.BifrostStreamResponseChoice.Delta.Content + if content != "" { + t.Logf("Chunk %d: %s", responseCount, content) + } + } + } + } + } + + t.Logf("Stream completed after %d responses", responseCount) + } else { + t.Logf("No streaming response channel received") + } + + t.Log("End-to-end test completed - check logs for JSON parsing behavior") +} + +// TestJsonParserPluginPerRequest tests the per-request configuration of the JSON parser plugin. +// It tests how the plugin behaves when enabled via context for specific requests. +// +// Required environment variables: +// - OPENAI_API_KEY: Your OpenAI API key for the test request +func TestJsonParserPluginPerRequest(t *testing.T) { + ctx := context.Background() + // Check if OpenAI API key is set + if os.Getenv("OPENAI_API_KEY") == "" { + t.Skip("OPENAI_API_KEY is not set, skipping per-request test") + } + + // Initialize the JSON parser plugin for per-request usage + plugin, err := Init(PluginConfig{ + Usage: PerRequest, + CleanupInterval: 5 * time.Minute, + MaxAge: 30 * time.Minute, + }) + if err != nil { + t.Fatalf("Error initializing JSON parser plugin: %v", err) + } + + account := BaseAccount{} + + // Initialize Bifrost with the plugin + client, err := bifrost.Init(ctx, schemas.BifrostConfig{ + Account: &account, + Plugins: []schemas.Plugin{plugin}, + Logger: bifrost.NewDefaultLogger(schemas.LogLevelDebug), + }) + if err != nil { + t.Fatalf("Error initializing Bifrost: %v", err) + } + defer client.Shutdown() + + // Test request with plugin enabled via context + request := &schemas.BifrostRequest{ + Provider: schemas.OpenAI, + Model: "gpt-4o-mini", + Input: schemas.RequestInput{ + ChatCompletionInput: &[]schemas.BifrostMessage{ + { + Role: "user", + Content: schemas.MessageContent{ + ContentStr: bifrost.Ptr("Return a JSON object with name and age fields."), + }, + }, + }, + }, + Params: &schemas.ModelParameters{ + ExtraParams: map[string]any{ + "stream": true, + "response_format": map[string]any{ + "type": "json_object", + }, + }, + }, + } + + // Create context with plugin enabled + newContext := context.WithValue(ctx, EnableStreamingJSONParser, true) + + // Make the streaming request + responseChan, bifrostErr := client.ChatCompletionStreamRequest(newContext, request) + + if bifrostErr != nil { + t.Logf("Error in Bifrost request: %v", bifrostErr) + } + + // Process streaming responses + if responseChan != nil { + t.Logf("Streaming response channel received for per-request test") + + // Read from the channel to see the streaming responses + responseCount := 0 + + for streamResponse := range responseChan { + responseCount++ + + if streamResponse.BifrostError != nil { + t.Logf("Streaming response error: %v", streamResponse.BifrostError) + } + + if streamResponse.BifrostResponse != nil { + for _, choice := range streamResponse.BifrostResponse.Choices { + if choice.BifrostStreamResponseChoice != nil && choice.BifrostStreamResponseChoice.Delta.Content != nil { + content := *choice.BifrostStreamResponseChoice.Delta.Content + if content != "" { + t.Logf("Per-request chunk %d: %s", responseCount, content) + } + } + } + } + } + + t.Logf("Per-request stream completed after %d responses", responseCount) + } else { + t.Logf("No streaming response channel received for per-request test") + } + + t.Log("Per-request test completed - check logs for JSON parsing behavior") +} + +func TestParsePartialJSON(t *testing.T) { + plugin, err := Init(PluginConfig{ + Usage: AllRequests, + CleanupInterval: 5 * time.Minute, + MaxAge: 30 * time.Minute, + }) + if err != nil { + t.Fatalf("Error initializing JSON parser plugin: %v", err) + } + + tests := []struct { + name string + input string + expected string + }{ + { + name: "Already valid JSON object", + input: `{"name": "John", "age": 30}`, + expected: `{"name": "John", "age": 30}`, + }, + { + name: "Partial JSON object missing closing brace", + input: `{"name": "John", "age": 30, "city": "New York"`, + expected: `{"name": "John", "age": 30, "city": "New York"}`, + }, + { + name: "Partial JSON array missing closing bracket", + input: `["apple", "banana", "cherry"`, + expected: `["apple", "banana", "cherry"]`, + }, + { + name: "Nested partial JSON", + input: `{"user": {"name": "John", "details": {"age": 30, "city": "NY"`, + expected: `{"user": {"name": "John", "details": {"age": 30, "city": "NY"}}}`, + }, + { + name: "Partial JSON with string containing newline", + input: `{"message": "Hello\nWorld"`, + expected: `{"message": "Hello\nWorld"}`, + }, + { + name: "Empty string", + input: "", + expected: "{}", + }, + { + name: "Whitespace only", + input: " \n\t ", + expected: "{}", + }, + { + name: "Non-JSON string", + input: "This is not JSON", + expected: "This is not JSON", + }, + { + name: "Partial JSON with escaped quotes", + input: `{"message": "He said \"Hello\""`, + expected: `{"message": "He said \"Hello\""}`, + }, + { + name: "Complex nested structure", + input: `{"data": {"users": [{"id": 1, "name": "John"}, {"id": 2, "name": "Jane"`, + expected: `{"data": {"users": [{"id": 1, "name": "John"}, {"id": 2, "name": "Jane"}]}}`, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := plugin.parsePartialJSON(tt.input) + if result != tt.expected { + t.Errorf("parsePartialJSON(%q) = %q, want %q", tt.input, result, tt.expected) + } + }) + } +} diff --git a/plugins/jsonparser/version b/plugins/jsonparser/version new file mode 100644 index 000000000..05060b805 --- /dev/null +++ b/plugins/jsonparser/version @@ -0,0 +1 @@ +1.2.15 \ No newline at end of file diff --git a/plugins/logging/changelog.md b/plugins/logging/changelog.md new file mode 100644 index 000000000..06359c9de --- /dev/null +++ b/plugins/logging/changelog.md @@ -0,0 +1,5 @@ + + + +- Upgrades framework to 1.0.23 +- Fixes pricing computation for nested model names. \ No newline at end of file diff --git a/plugins/logging/go.mod b/plugins/logging/go.mod new file mode 100644 index 000000000..c7dbf7248 --- /dev/null +++ b/plugins/logging/go.mod @@ -0,0 +1,88 @@ +module github.com/maximhq/bifrost/plugins/logging + +go 1.24 + +toolchain go1.24.3 + +require ( + github.com/maximhq/bifrost/core v1.1.37 + github.com/maximhq/bifrost/framework v1.0.23 +) + +require ( + cloud.google.com/go/compute/metadata v0.8.0 // indirect + github.com/andybalholm/brotli v1.2.0 // indirect + github.com/asaskevich/govalidator v0.0.0-20230301143203-a9d515a09cc2 // indirect + github.com/aws/aws-sdk-go-v2 v1.38.0 // indirect + github.com/aws/aws-sdk-go-v2/config v1.31.0 // indirect + github.com/aws/aws-sdk-go-v2/credentials v1.18.4 // indirect + github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.18.3 // indirect + github.com/aws/aws-sdk-go-v2/internal/configsources v1.4.3 // indirect + github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.7.3 // indirect + github.com/aws/aws-sdk-go-v2/internal/ini v1.8.3 // indirect + github.com/aws/aws-sdk-go-v2/service/internal/accept-encoding v1.13.0 // indirect + github.com/aws/aws-sdk-go-v2/service/internal/presigned-url v1.13.3 // indirect + github.com/aws/aws-sdk-go-v2/service/sso v1.28.0 // indirect + github.com/aws/aws-sdk-go-v2/service/ssooidc v1.33.0 // indirect + github.com/aws/aws-sdk-go-v2/service/sts v1.37.0 // indirect + github.com/aws/smithy-go v1.22.5 // indirect + github.com/bahlo/generic-list-go v0.2.0 // indirect + github.com/buger/jsonparser v1.1.1 // indirect + github.com/bytedance/sonic v1.14.0 // indirect + github.com/bytedance/sonic/loader v0.3.0 // indirect + github.com/cespare/xxhash/v2 v2.3.0 // indirect + github.com/cloudwego/base64x v0.1.6 // indirect + github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f // indirect + github.com/go-openapi/analysis v0.23.0 // indirect + github.com/go-openapi/errors v0.22.0 // indirect + github.com/go-openapi/jsonpointer v0.21.0 // indirect + github.com/go-openapi/jsonreference v0.21.0 // indirect + github.com/go-openapi/loads v0.22.0 // indirect + github.com/go-openapi/runtime v0.24.2 // indirect + github.com/go-openapi/spec v0.21.0 // indirect + github.com/go-openapi/strfmt v0.23.0 // indirect + github.com/go-openapi/swag v0.23.0 // indirect + github.com/go-openapi/validate v0.24.0 // indirect + github.com/google/uuid v1.6.0 // indirect + github.com/invopop/jsonschema v0.13.0 // indirect + github.com/jinzhu/inflection v1.0.0 // indirect + github.com/jinzhu/now v1.1.5 // indirect + github.com/josharian/intern v1.0.0 // indirect + github.com/klauspost/compress v1.18.0 // indirect + github.com/klauspost/cpuid/v2 v2.3.0 // indirect + github.com/mailru/easyjson v0.9.0 // indirect + github.com/mark3labs/mcp-go v0.37.0 // indirect + github.com/mattn/go-colorable v0.1.14 // indirect + github.com/mattn/go-isatty v0.0.20 // indirect + github.com/mattn/go-sqlite3 v1.14.32 // indirect + github.com/mitchellh/mapstructure v1.5.0 // indirect + github.com/oklog/ulid v1.3.1 // indirect + github.com/opentracing/opentracing-go v1.2.0 // indirect + github.com/pkg/errors v0.9.1 // indirect + github.com/redis/go-redis/v9 v9.12.1 // indirect + github.com/rs/zerolog v1.34.0 // indirect + github.com/spf13/cast v1.9.2 // indirect + github.com/twitchyliquid64/golang-asm v0.15.1 // indirect + github.com/valyala/bytebufferpool v1.0.0 // indirect + github.com/valyala/fasthttp v1.65.0 // indirect + github.com/weaviate/weaviate v1.31.5 // indirect + github.com/weaviate/weaviate-go-client/v5 v5.2.0 // indirect + github.com/wk8/go-ordered-map/v2 v2.1.8 // indirect + github.com/yosida95/uritemplate/v3 v3.0.2 // indirect + go.mongodb.org/mongo-driver v1.14.0 // indirect + go.opentelemetry.io/otel/metric v1.37.0 // indirect + go.opentelemetry.io/otel/sdk v1.37.0 // indirect + go.opentelemetry.io/otel/trace v1.37.0 // indirect + golang.org/x/arch v0.20.0 // indirect + golang.org/x/net v0.43.0 // indirect + golang.org/x/oauth2 v0.30.0 // indirect + golang.org/x/sys v0.35.0 // indirect + golang.org/x/text v0.28.0 // indirect + google.golang.org/genproto/googleapis/rpc v0.0.0-20250811230008-5f3141c8851a // indirect + google.golang.org/grpc v1.74.2 // indirect + google.golang.org/protobuf v1.36.7 // indirect + gopkg.in/yaml.v2 v2.4.0 // indirect + gopkg.in/yaml.v3 v3.0.1 // indirect + gorm.io/driver/sqlite v1.6.0 // indirect + gorm.io/gorm v1.30.1 // indirect +) diff --git a/plugins/logging/go.sum b/plugins/logging/go.sum new file mode 100644 index 000000000..a8bac98bd --- /dev/null +++ b/plugins/logging/go.sum @@ -0,0 +1,355 @@ +cloud.google.com/go/compute/metadata v0.8.0 h1:HxMRIbao8w17ZX6wBnjhcDkW6lTFpgcaobyVfZWqRLA= +cloud.google.com/go/compute/metadata v0.8.0/go.mod h1:sYOGTp851OV9bOFJ9CH7elVvyzopvWQFNNghtDQ/Biw= +github.com/BurntSushi/toml v0.3.1/go.mod h1:xHWCNGjB5oqiDr8zfno3MHue2Ht5sIBksp03qcyfWMU= +github.com/PuerkitoBio/purell v1.1.1/go.mod h1:c11w/QuzBsJSee3cPx9rAFu61PvFxuPbtSwDGJws/X0= +github.com/PuerkitoBio/urlesc v0.0.0-20170810143723-de5bf2ad4578/go.mod h1:uGdkoq3SwY9Y+13GIhn11/XLaGBb4BfwItxLd5jeuXE= +github.com/andybalholm/brotli v1.2.0 h1:ukwgCxwYrmACq68yiUqwIWnGY0cTPox/M94sVwToPjQ= +github.com/andybalholm/brotli v1.2.0/go.mod h1:rzTDkvFWvIrjDXZHkuS16NPggd91W3kUSvPlQ1pLaKY= +github.com/asaskevich/govalidator v0.0.0-20200907205600-7a23bdc65eef/go.mod h1:WaHUgvxTVq04UNunO+XhnAqY/wQc+bxr74GqbsZ/Jqw= +github.com/asaskevich/govalidator v0.0.0-20210307081110-f21760c49a8d/go.mod h1:WaHUgvxTVq04UNunO+XhnAqY/wQc+bxr74GqbsZ/Jqw= +github.com/asaskevich/govalidator v0.0.0-20230301143203-a9d515a09cc2 h1:DklsrG3dyBCFEj5IhUbnKptjxatkF07cF2ak3yi77so= +github.com/asaskevich/govalidator v0.0.0-20230301143203-a9d515a09cc2/go.mod h1:WaHUgvxTVq04UNunO+XhnAqY/wQc+bxr74GqbsZ/Jqw= +github.com/aws/aws-sdk-go-v2 v1.38.0 h1:UCRQ5mlqcFk9HJDIqENSLR3wiG1VTWlyUfLDEvY7RxU= +github.com/aws/aws-sdk-go-v2 v1.38.0/go.mod h1:9Q0OoGQoboYIAJyslFyF1f5K1Ryddop8gqMhWx/n4Wg= +github.com/aws/aws-sdk-go-v2/config v1.31.0 h1:9yH0xiY5fUnVNLRWO0AtayqwU1ndriZdN78LlhruJR4= +github.com/aws/aws-sdk-go-v2/config v1.31.0/go.mod h1:VeV3K72nXnhbe4EuxxhzsDc/ByrCSlZwUnWH52Nde/I= +github.com/aws/aws-sdk-go-v2/credentials v1.18.4 h1:IPd0Algf1b+Qy9BcDp0sCUcIWdCQPSzDoMK3a8pcbUM= +github.com/aws/aws-sdk-go-v2/credentials v1.18.4/go.mod h1:nwg78FjH2qvsRM1EVZlX9WuGUJOL5od+0qvm0adEzHk= +github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.18.3 h1:GicIdnekoJsjq9wqnvyi2elW6CGMSYKhdozE7/Svh78= +github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.18.3/go.mod h1:R7BIi6WNC5mc1kfRM7XM/VHC3uRWkjc396sfabq4iOo= +github.com/aws/aws-sdk-go-v2/internal/configsources v1.4.3 h1:o9RnO+YZ4X+kt5Z7Nvcishlz0nksIt2PIzDglLMP0vA= +github.com/aws/aws-sdk-go-v2/internal/configsources v1.4.3/go.mod h1:+6aLJzOG1fvMOyzIySYjOFjcguGvVRL68R+uoRencN4= +github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.7.3 h1:joyyUFhiTQQmVK6ImzNU9TQSNRNeD9kOklqTzyk5v6s= +github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.7.3/go.mod h1:+vNIyZQP3b3B1tSLI0lxvrU9cfM7gpdRXMFfm67ZcPc= +github.com/aws/aws-sdk-go-v2/internal/ini v1.8.3 h1:bIqFDwgGXXN1Kpp99pDOdKMTTb5d2KyU5X/BZxjOkRo= +github.com/aws/aws-sdk-go-v2/internal/ini v1.8.3/go.mod h1:H5O/EsxDWyU+LP/V8i5sm8cxoZgc2fdNR9bxlOFrQTo= +github.com/aws/aws-sdk-go-v2/service/internal/accept-encoding v1.13.0 h1:6+lZi2JeGKtCraAj1rpoZfKqnQ9SptseRZioejfUOLM= +github.com/aws/aws-sdk-go-v2/service/internal/accept-encoding v1.13.0/go.mod h1:eb3gfbVIxIoGgJsi9pGne19dhCBpK6opTYpQqAmdy44= +github.com/aws/aws-sdk-go-v2/service/internal/presigned-url v1.13.3 h1:ieRzyHXypu5ByllM7Sp4hC5f/1Fy5wqxqY0yB85hC7s= +github.com/aws/aws-sdk-go-v2/service/internal/presigned-url v1.13.3/go.mod h1:O5ROz8jHiOAKAwx179v+7sHMhfobFVi6nZt8DEyiYoM= +github.com/aws/aws-sdk-go-v2/service/sso v1.28.0 h1:Mc/MKBf2m4VynyJkABoVEN+QzkfLqGj0aiJuEe7cMeM= +github.com/aws/aws-sdk-go-v2/service/sso v1.28.0/go.mod h1:iS5OmxEcN4QIPXARGhavH7S8kETNL11kym6jhoS7IUQ= +github.com/aws/aws-sdk-go-v2/service/ssooidc v1.33.0 h1:6csaS/aJmqZQbKhi1EyEMM7yBW653Wy/B9hnBofW+sw= +github.com/aws/aws-sdk-go-v2/service/ssooidc v1.33.0/go.mod h1:59qHWaY5B+Rs7HGTuVGaC32m0rdpQ68N8QCN3khYiqs= +github.com/aws/aws-sdk-go-v2/service/sts v1.37.0 h1:MG9VFW43M4A8BYeAfaJJZWrroinxeTi2r3+SnmLQfSA= +github.com/aws/aws-sdk-go-v2/service/sts v1.37.0/go.mod h1:JdeBDPgpJfuS6rU/hNglmOigKhyEZtBmbraLE4GK1J8= +github.com/aws/smithy-go v1.22.5 h1:P9ATCXPMb2mPjYBgueqJNCA5S9UfktsW0tTxi+a7eqw= +github.com/aws/smithy-go v1.22.5/go.mod h1:t1ufH5HMublsJYulve2RKmHDC15xu1f26kHCp/HgceI= +github.com/bahlo/generic-list-go v0.2.0 h1:5sz/EEAK+ls5wF+NeqDpk5+iNdMDXrh3z3nPnH1Wvgk= +github.com/bahlo/generic-list-go v0.2.0/go.mod h1:2KvAjgMlE5NNynlg/5iLrrCCZ2+5xWbdbCW3pNTGyYg= +github.com/bsm/ginkgo/v2 v2.12.0 h1:Ny8MWAHyOepLGlLKYmXG4IEkioBysk6GpaRTLC8zwWs= +github.com/bsm/ginkgo/v2 v2.12.0/go.mod h1:SwYbGRRDovPVboqFv0tPTcG1sN61LM1Z4ARdbAV9g4c= +github.com/bsm/gomega v1.27.10 h1:yeMWxP2pV2fG3FgAODIY8EiRE3dy0aeFYt4l7wh6yKA= +github.com/bsm/gomega v1.27.10/go.mod h1:JyEr/xRbxbtgWNi8tIEVPUYZ5Dzef52k01W3YH0H+O0= +github.com/buger/jsonparser v1.1.1 h1:2PnMjfWD7wBILjqQbt530v576A/cAbQvEW9gGIpYMUs= +github.com/buger/jsonparser v1.1.1/go.mod h1:6RYKKt7H4d4+iWqouImQ9R2FZql3VbhNgx27UK13J/0= +github.com/bytedance/sonic v1.14.0 h1:/OfKt8HFw0kh2rj8N0F6C/qPGRESq0BbaNZgcNXXzQQ= +github.com/bytedance/sonic v1.14.0/go.mod h1:WoEbx8WTcFJfzCe0hbmyTGrfjt8PzNEBdxlNUO24NhA= +github.com/bytedance/sonic/loader v0.3.0 h1:dskwH8edlzNMctoruo8FPTJDF3vLtDT0sXZwvZJyqeA= +github.com/bytedance/sonic/loader v0.3.0/go.mod h1:N8A3vUdtUebEY2/VQC0MyhYeKUFosQU6FxH2JmUe6VI= +github.com/cespare/xxhash/v2 v2.3.0 h1:UL815xU9SqsFlibzuggzjXhog7bL6oX9BbNZnL2UFvs= +github.com/cespare/xxhash/v2 v2.3.0/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs= +github.com/cloudwego/base64x v0.1.6 h1:t11wG9AECkCDk5fMSoxmufanudBtJ+/HemLstXDLI2M= +github.com/cloudwego/base64x v0.1.6/go.mod h1:OFcloc187FXDaYHvrNIjxSe8ncn0OOM8gEHfghB2IPU= +github.com/coreos/go-systemd/v22 v22.5.0/go.mod h1:Y58oyj3AT4RCenI/lSvhwexgC+NSVTIJ3seZv2GcEnc= +github.com/creack/pty v1.1.9/go.mod h1:oKZEueFk5CKHvIhNR5MUki03XCEU+Q6VDXinZuGJ33E= +github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc h1:U9qPSI2PIWSS1VwoXQT9A3Wy9MM3WgvqSxFWenqJduM= +github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f h1:lO4WD4F/rVNCu3HqELle0jiPLLBs70cWOduZpkS1E78= +github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f/go.mod h1:cuUVRXasLTGF7a8hSLbxyZXjz+1KgoB3wDUb6vlszIc= +github.com/docker/go-units v0.4.0/go.mod h1:fgPhTUdO+D/Jk86RDLlptpiXQzgHJF7gydDDbaIK4Dk= +github.com/frankban/quicktest v1.14.6 h1:7Xjx+VpznH+oBnejlPUj8oUpdxnVs4f8XU8WnHkI4W8= +github.com/frankban/quicktest v1.14.6/go.mod h1:4ptaffx2x8+WTWXmUCuVU6aPUX1/Mz7zb5vbUoiM6w0= +github.com/go-logr/logr v1.4.3 h1:CjnDlHq8ikf6E492q6eKboGOC0T8CDaOvkHCIg8idEI= +github.com/go-logr/logr v1.4.3/go.mod h1:9T104GzyrTigFIr8wt5mBrctHMim0Nb2HLGrmQ40KvY= +github.com/go-logr/stdr v1.2.2 h1:hSWxHoqTgW2S2qGc0LTAI563KZ5YKYRhT3MFKZMbjag= +github.com/go-logr/stdr v1.2.2/go.mod h1:mMo/vtBO5dYbehREoey6XUKy/eSumjCCveDpRre4VKE= +github.com/go-openapi/analysis v0.21.2/go.mod h1:HZwRk4RRisyG8vx2Oe6aqeSQcoxRp47Xkp3+K6q+LdY= +github.com/go-openapi/analysis v0.23.0 h1:aGday7OWupfMs+LbmLZG4k0MYXIANxcuBTYUC03zFCU= +github.com/go-openapi/analysis v0.23.0/go.mod h1:9mz9ZWaSlV8TvjQHLl2mUW2PbZtemkE8yA5v22ohupo= +github.com/go-openapi/errors v0.19.8/go.mod h1:cM//ZKUKyO06HSwqAelJ5NsEMMcpa6VpXe8DOa1Mi1M= +github.com/go-openapi/errors v0.19.9/go.mod h1:cM//ZKUKyO06HSwqAelJ5NsEMMcpa6VpXe8DOa1Mi1M= +github.com/go-openapi/errors v0.20.2/go.mod h1:cM//ZKUKyO06HSwqAelJ5NsEMMcpa6VpXe8DOa1Mi1M= +github.com/go-openapi/errors v0.22.0 h1:c4xY/OLxUBSTiepAg3j/MHuAv5mJhnf53LLMWFB+u/w= +github.com/go-openapi/errors v0.22.0/go.mod h1:J3DmZScxCDufmIMsdOuDHxJbdOGC0xtUynjIx092vXE= +github.com/go-openapi/jsonpointer v0.19.3/go.mod h1:Pl9vOtqEWErmShwVjC8pYs9cog34VGT37dQOVbmoatg= +github.com/go-openapi/jsonpointer v0.19.5/go.mod h1:Pl9vOtqEWErmShwVjC8pYs9cog34VGT37dQOVbmoatg= +github.com/go-openapi/jsonpointer v0.21.0 h1:YgdVicSA9vH5RiHs9TZW5oyafXZFc6+2Vc1rr/O9oNQ= +github.com/go-openapi/jsonpointer v0.21.0/go.mod h1:IUyH9l/+uyhIYQ/PXVA41Rexl+kOkAPDdXEYns6fzUY= +github.com/go-openapi/jsonreference v0.19.6/go.mod h1:diGHMEHg2IqXZGKxqyvWdfWU/aim5Dprw5bqpKkTvns= +github.com/go-openapi/jsonreference v0.21.0 h1:Rs+Y7hSXT83Jacb7kFyjn4ijOuVGSvOdF2+tg1TRrwQ= +github.com/go-openapi/jsonreference v0.21.0/go.mod h1:LmZmgsrTkVg9LG4EaHeY8cBDslNPMo06cago5JNLkm4= +github.com/go-openapi/loads v0.21.1/go.mod h1:/DtAMXXneXFjbQMGEtbamCZb+4x7eGwkvZCvBmwUG+g= +github.com/go-openapi/loads v0.22.0 h1:ECPGd4jX1U6NApCGG1We+uEozOAvXvJSF4nnwHZ8Aco= +github.com/go-openapi/loads v0.22.0/go.mod h1:yLsaTCS92mnSAZX5WWoxszLj0u+Ojl+Zs5Stn1oF+rs= +github.com/go-openapi/runtime v0.24.2 h1:yX9HMGQbz32M87ECaAhGpJjBmErO3QLcgdZj9BzGx7c= +github.com/go-openapi/runtime v0.24.2/go.mod h1:AKurw9fNre+h3ELZfk6ILsfvPN+bvvlaU/M9q/r9hpk= +github.com/go-openapi/spec v0.20.4/go.mod h1:faYFR1CvsJZ0mNsmsphTMSoRrNV3TEDoAM7FOEWeq8I= +github.com/go-openapi/spec v0.21.0 h1:LTVzPc3p/RzRnkQqLRndbAzjY0d0BCL72A6j3CdL9ZY= +github.com/go-openapi/spec v0.21.0/go.mod h1:78u6VdPw81XU44qEWGhtr982gJ5BWg2c0I5XwVMotYk= +github.com/go-openapi/strfmt v0.21.0/go.mod h1:ZRQ409bWMj+SOgXofQAGTIo2Ebu72Gs+WaRADcS5iNg= +github.com/go-openapi/strfmt v0.21.1/go.mod h1:I/XVKeLc5+MM5oPNN7P6urMOpuLXEcNrCX/rPGuWb0k= +github.com/go-openapi/strfmt v0.21.2/go.mod h1:I/XVKeLc5+MM5oPNN7P6urMOpuLXEcNrCX/rPGuWb0k= +github.com/go-openapi/strfmt v0.23.0 h1:nlUS6BCqcnAk0pyhi9Y+kdDVZdZMHfEKQiS4HaMgO/c= +github.com/go-openapi/strfmt v0.23.0/go.mod h1:NrtIpfKtWIygRkKVsxh7XQMDQW5HKQl6S5ik2elW+K4= +github.com/go-openapi/swag v0.19.5/go.mod h1:POnQmlKehdgb5mhVOsnJFsivZCEZ/vjK9gh66Z9tfKk= +github.com/go-openapi/swag v0.19.15/go.mod h1:QYRuS/SOXUCsnplDa677K7+DxSOj6IPNl/eQntq43wQ= +github.com/go-openapi/swag v0.21.1/go.mod h1:QYRuS/SOXUCsnplDa677K7+DxSOj6IPNl/eQntq43wQ= +github.com/go-openapi/swag v0.23.0 h1:vsEVJDUo2hPJ2tu0/Xc+4noaxyEffXNIs3cOULZ+GrE= +github.com/go-openapi/swag v0.23.0/go.mod h1:esZ8ITTYEsH1V2trKHjAN8Ai7xHb8RV+YSZ577vPjgQ= +github.com/go-openapi/validate v0.21.0/go.mod h1:rjnrwK57VJ7A8xqfpAOEKRH8yQSGUriMu5/zuPSQ1hg= +github.com/go-openapi/validate v0.24.0 h1:LdfDKwNbpB6Vn40xhTdNZAnfLECL81w+VX3BumrGD58= +github.com/go-openapi/validate v0.24.0/go.mod h1:iyeX1sEufmv3nPbBdX3ieNviWnOZaJ1+zquzJEf2BAQ= +github.com/go-stack/stack v1.8.0/go.mod h1:v0f6uXyyMGvRgIKkXu+yp6POWl0qKG85gN/melR3HDY= +github.com/go-stack/stack v1.8.1/go.mod h1:dcoOX6HbPZSZptuspn9bctJ+N/CnF5gGygcUP3XYfe4= +github.com/gobuffalo/attrs v0.0.0-20190224210810-a9411de4debd/go.mod h1:4duuawTqi2wkkpB4ePgWMaai6/Kc6WEz83bhFwpHzj0= +github.com/gobuffalo/depgen v0.0.0-20190329151759-d478694a28d3/go.mod h1:3STtPUQYuzV0gBVOY3vy6CfMm/ljR4pABfrTeHNLHUY= +github.com/gobuffalo/depgen v0.1.0/go.mod h1:+ifsuy7fhi15RWncXQQKjWS9JPkdah5sZvtHc2RXGlg= +github.com/gobuffalo/envy v1.6.15/go.mod h1:n7DRkBerg/aorDM8kbduw5dN3oXGswK5liaSCx4T5NI= +github.com/gobuffalo/envy v1.7.0/go.mod h1:n7DRkBerg/aorDM8kbduw5dN3oXGswK5liaSCx4T5NI= +github.com/gobuffalo/flect v0.1.0/go.mod h1:d2ehjJqGOH/Kjqcoz+F7jHTBbmDb38yXA598Hb50EGs= +github.com/gobuffalo/flect v0.1.1/go.mod h1:8JCgGVbRjJhVgD6399mQr4fx5rRfGKVzFjbj6RE/9UI= +github.com/gobuffalo/flect v0.1.3/go.mod h1:8JCgGVbRjJhVgD6399mQr4fx5rRfGKVzFjbj6RE/9UI= +github.com/gobuffalo/genny v0.0.0-20190329151137-27723ad26ef9/go.mod h1:rWs4Z12d1Zbf19rlsn0nurr75KqhYp52EAGGxTbBhNk= +github.com/gobuffalo/genny v0.0.0-20190403191548-3ca520ef0d9e/go.mod h1:80lIj3kVJWwOrXWWMRzzdhW3DsrdjILVil/SFKBzF28= +github.com/gobuffalo/genny v0.1.0/go.mod h1:XidbUqzak3lHdS//TPu2OgiFB+51Ur5f7CSnXZ/JDvo= +github.com/gobuffalo/genny v0.1.1/go.mod h1:5TExbEyY48pfunL4QSXxlDOmdsD44RRq4mVZ0Ex28Xk= +github.com/gobuffalo/gitgen v0.0.0-20190315122116-cc086187d211/go.mod h1:vEHJk/E9DmhejeLeNt7UVvlSGv3ziL+djtTr3yyzcOw= +github.com/gobuffalo/gogen v0.0.0-20190315121717-8f38393713f5/go.mod h1:V9QVDIxsgKNZs6L2IYiGR8datgMhB577vzTDqypH360= +github.com/gobuffalo/gogen v0.1.0/go.mod h1:8NTelM5qd8RZ15VjQTFkAW6qOMx5wBbW4dSCS3BY8gg= +github.com/gobuffalo/gogen v0.1.1/go.mod h1:y8iBtmHmGc4qa3urIyo1shvOD8JftTtfcKi+71xfDNE= +github.com/gobuffalo/logger v0.0.0-20190315122211-86e12af44bc2/go.mod h1:QdxcLw541hSGtBnhUc4gaNIXRjiDppFGaDqzbrBd3v8= +github.com/gobuffalo/mapi v1.0.1/go.mod h1:4VAGh89y6rVOvm5A8fKFxYG+wIW6LO1FMTG9hnKStFc= +github.com/gobuffalo/mapi v1.0.2/go.mod h1:4VAGh89y6rVOvm5A8fKFxYG+wIW6LO1FMTG9hnKStFc= +github.com/gobuffalo/packd v0.0.0-20190315124812-a385830c7fc0/go.mod h1:M2Juc+hhDXf/PnmBANFCqx4DM3wRbgDvnVWeG2RIxq4= +github.com/gobuffalo/packd v0.1.0/go.mod h1:M2Juc+hhDXf/PnmBANFCqx4DM3wRbgDvnVWeG2RIxq4= +github.com/gobuffalo/packr/v2 v2.0.9/go.mod h1:emmyGweYTm6Kdper+iywB6YK5YzuKchGtJQZ0Odn4pQ= +github.com/gobuffalo/packr/v2 v2.2.0/go.mod h1:CaAwI0GPIAv+5wKLtv8Afwl+Cm78K/I/VCm/3ptBN+0= +github.com/gobuffalo/syncx v0.0.0-20190224160051-33c29581e754/go.mod h1:HhnNqWY95UYwwW3uSASeV7vtgYkT2t16hJgV3AEPUpw= +github.com/godbus/dbus/v5 v5.0.4/go.mod h1:xhWf0FNVPg57R7Z0UbKHbJfkEywrmjJnf7w5xrFpKfA= +github.com/golang/protobuf v1.5.4 h1:i7eJL8qZTpSEXOPTxNKhASYpMn+8e5Q6AdndVa1dWek= +github.com/golang/protobuf v1.5.4/go.mod h1:lnTiLA8Wa4RWRcIUkrtSVa5nRhsEGBg48fD6rSs7xps= +github.com/golang/snappy v0.0.1/go.mod h1:/XxbfmMg8lxefKM7IXC3fBNl/7bRcc72aCRzEWrmP2Q= +github.com/google/go-cmp v0.5.2/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= +github.com/google/go-cmp v0.7.0 h1:wk8382ETsv4JYUZwIsn6YpYiWiBsYLSJiTsyBybVuN8= +github.com/google/go-cmp v0.7.0/go.mod h1:pXiqmnSA92OHEEa9HXL2W4E7lf9JzCmGVUdgjX3N/iU= +github.com/google/uuid v1.1.1/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= +github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0= +github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= +github.com/inconshreveable/mousetrap v1.0.0/go.mod h1:PxqpIevigyE2G7u3NXJIT2ANytuPF1OarO4DADm73n8= +github.com/invopop/jsonschema v0.13.0 h1:KvpoAJWEjR3uD9Kbm2HWJmqsEaHt8lBUpd0qHcIi21E= +github.com/invopop/jsonschema v0.13.0/go.mod h1:ffZ5Km5SWWRAIN6wbDXItl95euhFz2uON45H2qjYt+0= +github.com/jinzhu/inflection v1.0.0 h1:K317FqzuhWc8YvSVlFMCCUb36O/S9MCKRDI7QkRKD/E= +github.com/jinzhu/inflection v1.0.0/go.mod h1:h+uFLlag+Qp1Va5pdKtLDYj+kHp5pxUVkryuEj+Srlc= +github.com/jinzhu/now v1.1.5 h1:/o9tlHleP7gOFmsnYNz3RGnqzefHA47wQpKrrdTIwXQ= +github.com/jinzhu/now v1.1.5/go.mod h1:d3SSVoowX0Lcu0IBviAWJpolVfI5UJVZZ7cO71lE/z8= +github.com/joho/godotenv v1.3.0/go.mod h1:7hK45KPybAkOC6peb+G5yklZfMxEjkZhHbwpqxOKXbg= +github.com/josharian/intern v1.0.0 h1:vlS4z54oSdjm0bgjRigI+G1HpF+tI+9rE5LLzOg8HmY= +github.com/josharian/intern v1.0.0/go.mod h1:5DoeVV0s6jJacbCEi61lwdGj/aVlrQvzHFFd8Hwg//Y= +github.com/karrick/godirwalk v1.8.0/go.mod h1:H5KPZjojv4lE+QYImBI8xVtrBRgYrIVsaRPx4tDPEn4= +github.com/karrick/godirwalk v1.10.3/go.mod h1:RoGL9dQei4vP9ilrpETWE8CLOZ1kiN0LhBygSwrAsHA= +github.com/klauspost/compress v1.13.6/go.mod h1:/3/Vjq9QcHkK5uEr5lBEmyoZ1iFhe47etQ6QUkpK6sk= +github.com/klauspost/compress v1.18.0 h1:c/Cqfb0r+Yi+JtIEq73FWXVkRonBlf0CRNYc8Zttxdo= +github.com/klauspost/compress v1.18.0/go.mod h1:2Pp+KzxcywXVXMr50+X0Q/Lsb43OQHYWRCY2AiWywWQ= +github.com/klauspost/cpuid/v2 v2.3.0 h1:S4CRMLnYUhGeDFDqkGriYKdfoFlDnMtqTiI/sFzhA9Y= +github.com/klauspost/cpuid/v2 v2.3.0/go.mod h1:hqwkgyIinND0mEev00jJYCxPNVRVXFQeu1XKlok6oO0= +github.com/konsorten/go-windows-terminal-sequences v1.0.1/go.mod h1:T0+1ngSBFLxvqU3pZ+m/2kptfBszLMUkC4ZK/EgS/cQ= +github.com/konsorten/go-windows-terminal-sequences v1.0.2/go.mod h1:T0+1ngSBFLxvqU3pZ+m/2kptfBszLMUkC4ZK/EgS/cQ= +github.com/kr/pretty v0.1.0/go.mod h1:dAy3ld7l9f0ibDNOQOHHMYYIIbhfbHSm3C4ZsoJORNo= +github.com/kr/pretty v0.3.1 h1:flRD4NNwYAUpkphVc1HcthR4KEIFJ65n8Mw5qdRn3LE= +github.com/kr/pretty v0.3.1/go.mod h1:hoEshYVHaxMs3cyo3Yncou5ZscifuDolrwPKZanG3xk= +github.com/kr/pty v1.1.1/go.mod h1:pFQYn66WHrOpPYNljwOMqo10TkYh1fy3cYio2l3bCsQ= +github.com/kr/text v0.1.0/go.mod h1:4Jbv+DJW3UT/LiOwJeYQe1efqtUx/iVham/4vfdArNI= +github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY= +github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE= +github.com/mailru/easyjson v0.0.0-20190614124828-94de47d64c63/go.mod h1:C1wdFJiN94OJF2b5HbByQZoLdCWB1Yqtg26g4irojpc= +github.com/mailru/easyjson v0.0.0-20190626092158-b2ccc519800e/go.mod h1:C1wdFJiN94OJF2b5HbByQZoLdCWB1Yqtg26g4irojpc= +github.com/mailru/easyjson v0.7.6/go.mod h1:xzfreul335JAWq5oZzymOObrkdz5UnU4kGfJJLY9Nlc= +github.com/mailru/easyjson v0.7.7/go.mod h1:xzfreul335JAWq5oZzymOObrkdz5UnU4kGfJJLY9Nlc= +github.com/mailru/easyjson v0.9.0 h1:PrnmzHw7262yW8sTBwxi1PdJA3Iw/EKBa8psRf7d9a4= +github.com/mailru/easyjson v0.9.0/go.mod h1:1+xMtQp2MRNVL/V1bOzuP3aP8VNwRW55fQUto+XFtTU= +github.com/mark3labs/mcp-go v0.37.0 h1:BywvZLPRT6Zx6mMG/MJfxLSZQkTGIcJSEGKsvr4DsoQ= +github.com/mark3labs/mcp-go v0.37.0/go.mod h1:T7tUa2jO6MavG+3P25Oy/jR7iCeJPHImCZHRymCn39g= +github.com/markbates/oncer v0.0.0-20181203154359-bf2de49a0be2/go.mod h1:Ld9puTsIW75CHf65OeIOkyKbteujpZVXDpWK6YGZbxE= +github.com/markbates/safe v1.0.1/go.mod h1:nAqgmRi7cY2nqMc92/bSEeQA+R4OheNU2T1kNSCBdG0= +github.com/mattn/go-colorable v0.1.13/go.mod h1:7S9/ev0klgBDR4GtXTXX8a3vIGJpMovkB8vQcUbaXHg= +github.com/mattn/go-colorable v0.1.14 h1:9A9LHSqF/7dyVVX6g0U9cwm9pG3kP9gSzcuIPHPsaIE= +github.com/mattn/go-colorable v0.1.14/go.mod h1:6LmQG8QLFO4G5z1gPvYEzlUgJ2wF+stgPZH1UqBm1s8= +github.com/mattn/go-isatty v0.0.16/go.mod h1:kYGgaQfpe5nmfYZH+SKPsOc2e4SrIfOl2e/yFXSvRLM= +github.com/mattn/go-isatty v0.0.19/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y= +github.com/mattn/go-isatty v0.0.20 h1:xfD0iDuEKnDkl03q4limB+vH+GxLEtL/jb4xVJSWWEY= +github.com/mattn/go-isatty v0.0.20/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y= +github.com/mattn/go-sqlite3 v1.14.32 h1:JD12Ag3oLy1zQA+BNn74xRgaBbdhbNIDYvQUEuuErjs= +github.com/mattn/go-sqlite3 v1.14.32/go.mod h1:Uh1q+B4BYcTPb+yiD3kU8Ct7aC0hY9fxUwlHK0RXw+Y= +github.com/maximhq/bifrost/core v1.1.37 h1:jVFY1tQFY8T2r4S3RE1zN8cFp1Uw97Dec3Ud32rR8Uc= +github.com/maximhq/bifrost/core v1.1.37/go.mod h1:tf2pFTpoM53UGXXMFYxsaUjMqnCqYDOd9glFgMJvA0c= +github.com/maximhq/bifrost/framework v1.0.23 h1:erRPP9Q0WIaUgxuLBN8urd77SObEF9irPvpV9Wbegyk= +github.com/maximhq/bifrost/framework v1.0.23/go.mod h1:uEB0iuQtFfuFuMrhccMsb+51mf8m8X2tB8ZlDVoJUbM= +github.com/mitchellh/mapstructure v1.3.3/go.mod h1:bFUtVrKA4DC2yAKiSyO/QUcy7e+RRV2QTWOzhPopBRo= +github.com/mitchellh/mapstructure v1.4.1/go.mod h1:bFUtVrKA4DC2yAKiSyO/QUcy7e+RRV2QTWOzhPopBRo= +github.com/mitchellh/mapstructure v1.4.3/go.mod h1:bFUtVrKA4DC2yAKiSyO/QUcy7e+RRV2QTWOzhPopBRo= +github.com/mitchellh/mapstructure v1.5.0 h1:jeMsZIYE/09sWLaz43PL7Gy6RuMjD2eJVyuac5Z2hdY= +github.com/mitchellh/mapstructure v1.5.0/go.mod h1:bFUtVrKA4DC2yAKiSyO/QUcy7e+RRV2QTWOzhPopBRo= +github.com/montanaflynn/stats v0.0.0-20171201202039-1bf9dbcd8cbe/go.mod h1:wL8QJuTMNUDYhXwkmfOly8iTdp5TEcJFWZD2D7SIkUc= +github.com/niemeyer/pretty v0.0.0-20200227124842-a10e7caefd8e/go.mod h1:zD1mROLANZcx1PVRCS0qkT7pwLkGfwJo4zjcN/Tysno= +github.com/oklog/ulid v1.3.1 h1:EGfNDEx6MqHz8B3uNV6QAib1UR2Lm97sHi3ocA6ESJ4= +github.com/oklog/ulid v1.3.1/go.mod h1:CirwcVhetQ6Lv90oh/F+FBtV6XMibvdAFo93nm5qn4U= +github.com/opentracing/opentracing-go v1.2.0 h1:uEJPy/1a5RIPAJ0Ov+OIO8OxWu77jEv+1B0VhjKrZUs= +github.com/opentracing/opentracing-go v1.2.0/go.mod h1:GxEUsuufX4nBwe+T+Wl9TAgYrxe9dPLANfrWvHYVTgc= +github.com/pelletier/go-toml v1.7.0/go.mod h1:vwGMzjaWMwyfHwgIBhI2YUM4fB6nL6lVAvS1LBMMhTE= +github.com/pkg/errors v0.8.0/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= +github.com/pkg/errors v0.8.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= +github.com/pkg/errors v0.9.1 h1:FEBLx1zS214owpjy7qsBeixbURkuhQAwrK5UwLGTwt4= +github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= +github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= +github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2 h1:Jamvg5psRIccs7FGNTlIRMkT8wgtp5eCXdBlqhYGL6U= +github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= +github.com/redis/go-redis/v9 v9.12.1 h1:k5iquqv27aBtnTm2tIkROUDp8JBXhXZIVu1InSgvovg= +github.com/redis/go-redis/v9 v9.12.1/go.mod h1:huWgSWd8mW6+m0VPhJjSSQ+d6Nh1VICQ6Q5lHuCH/Iw= +github.com/rogpeppe/go-internal v1.1.0/go.mod h1:M8bDsm7K2OlrFYOpmOWEs/qY81heoFRclV5y23lUDJ4= +github.com/rogpeppe/go-internal v1.2.2/go.mod h1:M8bDsm7K2OlrFYOpmOWEs/qY81heoFRclV5y23lUDJ4= +github.com/rogpeppe/go-internal v1.3.0/go.mod h1:M8bDsm7K2OlrFYOpmOWEs/qY81heoFRclV5y23lUDJ4= +github.com/rogpeppe/go-internal v1.13.1 h1:KvO1DLK/DRN07sQ1LQKScxyZJuNnedQ5/wKSR38lUII= +github.com/rogpeppe/go-internal v1.13.1/go.mod h1:uMEvuHeurkdAXX61udpOXGD/AzZDWNMNyH2VO9fmH0o= +github.com/rs/xid v1.6.0/go.mod h1:7XoLgs4eV+QndskICGsho+ADou8ySMSjJKDIan90Nz0= +github.com/rs/zerolog v1.34.0 h1:k43nTLIwcTVQAncfCw4KZ2VY6ukYoZaBPNOE8txlOeY= +github.com/rs/zerolog v1.34.0/go.mod h1:bJsvje4Z08ROH4Nhs5iH600c3IkWhwp44iRc54W6wYQ= +github.com/sirupsen/logrus v1.4.0/go.mod h1:LxeOpSwHxABJmUn/MG1IvRgCAasNZTLOkJPxbbu5VWo= +github.com/sirupsen/logrus v1.4.1/go.mod h1:ni0Sbl8bgC9z8RoU9G6nDWqqs/fq4eDPysMBDgk/93Q= +github.com/sirupsen/logrus v1.4.2/go.mod h1:tLMulIdttU9McNUspp0xgXVQah82FyeX6MwdIuYE2rE= +github.com/spf13/cast v1.9.2 h1:SsGfm7M8QOFtEzumm7UZrZdLLquNdzFYfIbEXntcFbE= +github.com/spf13/cast v1.9.2/go.mod h1:jNfB8QC9IA6ZuY2ZjDp0KtFO2LZZlg4S/7bzP6qqeHo= +github.com/spf13/cobra v0.0.3/go.mod h1:1l0Ry5zgKvJasoi3XT1TypsSe7PqH0Sj9dhYf7v3XqQ= +github.com/spf13/pflag v1.0.3/go.mod h1:DYY7MBk1bdzusC3SYhjObp+wFpr4gzcvqqNjLnInEg4= +github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= +github.com/stretchr/objx v0.1.1/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= +github.com/stretchr/objx v0.4.0/go.mod h1:YvHI0jy2hoMjB+UWwv71VJQ9isScKT/TqJzVSSt89Yw= +github.com/stretchr/objx v0.5.0/go.mod h1:Yh+to48EsGEfYuaHDzXPcE3xhTkx73EhmCGUpEOglKo= +github.com/stretchr/objx v0.5.2 h1:xuMeJ0Sdp5ZMRXx/aWO6RZxdr3beISkG5/G/aIRr3pY= +github.com/stretchr/objx v0.5.2/go.mod h1:FRsXN1f5AsAjCGJKqEizvkpNtU+EGNCLh3NxZ/8L+MA= +github.com/stretchr/testify v1.2.2/go.mod h1:a8OnRcib4nhh0OaRAV+Yts87kKdq0PP7pXfy6kDkUVs= +github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI= +github.com/stretchr/testify v1.6.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= +github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= +github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= +github.com/stretchr/testify v1.8.0/go.mod h1:yNjHg4UonilssWZ8iaSj1OCr/vHnekPRkoO+kdMU+MU= +github.com/stretchr/testify v1.8.1/go.mod h1:w2LPCIKwWwSfY2zedu0+kehJoqGctiVI29o6fzry7u4= +github.com/stretchr/testify v1.10.0 h1:Xv5erBjTwe/5IxqUQTdXv5kgmIvbHo3QQyRwhJsOfJA= +github.com/stretchr/testify v1.10.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY= +github.com/tidwall/pretty v1.0.0/go.mod h1:XNkn88O1ChpSDQmQeStsy+sBenx6DDtFZJxhVysOjyk= +github.com/twitchyliquid64/golang-asm v0.15.1 h1:SU5vSMR7hnwNxj24w34ZyCi/FmDZTkS4MhqMhdFk5YI= +github.com/twitchyliquid64/golang-asm v0.15.1/go.mod h1:a1lVb/DtPvCB8fslRZhAngC2+aY1QWCk3Cedj/Gdt08= +github.com/valyala/bytebufferpool v1.0.0 h1:GqA5TC/0021Y/b9FG4Oi9Mr3q7XYx6KllzawFIhcdPw= +github.com/valyala/bytebufferpool v1.0.0/go.mod h1:6bBcMArwyJ5K/AmCkWv1jt77kVWyCJ6HpOuEn7z0Csc= +github.com/valyala/fasthttp v1.65.0 h1:j/u3uzFEGFfRxw79iYzJN+TteTJwbYkru9uDp3d0Yf8= +github.com/valyala/fasthttp v1.65.0/go.mod h1:P/93/YkKPMsKSnATEeELUCkG8a7Y+k99uxNHVbKINr4= +github.com/weaviate/weaviate v1.31.5 h1:YcmU1NcY2rdegWpE/mifS/9OisjE3I30JC7k6OgRlIE= +github.com/weaviate/weaviate v1.31.5/go.mod h1:CMgFYC2WIekOrNtyCQZ+HRJzJVCtrJYAdAkZVUVy45E= +github.com/weaviate/weaviate-go-client/v5 v5.2.0 h1:/HG0vFiKBK3JoOKo0mdk2XVYZ+oM0KfvCLG2ySr/FCA= +github.com/weaviate/weaviate-go-client/v5 v5.2.0/go.mod h1:nzR0ScRmbbutI+0pAjylj9Pt6upGVotnphiLWjy/QNA= +github.com/wk8/go-ordered-map/v2 v2.1.8 h1:5h/BUHu93oj4gIdvHHHGsScSTMijfx5PeYkE/fJgbpc= +github.com/wk8/go-ordered-map/v2 v2.1.8/go.mod h1:5nJHM5DyteebpVlHnWMV0rPz6Zp7+xBAnxjb1X5vnTw= +github.com/xdg-go/pbkdf2 v1.0.0/go.mod h1:jrpuAogTd400dnrH08LKmI/xc1MbPOebTwRqcT5RDeI= +github.com/xdg-go/scram v1.0.2/go.mod h1:1WAq6h33pAW+iRreB34OORO2Nf7qel3VV3fjBj+hCSs= +github.com/xdg-go/stringprep v1.0.2/go.mod h1:8F9zXuvzgwmyT5DUm4GUfZGDdT3W+LCvS6+da4O5kxM= +github.com/xyproto/randomstring v1.0.5 h1:YtlWPoRdgMu3NZtP45drfy1GKoojuR7hmRcnhZqKjWU= +github.com/xyproto/randomstring v1.0.5/go.mod h1:rgmS5DeNXLivK7YprL0pY+lTuhNQW3iGxZ18UQApw/E= +github.com/yosida95/uritemplate/v3 v3.0.2 h1:Ed3Oyj9yrmi9087+NczuL5BwkIc4wvTb5zIM+UJPGz4= +github.com/yosida95/uritemplate/v3 v3.0.2/go.mod h1:ILOh0sOhIJR3+L/8afwt/kE++YT040gmv5BQTMR2HP4= +github.com/youmark/pkcs8 v0.0.0-20181117223130-1be2e3e5546d/go.mod h1:rHwXgn7JulP+udvsHwJoVG1YGAP6VLg4y9I5dyZdqmA= +go.mongodb.org/mongo-driver v1.7.3/go.mod h1:NqaYOwnXWr5Pm7AOpO5QFxKJ503nbMse/R79oO62zWg= +go.mongodb.org/mongo-driver v1.7.5/go.mod h1:VXEWRZ6URJIkUq2SCAyapmhH0ZLRBP+FT4xhp5Zvxng= +go.mongodb.org/mongo-driver v1.8.3/go.mod h1:0sQWfOeY63QTntERDJJ/0SuKK0T1uVSgKCuAROlKEPY= +go.mongodb.org/mongo-driver v1.14.0 h1:P98w8egYRjYe3XDjxhYJagTokP/H6HzlsnojRgZRd80= +go.mongodb.org/mongo-driver v1.14.0/go.mod h1:Vzb0Mk/pa7e6cWw85R4F/endUC3u0U9jGcNU603k65c= +go.opentelemetry.io/auto/sdk v1.1.0 h1:cH53jehLUN6UFLY71z+NDOiNJqDdPRaXzTel0sJySYA= +go.opentelemetry.io/auto/sdk v1.1.0/go.mod h1:3wSPjt5PWp2RhlCcmmOial7AvC4DQqZb7a7wCow3W8A= +go.opentelemetry.io/otel v1.37.0 h1:9zhNfelUvx0KBfu/gb+ZgeAfAgtWrfHJZcAqFC228wQ= +go.opentelemetry.io/otel v1.37.0/go.mod h1:ehE/umFRLnuLa/vSccNq9oS1ErUlkkK71gMcN34UG8I= +go.opentelemetry.io/otel/metric v1.37.0 h1:mvwbQS5m0tbmqML4NqK+e3aDiO02vsf/WgbsdpcPoZE= +go.opentelemetry.io/otel/metric v1.37.0/go.mod h1:04wGrZurHYKOc+RKeye86GwKiTb9FKm1WHtO+4EVr2E= +go.opentelemetry.io/otel/sdk v1.37.0 h1:ItB0QUqnjesGRvNcmAcU0LyvkVyGJ2xftD29bWdDvKI= +go.opentelemetry.io/otel/sdk v1.37.0/go.mod h1:VredYzxUvuo2q3WRcDnKDjbdvmO0sCzOvVAiY+yUkAg= +go.opentelemetry.io/otel/sdk/metric v1.37.0 h1:90lI228XrB9jCMuSdA0673aubgRobVZFhbjxHHspCPc= +go.opentelemetry.io/otel/sdk/metric v1.37.0/go.mod h1:cNen4ZWfiD37l5NhS+Keb5RXVWZWpRE+9WyVCpbo5ps= +go.opentelemetry.io/otel/trace v1.37.0 h1:HLdcFNbRQBE2imdSEgm/kwqmQj1Or1l/7bW6mxVK7z4= +go.opentelemetry.io/otel/trace v1.37.0/go.mod h1:TlgrlQ+PtQO5XFerSPUYG0JSgGyryXewPGyayAWSBS0= +golang.org/x/arch v0.20.0 h1:dx1zTU0MAE98U+TQ8BLl7XsJbgze2WnNKF/8tGp/Q6c= +golang.org/x/arch v0.20.0/go.mod h1:bdwinDaKcfZUGpH09BB7ZmOfhalA8lQdzl62l8gGWsk= +golang.org/x/crypto v0.0.0-20180904163835-0709b304e793/go.mod h1:6SG95UA2DQfeDnfUPMdvaQW0Q7yPrPDi9nlGo2tz2b4= +golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= +golang.org/x/crypto v0.0.0-20190422162423-af44ce270edf/go.mod h1:WFFai1msRO1wXaEeE5yQxYXgSfI8pQAWXbQop6sCtWE= +golang.org/x/crypto v0.0.0-20200302210943-78000ba7a073/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto= +golang.org/x/crypto v0.0.0-20201216223049-8b5274cf687f/go.mod h1:jdWPYTVW3xRLrWPugEBEK3UY2ZEsg3UU495nc5E+M+I= +golang.org/x/net v0.0.0-20190311183353-d8887717615a/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg= +golang.org/x/net v0.0.0-20190404232315-eb5bcb51f2a3/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg= +golang.org/x/net v0.0.0-20210421230115-4e50805a0758/go.mod h1:72T/g9IO56b78aLF+1Kcs5dz7/ng1VjMUvfKvpfy+jM= +golang.org/x/net v0.0.0-20220127200216-cd36cc0744dd/go.mod h1:CfG3xpIq0wQ8r1q4Su4UZFWDARRcnwPjda9FqA0JpMk= +golang.org/x/net v0.43.0 h1:lat02VYK2j4aLzMzecihNvTlJNQUq316m2Mr9rnM6YE= +golang.org/x/net v0.43.0/go.mod h1:vhO1fvI4dGsIjh73sWfUVjj3N7CA9WkKJNQm2svM6Jg= +golang.org/x/oauth2 v0.30.0 h1:dnDm7JmhM45NNpd8FDDeLhK6FwqbOf4MLCM9zb1BOHI= +golang.org/x/oauth2 v0.30.0/go.mod h1:B++QgG3ZKulg6sRPGD/mqlHQs5rB3Ml9erfeDY7xKlU= +golang.org/x/sync v0.0.0-20190227155943-e225da77a7e6/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= +golang.org/x/sync v0.0.0-20190412183630-56d357773e84/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= +golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= +golang.org/x/sync v0.0.0-20190911185100-cd5d95a43a6e/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= +golang.org/x/sys v0.0.0-20180905080454-ebe1bf3edb33/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= +golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= +golang.org/x/sys v0.0.0-20190403152447-81d4e9dc473e/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20190412213103-97732733099d/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20190419153524-e8e3143a4f4a/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20190422165155-953cdadca894/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20190531175056-4c3a928424d2/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20191026070338-33540a1f6037/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20210420072515-93ed5bcd2bfe/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20210615035016-665e8c7367d1/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.0.0-20211216021012-1d35b9e2eb4e/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.0.0-20220811171246-fbc7d0a398ab/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.12.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.35.0 h1:vz1N37gP5bs89s7He8XuIYXpyY0+QlsKmzipCbUtyxI= +golang.org/x/sys v0.35.0/go.mod h1:BJP2sWEmIv4KK5OTEluFJCKSidICx8ciO85XgH3Ak8k= +golang.org/x/term v0.0.0-20201117132131-f5c789dd3221/go.mod h1:Nr5EML6q2oocZ2LXRh80K7BxOlk5/8JxuGnuhpl+muw= +golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo= +golang.org/x/term v0.0.0-20210927222741-03fcf44c2211/go.mod h1:jbD1KX2456YbFQfuXm/mYQcufACuNUgVhRMnK/tPxf8= +golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= +golang.org/x/text v0.3.5/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= +golang.org/x/text v0.3.6/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= +golang.org/x/text v0.3.7/go.mod h1:u+2+/6zg+i71rQMx5EYifcz6MCKuco9NR6JIITiCfzQ= +golang.org/x/text v0.28.0 h1:rhazDwis8INMIwQ4tpjLDzUhx6RlXqZNPEM0huQojng= +golang.org/x/text v0.28.0/go.mod h1:U8nCwOR8jO/marOQ0QbDiOngZVEBB7MAiitBuMjXiNU= +golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= +golang.org/x/tools v0.0.0-20190329151228-23e29df326fe/go.mod h1:LCzVGOaR6xXOjkQ3onu1FJEFr0SW1gC7cKk1uF8kGRs= +golang.org/x/tools v0.0.0-20190416151739-9c9e1878f421/go.mod h1:LCzVGOaR6xXOjkQ3onu1FJEFr0SW1gC7cKk1uF8kGRs= +golang.org/x/tools v0.0.0-20190420181800-aa740d480789/go.mod h1:LCzVGOaR6xXOjkQ3onu1FJEFr0SW1gC7cKk1uF8kGRs= +golang.org/x/tools v0.0.0-20190531172133-b3315ee88b7d/go.mod h1:/rFqwRUd4F7ZHNgwSSTFct+R/Kf4OFW1sUzUTQQTgfc= +golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= +google.golang.org/genproto/googleapis/rpc v0.0.0-20250811230008-5f3141c8851a h1:tPE/Kp+x9dMSwUm/uM0JKK0IfdiJkwAbSMSeZBXXJXc= +google.golang.org/genproto/googleapis/rpc v0.0.0-20250811230008-5f3141c8851a/go.mod h1:gw1tLEfykwDz2ET4a12jcXt4couGAm7IwsVaTy0Sflo= +google.golang.org/grpc v1.74.2 h1:WoosgB65DlWVC9FqI82dGsZhWFNBSLjQ84bjROOpMu4= +google.golang.org/grpc v1.74.2/go.mod h1:CtQ+BGjaAIXHs/5YS3i473GqwBBa1zGQNevxdeBEXrM= +google.golang.org/protobuf v1.36.7 h1:IgrO7UwFQGJdRNXH/sQux4R1Dj1WAKcLElzeeRaXV2A= +google.golang.org/protobuf v1.36.7/go.mod h1:jduwjTPXsFjZGTmRluh+L6NjiWu7pchiJ2/5YcXBHnY= +gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= +gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= +gopkg.in/check.v1 v1.0.0-20200227125254-8fa46927fb4f/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= +gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c h1:Hei/4ADfdWqJk1ZMxUNpqntNwaWcugrBjAiHlqqRiVk= +gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c/go.mod h1:JHkPIbrfpd72SG/EVd6muEfDQjcINNoR0C8j2r3qZ4Q= +gopkg.in/errgo.v2 v2.1.0/go.mod h1:hNsd1EY+bozCKY1Ytp96fpM3vjJbqLJn88ws8XvfDNI= +gopkg.in/yaml.v2 v2.2.2/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= +gopkg.in/yaml.v2 v2.2.8/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= +gopkg.in/yaml.v2 v2.4.0 h1:D8xgwECY7CYvx+Y2n4sBz93Jn9JRvxdiyyo8CTfuKaY= +gopkg.in/yaml.v2 v2.4.0/go.mod h1:RDklbk79AGWmwhnvt/jBztapEOGDOx6ZbXqjP6csGnQ= +gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= +gopkg.in/yaml.v3 v3.0.0-20200605160147-a5ece683394c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= +gopkg.in/yaml.v3 v3.0.0-20200615113413-eeeca48fe776/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= +gopkg.in/yaml.v3 v3.0.0-20210107192922-496545a6307b/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= +gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= +gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= +gorm.io/driver/sqlite v1.6.0 h1:WHRRrIiulaPiPFmDcod6prc4l2VGVWHz80KspNsxSfQ= +gorm.io/driver/sqlite v1.6.0/go.mod h1:AO9V1qIQddBESngQUKWL9yoH93HIeA1X6V633rBwyT8= +gorm.io/gorm v1.30.1 h1:lSHg33jJTBxs2mgJRfRZeLDG+WZaHYCk3Wtfl6Ngzo4= +gorm.io/gorm v1.30.1/go.mod h1:8Z33v652h4//uMA76KjeDH8mJXPm1QNCYrMeatR0DOE= diff --git a/plugins/logging/main.go b/plugins/logging/main.go new file mode 100644 index 000000000..fbeb3e2b7 --- /dev/null +++ b/plugins/logging/main.go @@ -0,0 +1,678 @@ +// Package logging provides a GORM-based logging plugin for Bifrost. +// This plugin stores comprehensive logs of all requests and responses with search, +// filter, and pagination capabilities. +package logging + +import ( + "context" + "errors" + "fmt" + "sync" + "sync/atomic" + "time" + + bifrost "github.com/maximhq/bifrost/core" + "github.com/maximhq/bifrost/core/schemas" + "github.com/maximhq/bifrost/framework/logstore" + "github.com/maximhq/bifrost/framework/pricing" +) + +const ( + PluginName = "bifrost-http-logging" +) + +// ContextKey is a custom type for context keys to prevent collisions +type ContextKey string + +// LogOperation represents the type of logging operation +type LogOperation string + +const ( + LogOperationCreate LogOperation = "create" + LogOperationUpdate LogOperation = "update" + LogOperationStreamUpdate LogOperation = "stream_update" +) + +// Context keys for logging optimization +const ( + DroppedCreateContextKey ContextKey = "bifrost-logging-dropped" + CreatedTimestampKey ContextKey = "bifrost-logging-created-timestamp" +) + +// UpdateLogData contains data for log entry updates +type UpdateLogData struct { + Status string + TokenUsage *schemas.LLMUsage + Cost *float64 // Cost in dollars from pricing plugin + OutputMessage *schemas.BifrostMessage + EmbeddingOutput *[]schemas.BifrostEmbedding + ToolCalls *[]schemas.ToolCall + ErrorDetails *schemas.BifrostError + Model string // May be different from request + Object string // May be different from request + SpeechOutput *schemas.BifrostSpeech // For non-streaming speech responses + TranscriptionOutput *schemas.BifrostTranscribe // For non-streaming transcription responses +} + +// StreamUpdateData contains lightweight data for streaming delta updates +type StreamUpdateData struct { + ErrorDetails *schemas.BifrostError + Model string // May be different from request + Object string // May be different from request + TokenUsage *schemas.LLMUsage + Cost *float64 // Cost in dollars from pricing plugin + Delta *schemas.BifrostStreamDelta // The actual streaming delta + FinishReason *string // If the stream is finished + TranscriptionOutput *schemas.BifrostTranscribe // For transcription stream responses +} + +// LogMessage represents a message in the logging queue +type LogMessage struct { + Operation LogOperation + RequestID string + Timestamp time.Time // Of the preHook/postHook call + InitialData *InitialLogData // For create operations + SemanticCacheDebug *schemas.BifrostCacheDebug // For semantic cache operations + UpdateData *UpdateLogData // For update operations + StreamUpdateData *StreamUpdateData // For stream update operations +} + +// InitialLogData contains data for initial log entry creation +type InitialLogData struct { + Provider string + Model string + Object string + InputHistory []schemas.BifrostMessage + Params *schemas.ModelParameters + SpeechInput *schemas.SpeechInput + TranscriptionInput *schemas.TranscriptionInput + Tools *[]schemas.Tool +} + +// LogCallback is a function that gets called when a new log entry is created +type LogCallback func(*logstore.Log) + +// StreamChunk represents a single streaming chunk +type StreamChunk struct { + Timestamp time.Time // When chunk was received + Delta *schemas.BifrostStreamDelta // The actual delta content + FinishReason *string // If this is the final chunk + TokenUsage *schemas.LLMUsage // Token usage if available + SemanticCacheDebug *schemas.BifrostCacheDebug // Semantic cache debug if available + Cost *float64 // Cost in dollars from pricing plugin + ErrorDetails *schemas.BifrostError // Error if any +} + +// StreamAccumulator manages accumulation of streaming chunks +type StreamAccumulator struct { + RequestID string + Chunks []*StreamChunk + IsComplete bool + FinalTimestamp time.Time + Object string // Store object type once for the entire stream + mu sync.Mutex +} + +// LoggerPlugin implements the schemas.Plugin interface +type LoggerPlugin struct { + store logstore.LogStore + pricingManager *pricing.PricingManager + mu sync.Mutex + done chan struct{} + wg sync.WaitGroup + logger schemas.Logger + logCallback LogCallback + droppedRequests atomic.Int64 + cleanupTicker *time.Ticker // Ticker for cleaning up old processing logs + logMsgPool sync.Pool // Pool for reusing LogMessage structs + updateDataPool sync.Pool // Pool for reusing UpdateLogData structs + streamDataPool sync.Pool // Pool for reusing StreamUpdateData structs + streamChunkPool sync.Pool // Pool for reusing StreamChunk structs + streamAccumulators sync.Map // Track accumulators by request ID (atomic) +} + +// retryOnNotFound retries a function up to 3 times with 1-second delays if it returns logstore.ErrNotFound +func retryOnNotFound(ctx context.Context, operation func() error) error { + const maxRetries = 3 + const retryDelay = time.Second + + var lastErr error + for attempt := 0; attempt < maxRetries; attempt++ { + err := operation() + if err == nil { + return nil + } + + // Check if the error is logstore.ErrNotFound + if !errors.Is(err, logstore.ErrNotFound) { + return err + } + + lastErr = err + + // Don't wait after the last attempt + if attempt < maxRetries-1 { + select { + case <-ctx.Done(): + return ctx.Err() + case <-time.After(retryDelay): + // Continue to next retry + } + } + } + + return lastErr +} + +// Init creates new logger plugin with given log store +func Init(logger schemas.Logger, logsStore logstore.LogStore, pricingManager *pricing.PricingManager) (*LoggerPlugin, error) { + if logsStore == nil { + return nil, fmt.Errorf("logs store cannot be nil") + } + if pricingManager == nil { + logger.Warn("logging plugin requires pricing manager to calculate cost, all cost calculations will be skipped.") + } + + plugin := &LoggerPlugin{ + store: logsStore, + pricingManager: pricingManager, + done: make(chan struct{}), + logger: logger, + logMsgPool: sync.Pool{ + New: func() interface{} { + return &LogMessage{} + }, + }, + updateDataPool: sync.Pool{ + New: func() interface{} { + return &UpdateLogData{} + }, + }, + streamDataPool: sync.Pool{ + New: func() interface{} { + return &StreamUpdateData{} + }, + }, + streamChunkPool: sync.Pool{ + New: func() interface{} { + return &StreamChunk{} + }, + }, + streamAccumulators: sync.Map{}, + } + + // Prewarm the pools for better performance at startup + for range 1000 { + plugin.logMsgPool.Put(&LogMessage{}) + plugin.updateDataPool.Put(&UpdateLogData{}) + plugin.streamDataPool.Put(&StreamUpdateData{}) + plugin.streamChunkPool.Put(&StreamChunk{}) + } + + // Start cleanup ticker (runs every 30 seconds) + plugin.cleanupTicker = time.NewTicker(30 * time.Second) + plugin.wg.Add(1) + go plugin.cleanupWorker() + + return plugin, nil +} + +// cleanupWorker periodically removes old processing logs +func (p *LoggerPlugin) cleanupWorker() { + defer p.wg.Done() + + for { + select { + case <-p.cleanupTicker.C: + p.cleanupOldProcessingLogs() + + case <-p.done: + return + } + } +} + +// cleanupOldProcessingLogs removes processing logs older than 5 minutes +func (p *LoggerPlugin) cleanupOldProcessingLogs() { + // Calculate timestamp for 5 minutes ago + fiveMinutesAgo := time.Now().Add(-1 * 5 * time.Minute) + // Delete processing logs older than 5 minutes using the store + if err := p.store.CleanupLogs(fiveMinutesAgo); err != nil { + p.logger.Error("failed to cleanup old processing logs: %v", err) + } + + // Clean up old stream accumulators + p.cleanupOldStreamAccumulators() +} + +// SetLogCallback sets a callback function that will be called for each log entry +func (p *LoggerPlugin) SetLogCallback(callback LogCallback) { + p.mu.Lock() + defer p.mu.Unlock() + p.logCallback = callback +} + +// GetName returns the name of the plugin +func (p *LoggerPlugin) GetName() string { + return PluginName +} + +// PreHook is called before a request is processed - FULLY ASYNC, NO DATABASE I/O +func (p *LoggerPlugin) PreHook(ctx *context.Context, req *schemas.BifrostRequest) (*schemas.BifrostRequest, *schemas.PluginShortCircuit, error) { + if ctx == nil { + // Log error but don't fail the request + p.logger.Error("context is nil in PreHook") + return req, nil, nil + } + + // Extract request ID from context + requestID, ok := (*ctx).Value(schemas.BifrostContextKey("request-id")).(string) + if !ok || requestID == "" { + // Log error but don't fail the request + p.logger.Error("request-id not found in context or is empty") + return req, nil, nil + } + + requestType, ok := (*ctx).Value(schemas.BifrostContextKeyRequestType).(schemas.RequestType) + if !ok { + p.logger.Error("request type not found in context") + return req, nil, nil + } + + // Prepare initial log data + objectType := p.determineObjectType(requestType) + inputHistory := p.extractInputHistory(req.Input) + + initialData := &InitialLogData{ + Provider: string(req.Provider), + Model: req.Model, + Object: objectType, + InputHistory: inputHistory, + Params: req.Params, + SpeechInput: req.Input.SpeechInput, + TranscriptionInput: req.Input.TranscriptionInput, + } + + if req.Params != nil && req.Params.Tools != nil { + initialData.Tools = req.Params.Tools + } + + // Store created timestamp in context for latency calculation optimization + createdTimestamp := time.Now() + *ctx = context.WithValue(*ctx, CreatedTimestampKey, createdTimestamp) + + // Queue the log creation message (non-blocking) - Using sync.Pool + logMsg := p.getLogMessage() + logMsg.Operation = LogOperationCreate + logMsg.RequestID = requestID + logMsg.Timestamp = createdTimestamp + logMsg.InitialData = initialData + + go func(logMsg *LogMessage) { + defer p.putLogMessage(logMsg) // Return to pool when done + if err := p.insertInitialLogEntry(logMsg.RequestID, logMsg.Timestamp, logMsg.InitialData); err != nil { + p.logger.Error("failed to insert initial log entry for request %s: %v", logMsg.RequestID, err) + } else { + // Call callback for initial log creation (WebSocket "create" message) + // Construct LogEntry directly from data we have to avoid database query + p.mu.Lock() + if p.logCallback != nil { + initialEntry := &logstore.Log{ + ID: logMsg.RequestID, + Timestamp: logMsg.Timestamp, + Object: logMsg.InitialData.Object, + Provider: logMsg.InitialData.Provider, + Model: logMsg.InitialData.Model, + InputHistoryParsed: logMsg.InitialData.InputHistory, + ParamsParsed: logMsg.InitialData.Params, + ToolsParsed: logMsg.InitialData.Tools, + Status: "processing", + Stream: false, // Initially false, will be updated if streaming + CreatedAt: logMsg.Timestamp, + } + p.logCallback(initialEntry) + } + p.mu.Unlock() + } + }(logMsg) + + return req, nil, nil +} + +// PostHook is called after a response is received - FULLY ASYNC, NO DATABASE I/O +func (p *LoggerPlugin) PostHook(ctx *context.Context, result *schemas.BifrostResponse, err *schemas.BifrostError) (*schemas.BifrostResponse, *schemas.BifrostError, error) { + if ctx == nil { + // Log error but don't fail the request + p.logger.Error("context is nil in PostHook") + return result, err, nil + } + + // Check if the create operation was dropped - if so, skip the update + if dropped, ok := (*ctx).Value(DroppedCreateContextKey).(bool); ok && dropped { + // Create was dropped, skip update to avoid wasted processing and errors + return result, err, nil + } + + // Extract request ID from context + requestID, ok := (*ctx).Value(schemas.BifrostContextKey("request-id")).(string) + if !ok || requestID == "" { + // Log error but don't fail the request + p.logger.Error("request-id not found in context or is empty") + return result, err, nil + } + + provider, ok := (*ctx).Value(schemas.BifrostContextKeyRequestProvider).(schemas.ModelProvider) + if !ok { + p.logger.Error("provider not found in context") + return result, err, nil + } + + model, ok := (*ctx).Value(schemas.BifrostContextKeyRequestModel).(string) + if !ok { + p.logger.Error("model not found in context") + return result, err, nil + } + // Check if this is a streaming response + requestType, ok := (*ctx).Value(schemas.BifrostContextKeyRequestType).(schemas.RequestType) + if !ok { + p.logger.Error("request type missing/invalid in PostHook for request %s", requestID) + return result, err, nil + } + isAudioStreaming := requestType == schemas.SpeechStreamRequest || requestType == schemas.TranscriptionStreamRequest + isChatStreaming := requestType == schemas.ChatCompletionStreamRequest + + // Queue the log update message (non-blocking) - use same pattern for both streaming and regular + logMsg := p.getLogMessage() + logMsg.RequestID = requestID + logMsg.Timestamp = time.Now() + + isFinalChunk := bifrost.IsFinalChunk(ctx) + + if isChatStreaming { + // Handle text-based streaming with ordered accumulation + return p.handleStreamingResponse(ctx, result, err) + } else if isAudioStreaming { + // Handle speech/transcription streaming with original flow + logMsg.Operation = LogOperationStreamUpdate + + // Prepare lightweight streaming update data + streamUpdateData := p.getStreamUpdateData() + + if err != nil { + // Error case + streamUpdateData.ErrorDetails = err + } else if result != nil { + if result.Model != "" { + streamUpdateData.Model = model + } + + // Update object type if available + if result.Object != "" { + streamUpdateData.Object = result.Object + } + + // Token usage + if result.Usage != nil && result.Usage.TotalTokens > 0 { + streamUpdateData.TokenUsage = result.Usage + } + + // Extract token usage from speech and transcription streaming (lightweight) + if result.Speech != nil && result.Speech.Usage != nil && streamUpdateData.TokenUsage == nil { + streamUpdateData.TokenUsage = &schemas.LLMUsage{ + PromptTokens: result.Speech.Usage.InputTokens, + CompletionTokens: result.Speech.Usage.OutputTokens, + TotalTokens: result.Speech.Usage.TotalTokens, + } + } + if result.Transcribe != nil && result.Transcribe.Usage != nil && streamUpdateData.TokenUsage == nil { + transcriptionUsage := result.Transcribe.Usage + streamUpdateData.TokenUsage = &schemas.LLMUsage{} + + if transcriptionUsage.InputTokens != nil { + streamUpdateData.TokenUsage.PromptTokens = *transcriptionUsage.InputTokens + } + if transcriptionUsage.OutputTokens != nil { + streamUpdateData.TokenUsage.CompletionTokens = *transcriptionUsage.OutputTokens + } + if transcriptionUsage.TotalTokens != nil { + streamUpdateData.TokenUsage.TotalTokens = *transcriptionUsage.TotalTokens + } + } + if result.Transcribe != nil && result.Transcribe.BifrostTranscribeStreamResponse != nil && result.Transcribe.Text != "" { + streamUpdateData.TranscriptionOutput = result.Transcribe + } + } + + logMsg.StreamUpdateData = streamUpdateData + } else { + // Handle regular response + logMsg.Operation = LogOperationUpdate + + // Prepare update data (latency will be calculated in background worker) + updateData := p.getUpdateLogData() + + if err != nil { + // Error case + updateData.Status = "error" + updateData.ErrorDetails = err + } else if result != nil { + // Success case + updateData.Status = "success" + + if result.Model != "" { + updateData.Model = model + } + + // Update object type if available + if result.Object != "" { + updateData.Object = result.Object + } + + // Token usage + if result.Usage != nil && result.Usage.TotalTokens > 0 { + updateData.TokenUsage = result.Usage + } + + // Output message and tool calls + if len(result.Choices) > 0 { + choice := result.Choices[0] + + // Check if this is a non-stream response choice + if choice.BifrostNonStreamResponseChoice != nil { + updateData.OutputMessage = &choice.BifrostNonStreamResponseChoice.Message + + // Extract tool calls if present + if choice.BifrostNonStreamResponseChoice.Message.AssistantMessage != nil && + choice.BifrostNonStreamResponseChoice.Message.AssistantMessage.ToolCalls != nil { + updateData.ToolCalls = choice.BifrostNonStreamResponseChoice.Message.AssistantMessage.ToolCalls + } + } + } + + if result.Data != nil { + updateData.EmbeddingOutput = &result.Data + } + + // Handle speech and transcription outputs for NON-streaming responses + if result.Speech != nil { + updateData.SpeechOutput = result.Speech + // Extract token usage + if result.Speech.Usage != nil && updateData.TokenUsage == nil { + updateData.TokenUsage = &schemas.LLMUsage{ + PromptTokens: result.Speech.Usage.InputTokens, + CompletionTokens: result.Speech.Usage.OutputTokens, + TotalTokens: result.Speech.Usage.TotalTokens, + } + } + } + if result.Transcribe != nil { + updateData.TranscriptionOutput = result.Transcribe + // Extract token usage + if result.Transcribe.Usage != nil && updateData.TokenUsage == nil { + transcriptionUsage := result.Transcribe.Usage + updateData.TokenUsage = &schemas.LLMUsage{} + + if transcriptionUsage.InputTokens != nil { + updateData.TokenUsage.PromptTokens = *transcriptionUsage.InputTokens + } + if transcriptionUsage.OutputTokens != nil { + updateData.TokenUsage.CompletionTokens = *transcriptionUsage.OutputTokens + } + if transcriptionUsage.TotalTokens != nil { + updateData.TokenUsage.TotalTokens = *transcriptionUsage.TotalTokens + } + } + } + } + + logMsg.UpdateData = updateData + } + + // Both streaming and regular updates now use the same async pattern + go func() { + defer p.putLogMessage(logMsg) // Return to pool when done + + // Return pooled data structures to their respective pools + defer func() { + if logMsg.UpdateData != nil { + p.putUpdateLogData(logMsg.UpdateData) + } + if logMsg.StreamUpdateData != nil { + p.putStreamUpdateData(logMsg.StreamUpdateData) + } + }() + + if result != nil { + logMsg.SemanticCacheDebug = result.ExtraFields.CacheDebug + } + + if logMsg.UpdateData != nil && p.pricingManager != nil { + cost := p.pricingManager.CalculateCostWithCacheDebug(result, provider, model, requestType) + logMsg.UpdateData.Cost = &cost + } + if logMsg.StreamUpdateData != nil && isFinalChunk && p.pricingManager != nil { + cost := p.pricingManager.CalculateCostWithCacheDebug(result, provider, model, requestType) + logMsg.StreamUpdateData.Cost = &cost + } + + var processingErr error + if logMsg.Operation == LogOperationStreamUpdate { + processingErr = retryOnNotFound(*ctx, func() error { + return p.processStreamUpdate(*ctx, logMsg.RequestID, logMsg.Timestamp, logMsg.SemanticCacheDebug, logMsg.StreamUpdateData, isFinalChunk) + }) + } else { + processingErr = retryOnNotFound(*ctx, func() error { + return p.updateLogEntry(*ctx, logMsg.RequestID, logMsg.Timestamp, logMsg.SemanticCacheDebug, logMsg.UpdateData) + }) + } + if processingErr != nil { + p.logger.Error("failed to process log update for request %s: %v", logMsg.RequestID, processingErr) + } else { + // Call callback immediately for both streaming and regular updates + // UI will handle debouncing if needed + p.mu.Lock() + if p.logCallback != nil { + if updatedEntry, getErr := p.getLogEntry(logMsg.RequestID); getErr == nil { + p.logCallback(updatedEntry) + } + } + p.mu.Unlock() + } + }() + + return result, err, nil +} + +// Cleanup is called when the plugin is being shut down +func (p *LoggerPlugin) Cleanup() error { + // Stop the cleanup ticker + if p.cleanupTicker != nil { + p.cleanupTicker.Stop() + } + + // Signal the background worker to stop + close(p.done) + + // Wait for the background worker to finish processing remaining items + p.wg.Wait() + + // Clean up all stream accumulators + p.streamAccumulators.Range(func(key, value interface{}) bool { + acc := value.(*StreamAccumulator) + for _, c := range acc.Chunks { + p.putStreamChunk(c) + } + p.streamAccumulators.Delete(key) + return true + }) + + // GORM handles connection cleanup automatically + return nil +} + +// Helper methods + +// determineObjectType determines the object type from request input +func (p *LoggerPlugin) determineObjectType(requestType schemas.RequestType) string { + switch requestType { + case schemas.TextCompletionRequest: + return "text.completion" + case schemas.ChatCompletionRequest: + return "chat.completion" + case schemas.ChatCompletionStreamRequest: + return "chat.completion.chunk" + case schemas.EmbeddingRequest: + return "list" + case schemas.SpeechRequest: + return "audio.speech" + case schemas.SpeechStreamRequest: + return "audio.speech.chunk" + case schemas.TranscriptionRequest: + return "audio.transcription" + case schemas.TranscriptionStreamRequest: + return "audio.transcription.chunk" + } + return "unknown" +} + +// extractInputHistory extracts input history from request input +func (p *LoggerPlugin) extractInputHistory(input schemas.RequestInput) []schemas.BifrostMessage { + if input.ChatCompletionInput != nil { + return *input.ChatCompletionInput + } + if input.TextCompletionInput != nil { + // Convert text completion to message format + return []schemas.BifrostMessage{ + { + Role: schemas.ModelChatMessageRoleUser, + Content: schemas.MessageContent{ + ContentStr: input.TextCompletionInput, + }, + }, + } + } + if input.EmbeddingInput != nil { + texts := input.EmbeddingInput.Texts + + if len(texts) == 0 && input.EmbeddingInput.Text != nil { + texts = []string{*input.EmbeddingInput.Text} + } + + contentBlocks := make([]schemas.ContentBlock, len(texts)) + for i, text := range texts { + contentBlocks[i] = schemas.ContentBlock{ + Type: schemas.ContentBlockTypeText, + Text: &text, + } + } + return []schemas.BifrostMessage{ + { + Role: schemas.ModelChatMessageRoleUser, + Content: schemas.MessageContent{ + ContentBlocks: &contentBlocks, + }, + }, + } + } + return []schemas.BifrostMessage{} +} diff --git a/plugins/logging/operations.go b/plugins/logging/operations.go new file mode 100644 index 000000000..4a344c171 --- /dev/null +++ b/plugins/logging/operations.go @@ -0,0 +1,414 @@ +// Package logging provides database operations for the GORM-based logging plugin +package logging + +import ( + "context" + "fmt" + "time" + + "github.com/maximhq/bifrost/core/schemas" + "github.com/maximhq/bifrost/framework/logstore" +) + +// insertInitialLogEntry creates a new log entry in the database using GORM +func (p *LoggerPlugin) insertInitialLogEntry(requestID string, timestamp time.Time, data *InitialLogData) error { + entry := &logstore.Log{ + ID: requestID, + Timestamp: timestamp, + Object: data.Object, + Provider: data.Provider, + Model: data.Model, + Status: "processing", + Stream: false, + CreatedAt: timestamp, + // Set parsed fields for serialization + InputHistoryParsed: data.InputHistory, + ParamsParsed: data.Params, + ToolsParsed: data.Tools, + SpeechInputParsed: data.SpeechInput, + TranscriptionInputParsed: data.TranscriptionInput, + } + + return p.store.Create(entry) +} + +// updateLogEntry updates an existing log entry using GORM +func (p *LoggerPlugin) updateLogEntry(ctx context.Context, requestID string, timestamp time.Time, cacheDebug *schemas.BifrostCacheDebug, data *UpdateLogData) error { + updates := make(map[string]interface{}) + if !timestamp.IsZero() { + // Try to get original timestamp from context first for latency calculation + latency, err := p.calculateLatency(requestID, timestamp, ctx) + if err != nil { + return err + } + updates["latency"] = latency + } + updates["status"] = data.Status + if data.Model != "" { + updates["model"] = data.Model + } + if data.Object != "" { + updates["object_type"] = data.Object // Note: using object_type for database column + } + // Handle JSON fields by setting them on a temporary entry and serializing + tempEntry := &logstore.Log{} + if data.OutputMessage != nil { + tempEntry.OutputMessageParsed = data.OutputMessage + if err := tempEntry.SerializeFields(); err != nil { + p.logger.Error("failed to serialize output message: %v", err) + } else { + updates["output_message"] = tempEntry.OutputMessage + updates["content_summary"] = tempEntry.ContentSummary // Update content summary + } + } + + if data.EmbeddingOutput != nil { + tempEntry.EmbeddingOutputParsed = data.EmbeddingOutput + if err := tempEntry.SerializeFields(); err != nil { + p.logger.Error("failed to serialize embedding output: %v", err) + } else { + updates["embedding_output"] = tempEntry.EmbeddingOutput + } + } + + if data.ToolCalls != nil { + tempEntry.ToolCallsParsed = data.ToolCalls + if err := tempEntry.SerializeFields(); err != nil { + p.logger.Error("failed to serialize tool calls: %v", err) + } else { + updates["tool_calls"] = tempEntry.ToolCalls + } + } + + if data.SpeechOutput != nil { + tempEntry.SpeechOutputParsed = data.SpeechOutput + if err := tempEntry.SerializeFields(); err != nil { + p.logger.Error("failed to serialize speech output: %v", err) + } else { + updates["speech_output"] = tempEntry.SpeechOutput + } + } + + if data.TranscriptionOutput != nil { + tempEntry.TranscriptionOutputParsed = data.TranscriptionOutput + if err := tempEntry.SerializeFields(); err != nil { + p.logger.Error("failed to serialize transcription output: %v", err) + } else { + updates["transcription_output"] = tempEntry.TranscriptionOutput + } + } + + if data.TokenUsage != nil { + tempEntry.TokenUsageParsed = data.TokenUsage + if err := tempEntry.SerializeFields(); err != nil { + p.logger.Error("failed to serialize token usage: %v", err) + } else { + updates["token_usage"] = tempEntry.TokenUsage + updates["prompt_tokens"] = data.TokenUsage.PromptTokens + updates["completion_tokens"] = data.TokenUsage.CompletionTokens + updates["total_tokens"] = data.TokenUsage.TotalTokens + } + } + + // Handle cost from pricing plugin + if data.Cost != nil { + updates["cost"] = *data.Cost + } + + // Handle cache debug + if cacheDebug != nil { + tempEntry.CacheDebugParsed = cacheDebug + if err := tempEntry.SerializeFields(); err != nil { + p.logger.Error("failed to serialize cache debug: %v", err) + } else { + updates["cache_debug"] = tempEntry.CacheDebug + } + } + + if data.ErrorDetails != nil { + tempEntry.ErrorDetailsParsed = data.ErrorDetails + if err := tempEntry.SerializeFields(); err != nil { + p.logger.Error("failed to serialize error details: %v", err) + } else { + updates["error_details"] = tempEntry.ErrorDetails + } + } + return p.store.Update(requestID, updates) +} + +// processStreamUpdate handles streaming updates using GORM +func (p *LoggerPlugin) processStreamUpdate(ctx context.Context, requestID string, timestamp time.Time, cacheDebug *schemas.BifrostCacheDebug, data *StreamUpdateData, isFinalChunk bool) error { + updates := make(map[string]interface{}) + + // Handle error case first + if data.ErrorDetails != nil { + latency, err := p.calculateLatency(requestID, timestamp, ctx) + if err != nil { + // If we can't get created_at, just update status and error + tempEntry := &logstore.Log{} + tempEntry.ErrorDetailsParsed = data.ErrorDetails + if err := tempEntry.SerializeFields(); err == nil { + return p.store.Update(requestID, map[string]interface{}{ + "status": "error", + "error_details": tempEntry.ErrorDetails, + "timestamp": timestamp, + }) + } + return err + } + + tempEntry := &logstore.Log{} + tempEntry.ErrorDetailsParsed = data.ErrorDetails + if err := tempEntry.SerializeFields(); err != nil { + return fmt.Errorf("failed to serialize error details: %w", err) + } + return p.store.Update(requestID, map[string]interface{}{ + "status": "error", + "latency": latency, + "timestamp": timestamp, + "error_details": tempEntry.ErrorDetails, + }) + } + + // Always mark as streaming and update timestamp + updates["stream"] = true + updates["timestamp"] = timestamp + + // Calculate latency when stream finishes + var needsLatency bool + var latency float64 + tempEntry := &logstore.Log{} + + if isFinalChunk { + // Stream is finishing, calculate latency + var err error + latency, err = p.calculateLatency(requestID, timestamp, ctx) + if err != nil { + return fmt.Errorf("failed to get created_at for latency calculation: %w", err) + } + needsLatency = true + } + + // Add latency if this is the final chunk + if needsLatency { + updates["latency"] = latency + } + + // Update model if provided + if data.Model != "" { + updates["model"] = data.Model + } + + // Update object type if provided + if data.Object != "" { + updates["object_type"] = data.Object // Note: using object_type for database column + } + + // Update token usage if provided + if data.TokenUsage != nil { + tempEntry.TokenUsageParsed = data.TokenUsage + if err := tempEntry.SerializeFields(); err == nil { + updates["token_usage"] = tempEntry.TokenUsage + updates["prompt_tokens"] = data.TokenUsage.PromptTokens + updates["completion_tokens"] = data.TokenUsage.CompletionTokens + updates["total_tokens"] = data.TokenUsage.TotalTokens + } + } + + // Handle cost from pricing plugin + if data.Cost != nil { + updates["cost"] = *data.Cost + } + + // Handle finish reason - if present, mark as complete + if isFinalChunk { + updates["status"] = "success" + } + + // Process delta content and tool calls if present + if data.Delta != nil { + deltaUpdates, err := p.prepareDeltaUpdates(requestID, data.Delta) + if err != nil { + return fmt.Errorf("failed to prepare delta updates: %w", err) + } + // Merge delta updates into main updates + for key, value := range deltaUpdates { + updates[key] = value + } + } + + // Handle transcription output from stream updates + if data.TranscriptionOutput != nil { + tempEntry.TranscriptionOutputParsed = data.TranscriptionOutput + // Here we just log error but move one vs breaking the entire logging flow + if err := tempEntry.SerializeFields(); err != nil { + p.logger.Warn("failed to serialize transcription output: %v", err) + } else { + updates["transcription_output"] = tempEntry.TranscriptionOutput + } + } + + // Handle cache debug + if cacheDebug != nil { + tempEntry.CacheDebugParsed = cacheDebug + if err := tempEntry.SerializeFields(); err != nil { + p.logger.Error("failed to serialize cache debug: %v", err) + } else { + updates["cache_debug"] = tempEntry.CacheDebug + } + } + + // Only perform update if there's something to update + if len(updates) > 0 { + return p.store.Update(requestID, updates) + } + + return nil +} + +// calculateLatency computes latency in milliseconds from creation time +func (p *LoggerPlugin) calculateLatency(requestID string, currentTime time.Time, ctx context.Context) (float64, error) { + // Try to get original timestamp from context first + if ctxTimestamp, ok := ctx.Value(CreatedTimestampKey).(time.Time); ok { + return float64(currentTime.Sub(ctxTimestamp).Nanoseconds()) / 1e6, nil + } + + // Fallback to database query if not found in context + originalEntry, err := p.store.FindFirst(map[string]interface{}{"id": requestID}, "created_at") + if err != nil { + return 0, err + } + return float64(currentTime.Sub(originalEntry.CreatedAt).Nanoseconds()) / 1e6, nil +} + +// prepareDeltaUpdates prepares updates for streaming delta content without executing them +func (p *LoggerPlugin) prepareDeltaUpdates(requestID string, delta *schemas.BifrostStreamDelta) (map[string]interface{}, error) { + // Only fetch existing content if we have content or tool calls to append + if (delta.Content == nil || *delta.Content == "") && len(delta.ToolCalls) == 0 && delta.Refusal == nil { + return map[string]interface{}{}, nil + } + + // Get current entry + var currentEntry *logstore.Log + currentEntry, err := p.store.FindFirst(map[string]interface{}{"id": requestID}, "output_message") + if err != nil { + return nil, fmt.Errorf("failed to get existing entry: %w", err) + } + + // Parse existing message or create new one + var outputMessage *schemas.BifrostMessage + if currentEntry.OutputMessage != "" { + outputMessage = &schemas.BifrostMessage{} + // Attempt to deserialize; use parsed message only if successful + if err := currentEntry.DeserializeFields(); err == nil && currentEntry.OutputMessageParsed != nil { + outputMessage = currentEntry.OutputMessageParsed + } else { + // Create new message if parsing fails + outputMessage = &schemas.BifrostMessage{ + Role: schemas.ModelChatMessageRoleAssistant, + Content: schemas.MessageContent{}, + } + } + } else { + // Create new message + outputMessage = &schemas.BifrostMessage{ + Role: schemas.ModelChatMessageRoleAssistant, + Content: schemas.MessageContent{}, + } + } + + // Handle role (usually in first chunk) + if delta.Role != nil { + outputMessage.Role = schemas.ModelChatMessageRole(*delta.Role) + } + + // Append content + if delta.Content != nil && *delta.Content != "" { + p.appendContentToMessage(outputMessage, *delta.Content) + } + + // Handle refusal + if delta.Refusal != nil && *delta.Refusal != "" { + if outputMessage.AssistantMessage == nil { + outputMessage.AssistantMessage = &schemas.AssistantMessage{} + } + if outputMessage.AssistantMessage.Refusal == nil { + outputMessage.AssistantMessage.Refusal = delta.Refusal + } else { + *outputMessage.AssistantMessage.Refusal += *delta.Refusal + } + } + + // Accumulate tool calls + if len(delta.ToolCalls) > 0 { + p.accumulateToolCallsInMessage(outputMessage, delta.ToolCalls) + } + + // Update the database with new content + tempEntry := &logstore.Log{ + OutputMessageParsed: outputMessage, + } + if outputMessage.AssistantMessage != nil && outputMessage.AssistantMessage.ToolCalls != nil { + tempEntry.ToolCallsParsed = outputMessage.AssistantMessage.ToolCalls + } + + if err := tempEntry.SerializeFields(); err != nil { + return nil, fmt.Errorf("failed to serialize fields: %w", err) + } + + updates := map[string]interface{}{ + "output_message": tempEntry.OutputMessage, + "content_summary": tempEntry.ContentSummary, + } + + // Also update tool_calls field for backward compatibility + if tempEntry.ToolCalls != "" { + updates["tool_calls"] = tempEntry.ToolCalls + } + + return updates, nil +} + +// getLogEntry retrieves a log entry by ID using GORM +func (p *LoggerPlugin) getLogEntry(requestID string) (*logstore.Log, error) { + entry, err := p.store.FindFirst(map[string]interface{}{"id": requestID}) + if err != nil { + return nil, err + } + return entry, nil +} + +// SearchLogs searches logs with filters and pagination using GORM +func (p *LoggerPlugin) SearchLogs(filters logstore.SearchFilters, pagination logstore.PaginationOptions) (*logstore.SearchResult, error) { + // Set default pagination if not provided + if pagination.Limit == 0 { + pagination.Limit = 50 + } + if pagination.SortBy == "" { + pagination.SortBy = "timestamp" + } + if pagination.Order == "" { + pagination.Order = "desc" + } + // Build base query with all filters applied + return p.store.SearchLogs(filters, pagination) +} + +// GetAvailableModels returns all unique models from logs +func (p *LoggerPlugin) GetAvailableModels() []string { + modelSet := make(map[string]bool) + // Query distinct models from logs + result, err := p.store.FindAll("model IS NOT NULL AND model != ''", "model") + if err != nil { + p.logger.Error("failed to get available models: %w", err) + return []string{} + } + for _, model := range result { + modelSet[model.Model] = true + } + models := make([]string, 0, len(modelSet)) + for model := range modelSet { + models = append(models, model) + } + return models +} diff --git a/plugins/logging/pool.go b/plugins/logging/pool.go new file mode 100644 index 000000000..3c6b05ce3 --- /dev/null +++ b/plugins/logging/pool.go @@ -0,0 +1,81 @@ +package logging + +import "time" + +// getLogMessage gets a LogMessage from the pool +func (p *LoggerPlugin) getLogMessage() *LogMessage { + return p.logMsgPool.Get().(*LogMessage) +} + +// putLogMessage returns a LogMessage to the pool after resetting it +func (p *LoggerPlugin) putLogMessage(msg *LogMessage) { + // Reset the message fields to avoid memory leaks + msg.Operation = "" + msg.RequestID = "" + msg.Timestamp = time.Time{} + msg.InitialData = nil + + // Don't reset UpdateData and StreamUpdateData here since they're returned + // to their own pools in the defer function - just clear the pointers + msg.UpdateData = nil + msg.StreamUpdateData = nil + + p.logMsgPool.Put(msg) +} + +// getUpdateLogData gets an UpdateLogData from the pool +func (p *LoggerPlugin) getUpdateLogData() *UpdateLogData { + return p.updateDataPool.Get().(*UpdateLogData) +} + +// putUpdateLogData returns an UpdateLogData to the pool after resetting it +func (p *LoggerPlugin) putUpdateLogData(data *UpdateLogData) { + // Reset all fields to avoid memory leaks + data.Status = "" + data.TokenUsage = nil + data.OutputMessage = nil + data.ToolCalls = nil + data.ErrorDetails = nil + data.Model = "" + data.Object = "" + data.SpeechOutput = nil + data.TranscriptionOutput = nil + + p.updateDataPool.Put(data) +} + +// getStreamUpdateData gets a StreamUpdateData from the pool +func (p *LoggerPlugin) getStreamUpdateData() *StreamUpdateData { + return p.streamDataPool.Get().(*StreamUpdateData) +} + +// putStreamUpdateData returns a StreamUpdateData to the pool after resetting it +func (p *LoggerPlugin) putStreamUpdateData(data *StreamUpdateData) { + // Reset all fields to avoid memory leaks + data.ErrorDetails = nil + data.Model = "" + data.Object = "" + data.TokenUsage = nil + data.Delta = nil + data.FinishReason = nil + data.TranscriptionOutput = nil + + p.streamDataPool.Put(data) +} + +// getStreamChunk gets a StreamChunk from the pool +func (p *LoggerPlugin) getStreamChunk() *StreamChunk { + return p.streamChunkPool.Get().(*StreamChunk) +} + +// putStreamChunk returns a StreamChunk to the pool after resetting it +func (p *LoggerPlugin) putStreamChunk(chunk *StreamChunk) { + // Reset all fields to avoid memory leaks + chunk.Timestamp = time.Time{} + chunk.Delta = nil + chunk.FinishReason = nil + chunk.TokenUsage = nil + chunk.ErrorDetails = nil + + p.streamChunkPool.Put(chunk) +} diff --git a/plugins/logging/streaming.go b/plugins/logging/streaming.go new file mode 100644 index 000000000..206f8822c --- /dev/null +++ b/plugins/logging/streaming.go @@ -0,0 +1,414 @@ +// Package logging provides streaming-related functionality for the GORM-based logging plugin +package logging + +import ( + "context" + "fmt" + "time" + + bifrost "github.com/maximhq/bifrost/core" + "github.com/maximhq/bifrost/core/schemas" + "github.com/maximhq/bifrost/framework/logstore" +) + +// appendContentToMessage efficiently appends content to a message +func (p *LoggerPlugin) appendContentToMessage(message *schemas.BifrostMessage, newContent string) { + if message == nil { + return + } + if message.Content.ContentStr != nil { + // Append to existing string content + *message.Content.ContentStr += newContent + } else if message.Content.ContentBlocks != nil { + // Find the last text block and append, or create new one + blocks := *message.Content.ContentBlocks + if len(blocks) > 0 && blocks[len(blocks)-1].Type == schemas.ContentBlockTypeText && blocks[len(blocks)-1].Text != nil { + // Append to last text block + *blocks[len(blocks)-1].Text += newContent + } else { + // Create new text block + blocks = append(blocks, schemas.ContentBlock{ + Type: schemas.ContentBlockTypeText, + Text: &newContent, + }) + message.Content.ContentBlocks = &blocks + } + } else { + // Initialize with string content + message.Content.ContentStr = &newContent + } +} + +// accumulateToolCallsInMessage efficiently accumulates tool calls in a message +func (p *LoggerPlugin) accumulateToolCallsInMessage(message *schemas.BifrostMessage, deltaToolCalls []schemas.ToolCall) { + if message == nil { + return + } + if message.AssistantMessage == nil { + message.AssistantMessage = &schemas.AssistantMessage{} + } + + if message.AssistantMessage.ToolCalls == nil { + message.AssistantMessage.ToolCalls = &[]schemas.ToolCall{} + } + + existingToolCalls := *message.AssistantMessage.ToolCalls + + for _, deltaToolCall := range deltaToolCalls { + // Find existing tool call with same ID or create new one + found := false + for i := range existingToolCalls { + if existingToolCalls[i].ID != nil && deltaToolCall.ID != nil && + *existingToolCalls[i].ID == *deltaToolCall.ID { + // Append arguments to existing tool call + existingToolCalls[i].Function.Arguments += deltaToolCall.Function.Arguments + found = true + break + } + } + if !found { + // Add new tool call + existingToolCalls = append(existingToolCalls, deltaToolCall) + } + } + message.AssistantMessage.ToolCalls = &existingToolCalls +} + +// Stream accumulator helper methods + +// createStreamAccumulator creates a new stream accumulator for a request +func (p *LoggerPlugin) createStreamAccumulator(requestID string) *StreamAccumulator { + accumulator := &StreamAccumulator{ + RequestID: requestID, + Chunks: make([]*StreamChunk, 0), + IsComplete: false, + Object: "", + } + + p.streamAccumulators.Store(requestID, accumulator) + return accumulator +} + +// getOrCreateStreamAccumulator gets or creates a stream accumulator for a request +func (p *LoggerPlugin) getOrCreateStreamAccumulator(requestID string) *StreamAccumulator { + if accumulator, exists := p.streamAccumulators.Load(requestID); exists { + return accumulator.(*StreamAccumulator) + } + + // Create new accumulator if it doesn't exist + return p.createStreamAccumulator(requestID) +} + +// addStreamChunk adds a chunk to the stream accumulator +func (p *LoggerPlugin) addStreamChunk(requestID string, chunk *StreamChunk, object string, isFinalChunk bool) error { + accumulator := p.getOrCreateStreamAccumulator(requestID) + + accumulator.mu.Lock() + defer accumulator.mu.Unlock() + + // Store object type once (from first chunk) + if accumulator.Object == "" && object != "" { + accumulator.Object = object + } + + // Add chunk to the list (chunks arrive in order) + accumulator.Chunks = append(accumulator.Chunks, chunk) + + // Check if this is the final chunk + // Set FinalTimestamp when either FinishReason is present or token usage exists + // This handles both normal completion chunks and usage-only last chunks + if isFinalChunk { + accumulator.FinalTimestamp = chunk.Timestamp + } + + return nil +} + +// processAccumulatedChunks processes all accumulated chunks in order +func (p *LoggerPlugin) processAccumulatedChunks(requestID string) error { + accumulator := p.getOrCreateStreamAccumulator(requestID) + + accumulator.mu.Lock() + defer accumulator.mu.Unlock() + + // Ensure cleanup happens + defer p.cleanupStreamAccumulator(requestID) + + // Build complete message from accumulated chunks + completeMessage := p.buildCompleteMessageFromChunks(accumulator.Chunks) + + // Calculate final latency + latency, err := p.calculateLatency(requestID, accumulator.FinalTimestamp, context.Background()) + if err != nil { + p.logger.Error("failed to calculate latency for request %s: %v", requestID, err) + latency = 0 + } + + // Update database with complete message + updates := make(map[string]interface{}) + updates["status"] = "success" + updates["stream"] = true + updates["latency"] = latency + updates["timestamp"] = accumulator.FinalTimestamp + + // Serialize complete message + tempEntry := &logstore.Log{ + OutputMessageParsed: completeMessage, + } + if completeMessage.AssistantMessage != nil && completeMessage.AssistantMessage.ToolCalls != nil { + tempEntry.ToolCallsParsed = completeMessage.AssistantMessage.ToolCalls + } + + if err := tempEntry.SerializeFields(); err != nil { + return fmt.Errorf("failed to serialize complete message: %w", err) + } + + updates["output_message"] = tempEntry.OutputMessage + updates["content_summary"] = tempEntry.ContentSummary + if tempEntry.ToolCalls != "" { + updates["tool_calls"] = tempEntry.ToolCalls + } + + // Update token usage from final chunk if available + if len(accumulator.Chunks) > 0 { + lastChunk := accumulator.Chunks[len(accumulator.Chunks)-1] + if lastChunk.TokenUsage != nil { + tempEntry.TokenUsageParsed = lastChunk.TokenUsage + if err := tempEntry.SerializeFields(); err != nil { + p.logger.Error("failed to serialize token usage: %v", err) + } else { + updates["token_usage"] = tempEntry.TokenUsage + updates["prompt_tokens"] = lastChunk.TokenUsage.PromptTokens + updates["completion_tokens"] = lastChunk.TokenUsage.CompletionTokens + updates["total_tokens"] = lastChunk.TokenUsage.TotalTokens + } + } + + // Handle cache debug + if lastChunk.SemanticCacheDebug != nil { + tempEntry.CacheDebugParsed = lastChunk.SemanticCacheDebug + if err := tempEntry.SerializeFields(); err != nil { + p.logger.Error("failed to serialize cache debug: %v", err) + } else { + updates["cache_debug"] = tempEntry.CacheDebug + } + } + } + + // Update cost from final chunk if available + if len(accumulator.Chunks) > 0 { + lastChunk := accumulator.Chunks[len(accumulator.Chunks)-1] + if lastChunk.Cost != nil { + updates["cost"] = *lastChunk.Cost + } + } + + // Update object field from accumulator (stored once for the entire stream) + if accumulator.Object != "" { + updates["object_type"] = accumulator.Object + } + + // Perform final database update + if err := p.store.Update(requestID, updates); err != nil { + return fmt.Errorf("failed to update log entry with complete stream: %w", err) + } + + // Trigger callback + p.mu.Lock() + if p.logCallback != nil { + if updatedEntry, getErr := p.getLogEntry(requestID); getErr == nil { + p.logCallback(updatedEntry) + } + } + p.mu.Unlock() + + return nil +} + +// buildCompleteMessageFromChunks builds a complete message from ordered chunks +func (p *LoggerPlugin) buildCompleteMessageFromChunks(chunks []*StreamChunk) *schemas.BifrostMessage { + completeMessage := &schemas.BifrostMessage{ + Role: schemas.ModelChatMessageRoleAssistant, + Content: schemas.MessageContent{}, + } + + for _, chunk := range chunks { + if chunk.Delta == nil { + continue + } + + // Handle role (usually in first chunk) + if chunk.Delta.Role != nil { + completeMessage.Role = schemas.ModelChatMessageRole(*chunk.Delta.Role) + } + + // Append content + if chunk.Delta.Content != nil && *chunk.Delta.Content != "" { + p.appendContentToMessage(completeMessage, *chunk.Delta.Content) + } + + // Handle refusal + if chunk.Delta.Refusal != nil && *chunk.Delta.Refusal != "" { + if completeMessage.AssistantMessage == nil { + completeMessage.AssistantMessage = &schemas.AssistantMessage{} + } + if completeMessage.AssistantMessage.Refusal == nil { + completeMessage.AssistantMessage.Refusal = chunk.Delta.Refusal + } else { + *completeMessage.AssistantMessage.Refusal += *chunk.Delta.Refusal + } + } + + // Accumulate tool calls + if len(chunk.Delta.ToolCalls) > 0 { + p.accumulateToolCallsInMessage(completeMessage, chunk.Delta.ToolCalls) + } + } + + return completeMessage +} + +// cleanupStreamAccumulator removes the stream accumulator for a request +func (p *LoggerPlugin) cleanupStreamAccumulator(requestID string) { + if accumulator, exists := p.streamAccumulators.Load(requestID); exists { + // Return all chunks to the pool before deleting + acc := accumulator.(*StreamAccumulator) + for _, chunk := range acc.Chunks { + p.putStreamChunk(chunk) + } + p.streamAccumulators.Delete(requestID) + } +} + +// cleanupOldStreamAccumulators removes stream accumulators older than 5 minutes +func (p *LoggerPlugin) cleanupOldStreamAccumulators() { + fiveMinutesAgo := time.Now().Add(-5 * time.Minute) + cleanedCount := 0 + + p.streamAccumulators.Range(func(key, value interface{}) bool { + requestID := key.(string) + accumulator := value.(*StreamAccumulator) + accumulator.mu.Lock() + defer accumulator.mu.Unlock() + + // Check if this accumulator is old (no activity for 5 minutes) + // Use the timestamp of the first chunk as a reference + if len(accumulator.Chunks) > 0 { + firstChunkTime := accumulator.Chunks[0].Timestamp + if firstChunkTime.Before(fiveMinutesAgo) { + // Return all chunks to the pool + for _, chunk := range accumulator.Chunks { + p.putStreamChunk(chunk) + } + p.streamAccumulators.Delete(requestID) + cleanedCount++ + p.logger.Debug("cleaned up old stream accumulator for request %s") + } + } + return true + }) + + if cleanedCount > 0 { + p.logger.Debug("cleaned up %d old stream accumulators", cleanedCount) + } +} + +// handleStreamingResponse handles streaming responses with ordered accumulation +func (p *LoggerPlugin) handleStreamingResponse(ctx *context.Context, result *schemas.BifrostResponse, err *schemas.BifrostError) (*schemas.BifrostResponse, *schemas.BifrostError, error) { + requestID, ok := (*ctx).Value(schemas.BifrostContextKey("request-id")).(string) + if !ok || requestID == "" { + p.logger.Error("request-id not found in context or is empty") + return result, err, nil + } + + provider, ok := (*ctx).Value(schemas.BifrostContextKeyRequestProvider).(schemas.ModelProvider) + if !ok { + p.logger.Error("provider not found in context") + return result, err, nil + } + + model, ok := (*ctx).Value(schemas.BifrostContextKeyRequestModel).(string) + if !ok { + p.logger.Error("model not found in context") + return result, err, nil + } + + requestType, ok := (*ctx).Value(schemas.BifrostContextKeyRequestType).(schemas.RequestType) + if !ok { + p.logger.Error("request type not found in context") + return result, err, nil + } + + // Create chunk from current response using pool + chunk := p.getStreamChunk() + chunk.Timestamp = time.Now() + chunk.ErrorDetails = err + + if err != nil { + // Error case - mark as final chunk + chunk.FinishReason = bifrost.Ptr("error") + } else if result != nil { + // Extract delta and other information + if len(result.Choices) > 0 { + choice := result.Choices[0] + if choice.BifrostStreamResponseChoice != nil { + // Create a deep copy of the Delta to avoid pointing to stack memory + deltaCopy := choice.BifrostStreamResponseChoice.Delta + chunk.Delta = &deltaCopy + chunk.FinishReason = choice.FinishReason + } + } + + // Extract token usage + if result.Usage != nil && result.Usage.TotalTokens > 0 { + chunk.TokenUsage = result.Usage + } + } + + isFinalChunk := bifrost.IsFinalChunk(ctx) + + go func() { + // Add chunk to accumulator synchronously to maintain order + object := "" + if result != nil { + if isFinalChunk { + if p.pricingManager != nil { + cost := p.pricingManager.CalculateCostWithCacheDebug(result, provider, model, requestType) + chunk.Cost = bifrost.Ptr(cost) + } + chunk.SemanticCacheDebug = result.ExtraFields.CacheDebug + } + + object = result.Object + } + if addErr := p.addStreamChunk(requestID, chunk, object, isFinalChunk); addErr != nil { + p.logger.Error("failed to add stream chunk for request %s: %v", requestID, addErr) + } + + // If this is the final chunk, process accumulated chunks asynchronously + // Use the IsComplete flag to prevent duplicate processing + shouldProcess := false + if isFinalChunk { + // Get the accumulator to check if processing has already been triggered + accumulator := p.getOrCreateStreamAccumulator(requestID) + accumulator.mu.Lock() + shouldProcess = !accumulator.IsComplete + + // Mark as complete when we're about to process + if shouldProcess { + accumulator.IsComplete = true + } + accumulator.mu.Unlock() + + if shouldProcess { + + if processErr := p.processAccumulatedChunks(requestID); processErr != nil { + p.logger.Error("failed to process accumulated chunks for request %s: %v", requestID, processErr) + } + + } + } + }() + + return result, err, nil +} diff --git a/plugins/logging/utils.go b/plugins/logging/utils.go new file mode 100644 index 000000000..a9120ac5c --- /dev/null +++ b/plugins/logging/utils.go @@ -0,0 +1,48 @@ +// Package logging provides utility functions and interfaces for the GORM-based logging plugin +package logging + +import ( + "fmt" + + "github.com/maximhq/bifrost/framework/logstore" +) + +// LogManager defines the main interface that combines all logging functionality +type LogManager interface { + // Search searches for log entries based on filters and pagination + Search(filters *logstore.SearchFilters, pagination *logstore.PaginationOptions) (*logstore.SearchResult, error) + + // Get the number of dropped requests + GetDroppedRequests() int64 + + // GetAvailableModels returns all unique models from logs + GetAvailableModels() []string +} + +// PluginLogManager implements LogManager interface wrapping the plugin +type PluginLogManager struct { + plugin *LoggerPlugin +} + +func (p *PluginLogManager) Search(filters *logstore.SearchFilters, pagination *logstore.PaginationOptions) (*logstore.SearchResult, error) { + if filters == nil || pagination == nil { + return nil, fmt.Errorf("filters and pagination cannot be nil") + } + return p.plugin.SearchLogs(*filters, *pagination) +} + +func (p *PluginLogManager) GetDroppedRequests() int64 { + return p.plugin.droppedRequests.Load() +} + +// GetAvailableModels returns all unique models from logs +func (p *PluginLogManager) GetAvailableModels() []string { + return p.plugin.GetAvailableModels() +} + +// GetPluginLogManager returns a LogManager interface for this plugin +func (p *LoggerPlugin) GetPluginLogManager() *PluginLogManager { + return &PluginLogManager{ + plugin: p, + } +} diff --git a/plugins/logging/version b/plugins/logging/version new file mode 100644 index 000000000..f69752ab1 --- /dev/null +++ b/plugins/logging/version @@ -0,0 +1 @@ +1.2.16 diff --git a/plugins/maxim-sdk.go b/plugins/maxim-sdk.go deleted file mode 100644 index c70ad59e7..000000000 --- a/plugins/maxim-sdk.go +++ /dev/null @@ -1,128 +0,0 @@ -// Package plugins provides plugins for the Bifrost system. -// This file contains the Plugin implementation using maxim's logger plugin for bifrost. -package plugins - -import ( - "context" - "fmt" - "time" - - "github.com/maximhq/bifrost/core/schemas" - - "github.com/maximhq/maxim-go" - "github.com/maximhq/maxim-go/logging" -) - -// NewMaximLogger initializes and returns a Plugin instance for Maxim's logger. -// -// Parameters: -// - apiKey: API key for Maxim SDK authentication -// - loggerId: ID for the Maxim logger instance -// -// Returns: -// - schemas.Plugin: A configured plugin instance for request/response tracing -// - error: Any error that occurred during plugin initialization -func NewMaximLoggerPlugin(apiKey string, loggerId string) (schemas.Plugin, error) { - // check if Maxim Logger variables are set - if apiKey == "" { - return nil, fmt.Errorf("apiKey is not set") - } - - if loggerId == "" { - return nil, fmt.Errorf("loggerId is not set") - } - - mx := maxim.Init(&maxim.MaximSDKConfig{ApiKey: apiKey}) - - logger, err := mx.GetLogger(&logging.LoggerConfig{Id: loggerId}) - if err != nil { - return nil, err - } - - plugin := &Plugin{logger} - - return plugin, nil -} - -// contextKey is a custom type for context keys to prevent key collisions in the context. -// It provides type safety for context values and ensures that context keys are unique -// across different packages. -type contextKey string - -// traceIDKey is the context key used to store and retrieve trace IDs. -// This constant provides a consistent key for tracking request traces -// throughout the request/response lifecycle. -const ( - traceIDKey contextKey = "traceID" -) - -// Plugin implements the schemas.Plugin interface for Maxim's logger. -// It provides request and response tracing functionality using the Maxim logger, -// allowing detailed tracking of requests and responses. -// -// Fields: -// - logger: A Maxim logger instance used for tracing requests and responses -type Plugin struct { - logger *logging.Logger -} - -// PreHook is called before a request is processed by Bifrost. -// It creates a new trace for the incoming request and stores the trace ID in the context. -// The trace includes request details that can be used for debugging and monitoring. -// -// Parameters: -// - ctx: Pointer to the context.Context that will store the trace ID -// - req: The incoming Bifrost request to be traced -// -// Returns: -// - *schemas.BifrostRequest: The original request, unmodified -// - error: Always returns nil as this implementation doesn't produce errors -// -// The trace ID format is "YYYYMMDD_HHmmssSSS" based on the current time. -// If the context is nil, tracing information will still be logged but not stored in context. -func (plugin *Plugin) PreHook(ctx *context.Context, req *schemas.BifrostRequest) (*schemas.BifrostRequest, error) { - traceID := time.Now().Format("20060102_150405000") - - trace := plugin.logger.Trace(&logging.TraceConfig{ - Id: traceID, - Name: maxim.StrPtr("bifrost"), - }) - - trace.SetInput(fmt.Sprintf("New Request Incoming: %v", req)) - - if ctx != nil { - // Store traceID in context - *ctx = context.WithValue(*ctx, traceIDKey, traceID) - } - - return req, nil -} - -// PostHook is called after a request has been processed by Bifrost. -// It retrieves the trace ID from the context and logs the response details. -// This completes the request trace by adding response information. -// -// Parameters: -// - ctxRef: Pointer to the context.Context containing the trace ID -// - res: The Bifrost response to be traced -// -// Returns: -// - *schemas.BifrostResponse: The original response, unmodified -// - error: Returns an error if the trace ID cannot be retrieved from the context -// -// If the context is nil or the trace ID is not found, an error will be returned -// but the response will still be passed through unmodified. -func (plugin *Plugin) PostHook(ctxRef *context.Context, res *schemas.BifrostResponse) (*schemas.BifrostResponse, error) { - // Get traceID from context - if ctxRef != nil { - ctx := *ctxRef - traceID, ok := ctx.Value(traceIDKey).(string) - if !ok { - return res, fmt.Errorf("traceID not found in context") - } - - plugin.logger.SetTraceOutput(traceID, fmt.Sprintf("Response: %v", res)) - } - - return res, nil -} diff --git a/plugins/maxim/changelog.md b/plugins/maxim/changelog.md new file mode 100644 index 000000000..6dcfe4edd --- /dev/null +++ b/plugins/maxim/changelog.md @@ -0,0 +1,4 @@ + + + +- Upgrades framework to 1.0.23 \ No newline at end of file diff --git a/plugins/maxim/go.mod b/plugins/maxim/go.mod new file mode 100644 index 000000000..80a4308f8 --- /dev/null +++ b/plugins/maxim/go.mod @@ -0,0 +1,55 @@ +module github.com/maximhq/bifrost/plugins/maxim + +go 1.24.1 + +toolchain go1.24.3 + +require ( + github.com/maximhq/bifrost/core v1.1.37 + github.com/maximhq/maxim-go v0.1.10 +) + +require github.com/google/uuid v1.6.0 + +require ( + cloud.google.com/go/compute/metadata v0.8.0 // indirect + github.com/andybalholm/brotli v1.2.0 // indirect + github.com/aws/aws-sdk-go-v2 v1.38.0 // indirect + github.com/aws/aws-sdk-go-v2/config v1.31.0 // indirect + github.com/aws/aws-sdk-go-v2/credentials v1.18.4 // indirect + github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.18.3 // indirect + github.com/aws/aws-sdk-go-v2/internal/configsources v1.4.3 // indirect + github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.7.3 // indirect + github.com/aws/aws-sdk-go-v2/internal/ini v1.8.3 // indirect + github.com/aws/aws-sdk-go-v2/service/internal/accept-encoding v1.13.0 // indirect + github.com/aws/aws-sdk-go-v2/service/internal/presigned-url v1.13.3 // indirect + github.com/aws/aws-sdk-go-v2/service/sso v1.28.0 // indirect + github.com/aws/aws-sdk-go-v2/service/ssooidc v1.33.0 // indirect + github.com/aws/aws-sdk-go-v2/service/sts v1.37.0 // indirect + github.com/aws/smithy-go v1.22.5 // indirect + github.com/bahlo/generic-list-go v0.2.0 // indirect + github.com/buger/jsonparser v1.1.1 // indirect + github.com/bytedance/sonic v1.14.0 // indirect + github.com/bytedance/sonic/loader v0.3.0 // indirect + github.com/cloudwego/base64x v0.1.6 // indirect + github.com/invopop/jsonschema v0.13.0 // indirect + github.com/klauspost/compress v1.18.0 // indirect + github.com/klauspost/cpuid/v2 v2.3.0 // indirect + github.com/mailru/easyjson v0.9.0 // indirect + github.com/mark3labs/mcp-go v0.37.0 // indirect + github.com/mattn/go-colorable v0.1.14 // indirect + github.com/mattn/go-isatty v0.0.20 // indirect + github.com/rs/zerolog v1.34.0 // indirect + github.com/spf13/cast v1.9.2 // indirect + github.com/twitchyliquid64/golang-asm v0.15.1 // indirect + github.com/valyala/bytebufferpool v1.0.0 // indirect + github.com/valyala/fasthttp v1.65.0 // indirect + github.com/wk8/go-ordered-map/v2 v2.1.8 // indirect + github.com/yosida95/uritemplate/v3 v3.0.2 // indirect + golang.org/x/arch v0.20.0 // indirect + golang.org/x/net v0.43.0 // indirect + golang.org/x/oauth2 v0.30.0 // indirect + golang.org/x/sys v0.35.0 // indirect + golang.org/x/text v0.28.0 // indirect + gopkg.in/yaml.v3 v3.0.1 // indirect +) diff --git a/plugins/maxim/go.sum b/plugins/maxim/go.sum new file mode 100644 index 000000000..fbc80045d --- /dev/null +++ b/plugins/maxim/go.sum @@ -0,0 +1,127 @@ +cloud.google.com/go/compute/metadata v0.8.0 h1:HxMRIbao8w17ZX6wBnjhcDkW6lTFpgcaobyVfZWqRLA= +cloud.google.com/go/compute/metadata v0.8.0/go.mod h1:sYOGTp851OV9bOFJ9CH7elVvyzopvWQFNNghtDQ/Biw= +github.com/andybalholm/brotli v1.2.0 h1:ukwgCxwYrmACq68yiUqwIWnGY0cTPox/M94sVwToPjQ= +github.com/andybalholm/brotli v1.2.0/go.mod h1:rzTDkvFWvIrjDXZHkuS16NPggd91W3kUSvPlQ1pLaKY= +github.com/aws/aws-sdk-go-v2 v1.38.0 h1:UCRQ5mlqcFk9HJDIqENSLR3wiG1VTWlyUfLDEvY7RxU= +github.com/aws/aws-sdk-go-v2 v1.38.0/go.mod h1:9Q0OoGQoboYIAJyslFyF1f5K1Ryddop8gqMhWx/n4Wg= +github.com/aws/aws-sdk-go-v2/config v1.31.0 h1:9yH0xiY5fUnVNLRWO0AtayqwU1ndriZdN78LlhruJR4= +github.com/aws/aws-sdk-go-v2/config v1.31.0/go.mod h1:VeV3K72nXnhbe4EuxxhzsDc/ByrCSlZwUnWH52Nde/I= +github.com/aws/aws-sdk-go-v2/credentials v1.18.4 h1:IPd0Algf1b+Qy9BcDp0sCUcIWdCQPSzDoMK3a8pcbUM= +github.com/aws/aws-sdk-go-v2/credentials v1.18.4/go.mod h1:nwg78FjH2qvsRM1EVZlX9WuGUJOL5od+0qvm0adEzHk= +github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.18.3 h1:GicIdnekoJsjq9wqnvyi2elW6CGMSYKhdozE7/Svh78= +github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.18.3/go.mod h1:R7BIi6WNC5mc1kfRM7XM/VHC3uRWkjc396sfabq4iOo= +github.com/aws/aws-sdk-go-v2/internal/configsources v1.4.3 h1:o9RnO+YZ4X+kt5Z7Nvcishlz0nksIt2PIzDglLMP0vA= +github.com/aws/aws-sdk-go-v2/internal/configsources v1.4.3/go.mod h1:+6aLJzOG1fvMOyzIySYjOFjcguGvVRL68R+uoRencN4= +github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.7.3 h1:joyyUFhiTQQmVK6ImzNU9TQSNRNeD9kOklqTzyk5v6s= +github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.7.3/go.mod h1:+vNIyZQP3b3B1tSLI0lxvrU9cfM7gpdRXMFfm67ZcPc= +github.com/aws/aws-sdk-go-v2/internal/ini v1.8.3 h1:bIqFDwgGXXN1Kpp99pDOdKMTTb5d2KyU5X/BZxjOkRo= +github.com/aws/aws-sdk-go-v2/internal/ini v1.8.3/go.mod h1:H5O/EsxDWyU+LP/V8i5sm8cxoZgc2fdNR9bxlOFrQTo= +github.com/aws/aws-sdk-go-v2/service/internal/accept-encoding v1.13.0 h1:6+lZi2JeGKtCraAj1rpoZfKqnQ9SptseRZioejfUOLM= +github.com/aws/aws-sdk-go-v2/service/internal/accept-encoding v1.13.0/go.mod h1:eb3gfbVIxIoGgJsi9pGne19dhCBpK6opTYpQqAmdy44= +github.com/aws/aws-sdk-go-v2/service/internal/presigned-url v1.13.3 h1:ieRzyHXypu5ByllM7Sp4hC5f/1Fy5wqxqY0yB85hC7s= +github.com/aws/aws-sdk-go-v2/service/internal/presigned-url v1.13.3/go.mod h1:O5ROz8jHiOAKAwx179v+7sHMhfobFVi6nZt8DEyiYoM= +github.com/aws/aws-sdk-go-v2/service/sso v1.28.0 h1:Mc/MKBf2m4VynyJkABoVEN+QzkfLqGj0aiJuEe7cMeM= +github.com/aws/aws-sdk-go-v2/service/sso v1.28.0/go.mod h1:iS5OmxEcN4QIPXARGhavH7S8kETNL11kym6jhoS7IUQ= +github.com/aws/aws-sdk-go-v2/service/ssooidc v1.33.0 h1:6csaS/aJmqZQbKhi1EyEMM7yBW653Wy/B9hnBofW+sw= +github.com/aws/aws-sdk-go-v2/service/ssooidc v1.33.0/go.mod h1:59qHWaY5B+Rs7HGTuVGaC32m0rdpQ68N8QCN3khYiqs= +github.com/aws/aws-sdk-go-v2/service/sts v1.37.0 h1:MG9VFW43M4A8BYeAfaJJZWrroinxeTi2r3+SnmLQfSA= +github.com/aws/aws-sdk-go-v2/service/sts v1.37.0/go.mod h1:JdeBDPgpJfuS6rU/hNglmOigKhyEZtBmbraLE4GK1J8= +github.com/aws/smithy-go v1.22.5 h1:P9ATCXPMb2mPjYBgueqJNCA5S9UfktsW0tTxi+a7eqw= +github.com/aws/smithy-go v1.22.5/go.mod h1:t1ufH5HMublsJYulve2RKmHDC15xu1f26kHCp/HgceI= +github.com/bahlo/generic-list-go v0.2.0 h1:5sz/EEAK+ls5wF+NeqDpk5+iNdMDXrh3z3nPnH1Wvgk= +github.com/bahlo/generic-list-go v0.2.0/go.mod h1:2KvAjgMlE5NNynlg/5iLrrCCZ2+5xWbdbCW3pNTGyYg= +github.com/buger/jsonparser v1.1.1 h1:2PnMjfWD7wBILjqQbt530v576A/cAbQvEW9gGIpYMUs= +github.com/buger/jsonparser v1.1.1/go.mod h1:6RYKKt7H4d4+iWqouImQ9R2FZql3VbhNgx27UK13J/0= +github.com/bytedance/sonic v1.14.0 h1:/OfKt8HFw0kh2rj8N0F6C/qPGRESq0BbaNZgcNXXzQQ= +github.com/bytedance/sonic v1.14.0/go.mod h1:WoEbx8WTcFJfzCe0hbmyTGrfjt8PzNEBdxlNUO24NhA= +github.com/bytedance/sonic/loader v0.3.0 h1:dskwH8edlzNMctoruo8FPTJDF3vLtDT0sXZwvZJyqeA= +github.com/bytedance/sonic/loader v0.3.0/go.mod h1:N8A3vUdtUebEY2/VQC0MyhYeKUFosQU6FxH2JmUe6VI= +github.com/cloudwego/base64x v0.1.6 h1:t11wG9AECkCDk5fMSoxmufanudBtJ+/HemLstXDLI2M= +github.com/cloudwego/base64x v0.1.6/go.mod h1:OFcloc187FXDaYHvrNIjxSe8ncn0OOM8gEHfghB2IPU= +github.com/coreos/go-systemd/v22 v22.5.0/go.mod h1:Y58oyj3AT4RCenI/lSvhwexgC+NSVTIJ3seZv2GcEnc= +github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc h1:U9qPSI2PIWSS1VwoXQT9A3Wy9MM3WgvqSxFWenqJduM= +github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/frankban/quicktest v1.14.6 h1:7Xjx+VpznH+oBnejlPUj8oUpdxnVs4f8XU8WnHkI4W8= +github.com/frankban/quicktest v1.14.6/go.mod h1:4ptaffx2x8+WTWXmUCuVU6aPUX1/Mz7zb5vbUoiM6w0= +github.com/godbus/dbus/v5 v5.0.4/go.mod h1:xhWf0FNVPg57R7Z0UbKHbJfkEywrmjJnf7w5xrFpKfA= +github.com/google/go-cmp v0.7.0 h1:wk8382ETsv4JYUZwIsn6YpYiWiBsYLSJiTsyBybVuN8= +github.com/google/go-cmp v0.7.0/go.mod h1:pXiqmnSA92OHEEa9HXL2W4E7lf9JzCmGVUdgjX3N/iU= +github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0= +github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= +github.com/invopop/jsonschema v0.13.0 h1:KvpoAJWEjR3uD9Kbm2HWJmqsEaHt8lBUpd0qHcIi21E= +github.com/invopop/jsonschema v0.13.0/go.mod h1:ffZ5Km5SWWRAIN6wbDXItl95euhFz2uON45H2qjYt+0= +github.com/klauspost/compress v1.18.0 h1:c/Cqfb0r+Yi+JtIEq73FWXVkRonBlf0CRNYc8Zttxdo= +github.com/klauspost/compress v1.18.0/go.mod h1:2Pp+KzxcywXVXMr50+X0Q/Lsb43OQHYWRCY2AiWywWQ= +github.com/klauspost/cpuid/v2 v2.3.0 h1:S4CRMLnYUhGeDFDqkGriYKdfoFlDnMtqTiI/sFzhA9Y= +github.com/klauspost/cpuid/v2 v2.3.0/go.mod h1:hqwkgyIinND0mEev00jJYCxPNVRVXFQeu1XKlok6oO0= +github.com/kr/pretty v0.3.1 h1:flRD4NNwYAUpkphVc1HcthR4KEIFJ65n8Mw5qdRn3LE= +github.com/kr/pretty v0.3.1/go.mod h1:hoEshYVHaxMs3cyo3Yncou5ZscifuDolrwPKZanG3xk= +github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY= +github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE= +github.com/mailru/easyjson v0.9.0 h1:PrnmzHw7262yW8sTBwxi1PdJA3Iw/EKBa8psRf7d9a4= +github.com/mailru/easyjson v0.9.0/go.mod h1:1+xMtQp2MRNVL/V1bOzuP3aP8VNwRW55fQUto+XFtTU= +github.com/mark3labs/mcp-go v0.37.0 h1:BywvZLPRT6Zx6mMG/MJfxLSZQkTGIcJSEGKsvr4DsoQ= +github.com/mark3labs/mcp-go v0.37.0/go.mod h1:T7tUa2jO6MavG+3P25Oy/jR7iCeJPHImCZHRymCn39g= +github.com/mattn/go-colorable v0.1.13/go.mod h1:7S9/ev0klgBDR4GtXTXX8a3vIGJpMovkB8vQcUbaXHg= +github.com/mattn/go-colorable v0.1.14 h1:9A9LHSqF/7dyVVX6g0U9cwm9pG3kP9gSzcuIPHPsaIE= +github.com/mattn/go-colorable v0.1.14/go.mod h1:6LmQG8QLFO4G5z1gPvYEzlUgJ2wF+stgPZH1UqBm1s8= +github.com/mattn/go-isatty v0.0.16/go.mod h1:kYGgaQfpe5nmfYZH+SKPsOc2e4SrIfOl2e/yFXSvRLM= +github.com/mattn/go-isatty v0.0.19/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y= +github.com/mattn/go-isatty v0.0.20 h1:xfD0iDuEKnDkl03q4limB+vH+GxLEtL/jb4xVJSWWEY= +github.com/mattn/go-isatty v0.0.20/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y= +github.com/maximhq/bifrost/core v1.1.37 h1:jVFY1tQFY8T2r4S3RE1zN8cFp1Uw97Dec3Ud32rR8Uc= +github.com/maximhq/bifrost/core v1.1.37/go.mod h1:tf2pFTpoM53UGXXMFYxsaUjMqnCqYDOd9glFgMJvA0c= +github.com/maximhq/maxim-go v0.1.10 h1:rGBYSY3qld2zfZeL4HBmropkyfrqNiJ4IYA49jbvYX8= +github.com/maximhq/maxim-go v0.1.10/go.mod h1:0+UTWM7UZwNNE5VnljLtr/vpRGtYP8r/2q9WDwlLWFw= +github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= +github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= +github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2 h1:Jamvg5psRIccs7FGNTlIRMkT8wgtp5eCXdBlqhYGL6U= +github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= +github.com/rogpeppe/go-internal v1.13.1 h1:KvO1DLK/DRN07sQ1LQKScxyZJuNnedQ5/wKSR38lUII= +github.com/rogpeppe/go-internal v1.13.1/go.mod h1:uMEvuHeurkdAXX61udpOXGD/AzZDWNMNyH2VO9fmH0o= +github.com/rs/xid v1.6.0/go.mod h1:7XoLgs4eV+QndskICGsho+ADou8ySMSjJKDIan90Nz0= +github.com/rs/zerolog v1.34.0 h1:k43nTLIwcTVQAncfCw4KZ2VY6ukYoZaBPNOE8txlOeY= +github.com/rs/zerolog v1.34.0/go.mod h1:bJsvje4Z08ROH4Nhs5iH600c3IkWhwp44iRc54W6wYQ= +github.com/spf13/cast v1.9.2 h1:SsGfm7M8QOFtEzumm7UZrZdLLquNdzFYfIbEXntcFbE= +github.com/spf13/cast v1.9.2/go.mod h1:jNfB8QC9IA6ZuY2ZjDp0KtFO2LZZlg4S/7bzP6qqeHo= +github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= +github.com/stretchr/objx v0.4.0/go.mod h1:YvHI0jy2hoMjB+UWwv71VJQ9isScKT/TqJzVSSt89Yw= +github.com/stretchr/objx v0.5.0/go.mod h1:Yh+to48EsGEfYuaHDzXPcE3xhTkx73EhmCGUpEOglKo= +github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= +github.com/stretchr/testify v1.8.0/go.mod h1:yNjHg4UonilssWZ8iaSj1OCr/vHnekPRkoO+kdMU+MU= +github.com/stretchr/testify v1.8.1/go.mod h1:w2LPCIKwWwSfY2zedu0+kehJoqGctiVI29o6fzry7u4= +github.com/stretchr/testify v1.10.0 h1:Xv5erBjTwe/5IxqUQTdXv5kgmIvbHo3QQyRwhJsOfJA= +github.com/stretchr/testify v1.10.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY= +github.com/twitchyliquid64/golang-asm v0.15.1 h1:SU5vSMR7hnwNxj24w34ZyCi/FmDZTkS4MhqMhdFk5YI= +github.com/twitchyliquid64/golang-asm v0.15.1/go.mod h1:a1lVb/DtPvCB8fslRZhAngC2+aY1QWCk3Cedj/Gdt08= +github.com/valyala/bytebufferpool v1.0.0 h1:GqA5TC/0021Y/b9FG4Oi9Mr3q7XYx6KllzawFIhcdPw= +github.com/valyala/bytebufferpool v1.0.0/go.mod h1:6bBcMArwyJ5K/AmCkWv1jt77kVWyCJ6HpOuEn7z0Csc= +github.com/valyala/fasthttp v1.65.0 h1:j/u3uzFEGFfRxw79iYzJN+TteTJwbYkru9uDp3d0Yf8= +github.com/valyala/fasthttp v1.65.0/go.mod h1:P/93/YkKPMsKSnATEeELUCkG8a7Y+k99uxNHVbKINr4= +github.com/wk8/go-ordered-map/v2 v2.1.8 h1:5h/BUHu93oj4gIdvHHHGsScSTMijfx5PeYkE/fJgbpc= +github.com/wk8/go-ordered-map/v2 v2.1.8/go.mod h1:5nJHM5DyteebpVlHnWMV0rPz6Zp7+xBAnxjb1X5vnTw= +github.com/xyproto/randomstring v1.0.5 h1:YtlWPoRdgMu3NZtP45drfy1GKoojuR7hmRcnhZqKjWU= +github.com/xyproto/randomstring v1.0.5/go.mod h1:rgmS5DeNXLivK7YprL0pY+lTuhNQW3iGxZ18UQApw/E= +github.com/yosida95/uritemplate/v3 v3.0.2 h1:Ed3Oyj9yrmi9087+NczuL5BwkIc4wvTb5zIM+UJPGz4= +github.com/yosida95/uritemplate/v3 v3.0.2/go.mod h1:ILOh0sOhIJR3+L/8afwt/kE++YT040gmv5BQTMR2HP4= +golang.org/x/arch v0.20.0 h1:dx1zTU0MAE98U+TQ8BLl7XsJbgze2WnNKF/8tGp/Q6c= +golang.org/x/arch v0.20.0/go.mod h1:bdwinDaKcfZUGpH09BB7ZmOfhalA8lQdzl62l8gGWsk= +golang.org/x/net v0.43.0 h1:lat02VYK2j4aLzMzecihNvTlJNQUq316m2Mr9rnM6YE= +golang.org/x/net v0.43.0/go.mod h1:vhO1fvI4dGsIjh73sWfUVjj3N7CA9WkKJNQm2svM6Jg= +golang.org/x/oauth2 v0.30.0 h1:dnDm7JmhM45NNpd8FDDeLhK6FwqbOf4MLCM9zb1BOHI= +golang.org/x/oauth2 v0.30.0/go.mod h1:B++QgG3ZKulg6sRPGD/mqlHQs5rB3Ml9erfeDY7xKlU= +golang.org/x/sys v0.0.0-20220811171246-fbc7d0a398ab/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.12.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.35.0 h1:vz1N37gP5bs89s7He8XuIYXpyY0+QlsKmzipCbUtyxI= +golang.org/x/sys v0.35.0/go.mod h1:BJP2sWEmIv4KK5OTEluFJCKSidICx8ciO85XgH3Ak8k= +golang.org/x/text v0.28.0 h1:rhazDwis8INMIwQ4tpjLDzUhx6RlXqZNPEM0huQojng= +golang.org/x/text v0.28.0/go.mod h1:U8nCwOR8jO/marOQ0QbDiOngZVEBB7MAiitBuMjXiNU= +gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= +gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c h1:Hei/4ADfdWqJk1ZMxUNpqntNwaWcugrBjAiHlqqRiVk= +gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c/go.mod h1:JHkPIbrfpd72SG/EVd6muEfDQjcINNoR0C8j2r3qZ4Q= +gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= +gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= +gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= diff --git a/plugins/maxim/main.go b/plugins/maxim/main.go new file mode 100644 index 000000000..8988ba7dc --- /dev/null +++ b/plugins/maxim/main.go @@ -0,0 +1,433 @@ +// Package maxim provides integration for Maxim's SDK as a Bifrost plugin. +// This file contains the main plugin implementation. +package maxim + +import ( + "context" + "encoding/json" + "fmt" + "sync" + + "github.com/google/uuid" + "github.com/maximhq/bifrost/core/schemas" + + "github.com/maximhq/maxim-go" + "github.com/maximhq/maxim-go/logging" +) + +// PluginName is the canonical name for the maxim plugin. +const PluginName = "maxim" + +// Config is the configuration for the maxim plugin. +// - apiKey: API key for Maxim SDK authentication +// - logRepoId: Optional default ID for the Maxim logger instance +type Config struct { + LogRepoId string `json:"log_repo_id,omitempty"` // Optional - can be empty + ApiKey string `json:"api_key"` +} + +// Init initializes and returns a Plugin instance for Maxim's logger. +// +// Parameters: +// - config: Configuration for the maxim plugin +// +// Returns: +// - schemas.Plugin: A configured plugin instance for request/response tracing +// - error: Any error that occurred during plugin initialization +func Init(config Config) (schemas.Plugin, error) { + // check if Maxim Logger variables are set + if config.ApiKey == "" { + return nil, fmt.Errorf("apiKey is not set") + } + + mx := maxim.Init(&maxim.MaximSDKConfig{ApiKey: config.ApiKey}) + + plugin := &Plugin{ + mx: mx, + defaultLogRepoId: config.LogRepoId, + loggers: make(map[string]*logging.Logger), + loggerMutex: &sync.RWMutex{}, + } + + // Initialize default logger if LogRepoId is provided + if config.LogRepoId != "" { + logger, err := mx.GetLogger(&logging.LoggerConfig{Id: config.LogRepoId}) + if err != nil { + return nil, fmt.Errorf("failed to initialize default logger: %w", err) + } + plugin.loggers[config.LogRepoId] = logger + } + + return plugin, nil +} + +// ContextKey is a custom type for context keys to prevent key collisions in the context. +// It provides type safety for context values and ensures that context keys are unique +// across different packages. +type ContextKey string + +// TraceIDKey is the context key used to store and retrieve trace IDs. +// This constant provides a consistent key for tracking request traces +// throughout the request/response lifecycle. +const ( + SessionIDKey ContextKey = "session-id" + TraceIDKey ContextKey = "trace-id" + TraceNameKey ContextKey = "trace-name" + GenerationIDKey ContextKey = "generation-id" + GenerationNameKey ContextKey = "generation-name" + TagsKey ContextKey = "maxim-tags" + LogRepoIDKey ContextKey = "log-repo-id" +) + +// The plugin provides request/response tracing functionality by integrating with Maxim's logging system. +// It supports both chat completion and text completion requests, tracking the entire lifecycle of each request +// including inputs, parameters, and responses. +// +// Key Features: +// - Automatic trace and generation ID management +// - Support for both chat and text completion requests +// - Contextual tracking across request lifecycle +// - Graceful handling of existing trace/generation IDs +// +// The plugin uses context values to maintain trace and generation IDs throughout the request lifecycle. +// These IDs can be propagated from external systems through HTTP headers (x-bf-maxim-trace-id and x-bf-maxim-generation-id). + +// Plugin implements the schemas.Plugin interface for Maxim's logger. +// It provides request and response tracing functionality using Maxim logger, +// allowing detailed tracking of requests and responses across different log repositories. +// +// Fields: +// - mx: The Maxim SDK instance for creating new loggers +// - defaultLogRepoId: Default log repository ID from config (optional) +// - loggers: Map of log repo ID to logger instances +// - loggerMutex: RW mutex for thread-safe access to loggers map +type Plugin struct { + mx *maxim.Maxim + defaultLogRepoId string + loggers map[string]*logging.Logger + loggerMutex *sync.RWMutex +} + +// GetName returns the name of the plugin. +func (plugin *Plugin) GetName() string { + return PluginName +} + +// getEffectiveLogRepoID determines which single log repo ID to use based on priority: +// 1. Header log repo ID (if provided) +// 2. Default log repo ID from config (if configured) +// 3. Empty string (skip logging) +func (plugin *Plugin) getEffectiveLogRepoID(ctx *context.Context) string { + // Check for header log repo ID first (highest priority) + if ctx != nil { + if headerRepoID, ok := (*ctx).Value(LogRepoIDKey).(string); ok && headerRepoID != "" { + return headerRepoID + } + } + + // Fall back to default log repo ID from config + if plugin.defaultLogRepoId != "" { + return plugin.defaultLogRepoId + } + + // Return empty string if neither header nor default is available + return "" +} + +// getOrCreateLogger gets an existing logger or creates a new one for the given log repo ID +func (plugin *Plugin) getOrCreateLogger(logRepoID string) (*logging.Logger, error) { + // First, try to get existing logger (read lock) + plugin.loggerMutex.RLock() + if logger, exists := plugin.loggers[logRepoID]; exists { + plugin.loggerMutex.RUnlock() + return logger, nil + } + plugin.loggerMutex.RUnlock() + + // Logger doesn't exist, create it (write lock) + plugin.loggerMutex.Lock() + defer plugin.loggerMutex.Unlock() + + // Double-check in case another goroutine created it while we were waiting + if logger, exists := plugin.loggers[logRepoID]; exists { + return logger, nil + } + + // Create new logger + logger, err := plugin.mx.GetLogger(&logging.LoggerConfig{Id: logRepoID}) + if err != nil { + return nil, fmt.Errorf("failed to create logger for repo ID %s: %w", logRepoID, err) + } + + plugin.loggers[logRepoID] = logger + return logger, nil +} + +// PreHook is called before a request is processed by Bifrost. +// It manages trace and generation tracking for incoming requests by either: +// - Creating a new trace if none exists +// - Reusing an existing trace ID from the context +// - Creating a new generation within an existing trace +// - Skipping trace/generation creation if they already exist +// +// The function handles both chat completion and text completion requests, +// capturing relevant metadata such as: +// - Request type (chat/text completion) +// - Model information +// - Message content and role +// - Model parameters +// +// Parameters: +// - ctx: Pointer to the context.Context that may contain existing trace/generation IDs +// - req: The incoming Bifrost request to be traced +// +// Returns: +// - *schemas.BifrostRequest: The original request, unmodified +// - error: Any error that occurred during trace/generation creation +func (plugin *Plugin) PreHook(ctx *context.Context, req *schemas.BifrostRequest) (*schemas.BifrostRequest, *schemas.PluginShortCircuit, error) { + var traceID string + var traceName string + var sessionID string + var generationName string + var tags map[string]string + + // Get effective log repo ID (header > default > skip) + effectiveLogRepoID := plugin.getEffectiveLogRepoID(ctx) + + // If no log repo ID available, skip logging + if effectiveLogRepoID == "" { + return req, nil, nil + } + + // Check if context already has traceID and generationID + if ctx != nil { + if existingGenerationID, ok := (*ctx).Value(GenerationIDKey).(string); ok && existingGenerationID != "" { + // If generationID exists, return early + return req, nil, nil + } + + if existingTraceID, ok := (*ctx).Value(TraceIDKey).(string); ok && existingTraceID != "" { + // If traceID exists, and no generationID, create a new generation on the trace + traceID = existingTraceID + } + + if existingSessionID, ok := (*ctx).Value(SessionIDKey).(string); ok && existingSessionID != "" { + sessionID = existingSessionID + } + + if existingTraceName, ok := (*ctx).Value(TraceNameKey).(string); ok && existingTraceName != "" { + traceName = existingTraceName + } + + if existingGenerationName, ok := (*ctx).Value(GenerationNameKey).(string); ok && existingGenerationName != "" { + generationName = existingGenerationName + } + + // retrieve all tags from context + // the transport layer now stores all maxim tags in a single map + if tagsValue := (*ctx).Value(TagsKey); tagsValue != nil { + if tagsMap, ok := tagsValue.(map[string]string); ok { + tags = make(map[string]string) + for key, value := range tagsMap { + tags[key] = value + } + } + } + } + + // Determine request type and set appropriate tags + var requestType string + var messages []logging.CompletionRequest + var latestMessage string + + // Initialize tags map if not already initialized from context + if tags == nil { + tags = make(map[string]string) + } + + // Add model to tags + tags["model"] = req.Model + + if req.Input.ChatCompletionInput != nil { + requestType = "chat_completion" + for _, message := range *req.Input.ChatCompletionInput { + messages = append(messages, logging.CompletionRequest{ + Role: string(message.Role), + Content: message.Content, + }) + } + if len(*req.Input.ChatCompletionInput) > 0 { + lastMsg := (*req.Input.ChatCompletionInput)[len(*req.Input.ChatCompletionInput)-1] + if lastMsg.Content.ContentStr != nil { + latestMessage = *lastMsg.Content.ContentStr + } else if lastMsg.Content.ContentBlocks != nil { + // Find the last text content block + for i := len(*lastMsg.Content.ContentBlocks) - 1; i >= 0; i-- { + block := (*lastMsg.Content.ContentBlocks)[i] + if block.Type == "text" && block.Text != nil { + latestMessage = *block.Text + break + } + } + // If no text block found, use placeholder + if latestMessage == "" { + latestMessage = "-" + } + } + } + } else if req.Input.TextCompletionInput != nil { + requestType = "text_completion" + messages = append(messages, logging.CompletionRequest{ + Role: string(schemas.ModelChatMessageRoleUser), + Content: req.Input.TextCompletionInput, + }) + latestMessage = *req.Input.TextCompletionInput + } + + tags["action"] = requestType + + if traceID == "" { + // If traceID is not set, create a new trace + traceID = uuid.New().String() + name := fmt.Sprintf("bifrost_%s", requestType) + if traceName != "" { + name = traceName + } + + traceConfig := logging.TraceConfig{ + Id: traceID, + Name: maxim.StrPtr(name), + Tags: &tags, + } + + if sessionID != "" { + traceConfig.SessionId = &sessionID + } + + // Create trace in the effective log repository + logger, err := plugin.getOrCreateLogger(effectiveLogRepoID) + if err == nil { + trace := logger.Trace(&traceConfig) + trace.SetInput(latestMessage) + } + } + + // Convert ModelParameters to map[string]interface{} + modelParams := make(map[string]interface{}) + if req.Params != nil { + // Convert the struct to a map using reflection or JSON marshaling + jsonData, err := json.Marshal(req.Params) + if err == nil { + json.Unmarshal(jsonData, &modelParams) + } + } + + generationID := uuid.New().String() + + generationConfig := logging.GenerationConfig{ + Id: generationID, + Model: req.Model, + Provider: string(req.Provider), + Tags: &tags, + Messages: messages, + ModelParameters: modelParams, + } + + if generationName != "" { + generationConfig.Name = &generationName + } + + // Add generation to the effective log repository + logger, err := plugin.getOrCreateLogger(effectiveLogRepoID) + if err == nil { + logger.AddGenerationToTrace(traceID, &generationConfig) + } + + if ctx != nil { + if _, ok := (*ctx).Value(TraceIDKey).(string); !ok { + *ctx = context.WithValue(*ctx, TraceIDKey, traceID) + } + *ctx = context.WithValue(*ctx, GenerationIDKey, generationID) + } + + return req, nil, nil +} + +// PostHook is called after a request has been processed by Bifrost. +// It completes the request trace by: +// - Adding response data to the generation if a generation ID exists +// - Logging error details if bifrostErr is provided +// - Ending the generation if it exists +// - Ending the trace if a trace ID exists +// - Flushing all pending log data +// +// The function gracefully handles cases where trace or generation IDs may be missing, +// ensuring that partial logging is still performed when possible. +// +// Parameters: +// - ctxRef: Pointer to the context.Context containing trace/generation IDs +// - res: The Bifrost response to be traced +// - bifrostErr: The BifrostError returned by the request, if any +// +// Returns: +// - *schemas.BifrostResponse: The original response, unmodified +// - *schemas.BifrostError: The original error, unmodified +// - error: Never returns an error as it handles missing IDs gracefully +func (plugin *Plugin) PostHook(ctxRef *context.Context, res *schemas.BifrostResponse, bifrostErr *schemas.BifrostError) (*schemas.BifrostResponse, *schemas.BifrostError, error) { + if ctxRef != nil { + ctx := *ctxRef + + // Get effective log repo ID for this request + effectiveLogRepoID := plugin.getEffectiveLogRepoID(ctxRef) + + generationID, ok := ctx.Value(GenerationIDKey).(string) + if ok && effectiveLogRepoID != "" { + // Process generation completion in the effective log repository + logger, err := plugin.getOrCreateLogger(effectiveLogRepoID) + if err == nil { + if bifrostErr != nil { + genErr := logging.GenerationError{ + Message: bifrostErr.Error.Message, + Code: bifrostErr.Error.Code, + Type: bifrostErr.Error.Type, + } + logger.SetGenerationError(generationID, &genErr) + } else if res != nil { + logger.AddResultToGeneration(generationID, res) + } + + logger.EndGeneration(generationID) + } + } + + traceID, ok := ctx.Value(TraceIDKey).(string) + if ok && effectiveLogRepoID != "" { + // End trace in the effective log repository + logger, err := plugin.getOrCreateLogger(effectiveLogRepoID) + if err == nil { + logger.EndTrace(traceID) + } + } + + // Flush only the effective logger that was used for this request + if effectiveLogRepoID != "" { + logger, err := plugin.getOrCreateLogger(effectiveLogRepoID) + if err == nil { + logger.Flush() + } + } + } + + return res, bifrostErr, nil +} + +func (plugin *Plugin) Cleanup() error { + // Flush all loggers + plugin.loggerMutex.RLock() + for _, logger := range plugin.loggers { + logger.Flush() + } + plugin.loggerMutex.RUnlock() + + return nil +} diff --git a/plugins/maxim/plugin_test.go b/plugins/maxim/plugin_test.go new file mode 100644 index 000000000..9d69fe404 --- /dev/null +++ b/plugins/maxim/plugin_test.go @@ -0,0 +1,258 @@ +// Package maxim provides integration for Maxim's SDK as a Bifrost plugin. +// It includes tests for plugin initialization, Bifrost integration, and request/response tracing. +package maxim + +import ( + "context" + "fmt" + "log" + "os" + "testing" + + bifrost "github.com/maximhq/bifrost/core" + "github.com/maximhq/bifrost/core/schemas" +) + +// getPlugin initializes and returns a Plugin instance for testing purposes. +// It sets up the Maxim logger with configuration from environment variables. +// +// Environment Variables: +// - MAXIM_API_KEY: API key for Maxim SDK authentication +// - MAXIM_LOG_REPO_ID: ID for the Maxim logger instance +// +// Returns: +// - schemas.Plugin: A configured plugin instance for request/response tracing +// - error: Any error that occurred during plugin initialization +func getPlugin() (schemas.Plugin, error) { + // check if Maxim Logger variables are set + if os.Getenv("MAXIM_API_KEY") == "" { + return nil, fmt.Errorf("MAXIM_API_KEY is not set, please set it in your environment variables") + } + + plugin, err := Init(Config{ + ApiKey: os.Getenv("MAXIM_API_KEY"), + LogRepoId: os.Getenv("MAXIM_LOG_REPO_ID"), + }) + if err != nil { + return nil, err + } + + return plugin, nil +} + +// BaseAccount implements the schemas.Account interface for testing purposes. +// It provides mock implementations of the required methods to test the Maxim plugin +// with a basic OpenAI configuration. +type BaseAccount struct{} + +// GetConfiguredProviders returns a list of supported providers for testing. +// Currently only supports OpenAI for simplicity in testing. You are free to add more providers as needed. +func (baseAccount *BaseAccount) GetConfiguredProviders() ([]schemas.ModelProvider, error) { + return []schemas.ModelProvider{schemas.OpenAI}, nil +} + +// GetKeysForProvider returns a mock API key configuration for testing. +// Uses the OPENAI_API_KEY environment variable for authentication. +func (baseAccount *BaseAccount) GetKeysForProvider(ctx *context.Context, providerKey schemas.ModelProvider) ([]schemas.Key, error) { + return []schemas.Key{ + { + Value: os.Getenv("OPENAI_API_KEY"), + Models: []string{"gpt-4o-mini", "gpt-4-turbo"}, + Weight: 1.0, + }, + }, nil +} + +// GetConfigForProvider returns default provider configuration for testing. +// Uses standard network and concurrency settings. +func (baseAccount *BaseAccount) GetConfigForProvider(providerKey schemas.ModelProvider) (*schemas.ProviderConfig, error) { + return &schemas.ProviderConfig{ + NetworkConfig: schemas.DefaultNetworkConfig, + ConcurrencyAndBufferSize: schemas.DefaultConcurrencyAndBufferSize, + }, nil +} + +// TestMaximLoggerPlugin tests the integration of the Maxim Logger plugin with Bifrost. +// It performs the following steps: +// 1. Initializes the Maxim plugin with environment variables +// 2. Sets up a test Bifrost instance with the plugin +// 3. Makes a test chat completion request +// +// Required environment variables: +// - MAXIM_API_KEY: Your Maxim API key +// - MAXIM_LOGGER_ID: Your Maxim logger repository ID +// - OPENAI_API_KEY: Your OpenAI API key for the test request +func TestMaximLoggerPlugin(t *testing.T) { + ctx := context.Background() + // Initialize the Maxim plugin + plugin, err := getPlugin() + if err != nil { + t.Fatalf("Error setting up the plugin: %v", err) + } + + account := BaseAccount{} + + // Initialize Bifrost with the plugin + client, err := bifrost.Init(ctx, schemas.BifrostConfig{ + Account: &account, + Plugins: []schemas.Plugin{plugin}, + Logger: bifrost.NewDefaultLogger(schemas.LogLevelDebug), + }) + if err != nil { + t.Fatalf("Error initializing Bifrost: %v", err) + } + + // Make a test chat completion request + _, bifrostErr := client.ChatCompletionRequest(context.Background(), &schemas.BifrostRequest{ + Provider: schemas.OpenAI, + Model: "gpt-4o-mini", + Input: schemas.RequestInput{ + ChatCompletionInput: &[]schemas.BifrostMessage{ + { + Role: "user", + Content: schemas.MessageContent{ + ContentStr: bifrost.Ptr("Hello, how are you?"), + }, + }, + }, + }, + }) + + if bifrostErr != nil { + log.Printf("Error in Bifrost request: %v", bifrostErr) + } + + log.Println("Bifrost request completed, check your Maxim Dashboard for the trace") + + client.Shutdown() +} + +// TestLogRepoIDSelection tests the single repository selection logic +func TestLogRepoIDSelection(t *testing.T) { + tests := []struct { + name string + defaultRepo string + headerRepo string + expectedRepo string + shouldLog bool + }{ + { + name: "Header repo takes priority", + defaultRepo: "default-repo", + headerRepo: "header-repo", + expectedRepo: "header-repo", + shouldLog: true, + }, + { + name: "Fall back to default repo when no header", + defaultRepo: "default-repo", + headerRepo: "", + expectedRepo: "default-repo", + shouldLog: true, + }, + { + name: "Use header repo when no default", + defaultRepo: "", + headerRepo: "header-repo", + expectedRepo: "header-repo", + shouldLog: true, + }, + { + name: "Skip logging when neither available", + defaultRepo: "", + headerRepo: "", + expectedRepo: "", + shouldLog: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + // Create plugin with default repo + plugin := &Plugin{ + defaultLogRepoId: tt.defaultRepo, + } + + // Create context with header repo if provided + ctx := context.Background() + if tt.headerRepo != "" { + ctx = context.WithValue(ctx, LogRepoIDKey, tt.headerRepo) + } + + // Test the selection logic + result := plugin.getEffectiveLogRepoID(&ctx) + + if result != tt.expectedRepo { + t.Errorf("Expected repo '%s', got '%s'", tt.expectedRepo, result) + } + + shouldLog := result != "" + if shouldLog != tt.shouldLog { + t.Errorf("Expected shouldLog=%t, got shouldLog=%t", tt.shouldLog, shouldLog) + } + }) + } +} + +// TestPluginInitialization tests plugin initialization with different configs +func TestPluginInitialization(t *testing.T) { + tests := []struct { + name string + config Config + expectError bool + }{ + { + name: "Valid config with both fields", + config: Config{ + ApiKey: "test-api-key", + LogRepoId: "test-repo-id", + }, + expectError: false, + }, + { + name: "Valid config with only API key", + config: Config{ + ApiKey: "test-api-key", + LogRepoId: "", + }, + expectError: false, + }, + { + name: "Invalid config - missing API key", + config: Config{ + ApiKey: "", + LogRepoId: "test-repo-id", + }, + expectError: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + // Skip actual Maxim SDK initialization in tests + if tt.expectError { + _, err := Init(tt.config) + if err == nil { + t.Error("Expected error but got none") + } + } else { + // For valid configs, we can't test actual initialization without real API key + // Just test the validation logic + if tt.config.ApiKey == "" { + t.Skip("Skipping valid config test - would need real Maxim API key") + } + } + }) + } +} + +// TestPluginName tests the plugin name functionality +func TestPluginName(t *testing.T) { + plugin := &Plugin{} + if plugin.GetName() != PluginName { + t.Errorf("Expected plugin name '%s', got '%s'", PluginName, plugin.GetName()) + } + if PluginName != "maxim" { + t.Errorf("Expected PluginName constant to be 'maxim', got '%s'", PluginName) + } +} diff --git a/plugins/maxim/version b/plugins/maxim/version new file mode 100644 index 000000000..95b25aee2 --- /dev/null +++ b/plugins/maxim/version @@ -0,0 +1 @@ +1.3.6 diff --git a/plugins/mocker/benchmark_test.go b/plugins/mocker/benchmark_test.go new file mode 100644 index 000000000..5e26311c7 --- /dev/null +++ b/plugins/mocker/benchmark_test.go @@ -0,0 +1,296 @@ +package mocker + +import ( + "context" + "strconv" + "testing" + + bifrost "github.com/maximhq/bifrost/core" + "github.com/maximhq/bifrost/core/schemas" +) + +// BenchmarkMockerPlugin_PreHook_SimpleRule benchmarks simple rule matching +func BenchmarkMockerPlugin_PreHook_SimpleRule(b *testing.B) { + plugin, err := Init(MockerConfig{ + Enabled: true, + Rules: []MockRule{ + { + Name: "simple-rule", + Enabled: true, + Priority: 100, + Probability: 1.0, + Conditions: Conditions{ + Providers: []string{"openai"}, + }, + Responses: []Response{ + { + Type: ResponseTypeSuccess, + Content: &SuccessResponse{ + Message: "Benchmark response", + }, + }, + }, + }, + }, + }) + if err != nil { + b.Fatal(err) + } + + req := &schemas.BifrostRequest{ + Provider: schemas.OpenAI, + Model: "gpt-4", + Input: schemas.RequestInput{ + ChatCompletionInput: &[]schemas.BifrostMessage{ + { + Role: schemas.ModelChatMessageRoleUser, + Content: schemas.MessageContent{ + ContentStr: bifrost.Ptr("Hello, benchmark test"), + }, + }, + }, + }, + } + + ctx := context.Background() + + b.ResetTimer() + b.ReportAllocs() + + for i := 0; i < b.N; i++ { + _, _, _ = plugin.PreHook(&ctx, req) + } +} + +// BenchmarkMockerPlugin_PreHook_RegexRule benchmarks regex rule matching +func BenchmarkMockerPlugin_PreHook_RegexRule(b *testing.B) { + plugin, err := Init(MockerConfig{ + Enabled: true, + Rules: []MockRule{ + { + Name: "regex-rule", + Enabled: true, + Priority: 100, + Probability: 1.0, + Conditions: Conditions{ + MessageRegex: bifrost.Ptr(`(?i).*hello.*`), + }, + Responses: []Response{ + { + Type: ResponseTypeSuccess, + Content: &SuccessResponse{ + Message: "Regex matched response", + }, + }, + }, + }, + }, + }) + if err != nil { + b.Fatal(err) + } + + req := &schemas.BifrostRequest{ + Provider: schemas.OpenAI, + Model: "gpt-4", + Input: schemas.RequestInput{ + ChatCompletionInput: &[]schemas.BifrostMessage{ + { + Role: schemas.ModelChatMessageRoleUser, + Content: schemas.MessageContent{ + ContentStr: bifrost.Ptr("Hello, this should match the regex pattern"), + }, + }, + }, + }, + } + + ctx := context.Background() + + b.ResetTimer() + b.ReportAllocs() + + for i := 0; i < b.N; i++ { + _, _, _ = plugin.PreHook(&ctx, req) + } +} + +// BenchmarkMockerPlugin_PreHook_MultipleRules benchmarks multiple rule evaluation +func BenchmarkMockerPlugin_PreHook_MultipleRules(b *testing.B) { + rules := make([]MockRule, 10) + for i := 0; i < 10; i++ { + rules[i] = MockRule{ + Name: "rule-" + strconv.Itoa(i), + Enabled: true, + Priority: 100 - i, // Descending priority + Probability: 1.0, + Conditions: Conditions{ + Models: []string{"gpt-" + strconv.Itoa(i)}, + }, + Responses: []Response{ + { + Type: ResponseTypeSuccess, + Content: &SuccessResponse{ + Message: "Response from rule " + strconv.Itoa(i), + }, + }, + }, + } + } + + // Add a matching rule at the end + rules = append(rules, MockRule{ + Name: "matching-rule", + Enabled: true, + Priority: 50, + Probability: 1.0, + Conditions: Conditions{ + Models: []string{"gpt-4"}, + }, + Responses: []Response{ + { + Type: ResponseTypeSuccess, + Content: &SuccessResponse{ + Message: "Matching rule response", + }, + }, + }, + }) + + plugin, err := Init(MockerConfig{ + Enabled: true, + Rules: rules, + }) + if err != nil { + b.Fatal(err) + } + + req := &schemas.BifrostRequest{ + Provider: schemas.OpenAI, + Model: "gpt-4", + Input: schemas.RequestInput{ + ChatCompletionInput: &[]schemas.BifrostMessage{ + { + Role: schemas.ModelChatMessageRoleUser, + Content: schemas.MessageContent{ + ContentStr: bifrost.Ptr("Test message"), + }, + }, + }, + }, + } + + ctx := context.Background() + + b.ResetTimer() + b.ReportAllocs() + + for i := 0; i < b.N; i++ { + _, _, _ = plugin.PreHook(&ctx, req) + } +} + +// BenchmarkMockerPlugin_PreHook_NoMatch benchmarks when no rules match +func BenchmarkMockerPlugin_PreHook_NoMatch(b *testing.B) { + plugin, err := Init(MockerConfig{ + Enabled: true, + DefaultBehavior: DefaultBehaviorPassthrough, + Rules: []MockRule{ + { + Name: "non-matching-rule", + Enabled: true, + Priority: 100, + Probability: 1.0, + Conditions: Conditions{ + Providers: []string{"anthropic"}, // Won't match OpenAI + }, + Responses: []Response{ + { + Type: ResponseTypeSuccess, + Content: &SuccessResponse{ + Message: "This won't match", + }, + }, + }, + }, + }, + }) + if err != nil { + b.Fatal(err) + } + + req := &schemas.BifrostRequest{ + Provider: schemas.OpenAI, // Different from rule condition + Model: "gpt-4", + Input: schemas.RequestInput{ + ChatCompletionInput: &[]schemas.BifrostMessage{ + { + Role: schemas.ModelChatMessageRoleUser, + Content: schemas.MessageContent{ + ContentStr: bifrost.Ptr("Test message"), + }, + }, + }, + }, + } + + ctx := context.Background() + + b.ResetTimer() + b.ReportAllocs() + + for i := 0; i < b.N; i++ { + _, _, _ = plugin.PreHook(&ctx, req) + } +} + +// BenchmarkMockerPlugin_PreHook_Template benchmarks template processing +func BenchmarkMockerPlugin_PreHook_Template(b *testing.B) { + plugin, err := Init(MockerConfig{ + Enabled: true, + Rules: []MockRule{ + { + Name: "template-rule", + Enabled: true, + Priority: 100, + Probability: 1.0, + Conditions: Conditions{}, // Match all + Responses: []Response{ + { + Type: ResponseTypeSuccess, + Content: &SuccessResponse{ + MessageTemplate: bifrost.Ptr("Hello from {{provider}} using model {{model}}!"), + }, + }, + }, + }, + }, + }) + if err != nil { + b.Fatal(err) + } + + req := &schemas.BifrostRequest{ + Provider: schemas.OpenAI, + Model: "gpt-4", + Input: schemas.RequestInput{ + ChatCompletionInput: &[]schemas.BifrostMessage{ + { + Role: schemas.ModelChatMessageRoleUser, + Content: schemas.MessageContent{ + ContentStr: bifrost.Ptr("Test message"), + }, + }, + }, + }, + } + + ctx := context.Background() + + b.ResetTimer() + b.ReportAllocs() + + for i := 0; i < b.N; i++ { + _, _, _ = plugin.PreHook(&ctx, req) + } +} diff --git a/plugins/mocker/changelog.md b/plugins/mocker/changelog.md new file mode 100644 index 000000000..6dcfe4edd --- /dev/null +++ b/plugins/mocker/changelog.md @@ -0,0 +1,4 @@ + + + +- Upgrades framework to 1.0.23 \ No newline at end of file diff --git a/plugins/mocker/go.mod b/plugins/mocker/go.mod new file mode 100644 index 000000000..4b66ac0df --- /dev/null +++ b/plugins/mocker/go.mod @@ -0,0 +1,54 @@ +module github.com/maximhq/bifrost/plugins/mocker + +go 1.24.1 + +toolchain go1.24.3 + +require ( + github.com/jaswdr/faker/v2 v2.8.0 + github.com/maximhq/bifrost/core v1.1.37 +) + +require ( + cloud.google.com/go/compute/metadata v0.8.0 // indirect + github.com/andybalholm/brotli v1.2.0 // indirect + github.com/aws/aws-sdk-go-v2 v1.38.0 // indirect + github.com/aws/aws-sdk-go-v2/config v1.31.0 // indirect + github.com/aws/aws-sdk-go-v2/credentials v1.18.4 // indirect + github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.18.3 // indirect + github.com/aws/aws-sdk-go-v2/internal/configsources v1.4.3 // indirect + github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.7.3 // indirect + github.com/aws/aws-sdk-go-v2/internal/ini v1.8.3 // indirect + github.com/aws/aws-sdk-go-v2/service/internal/accept-encoding v1.13.0 // indirect + github.com/aws/aws-sdk-go-v2/service/internal/presigned-url v1.13.3 // indirect + github.com/aws/aws-sdk-go-v2/service/sso v1.28.0 // indirect + github.com/aws/aws-sdk-go-v2/service/ssooidc v1.33.0 // indirect + github.com/aws/aws-sdk-go-v2/service/sts v1.37.0 // indirect + github.com/aws/smithy-go v1.22.5 // indirect + github.com/bahlo/generic-list-go v0.2.0 // indirect + github.com/buger/jsonparser v1.1.1 // indirect + github.com/bytedance/sonic v1.14.0 // indirect + github.com/bytedance/sonic/loader v0.3.0 // indirect + github.com/cloudwego/base64x v0.1.6 // indirect + github.com/google/uuid v1.6.0 // indirect + github.com/invopop/jsonschema v0.13.0 // indirect + github.com/klauspost/compress v1.18.0 // indirect + github.com/klauspost/cpuid/v2 v2.3.0 // indirect + github.com/mailru/easyjson v0.9.0 // indirect + github.com/mark3labs/mcp-go v0.37.0 // indirect + github.com/mattn/go-colorable v0.1.14 // indirect + github.com/mattn/go-isatty v0.0.20 // indirect + github.com/rs/zerolog v1.34.0 // indirect + github.com/spf13/cast v1.9.2 // indirect + github.com/twitchyliquid64/golang-asm v0.15.1 // indirect + github.com/valyala/bytebufferpool v1.0.0 // indirect + github.com/valyala/fasthttp v1.65.0 // indirect + github.com/wk8/go-ordered-map/v2 v2.1.8 // indirect + github.com/yosida95/uritemplate/v3 v3.0.2 // indirect + golang.org/x/arch v0.20.0 // indirect + golang.org/x/net v0.43.0 // indirect + golang.org/x/oauth2 v0.30.0 // indirect + golang.org/x/sys v0.35.0 // indirect + golang.org/x/text v0.28.0 // indirect + gopkg.in/yaml.v3 v3.0.1 // indirect +) diff --git a/plugins/mocker/go.sum b/plugins/mocker/go.sum new file mode 100644 index 000000000..1784779ff --- /dev/null +++ b/plugins/mocker/go.sum @@ -0,0 +1,127 @@ +cloud.google.com/go/compute/metadata v0.8.0 h1:HxMRIbao8w17ZX6wBnjhcDkW6lTFpgcaobyVfZWqRLA= +cloud.google.com/go/compute/metadata v0.8.0/go.mod h1:sYOGTp851OV9bOFJ9CH7elVvyzopvWQFNNghtDQ/Biw= +github.com/andybalholm/brotli v1.2.0 h1:ukwgCxwYrmACq68yiUqwIWnGY0cTPox/M94sVwToPjQ= +github.com/andybalholm/brotli v1.2.0/go.mod h1:rzTDkvFWvIrjDXZHkuS16NPggd91W3kUSvPlQ1pLaKY= +github.com/aws/aws-sdk-go-v2 v1.38.0 h1:UCRQ5mlqcFk9HJDIqENSLR3wiG1VTWlyUfLDEvY7RxU= +github.com/aws/aws-sdk-go-v2 v1.38.0/go.mod h1:9Q0OoGQoboYIAJyslFyF1f5K1Ryddop8gqMhWx/n4Wg= +github.com/aws/aws-sdk-go-v2/config v1.31.0 h1:9yH0xiY5fUnVNLRWO0AtayqwU1ndriZdN78LlhruJR4= +github.com/aws/aws-sdk-go-v2/config v1.31.0/go.mod h1:VeV3K72nXnhbe4EuxxhzsDc/ByrCSlZwUnWH52Nde/I= +github.com/aws/aws-sdk-go-v2/credentials v1.18.4 h1:IPd0Algf1b+Qy9BcDp0sCUcIWdCQPSzDoMK3a8pcbUM= +github.com/aws/aws-sdk-go-v2/credentials v1.18.4/go.mod h1:nwg78FjH2qvsRM1EVZlX9WuGUJOL5od+0qvm0adEzHk= +github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.18.3 h1:GicIdnekoJsjq9wqnvyi2elW6CGMSYKhdozE7/Svh78= +github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.18.3/go.mod h1:R7BIi6WNC5mc1kfRM7XM/VHC3uRWkjc396sfabq4iOo= +github.com/aws/aws-sdk-go-v2/internal/configsources v1.4.3 h1:o9RnO+YZ4X+kt5Z7Nvcishlz0nksIt2PIzDglLMP0vA= +github.com/aws/aws-sdk-go-v2/internal/configsources v1.4.3/go.mod h1:+6aLJzOG1fvMOyzIySYjOFjcguGvVRL68R+uoRencN4= +github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.7.3 h1:joyyUFhiTQQmVK6ImzNU9TQSNRNeD9kOklqTzyk5v6s= +github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.7.3/go.mod h1:+vNIyZQP3b3B1tSLI0lxvrU9cfM7gpdRXMFfm67ZcPc= +github.com/aws/aws-sdk-go-v2/internal/ini v1.8.3 h1:bIqFDwgGXXN1Kpp99pDOdKMTTb5d2KyU5X/BZxjOkRo= +github.com/aws/aws-sdk-go-v2/internal/ini v1.8.3/go.mod h1:H5O/EsxDWyU+LP/V8i5sm8cxoZgc2fdNR9bxlOFrQTo= +github.com/aws/aws-sdk-go-v2/service/internal/accept-encoding v1.13.0 h1:6+lZi2JeGKtCraAj1rpoZfKqnQ9SptseRZioejfUOLM= +github.com/aws/aws-sdk-go-v2/service/internal/accept-encoding v1.13.0/go.mod h1:eb3gfbVIxIoGgJsi9pGne19dhCBpK6opTYpQqAmdy44= +github.com/aws/aws-sdk-go-v2/service/internal/presigned-url v1.13.3 h1:ieRzyHXypu5ByllM7Sp4hC5f/1Fy5wqxqY0yB85hC7s= +github.com/aws/aws-sdk-go-v2/service/internal/presigned-url v1.13.3/go.mod h1:O5ROz8jHiOAKAwx179v+7sHMhfobFVi6nZt8DEyiYoM= +github.com/aws/aws-sdk-go-v2/service/sso v1.28.0 h1:Mc/MKBf2m4VynyJkABoVEN+QzkfLqGj0aiJuEe7cMeM= +github.com/aws/aws-sdk-go-v2/service/sso v1.28.0/go.mod h1:iS5OmxEcN4QIPXARGhavH7S8kETNL11kym6jhoS7IUQ= +github.com/aws/aws-sdk-go-v2/service/ssooidc v1.33.0 h1:6csaS/aJmqZQbKhi1EyEMM7yBW653Wy/B9hnBofW+sw= +github.com/aws/aws-sdk-go-v2/service/ssooidc v1.33.0/go.mod h1:59qHWaY5B+Rs7HGTuVGaC32m0rdpQ68N8QCN3khYiqs= +github.com/aws/aws-sdk-go-v2/service/sts v1.37.0 h1:MG9VFW43M4A8BYeAfaJJZWrroinxeTi2r3+SnmLQfSA= +github.com/aws/aws-sdk-go-v2/service/sts v1.37.0/go.mod h1:JdeBDPgpJfuS6rU/hNglmOigKhyEZtBmbraLE4GK1J8= +github.com/aws/smithy-go v1.22.5 h1:P9ATCXPMb2mPjYBgueqJNCA5S9UfktsW0tTxi+a7eqw= +github.com/aws/smithy-go v1.22.5/go.mod h1:t1ufH5HMublsJYulve2RKmHDC15xu1f26kHCp/HgceI= +github.com/bahlo/generic-list-go v0.2.0 h1:5sz/EEAK+ls5wF+NeqDpk5+iNdMDXrh3z3nPnH1Wvgk= +github.com/bahlo/generic-list-go v0.2.0/go.mod h1:2KvAjgMlE5NNynlg/5iLrrCCZ2+5xWbdbCW3pNTGyYg= +github.com/buger/jsonparser v1.1.1 h1:2PnMjfWD7wBILjqQbt530v576A/cAbQvEW9gGIpYMUs= +github.com/buger/jsonparser v1.1.1/go.mod h1:6RYKKt7H4d4+iWqouImQ9R2FZql3VbhNgx27UK13J/0= +github.com/bytedance/sonic v1.14.0 h1:/OfKt8HFw0kh2rj8N0F6C/qPGRESq0BbaNZgcNXXzQQ= +github.com/bytedance/sonic v1.14.0/go.mod h1:WoEbx8WTcFJfzCe0hbmyTGrfjt8PzNEBdxlNUO24NhA= +github.com/bytedance/sonic/loader v0.3.0 h1:dskwH8edlzNMctoruo8FPTJDF3vLtDT0sXZwvZJyqeA= +github.com/bytedance/sonic/loader v0.3.0/go.mod h1:N8A3vUdtUebEY2/VQC0MyhYeKUFosQU6FxH2JmUe6VI= +github.com/cloudwego/base64x v0.1.6 h1:t11wG9AECkCDk5fMSoxmufanudBtJ+/HemLstXDLI2M= +github.com/cloudwego/base64x v0.1.6/go.mod h1:OFcloc187FXDaYHvrNIjxSe8ncn0OOM8gEHfghB2IPU= +github.com/coreos/go-systemd/v22 v22.5.0/go.mod h1:Y58oyj3AT4RCenI/lSvhwexgC+NSVTIJ3seZv2GcEnc= +github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc h1:U9qPSI2PIWSS1VwoXQT9A3Wy9MM3WgvqSxFWenqJduM= +github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/frankban/quicktest v1.14.6 h1:7Xjx+VpznH+oBnejlPUj8oUpdxnVs4f8XU8WnHkI4W8= +github.com/frankban/quicktest v1.14.6/go.mod h1:4ptaffx2x8+WTWXmUCuVU6aPUX1/Mz7zb5vbUoiM6w0= +github.com/godbus/dbus/v5 v5.0.4/go.mod h1:xhWf0FNVPg57R7Z0UbKHbJfkEywrmjJnf7w5xrFpKfA= +github.com/google/go-cmp v0.7.0 h1:wk8382ETsv4JYUZwIsn6YpYiWiBsYLSJiTsyBybVuN8= +github.com/google/go-cmp v0.7.0/go.mod h1:pXiqmnSA92OHEEa9HXL2W4E7lf9JzCmGVUdgjX3N/iU= +github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0= +github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= +github.com/invopop/jsonschema v0.13.0 h1:KvpoAJWEjR3uD9Kbm2HWJmqsEaHt8lBUpd0qHcIi21E= +github.com/invopop/jsonschema v0.13.0/go.mod h1:ffZ5Km5SWWRAIN6wbDXItl95euhFz2uON45H2qjYt+0= +github.com/jaswdr/faker/v2 v2.8.0 h1:3AxdXW9U7dJmWckh/P0YgRbNlCcVsTyrUNUnLVP9b3Q= +github.com/jaswdr/faker/v2 v2.8.0/go.mod h1:jZq+qzNQr8/P+5fHd9t3txe2GNPnthrTfohtnJ7B+68= +github.com/klauspost/compress v1.18.0 h1:c/Cqfb0r+Yi+JtIEq73FWXVkRonBlf0CRNYc8Zttxdo= +github.com/klauspost/compress v1.18.0/go.mod h1:2Pp+KzxcywXVXMr50+X0Q/Lsb43OQHYWRCY2AiWywWQ= +github.com/klauspost/cpuid/v2 v2.3.0 h1:S4CRMLnYUhGeDFDqkGriYKdfoFlDnMtqTiI/sFzhA9Y= +github.com/klauspost/cpuid/v2 v2.3.0/go.mod h1:hqwkgyIinND0mEev00jJYCxPNVRVXFQeu1XKlok6oO0= +github.com/kr/pretty v0.3.1 h1:flRD4NNwYAUpkphVc1HcthR4KEIFJ65n8Mw5qdRn3LE= +github.com/kr/pretty v0.3.1/go.mod h1:hoEshYVHaxMs3cyo3Yncou5ZscifuDolrwPKZanG3xk= +github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY= +github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE= +github.com/mailru/easyjson v0.9.0 h1:PrnmzHw7262yW8sTBwxi1PdJA3Iw/EKBa8psRf7d9a4= +github.com/mailru/easyjson v0.9.0/go.mod h1:1+xMtQp2MRNVL/V1bOzuP3aP8VNwRW55fQUto+XFtTU= +github.com/mark3labs/mcp-go v0.37.0 h1:BywvZLPRT6Zx6mMG/MJfxLSZQkTGIcJSEGKsvr4DsoQ= +github.com/mark3labs/mcp-go v0.37.0/go.mod h1:T7tUa2jO6MavG+3P25Oy/jR7iCeJPHImCZHRymCn39g= +github.com/mattn/go-colorable v0.1.13/go.mod h1:7S9/ev0klgBDR4GtXTXX8a3vIGJpMovkB8vQcUbaXHg= +github.com/mattn/go-colorable v0.1.14 h1:9A9LHSqF/7dyVVX6g0U9cwm9pG3kP9gSzcuIPHPsaIE= +github.com/mattn/go-colorable v0.1.14/go.mod h1:6LmQG8QLFO4G5z1gPvYEzlUgJ2wF+stgPZH1UqBm1s8= +github.com/mattn/go-isatty v0.0.16/go.mod h1:kYGgaQfpe5nmfYZH+SKPsOc2e4SrIfOl2e/yFXSvRLM= +github.com/mattn/go-isatty v0.0.19/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y= +github.com/mattn/go-isatty v0.0.20 h1:xfD0iDuEKnDkl03q4limB+vH+GxLEtL/jb4xVJSWWEY= +github.com/mattn/go-isatty v0.0.20/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y= +github.com/maximhq/bifrost/core v1.1.37 h1:jVFY1tQFY8T2r4S3RE1zN8cFp1Uw97Dec3Ud32rR8Uc= +github.com/maximhq/bifrost/core v1.1.37/go.mod h1:tf2pFTpoM53UGXXMFYxsaUjMqnCqYDOd9glFgMJvA0c= +github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= +github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= +github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2 h1:Jamvg5psRIccs7FGNTlIRMkT8wgtp5eCXdBlqhYGL6U= +github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= +github.com/rogpeppe/go-internal v1.13.1 h1:KvO1DLK/DRN07sQ1LQKScxyZJuNnedQ5/wKSR38lUII= +github.com/rogpeppe/go-internal v1.13.1/go.mod h1:uMEvuHeurkdAXX61udpOXGD/AzZDWNMNyH2VO9fmH0o= +github.com/rs/xid v1.6.0/go.mod h1:7XoLgs4eV+QndskICGsho+ADou8ySMSjJKDIan90Nz0= +github.com/rs/zerolog v1.34.0 h1:k43nTLIwcTVQAncfCw4KZ2VY6ukYoZaBPNOE8txlOeY= +github.com/rs/zerolog v1.34.0/go.mod h1:bJsvje4Z08ROH4Nhs5iH600c3IkWhwp44iRc54W6wYQ= +github.com/spf13/cast v1.9.2 h1:SsGfm7M8QOFtEzumm7UZrZdLLquNdzFYfIbEXntcFbE= +github.com/spf13/cast v1.9.2/go.mod h1:jNfB8QC9IA6ZuY2ZjDp0KtFO2LZZlg4S/7bzP6qqeHo= +github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= +github.com/stretchr/objx v0.4.0/go.mod h1:YvHI0jy2hoMjB+UWwv71VJQ9isScKT/TqJzVSSt89Yw= +github.com/stretchr/objx v0.5.0/go.mod h1:Yh+to48EsGEfYuaHDzXPcE3xhTkx73EhmCGUpEOglKo= +github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= +github.com/stretchr/testify v1.8.0/go.mod h1:yNjHg4UonilssWZ8iaSj1OCr/vHnekPRkoO+kdMU+MU= +github.com/stretchr/testify v1.8.1/go.mod h1:w2LPCIKwWwSfY2zedu0+kehJoqGctiVI29o6fzry7u4= +github.com/stretchr/testify v1.10.0 h1:Xv5erBjTwe/5IxqUQTdXv5kgmIvbHo3QQyRwhJsOfJA= +github.com/stretchr/testify v1.10.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY= +github.com/twitchyliquid64/golang-asm v0.15.1 h1:SU5vSMR7hnwNxj24w34ZyCi/FmDZTkS4MhqMhdFk5YI= +github.com/twitchyliquid64/golang-asm v0.15.1/go.mod h1:a1lVb/DtPvCB8fslRZhAngC2+aY1QWCk3Cedj/Gdt08= +github.com/valyala/bytebufferpool v1.0.0 h1:GqA5TC/0021Y/b9FG4Oi9Mr3q7XYx6KllzawFIhcdPw= +github.com/valyala/bytebufferpool v1.0.0/go.mod h1:6bBcMArwyJ5K/AmCkWv1jt77kVWyCJ6HpOuEn7z0Csc= +github.com/valyala/fasthttp v1.65.0 h1:j/u3uzFEGFfRxw79iYzJN+TteTJwbYkru9uDp3d0Yf8= +github.com/valyala/fasthttp v1.65.0/go.mod h1:P/93/YkKPMsKSnATEeELUCkG8a7Y+k99uxNHVbKINr4= +github.com/wk8/go-ordered-map/v2 v2.1.8 h1:5h/BUHu93oj4gIdvHHHGsScSTMijfx5PeYkE/fJgbpc= +github.com/wk8/go-ordered-map/v2 v2.1.8/go.mod h1:5nJHM5DyteebpVlHnWMV0rPz6Zp7+xBAnxjb1X5vnTw= +github.com/xyproto/randomstring v1.0.5 h1:YtlWPoRdgMu3NZtP45drfy1GKoojuR7hmRcnhZqKjWU= +github.com/xyproto/randomstring v1.0.5/go.mod h1:rgmS5DeNXLivK7YprL0pY+lTuhNQW3iGxZ18UQApw/E= +github.com/yosida95/uritemplate/v3 v3.0.2 h1:Ed3Oyj9yrmi9087+NczuL5BwkIc4wvTb5zIM+UJPGz4= +github.com/yosida95/uritemplate/v3 v3.0.2/go.mod h1:ILOh0sOhIJR3+L/8afwt/kE++YT040gmv5BQTMR2HP4= +golang.org/x/arch v0.20.0 h1:dx1zTU0MAE98U+TQ8BLl7XsJbgze2WnNKF/8tGp/Q6c= +golang.org/x/arch v0.20.0/go.mod h1:bdwinDaKcfZUGpH09BB7ZmOfhalA8lQdzl62l8gGWsk= +golang.org/x/net v0.43.0 h1:lat02VYK2j4aLzMzecihNvTlJNQUq316m2Mr9rnM6YE= +golang.org/x/net v0.43.0/go.mod h1:vhO1fvI4dGsIjh73sWfUVjj3N7CA9WkKJNQm2svM6Jg= +golang.org/x/oauth2 v0.30.0 h1:dnDm7JmhM45NNpd8FDDeLhK6FwqbOf4MLCM9zb1BOHI= +golang.org/x/oauth2 v0.30.0/go.mod h1:B++QgG3ZKulg6sRPGD/mqlHQs5rB3Ml9erfeDY7xKlU= +golang.org/x/sys v0.0.0-20220811171246-fbc7d0a398ab/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.12.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.35.0 h1:vz1N37gP5bs89s7He8XuIYXpyY0+QlsKmzipCbUtyxI= +golang.org/x/sys v0.35.0/go.mod h1:BJP2sWEmIv4KK5OTEluFJCKSidICx8ciO85XgH3Ak8k= +golang.org/x/text v0.28.0 h1:rhazDwis8INMIwQ4tpjLDzUhx6RlXqZNPEM0huQojng= +golang.org/x/text v0.28.0/go.mod h1:U8nCwOR8jO/marOQ0QbDiOngZVEBB7MAiitBuMjXiNU= +gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= +gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c h1:Hei/4ADfdWqJk1ZMxUNpqntNwaWcugrBjAiHlqqRiVk= +gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c/go.mod h1:JHkPIbrfpd72SG/EVd6muEfDQjcINNoR0C8j2r3qZ4Q= +gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= +gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= +gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= diff --git a/plugins/mocker/main.go b/plugins/mocker/main.go new file mode 100644 index 000000000..cb850af7b --- /dev/null +++ b/plugins/mocker/main.go @@ -0,0 +1,1088 @@ +package mocker + +import ( + "context" + "fmt" + "maps" + "math/rand" + "regexp" + "sort" + "strings" + "sync" + "sync/atomic" + "time" + + "github.com/jaswdr/faker/v2" + bifrost "github.com/maximhq/bifrost/core" + "github.com/maximhq/bifrost/core/schemas" +) + +const ( + PluginName = "bifrost-mocker" +) + +// Constants for type checking and validation +const ( + // Response types + ResponseTypeSuccess = "success" + ResponseTypeError = "error" + + // Default behaviors + DefaultBehaviorPassthrough = "passthrough" + DefaultBehaviorError = "error" + DefaultBehaviorSuccess = "success" + + // Latency types + LatencyTypeFixed = "fixed" + LatencyTypeUniform = "uniform" +) + +// compiledRule represents a rule with pre-compiled regex and normalized weights for performance +type compiledRule struct { + MockRule + compiledRegex *regexp.Regexp // Pre-compiled regex for fast matching + normalizedWeights []float64 // Pre-calculated normalized weights for fast response selection +} + +// MockerPlugin provides comprehensive request/response mocking capabilities +type MockerPlugin struct { + config MockerConfig + rules []MockRule + compiledRules []compiledRule // Pre-compiled rules for performance + mu sync.RWMutex + faker faker.Faker // Use jaswdr/faker library + + // Atomic counters for high-performance statistics tracking + totalRequests int64 + mockedRequests int64 + responsesGenerated int64 + errorsGenerated int64 + + // Rule hits tracking (still needs mutex for map access) + ruleHitsMu sync.RWMutex + ruleHits map[string]int64 +} + +// MockerConfig defines the overall configuration for the mocker plugin +type MockerConfig struct { + Enabled bool `json:"enabled"` // Enable/disable the mocker plugin + GlobalLatency *Latency `json:"global_latency"` // Global latency settings applied to all rules (can be overridden per rule) + Rules []MockRule `json:"rules"` // List of mock rules to be evaluated in priority order + DefaultBehavior string `json:"default_behavior"` // Action when no rules match: "passthrough", "error", or "success" +} + +// MockRule defines a single mocking rule with conditions and responses +// Rules are evaluated in priority order (higher numbers = higher priority) +type MockRule struct { + Name string `json:"name"` // Unique rule name for identification and statistics tracking + Enabled bool `json:"enabled"` // Enable/disable this rule (disabled rules are skipped) + Priority int `json:"priority"` // Higher priority rules are checked first (higher numbers = higher priority) + Conditions Conditions `json:"conditions"` // Conditions that must match for this rule to apply + Responses []Response `json:"responses"` // Possible responses (selected using weighted random selection) + Latency *Latency `json:"latency"` // Rule-specific latency override (overrides global latency if set) + Probability float64 `json:"probability"` // Probability of rule activation (0.0=never, 1.0=always, 0=disabled) +} + +// Conditions define when a mock rule should be applied +// All specified conditions must match for the rule to trigger +type Conditions struct { + Providers []string `json:"providers"` // Match specific providers (e.g., ["openai", "anthropic"]) + Models []string `json:"models"` // Match specific models (e.g., ["gpt-4", "claude-3"]) + MessageRegex *string `json:"message_regex"` // Regex pattern to match against message content + RequestSize *SizeRange `json:"request_size"` // Request size constraints in bytes +} + +// Response defines a mock response configuration +// Either Content (for success) or Error (for error) should be set, not both +type Response struct { + Type string `json:"type"` // Response type: "success" or "error" + Weight float64 `json:"weight"` // Weight for random selection (higher = more likely) + Content *SuccessResponse `json:"content"` // Success response content (required if Type="success") + Error *ErrorResponse `json:"error"` // Error response content (required if Type="error") + AllowFallbacks *bool `json:"allow_fallbacks"` // Control fallback behavior for errors (nil=true, false=no fallbacks) +} + +// SuccessResponse defines mock success response content +// Either Message or MessageTemplate should be set (MessageTemplate takes precedence) +type SuccessResponse struct { + Message string `json:"message"` // Static response message + Model *string `json:"model"` // Override model name in response (optional) + Usage *Usage `json:"usage"` // Token usage info (optional, defaults applied if nil) + FinishReason *string `json:"finish_reason"` // Completion reason (optional, defaults to "stop") + MessageTemplate *string `json:"message_template"` // Template with variables like {{model}}, {{provider}} (overrides Message) + CustomFields map[string]interface{} `json:"custom_fields"` // Additional fields stored in response metadata +} + +// ErrorResponse defines mock error response content +type ErrorResponse struct { + Message string `json:"message"` // Error message to return + Type *string `json:"type"` // Error type (e.g., "rate_limit", "auth_error") + Code *string `json:"code"` // Error code (e.g., "429", "401") + StatusCode *int `json:"status_code"` // HTTP status code for the error +} + +// Latency defines latency simulation settings +type Latency struct { + Min time.Duration `json:"min"` // Minimum latency as time.Duration (e.g., 100*time.Millisecond, NOT raw int) + Max time.Duration `json:"max"` // Maximum latency as time.Duration (e.g., 500*time.Millisecond, NOT raw int) + Type string `json:"type"` // Latency type: "fixed" or "uniform" +} + +// SizeRange defines request size constraints in bytes +type SizeRange struct { + Min int `json:"min"` // Minimum request size in bytes + Max int `json:"max"` // Maximum request size in bytes +} + +// Usage defines token usage information +type Usage struct { + PromptTokens int `json:"prompt_tokens"` + CompletionTokens int `json:"completion_tokens"` + TotalTokens int `json:"total_tokens"` +} + +// MockStats tracks plugin statistics and rule execution counts +type MockStats struct { + TotalRequests int64 `json:"total_requests"` // Total number of requests processed + MockedRequests int64 `json:"mocked_requests"` // Number of requests that were mocked (rules matched) + RuleHits map[string]int64 `json:"rule_hits"` // Rule name -> hit count mapping + ErrorsGenerated int64 `json:"errors_generated"` // Number of error responses generated + ResponsesGenerated int64 `json:"responses_generated"` // Number of success responses generated +} + +// Init creates a new mocker plugin instance with sensible defaults +// Returns an error if required configuration is invalid or missing +func Init(config MockerConfig) (*MockerPlugin, error) { + // Validate configuration + if err := validateConfig(config); err != nil { + return nil, fmt.Errorf("invalid mocker plugin configuration: %w", err) + } + + // Apply defaults if not set + if config.DefaultBehavior == "" { + config.DefaultBehavior = DefaultBehaviorPassthrough // Default to passthrough if no rules match + } + + // If no rules provided, create a simple catch-all rule for quick testing + if len(config.Rules) == 0 && config.Enabled { + config.Rules = []MockRule{ + { + Name: "default-mock", + Enabled: true, + Priority: 1, + Conditions: Conditions{}, // Empty conditions = match all requests + Probability: 1.0, // Always activate + Responses: []Response{ + { + Type: ResponseTypeSuccess, + Weight: 1.0, + Content: &SuccessResponse{ + Message: "This is a mock response from the Mocker plugin", + }, + }, + }, + }, + } + } + + plugin := &MockerPlugin{ + config: config, + rules: config.Rules, + ruleHits: make(map[string]int64), + faker: faker.New(), // Initialize faker + } + + // Pre-compile all regex patterns for performance + if err := plugin.compileRules(); err != nil { + return nil, fmt.Errorf("failed to compile rules: %w", err) + } + + return plugin, nil +} + +// compileRules pre-compiles all regex patterns and calculates normalized weights for performance +func (p *MockerPlugin) compileRules() error { + p.compiledRules = make([]compiledRule, 0, len(p.rules)) + + for _, rule := range p.rules { + compiled := compiledRule{MockRule: rule} + + // Pre-compile regex if present + if rule.Conditions.MessageRegex != nil { + regex, err := regexp.Compile(*rule.Conditions.MessageRegex) + if err != nil { + return fmt.Errorf("invalid regex in rule '%s': %w", rule.Name, err) + } + compiled.compiledRegex = regex + } + + // Pre-calculate normalized weights for fast response selection + compiled.normalizedWeights = p.calculateNormalizedWeights(rule.Responses) + + p.compiledRules = append(p.compiledRules, compiled) + } + + // Sort compiled rules by priority (higher first) + p.sortCompiledRulesByPriority() + + return nil +} + +// calculateNormalizedWeights pre-calculates normalized cumulative weights for fast response selection +func (p *MockerPlugin) calculateNormalizedWeights(responses []Response) []float64 { + if len(responses) == 0 { + return nil + } + + if len(responses) == 1 { + return []float64{1.0} // Single response always gets 100% probability + } + + // Calculate total weight, applying default weight of 1.0 if not specified + totalWeight := 0.0 + for _, response := range responses { + weight := response.Weight + if weight == 0 { + weight = 1.0 // Default weight + } + totalWeight += weight + } + + // Calculate normalized cumulative weights for O(1) selection + normalizedWeights := make([]float64, len(responses)) + cumulativeWeight := 0.0 + + for i, response := range responses { + weight := response.Weight + if weight == 0 { + weight = 1.0 // Default weight + } + cumulativeWeight += weight / totalWeight // Normalize to [0, 1] + normalizedWeights[i] = cumulativeWeight + } + + // Ensure the last weight is exactly 1.0 to handle floating point precision issues + if len(normalizedWeights) > 0 { + normalizedWeights[len(normalizedWeights)-1] = 1.0 + } + + return normalizedWeights +} + +// validateConfig validates the mocker plugin configuration +func validateConfig(config MockerConfig) error { + // Validate default behavior + if config.DefaultBehavior != "" { + switch config.DefaultBehavior { + case DefaultBehaviorPassthrough, DefaultBehaviorError, DefaultBehaviorSuccess: + // Valid + default: + return fmt.Errorf("invalid default_behavior '%s', must be one of: %s, %s, %s", + config.DefaultBehavior, DefaultBehaviorPassthrough, DefaultBehaviorError, DefaultBehaviorSuccess) + } + } + + // Validate global latency if provided + if config.GlobalLatency != nil { + if err := validateLatency(*config.GlobalLatency); err != nil { + return fmt.Errorf("invalid global_latency: %w", err) + } + } + + // Validate each rule + for i, rule := range config.Rules { + if err := validateRule(rule); err != nil { + return fmt.Errorf("invalid rule at index %d (%s): %w", i, rule.Name, err) + } + } + + return nil +} + +// validateRule validates a single mock rule +func validateRule(rule MockRule) error { + // Rule name is required + if rule.Name == "" { + return fmt.Errorf("rule name is required") + } + + // Priority should be reasonable (allow negative for low priority) + if rule.Priority < -1000 || rule.Priority > 1000 { + return fmt.Errorf("priority %d is out of reasonable range (-1000 to 1000)", rule.Priority) + } + + // Probability must be between 0 and 1 + if rule.Probability < 0 || rule.Probability > 1 { + return fmt.Errorf("probability %.2f must be between 0.0 and 1.0", rule.Probability) + } + + // At least one response is required + if len(rule.Responses) == 0 { + return fmt.Errorf("at least one response is required") + } + + // Validate rule-specific latency if provided + if rule.Latency != nil { + if err := validateLatency(*rule.Latency); err != nil { + return fmt.Errorf("invalid rule latency: %w", err) + } + } + + // Validate conditions + if err := validateConditions(rule.Conditions); err != nil { + return fmt.Errorf("invalid conditions: %w", err) + } + + // Validate each response + for i, response := range rule.Responses { + if err := validateResponse(response); err != nil { + return fmt.Errorf("invalid response at index %d: %w", i, err) + } + } + + return nil +} + +// validateLatency validates latency configuration +func validateLatency(latency Latency) error { + // Type is required + if latency.Type == "" { + return fmt.Errorf("latency type is required") + } + + // Validate type + switch latency.Type { + case LatencyTypeFixed, LatencyTypeUniform: + // Valid + default: + return fmt.Errorf("invalid latency type '%s', must be one of: %s, %s", + latency.Type, LatencyTypeFixed, LatencyTypeUniform) + } + + // Min latency should be non-negative + if latency.Min < 0 { + return fmt.Errorf("minimum latency cannot be negative") + } + + // For uniform type, max should be >= min + if latency.Type == LatencyTypeUniform { + if latency.Max < latency.Min { + return fmt.Errorf("maximum latency (%v) cannot be less than minimum latency (%v)", latency.Max, latency.Min) + } + } + + return nil +} + +// validateConditions validates rule conditions +func validateConditions(conditions Conditions) error { + // Validate regex if provided + if conditions.MessageRegex != nil { + _, err := regexp.Compile(*conditions.MessageRegex) + if err != nil { + return fmt.Errorf("invalid message regex '%s': %w", *conditions.MessageRegex, err) + } + } + + // Validate request size range if provided + if conditions.RequestSize != nil { + if conditions.RequestSize.Min < 0 { + return fmt.Errorf("request size minimum cannot be negative") + } + if conditions.RequestSize.Max < conditions.RequestSize.Min { + return fmt.Errorf("request size maximum (%d) cannot be less than minimum (%d)", + conditions.RequestSize.Max, conditions.RequestSize.Min) + } + } + + return nil +} + +// validateResponse validates a response configuration +func validateResponse(response Response) error { + // Type is required + if response.Type == "" { + return fmt.Errorf("response type is required") + } + + // Validate type + switch response.Type { + case ResponseTypeSuccess, ResponseTypeError: + // Valid + default: + return fmt.Errorf("invalid response type '%s', must be one of: %s, %s", + response.Type, ResponseTypeSuccess, ResponseTypeError) + } + + // Weight should be non-negative + if response.Weight < 0 { + return fmt.Errorf("response weight cannot be negative") + } + + // Validate response content based on type + if response.Type == ResponseTypeSuccess { + if response.Content == nil { + return fmt.Errorf("success response must have content") + } + if err := validateSuccessResponse(*response.Content); err != nil { + return fmt.Errorf("invalid success content: %w", err) + } + } else if response.Type == ResponseTypeError { + if response.Error == nil { + return fmt.Errorf("error response must have error content") + } + if err := validateErrorResponse(*response.Error); err != nil { + return fmt.Errorf("invalid error content: %w", err) + } + } + + return nil +} + +// validateSuccessResponse validates success response content +func validateSuccessResponse(content SuccessResponse) error { + // Either Message or MessageTemplate must be provided + if content.Message == "" && (content.MessageTemplate == nil || *content.MessageTemplate == "") { + return fmt.Errorf("either message or message_template is required") + } + + // If usage is provided, validate it + if content.Usage != nil { + if content.Usage.PromptTokens < 0 || content.Usage.CompletionTokens < 0 || content.Usage.TotalTokens < 0 { + return fmt.Errorf("token counts cannot be negative") + } + } + + return nil +} + +// validateErrorResponse validates error response content +func validateErrorResponse(errorContent ErrorResponse) error { + // Message is required + if errorContent.Message == "" { + return fmt.Errorf("error message is required") + } + + // Status code should be reasonable if provided + if errorContent.StatusCode != nil { + if *errorContent.StatusCode < 100 || *errorContent.StatusCode > 599 { + return fmt.Errorf("status code %d is out of valid HTTP range (100-599)", *errorContent.StatusCode) + } + } + + return nil +} + +// GetName returns the plugin name +func (p *MockerPlugin) GetName() string { + return PluginName +} + +// PreHook intercepts requests and applies mocking rules based on configuration +// This is called before the actual provider request and can short-circuit the flow +func (p *MockerPlugin) PreHook(ctx *context.Context, req *schemas.BifrostRequest) (*schemas.BifrostRequest, *schemas.PluginShortCircuit, error) { + // Skip processing if plugin is disabled + if !p.config.Enabled { + return req, nil, nil + } + + // Track total request count using atomic operation (no lock needed) + atomic.AddInt64(&p.totalRequests, 1) + + // Find the first matching rule based on priority order + rule := p.findMatchingCompiledRule(req) + if rule == nil { + // No rules matched, handle according to default behavior + return p.handleDefaultBehavior(req) + } + + // Check if rule should activate based on probability (0.0 = never, 1.0 = always) + if rule.Probability > 0 && rand.Float64() > rule.Probability { + // Rule didn't activate due to probability, continue with normal flow + return req, nil, nil + } + + // Apply artificial latency simulation if configured + if latency := p.getLatency(&rule.MockRule); latency != nil { + delay := p.calculateLatency(latency) + time.Sleep(delay) + } + + // Select a response from the rule's possible responses using pre-calculated weights + response := p.selectResponse(rule) + if response == nil { + // No valid response configuration, continue with normal flow + return req, nil, nil + } + + // Update statistics using atomic operations and minimal locking + atomic.AddInt64(&p.mockedRequests, 1) + + // Rule hits still need a mutex since it's a map, but we minimize lock time + p.ruleHitsMu.Lock() + p.ruleHits[rule.Name]++ + p.ruleHitsMu.Unlock() + + // Generate appropriate mock response based on type + if response.Type == ResponseTypeSuccess { + return p.generateSuccessShortCircuit(req, response) + } else if response.Type == ResponseTypeError { + return p.generateErrorShortCircuit(req, response) + } + + // Fallback: continue with normal flow if response type is unrecognized + return req, nil, nil +} + +// PostHook processes responses after provider calls +func (p *MockerPlugin) PostHook(ctx *context.Context, result *schemas.BifrostResponse, err *schemas.BifrostError) (*schemas.BifrostResponse, *schemas.BifrostError, error) { + return result, err, nil +} + +// Cleanup performs plugin cleanup and frees memory +// IMPORTANT: Call GetStats() before Cleanup() if you need the statistics, +// as this method clears all statistics data to free memory +func (p *MockerPlugin) Cleanup() error { + p.mu.Lock() + defer p.mu.Unlock() + + // Clear all statistics to free memory using atomic operations + atomic.StoreInt64(&p.totalRequests, 0) + atomic.StoreInt64(&p.mockedRequests, 0) + atomic.StoreInt64(&p.responsesGenerated, 0) + atomic.StoreInt64(&p.errorsGenerated, 0) + + // Clear rule hits map + p.ruleHitsMu.Lock() + p.ruleHits = make(map[string]int64) + p.ruleHitsMu.Unlock() + + // Clear rules to free memory + p.rules = nil + p.compiledRules = nil + + return nil +} + +// findMatchingCompiledRule finds the first rule that matches the request using pre-compiled rules +func (p *MockerPlugin) findMatchingCompiledRule(req *schemas.BifrostRequest) *compiledRule { + for i := range p.compiledRules { + rule := &p.compiledRules[i] + if !rule.Enabled { + continue + } + + if p.matchesConditionsFast(req, &rule.Conditions, rule.compiledRegex) { + return rule + } + } + return nil +} + +// matchesConditionsFast checks if request matches rule conditions with optimized performance +func (p *MockerPlugin) matchesConditionsFast(req *schemas.BifrostRequest, conditions *Conditions, compiledRegex *regexp.Regexp) bool { + // Check providers - optimized string comparison + if len(conditions.Providers) > 0 { + providerStr := string(req.Provider) + found := false + for _, provider := range conditions.Providers { + if providerStr == provider { + found = true + break + } + } + if !found { + return false + } + } + + // Check models - direct string comparison + if len(conditions.Models) > 0 { + found := false + for _, model := range conditions.Models { + if req.Model == model { + found = true + break + } + } + if !found { + return false + } + } + + // Check message regex using pre-compiled regex (major performance improvement) + if compiledRegex != nil { + // Extract message content from request (cached if possible) + messageContent := p.extractMessageContentFast(req) + if !compiledRegex.MatchString(messageContent) { + return false + } + } + + // Check request size - only calculate if needed + if conditions.RequestSize != nil { + size := p.calculateRequestSizeFast(req) + if size < conditions.RequestSize.Min || size > conditions.RequestSize.Max { + return false + } + } + + // All conditions matched + return true +} + +// extractMessageContentFast extracts message content with optimized performance +func (p *MockerPlugin) extractMessageContentFast(req *schemas.BifrostRequest) string { + // Handle text completion input + if req.Input.TextCompletionInput != nil { + return *req.Input.TextCompletionInput + } + + // Handle chat completion input - optimized for common cases + if req.Input.ChatCompletionInput != nil { + messages := *req.Input.ChatCompletionInput + if len(messages) == 0 { + return "" + } + + // Fast path for single message + if len(messages) == 1 { + if messages[0].Content.ContentStr != nil { + return *messages[0].Content.ContentStr + } + return "" + } + + // Multiple messages - use string builder for efficiency + var builder strings.Builder + for i, message := range messages { + if message.Content.ContentStr != nil { + if i > 0 { + builder.WriteByte(' ') + } + builder.WriteString(*message.Content.ContentStr) + } + } + return builder.String() + } + + return "" +} + +// calculateRequestSizeFast calculates request size with minimal overhead +func (p *MockerPlugin) calculateRequestSizeFast(req *schemas.BifrostRequest) int { + // Approximate size calculation to avoid expensive JSON marshaling + size := len(req.Model) + len(string(req.Provider)) + + // Add input size + if req.Input.TextCompletionInput != nil { + size += len(*req.Input.TextCompletionInput) + } + + if req.Input.ChatCompletionInput != nil { + for _, message := range *req.Input.ChatCompletionInput { + if message.Content.ContentStr != nil { + size += len(*message.Content.ContentStr) + } + size += 50 // Approximate overhead for message structure + } + } + + return size +} + +// generateSuccessShortCircuit creates a success response short-circuit with optimized allocations +func (p *MockerPlugin) generateSuccessShortCircuit(req *schemas.BifrostRequest, response *Response) (*schemas.BifrostRequest, *schemas.PluginShortCircuit, error) { + if response.Content == nil { + return req, nil, nil + } + + content := response.Content + message := content.Message + + // Apply message template if provided + if content.MessageTemplate != nil { + message = p.applyTemplate(*content.MessageTemplate, req) + } + + // Apply defaults for token usage if not provided + var usage schemas.LLMUsage + if content.Usage != nil { + usage = schemas.LLMUsage{ + PromptTokens: p.getOrDefault(content.Usage.PromptTokens, 10), + CompletionTokens: p.getOrDefault(content.Usage.CompletionTokens, 20), + TotalTokens: p.getOrDefault(content.Usage.TotalTokens, content.Usage.PromptTokens+content.Usage.CompletionTokens), + } + } else { + // Default usage when none specified + usage = schemas.LLMUsage{ + PromptTokens: 10, + CompletionTokens: 20, + TotalTokens: 30, + } + } + + // Get finish reason with minimal allocation + var finishReason *string + if content.FinishReason != nil { + finishReason = content.FinishReason + } else { + // Use a static string to avoid allocation + static := "stop" + finishReason = &static + } + + // Create mock response with proper structure + mockResponse := &schemas.BifrostResponse{ + Model: req.Model, + Usage: &usage, + Choices: []schemas.BifrostResponseChoice{ + { + Index: 0, + BifrostNonStreamResponseChoice: &schemas.BifrostNonStreamResponseChoice{ + Message: schemas.BifrostMessage{ + Role: schemas.ModelChatMessageRoleAssistant, + Content: schemas.MessageContent{ + ContentStr: &message, + }, + }, + }, + FinishReason: finishReason, + }, + }, + ExtraFields: schemas.BifrostResponseExtraFields{ + Provider: req.Provider, + }, + } + + // Override model if specified + if content.Model != nil { + mockResponse.Model = *content.Model + } + + // Only create raw response map if there are custom fields (avoid allocation) + if len(content.CustomFields) > 0 { + rawResponse := make(map[string]interface{}, len(content.CustomFields)+1) + + // Add custom fields + for key, value := range content.CustomFields { + rawResponse[key] = value + } + + // Add mock metadata + rawResponse["mock_rule"] = "success" + mockResponse.ExtraFields.RawResponse = rawResponse + } + + // Increment success response counter using atomic operation + atomic.AddInt64(&p.responsesGenerated, 1) + + return req, &schemas.PluginShortCircuit{ + Response: mockResponse, + }, nil +} + +// generateErrorShortCircuit creates an error response short-circuit with optimized performance +func (p *MockerPlugin) generateErrorShortCircuit(req *schemas.BifrostRequest, response *Response) (*schemas.BifrostRequest, *schemas.PluginShortCircuit, error) { + if response.Error == nil { + return req, nil, nil + } + + errorContent := response.Error + allowFallbacks := response.AllowFallbacks + + // Create mock error + mockError := &schemas.BifrostError{ + Error: schemas.ErrorField{ + Message: errorContent.Message, + }, + AllowFallbacks: allowFallbacks, + } + + // Set error type + if errorContent.Type != nil { + mockError.Error.Type = errorContent.Type + } + + // Set error code + if errorContent.Code != nil { + mockError.Error.Code = errorContent.Code + } + + // Set status code + if errorContent.StatusCode != nil { + mockError.StatusCode = errorContent.StatusCode + } + + // Increment error counter using atomic operation + atomic.AddInt64(&p.errorsGenerated, 1) + + return req, &schemas.PluginShortCircuit{ + Error: mockError, + }, nil +} + +// selectResponse selects a response using pre-calculated normalized weights for optimal performance +func (p *MockerPlugin) selectResponse(rule *compiledRule) *Response { + responses := rule.Responses + normalizedWeights := rule.normalizedWeights + + if len(responses) == 0 { + return nil + } + + if len(responses) == 1 { + return &responses[0] + } + + // Fast O(log n) binary search using pre-calculated cumulative weights + randomValue := rand.Float64() + + // Binary search for the selected response + left, right := 0, len(normalizedWeights)-1 + for left < right { + mid := (left + right) / 2 + if randomValue <= normalizedWeights[mid] { + right = mid + } else { + left = mid + 1 + } + } + + return &responses[left] +} + +// getLatency returns the applicable latency configuration +func (p *MockerPlugin) getLatency(rule *MockRule) *Latency { + if rule.Latency != nil { + return rule.Latency + } + return p.config.GlobalLatency +} + +// calculateLatency calculates the actual delay based on latency configuration +func (p *MockerPlugin) calculateLatency(latency *Latency) time.Duration { + switch latency.Type { + case LatencyTypeFixed: + return latency.Min + case LatencyTypeUniform: + if latency.Max <= latency.Min { + return latency.Min + } + // Calculate random duration between Min and Max + diff := latency.Max - latency.Min + return latency.Min + time.Duration(rand.Float64()*float64(diff)) + default: + // Default to fixed latency + return latency.Min + } +} + +// handleDefaultBehavior handles requests when no rules match +func (p *MockerPlugin) handleDefaultBehavior(req *schemas.BifrostRequest) (*schemas.BifrostRequest, *schemas.PluginShortCircuit, error) { + switch p.config.DefaultBehavior { + case DefaultBehaviorError: + return req, &schemas.PluginShortCircuit{ + Error: &schemas.BifrostError{ + Error: schemas.ErrorField{ + Message: "Mock plugin default error", + }, + }, + }, nil + case DefaultBehaviorSuccess: + finishReason := "stop" + return req, &schemas.PluginShortCircuit{ + Response: &schemas.BifrostResponse{ + Model: req.Model, + Usage: &schemas.LLMUsage{ + PromptTokens: 5, + CompletionTokens: 10, + TotalTokens: 15, + }, + Choices: []schemas.BifrostResponseChoice{ + { + Index: 0, + BifrostNonStreamResponseChoice: &schemas.BifrostNonStreamResponseChoice{ + Message: schemas.BifrostMessage{ + Role: schemas.ModelChatMessageRoleAssistant, + Content: schemas.MessageContent{ + ContentStr: bifrost.Ptr("Mock plugin default response"), + }, + }, + }, + FinishReason: &finishReason, + }, + }, + ExtraFields: schemas.BifrostResponseExtraFields{ + Provider: req.Provider, + }, + }, + }, nil + default: // DefaultBehaviorPassthrough + return req, nil, nil + } +} + +// Helper functions + +// sortCompiledRulesByPriority sorts rules by priority (descending) +func (p *MockerPlugin) sortCompiledRulesByPriority() { + sort.Slice(p.compiledRules, func(i, j int) bool { + return p.compiledRules[i].Priority > p.compiledRules[j].Priority + }) +} + +// applyTemplate applies template variables with optimized string operations including faker support +func (p *MockerPlugin) applyTemplate(template string, req *schemas.BifrostRequest) string { + // Fast path: no template variables + if !strings.Contains(template, "{{") { + return template + } + + result := template + + // Replace basic variables first + replacer := strings.NewReplacer( + "{{provider}}", string(req.Provider), + "{{model}}", req.Model, + ) + result = replacer.Replace(result) + + // Handle faker variables with regex for more complex patterns + fakerRegex := regexp.MustCompile(`\{\{faker\.([^}]+)\}\}`) + result = fakerRegex.ReplaceAllStringFunc(result, func(match string) string { + // Extract the faker method name + submatch := fakerRegex.FindStringSubmatch(match) + if len(submatch) < 2 { + return match // Return original if no match + } + + fakerMethod := submatch[1] + return p.generateFakerValue(fakerMethod) + }) + + return result +} + +// generateFakerValue generates fake data based on the faker method name +func (p *MockerPlugin) generateFakerValue(method string) string { + // Parse method with potential parameters (e.g., "lorem_ipsum:20" for 20 words) + parts := strings.Split(method, ":") + baseMethod := parts[0] + + switch baseMethod { + case "name": + return p.faker.Person().Name() + case "first_name": + return p.faker.Person().FirstName() + case "last_name": + return p.faker.Person().LastName() + case "email": + return p.faker.Internet().Email() + case "phone": + return p.faker.Phone().Number() + case "address": + return p.faker.Address().Address() + case "city": + return p.faker.Address().City() + case "state": + return p.faker.Address().State() + case "zip_code": + return p.faker.Address().PostCode() + case "company": + return p.faker.Company().Name() + case "job_title": + return p.faker.Company().JobTitle() + case "lorem_ipsum": + wordCount := 10 // default + if len(parts) > 1 { + if count, err := fmt.Sscanf(parts[1], "%d", &wordCount); err != nil || count != 1 { + wordCount = 10 + } + } + return p.faker.Lorem().Sentence(wordCount) + case "uuid": + return p.faker.UUID().V4() + case "hex_color": + return p.faker.Color().Hex() + case "integer": + min, max := 1, 100 // defaults + if len(parts) > 1 { + params := strings.Split(parts[1], ",") + if len(params) >= 2 { + if _, err := fmt.Sscanf(params[0], "%d", &min); err != nil { + min = 1 // fallback to default on parse error + } + if _, err := fmt.Sscanf(params[1], "%d", &max); err != nil { + max = 100 // fallback to default on parse error + } + } + } + return fmt.Sprintf("%d", p.faker.IntBetween(min, max)) + case "float": + min, max := 0, 100 // defaults as integers + if len(parts) > 1 { + params := strings.Split(parts[1], ",") + if len(params) >= 2 { + if _, err := fmt.Sscanf(params[0], "%d", &min); err != nil { + min = 0 // fallback to default on parse error + } + if _, err := fmt.Sscanf(params[1], "%d", &max); err != nil { + max = 100 // fallback to default on parse error + } + } + } + return fmt.Sprintf("%.2f", p.faker.Float64(2, min, max)) + case "boolean": + return fmt.Sprintf("%t", p.faker.Bool()) + case "date": + return p.faker.Time().Time(time.Now()).Format("2006-01-02") + case "datetime": + return p.faker.Time().Time(time.Now()).Format("2006-01-02 15:04:05") + case "word": + return p.faker.Lorem().Word() + case "sentence": + wordCount := 8 // default + if len(parts) > 1 { + if count, err := fmt.Sscanf(parts[1], "%d", &wordCount); err != nil || count != 1 { + wordCount = 8 + } + } + return p.faker.Lorem().Sentence(wordCount) + default: + // Return the original placeholder if method is not recognized + return fmt.Sprintf("{{faker.%s}}", method) + } +} + +// getOrDefault returns value or default if 0 +func (p *MockerPlugin) getOrDefault(value, defaultValue int) int { + if value == 0 { + return defaultValue + } + return value +} + +// GetStats returns current plugin statistics +// IMPORTANT: Call this method before Cleanup() if you need the statistics, +// as Cleanup() clears all statistics data to free memory +func (p *MockerPlugin) GetStats() MockStats { + p.mu.RLock() + defer p.mu.RUnlock() + + // Create a deep copy using atomic reads for counters + statsCopy := MockStats{ + TotalRequests: atomic.LoadInt64(&p.totalRequests), + MockedRequests: atomic.LoadInt64(&p.mockedRequests), + ErrorsGenerated: atomic.LoadInt64(&p.errorsGenerated), + ResponsesGenerated: atomic.LoadInt64(&p.responsesGenerated), + RuleHits: make(map[string]int64), + } + + // Copy rule hits map (still needs lock) + p.ruleHitsMu.RLock() + maps.Copy(statsCopy.RuleHits, p.ruleHits) + p.ruleHitsMu.RUnlock() + + return statsCopy +} diff --git a/plugins/mocker/plugin_test.go b/plugins/mocker/plugin_test.go new file mode 100644 index 000000000..820019ae6 --- /dev/null +++ b/plugins/mocker/plugin_test.go @@ -0,0 +1,544 @@ +package mocker + +import ( + "context" + "testing" + + bifrost "github.com/maximhq/bifrost/core" + "github.com/maximhq/bifrost/core/schemas" +) + +// BaseAccount implements the schemas.Account interface for testing purposes. +// It provides mock implementations of the required methods to test the Mocker plugin +// with a basic OpenAI configuration. +type BaseAccount struct{} + +// GetConfiguredProviders returns a list of supported providers for testing. +func (baseAccount *BaseAccount) GetConfiguredProviders() ([]schemas.ModelProvider, error) { + return []schemas.ModelProvider{schemas.OpenAI, schemas.Anthropic}, nil +} + +// GetKeysForProvider returns a dummy API key configuration for testing. +// Since we're testing the mocker plugin, these keys should never be used +// as the plugin intercepts requests before they reach the actual providers. +func (baseAccount *BaseAccount) GetKeysForProvider(ctx *context.Context, providerKey schemas.ModelProvider) ([]schemas.Key, error) { + return []schemas.Key{ + { + Value: "dummy-api-key-for-testing", // Dummy key + Models: []string{"gpt-4", "gpt-4-turbo", "claude-3"}, + Weight: 1.0, + }, + }, nil +} + +// GetConfigForProvider returns default provider configuration for testing. +func (baseAccount *BaseAccount) GetConfigForProvider(providerKey schemas.ModelProvider) (*schemas.ProviderConfig, error) { + return &schemas.ProviderConfig{ + NetworkConfig: schemas.DefaultNetworkConfig, + ConcurrencyAndBufferSize: schemas.DefaultConcurrencyAndBufferSize, + }, nil +} + +// TestMockerPlugin_GetName tests the plugin name +func TestMockerPlugin_GetName(t *testing.T) { + plugin, err := Init(MockerConfig{}) + if err != nil { + t.Fatalf("Expected no error creating plugin, got: %v", err) + } + if plugin.GetName() != PluginName { + t.Errorf("Expected '%s', got '%s'", PluginName, plugin.GetName()) + } +} + +// TestMockerPlugin_Disabled tests that disabled plugin doesn't interfere +func TestMockerPlugin_Disabled(t *testing.T) { + ctx := context.Background() + config := MockerConfig{ + Enabled: false, + } + plugin, err := Init(config) + if err != nil { + t.Fatalf("Expected no error creating plugin, got: %v", err) + } + + account := BaseAccount{} + client, err := bifrost.Init(ctx, schemas.BifrostConfig{ + Account: &account, + Plugins: []schemas.Plugin{plugin}, + Logger: bifrost.NewDefaultLogger(schemas.LogLevelError), + }) + if err != nil { + t.Fatalf("Error initializing Bifrost: %v", err) + } + defer client.Shutdown() + + // This should pass through to the real provider (but will fail due to dummy key) + _, bifrostErr := client.ChatCompletionRequest(ctx, &schemas.BifrostRequest{ + Provider: schemas.OpenAI, + Model: "gpt-4", + Input: schemas.RequestInput{ + ChatCompletionInput: &[]schemas.BifrostMessage{ + { + Role: schemas.ModelChatMessageRoleUser, + Content: schemas.MessageContent{ + ContentStr: bifrost.Ptr("Hello, test message"), + }, + }, + }, + }, + }) + + // Should get an authentication error from OpenAI, not a mock response + // This proves the plugin is disabled and not intercepting requests + if bifrostErr == nil { + t.Error("Expected error from real provider with dummy API key") + } +} + +// TestMockerPlugin_DefaultMockRule tests the default catch-all rule +func TestMockerPlugin_DefaultMockRule(t *testing.T) { + ctx := context.Background() + config := MockerConfig{ + Enabled: true, // No rules provided, should create default rule + } + plugin, err := Init(config) + if err != nil { + t.Fatalf("Expected no error creating plugin, got: %v", err) + } + + account := BaseAccount{} + client, err := bifrost.Init(ctx, schemas.BifrostConfig{ + Account: &account, + Plugins: []schemas.Plugin{plugin}, + Logger: bifrost.NewDefaultLogger(schemas.LogLevelError), + }) + if err != nil { + t.Fatalf("Error initializing Bifrost: %v", err) + } + defer client.Shutdown() + + response, bifrostErr := client.ChatCompletionRequest(ctx, &schemas.BifrostRequest{ + Provider: schemas.OpenAI, + Model: "gpt-4", + Input: schemas.RequestInput{ + ChatCompletionInput: &[]schemas.BifrostMessage{ + { + Role: schemas.ModelChatMessageRoleUser, + Content: schemas.MessageContent{ + ContentStr: bifrost.Ptr("Hello, test message"), + }, + }, + }, + }, + }) + + if bifrostErr != nil { + t.Fatalf("Expected no error, got: %v", bifrostErr) + } + if response == nil { + t.Fatal("Expected response") + } + if len(response.Choices) == 0 { + t.Fatal("Expected at least one choice") + } + if response.Choices[0].Message.Content.ContentStr == nil { + t.Fatal("Expected content string") + } + if *response.Choices[0].Message.Content.ContentStr != "This is a mock response from the Mocker plugin" { + t.Errorf("Expected default mock message, got: %s", *response.Choices[0].Message.Content.ContentStr) + } +} + +// TestMockerPlugin_CustomSuccessRule tests custom success response +func TestMockerPlugin_CustomSuccessRule(t *testing.T) { + ctx := context.Background() + config := MockerConfig{ + Enabled: true, + Rules: []MockRule{ + { + Name: "openai-success", + Enabled: true, + Priority: 100, + Probability: 1.0, + Conditions: Conditions{ + Providers: []string{"openai"}, + }, + Responses: []Response{ + { + Type: ResponseTypeSuccess, + Content: &SuccessResponse{ + Message: "Custom OpenAI mock response", + Usage: &Usage{ + PromptTokens: 15, + CompletionTokens: 25, + TotalTokens: 40, + }, + }, + }, + }, + }, + }, + } + plugin, err := Init(config) + if err != nil { + t.Fatalf("Expected no error creating plugin, got: %v", err) + } + + account := BaseAccount{} + client, err := bifrost.Init(ctx, schemas.BifrostConfig{ + Account: &account, + Plugins: []schemas.Plugin{plugin}, + Logger: bifrost.NewDefaultLogger(schemas.LogLevelError), + }) + if err != nil { + t.Fatalf("Error initializing Bifrost: %v", err) + } + defer client.Shutdown() + + response, bifrostErr := client.ChatCompletionRequest(ctx, &schemas.BifrostRequest{ + Provider: schemas.OpenAI, + Model: "gpt-4", + Input: schemas.RequestInput{ + ChatCompletionInput: &[]schemas.BifrostMessage{ + { + Role: schemas.ModelChatMessageRoleUser, + Content: schemas.MessageContent{ + ContentStr: bifrost.Ptr("Hello, test message"), + }, + }, + }, + }, + }) + + if bifrostErr != nil { + t.Fatalf("Expected no error, got: %v", bifrostErr) + } + if response == nil { + t.Fatal("Expected response") + } + if len(response.Choices) == 0 { + t.Fatal("Expected at least one choice") + } + if response.Choices[0].Message.Content.ContentStr == nil { + t.Fatal("Expected content string") + } + if *response.Choices[0].Message.Content.ContentStr != "Custom OpenAI mock response" { + t.Errorf("Expected custom message, got: %s", *response.Choices[0].Message.Content.ContentStr) + } + if response.Usage.TotalTokens != 40 { + t.Errorf("Expected 40 total tokens, got %d", response.Usage.TotalTokens) + } +} + +// TestMockerPlugin_ErrorResponse tests error response generation +func TestMockerPlugin_ErrorResponse(t *testing.T) { + ctx := context.Background() + allowFallbacks := false + config := MockerConfig{ + Enabled: true, + Rules: []MockRule{ + { + Name: "rate-limit-error", + Enabled: true, + Priority: 100, + Probability: 1.0, + Conditions: Conditions{ + Providers: []string{"openai"}, + }, + Responses: []Response{ + { + Type: ResponseTypeError, + AllowFallbacks: &allowFallbacks, + Error: &ErrorResponse{ + Message: "Rate limit exceeded", + Type: bifrost.Ptr("rate_limit"), + Code: bifrost.Ptr("429"), + StatusCode: bifrost.Ptr(429), + }, + }, + }, + }, + }, + } + plugin, err := Init(config) + if err != nil { + t.Fatalf("Expected no error creating plugin, got: %v", err) + } + + account := BaseAccount{} + client, err := bifrost.Init(ctx, schemas.BifrostConfig{ + Account: &account, + Plugins: []schemas.Plugin{plugin}, + Logger: bifrost.NewDefaultLogger(schemas.LogLevelError), + }) + if err != nil { + t.Fatalf("Error initializing Bifrost: %v", err) + } + defer client.Shutdown() + + _, bifrostErr := client.ChatCompletionRequest(ctx, &schemas.BifrostRequest{ + Provider: schemas.OpenAI, + Model: "gpt-4", + Input: schemas.RequestInput{ + ChatCompletionInput: &[]schemas.BifrostMessage{ + { + Role: schemas.ModelChatMessageRoleUser, + Content: schemas.MessageContent{ + ContentStr: bifrost.Ptr("Hello, test message"), + }, + }, + }, + }, + }) + + if bifrostErr == nil { + t.Fatal("Expected error response") + } + if bifrostErr.Error.Message != "Rate limit exceeded" { + t.Errorf("Expected 'Rate limit exceeded', got: %s", bifrostErr.Error.Message) + } + if bifrostErr.StatusCode == nil || *bifrostErr.StatusCode != 429 { + t.Errorf("Expected status code 429, got: %v", bifrostErr.StatusCode) + } +} + +// TestMockerPlugin_MessageTemplate tests template variable substitution +func TestMockerPlugin_MessageTemplate(t *testing.T) { + ctx := context.Background() + config := MockerConfig{ + Enabled: true, + Rules: []MockRule{ + { + Name: "template-test", + Enabled: true, + Priority: 100, + Probability: 1.0, + Conditions: Conditions{}, // Match all + Responses: []Response{ + { + Type: ResponseTypeSuccess, + Content: &SuccessResponse{ + MessageTemplate: bifrost.Ptr("Hello from {{provider}} using model {{model}}"), + }, + }, + }, + }, + }, + } + plugin, err := Init(config) + if err != nil { + t.Fatalf("Expected no error creating plugin, got: %v", err) + } + + account := BaseAccount{} + client, err := bifrost.Init(ctx, schemas.BifrostConfig{ + Account: &account, + Plugins: []schemas.Plugin{plugin}, + Logger: bifrost.NewDefaultLogger(schemas.LogLevelError), + }) + if err != nil { + t.Fatalf("Error initializing Bifrost: %v", err) + } + defer client.Shutdown() + + response, bifrostErr := client.ChatCompletionRequest(ctx, &schemas.BifrostRequest{ + Provider: schemas.Anthropic, + Model: "claude-3", + Input: schemas.RequestInput{ + ChatCompletionInput: &[]schemas.BifrostMessage{ + { + Role: schemas.ModelChatMessageRoleUser, + Content: schemas.MessageContent{ + ContentStr: bifrost.Ptr("Hello, test message"), + }, + }, + }, + }, + }) + + if bifrostErr != nil { + t.Fatalf("Expected no error, got: %v", bifrostErr) + } + if response == nil { + t.Fatal("Expected response") + } + if len(response.Choices) == 0 { + t.Fatal("Expected at least one choice") + } + if response.Choices[0].Message.Content.ContentStr == nil { + t.Fatal("Expected content string") + } + expectedMessage := "Hello from anthropic using model claude-3" + if *response.Choices[0].Message.Content.ContentStr != expectedMessage { + t.Errorf("Expected '%s', got: %s", expectedMessage, *response.Choices[0].Message.Content.ContentStr) + } +} + +// TestMockerPlugin_Statistics tests plugin statistics tracking +func TestMockerPlugin_Statistics(t *testing.T) { + ctx := context.Background() + config := MockerConfig{ + Enabled: true, + Rules: []MockRule{ + { + Name: "stats-test", + Enabled: true, + Priority: 100, + Probability: 1.0, + Conditions: Conditions{}, // Match all + Responses: []Response{ + { + Type: ResponseTypeSuccess, + Content: &SuccessResponse{ + Message: "Stats test response", + }, + }, + }, + }, + }, + } + plugin, err := Init(config) + if err != nil { + t.Fatalf("Expected no error creating plugin, got: %v", err) + } + + account := BaseAccount{} + client, err := bifrost.Init(ctx, schemas.BifrostConfig{ + Account: &account, + Plugins: []schemas.Plugin{plugin}, + Logger: bifrost.NewDefaultLogger(schemas.LogLevelError), + }) + if err != nil { + t.Fatalf("Error initializing Bifrost: %v", err) + } + defer client.Shutdown() + + // Make multiple requests + for i := 0; i < 3; i++ { + _, _ = client.ChatCompletionRequest(ctx, &schemas.BifrostRequest{ + Provider: schemas.OpenAI, + Model: "gpt-4", + Input: schemas.RequestInput{ + ChatCompletionInput: &[]schemas.BifrostMessage{ + { + Role: schemas.ModelChatMessageRoleUser, + Content: schemas.MessageContent{ + ContentStr: bifrost.Ptr("Hello, test message"), + }, + }, + }, + }, + }) + } + + // Check statistics + stats := plugin.GetStats() + if stats.TotalRequests != 3 { + t.Errorf("Expected 3 total requests, got %d", stats.TotalRequests) + } + if stats.MockedRequests != 3 { + t.Errorf("Expected 3 mocked requests, got %d", stats.MockedRequests) + } + if stats.ResponsesGenerated != 3 { + t.Errorf("Expected 3 responses generated, got %d", stats.ResponsesGenerated) + } + if stats.RuleHits["stats-test"] != 3 { + t.Errorf("Expected 3 hits for 'stats-test' rule, got %d", stats.RuleHits["stats-test"]) + } +} + +// TestMockerPlugin_ValidationErrors tests configuration validation +func TestMockerPlugin_ValidationErrors(t *testing.T) { + tests := []struct { + name string + config MockerConfig + expectError bool + }{ + { + name: "invalid default behavior", + config: MockerConfig{ + Enabled: true, + DefaultBehavior: "invalid", + }, + expectError: true, + }, + { + name: "missing rule name", + config: MockerConfig{ + Enabled: true, + Rules: []MockRule{ + { + Name: "", // Missing name + Enabled: true, + Responses: []Response{ + { + Type: ResponseTypeSuccess, + Content: &SuccessResponse{ + Message: "test", + }, + }, + }, + }, + }, + }, + expectError: true, + }, + { + name: "invalid probability", + config: MockerConfig{ + Enabled: true, + Rules: []MockRule{ + { + Name: "test", + Enabled: true, + Probability: 1.5, // Invalid probability > 1 + Responses: []Response{ + { + Type: ResponseTypeSuccess, + Content: &SuccessResponse{ + Message: "test", + }, + }, + }, + }, + }, + }, + expectError: true, + }, + { + name: "valid configuration", + config: MockerConfig{ + Enabled: true, + DefaultBehavior: DefaultBehaviorPassthrough, + Rules: []MockRule{ + { + Name: "valid-rule", + Enabled: true, + Probability: 0.5, + Responses: []Response{ + { + Type: ResponseTypeSuccess, + Content: &SuccessResponse{ + Message: "Valid response", + }, + }, + }, + }, + }, + }, + expectError: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + _, err := Init(tt.config) + if tt.expectError && err == nil { + t.Error("Expected error but got none") + } + if !tt.expectError && err != nil { + t.Errorf("Expected no error but got: %v", err) + } + }) + } +} diff --git a/plugins/mocker/version b/plugins/mocker/version new file mode 100644 index 000000000..fd9d1a5ac --- /dev/null +++ b/plugins/mocker/version @@ -0,0 +1 @@ +1.2.14 diff --git a/plugins/semanticcache/changelog.md b/plugins/semanticcache/changelog.md new file mode 100644 index 000000000..6dcfe4edd --- /dev/null +++ b/plugins/semanticcache/changelog.md @@ -0,0 +1,4 @@ + + + +- Upgrades framework to 1.0.23 \ No newline at end of file diff --git a/plugins/semanticcache/go.mod b/plugins/semanticcache/go.mod new file mode 100644 index 000000000..72b4ea973 --- /dev/null +++ b/plugins/semanticcache/go.mod @@ -0,0 +1,83 @@ +module github.com/maximhq/bifrost/plugins/semanticcache + +go 1.24 + +toolchain go1.24.3 + +require ( + github.com/cespare/xxhash/v2 v2.3.0 + github.com/google/uuid v1.6.0 + github.com/maximhq/bifrost/core v1.1.37 + github.com/maximhq/bifrost/framework v1.0.23 +) + +require ( + cloud.google.com/go/compute/metadata v0.8.0 // indirect + github.com/andybalholm/brotli v1.2.0 // indirect + github.com/asaskevich/govalidator v0.0.0-20230301143203-a9d515a09cc2 // indirect + github.com/aws/aws-sdk-go-v2 v1.38.0 // indirect + github.com/aws/aws-sdk-go-v2/config v1.31.0 // indirect + github.com/aws/aws-sdk-go-v2/credentials v1.18.4 // indirect + github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.18.3 // indirect + github.com/aws/aws-sdk-go-v2/internal/configsources v1.4.3 // indirect + github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.7.3 // indirect + github.com/aws/aws-sdk-go-v2/internal/ini v1.8.3 // indirect + github.com/aws/aws-sdk-go-v2/service/internal/accept-encoding v1.13.0 // indirect + github.com/aws/aws-sdk-go-v2/service/internal/presigned-url v1.13.3 // indirect + github.com/aws/aws-sdk-go-v2/service/sso v1.28.0 // indirect + github.com/aws/aws-sdk-go-v2/service/ssooidc v1.33.0 // indirect + github.com/aws/aws-sdk-go-v2/service/sts v1.37.0 // indirect + github.com/aws/smithy-go v1.22.5 // indirect + github.com/bahlo/generic-list-go v0.2.0 // indirect + github.com/buger/jsonparser v1.1.1 // indirect + github.com/bytedance/sonic v1.14.0 // indirect + github.com/bytedance/sonic/loader v0.3.0 // indirect + github.com/cloudwego/base64x v0.1.6 // indirect + github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f // indirect + github.com/go-openapi/analysis v0.23.0 // indirect + github.com/go-openapi/errors v0.22.0 // indirect + github.com/go-openapi/jsonpointer v0.21.0 // indirect + github.com/go-openapi/jsonreference v0.21.0 // indirect + github.com/go-openapi/loads v0.22.0 // indirect + github.com/go-openapi/runtime v0.24.2 // indirect + github.com/go-openapi/spec v0.21.0 // indirect + github.com/go-openapi/strfmt v0.23.0 // indirect + github.com/go-openapi/swag v0.23.0 // indirect + github.com/go-openapi/validate v0.24.0 // indirect + github.com/invopop/jsonschema v0.13.0 // indirect + github.com/josharian/intern v1.0.0 // indirect + github.com/klauspost/compress v1.18.0 // indirect + github.com/klauspost/cpuid/v2 v2.3.0 // indirect + github.com/mailru/easyjson v0.9.0 // indirect + github.com/mark3labs/mcp-go v0.37.0 // indirect + github.com/mattn/go-colorable v0.1.14 // indirect + github.com/mattn/go-isatty v0.0.20 // indirect + github.com/mitchellh/mapstructure v1.5.0 // indirect + github.com/oklog/ulid v1.3.1 // indirect + github.com/opentracing/opentracing-go v1.2.0 // indirect + github.com/pkg/errors v0.9.1 // indirect + github.com/redis/go-redis/v9 v9.12.1 // indirect + github.com/rs/zerolog v1.34.0 // indirect + github.com/spf13/cast v1.9.2 // indirect + github.com/twitchyliquid64/golang-asm v0.15.1 // indirect + github.com/valyala/bytebufferpool v1.0.0 // indirect + github.com/valyala/fasthttp v1.65.0 // indirect + github.com/weaviate/weaviate v1.31.5 // indirect + github.com/weaviate/weaviate-go-client/v5 v5.2.0 // indirect + github.com/wk8/go-ordered-map/v2 v2.1.8 // indirect + github.com/yosida95/uritemplate/v3 v3.0.2 // indirect + go.mongodb.org/mongo-driver v1.14.0 // indirect + go.opentelemetry.io/otel/metric v1.37.0 // indirect + go.opentelemetry.io/otel/sdk v1.37.0 // indirect + go.opentelemetry.io/otel/trace v1.37.0 // indirect + golang.org/x/arch v0.20.0 // indirect + golang.org/x/net v0.43.0 // indirect + golang.org/x/oauth2 v0.30.0 // indirect + golang.org/x/sys v0.35.0 // indirect + golang.org/x/text v0.28.0 // indirect + google.golang.org/genproto/googleapis/rpc v0.0.0-20250811230008-5f3141c8851a // indirect + google.golang.org/grpc v1.74.2 // indirect + google.golang.org/protobuf v1.36.7 // indirect + gopkg.in/yaml.v2 v2.4.0 // indirect + gopkg.in/yaml.v3 v3.0.1 // indirect +) diff --git a/plugins/semanticcache/go.sum b/plugins/semanticcache/go.sum new file mode 100644 index 000000000..3453ca6da --- /dev/null +++ b/plugins/semanticcache/go.sum @@ -0,0 +1,345 @@ +cloud.google.com/go/compute/metadata v0.8.0 h1:HxMRIbao8w17ZX6wBnjhcDkW6lTFpgcaobyVfZWqRLA= +cloud.google.com/go/compute/metadata v0.8.0/go.mod h1:sYOGTp851OV9bOFJ9CH7elVvyzopvWQFNNghtDQ/Biw= +github.com/BurntSushi/toml v0.3.1/go.mod h1:xHWCNGjB5oqiDr8zfno3MHue2Ht5sIBksp03qcyfWMU= +github.com/PuerkitoBio/purell v1.1.1/go.mod h1:c11w/QuzBsJSee3cPx9rAFu61PvFxuPbtSwDGJws/X0= +github.com/PuerkitoBio/urlesc v0.0.0-20170810143723-de5bf2ad4578/go.mod h1:uGdkoq3SwY9Y+13GIhn11/XLaGBb4BfwItxLd5jeuXE= +github.com/andybalholm/brotli v1.2.0 h1:ukwgCxwYrmACq68yiUqwIWnGY0cTPox/M94sVwToPjQ= +github.com/andybalholm/brotli v1.2.0/go.mod h1:rzTDkvFWvIrjDXZHkuS16NPggd91W3kUSvPlQ1pLaKY= +github.com/asaskevich/govalidator v0.0.0-20200907205600-7a23bdc65eef/go.mod h1:WaHUgvxTVq04UNunO+XhnAqY/wQc+bxr74GqbsZ/Jqw= +github.com/asaskevich/govalidator v0.0.0-20210307081110-f21760c49a8d/go.mod h1:WaHUgvxTVq04UNunO+XhnAqY/wQc+bxr74GqbsZ/Jqw= +github.com/asaskevich/govalidator v0.0.0-20230301143203-a9d515a09cc2 h1:DklsrG3dyBCFEj5IhUbnKptjxatkF07cF2ak3yi77so= +github.com/asaskevich/govalidator v0.0.0-20230301143203-a9d515a09cc2/go.mod h1:WaHUgvxTVq04UNunO+XhnAqY/wQc+bxr74GqbsZ/Jqw= +github.com/aws/aws-sdk-go-v2 v1.38.0 h1:UCRQ5mlqcFk9HJDIqENSLR3wiG1VTWlyUfLDEvY7RxU= +github.com/aws/aws-sdk-go-v2 v1.38.0/go.mod h1:9Q0OoGQoboYIAJyslFyF1f5K1Ryddop8gqMhWx/n4Wg= +github.com/aws/aws-sdk-go-v2/config v1.31.0 h1:9yH0xiY5fUnVNLRWO0AtayqwU1ndriZdN78LlhruJR4= +github.com/aws/aws-sdk-go-v2/config v1.31.0/go.mod h1:VeV3K72nXnhbe4EuxxhzsDc/ByrCSlZwUnWH52Nde/I= +github.com/aws/aws-sdk-go-v2/credentials v1.18.4 h1:IPd0Algf1b+Qy9BcDp0sCUcIWdCQPSzDoMK3a8pcbUM= +github.com/aws/aws-sdk-go-v2/credentials v1.18.4/go.mod h1:nwg78FjH2qvsRM1EVZlX9WuGUJOL5od+0qvm0adEzHk= +github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.18.3 h1:GicIdnekoJsjq9wqnvyi2elW6CGMSYKhdozE7/Svh78= +github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.18.3/go.mod h1:R7BIi6WNC5mc1kfRM7XM/VHC3uRWkjc396sfabq4iOo= +github.com/aws/aws-sdk-go-v2/internal/configsources v1.4.3 h1:o9RnO+YZ4X+kt5Z7Nvcishlz0nksIt2PIzDglLMP0vA= +github.com/aws/aws-sdk-go-v2/internal/configsources v1.4.3/go.mod h1:+6aLJzOG1fvMOyzIySYjOFjcguGvVRL68R+uoRencN4= +github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.7.3 h1:joyyUFhiTQQmVK6ImzNU9TQSNRNeD9kOklqTzyk5v6s= +github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.7.3/go.mod h1:+vNIyZQP3b3B1tSLI0lxvrU9cfM7gpdRXMFfm67ZcPc= +github.com/aws/aws-sdk-go-v2/internal/ini v1.8.3 h1:bIqFDwgGXXN1Kpp99pDOdKMTTb5d2KyU5X/BZxjOkRo= +github.com/aws/aws-sdk-go-v2/internal/ini v1.8.3/go.mod h1:H5O/EsxDWyU+LP/V8i5sm8cxoZgc2fdNR9bxlOFrQTo= +github.com/aws/aws-sdk-go-v2/service/internal/accept-encoding v1.13.0 h1:6+lZi2JeGKtCraAj1rpoZfKqnQ9SptseRZioejfUOLM= +github.com/aws/aws-sdk-go-v2/service/internal/accept-encoding v1.13.0/go.mod h1:eb3gfbVIxIoGgJsi9pGne19dhCBpK6opTYpQqAmdy44= +github.com/aws/aws-sdk-go-v2/service/internal/presigned-url v1.13.3 h1:ieRzyHXypu5ByllM7Sp4hC5f/1Fy5wqxqY0yB85hC7s= +github.com/aws/aws-sdk-go-v2/service/internal/presigned-url v1.13.3/go.mod h1:O5ROz8jHiOAKAwx179v+7sHMhfobFVi6nZt8DEyiYoM= +github.com/aws/aws-sdk-go-v2/service/sso v1.28.0 h1:Mc/MKBf2m4VynyJkABoVEN+QzkfLqGj0aiJuEe7cMeM= +github.com/aws/aws-sdk-go-v2/service/sso v1.28.0/go.mod h1:iS5OmxEcN4QIPXARGhavH7S8kETNL11kym6jhoS7IUQ= +github.com/aws/aws-sdk-go-v2/service/ssooidc v1.33.0 h1:6csaS/aJmqZQbKhi1EyEMM7yBW653Wy/B9hnBofW+sw= +github.com/aws/aws-sdk-go-v2/service/ssooidc v1.33.0/go.mod h1:59qHWaY5B+Rs7HGTuVGaC32m0rdpQ68N8QCN3khYiqs= +github.com/aws/aws-sdk-go-v2/service/sts v1.37.0 h1:MG9VFW43M4A8BYeAfaJJZWrroinxeTi2r3+SnmLQfSA= +github.com/aws/aws-sdk-go-v2/service/sts v1.37.0/go.mod h1:JdeBDPgpJfuS6rU/hNglmOigKhyEZtBmbraLE4GK1J8= +github.com/aws/smithy-go v1.22.5 h1:P9ATCXPMb2mPjYBgueqJNCA5S9UfktsW0tTxi+a7eqw= +github.com/aws/smithy-go v1.22.5/go.mod h1:t1ufH5HMublsJYulve2RKmHDC15xu1f26kHCp/HgceI= +github.com/bahlo/generic-list-go v0.2.0 h1:5sz/EEAK+ls5wF+NeqDpk5+iNdMDXrh3z3nPnH1Wvgk= +github.com/bahlo/generic-list-go v0.2.0/go.mod h1:2KvAjgMlE5NNynlg/5iLrrCCZ2+5xWbdbCW3pNTGyYg= +github.com/bsm/ginkgo/v2 v2.12.0 h1:Ny8MWAHyOepLGlLKYmXG4IEkioBysk6GpaRTLC8zwWs= +github.com/bsm/ginkgo/v2 v2.12.0/go.mod h1:SwYbGRRDovPVboqFv0tPTcG1sN61LM1Z4ARdbAV9g4c= +github.com/bsm/gomega v1.27.10 h1:yeMWxP2pV2fG3FgAODIY8EiRE3dy0aeFYt4l7wh6yKA= +github.com/bsm/gomega v1.27.10/go.mod h1:JyEr/xRbxbtgWNi8tIEVPUYZ5Dzef52k01W3YH0H+O0= +github.com/buger/jsonparser v1.1.1 h1:2PnMjfWD7wBILjqQbt530v576A/cAbQvEW9gGIpYMUs= +github.com/buger/jsonparser v1.1.1/go.mod h1:6RYKKt7H4d4+iWqouImQ9R2FZql3VbhNgx27UK13J/0= +github.com/bytedance/sonic v1.14.0 h1:/OfKt8HFw0kh2rj8N0F6C/qPGRESq0BbaNZgcNXXzQQ= +github.com/bytedance/sonic v1.14.0/go.mod h1:WoEbx8WTcFJfzCe0hbmyTGrfjt8PzNEBdxlNUO24NhA= +github.com/bytedance/sonic/loader v0.3.0 h1:dskwH8edlzNMctoruo8FPTJDF3vLtDT0sXZwvZJyqeA= +github.com/bytedance/sonic/loader v0.3.0/go.mod h1:N8A3vUdtUebEY2/VQC0MyhYeKUFosQU6FxH2JmUe6VI= +github.com/cespare/xxhash/v2 v2.3.0 h1:UL815xU9SqsFlibzuggzjXhog7bL6oX9BbNZnL2UFvs= +github.com/cespare/xxhash/v2 v2.3.0/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs= +github.com/cloudwego/base64x v0.1.6 h1:t11wG9AECkCDk5fMSoxmufanudBtJ+/HemLstXDLI2M= +github.com/cloudwego/base64x v0.1.6/go.mod h1:OFcloc187FXDaYHvrNIjxSe8ncn0OOM8gEHfghB2IPU= +github.com/coreos/go-systemd/v22 v22.5.0/go.mod h1:Y58oyj3AT4RCenI/lSvhwexgC+NSVTIJ3seZv2GcEnc= +github.com/creack/pty v1.1.9/go.mod h1:oKZEueFk5CKHvIhNR5MUki03XCEU+Q6VDXinZuGJ33E= +github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc h1:U9qPSI2PIWSS1VwoXQT9A3Wy9MM3WgvqSxFWenqJduM= +github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f h1:lO4WD4F/rVNCu3HqELle0jiPLLBs70cWOduZpkS1E78= +github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f/go.mod h1:cuUVRXasLTGF7a8hSLbxyZXjz+1KgoB3wDUb6vlszIc= +github.com/docker/go-units v0.4.0/go.mod h1:fgPhTUdO+D/Jk86RDLlptpiXQzgHJF7gydDDbaIK4Dk= +github.com/frankban/quicktest v1.14.6 h1:7Xjx+VpznH+oBnejlPUj8oUpdxnVs4f8XU8WnHkI4W8= +github.com/frankban/quicktest v1.14.6/go.mod h1:4ptaffx2x8+WTWXmUCuVU6aPUX1/Mz7zb5vbUoiM6w0= +github.com/go-logr/logr v1.4.3 h1:CjnDlHq8ikf6E492q6eKboGOC0T8CDaOvkHCIg8idEI= +github.com/go-logr/logr v1.4.3/go.mod h1:9T104GzyrTigFIr8wt5mBrctHMim0Nb2HLGrmQ40KvY= +github.com/go-logr/stdr v1.2.2 h1:hSWxHoqTgW2S2qGc0LTAI563KZ5YKYRhT3MFKZMbjag= +github.com/go-logr/stdr v1.2.2/go.mod h1:mMo/vtBO5dYbehREoey6XUKy/eSumjCCveDpRre4VKE= +github.com/go-openapi/analysis v0.21.2/go.mod h1:HZwRk4RRisyG8vx2Oe6aqeSQcoxRp47Xkp3+K6q+LdY= +github.com/go-openapi/analysis v0.23.0 h1:aGday7OWupfMs+LbmLZG4k0MYXIANxcuBTYUC03zFCU= +github.com/go-openapi/analysis v0.23.0/go.mod h1:9mz9ZWaSlV8TvjQHLl2mUW2PbZtemkE8yA5v22ohupo= +github.com/go-openapi/errors v0.19.8/go.mod h1:cM//ZKUKyO06HSwqAelJ5NsEMMcpa6VpXe8DOa1Mi1M= +github.com/go-openapi/errors v0.19.9/go.mod h1:cM//ZKUKyO06HSwqAelJ5NsEMMcpa6VpXe8DOa1Mi1M= +github.com/go-openapi/errors v0.20.2/go.mod h1:cM//ZKUKyO06HSwqAelJ5NsEMMcpa6VpXe8DOa1Mi1M= +github.com/go-openapi/errors v0.22.0 h1:c4xY/OLxUBSTiepAg3j/MHuAv5mJhnf53LLMWFB+u/w= +github.com/go-openapi/errors v0.22.0/go.mod h1:J3DmZScxCDufmIMsdOuDHxJbdOGC0xtUynjIx092vXE= +github.com/go-openapi/jsonpointer v0.19.3/go.mod h1:Pl9vOtqEWErmShwVjC8pYs9cog34VGT37dQOVbmoatg= +github.com/go-openapi/jsonpointer v0.19.5/go.mod h1:Pl9vOtqEWErmShwVjC8pYs9cog34VGT37dQOVbmoatg= +github.com/go-openapi/jsonpointer v0.21.0 h1:YgdVicSA9vH5RiHs9TZW5oyafXZFc6+2Vc1rr/O9oNQ= +github.com/go-openapi/jsonpointer v0.21.0/go.mod h1:IUyH9l/+uyhIYQ/PXVA41Rexl+kOkAPDdXEYns6fzUY= +github.com/go-openapi/jsonreference v0.19.6/go.mod h1:diGHMEHg2IqXZGKxqyvWdfWU/aim5Dprw5bqpKkTvns= +github.com/go-openapi/jsonreference v0.21.0 h1:Rs+Y7hSXT83Jacb7kFyjn4ijOuVGSvOdF2+tg1TRrwQ= +github.com/go-openapi/jsonreference v0.21.0/go.mod h1:LmZmgsrTkVg9LG4EaHeY8cBDslNPMo06cago5JNLkm4= +github.com/go-openapi/loads v0.21.1/go.mod h1:/DtAMXXneXFjbQMGEtbamCZb+4x7eGwkvZCvBmwUG+g= +github.com/go-openapi/loads v0.22.0 h1:ECPGd4jX1U6NApCGG1We+uEozOAvXvJSF4nnwHZ8Aco= +github.com/go-openapi/loads v0.22.0/go.mod h1:yLsaTCS92mnSAZX5WWoxszLj0u+Ojl+Zs5Stn1oF+rs= +github.com/go-openapi/runtime v0.24.2 h1:yX9HMGQbz32M87ECaAhGpJjBmErO3QLcgdZj9BzGx7c= +github.com/go-openapi/runtime v0.24.2/go.mod h1:AKurw9fNre+h3ELZfk6ILsfvPN+bvvlaU/M9q/r9hpk= +github.com/go-openapi/spec v0.20.4/go.mod h1:faYFR1CvsJZ0mNsmsphTMSoRrNV3TEDoAM7FOEWeq8I= +github.com/go-openapi/spec v0.21.0 h1:LTVzPc3p/RzRnkQqLRndbAzjY0d0BCL72A6j3CdL9ZY= +github.com/go-openapi/spec v0.21.0/go.mod h1:78u6VdPw81XU44qEWGhtr982gJ5BWg2c0I5XwVMotYk= +github.com/go-openapi/strfmt v0.21.0/go.mod h1:ZRQ409bWMj+SOgXofQAGTIo2Ebu72Gs+WaRADcS5iNg= +github.com/go-openapi/strfmt v0.21.1/go.mod h1:I/XVKeLc5+MM5oPNN7P6urMOpuLXEcNrCX/rPGuWb0k= +github.com/go-openapi/strfmt v0.21.2/go.mod h1:I/XVKeLc5+MM5oPNN7P6urMOpuLXEcNrCX/rPGuWb0k= +github.com/go-openapi/strfmt v0.23.0 h1:nlUS6BCqcnAk0pyhi9Y+kdDVZdZMHfEKQiS4HaMgO/c= +github.com/go-openapi/strfmt v0.23.0/go.mod h1:NrtIpfKtWIygRkKVsxh7XQMDQW5HKQl6S5ik2elW+K4= +github.com/go-openapi/swag v0.19.5/go.mod h1:POnQmlKehdgb5mhVOsnJFsivZCEZ/vjK9gh66Z9tfKk= +github.com/go-openapi/swag v0.19.15/go.mod h1:QYRuS/SOXUCsnplDa677K7+DxSOj6IPNl/eQntq43wQ= +github.com/go-openapi/swag v0.21.1/go.mod h1:QYRuS/SOXUCsnplDa677K7+DxSOj6IPNl/eQntq43wQ= +github.com/go-openapi/swag v0.23.0 h1:vsEVJDUo2hPJ2tu0/Xc+4noaxyEffXNIs3cOULZ+GrE= +github.com/go-openapi/swag v0.23.0/go.mod h1:esZ8ITTYEsH1V2trKHjAN8Ai7xHb8RV+YSZ577vPjgQ= +github.com/go-openapi/validate v0.21.0/go.mod h1:rjnrwK57VJ7A8xqfpAOEKRH8yQSGUriMu5/zuPSQ1hg= +github.com/go-openapi/validate v0.24.0 h1:LdfDKwNbpB6Vn40xhTdNZAnfLECL81w+VX3BumrGD58= +github.com/go-openapi/validate v0.24.0/go.mod h1:iyeX1sEufmv3nPbBdX3ieNviWnOZaJ1+zquzJEf2BAQ= +github.com/go-stack/stack v1.8.0/go.mod h1:v0f6uXyyMGvRgIKkXu+yp6POWl0qKG85gN/melR3HDY= +github.com/go-stack/stack v1.8.1/go.mod h1:dcoOX6HbPZSZptuspn9bctJ+N/CnF5gGygcUP3XYfe4= +github.com/gobuffalo/attrs v0.0.0-20190224210810-a9411de4debd/go.mod h1:4duuawTqi2wkkpB4ePgWMaai6/Kc6WEz83bhFwpHzj0= +github.com/gobuffalo/depgen v0.0.0-20190329151759-d478694a28d3/go.mod h1:3STtPUQYuzV0gBVOY3vy6CfMm/ljR4pABfrTeHNLHUY= +github.com/gobuffalo/depgen v0.1.0/go.mod h1:+ifsuy7fhi15RWncXQQKjWS9JPkdah5sZvtHc2RXGlg= +github.com/gobuffalo/envy v1.6.15/go.mod h1:n7DRkBerg/aorDM8kbduw5dN3oXGswK5liaSCx4T5NI= +github.com/gobuffalo/envy v1.7.0/go.mod h1:n7DRkBerg/aorDM8kbduw5dN3oXGswK5liaSCx4T5NI= +github.com/gobuffalo/flect v0.1.0/go.mod h1:d2ehjJqGOH/Kjqcoz+F7jHTBbmDb38yXA598Hb50EGs= +github.com/gobuffalo/flect v0.1.1/go.mod h1:8JCgGVbRjJhVgD6399mQr4fx5rRfGKVzFjbj6RE/9UI= +github.com/gobuffalo/flect v0.1.3/go.mod h1:8JCgGVbRjJhVgD6399mQr4fx5rRfGKVzFjbj6RE/9UI= +github.com/gobuffalo/genny v0.0.0-20190329151137-27723ad26ef9/go.mod h1:rWs4Z12d1Zbf19rlsn0nurr75KqhYp52EAGGxTbBhNk= +github.com/gobuffalo/genny v0.0.0-20190403191548-3ca520ef0d9e/go.mod h1:80lIj3kVJWwOrXWWMRzzdhW3DsrdjILVil/SFKBzF28= +github.com/gobuffalo/genny v0.1.0/go.mod h1:XidbUqzak3lHdS//TPu2OgiFB+51Ur5f7CSnXZ/JDvo= +github.com/gobuffalo/genny v0.1.1/go.mod h1:5TExbEyY48pfunL4QSXxlDOmdsD44RRq4mVZ0Ex28Xk= +github.com/gobuffalo/gitgen v0.0.0-20190315122116-cc086187d211/go.mod h1:vEHJk/E9DmhejeLeNt7UVvlSGv3ziL+djtTr3yyzcOw= +github.com/gobuffalo/gogen v0.0.0-20190315121717-8f38393713f5/go.mod h1:V9QVDIxsgKNZs6L2IYiGR8datgMhB577vzTDqypH360= +github.com/gobuffalo/gogen v0.1.0/go.mod h1:8NTelM5qd8RZ15VjQTFkAW6qOMx5wBbW4dSCS3BY8gg= +github.com/gobuffalo/gogen v0.1.1/go.mod h1:y8iBtmHmGc4qa3urIyo1shvOD8JftTtfcKi+71xfDNE= +github.com/gobuffalo/logger v0.0.0-20190315122211-86e12af44bc2/go.mod h1:QdxcLw541hSGtBnhUc4gaNIXRjiDppFGaDqzbrBd3v8= +github.com/gobuffalo/mapi v1.0.1/go.mod h1:4VAGh89y6rVOvm5A8fKFxYG+wIW6LO1FMTG9hnKStFc= +github.com/gobuffalo/mapi v1.0.2/go.mod h1:4VAGh89y6rVOvm5A8fKFxYG+wIW6LO1FMTG9hnKStFc= +github.com/gobuffalo/packd v0.0.0-20190315124812-a385830c7fc0/go.mod h1:M2Juc+hhDXf/PnmBANFCqx4DM3wRbgDvnVWeG2RIxq4= +github.com/gobuffalo/packd v0.1.0/go.mod h1:M2Juc+hhDXf/PnmBANFCqx4DM3wRbgDvnVWeG2RIxq4= +github.com/gobuffalo/packr/v2 v2.0.9/go.mod h1:emmyGweYTm6Kdper+iywB6YK5YzuKchGtJQZ0Odn4pQ= +github.com/gobuffalo/packr/v2 v2.2.0/go.mod h1:CaAwI0GPIAv+5wKLtv8Afwl+Cm78K/I/VCm/3ptBN+0= +github.com/gobuffalo/syncx v0.0.0-20190224160051-33c29581e754/go.mod h1:HhnNqWY95UYwwW3uSASeV7vtgYkT2t16hJgV3AEPUpw= +github.com/godbus/dbus/v5 v5.0.4/go.mod h1:xhWf0FNVPg57R7Z0UbKHbJfkEywrmjJnf7w5xrFpKfA= +github.com/golang/protobuf v1.5.4 h1:i7eJL8qZTpSEXOPTxNKhASYpMn+8e5Q6AdndVa1dWek= +github.com/golang/protobuf v1.5.4/go.mod h1:lnTiLA8Wa4RWRcIUkrtSVa5nRhsEGBg48fD6rSs7xps= +github.com/golang/snappy v0.0.1/go.mod h1:/XxbfmMg8lxefKM7IXC3fBNl/7bRcc72aCRzEWrmP2Q= +github.com/google/go-cmp v0.5.2/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= +github.com/google/go-cmp v0.7.0 h1:wk8382ETsv4JYUZwIsn6YpYiWiBsYLSJiTsyBybVuN8= +github.com/google/go-cmp v0.7.0/go.mod h1:pXiqmnSA92OHEEa9HXL2W4E7lf9JzCmGVUdgjX3N/iU= +github.com/google/uuid v1.1.1/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= +github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0= +github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= +github.com/inconshreveable/mousetrap v1.0.0/go.mod h1:PxqpIevigyE2G7u3NXJIT2ANytuPF1OarO4DADm73n8= +github.com/invopop/jsonschema v0.13.0 h1:KvpoAJWEjR3uD9Kbm2HWJmqsEaHt8lBUpd0qHcIi21E= +github.com/invopop/jsonschema v0.13.0/go.mod h1:ffZ5Km5SWWRAIN6wbDXItl95euhFz2uON45H2qjYt+0= +github.com/joho/godotenv v1.3.0/go.mod h1:7hK45KPybAkOC6peb+G5yklZfMxEjkZhHbwpqxOKXbg= +github.com/josharian/intern v1.0.0 h1:vlS4z54oSdjm0bgjRigI+G1HpF+tI+9rE5LLzOg8HmY= +github.com/josharian/intern v1.0.0/go.mod h1:5DoeVV0s6jJacbCEi61lwdGj/aVlrQvzHFFd8Hwg//Y= +github.com/karrick/godirwalk v1.8.0/go.mod h1:H5KPZjojv4lE+QYImBI8xVtrBRgYrIVsaRPx4tDPEn4= +github.com/karrick/godirwalk v1.10.3/go.mod h1:RoGL9dQei4vP9ilrpETWE8CLOZ1kiN0LhBygSwrAsHA= +github.com/klauspost/compress v1.13.6/go.mod h1:/3/Vjq9QcHkK5uEr5lBEmyoZ1iFhe47etQ6QUkpK6sk= +github.com/klauspost/compress v1.18.0 h1:c/Cqfb0r+Yi+JtIEq73FWXVkRonBlf0CRNYc8Zttxdo= +github.com/klauspost/compress v1.18.0/go.mod h1:2Pp+KzxcywXVXMr50+X0Q/Lsb43OQHYWRCY2AiWywWQ= +github.com/klauspost/cpuid/v2 v2.3.0 h1:S4CRMLnYUhGeDFDqkGriYKdfoFlDnMtqTiI/sFzhA9Y= +github.com/klauspost/cpuid/v2 v2.3.0/go.mod h1:hqwkgyIinND0mEev00jJYCxPNVRVXFQeu1XKlok6oO0= +github.com/konsorten/go-windows-terminal-sequences v1.0.1/go.mod h1:T0+1ngSBFLxvqU3pZ+m/2kptfBszLMUkC4ZK/EgS/cQ= +github.com/konsorten/go-windows-terminal-sequences v1.0.2/go.mod h1:T0+1ngSBFLxvqU3pZ+m/2kptfBszLMUkC4ZK/EgS/cQ= +github.com/kr/pretty v0.1.0/go.mod h1:dAy3ld7l9f0ibDNOQOHHMYYIIbhfbHSm3C4ZsoJORNo= +github.com/kr/pretty v0.3.1 h1:flRD4NNwYAUpkphVc1HcthR4KEIFJ65n8Mw5qdRn3LE= +github.com/kr/pretty v0.3.1/go.mod h1:hoEshYVHaxMs3cyo3Yncou5ZscifuDolrwPKZanG3xk= +github.com/kr/pty v1.1.1/go.mod h1:pFQYn66WHrOpPYNljwOMqo10TkYh1fy3cYio2l3bCsQ= +github.com/kr/text v0.1.0/go.mod h1:4Jbv+DJW3UT/LiOwJeYQe1efqtUx/iVham/4vfdArNI= +github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY= +github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE= +github.com/mailru/easyjson v0.0.0-20190614124828-94de47d64c63/go.mod h1:C1wdFJiN94OJF2b5HbByQZoLdCWB1Yqtg26g4irojpc= +github.com/mailru/easyjson v0.0.0-20190626092158-b2ccc519800e/go.mod h1:C1wdFJiN94OJF2b5HbByQZoLdCWB1Yqtg26g4irojpc= +github.com/mailru/easyjson v0.7.6/go.mod h1:xzfreul335JAWq5oZzymOObrkdz5UnU4kGfJJLY9Nlc= +github.com/mailru/easyjson v0.7.7/go.mod h1:xzfreul335JAWq5oZzymOObrkdz5UnU4kGfJJLY9Nlc= +github.com/mailru/easyjson v0.9.0 h1:PrnmzHw7262yW8sTBwxi1PdJA3Iw/EKBa8psRf7d9a4= +github.com/mailru/easyjson v0.9.0/go.mod h1:1+xMtQp2MRNVL/V1bOzuP3aP8VNwRW55fQUto+XFtTU= +github.com/mark3labs/mcp-go v0.37.0 h1:BywvZLPRT6Zx6mMG/MJfxLSZQkTGIcJSEGKsvr4DsoQ= +github.com/mark3labs/mcp-go v0.37.0/go.mod h1:T7tUa2jO6MavG+3P25Oy/jR7iCeJPHImCZHRymCn39g= +github.com/markbates/oncer v0.0.0-20181203154359-bf2de49a0be2/go.mod h1:Ld9puTsIW75CHf65OeIOkyKbteujpZVXDpWK6YGZbxE= +github.com/markbates/safe v1.0.1/go.mod h1:nAqgmRi7cY2nqMc92/bSEeQA+R4OheNU2T1kNSCBdG0= +github.com/mattn/go-colorable v0.1.13/go.mod h1:7S9/ev0klgBDR4GtXTXX8a3vIGJpMovkB8vQcUbaXHg= +github.com/mattn/go-colorable v0.1.14 h1:9A9LHSqF/7dyVVX6g0U9cwm9pG3kP9gSzcuIPHPsaIE= +github.com/mattn/go-colorable v0.1.14/go.mod h1:6LmQG8QLFO4G5z1gPvYEzlUgJ2wF+stgPZH1UqBm1s8= +github.com/mattn/go-isatty v0.0.16/go.mod h1:kYGgaQfpe5nmfYZH+SKPsOc2e4SrIfOl2e/yFXSvRLM= +github.com/mattn/go-isatty v0.0.19/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y= +github.com/mattn/go-isatty v0.0.20 h1:xfD0iDuEKnDkl03q4limB+vH+GxLEtL/jb4xVJSWWEY= +github.com/mattn/go-isatty v0.0.20/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y= +github.com/maximhq/bifrost/core v1.1.37 h1:jVFY1tQFY8T2r4S3RE1zN8cFp1Uw97Dec3Ud32rR8Uc= +github.com/maximhq/bifrost/core v1.1.37/go.mod h1:tf2pFTpoM53UGXXMFYxsaUjMqnCqYDOd9glFgMJvA0c= +github.com/maximhq/bifrost/framework v1.0.23 h1:erRPP9Q0WIaUgxuLBN8urd77SObEF9irPvpV9Wbegyk= +github.com/maximhq/bifrost/framework v1.0.23/go.mod h1:uEB0iuQtFfuFuMrhccMsb+51mf8m8X2tB8ZlDVoJUbM= +github.com/mitchellh/mapstructure v1.3.3/go.mod h1:bFUtVrKA4DC2yAKiSyO/QUcy7e+RRV2QTWOzhPopBRo= +github.com/mitchellh/mapstructure v1.4.1/go.mod h1:bFUtVrKA4DC2yAKiSyO/QUcy7e+RRV2QTWOzhPopBRo= +github.com/mitchellh/mapstructure v1.4.3/go.mod h1:bFUtVrKA4DC2yAKiSyO/QUcy7e+RRV2QTWOzhPopBRo= +github.com/mitchellh/mapstructure v1.5.0 h1:jeMsZIYE/09sWLaz43PL7Gy6RuMjD2eJVyuac5Z2hdY= +github.com/mitchellh/mapstructure v1.5.0/go.mod h1:bFUtVrKA4DC2yAKiSyO/QUcy7e+RRV2QTWOzhPopBRo= +github.com/montanaflynn/stats v0.0.0-20171201202039-1bf9dbcd8cbe/go.mod h1:wL8QJuTMNUDYhXwkmfOly8iTdp5TEcJFWZD2D7SIkUc= +github.com/niemeyer/pretty v0.0.0-20200227124842-a10e7caefd8e/go.mod h1:zD1mROLANZcx1PVRCS0qkT7pwLkGfwJo4zjcN/Tysno= +github.com/oklog/ulid v1.3.1 h1:EGfNDEx6MqHz8B3uNV6QAib1UR2Lm97sHi3ocA6ESJ4= +github.com/oklog/ulid v1.3.1/go.mod h1:CirwcVhetQ6Lv90oh/F+FBtV6XMibvdAFo93nm5qn4U= +github.com/opentracing/opentracing-go v1.2.0 h1:uEJPy/1a5RIPAJ0Ov+OIO8OxWu77jEv+1B0VhjKrZUs= +github.com/opentracing/opentracing-go v1.2.0/go.mod h1:GxEUsuufX4nBwe+T+Wl9TAgYrxe9dPLANfrWvHYVTgc= +github.com/pelletier/go-toml v1.7.0/go.mod h1:vwGMzjaWMwyfHwgIBhI2YUM4fB6nL6lVAvS1LBMMhTE= +github.com/pkg/errors v0.8.0/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= +github.com/pkg/errors v0.8.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= +github.com/pkg/errors v0.9.1 h1:FEBLx1zS214owpjy7qsBeixbURkuhQAwrK5UwLGTwt4= +github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= +github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= +github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2 h1:Jamvg5psRIccs7FGNTlIRMkT8wgtp5eCXdBlqhYGL6U= +github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= +github.com/redis/go-redis/v9 v9.12.1 h1:k5iquqv27aBtnTm2tIkROUDp8JBXhXZIVu1InSgvovg= +github.com/redis/go-redis/v9 v9.12.1/go.mod h1:huWgSWd8mW6+m0VPhJjSSQ+d6Nh1VICQ6Q5lHuCH/Iw= +github.com/rogpeppe/go-internal v1.1.0/go.mod h1:M8bDsm7K2OlrFYOpmOWEs/qY81heoFRclV5y23lUDJ4= +github.com/rogpeppe/go-internal v1.2.2/go.mod h1:M8bDsm7K2OlrFYOpmOWEs/qY81heoFRclV5y23lUDJ4= +github.com/rogpeppe/go-internal v1.3.0/go.mod h1:M8bDsm7K2OlrFYOpmOWEs/qY81heoFRclV5y23lUDJ4= +github.com/rogpeppe/go-internal v1.13.1 h1:KvO1DLK/DRN07sQ1LQKScxyZJuNnedQ5/wKSR38lUII= +github.com/rogpeppe/go-internal v1.13.1/go.mod h1:uMEvuHeurkdAXX61udpOXGD/AzZDWNMNyH2VO9fmH0o= +github.com/rs/xid v1.6.0/go.mod h1:7XoLgs4eV+QndskICGsho+ADou8ySMSjJKDIan90Nz0= +github.com/rs/zerolog v1.34.0 h1:k43nTLIwcTVQAncfCw4KZ2VY6ukYoZaBPNOE8txlOeY= +github.com/rs/zerolog v1.34.0/go.mod h1:bJsvje4Z08ROH4Nhs5iH600c3IkWhwp44iRc54W6wYQ= +github.com/sirupsen/logrus v1.4.0/go.mod h1:LxeOpSwHxABJmUn/MG1IvRgCAasNZTLOkJPxbbu5VWo= +github.com/sirupsen/logrus v1.4.1/go.mod h1:ni0Sbl8bgC9z8RoU9G6nDWqqs/fq4eDPysMBDgk/93Q= +github.com/sirupsen/logrus v1.4.2/go.mod h1:tLMulIdttU9McNUspp0xgXVQah82FyeX6MwdIuYE2rE= +github.com/spf13/cast v1.9.2 h1:SsGfm7M8QOFtEzumm7UZrZdLLquNdzFYfIbEXntcFbE= +github.com/spf13/cast v1.9.2/go.mod h1:jNfB8QC9IA6ZuY2ZjDp0KtFO2LZZlg4S/7bzP6qqeHo= +github.com/spf13/cobra v0.0.3/go.mod h1:1l0Ry5zgKvJasoi3XT1TypsSe7PqH0Sj9dhYf7v3XqQ= +github.com/spf13/pflag v1.0.3/go.mod h1:DYY7MBk1bdzusC3SYhjObp+wFpr4gzcvqqNjLnInEg4= +github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= +github.com/stretchr/objx v0.1.1/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= +github.com/stretchr/objx v0.4.0/go.mod h1:YvHI0jy2hoMjB+UWwv71VJQ9isScKT/TqJzVSSt89Yw= +github.com/stretchr/objx v0.5.0/go.mod h1:Yh+to48EsGEfYuaHDzXPcE3xhTkx73EhmCGUpEOglKo= +github.com/stretchr/objx v0.5.2 h1:xuMeJ0Sdp5ZMRXx/aWO6RZxdr3beISkG5/G/aIRr3pY= +github.com/stretchr/objx v0.5.2/go.mod h1:FRsXN1f5AsAjCGJKqEizvkpNtU+EGNCLh3NxZ/8L+MA= +github.com/stretchr/testify v1.2.2/go.mod h1:a8OnRcib4nhh0OaRAV+Yts87kKdq0PP7pXfy6kDkUVs= +github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI= +github.com/stretchr/testify v1.6.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= +github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= +github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= +github.com/stretchr/testify v1.8.0/go.mod h1:yNjHg4UonilssWZ8iaSj1OCr/vHnekPRkoO+kdMU+MU= +github.com/stretchr/testify v1.8.1/go.mod h1:w2LPCIKwWwSfY2zedu0+kehJoqGctiVI29o6fzry7u4= +github.com/stretchr/testify v1.10.0 h1:Xv5erBjTwe/5IxqUQTdXv5kgmIvbHo3QQyRwhJsOfJA= +github.com/stretchr/testify v1.10.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY= +github.com/tidwall/pretty v1.0.0/go.mod h1:XNkn88O1ChpSDQmQeStsy+sBenx6DDtFZJxhVysOjyk= +github.com/twitchyliquid64/golang-asm v0.15.1 h1:SU5vSMR7hnwNxj24w34ZyCi/FmDZTkS4MhqMhdFk5YI= +github.com/twitchyliquid64/golang-asm v0.15.1/go.mod h1:a1lVb/DtPvCB8fslRZhAngC2+aY1QWCk3Cedj/Gdt08= +github.com/valyala/bytebufferpool v1.0.0 h1:GqA5TC/0021Y/b9FG4Oi9Mr3q7XYx6KllzawFIhcdPw= +github.com/valyala/bytebufferpool v1.0.0/go.mod h1:6bBcMArwyJ5K/AmCkWv1jt77kVWyCJ6HpOuEn7z0Csc= +github.com/valyala/fasthttp v1.65.0 h1:j/u3uzFEGFfRxw79iYzJN+TteTJwbYkru9uDp3d0Yf8= +github.com/valyala/fasthttp v1.65.0/go.mod h1:P/93/YkKPMsKSnATEeELUCkG8a7Y+k99uxNHVbKINr4= +github.com/weaviate/weaviate v1.31.5 h1:YcmU1NcY2rdegWpE/mifS/9OisjE3I30JC7k6OgRlIE= +github.com/weaviate/weaviate v1.31.5/go.mod h1:CMgFYC2WIekOrNtyCQZ+HRJzJVCtrJYAdAkZVUVy45E= +github.com/weaviate/weaviate-go-client/v5 v5.2.0 h1:/HG0vFiKBK3JoOKo0mdk2XVYZ+oM0KfvCLG2ySr/FCA= +github.com/weaviate/weaviate-go-client/v5 v5.2.0/go.mod h1:nzR0ScRmbbutI+0pAjylj9Pt6upGVotnphiLWjy/QNA= +github.com/wk8/go-ordered-map/v2 v2.1.8 h1:5h/BUHu93oj4gIdvHHHGsScSTMijfx5PeYkE/fJgbpc= +github.com/wk8/go-ordered-map/v2 v2.1.8/go.mod h1:5nJHM5DyteebpVlHnWMV0rPz6Zp7+xBAnxjb1X5vnTw= +github.com/xdg-go/pbkdf2 v1.0.0/go.mod h1:jrpuAogTd400dnrH08LKmI/xc1MbPOebTwRqcT5RDeI= +github.com/xdg-go/scram v1.0.2/go.mod h1:1WAq6h33pAW+iRreB34OORO2Nf7qel3VV3fjBj+hCSs= +github.com/xdg-go/stringprep v1.0.2/go.mod h1:8F9zXuvzgwmyT5DUm4GUfZGDdT3W+LCvS6+da4O5kxM= +github.com/xyproto/randomstring v1.0.5 h1:YtlWPoRdgMu3NZtP45drfy1GKoojuR7hmRcnhZqKjWU= +github.com/xyproto/randomstring v1.0.5/go.mod h1:rgmS5DeNXLivK7YprL0pY+lTuhNQW3iGxZ18UQApw/E= +github.com/yosida95/uritemplate/v3 v3.0.2 h1:Ed3Oyj9yrmi9087+NczuL5BwkIc4wvTb5zIM+UJPGz4= +github.com/yosida95/uritemplate/v3 v3.0.2/go.mod h1:ILOh0sOhIJR3+L/8afwt/kE++YT040gmv5BQTMR2HP4= +github.com/youmark/pkcs8 v0.0.0-20181117223130-1be2e3e5546d/go.mod h1:rHwXgn7JulP+udvsHwJoVG1YGAP6VLg4y9I5dyZdqmA= +go.mongodb.org/mongo-driver v1.7.3/go.mod h1:NqaYOwnXWr5Pm7AOpO5QFxKJ503nbMse/R79oO62zWg= +go.mongodb.org/mongo-driver v1.7.5/go.mod h1:VXEWRZ6URJIkUq2SCAyapmhH0ZLRBP+FT4xhp5Zvxng= +go.mongodb.org/mongo-driver v1.8.3/go.mod h1:0sQWfOeY63QTntERDJJ/0SuKK0T1uVSgKCuAROlKEPY= +go.mongodb.org/mongo-driver v1.14.0 h1:P98w8egYRjYe3XDjxhYJagTokP/H6HzlsnojRgZRd80= +go.mongodb.org/mongo-driver v1.14.0/go.mod h1:Vzb0Mk/pa7e6cWw85R4F/endUC3u0U9jGcNU603k65c= +go.opentelemetry.io/auto/sdk v1.1.0 h1:cH53jehLUN6UFLY71z+NDOiNJqDdPRaXzTel0sJySYA= +go.opentelemetry.io/auto/sdk v1.1.0/go.mod h1:3wSPjt5PWp2RhlCcmmOial7AvC4DQqZb7a7wCow3W8A= +go.opentelemetry.io/otel v1.37.0 h1:9zhNfelUvx0KBfu/gb+ZgeAfAgtWrfHJZcAqFC228wQ= +go.opentelemetry.io/otel v1.37.0/go.mod h1:ehE/umFRLnuLa/vSccNq9oS1ErUlkkK71gMcN34UG8I= +go.opentelemetry.io/otel/metric v1.37.0 h1:mvwbQS5m0tbmqML4NqK+e3aDiO02vsf/WgbsdpcPoZE= +go.opentelemetry.io/otel/metric v1.37.0/go.mod h1:04wGrZurHYKOc+RKeye86GwKiTb9FKm1WHtO+4EVr2E= +go.opentelemetry.io/otel/sdk v1.37.0 h1:ItB0QUqnjesGRvNcmAcU0LyvkVyGJ2xftD29bWdDvKI= +go.opentelemetry.io/otel/sdk v1.37.0/go.mod h1:VredYzxUvuo2q3WRcDnKDjbdvmO0sCzOvVAiY+yUkAg= +go.opentelemetry.io/otel/sdk/metric v1.37.0 h1:90lI228XrB9jCMuSdA0673aubgRobVZFhbjxHHspCPc= +go.opentelemetry.io/otel/sdk/metric v1.37.0/go.mod h1:cNen4ZWfiD37l5NhS+Keb5RXVWZWpRE+9WyVCpbo5ps= +go.opentelemetry.io/otel/trace v1.37.0 h1:HLdcFNbRQBE2imdSEgm/kwqmQj1Or1l/7bW6mxVK7z4= +go.opentelemetry.io/otel/trace v1.37.0/go.mod h1:TlgrlQ+PtQO5XFerSPUYG0JSgGyryXewPGyayAWSBS0= +golang.org/x/arch v0.20.0 h1:dx1zTU0MAE98U+TQ8BLl7XsJbgze2WnNKF/8tGp/Q6c= +golang.org/x/arch v0.20.0/go.mod h1:bdwinDaKcfZUGpH09BB7ZmOfhalA8lQdzl62l8gGWsk= +golang.org/x/crypto v0.0.0-20180904163835-0709b304e793/go.mod h1:6SG95UA2DQfeDnfUPMdvaQW0Q7yPrPDi9nlGo2tz2b4= +golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= +golang.org/x/crypto v0.0.0-20190422162423-af44ce270edf/go.mod h1:WFFai1msRO1wXaEeE5yQxYXgSfI8pQAWXbQop6sCtWE= +golang.org/x/crypto v0.0.0-20200302210943-78000ba7a073/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto= +golang.org/x/crypto v0.0.0-20201216223049-8b5274cf687f/go.mod h1:jdWPYTVW3xRLrWPugEBEK3UY2ZEsg3UU495nc5E+M+I= +golang.org/x/net v0.0.0-20190311183353-d8887717615a/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg= +golang.org/x/net v0.0.0-20190404232315-eb5bcb51f2a3/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg= +golang.org/x/net v0.0.0-20210421230115-4e50805a0758/go.mod h1:72T/g9IO56b78aLF+1Kcs5dz7/ng1VjMUvfKvpfy+jM= +golang.org/x/net v0.0.0-20220127200216-cd36cc0744dd/go.mod h1:CfG3xpIq0wQ8r1q4Su4UZFWDARRcnwPjda9FqA0JpMk= +golang.org/x/net v0.43.0 h1:lat02VYK2j4aLzMzecihNvTlJNQUq316m2Mr9rnM6YE= +golang.org/x/net v0.43.0/go.mod h1:vhO1fvI4dGsIjh73sWfUVjj3N7CA9WkKJNQm2svM6Jg= +golang.org/x/oauth2 v0.30.0 h1:dnDm7JmhM45NNpd8FDDeLhK6FwqbOf4MLCM9zb1BOHI= +golang.org/x/oauth2 v0.30.0/go.mod h1:B++QgG3ZKulg6sRPGD/mqlHQs5rB3Ml9erfeDY7xKlU= +golang.org/x/sync v0.0.0-20190227155943-e225da77a7e6/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= +golang.org/x/sync v0.0.0-20190412183630-56d357773e84/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= +golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= +golang.org/x/sync v0.0.0-20190911185100-cd5d95a43a6e/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= +golang.org/x/sys v0.0.0-20180905080454-ebe1bf3edb33/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= +golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= +golang.org/x/sys v0.0.0-20190403152447-81d4e9dc473e/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20190412213103-97732733099d/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20190419153524-e8e3143a4f4a/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20190422165155-953cdadca894/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20190531175056-4c3a928424d2/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20191026070338-33540a1f6037/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20210420072515-93ed5bcd2bfe/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20210615035016-665e8c7367d1/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.0.0-20211216021012-1d35b9e2eb4e/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.0.0-20220811171246-fbc7d0a398ab/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.12.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.35.0 h1:vz1N37gP5bs89s7He8XuIYXpyY0+QlsKmzipCbUtyxI= +golang.org/x/sys v0.35.0/go.mod h1:BJP2sWEmIv4KK5OTEluFJCKSidICx8ciO85XgH3Ak8k= +golang.org/x/term v0.0.0-20201117132131-f5c789dd3221/go.mod h1:Nr5EML6q2oocZ2LXRh80K7BxOlk5/8JxuGnuhpl+muw= +golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo= +golang.org/x/term v0.0.0-20210927222741-03fcf44c2211/go.mod h1:jbD1KX2456YbFQfuXm/mYQcufACuNUgVhRMnK/tPxf8= +golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= +golang.org/x/text v0.3.5/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= +golang.org/x/text v0.3.6/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= +golang.org/x/text v0.3.7/go.mod h1:u+2+/6zg+i71rQMx5EYifcz6MCKuco9NR6JIITiCfzQ= +golang.org/x/text v0.28.0 h1:rhazDwis8INMIwQ4tpjLDzUhx6RlXqZNPEM0huQojng= +golang.org/x/text v0.28.0/go.mod h1:U8nCwOR8jO/marOQ0QbDiOngZVEBB7MAiitBuMjXiNU= +golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= +golang.org/x/tools v0.0.0-20190329151228-23e29df326fe/go.mod h1:LCzVGOaR6xXOjkQ3onu1FJEFr0SW1gC7cKk1uF8kGRs= +golang.org/x/tools v0.0.0-20190416151739-9c9e1878f421/go.mod h1:LCzVGOaR6xXOjkQ3onu1FJEFr0SW1gC7cKk1uF8kGRs= +golang.org/x/tools v0.0.0-20190420181800-aa740d480789/go.mod h1:LCzVGOaR6xXOjkQ3onu1FJEFr0SW1gC7cKk1uF8kGRs= +golang.org/x/tools v0.0.0-20190531172133-b3315ee88b7d/go.mod h1:/rFqwRUd4F7ZHNgwSSTFct+R/Kf4OFW1sUzUTQQTgfc= +golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= +google.golang.org/genproto/googleapis/rpc v0.0.0-20250811230008-5f3141c8851a h1:tPE/Kp+x9dMSwUm/uM0JKK0IfdiJkwAbSMSeZBXXJXc= +google.golang.org/genproto/googleapis/rpc v0.0.0-20250811230008-5f3141c8851a/go.mod h1:gw1tLEfykwDz2ET4a12jcXt4couGAm7IwsVaTy0Sflo= +google.golang.org/grpc v1.74.2 h1:WoosgB65DlWVC9FqI82dGsZhWFNBSLjQ84bjROOpMu4= +google.golang.org/grpc v1.74.2/go.mod h1:CtQ+BGjaAIXHs/5YS3i473GqwBBa1zGQNevxdeBEXrM= +google.golang.org/protobuf v1.36.7 h1:IgrO7UwFQGJdRNXH/sQux4R1Dj1WAKcLElzeeRaXV2A= +google.golang.org/protobuf v1.36.7/go.mod h1:jduwjTPXsFjZGTmRluh+L6NjiWu7pchiJ2/5YcXBHnY= +gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= +gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= +gopkg.in/check.v1 v1.0.0-20200227125254-8fa46927fb4f/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= +gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c h1:Hei/4ADfdWqJk1ZMxUNpqntNwaWcugrBjAiHlqqRiVk= +gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c/go.mod h1:JHkPIbrfpd72SG/EVd6muEfDQjcINNoR0C8j2r3qZ4Q= +gopkg.in/errgo.v2 v2.1.0/go.mod h1:hNsd1EY+bozCKY1Ytp96fpM3vjJbqLJn88ws8XvfDNI= +gopkg.in/yaml.v2 v2.2.2/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= +gopkg.in/yaml.v2 v2.2.8/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= +gopkg.in/yaml.v2 v2.4.0 h1:D8xgwECY7CYvx+Y2n4sBz93Jn9JRvxdiyyo8CTfuKaY= +gopkg.in/yaml.v2 v2.4.0/go.mod h1:RDklbk79AGWmwhnvt/jBztapEOGDOx6ZbXqjP6csGnQ= +gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= +gopkg.in/yaml.v3 v3.0.0-20200605160147-a5ece683394c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= +gopkg.in/yaml.v3 v3.0.0-20200615113413-eeeca48fe776/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= +gopkg.in/yaml.v3 v3.0.0-20210107192922-496545a6307b/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= +gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= +gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= diff --git a/plugins/semanticcache/main.go b/plugins/semanticcache/main.go new file mode 100644 index 000000000..8e03fa986 --- /dev/null +++ b/plugins/semanticcache/main.go @@ -0,0 +1,716 @@ +// Package semanticcache provides semantic caching integration for Bifrost plugin. +// This plugin caches responses using both direct hash matching (xxhash) and semantic similarity search (embeddings). +// It supports configurable caching behavior via the VectorStore abstraction, with TTL management and streaming response handling. +package semanticcache + +import ( + "context" + "encoding/json" + "fmt" + "strconv" + "sync" + "time" + + "github.com/google/uuid" + + bifrost "github.com/maximhq/bifrost/core" + "github.com/maximhq/bifrost/core/schemas" + "github.com/maximhq/bifrost/framework" + "github.com/maximhq/bifrost/framework/vectorstore" +) + +// Config contains configuration for the semantic cache plugin. +// The VectorStore abstraction handles the underlying storage implementation and its defaults. +// Only specify values you want to override from the semantic cache defaults. +type Config struct { + // Embedding Model settings - REQUIRED for semantic caching + Provider schemas.ModelProvider `json:"provider"` + Keys []schemas.Key `json:"keys"` + EmbeddingModel string `json:"embedding_model,omitempty"` // Model to use for generating embeddings (optional) + + // Plugin behavior settings + CleanUpOnShutdown bool `json:"cleanup_on_shutdown,omitempty"` // Clean up cache on shutdown (default: false) + TTL time.Duration `json:"ttl,omitempty"` // Time-to-live for cached responses (default: 5min) + Threshold float64 `json:"threshold,omitempty"` // Cosine similarity threshold for semantic matching (default: 0.8) + VectorStoreNamespace string `json:"vector_store_namespace,omitempty"` // Namespace for vector store (optional) + Dimension int `json:"dimension"` // Dimension for vector store + + // Advanced caching behavior + ConversationHistoryThreshold int `json:"conversation_history_threshold,omitempty"` // Skip caching for requests with more than this number of messages in the conversation history (default: 3) + CacheByModel *bool `json:"cache_by_model,omitempty"` // Include model in cache key (default: true) + CacheByProvider *bool `json:"cache_by_provider,omitempty"` // Include provider in cache key (default: true) + ExcludeSystemPrompt *bool `json:"exclude_system_prompt,omitempty"` // Exclude system prompt in cache key (default: false) +} + +// UnmarshalJSON implements custom JSON unmarshaling for semantic cache Config. +// It supports TTL parsing from both string durations ("1m", "1hr") and numeric seconds for configurable cache behavior. +func (c *Config) UnmarshalJSON(data []byte) error { + // Define a temporary struct to avoid infinite recursion + type TempConfig struct { + Provider string `json:"provider"` + Keys []schemas.Key `json:"keys"` + EmbeddingModel string `json:"embedding_model,omitempty"` + CleanUpOnShutdown bool `json:"cleanup_on_shutdown,omitempty"` + Dimension int `json:"dimension"` + TTL interface{} `json:"ttl,omitempty"` + Threshold float64 `json:"threshold,omitempty"` + VectorStoreNamespace string `json:"vector_store_namespace,omitempty"` + ConversationHistoryThreshold int `json:"conversation_history_threshold,omitempty"` + CacheByModel *bool `json:"cache_by_model,omitempty"` + CacheByProvider *bool `json:"cache_by_provider,omitempty"` + ExcludeSystemPrompt *bool `json:"exclude_system_prompt,omitempty"` + } + + var temp TempConfig + if err := json.Unmarshal(data, &temp); err != nil { + return fmt.Errorf("failed to unmarshal config: %w", err) + } + + // Set simple fields + c.Provider = schemas.ModelProvider(temp.Provider) + c.Keys = temp.Keys + c.EmbeddingModel = temp.EmbeddingModel + c.CleanUpOnShutdown = temp.CleanUpOnShutdown + c.Dimension = temp.Dimension + c.CacheByModel = temp.CacheByModel + c.CacheByProvider = temp.CacheByProvider + c.VectorStoreNamespace = temp.VectorStoreNamespace + c.ConversationHistoryThreshold = temp.ConversationHistoryThreshold + c.Threshold = temp.Threshold + c.ExcludeSystemPrompt = temp.ExcludeSystemPrompt + // Handle TTL field with custom parsing for VectorStore-backed cache behavior + if temp.TTL != nil { + switch v := temp.TTL.(type) { + case string: + // Try parsing as duration string (e.g., "1m", "1hr") for semantic cache TTL + duration, err := time.ParseDuration(v) + if err != nil { + return fmt.Errorf("failed to parse TTL duration string '%s': %w", v, err) + } + c.TTL = duration + case int: + // Handle integer seconds for semantic cache TTL + c.TTL = time.Duration(v) * time.Second + default: + // Try converting to string and parsing as number for semantic cache TTL + ttlStr := fmt.Sprintf("%v", v) + if seconds, err := strconv.ParseFloat(ttlStr, 64); err == nil { + c.TTL = time.Duration(seconds * float64(time.Second)) + } else { + return fmt.Errorf("unsupported TTL type: %T (value: %v)", v, v) + } + } + } + + return nil +} + +// StreamChunk represents a single chunk from a streaming response +type StreamChunk struct { + Timestamp time.Time // When chunk was received + Response *schemas.BifrostResponse // The actual response chunk + FinishReason *string // If this is the final chunk +} + +// StreamAccumulator manages accumulation of streaming chunks for caching +type StreamAccumulator struct { + RequestID string // The request ID + Chunks []*StreamChunk // All chunks for this stream + IsComplete bool // Whether the stream is complete + HasError bool // Whether any chunk in the stream had an error + FinalTimestamp time.Time // When the stream completed + Embedding []float32 // Embedding for the original request + Metadata map[string]interface{} // Metadata for caching + TTL time.Duration // TTL for this cache entry + mu sync.Mutex // Protects chunk operations +} + +// Plugin implements the schemas.Plugin interface for semantic caching. +// It caches responses using a two-tier approach: direct hash matching for exact requests +// and semantic similarity search for related content. The plugin supports configurable caching behavior +// via the VectorStore abstraction, including TTL management and streaming response handling. +// +// Fields: +// - store: VectorStore instance for semantic cache operations +// - config: Plugin configuration including semantic cache and caching settings +// - logger: Logger instance for plugin operations +type Plugin struct { + store vectorstore.VectorStore + config Config + logger schemas.Logger + client *bifrost.Bifrost + streamAccumulators sync.Map // Track stream accumulators by request ID +} + +// Plugin constants +const ( + PluginName string = "semantic_cache" + DefaultVectorStoreNamespace string = "BifrostSemanticCachePlugin" + PluginLoggerPrefix string = "[Semantic Cache]" + CacheConnectionTimeout time.Duration = 5 * time.Second + CreateNamespaceTimeout time.Duration = 30 * time.Second + CacheSetTimeout time.Duration = 30 * time.Second + DefaultCacheTTL time.Duration = 5 * time.Minute + DefaultCacheThreshold float64 = 0.8 + DefaultConversationHistoryThreshold int = 3 +) + +var SelectFields = []string{"request_hash", "response", "stream_chunks", "expires_at", "cache_key", "provider", "model"} + +var VectorStoreProperties = map[string]vectorstore.VectorStoreProperties{ + "request_hash": { + DataType: vectorstore.VectorStorePropertyTypeString, + Description: "The hash of the request", + }, + "response": { + DataType: vectorstore.VectorStorePropertyTypeString, + Description: "The response from the provider", + }, + "stream_chunks": { + DataType: vectorstore.VectorStorePropertyTypeStringArray, + Description: "The stream chunks from the provider", + }, + "expires_at": { + DataType: vectorstore.VectorStorePropertyTypeInteger, + Description: "The expiration time of the cache entry", + }, + "cache_key": { + DataType: vectorstore.VectorStorePropertyTypeString, + Description: "The cache key from the request", + }, + "provider": { + DataType: vectorstore.VectorStorePropertyTypeString, + Description: "The provider used for the request", + }, + "model": { + DataType: vectorstore.VectorStorePropertyTypeString, + Description: "The model used for the request", + }, + "params_hash": { + DataType: vectorstore.VectorStorePropertyTypeString, + Description: "The hash of the parameters used for the request", + }, + "from_bifrost_semantic_cache_plugin": { + DataType: vectorstore.VectorStorePropertyTypeBoolean, + Description: "Whether the cache entry was created by the BifrostSemanticCachePlugin", + }, +} + +type PluginAccount struct { + provider schemas.ModelProvider + keys []schemas.Key +} + +func (pa *PluginAccount) GetConfiguredProviders() ([]schemas.ModelProvider, error) { + return []schemas.ModelProvider{pa.provider}, nil +} + +func (pa *PluginAccount) GetKeysForProvider(ctx *context.Context, providerKey schemas.ModelProvider) ([]schemas.Key, error) { + return pa.keys, nil +} + +func (pa *PluginAccount) GetConfigForProvider(providerKey schemas.ModelProvider) (*schemas.ProviderConfig, error) { + return &schemas.ProviderConfig{ + NetworkConfig: schemas.DefaultNetworkConfig, + ConcurrencyAndBufferSize: schemas.DefaultConcurrencyAndBufferSize, + }, nil +} + +// Dependencies is a list of dependencies that the plugin requires. +var Dependencies []framework.FrameworkDependency = []framework.FrameworkDependency{framework.FrameworkDependencyVectorStore} + +// ContextKey is a custom type for context keys to prevent key collisions +type ContextKey string + +const ( + CacheKey ContextKey = "semantic_cache_key" // To set the cache key for a request - REQUIRED for all requests + CacheTTLKey ContextKey = "semantic_cache_ttl" // To explicitly set the TTL for a request + CacheThresholdKey ContextKey = "semantic_cache_threshold" // To explicitly set the threshold for a request + CacheTypeKey ContextKey = "semantic_cache_cache_type" // To explicitly set the cache type for a request + CacheNoStoreKey ContextKey = "semantic_cache_no_store" // To explicitly disable storing the response in the cache + + // context keys for internal usage + requestIDKey ContextKey = "semantic_cache_request_id" + requestHashKey ContextKey = "semantic_cache_request_hash" + requestEmbeddingKey ContextKey = "semantic_cache_embedding" + requestEmbeddingTokensKey ContextKey = "semantic_cache_embedding_tokens" + requestParamsHashKey ContextKey = "semantic_cache_params_hash" + requestModelKey ContextKey = "semantic_cache_model" + requestProviderKey ContextKey = "semantic_cache_provider" + isCacheHitKey ContextKey = "semantic_cache_is_cache_hit" + cacheHitTypeKey ContextKey = "semantic_cache_cache_hit_type" +) + +type CacheType string + +const ( + CacheTypeDirect CacheType = "direct" + CacheTypeSemantic CacheType = "semantic" +) + +// Init creates a new semantic cache plugin instance with the provided configuration. +// It uses the VectorStore abstraction for cache operations and returns a configured plugin. +// +// The VectorStore handles the underlying storage implementation and its defaults. +// The plugin only sets defaults for its own behavior (TTL, cache key generation, etc.). +// +// Parameters: +// - config: Semantic cache and plugin configuration (CacheKey is required) +// - logger: Logger instance for the plugin +// - store: VectorStore instance for cache operations +// +// Returns: +// - schemas.Plugin: A configured semantic cache plugin instance +// - error: Any error that occurred during plugin initialization +func Init(ctx context.Context, config Config, logger schemas.Logger, store vectorstore.VectorStore) (schemas.Plugin, error) { + // Set plugin-specific defaults + if config.VectorStoreNamespace == "" { + logger.Debug(PluginLoggerPrefix + " Vector store namespace is not set, using default of " + DefaultVectorStoreNamespace) + config.VectorStoreNamespace = DefaultVectorStoreNamespace + } + if config.TTL == 0 { + logger.Debug(PluginLoggerPrefix + " TTL is not set, using default of 5 minutes") + config.TTL = DefaultCacheTTL + } + if config.Threshold == 0 { + logger.Debug(PluginLoggerPrefix + " Threshold is not set, using default of " + strconv.FormatFloat(DefaultCacheThreshold, 'f', -1, 64)) + config.Threshold = DefaultCacheThreshold + } + if config.ConversationHistoryThreshold == 0 { + logger.Debug(PluginLoggerPrefix + " Conversation history threshold is not set, using default of " + strconv.Itoa(DefaultConversationHistoryThreshold)) + config.ConversationHistoryThreshold = DefaultConversationHistoryThreshold + } + + // Set cache behavior defaults + if config.CacheByModel == nil { + config.CacheByModel = bifrost.Ptr(true) + } + if config.CacheByProvider == nil { + config.CacheByProvider = bifrost.Ptr(true) + } + + plugin := &Plugin{ + store: store, + config: config, + logger: logger, + } + + if config.Provider == "" || len(config.Keys) == 0 { + logger.Warn(PluginLoggerPrefix + " Provider and keys are required for semantic cache, falling back to direct search only") + } else { + bifrost, err := bifrost.Init(ctx, schemas.BifrostConfig{ + Logger: logger, + Account: &PluginAccount{ + provider: config.Provider, + keys: config.Keys, + }, + }) + if err != nil { + return nil, fmt.Errorf("failed to initialize bifrost for semantic cache: %w", err) + } + + plugin.client = bifrost + } + + createCtx, cancel := context.WithTimeout(ctx, CreateNamespaceTimeout) + defer cancel() + if err := store.CreateNamespace(createCtx, config.VectorStoreNamespace, config.Dimension, VectorStoreProperties); err != nil { + return nil, fmt.Errorf("failed to create namespace for semantic cache: %w", err) + } + + return plugin, nil +} + +// GetName returns the canonical name of the semantic cache plugin. +// This name is used for plugin identification and logging purposes. +// +// Returns: +// - string: The plugin name for semantic cache +func (plugin *Plugin) GetName() string { + return PluginName +} + +// PreHook is called before a request is processed by Bifrost. +// It performs a two-stage cache lookup: first direct hash matching, then semantic similarity search. +// Uses UUID-based keys for entries stored in the VectorStore. +// +// Parameters: +// - ctx: Pointer to the context.Context +// - req: The incoming Bifrost request +// +// Returns: +// - *schemas.BifrostRequest: The original request +// - *schemas.BifrostResponse: Cached response if found, nil otherwise +// - error: Any error that occurred during cache lookup +func (plugin *Plugin) PreHook(ctx *context.Context, req *schemas.BifrostRequest) (*schemas.BifrostRequest, *schemas.PluginShortCircuit, error) { + // Get the cache key from the context + var cacheKey string + var ok bool + + cacheKey, ok = (*ctx).Value(CacheKey).(string) + if !ok || cacheKey == "" { + plugin.logger.Debug(PluginLoggerPrefix + " No cache key found in context, continuing without caching") + return req, nil, nil + } + + if plugin.isConversationHistoryThresholdExceeded(req) { + plugin.logger.Debug(PluginLoggerPrefix + " Skipping caching for request with conversation history threshold exceeded") + return req, nil, nil + } + + // Generate UUID for this request + requestID := uuid.New().String() + + // Store request ID, model, and provider in context for PostHook + *ctx = context.WithValue(*ctx, requestIDKey, requestID) + *ctx = context.WithValue(*ctx, requestModelKey, req.Model) + *ctx = context.WithValue(*ctx, requestProviderKey, req.Provider) + + requestType, ok := (*ctx).Value(schemas.BifrostContextKeyRequestType).(schemas.RequestType) + if !ok { + return req, nil, nil + } + + performDirectSearch, performSemanticSearch := true, true + if (*ctx).Value(CacheTypeKey) != nil { + cacheTypeVal, ok := (*ctx).Value(CacheTypeKey).(CacheType) + if !ok { + plugin.logger.Warn(PluginLoggerPrefix + " Cache type is not a CacheType, using all available cache types") + } else { + performDirectSearch = cacheTypeVal == CacheTypeDirect + performSemanticSearch = cacheTypeVal == CacheTypeSemantic + } + } + + if performDirectSearch { + shortCircuit, err := plugin.performDirectSearch(ctx, req, requestType, cacheKey) + if err != nil { + plugin.logger.Warn(PluginLoggerPrefix + " Direct search failed: " + err.Error()) + // Don't return - continue to semantic search fallback + shortCircuit = nil // Ensure we don't use an invalid shortCircuit + } + + if shortCircuit != nil { + return req, shortCircuit, nil + } + } + + if performSemanticSearch && plugin.client != nil { + if req.Input.EmbeddingInput != nil || req.Input.TranscriptionInput != nil { + plugin.logger.Debug(PluginLoggerPrefix + " Skipping semantic search for embedding/transcription input") + return req, nil, nil + } + + // Try semantic search as fallback + shortCircuit, err := plugin.performSemanticSearch(ctx, req, requestType, cacheKey) + if err != nil { + return req, nil, nil + } + + if shortCircuit != nil { + return req, shortCircuit, nil + } + } + + return req, nil, nil +} + +// PostHook is called after a response is received from a provider. +// It caches responses in the VectorStore using UUID-based keys with unified metadata structure +// including provider, model, request hash, and TTL. Handles both single and streaming responses. +// +// The function performs the following operations: +// 1. Checks configurable caching behavior and skips caching for unsuccessful responses if configured +// 2. Retrieves the request hash and ID from the context (set during PreHook) +// 3. Marshals the response for storage +// 4. Stores the unified cache entry in the VectorStore asynchronously (non-blocking) +// +// The VectorStore Add operation runs in a separate goroutine to avoid blocking the response. +// The function gracefully handles errors and continues without caching if any step fails, +// ensuring that response processing is never interrupted by caching issues. +// +// Parameters: +// - ctx: Pointer to the context.Context containing the request hash and ID +// - res: The response from the provider to be cached +// - bifrostErr: The error from the provider, if any (used for success determination) +// +// Returns: +// - *schemas.BifrostResponse: The original response, unmodified +// - *schemas.BifrostError: The original error, unmodified +// - error: Any error that occurred during caching preparation (always nil as errors are handled gracefully) +func (plugin *Plugin) PostHook(ctx *context.Context, res *schemas.BifrostResponse, bifrostErr *schemas.BifrostError) (*schemas.BifrostResponse, *schemas.BifrostError, error) { + if bifrostErr != nil { + return res, bifrostErr, nil + } + + isCacheHit := (*ctx).Value(isCacheHitKey) + if isCacheHit != nil { + isCacheHitValue, ok := isCacheHit.(bool) + if ok && isCacheHitValue { + return res, nil, nil + } + } + + // Check if caching is explicitly disabled + noStore := (*ctx).Value(CacheNoStoreKey) + if noStore != nil { + noStoreValue, ok := noStore.(bool) + if ok && noStoreValue { + plugin.logger.Debug(PluginLoggerPrefix + " Caching is explicitly disabled for this request, continuing without caching") + return res, nil, nil + } + } + + // Get the request type from context + requestType, ok := (*ctx).Value(schemas.BifrostContextKeyRequestType).(schemas.RequestType) + if !ok { + return res, nil, nil + } + + // Get the cache key from context + cacheKey, ok := (*ctx).Value(CacheKey).(string) + if !ok { + return res, nil, nil + } + + // Get the request ID from context + requestID, ok := (*ctx).Value(requestIDKey).(string) + if !ok { + return res, nil, nil + } + + // Get the hash from context + hash, ok := (*ctx).Value(requestHashKey).(string) + if !ok { + plugin.logger.Warn(PluginLoggerPrefix + " Hash is not a string, continuing without caching") + return res, nil, nil + } + + // Check cache type to optimize embedding handling + var embedding []float32 + var shouldStoreEmbeddings = true + + if (*ctx).Value(CacheTypeKey) != nil { + cacheTypeVal, ok := (*ctx).Value(CacheTypeKey).(CacheType) + if ok && cacheTypeVal == CacheTypeDirect { + // For direct-only caching, skip embedding operations entirely + shouldStoreEmbeddings = false + plugin.logger.Debug(PluginLoggerPrefix + " Skipping embedding operations for direct-only cache type") + } + } + + // Get embedding from context if available and needed + if shouldStoreEmbeddings && requestType != schemas.EmbeddingRequest && requestType != schemas.TranscriptionRequest { + embeddingValue := (*ctx).Value(requestEmbeddingKey) + if embeddingValue != nil { + embedding, ok = embeddingValue.([]float32) + if !ok { + plugin.logger.Warn(PluginLoggerPrefix + " Embedding is not a []float32, continuing without caching") + return res, nil, nil + } + } + // Note: embedding can be nil for direct cache hits or when semantic search is disabled + // This is fine - we can still cache using direct hash matching + } + + // Get the provider from context + provider, ok := (*ctx).Value(requestProviderKey).(schemas.ModelProvider) + if !ok { + plugin.logger.Warn(PluginLoggerPrefix + " Provider is not a schemas.ModelProvider, continuing without caching") + return res, nil, nil + } + + // Get the model from context + model, ok := (*ctx).Value(requestModelKey).(string) + if !ok { + plugin.logger.Warn(PluginLoggerPrefix + " Model is not a string, continuing without caching") + return res, nil, nil + } + + isFinalChunk := bifrost.IsFinalChunk(ctx) + + // Get the input tokens from context (can be nil if not set) + inputTokens, ok := (*ctx).Value(requestEmbeddingTokensKey).(int) + if ok { + isStreamRequest := bifrost.IsStreamRequestType(requestType) + + if !isStreamRequest || (isStreamRequest && isFinalChunk) { + if res.ExtraFields.CacheDebug == nil { + res.ExtraFields.CacheDebug = &schemas.BifrostCacheDebug{} + } + res.ExtraFields.CacheDebug.CacheHit = false + res.ExtraFields.CacheDebug.ProviderUsed = bifrost.Ptr(string(plugin.config.Provider)) + res.ExtraFields.CacheDebug.ModelUsed = bifrost.Ptr(plugin.config.EmbeddingModel) + res.ExtraFields.CacheDebug.InputTokens = &inputTokens + } + } + + cacheTTL := plugin.config.TTL + + ttlValue := (*ctx).Value(CacheTTLKey) + if ttlValue != nil { + // Get the request TTL from the context + ttl, ok := ttlValue.(time.Duration) + if !ok { + plugin.logger.Warn(PluginLoggerPrefix + " TTL is not a time.Duration, using default TTL") + } else { + cacheTTL = ttl + } + } + + // Cache everything in a unified VectorEntry asynchronously to avoid blocking the response + go func() { + // Create a background context with timeout for the cache operation + cacheCtx, cancel := context.WithTimeout(context.Background(), CacheSetTimeout) + defer cancel() + + // Get metadata from context + paramsHash, _ := (*ctx).Value(requestParamsHashKey).(string) + + // Build unified metadata with provider, model, and all params + unifiedMetadata := plugin.buildUnifiedMetadata(provider, model, paramsHash, hash, cacheKey, cacheTTL) + + // Handle streaming vs non-streaming responses + // Pass nil for embedding if we're in direct-only mode to optimize storage + embeddingToStore := embedding + if !shouldStoreEmbeddings { + embeddingToStore = nil + } + + if plugin.isStreamingRequest(requestType) { + if err := plugin.addStreamingResponse(cacheCtx, requestID, res, bifrostErr, embeddingToStore, unifiedMetadata, cacheTTL, isFinalChunk); err != nil { + plugin.logger.Warn(fmt.Sprintf("%s Failed to cache streaming response: %v", PluginLoggerPrefix, err)) + } + } else { + if err := plugin.addSingleResponse(cacheCtx, requestID, res, embeddingToStore, unifiedMetadata, cacheTTL); err != nil { + plugin.logger.Warn(fmt.Sprintf("%s Failed to cache single response: %v", PluginLoggerPrefix, err)) + } + } + }() + + return res, nil, nil +} + +// Cleanup performs cleanup operations for the semantic cache plugin. +// It removes all cached entries created by this plugin from the VectorStore only if CleanUpOnShutdown is true. +// Identifies cache entries by the presence of semantic cache-specific fields (request_hash, cache_key). +// +// The function performs the following operations: +// 1. Checks if cleanup is enabled via CleanUpOnShutdown config +// 2. Retrieves all entries and filters client-side to identify cache entries +// 3. Deletes all matching cache entries from the VectorStore in batches +// +// This method should be called when shutting down the application to ensure +// proper resource cleanup if configured to do so. +// +// Returns: +// - error: Any error that occurred during cleanup operations +func (plugin *Plugin) Cleanup() error { + // Clean up old stream accumulators first + plugin.cleanupOldStreamAccumulators() + + // Only clean up cache entries if configured to do so + if !plugin.config.CleanUpOnShutdown { + plugin.logger.Debug(PluginLoggerPrefix + " Cleanup on shutdown is disabled, skipping cache cleanup") + return nil + } + + // Clean up all cache entries created by this plugin + ctx, cancel := context.WithTimeout(context.Background(), CacheSetTimeout) + defer cancel() + + plugin.logger.Debug(PluginLoggerPrefix + " Starting cleanup of cache entries...") + + // Delete all cache entries created by this plugin + queries := []vectorstore.Query{ + { + Field: "from_bifrost_semantic_cache_plugin", + Operator: vectorstore.QueryOperatorEqual, + Value: true, + }, + } + + results, err := plugin.store.DeleteAll(ctx, plugin.config.VectorStoreNamespace, queries) + if err != nil { + return fmt.Errorf("failed to delete cache entries: %w", err) + } + + for _, result := range results { + if result.Status == vectorstore.DeleteStatusError { + plugin.logger.Warn(fmt.Sprintf("%s Failed to delete cache entry: %s", PluginLoggerPrefix, result.Error)) + } + } + plugin.logger.Info(fmt.Sprintf("%s Cleanup completed - deleted all cache entries", PluginLoggerPrefix)) + + if err := plugin.store.DeleteNamespace(ctx, plugin.config.VectorStoreNamespace); err != nil { + return fmt.Errorf("failed to delete namespace: %w", err) + } + + return nil +} + +// Public Methods for External Use + +// ClearCacheForKey deletes cache entries for a specific cache key. +// Uses the unified VectorStore interface for deletion of all entries with the given cache key. +// +// Parameters: +// - cacheKey: The specific cache key to delete +// +// Returns: +// - error: Any error that occurred during cache key deletion +func (plugin *Plugin) ClearCacheForKey(cacheKey string) error { + // Delete all entries with "cache_key" equal to the given cacheKey + queries := []vectorstore.Query{ + { + Field: "cache_key", + Operator: vectorstore.QueryOperatorEqual, + Value: cacheKey, + }, + { + Field: "from_bifrost_semantic_cache_plugin", + Operator: vectorstore.QueryOperatorEqual, + Value: true, + }, + } + + ctx, cancel := context.WithTimeout(context.Background(), CacheSetTimeout) + defer cancel() + results, err := plugin.store.DeleteAll(ctx, plugin.config.VectorStoreNamespace, queries) + if err != nil { + plugin.logger.Warn(fmt.Sprintf("%s Failed to delete cache entries for key '%s': %v", PluginLoggerPrefix, cacheKey, err)) + return err + } + + for _, result := range results { + if result.Status == vectorstore.DeleteStatusError { + plugin.logger.Warn(fmt.Sprintf("%s Failed to delete cache entry for key %s: %s", PluginLoggerPrefix, result.ID, result.Error)) + } + } + + plugin.logger.Debug(fmt.Sprintf("%s Deleted all cache entries for key %s", PluginLoggerPrefix, cacheKey)) + + return nil +} + +// ClearCacheForRequestID deletes cache entries for a specific request ID. +// Uses the unified VectorStore interface to delete the single entry by its UUID. +// +// Parameters: +// - requestID: The UUID-based request ID to delete cache entries for +// +// Returns: +// - error: Any error that occurred during cache key deletion +func (plugin *Plugin) ClearCacheForRequestID(requestID string) error { + // With the unified VectorStore interface, we delete the single entry by its UUID + ctx, cancel := context.WithTimeout(context.Background(), CacheSetTimeout) + defer cancel() + if err := plugin.store.Delete(ctx, plugin.config.VectorStoreNamespace, requestID); err != nil { + plugin.logger.Warn(fmt.Sprintf("%s Failed to delete cache entry: %v", PluginLoggerPrefix, err)) + return err + } + + plugin.logger.Debug(fmt.Sprintf("%s Deleted cache entry for key %s", PluginLoggerPrefix, requestID)) + + return nil +} diff --git a/plugins/semanticcache/plugin_cache_type_test.go b/plugins/semanticcache/plugin_cache_type_test.go new file mode 100644 index 000000000..b5e62b6f7 --- /dev/null +++ b/plugins/semanticcache/plugin_cache_type_test.go @@ -0,0 +1,242 @@ +package semanticcache + +import ( + "context" + "testing" + "time" +) + +// TestCacheTypeDirectOnly tests that CacheTypeKey set to "direct" only performs direct hash matching +func TestCacheTypeDirectOnly(t *testing.T) { + setup := NewTestSetup(t) + defer setup.Cleanup() + + // First, cache a response using normal behavior (both direct and semantic) + ctx1 := CreateContextWithCacheKey("test-cache-type-direct") + testRequest := CreateBasicChatRequest("What is Bifrost?", 0.7, 50) + + t.Log("Making first request to populate cache...") + response1, err1 := setup.Client.ChatCompletionRequest(ctx1, testRequest) + if err1 != nil { + t.Fatalf("First request failed: %v", err1) + } + AssertNoCacheHit(t, response1) + + WaitForCache() + + // Now test with CacheTypeKey set to direct only + ctx2 := CreateContextWithCacheKeyAndType("test-cache-type-direct", CacheTypeDirect) + + t.Log("Making second request with CacheTypeKey=direct...") + response2, err2 := setup.Client.ChatCompletionRequest(ctx2, testRequest) + if err2 != nil { + t.Fatalf("Second request failed: %v", err2) + } + + // Should be a cache hit from direct search + AssertCacheHit(t, response2, "direct") + + t.Log("βœ… CacheTypeKey=direct correctly performs only direct hash matching") +} + +// TestCacheTypeSemanticOnly tests that CacheTypeKey set to "semantic" only performs semantic search +func TestCacheTypeSemanticOnly(t *testing.T) { + setup := NewTestSetup(t) + defer setup.Cleanup() + + // First, cache a response using normal behavior + ctx1 := CreateContextWithCacheKey("test-cache-type-semantic") + testRequest := CreateBasicChatRequest("Explain machine learning concepts", 0.7, 50) + + t.Log("Making first request to populate cache...") + response1, err1 := setup.Client.ChatCompletionRequest(ctx1, testRequest) + if err1 != nil { + t.Fatalf("First request failed: %v", err1) + } + AssertNoCacheHit(t, response1) + + WaitForCache() + + // Test with slightly different wording that should match semantically but not directly + similarRequest := CreateBasicChatRequest("Can you explain concepts in machine learning", 0.7, 50) + + // Try with semantic-only search + ctx2 := CreateContextWithCacheKeyAndType("test-cache-type-semantic", CacheTypeSemantic) + + t.Log("Making second request with similar content and CacheTypeKey=semantic...") + response2, err2 := setup.Client.ChatCompletionRequest(ctx2, similarRequest) + if err2 != nil { + t.Fatalf("Second request failed: %v", err2) + } + + // This might be a cache hit if semantic similarity is high enough + // The test validates that semantic search is attempted + if response2.ExtraFields.CacheDebug != nil && response2.ExtraFields.CacheDebug.CacheHit { + AssertCacheHit(t, response2, "semantic") + t.Log("βœ… CacheTypeKey=semantic correctly found semantic match") + } else { + t.Log("ℹ️ No semantic match found (threshold may be too high for these similar phrases)") + AssertNoCacheHit(t, response2) + } + + t.Log("βœ… CacheTypeKey=semantic correctly performs only semantic search") +} + +// TestCacheTypeDirectWithSemanticFallback tests the default behavior (both direct and semantic) +func TestCacheTypeDirectWithSemanticFallback(t *testing.T) { + setup := NewTestSetup(t) + defer setup.Cleanup() + + // Cache a response first + ctx1 := CreateContextWithCacheKey("test-cache-type-fallback") + testRequest := CreateBasicChatRequest("Define artificial intelligence", 0.7, 50) + + t.Log("Making first request to populate cache...") + response1, err1 := setup.Client.ChatCompletionRequest(ctx1, testRequest) + if err1 != nil { + t.Fatalf("First request failed: %v", err1) + } + AssertNoCacheHit(t, response1) + + WaitForCache() + + // Test exact match (should hit direct cache) + ctx2 := CreateContextWithCacheKey("test-cache-type-fallback") + + t.Log("Making second identical request (should hit direct cache)...") + response2, err2 := setup.Client.ChatCompletionRequest(ctx2, testRequest) + if err2 != nil { + t.Fatalf("Second request failed: %v", err2) + } + AssertCacheHit(t, response2, "direct") + + // Test similar request (should potentially hit semantic cache) + similarRequest := CreateBasicChatRequest("What is artificial intelligence", 0.7, 50) + + t.Log("Making third similar request (should attempt semantic match)...") + response3, err3 := setup.Client.ChatCompletionRequest(ctx2, similarRequest) + if err3 != nil { + t.Fatalf("Third request failed: %v", err3) + } + + // May or may not be a cache hit depending on semantic similarity + if response3.ExtraFields.CacheDebug != nil && response3.ExtraFields.CacheDebug.CacheHit { + AssertCacheHit(t, response3, "semantic") + t.Log("βœ… Default behavior correctly found semantic match") + } else { + t.Log("ℹ️ No semantic match found (normal for different wording)") + AssertNoCacheHit(t, response3) + } + + t.Log("βœ… Default behavior correctly attempts both direct and semantic search") +} + +// TestCacheTypeInvalidValue tests behavior with invalid CacheTypeKey values +func TestCacheTypeInvalidValue(t *testing.T) { + setup := NewTestSetup(t) + defer setup.Cleanup() + + // Create context with invalid cache type + ctx := CreateContextWithCacheKey("test-invalid-cache-type") + ctx = context.WithValue(ctx, CacheTypeKey, "invalid_type") + + testRequest := CreateBasicChatRequest("Test invalid cache type", 0.7, 50) + + t.Log("Making request with invalid CacheTypeKey value...") + response, err := setup.Client.ChatCompletionRequest(ctx, testRequest) + if err != nil { + t.Fatalf("Request failed: %v", err) + } + + // Should fall back to default behavior (both direct and semantic) + AssertNoCacheHit(t, response) + + t.Log("βœ… Invalid CacheTypeKey value falls back to default behavior") +} + +// TestCacheTypeWithEmbeddingRequests tests CacheTypeKey behavior with embedding requests +func TestCacheTypeWithEmbeddingRequests(t *testing.T) { + setup := NewTestSetup(t) + defer setup.Cleanup() + + embeddingRequest := CreateEmbeddingRequest([]string{"Test embedding with cache type"}) + + // Cache first request + ctx1 := CreateContextWithCacheKey("test-embedding-cache-type") + t.Log("Making first embedding request...") + response1, err1 := setup.Client.EmbeddingRequest(ctx1, embeddingRequest) + if err1 != nil { + t.Fatalf("First request failed: %v", err1) + } + AssertNoCacheHit(t, response1) + + WaitForCache() + + // Test with direct-only cache type + ctx2 := CreateContextWithCacheKeyAndType("test-embedding-cache-type", CacheTypeDirect) + t.Log("Making second embedding request with CacheTypeKey=direct...") + response2, err2 := setup.Client.EmbeddingRequest(ctx2, embeddingRequest) + if err2 != nil { + t.Fatalf("Second request failed: %v", err2) + } + AssertCacheHit(t, response2, "direct") + + // Test with semantic-only cache type (should not find semantic match for embeddings) + ctx3 := CreateContextWithCacheKeyAndType("test-embedding-cache-type", CacheTypeSemantic) + t.Log("Making third embedding request with CacheTypeKey=semantic...") + response3, err3 := setup.Client.EmbeddingRequest(ctx3, embeddingRequest) + if err3 != nil { + t.Fatalf("Third request failed: %v", err3) + } + // Semantic search should be skipped for embedding requests + AssertNoCacheHit(t, response3) + + t.Log("βœ… CacheTypeKey works correctly with embedding requests") +} + +// TestCacheTypePerformanceCharacteristics tests that different cache types have expected performance +func TestCacheTypePerformanceCharacteristics(t *testing.T) { + setup := NewTestSetup(t) + defer setup.Cleanup() + + testRequest := CreateBasicChatRequest("Performance test for cache types", 0.7, 50) + + // Cache first request + ctx1 := CreateContextWithCacheKey("test-cache-performance") + t.Log("Making first request to populate cache...") + response1, err1 := setup.Client.ChatCompletionRequest(ctx1, testRequest) + if err1 != nil { + t.Fatalf("First request failed: %v", err1) + } + AssertNoCacheHit(t, response1) + + WaitForCache() + + // Test direct-only performance + ctx2 := CreateContextWithCacheKeyAndType("test-cache-performance", CacheTypeDirect) + start2 := time.Now() + response2, err2 := setup.Client.ChatCompletionRequest(ctx2, testRequest) + duration2 := time.Since(start2) + if err2 != nil { + t.Fatalf("Direct cache request failed: %v", err2) + } + AssertCacheHit(t, response2, "direct") + + t.Logf("Direct cache lookup took: %v", duration2) + + // Test default behavior (both direct and semantic) performance + ctx3 := CreateContextWithCacheKey("test-cache-performance") + start3 := time.Now() + response3, err3 := setup.Client.ChatCompletionRequest(ctx3, testRequest) + duration3 := time.Since(start3) + if err3 != nil { + t.Fatalf("Default cache request failed: %v", err3) + } + AssertCacheHit(t, response3, "direct") + + t.Logf("Default cache lookup took: %v", duration3) + + // Both should be fast since they hit direct cache + // Direct-only might be slightly faster as it doesn't need to prepare for semantic fallback + t.Log("βœ… Cache type performance characteristics validated") +} diff --git a/plugins/semanticcache/plugin_conversation_config_test.go b/plugins/semanticcache/plugin_conversation_config_test.go new file mode 100644 index 000000000..28bffe3d8 --- /dev/null +++ b/plugins/semanticcache/plugin_conversation_config_test.go @@ -0,0 +1,430 @@ +package semanticcache + +import ( + "strconv" + "testing" + + bifrost "github.com/maximhq/bifrost/core" + "github.com/maximhq/bifrost/core/schemas" +) + +// TestConversationHistoryThresholdBasic tests basic conversation history threshold functionality +func TestConversationHistoryThresholdBasic(t *testing.T) { + // Test with threshold of 2 messages + setup := CreateTestSetupWithConversationThreshold(t, 2) + defer setup.Cleanup() + + ctx := CreateContextWithCacheKey("test-conversation-threshold-basic") + + // Test 1: Conversation with exactly 2 messages (should cache) + conversation1 := BuildConversationHistory("", + []string{"Hello", "Hi there!"}, + ) + request1 := CreateConversationRequest(conversation1, 0.7, 50) + + t.Log("Testing conversation with exactly 2 messages (at threshold)...") + response1, err1 := setup.Client.ChatCompletionRequest(ctx, request1) + if err1 != nil { + t.Fatalf("First request failed: %v", err1) + } + AssertNoCacheHit(t, response1) // Fresh request + + WaitForCache() + + // Verify it was cached + response2, err2 := setup.Client.ChatCompletionRequest(ctx, request1) + if err2 != nil { + t.Fatalf("Second request failed: %v", err2) + } + AssertCacheHit(t, response2, "direct") // Should be cached + + // Test 2: Conversation with 3 messages (exceeds threshold, should NOT cache) + conversation2 := BuildConversationHistory("", + []string{"Hello", "Hi there!"}, + []string{"How are you?", "I'm doing well!"}, + ) + messages2 := AddUserMessage(conversation2, "What's the weather?") + request2 := CreateConversationRequest(messages2, 0.7, 50) // 5 messages total > 2 + + t.Log("Testing conversation with 5 messages (exceeds threshold)...") + response3, err3 := setup.Client.ChatCompletionRequest(ctx, request2) + if err3 != nil { + t.Fatalf("Third request failed: %v", err3) + } + AssertNoCacheHit(t, response3) // Should not cache + + WaitForCache() + + // Verify it was NOT cached + t.Log("Verifying conversation exceeding threshold was not cached...") + response4, err4 := setup.Client.ChatCompletionRequest(ctx, request2) + if err4 != nil { + t.Fatalf("Fourth request failed: %v", err4) + } + AssertNoCacheHit(t, response4) // Should still be fresh (not cached) + + t.Log("βœ… Conversation history threshold works correctly") +} + +// TestConversationHistoryThresholdWithSystemPrompt tests threshold with system messages +func TestConversationHistoryThresholdWithSystemPrompt(t *testing.T) { + // Test with threshold of 3, ExcludeSystemPrompt = false + setup := CreateTestSetupWithConversationThreshold(t, 3) + defer setup.Cleanup() + + ctx := CreateContextWithCacheKey("test-threshold-system-prompt") + + // System prompt + 2 user/assistant pairs = 5 messages total > 3 + conversation := BuildConversationHistory( + "You are a helpful assistant", // System message (counts toward threshold) + []string{"Hello", "Hi there!"}, + []string{"How are you?", "I'm doing well!"}, + ) + request := CreateConversationRequest(conversation, 0.7, 50) + + t.Log("Testing conversation with system prompt (5 total messages > 3 threshold)...") + response1, err1 := setup.Client.ChatCompletionRequest(ctx, request) + if err1 != nil { + t.Fatalf("First request failed: %v", err1) + } + AssertNoCacheHit(t, response1) // Should not cache (exceeds threshold) + + WaitForCache() + + // Verify not cached + response2, err2 := setup.Client.ChatCompletionRequest(ctx, request) + if err2 != nil { + t.Fatalf("Second request failed: %v", err2) + } + AssertNoCacheHit(t, response2) // Should not be cached + + t.Log("βœ… Conversation threshold correctly counts system messages") +} + +// TestConversationHistoryThresholdWithExcludeSystemPrompt tests interaction between threshold and exclude system prompt +func TestConversationHistoryThresholdWithExcludeSystemPrompt(t *testing.T) { + // Create setup with both threshold=3 and ExcludeSystemPrompt=true + setup := CreateTestSetupWithThresholdAndExcludeSystem(t, 3, true) + defer setup.Cleanup() + + ctx := CreateContextWithCacheKey("test-threshold-exclude-system") + + // Create conversation with exactly 3 non-system messages to test threshold boundary + // System + 1.5 user/assistant pairs = 4 messages total + // With ExcludeSystemPrompt=true, should only count 3 non-system messages for threshold + conversation := BuildConversationHistory( + "You are helpful", // System (excluded from count) + []string{"Hello", "Hi"}, // User + Assistant = 2 messages + []string{"Thanks", ""}, // User only = 1 message (no assistant response) + ) + // No slicing needed; BuildConversationHistory skips empty assistant entries. + request := CreateConversationRequest(conversation, 0.7, 50) // 3 non-system messages exactly + + t.Log("Testing threshold with ExcludeSystemPrompt=true (3 non-system messages = at threshold)...") + + // Test logic: + // - Total messages: 4 (1 system + 3 others) + // - With ExcludeSystemPrompt=true: counts as 3 non-system messages + // - Threshold is 3, so 3 <= 3 should allow caching + + response1, err1 := setup.Client.ChatCompletionRequest(ctx, request) + if err1 != nil { + t.Fatalf("First request failed: %v", err1) + } + AssertNoCacheHit(t, response1) // Fresh request, should not hit cache + + WaitForCache() + + // Second request should hit cache (3 non-system messages <= 3 threshold) + response2, err2 := setup.Client.ChatCompletionRequest(ctx, request) + if err2 != nil { + t.Fatalf("Second request failed: %v", err2) + } + AssertCacheHit(t, response2, "direct") // Should cache since 3 <= 3 after excluding system + + t.Log("βœ… Conversation threshold respects ExcludeSystemPrompt setting") +} + +// TestConversationHistoryThresholdDifferentValues tests different threshold values +func TestConversationHistoryThresholdDifferentValues(t *testing.T) { + testCases := []struct { + name string + threshold int + messages int + shouldCache bool + }{ + {"Threshold 1, 1 message", 1, 1, true}, + {"Threshold 1, 2 messages", 1, 2, false}, + {"Threshold 5, 4 messages", 5, 4, true}, + {"Threshold 5, 6 messages", 5, 6, false}, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + setup := CreateTestSetupWithConversationThreshold(t, tc.threshold) + defer setup.Cleanup() + + ctx := CreateContextWithCacheKey("test-threshold-" + tc.name) + + // Build conversation with specified number of messages + var conversation []schemas.BifrostMessage + for i := 0; i < tc.messages; i++ { + role := schemas.ModelChatMessageRoleUser + if i%2 == 1 { + role = schemas.ModelChatMessageRoleAssistant + } + message := schemas.BifrostMessage{ + Role: role, + Content: schemas.MessageContent{ + ContentStr: bifrost.Ptr("Message " + strconv.Itoa(i+1)), + }, + } + conversation = append(conversation, message) + } + + request := CreateConversationRequest(conversation, 0.7, 50) + + response1, err1 := setup.Client.ChatCompletionRequest(ctx, request) + if err1 != nil { + t.Fatalf("Request failed: %v", err1) + } + AssertNoCacheHit(t, response1) // Always fresh first time + + WaitForCache() + + response2, err2 := setup.Client.ChatCompletionRequest(ctx, request) + if err2 != nil { + t.Fatalf("Second request failed: %v", err2) + } + + if tc.shouldCache { + AssertCacheHit(t, response2, "direct") + } else { + AssertNoCacheHit(t, response2) + } + }) + } + + t.Log("βœ… Different conversation threshold values work correctly") +} + +// TestExcludeSystemPromptBasic tests basic ExcludeSystemPrompt functionality +func TestExcludeSystemPromptBasic(t *testing.T) { + // Test with ExcludeSystemPrompt = true + setup := CreateTestSetupWithExcludeSystemPrompt(t, true) + defer setup.Cleanup() + + ctx := CreateContextWithCacheKey("test-exclude-system-basic") + + // Create two conversations with different system prompts but same user/assistant messages + conversation1 := BuildConversationHistory( + "You are a helpful assistant", + []string{"What is AI?", "AI is artificial intelligence."}, + ) + + conversation2 := BuildConversationHistory( + "You are a technical expert", // Different system prompt + []string{"What is AI?", "AI is artificial intelligence."}, // Same user/assistant + ) + + request1 := CreateConversationRequest(conversation1, 0.7, 50) + request2 := CreateConversationRequest(conversation2, 0.7, 50) + + t.Log("Caching conversation with system prompt 1...") + response1, err1 := setup.Client.ChatCompletionRequest(ctx, request1) + if err1 != nil { + t.Fatalf("First request failed: %v", err1) + } + AssertNoCacheHit(t, response1) + + WaitForCache() + + t.Log("Testing conversation with different system prompt (should hit cache due to ExcludeSystemPrompt=true)...") + response2, err2 := setup.Client.ChatCompletionRequest(ctx, request2) + if err2 != nil { + t.Fatalf("Second request failed: %v", err2) + } + // Should hit cache because system prompts are excluded from cache key + AssertCacheHit(t, response2, "direct") + + t.Log("βœ… ExcludeSystemPrompt=true correctly ignores system prompts in cache keys") +} + +// TestExcludeSystemPromptComparison tests ExcludeSystemPrompt true vs false +func TestExcludeSystemPromptComparison(t *testing.T) { + // Test 1: ExcludeSystemPrompt = false (default) + setup1 := CreateTestSetupWithExcludeSystemPrompt(t, false) + defer setup1.Cleanup() + + ctx1 := CreateContextWithCacheKey("test-exclude-system-false") + + conversation1 := BuildConversationHistory( + "You are helpful", + []string{"Hello", "Hi there!"}, + ) + + conversation2 := BuildConversationHistory( + "You are an expert", // Different system prompt + []string{"Hello", "Hi there!"}, // Same user/assistant + ) + + request1 := CreateConversationRequest(conversation1, 0.7, 50) + request2 := CreateConversationRequest(conversation2, 0.7, 50) + + t.Log("Testing ExcludeSystemPrompt=false...") + response1, err1 := setup1.Client.ChatCompletionRequest(ctx1, request1) + if err1 != nil { + t.Fatalf("First request failed: %v", err1) + } + AssertNoCacheHit(t, response1) + + WaitForCache() + + response2, err2 := setup1.Client.ChatCompletionRequest(ctx1, request2) + if err2 != nil { + t.Fatalf("Second request failed: %v", err2) + } + // Should NOT hit direct cache, but might hit semantic cache due to similar content + if response2.ExtraFields.CacheDebug != nil && response2.ExtraFields.CacheDebug.CacheHit { + if response2.ExtraFields.CacheDebug.HitType != nil && *response2.ExtraFields.CacheDebug.HitType == "semantic" { + t.Log("βœ… Found semantic cache match (expected with similar content)") + } else { + t.Error("❌ Unexpected direct cache hit with different system prompts") + } + } else { + t.Log("βœ… No cache hit (system prompts create different cache keys)") + } + + // Test 2: ExcludeSystemPrompt = true + setup2 := CreateTestSetupWithExcludeSystemPrompt(t, true) + defer setup2.Cleanup() + + ctx2 := CreateContextWithCacheKey("test-exclude-system-true") + + t.Log("Testing ExcludeSystemPrompt=true...") + response3, err3 := setup2.Client.ChatCompletionRequest(ctx2, request1) + if err3 != nil { + t.Fatalf("Third request failed: %v", err3) + } + AssertNoCacheHit(t, response3) + + WaitForCache() + + response4, err4 := setup2.Client.ChatCompletionRequest(ctx2, request2) + if err4 != nil { + t.Fatalf("Fourth request failed: %v", err4) + } + // Should hit cache because system prompts are excluded from cache key + AssertCacheHit(t, response4, "direct") + + t.Log("βœ… ExcludeSystemPrompt true vs false comparison works correctly") +} + +// TestExcludeSystemPromptWithMultipleSystemMessages tests behavior with multiple system messages +func TestExcludeSystemPromptWithMultipleSystemMessages(t *testing.T) { + setup := CreateTestSetupWithExcludeSystemPrompt(t, true) + defer setup.Cleanup() + + ctx := CreateContextWithCacheKey("test-multiple-system-messages") + + // Manually create conversation with multiple system messages + conversation1 := []schemas.BifrostMessage{ + { + Role: schemas.ModelChatMessageRoleSystem, + Content: schemas.MessageContent{ContentStr: bifrost.Ptr("You are helpful")}, + }, + { + Role: schemas.ModelChatMessageRoleSystem, + Content: schemas.MessageContent{ContentStr: bifrost.Ptr("Be concise")}, + }, + { + Role: schemas.ModelChatMessageRoleUser, + Content: schemas.MessageContent{ContentStr: bifrost.Ptr("Hello")}, + }, + { + Role: schemas.ModelChatMessageRoleAssistant, + Content: schemas.MessageContent{ContentStr: bifrost.Ptr("Hi!")}, + }, + } + + conversation2 := []schemas.BifrostMessage{ + { + Role: schemas.ModelChatMessageRoleSystem, + Content: schemas.MessageContent{ContentStr: bifrost.Ptr("You are an expert")}, + }, + { + Role: schemas.ModelChatMessageRoleSystem, + Content: schemas.MessageContent{ContentStr: bifrost.Ptr("Be detailed")}, + }, + { + Role: schemas.ModelChatMessageRoleUser, + Content: schemas.MessageContent{ContentStr: bifrost.Ptr("Hello")}, + }, + { + Role: schemas.ModelChatMessageRoleAssistant, + Content: schemas.MessageContent{ContentStr: bifrost.Ptr("Hi!")}, + }, + } + + request1 := CreateConversationRequest(conversation1, 0.7, 50) + request2 := CreateConversationRequest(conversation2, 0.7, 50) + + t.Log("Caching conversation with multiple system messages...") + response1, err1 := setup.Client.ChatCompletionRequest(ctx, request1) + if err1 != nil { + t.Fatalf("First request failed: %v", err1) + } + AssertNoCacheHit(t, response1) + + WaitForCache() + + t.Log("Testing conversation with different multiple system messages...") + response2, err2 := setup.Client.ChatCompletionRequest(ctx, request2) + if err2 != nil { + t.Fatalf("Second request failed: %v", err2) + } + // Should hit cache because all system messages are excluded + AssertCacheHit(t, response2, "direct") + + t.Log("βœ… ExcludeSystemPrompt works with multiple system messages") +} + +// TestExcludeSystemPromptWithNoSystemMessages tests behavior when there are no system messages +func TestExcludeSystemPromptWithNoSystemMessages(t *testing.T) { + setup := CreateTestSetupWithExcludeSystemPrompt(t, true) + defer setup.Cleanup() + + ctx := CreateContextWithCacheKey("test-no-system-messages") + + // Conversation with no system messages + conversation := []schemas.BifrostMessage{ + { + Role: schemas.ModelChatMessageRoleUser, + Content: schemas.MessageContent{ContentStr: bifrost.Ptr("Hello")}, + }, + { + Role: schemas.ModelChatMessageRoleAssistant, + Content: schemas.MessageContent{ContentStr: bifrost.Ptr("Hi there!")}, + }, + } + + request := CreateConversationRequest(conversation, 0.7, 50) + + t.Log("Testing conversation with no system messages...") + response1, err1 := setup.Client.ChatCompletionRequest(ctx, request) + if err1 != nil { + t.Fatalf("First request failed: %v", err1) + } + AssertNoCacheHit(t, response1) + + WaitForCache() + + // Should cache normally + response2, err2 := setup.Client.ChatCompletionRequest(ctx, request) + if err2 != nil { + t.Fatalf("Second request failed: %v", err2) + } + AssertCacheHit(t, response2, "direct") + + t.Log("βœ… ExcludeSystemPrompt works correctly when no system messages present") +} diff --git a/plugins/semanticcache/plugin_core_test.go b/plugins/semanticcache/plugin_core_test.go new file mode 100644 index 000000000..d1754731b --- /dev/null +++ b/plugins/semanticcache/plugin_core_test.go @@ -0,0 +1,419 @@ +package semanticcache + +import ( + "context" + "os" + "testing" + "time" + + "github.com/maximhq/bifrost/core/schemas" + "github.com/maximhq/bifrost/framework/vectorstore" +) + +// TestSemanticCacheBasicFunctionality tests the core caching functionality +func TestSemanticCacheBasicFunctionality(t *testing.T) { + setup := NewTestSetup(t) + defer setup.Cleanup() + + ctx := CreateContextWithCacheKey("test-basic-value") + + // Create test request + testRequest := CreateBasicChatRequest( + "What is Bifrost? Answer in one short sentence.", + 0.7, + 50, + ) + + t.Log("Making first request (should go to OpenAI and be cached)...") + + // Make first request (will go to OpenAI and be cached) + start1 := time.Now() + response1, err1 := setup.Client.ChatCompletionRequest(ctx, testRequest) + duration1 := time.Since(start1) + + if err1 != nil { + t.Fatalf("First request failed: %v", err1) + } + + if response1 == nil || len(response1.Choices) == 0 || response1.Choices[0].Message.Content.ContentStr == nil { + t.Fatal("First response is invalid") + } + + t.Logf("First request completed in %v", duration1) + t.Logf("Response: %s", *response1.Choices[0].Message.Content.ContentStr) + + // Wait for cache to be written + WaitForCache() + + t.Log("Making second identical request (should be served from cache)...") + + // Make second identical request (should be cached) + start2 := time.Now() + response2, err2 := setup.Client.ChatCompletionRequest(ctx, testRequest) + duration2 := time.Since(start2) + + if err2 != nil { + t.Fatalf("Second request failed: %v", err2) + } + + if response2 == nil || len(response2.Choices) == 0 || response2.Choices[0].Message.Content.ContentStr == nil { + t.Fatal("Second response is invalid") + } + + t.Logf("Second request completed in %v", duration2) + t.Logf("Response: %s", *response2.Choices[0].Message.Content.ContentStr) + + // Verify cache hit + AssertCacheHit(t, response2, string(CacheTypeDirect)) + + // Performance comparison + t.Logf("Performance Summary:") + t.Logf("First request (OpenAI): %v", duration1) + t.Logf("Second request (Cache): %v", duration2) + + if duration2 >= duration1 { + t.Errorf("Cache request took longer than original request: cache=%v, original=%v", duration2, duration1) + } else { + speedup := float64(duration1) / float64(duration2) + t.Logf("Cache speedup: %.2fx faster", speedup) + + // Assert that cache is at least 1.5x faster (reasonable expectation) + if speedup < 1.5 { + t.Errorf("Cache speedup is less than 1.5x: got %.2fx", speedup) + } + } + + // Verify responses are identical (content should be the same) + content1 := *response1.Choices[0].Message.Content.ContentStr + content2 := *response2.Choices[0].Message.Content.ContentStr + + if content1 != content2 { + t.Errorf("Response content differs between cached and original:\nOriginal: %s\nCached: %s", content1, content2) + } + + // Verify provider information is maintained in cached response + if response2.ExtraFields.Provider != testRequest.Provider { + t.Errorf("Provider mismatch in cached response: expected %s, got %s", + testRequest.Provider, response2.ExtraFields.Provider) + } + + t.Log("βœ… Basic semantic caching test completed successfully!") +} + +// TestSemanticSearch tests the semantic similarity search functionality +func TestSemanticSearch(t *testing.T) { + setup := NewTestSetup(t) + defer setup.Cleanup() + + // Lower threshold for more flexible matching + setup.Config.Threshold = 0.5 + + ctx := CreateContextWithCacheKey("semantic-test-value") + + // First request - this will be cached + firstRequest := CreateBasicChatRequest( + "What is machine learning? Explain briefly.", + 0.0, // Use 0 temperature for consistent results + 50, + ) + + t.Log("Making first request (should go to OpenAI and be cached)...") + start1 := time.Now() + response1, err1 := setup.Client.ChatCompletionRequest(ctx, firstRequest) + duration1 := time.Since(start1) + + if err1 != nil { + t.Fatalf("First request failed: %v", err1) + } + + if response1 == nil || len(response1.Choices) == 0 || response1.Choices[0].Message.Content.ContentStr == nil { + t.Fatal("First response is invalid") + } + + t.Logf("First request completed in %v", duration1) + t.Logf("Response: %s", *response1.Choices[0].Message.Content.ContentStr) + + // Wait for cache to be written (async PostHook needs time to complete) + WaitForCache() + + // Second request - very similar text to test semantic matching + secondRequest := CreateBasicChatRequest( + "What is machine learning? Explain it briefly.", + 0.0, // Use 0 temperature for consistent results + 50, + ) + + t.Log("Making semantically similar request (should be served from semantic cache)...") + start2 := time.Now() + response2, err2 := setup.Client.ChatCompletionRequest(ctx, secondRequest) + duration2 := time.Since(start2) + + if err2 != nil { + t.Fatalf("Second request failed: %v", err2) + } + + if response2 == nil || len(response2.Choices) == 0 || response2.Choices[0].Message.Content.ContentStr == nil { + t.Fatal("Second response is invalid") + } + + t.Logf("Second request completed in %v", duration2) + t.Logf("Response: %s", *response2.Choices[0].Message.Content.ContentStr) + + // Check if second request was served from semantic cache + semanticMatch := false + + if response2.ExtraFields.CacheDebug != nil && response2.ExtraFields.CacheDebug.CacheHit { + if response2.ExtraFields.CacheDebug.HitType != nil && *response2.ExtraFields.CacheDebug.HitType == string(CacheTypeSemantic) { + semanticMatch = true + + threshold := 0.0 + similarity := 0.0 + + if response2.ExtraFields.CacheDebug.Threshold != nil { + threshold = *response2.ExtraFields.CacheDebug.Threshold + } + if response2.ExtraFields.CacheDebug.Similarity != nil { + similarity = *response2.ExtraFields.CacheDebug.Similarity + } + + t.Logf("βœ… Second request was served from semantic cache! Cache threshold: %f, Cache similarity: %f", threshold, similarity) + } + } + + if !semanticMatch { + t.Error("Semantic match expected but not found") + return + } + + // Performance comparison + t.Logf("Semantic Cache Performance:") + t.Logf("First request (OpenAI): %v", duration1) + t.Logf("Second request (Semantic): %v", duration2) + + if duration2 < duration1 { + speedup := float64(duration1) / float64(duration2) + t.Logf("Semantic cache speedup: %.2fx faster", speedup) + } + + t.Log("βœ… Semantic search test completed successfully!") +} + +// TestDirectVsSemanticSearch tests the difference between direct hash matching and semantic search +func TestDirectVsSemanticSearch(t *testing.T) { + setup := NewTestSetup(t) + defer setup.Cleanup() + + // Lower threshold for more flexible semantic matching + setup.Config.Threshold = 0.2 + + ctx := CreateContextWithCacheKey("direct-vs-semantic-test") + + // Test Case 1: Exact same request (should use direct hash matching) + t.Log("=== Test Case 1: Exact Same Request (Direct Hash Match) ===") + + exactRequest := CreateBasicChatRequest( + "What is artificial intelligence?", + 0.1, + 100, + ) + + t.Log("Making first request...") + _, err1 := setup.Client.ChatCompletionRequest(ctx, exactRequest) + if err1 != nil { + t.Fatalf("First request failed: %v", err1) + } + + WaitForCache() + + t.Log("Making exact same request (should hit direct cache)...") + response2, err2 := setup.Client.ChatCompletionRequest(ctx, exactRequest) + if err2 != nil { + t.Fatalf("Second request failed: %v", err2) + } + + // Should be a direct cache hit + AssertCacheHit(t, response2, string(CacheTypeDirect)) + + // Test Case 2: Similar but different request (should use semantic search) + t.Log("\n=== Test Case 2: Semantically Similar Request ===") + + semanticRequest := CreateBasicChatRequest( + "Can you explain what AI is?", // Similar but different wording + 0.1, // Same parameters + 100, + ) + + t.Log("Making semantically similar request...") + response3, err3 := setup.Client.ChatCompletionRequest(ctx, semanticRequest) + if err3 != nil { + t.Fatalf("Third request failed: %v", err3) + } + + semanticMatch := false + + // Check if it was served from cache and what type + if response3.ExtraFields.CacheDebug != nil && response3.ExtraFields.CacheDebug.CacheHit { + if response3.ExtraFields.CacheDebug.HitType != nil && *response3.ExtraFields.CacheDebug.HitType == string(CacheTypeSemantic) { + semanticMatch = true + + threshold := 0.0 + similarity := 0.0 + + if response3.ExtraFields.CacheDebug.Threshold != nil { + threshold = *response3.ExtraFields.CacheDebug.Threshold + } + if response3.ExtraFields.CacheDebug.Similarity != nil { + similarity = *response3.ExtraFields.CacheDebug.Similarity + } + + t.Logf("βœ… Third request was served from semantic cache! Cache threshold: %f, Cache similarity: %f", threshold, similarity) + } + } + + if !semanticMatch { + t.Error("Semantic match expected but not found") + return + } + + t.Log("βœ… Direct vs semantic search test completed!") +} + +// TestNoCacheScenarios tests scenarios where caching should NOT occur +func TestNoCacheScenarios(t *testing.T) { + setup := NewTestSetup(t) + defer setup.Cleanup() + + ctx := CreateContextWithCacheKey("no-cache-test") + + // Test Case 1: Different parameters should NOT cache hit + t.Log("=== Test Case 1: Different Parameters ===") + + basePrompt := "What is the capital of France?" + + // First request + request1 := CreateBasicChatRequest(basePrompt, 0.1, 50) + _, err1 := setup.Client.ChatCompletionRequest(ctx, request1) + if err1 != nil { + t.Fatalf("First request failed: %v", err1) + } + + WaitForCache() + + // Second request with different temperature + request2 := CreateBasicChatRequest(basePrompt, 0.9, 50) // Different temperature + response2, err2 := setup.Client.ChatCompletionRequest(ctx, request2) + if err2 != nil { + t.Fatalf("Second request failed: %v", err2) + } + + // Should NOT be cached + AssertNoCacheHit(t, response2) + + // Test Case 2: Different max_tokens should NOT cache hit + t.Log("\n=== Test Case 2: Different MaxTokens ===") + + request3 := CreateBasicChatRequest(basePrompt, 0.1, 200) // Different max_tokens + response3, err3 := setup.Client.ChatCompletionRequest(ctx, request3) + if err3 != nil { + t.Fatalf("Third request failed: %v", err3) + } + + // Should NOT be cached + AssertNoCacheHit(t, response3) + + t.Log("βœ… No cache scenarios test completed!") +} + +// TestCacheConfiguration tests different cache configuration options +func TestCacheConfiguration(t *testing.T) { + tests := []struct { + name string + config Config + expectedBehavior string + }{ + { + name: "High Threshold", + config: Config{ + Provider: schemas.OpenAI, + EmbeddingModel: "text-embedding-3-small", + Threshold: 0.95, // Very high threshold + Keys: []schemas.Key{ + {Value: os.Getenv("OPENAI_API_KEY"), Models: []string{}, Weight: 1.0}, + }, + }, + expectedBehavior: "strict_matching", + }, + { + name: "Low Threshold", + config: Config{ + Provider: schemas.OpenAI, + EmbeddingModel: "text-embedding-3-small", + Threshold: 0.1, // Very low threshold + Keys: []schemas.Key{ + {Value: os.Getenv("OPENAI_API_KEY"), Models: []string{}, Weight: 1.0}, + }, + }, + expectedBehavior: "loose_matching", + }, + { + name: "Custom TTL", + config: Config{ + Provider: schemas.OpenAI, + EmbeddingModel: "text-embedding-3-small", + Threshold: 0.8, + TTL: 1 * time.Hour, // Custom TTL + Keys: []schemas.Key{ + {Value: os.Getenv("OPENAI_API_KEY"), Models: []string{}, Weight: 1.0}, + }, + }, + expectedBehavior: "custom_ttl", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + setup := NewTestSetupWithConfig(t, tt.config) + defer setup.Cleanup() + + ctx := CreateContextWithCacheKey("config-test-" + tt.name) + + // Basic functionality test with the configuration + testRequest := CreateBasicChatRequest("Test configuration: "+tt.name, 0.5, 50) + + _, err1 := setup.Client.ChatCompletionRequest(ctx, testRequest) + if err1 != nil { + t.Fatalf("First request failed: %v", err1) + } + + WaitForCache() + + _, err2 := setup.Client.ChatCompletionRequest(ctx, testRequest) + if err2 != nil { + t.Fatalf("Second request failed: %v", err2) + } + + t.Logf("βœ… Configuration test '%s' completed", tt.name) + }) + } +} + +// MockUnsupportedStore is a mock store that returns ErrNotSupported for semantic operations +type MockUnsupportedStore struct { + vectorstore.VectorStore // Embed interface to implement all methods +} + +func (m *MockUnsupportedStore) SearchSemanticCache(ctx context.Context, queryEmbedding []float32, metadata map[string]interface{}, threshold float64, limit int64) ([]vectorstore.SearchResult, error) { + return nil, vectorstore.ErrNotSupported +} + +func (m *MockUnsupportedStore) AddSemanticCache(ctx context.Context, key string, embedding []float32, metadata map[string]interface{}, ttl time.Duration) error { + return vectorstore.ErrNotSupported +} + +func (m *MockUnsupportedStore) EnsureSemanticIndex(ctx context.Context, keyPrefix string, embeddingDim int, metadataFields []string) error { + return vectorstore.ErrNotSupported +} + +func (m *MockUnsupportedStore) Close(ctx context.Context) error { + return nil +} diff --git a/plugins/semanticcache/plugin_cross_cache_test.go b/plugins/semanticcache/plugin_cross_cache_test.go new file mode 100644 index 000000000..6c6bff0b4 --- /dev/null +++ b/plugins/semanticcache/plugin_cross_cache_test.go @@ -0,0 +1,310 @@ +package semanticcache + +import ( + "context" + "testing" +) + +// TestCrossCacheTypeAccessibility tests that entries cached one way are accessible another way +func TestCrossCacheTypeAccessibility(t *testing.T) { + setup := NewTestSetup(t) + defer setup.Cleanup() + + testRequest := CreateBasicChatRequest("What is artificial intelligence?", 0.7, 100) + + // Test 1: Cache with default behavior (both direct + semantic) + ctx1 := CreateContextWithCacheKey("test-cross-cache-access") + t.Log("Caching with default behavior (both direct + semantic)...") + response1, err1 := setup.Client.ChatCompletionRequest(ctx1, testRequest) + if err1 != nil { + t.Fatalf("First request failed: %v", err1) + } + AssertNoCacheHit(t, response1) + + WaitForCache() + + // Test 2: Retrieve with direct-only cache type + ctx2 := CreateContextWithCacheKeyAndType("test-cross-cache-access", CacheTypeDirect) + t.Log("Retrieving with CacheTypeKey=direct...") + response2, err2 := setup.Client.ChatCompletionRequest(ctx2, testRequest) + if err2 != nil { + t.Fatalf("Second request failed: %v", err2) + } + AssertCacheHit(t, response2, "direct") // Should find direct match + + // Test 3: Retrieve with semantic-only cache type + ctx3 := CreateContextWithCacheKeyAndType("test-cross-cache-access", CacheTypeSemantic) + t.Log("Retrieving with CacheTypeKey=semantic...") + response3, err3 := setup.Client.ChatCompletionRequest(ctx3, testRequest) + if err3 != nil { + t.Fatalf("Third request failed: %v", err3) + } + AssertCacheHit(t, response3, "semantic") // Should find semantic match + + t.Log("βœ… Entries cached with default behavior are accessible via both cache types") +} + +// TestCacheTypeIsolation tests that entries cached separately by type behave correctly +func TestCacheTypeIsolation(t *testing.T) { + setup := NewTestSetup(t) + defer setup.Cleanup() + + testRequest := CreateBasicChatRequest("Define blockchain technology", 0.7, 100) + + // Clear cache to start fresh + clearTestKeysWithStore(t, setup.Store) + + // Test 1: Cache with direct-only + ctx1 := CreateContextWithCacheKeyAndType("test-cache-isolation", CacheTypeDirect) + t.Log("Caching with CacheTypeKey=direct only...") + response1, err1 := setup.Client.ChatCompletionRequest(ctx1, testRequest) + if err1 != nil { + t.Fatalf("First request failed: %v", err1) + } + AssertNoCacheHit(t, response1) // Fresh request + + WaitForCache() + + // Test 2: Try to retrieve with semantic-only (should miss because no semantic entry) + ctx2 := CreateContextWithCacheKeyAndType("test-cache-isolation", CacheTypeSemantic) + t.Log("Retrieving same request with CacheTypeKey=semantic (should miss)...") + response2, err2 := setup.Client.ChatCompletionRequest(ctx2, testRequest) + if err2 != nil { + t.Fatalf("Second request failed: %v", err2) + } + AssertNoCacheHit(t, response2) // Should miss - no semantic cache entry + + WaitForCache() + + // Test 3: Retrieve with direct-only (should hit) + t.Log("Retrieving with CacheTypeKey=direct (should hit)...") + response3, err3 := setup.Client.ChatCompletionRequest(ctx1, testRequest) + if err3 != nil { + t.Fatalf("Third request failed: %v", err3) + } + AssertCacheHit(t, response3, "direct") // Should hit direct cache + + // Test 4: Default behavior (should find the direct cache) + ctx4 := CreateContextWithCacheKey("test-cache-isolation") + t.Log("Retrieving with default behavior (should find direct cache)...") + response4, err4 := setup.Client.ChatCompletionRequest(ctx4, testRequest) + if err4 != nil { + t.Fatalf("Fourth request failed: %v", err4) + } + AssertCacheHit(t, response4, "direct") // Should find existing direct cache + + t.Log("βœ… Cache type isolation works correctly") +} + +// TestCacheTypeFallbackBehavior tests whether cache types fallback to each other +func TestCacheTypeFallbackBehavior(t *testing.T) { + setup := NewTestSetup(t) + defer setup.Cleanup() + + // Cache an entry with default behavior + originalRequest := CreateBasicChatRequest("Explain machine learning", 0.7, 100) + ctx1 := CreateContextWithCacheKey("test-fallback-behavior") + + t.Log("Caching with default behavior...") + response1, err1 := setup.Client.ChatCompletionRequest(ctx1, originalRequest) + if err1 != nil { + t.Fatalf("First request failed: %v", err1) + } + AssertNoCacheHit(t, response1) + + WaitForCache() + + // Test similar request with direct-only (should miss direct, no fallback, but should cache response) + similarRequest := CreateBasicChatRequest("Explain machine learning concepts", 0.7, 100) + ctx2 := CreateContextWithCacheKeyAndType("test-fallback-behavior", CacheTypeDirect) + + t.Log("Testing similar request with CacheTypeKey=direct (should miss, make request, cache without embeddings)...") + response2, err2 := setup.Client.ChatCompletionRequest(ctx2, similarRequest) + if err2 != nil { + t.Fatalf("Second request failed: %v", err2) + } + AssertNoCacheHit(t, response2) // Should miss - no direct match, no semantic search + + WaitForCache() // Let the response get cached + + // Test same similar request with semantic-only (should hit original entry) + ctx3 := CreateContextWithCacheKeyAndType("test-fallback-behavior", CacheTypeSemantic) + + t.Log("Testing similar request with CacheTypeKey=semantic (should find semantic match from step 1)...") + response3, err3 := setup.Client.ChatCompletionRequest(ctx3, similarRequest) + if err3 != nil { + t.Fatalf("Third request failed: %v", err3) + } + + // Should find semantic match from step 1's cached entry (which has embeddings) + if response3.ExtraFields.CacheDebug != nil && response3.ExtraFields.CacheDebug.CacheHit { + AssertCacheHit(t, response3, "semantic") + t.Log("βœ… Semantic search found similar entry from step 1") + } else { + AssertNoCacheHit(t, response3) + t.Log("ℹ️ No semantic match found (threshold may be too high or semantic similarity low)") + } + + // Test a different similar request with default behavior (try both, fallback to semantic) + // Use a slightly different request to avoid hitting the cached response from step 2 + differentSimilarRequest := CreateBasicChatRequest("Explain the basics of machine learning", 0.7, 100) + ctx4 := CreateContextWithCacheKey("test-fallback-behavior") + + t.Log("Testing different similar request with default behavior (direct miss -> semantic fallback)...") + response4, err4 := setup.Client.ChatCompletionRequest(ctx4, differentSimilarRequest) + if err4 != nil { + t.Fatalf("Fourth request failed: %v", err4) + } + + // Should try direct first (miss), then semantic (might hit) + if response4.ExtraFields.CacheDebug != nil && response4.ExtraFields.CacheDebug.CacheHit { + AssertCacheHit(t, response4, "semantic") + t.Log("βœ… Default behavior found semantic fallback") + } else { + AssertNoCacheHit(t, response4) + t.Log("ℹ️ No fallback match found") + } + + t.Log("βœ… Cache type fallback behavior verified") +} + +// TestMultipleCacheEntriesPriority tests behavior when multiple cache entries exist +func TestMultipleCacheEntriesPriority(t *testing.T) { + setup := NewTestSetup(t) + defer setup.Cleanup() + + testRequest := CreateBasicChatRequest("What is deep learning?", 0.7, 100) + + // Create cache entry with default behavior first + ctx1 := CreateContextWithCacheKey("test-cache-priority") + t.Log("Creating cache entry with default behavior...") + response1, err1 := setup.Client.ChatCompletionRequest(ctx1, testRequest) + if err1 != nil { + t.Fatalf("First request failed: %v", err1) + } + AssertNoCacheHit(t, response1) + originalContent := *response1.Choices[0].Message.Content.ContentStr + + WaitForCache() + + // Verify it hits cache with default behavior + t.Log("Verifying cache hit with default behavior...") + response2, err2 := setup.Client.ChatCompletionRequest(ctx1, testRequest) + if err2 != nil { + t.Fatalf("Second request failed: %v", err2) + } + AssertCacheHit(t, response2, "direct") // Should hit direct cache + cachedContent := *response2.Choices[0].Message.Content.ContentStr + + // Verify content is the same + if originalContent != cachedContent { + t.Errorf("Cache content mismatch:\nOriginal: %s\nCached: %s", originalContent, cachedContent) + } + + // Test with direct-only access + ctx2 := CreateContextWithCacheKeyAndType("test-cache-priority", CacheTypeDirect) + t.Log("Accessing with CacheTypeKey=direct...") + response3, err3 := setup.Client.ChatCompletionRequest(ctx2, testRequest) + if err3 != nil { + t.Fatalf("Third request failed: %v", err3) + } + AssertCacheHit(t, response3, "direct") // Should find direct cache + + // Test with semantic-only access + ctx3 := CreateContextWithCacheKeyAndType("test-cache-priority", CacheTypeSemantic) + t.Log("Accessing with CacheTypeKey=semantic...") + response4, err4 := setup.Client.ChatCompletionRequest(ctx3, testRequest) + if err4 != nil { + t.Fatalf("Fourth request failed: %v", err4) + } + AssertCacheHit(t, response4, "semantic") // Should find semantic cache + + t.Log("βœ… Multiple cache entries accessible correctly") +} + +// TestCrossCacheTypeWithDifferentParameters tests cache type behavior with parameter variations +func TestCrossCacheTypeWithDifferentParameters(t *testing.T) { + setup := NewTestSetup(t) + defer setup.Cleanup() + + baseMessage := "Explain quantum computing" + + // Cache with specific parameters + request1 := CreateBasicChatRequest(baseMessage, 0.7, 100) + ctx1 := CreateContextWithCacheKey("test-cross-cache-params") + + t.Log("Caching with temp=0.7, max_tokens=100...") + response1, err1 := setup.Client.ChatCompletionRequest(ctx1, request1) + if err1 != nil { + t.Fatalf("First request failed: %v", err1) + } + AssertNoCacheHit(t, response1) + + WaitForCache() + + // Test same parameters with direct-only + ctx2 := CreateContextWithCacheKeyAndType("test-cross-cache-params", CacheTypeDirect) + t.Log("Retrieving same parameters with CacheTypeKey=direct...") + response2, err2 := setup.Client.ChatCompletionRequest(ctx2, request1) + if err2 != nil { + t.Fatalf("Second request failed: %v", err2) + } + AssertCacheHit(t, response2, "direct") // Should hit + + // Test different parameters - should miss + request3 := CreateBasicChatRequest(baseMessage, 0.5, 200) // Different temp and tokens + t.Log("Testing different parameters (should miss)...") + response3, err3 := setup.Client.ChatCompletionRequest(ctx2, request3) + if err3 != nil { + t.Fatalf("Third request failed: %v", err3) + } + AssertNoCacheHit(t, response3) // Should miss due to different params + + // Test semantic search with different parameters + ctx4 := CreateContextWithCacheKeyAndType("test-cross-cache-params", CacheTypeSemantic) + similarRequest := CreateBasicChatRequest("Can you explain quantum computing", 0.5, 200) + + t.Log("Testing semantic search with different params and similar message...") + response4, err4 := setup.Client.ChatCompletionRequest(ctx4, similarRequest) + if err4 != nil { + t.Fatalf("Fourth request failed: %v", err4) + } + // Should miss semantic search due to different parameters (params_hash different) + AssertNoCacheHit(t, response4) + + t.Log("βœ… Cross-cache-type parameter handling works correctly") +} + +// TestCacheTypeErrorHandling tests error scenarios with cache types +func TestCacheTypeErrorHandling(t *testing.T) { + setup := NewTestSetup(t) + defer setup.Cleanup() + + testRequest := CreateBasicChatRequest("Test error handling", 0.7, 50) + + // Test invalid cache type (should fallback to default) + ctx1 := CreateContextWithCacheKey("test-cache-error-handling") + ctx1 = context.WithValue(ctx1, CacheTypeKey, "invalid_cache_type") + + t.Log("Testing invalid cache type (should fallback to default behavior)...") + response1, err1 := setup.Client.ChatCompletionRequest(ctx1, testRequest) + if err1 != nil { + t.Fatalf("First request failed: %v", err1) + } + AssertNoCacheHit(t, response1) // Should work with fallback behavior + + WaitForCache() + + // Test nil cache type (should use default) + ctx2 := CreateContextWithCacheKey("test-cache-error-handling") + ctx2 = context.WithValue(ctx2, CacheTypeKey, nil) + + t.Log("Testing nil cache type (should use default behavior)...") + response2, err2 := setup.Client.ChatCompletionRequest(ctx2, testRequest) + if err2 != nil { + t.Fatalf("Second request failed: %v", err2) + } + AssertCacheHit(t, response2, "direct") // Should find cached entry from first request + + t.Log("βœ… Cache type error handling works correctly") +} diff --git a/plugins/semanticcache/plugin_edge_cases_test.go b/plugins/semanticcache/plugin_edge_cases_test.go new file mode 100644 index 000000000..7136e6ec3 --- /dev/null +++ b/plugins/semanticcache/plugin_edge_cases_test.go @@ -0,0 +1,648 @@ +package semanticcache + +import ( + "context" + "strings" + "testing" + + bifrost "github.com/maximhq/bifrost/core" + "github.com/maximhq/bifrost/core/schemas" +) + +// TestParameterVariations tests that different parameters don't cache hit inappropriately +func TestParameterVariations(t *testing.T) { + setup := NewTestSetup(t) + defer setup.Cleanup() + + ctx := CreateContextWithCacheKey("param-variations-test") + basePrompt := "What is the capital of France?" + + tests := []struct { + name string + request1 *schemas.BifrostRequest + request2 *schemas.BifrostRequest + shouldCache bool + }{ + { + name: "Same Parameters", + request1: CreateBasicChatRequest(basePrompt, 0.5, 50), + request2: CreateBasicChatRequest(basePrompt, 0.5, 50), + shouldCache: true, + }, + { + name: "Different Temperature", + request1: CreateBasicChatRequest(basePrompt, 0.1, 50), + request2: CreateBasicChatRequest(basePrompt, 0.9, 50), + shouldCache: false, + }, + { + name: "Different MaxTokens", + request1: CreateBasicChatRequest(basePrompt, 0.5, 50), + request2: CreateBasicChatRequest(basePrompt, 0.5, 200), + shouldCache: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + // Clear cache for this subtest + clearTestKeysWithStore(t, setup.Store) + + // Make first request + _, err1 := setup.Client.ChatCompletionRequest(ctx, tt.request1) + if err1 != nil { + t.Fatalf("First request failed: %v", err1) + } + + WaitForCache() + + // Make second request + response2, err2 := setup.Client.ChatCompletionRequest(ctx, tt.request2) + if err2 != nil { + t.Fatalf("Second request failed: %v", err2) + } + + // Check cache behavior + if tt.shouldCache { + AssertCacheHit(t, response2, string(CacheTypeDirect)) + } else { + AssertNoCacheHit(t, response2) + } + }) + } +} + +// TestToolVariations tests caching behavior with different tool configurations +func TestToolVariations(t *testing.T) { + setup := NewTestSetup(t) + defer setup.Cleanup() + + ctx := CreateContextWithCacheKey("tool-variations-test") + + // Base request without tools + baseRequest := &schemas.BifrostRequest{ + Provider: schemas.OpenAI, + Model: "gpt-4o-mini", + Input: schemas.RequestInput{ + ChatCompletionInput: &[]schemas.BifrostMessage{ + { + Role: "user", + Content: schemas.MessageContent{ + ContentStr: bifrost.Ptr("What's the weather like today?"), + }, + }, + }, + }, + Params: &schemas.ModelParameters{ + Temperature: bifrost.Ptr(0.5), + MaxTokens: bifrost.Ptr(100), + }, + } + + // Request with tools + requestWithTools := &schemas.BifrostRequest{ + Provider: schemas.OpenAI, + Model: "gpt-4o-mini", + Input: schemas.RequestInput{ + ChatCompletionInput: &[]schemas.BifrostMessage{ + { + Role: "user", + Content: schemas.MessageContent{ + ContentStr: bifrost.Ptr("What's the weather like today?"), + }, + }, + }, + }, + Params: &schemas.ModelParameters{ + Temperature: bifrost.Ptr(0.5), + MaxTokens: bifrost.Ptr(100), + Tools: &[]schemas.Tool{ + { + Type: "function", + Function: schemas.Function{ + Name: "get_weather", + Description: "Get the current weather", + Parameters: schemas.FunctionParameters{ + Type: "object", + Properties: map[string]interface{}{ + "location": map[string]interface{}{ + "type": "string", + "description": "The city and state", + }, + }, + }, + }, + }, + }, + }, + } + + // Request with different tools + requestWithDifferentTools := &schemas.BifrostRequest{ + Provider: schemas.OpenAI, + Model: "gpt-4o-mini", + Input: schemas.RequestInput{ + ChatCompletionInput: &[]schemas.BifrostMessage{ + { + Role: "user", + Content: schemas.MessageContent{ + ContentStr: bifrost.Ptr("What's the weather like today?"), + }, + }, + }, + }, + Params: &schemas.ModelParameters{ + Temperature: bifrost.Ptr(0.5), + MaxTokens: bifrost.Ptr(100), + Tools: &[]schemas.Tool{ + { + Type: "function", + Function: schemas.Function{ + Name: "get_current_weather", // Different name + Description: "Get current weather information", + Parameters: schemas.FunctionParameters{ + Type: "object", + Properties: map[string]interface{}{ + "city": map[string]interface{}{ // Different parameter name + "type": "string", + "description": "The city name", + }, + }, + }, + }, + }, + }, + }, + } + + // Test 1: Request without tools + t.Log("Making request without tools...") + _, err1 := setup.Client.ChatCompletionRequest(ctx, baseRequest) + if err1 != nil { + t.Fatalf("Request without tools failed: %v", err1) + } + + WaitForCache() + + // Test 2: Request with tools (should NOT cache hit) + t.Log("Making request with tools...") + response2, err2 := setup.Client.ChatCompletionRequest(ctx, requestWithTools) + if err2 != nil { + t.Fatalf("Request with tools failed: %v", err2) + } + + AssertNoCacheHit(t, response2) + + WaitForCache() + + // Test 3: Same request with tools (should cache hit) + t.Log("Making same request with tools again...") + response3, err3 := setup.Client.ChatCompletionRequest(ctx, requestWithTools) + if err3 != nil { + t.Fatalf("Second request with tools failed: %v", err3) + } + + AssertCacheHit(t, response3, string(CacheTypeDirect)) + + // Test 4: Request with different tools (should NOT cache hit) + t.Log("Making request with different tools...") + response4, err4 := setup.Client.ChatCompletionRequest(ctx, requestWithDifferentTools) + if err4 != nil { + t.Fatalf("Request with different tools failed: %v", err4) + } + + AssertNoCacheHit(t, response4) + + t.Log("βœ… Tool variations test completed!") +} + +// TestContentVariations tests caching behavior with different content types +func TestContentVariations(t *testing.T) { + setup := NewTestSetup(t) + defer setup.Cleanup() + + ctx := CreateContextWithCacheKey("content-variations-test") + + tests := []struct { + name string + request *schemas.BifrostRequest + }{ + { + name: "Unicode Content", + request: &schemas.BifrostRequest{ + Provider: schemas.OpenAI, + Model: "gpt-4o-mini", + Input: schemas.RequestInput{ + ChatCompletionInput: &[]schemas.BifrostMessage{ + { + Role: "user", + Content: schemas.MessageContent{ + ContentStr: bifrost.Ptr("🌟 Unicode test: Hello, δΈ–η•Œ! Ω…Ψ±Ψ­Ψ¨Ψ§ 🌍"), + }, + }, + }, + }, + Params: &schemas.ModelParameters{ + Temperature: bifrost.Ptr(0.1), + MaxTokens: bifrost.Ptr(50), + }, + }, + }, + { + name: "Image URL Content", + request: &schemas.BifrostRequest{ + Provider: schemas.OpenAI, + Model: "gpt-4o-mini", + Input: schemas.RequestInput{ + ChatCompletionInput: &[]schemas.BifrostMessage{ + { + Role: "user", + Content: schemas.MessageContent{ + ContentBlocks: &[]schemas.ContentBlock{ + { + Type: "text", + Text: bifrost.Ptr("Analyze this image"), + }, + { + Type: "image_url", + ImageURL: &schemas.ImageURLStruct{ + URL: "https://upload.wikimedia.org/wikipedia/commons/thumb/d/dd/Gfp-wisconsin-madison-the-nature-boardwalk.jpg/2560px-Gfp-wisconsin-madison-the-nature-boardwalk.jpg", + }, + }, + }, + }, + }, + }, + }, + Params: &schemas.ModelParameters{ + Temperature: bifrost.Ptr(0.3), + MaxTokens: bifrost.Ptr(200), + }, + }, + }, + { + name: "Multiple Images", + request: &schemas.BifrostRequest{ + Provider: schemas.OpenAI, + Model: "gpt-4o-mini", + Input: schemas.RequestInput{ + ChatCompletionInput: &[]schemas.BifrostMessage{ + { + Role: "user", + Content: schemas.MessageContent{ + ContentBlocks: &[]schemas.ContentBlock{ + { + Type: "text", + Text: bifrost.Ptr("Compare these images"), + }, + { + Type: "image_url", + ImageURL: &schemas.ImageURLStruct{ + URL: "https://upload.wikimedia.org/wikipedia/commons/thumb/d/dd/Gfp-wisconsin-madison-the-nature-boardwalk.jpg/2560px-Gfp-wisconsin-madison-the-nature-boardwalk.jpg", + }, + }, + { + Type: "image_url", + ImageURL: &schemas.ImageURLStruct{ + URL: "https://upload.wikimedia.org/wikipedia/commons/b/b5/Scenery_.jpg", + }, + }, + }, + }, + }, + }, + }, + Params: &schemas.ModelParameters{ + Temperature: bifrost.Ptr(0.3), + MaxTokens: bifrost.Ptr(200), + }, + }, + }, + { + name: "Very Long Content", + request: &schemas.BifrostRequest{ + Provider: schemas.OpenAI, + Model: "gpt-4o-mini", + Input: schemas.RequestInput{ + ChatCompletionInput: &[]schemas.BifrostMessage{ + { + Role: "user", + Content: schemas.MessageContent{ + ContentStr: bifrost.Ptr(strings.Repeat("This is a very long prompt. ", 100)), + }, + }, + }, + }, + Params: &schemas.ModelParameters{ + Temperature: bifrost.Ptr(0.2), + MaxTokens: bifrost.Ptr(50), + }, + }, + }, + { + name: "Multi-turn Conversation", + request: &schemas.BifrostRequest{ + Provider: schemas.OpenAI, + Model: "gpt-4o-mini", + Input: schemas.RequestInput{ + ChatCompletionInput: &[]schemas.BifrostMessage{ + { + Role: "user", + Content: schemas.MessageContent{ + ContentStr: bifrost.Ptr("What is AI?"), + }, + }, + { + Role: "assistant", + Content: schemas.MessageContent{ + ContentStr: bifrost.Ptr("AI stands for Artificial Intelligence..."), + }, + }, + { + Role: "user", + Content: schemas.MessageContent{ + ContentStr: bifrost.Ptr("Can you give me examples?"), + }, + }, + }, + }, + Params: &schemas.ModelParameters{ + Temperature: bifrost.Ptr(0.5), + MaxTokens: bifrost.Ptr(150), + }, + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Logf("Testing content variation: %s", tt.name) + + // Make first request + _, err1 := setup.Client.ChatCompletionRequest(ctx, tt.request) + if err1 != nil { + t.Logf("⚠️ First %s request failed: %v", tt.name, err1) + return // Skip this test case + } + + WaitForCache() + + // Make second identical request + response2, err2 := setup.Client.ChatCompletionRequest(ctx, tt.request) + if err2 != nil { + t.Fatalf("Second %s request failed: %v", tt.name, err2) + } + + // Should be cached + AssertCacheHit(t, response2, string(CacheTypeDirect)) + t.Logf("βœ… %s content variation successful", tt.name) + }) + } +} + +// TestBoundaryParameterValues tests edge case parameter values +func TestBoundaryParameterValues(t *testing.T) { + setup := NewTestSetup(t) + defer setup.Cleanup() + + ctx := CreateContextWithCacheKey("boundary-params-test") + + tests := []struct { + name string + request *schemas.BifrostRequest + }{ + { + name: "Maximum Parameter Values", + request: &schemas.BifrostRequest{ + Provider: schemas.OpenAI, + Model: "gpt-4o-mini", + Input: schemas.RequestInput{ + ChatCompletionInput: &[]schemas.BifrostMessage{ + { + Role: "user", + Content: schemas.MessageContent{ + ContentStr: bifrost.Ptr("Test max parameters"), + }, + }, + }, + }, + Params: &schemas.ModelParameters{ + Temperature: bifrost.Ptr(2.0), + MaxTokens: bifrost.Ptr(4096), + TopP: bifrost.Ptr(1.0), + PresencePenalty: bifrost.Ptr(2.0), + FrequencyPenalty: bifrost.Ptr(2.0), + }, + }, + }, + { + name: "Minimum Parameter Values", + request: &schemas.BifrostRequest{ + Provider: schemas.OpenAI, + Model: "gpt-4o-mini", + Input: schemas.RequestInput{ + ChatCompletionInput: &[]schemas.BifrostMessage{ + { + Role: "user", + Content: schemas.MessageContent{ + ContentStr: bifrost.Ptr("Test min parameters"), + }, + }, + }, + }, + Params: &schemas.ModelParameters{ + Temperature: bifrost.Ptr(0.0), + MaxTokens: bifrost.Ptr(1), + TopP: bifrost.Ptr(0.01), + PresencePenalty: bifrost.Ptr(-2.0), + FrequencyPenalty: bifrost.Ptr(-2.0), + }, + }, + }, + { + name: "Edge Case Parameters", + request: &schemas.BifrostRequest{ + Provider: schemas.OpenAI, + Model: "gpt-4o-mini", + Input: schemas.RequestInput{ + ChatCompletionInput: &[]schemas.BifrostMessage{ + { + Role: "user", + Content: schemas.MessageContent{ + ContentStr: bifrost.Ptr("Test edge case parameters"), + }, + }, + }, + }, + Params: &schemas.ModelParameters{ + Temperature: bifrost.Ptr(0.0), + MaxTokens: bifrost.Ptr(1), + TopP: bifrost.Ptr(0.1), + User: bifrost.Ptr("test-user-id-12345"), + }, + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Logf("Testing boundary parameters: %s", tt.name) + + _, err := setup.Client.ChatCompletionRequest(ctx, tt.request) + if err != nil { + t.Logf("⚠️ %s request failed (may be expected): %v", tt.name, err) + } else { + t.Logf("βœ… %s handled gracefully", tt.name) + } + }) + } +} + +// TestSemanticSimilarityEdgeCases tests edge cases in semantic similarity matching +func TestSemanticSimilarityEdgeCases(t *testing.T) { + setup := NewTestSetup(t) + defer setup.Cleanup() + + setup.Config.Threshold = 0.9 + + ctx := CreateContextWithCacheKey("semantic-edge-test") + + // Test case: Similar questions with different wording + similarTests := []struct { + prompt1 string + prompt2 string + shouldMatch bool + description string + }{ + { + prompt1: "What is machine learning?", + prompt2: "Can you explain machine learning?", + shouldMatch: true, + description: "Similar questions about ML", + }, + { + prompt1: "How does AI work?", + prompt2: "Explain artificial intelligence", + shouldMatch: true, + description: "AI-related questions", + }, + { + prompt1: "What is the weather today?", + prompt2: "What do you know about bifrost?", + shouldMatch: false, + description: "Completely different topics", + }, + { + prompt1: "Hello, how are you?", + prompt2: "Hi, how are you doing?", + shouldMatch: true, + description: "Similar greetings", + }, + } + + for i, test := range similarTests { + t.Run(test.description, func(t *testing.T) { + // Clear cache for this subtest + clearTestKeysWithStore(t, setup.Store) + + // Make first request + request1 := CreateBasicChatRequest(test.prompt1, 0.1, 50) + _, err1 := setup.Client.ChatCompletionRequest(ctx, request1) + if err1 != nil { + t.Fatalf("First request failed: %v", err1) + } + + // Wait for cache to be written + WaitForCache() + + // Make second request with similar content + request2 := CreateBasicChatRequest(test.prompt2, 0.1, 50) // Same parameters + response2, err2 := setup.Client.ChatCompletionRequest(ctx, request2) + if err2 != nil { + t.Fatalf("Second request failed: %v", err2) + } + + var cacheThresholdFloat float64 + var cacheSimilarityFloat float64 + + // Check if semantic matching occurred + semanticMatch := false + if response2.ExtraFields.CacheDebug != nil && response2.ExtraFields.CacheDebug.CacheHit { + if response2.ExtraFields.CacheDebug.HitType != nil && *response2.ExtraFields.CacheDebug.HitType == string(CacheTypeSemantic) { + semanticMatch = true + + if response2.ExtraFields.CacheDebug.Threshold != nil { + cacheThresholdFloat = *response2.ExtraFields.CacheDebug.Threshold + } + if response2.ExtraFields.CacheDebug.Similarity != nil { + cacheSimilarityFloat = *response2.ExtraFields.CacheDebug.Similarity + } + } + } + + if test.shouldMatch { + if semanticMatch { + t.Logf("βœ… Test %d: Semantic match found as expected for '%s'", i+1, test.description) + } else { + t.Logf("ℹ️ Test %d: No semantic match found for '%s', check with threshold: %f and found similarity: %f", i+1, test.description, cacheThresholdFloat, cacheSimilarityFloat) + } + } else { + if semanticMatch { + t.Errorf("❌ Test %d: Unexpected semantic match for different topics: '%s', check with threshold: %f and found similarity: %f", i+1, test.description, cacheThresholdFloat, cacheSimilarityFloat) + } else { + t.Logf("βœ… Test %d: Correctly no semantic match for different topics: '%s'", i+1, test.description) + } + } + }) + } +} + +// TestErrorHandlingEdgeCases tests various error scenarios +func TestErrorHandlingEdgeCases(t *testing.T) { + setup := NewTestSetup(t) + defer setup.Cleanup() + + testRequest := CreateBasicChatRequest("Test error handling scenarios", 0.5, 50) + + // Test without cache key (should not crash and bypass cache) + t.Run("Request without cache key", func(t *testing.T) { + ctxNoKey := context.Background() // No cache key + + response, err := setup.Client.ChatCompletionRequest(ctxNoKey, testRequest) + if err != nil { + t.Errorf("Request without cache key failed: %v", err) + return + } + + // Should bypass cache since there's no cache key + AssertNoCacheHit(t, response) + t.Log("βœ… Request without cache key correctly bypassed cache") + }) + + // Test with invalid cache key type + t.Run("Request with invalid cache key type", func(t *testing.T) { + // First establish a cached response with valid context + validCtx := CreateContextWithCacheKey("error-handling-test") + _, err := setup.Client.ChatCompletionRequest(validCtx, testRequest) + if err != nil { + t.Fatalf("First request with valid cache key failed: %v", err) + } + + WaitForCache() + + // Now test with invalid key type - should bypass cache + ctxInvalidKey := context.WithValue(context.Background(), CacheKey, 12345) // Wrong type (int instead of string) + + response, err := setup.Client.ChatCompletionRequest(ctxInvalidKey, testRequest) + if err != nil { + t.Errorf("Request with invalid cache key type failed: %v", err) + return + } + + // Should bypass cache due to invalid key type + AssertNoCacheHit(t, response) + t.Log("βœ… Request with invalid cache key type correctly bypassed cache") + }) + + t.Log("βœ… Error handling edge cases completed!") +} diff --git a/plugins/semanticcache/plugin_embedding_test.go b/plugins/semanticcache/plugin_embedding_test.go new file mode 100644 index 000000000..c6458b94d --- /dev/null +++ b/plugins/semanticcache/plugin_embedding_test.go @@ -0,0 +1,168 @@ +package semanticcache + +import ( + "testing" + "time" +) + +// TestEmbeddingRequestsCaching tests that embedding requests are properly cached using direct hash matching +func TestEmbeddingRequestsCaching(t *testing.T) { + setup := NewTestSetup(t) + defer setup.Cleanup() + + ctx := CreateContextWithCacheKey("test-embedding-cache") + + // Create embedding request + embeddingRequest := CreateEmbeddingRequest([]string{ + "What is machine learning?", + "Explain artificial intelligence in simple terms.", + }) + + t.Log("Making first embedding request (should go to OpenAI and be cached)...") + + // Make first request (will go to OpenAI and be cached) + start1 := time.Now() + response1, err1 := setup.Client.EmbeddingRequest(ctx, embeddingRequest) + duration1 := time.Since(start1) + + if err1 != nil { + t.Fatalf("First embedding request failed: %v", err1) + } + + if response1 == nil || len(response1.Data) == 0 { + t.Fatal("First embedding response is invalid") + } + + t.Logf("First embedding request completed in %v", duration1) + t.Logf("Response contains %d embeddings", len(response1.Data)) + + // Wait for cache to be written + WaitForCache() + + t.Log("Making second identical embedding request (should be served from cache)...") + + // Make second identical request (should be cached) + start2 := time.Now() + response2, err2 := setup.Client.EmbeddingRequest(ctx, embeddingRequest) + duration2 := time.Since(start2) + + if err2 != nil { + t.Fatalf("Second embedding request failed: %v", err2) + } + + if response2 == nil || len(response2.Data) == 0 { + t.Fatal("Second embedding response is invalid") + } + + // Verify cache hit + AssertCacheHit(t, response2, "direct") + + t.Logf("Second embedding request completed in %v", duration2) + + // Cache should be significantly faster + if duration2 >= duration1 { // Allow some margin but cache should be much faster + t.Log("⚠️ Cache doesn't seem faster, but this could be due to test environment") + } + + // Responses should be identical + if len(response1.Data) != len(response2.Data) { + t.Errorf("Response lengths differ: %d vs %d", len(response1.Data), len(response2.Data)) + } + + t.Log("βœ… Embedding requests properly cached using direct hash matching") +} + +// TestEmbeddingRequestsNoCacheWithoutCacheKey tests that embedding requests without cache key are not cached +func TestEmbeddingRequestsNoCacheWithoutCacheKey(t *testing.T) { + setup := NewTestSetup(t) + defer setup.Cleanup() + + // Don't set cache key in context + ctx := CreateContextWithCacheKey("") + + embeddingRequest := CreateEmbeddingRequest([]string{"Test embedding without cache key"}) + + t.Log("Making embedding request without cache key...") + + response, err := setup.Client.EmbeddingRequest(ctx, embeddingRequest) + if err != nil { + t.Fatalf("Embedding request failed: %v", err) + } + + // Should not be cached + AssertNoCacheHit(t, response) + + t.Log("βœ… Embedding requests without cache key are properly not cached") +} + +// TestEmbeddingRequestsDifferentTexts tests that different embedding texts produce different cache entries +func TestEmbeddingRequestsDifferentTexts(t *testing.T) { + setup := NewTestSetup(t) + defer setup.Cleanup() + + ctx := CreateContextWithCacheKey("test-embedding-different") + + // Create two different embedding requests + request1 := CreateEmbeddingRequest([]string{"First set of texts"}) + request2 := CreateEmbeddingRequest([]string{"Second set of texts"}) + + t.Log("Making first embedding request...") + response1, err1 := setup.Client.EmbeddingRequest(ctx, request1) + if err1 != nil { + t.Fatalf("First request failed: %v", err1) + } + AssertNoCacheHit(t, response1) + + WaitForCache() + + t.Log("Making second different embedding request...") + response2, err2 := setup.Client.EmbeddingRequest(ctx, request2) + if err2 != nil { + t.Fatalf("Second request failed: %v", err2) + } + // Should not be a cache hit since texts are different + AssertNoCacheHit(t, response2) + + t.Log("βœ… Different embedding texts produce different cache entries") +} + +// TestEmbeddingRequestsCacheExpiration tests TTL functionality for embedding requests +func TestEmbeddingRequestsCacheExpiration(t *testing.T) { + setup := NewTestSetup(t) + defer setup.Cleanup() + + // Set very short TTL for testing + shortTTL := 2 * time.Second + ctx := CreateContextWithCacheKeyAndTTL("test-embedding-ttl", shortTTL) + + embeddingRequest := CreateEmbeddingRequest([]string{"TTL test embedding"}) + + t.Log("Making first embedding request with short TTL...") + response1, err1 := setup.Client.EmbeddingRequest(ctx, embeddingRequest) + if err1 != nil { + t.Fatalf("First request failed: %v", err1) + } + AssertNoCacheHit(t, response1) + + WaitForCache() + + t.Log("Making second request before TTL expiration...") + response2, err2 := setup.Client.EmbeddingRequest(ctx, embeddingRequest) + if err2 != nil { + t.Fatalf("Second request failed: %v", err2) + } + AssertCacheHit(t, response2, "direct") + + t.Logf("Waiting for TTL expiration (%v)...", shortTTL) + time.Sleep(shortTTL + 1*time.Second) // Wait for TTL to expire + + t.Log("Making third request after TTL expiration...") + response3, err3 := setup.Client.EmbeddingRequest(ctx, embeddingRequest) + if err3 != nil { + t.Fatalf("Third request failed: %v", err3) + } + // Should not be a cache hit since TTL expired + AssertNoCacheHit(t, response3) + + t.Log("βœ… Embedding requests properly handle TTL expiration") +} diff --git a/plugins/semanticcache/plugin_integration_test.go b/plugins/semanticcache/plugin_integration_test.go new file mode 100644 index 000000000..16c80e342 --- /dev/null +++ b/plugins/semanticcache/plugin_integration_test.go @@ -0,0 +1,693 @@ +package semanticcache + +import ( + "context" + "strings" + "testing" + "time" + + "github.com/google/uuid" + bifrost "github.com/maximhq/bifrost/core" + "github.com/maximhq/bifrost/core/schemas" +) + +// TestSemanticCacheBasicFlow tests the complete semantic cache flow +func TestSemanticCacheBasicFlow(t *testing.T) { + setup := NewTestSetup(t) + defer setup.Cleanup() + + ctx := context.Background() + + // Add cache key to context + ctx = context.WithValue(ctx, CacheKey, "test-cache-enabled") + ctx = context.WithValue(ctx, schemas.BifrostContextKeyRequestType, schemas.ChatCompletionRequest) + + // Test request + request := &schemas.BifrostRequest{ + Provider: schemas.OpenAI, + Model: "gpt-4o-mini", + Input: schemas.RequestInput{ + ChatCompletionInput: &[]schemas.BifrostMessage{ + { + Role: "user", + Content: schemas.MessageContent{ + ContentStr: bifrost.Ptr("Hello, world!"), + }, + }, + }, + }, + Params: &schemas.ModelParameters{ + Temperature: bifrost.Ptr(0.7), + MaxTokens: bifrost.Ptr(100), + }, + } + + t.Log("Testing first request (cache miss)...") + + // First request - should be a cache miss + modifiedReq, shortCircuit, err := setup.Plugin.PreHook(&ctx, request) + if err != nil { + t.Fatalf("PreHook failed: %v", err) + } + + if shortCircuit != nil { + t.Fatal("Expected cache miss, but got cache hit") + } + + if modifiedReq == nil { + t.Fatal("Modified request is nil") + } + + t.Log("βœ… Cache miss handled correctly") + + // Simulate a response + response := &schemas.BifrostResponse{ + ID: uuid.New().String(), + Choices: []schemas.BifrostResponseChoice{ + { + BifrostNonStreamResponseChoice: &schemas.BifrostNonStreamResponseChoice{ + Message: schemas.BifrostMessage{ + Role: "assistant", + Content: schemas.MessageContent{ + ContentStr: bifrost.Ptr("Hello! How can I help you today?"), + }, + }, + }, + }, + }, + ExtraFields: schemas.BifrostResponseExtraFields{ + Provider: schemas.OpenAI, + }, + } + + // Capture original response content for comparison + var originalContent string + if len(response.Choices) > 0 && response.Choices[0].Message.Content.ContentStr != nil { + originalContent = *response.Choices[0].Message.Content.ContentStr + } + if originalContent == "" { + t.Fatal("Original response content is empty") + } + t.Logf("Original response content: %s", originalContent) + + // Cache the response + t.Log("Caching response...") + _, _, err = setup.Plugin.PostHook(&ctx, response, nil) + if err != nil { + t.Fatalf("PostHook failed: %v", err) + } + + // Wait for async caching to complete + WaitForCache() + t.Log("βœ… Response cached successfully") + + // Second request - should be a cache hit + t.Log("Testing second identical request (expecting cache hit)...") + + // Reset context for second request + ctx2 := context.Background() + ctx2 = context.WithValue(ctx2, CacheKey, "test-cache-enabled") + ctx2 = context.WithValue(ctx2, schemas.BifrostContextKeyRequestType, schemas.ChatCompletionRequest) + + modifiedReq2, shortCircuit2, err := setup.Plugin.PreHook(&ctx2, request) + if err != nil { + t.Fatalf("Second PreHook failed: %v", err) + } + + if shortCircuit2 == nil { + t.Fatal("expected cache hit on identical request") + return + } + + if shortCircuit2.Response == nil { + t.Fatal("Cache hit but response is nil") + } + + if modifiedReq2 == nil { + t.Fatal("Modified request is nil on cache hit") + } + + t.Log("βœ… Cache hit detected and response returned") + + // Verify the cached response + if len(shortCircuit2.Response.Choices) == 0 { + t.Fatal("Cached response has no choices") + } + + cachedContent := shortCircuit2.Response.Choices[0].Message.Content.ContentStr + if cachedContent == nil || *cachedContent == "" { + t.Fatal("Cached response content is empty") + } + + t.Logf("βœ… Cached response content: %s", *cachedContent) + + // Compare original and cached content + cachedContentStr := *cachedContent + // Trim whitespace and newlines for comparison + originalContentTrimmed := strings.TrimSpace(originalContent) + cachedContentTrimmed := strings.TrimSpace(cachedContentStr) + + if originalContentTrimmed != cachedContentTrimmed { + t.Fatalf("❌ Content mismatch: original='%s', cached='%s'", originalContentTrimmed, cachedContentTrimmed) + } + + t.Log("βœ… Content verification passed - original and cached responses match") + t.Log("πŸŽ‰ Basic semantic cache flow test passed!") +} + +// TestSemanticCacheStrictFiltering tests that the cache respects parameter differences +func TestSemanticCacheStrictFiltering(t *testing.T) { + setup := NewTestSetup(t) + defer setup.Cleanup() + + ctx := context.Background() + ctx = context.WithValue(ctx, CacheKey, "test-cache-enabled") + ctx = context.WithValue(ctx, schemas.BifrostContextKeyRequestType, schemas.ChatCompletionRequest) + + // Base request + baseRequest := &schemas.BifrostRequest{ + Provider: schemas.OpenAI, + Model: "gpt-4o-mini", + Input: schemas.RequestInput{ + ChatCompletionInput: &[]schemas.BifrostMessage{ + { + Role: "user", + Content: schemas.MessageContent{ + ContentStr: bifrost.Ptr("What is the weather like?"), + }, + }, + }, + }, + Params: &schemas.ModelParameters{ + Temperature: bifrost.Ptr(0.7), + MaxTokens: bifrost.Ptr(100), + }, + } + + t.Log("Testing first request with temperature=0.7...") + + // First request + _, shortCircuit1, err := setup.Plugin.PreHook(&ctx, baseRequest) + if err != nil { + t.Fatalf("First PreHook failed: %v", err) + } + + if shortCircuit1 != nil { + t.Fatal("Expected cache miss for first request") + } + + // Cache a response + response := &schemas.BifrostResponse{ + ID: uuid.New().String(), + Choices: []schemas.BifrostResponseChoice{ + { + BifrostNonStreamResponseChoice: &schemas.BifrostNonStreamResponseChoice{ + Message: schemas.BifrostMessage{ + Role: "assistant", + Content: schemas.MessageContent{ + ContentStr: bifrost.Ptr("It's sunny today!"), + }, + }, + }, + }, + }, + ExtraFields: schemas.BifrostResponseExtraFields{ + Provider: schemas.OpenAI, + }, + } + + _, _, err = setup.Plugin.PostHook(&ctx, response, nil) + if err != nil { + t.Fatalf("PostHook failed: %v", err) + } + + WaitForCache() + t.Log("βœ… First response cached") + + // Second request with different temperature - should be cache miss + t.Log("Testing second request with temperature=0.5 (expecting cache miss)...") + + ctx2 := context.Background() + ctx2 = context.WithValue(ctx2, CacheKey, "test-cache-enabled") + ctx2 = context.WithValue(ctx2, schemas.BifrostContextKeyRequestType, schemas.ChatCompletionRequest) + + modifiedRequest := *baseRequest + modifiedRequest.Params = &schemas.ModelParameters{ + Temperature: bifrost.Ptr(0.5), // Different temperature + MaxTokens: bifrost.Ptr(100), + } + + _, shortCircuit2, err := setup.Plugin.PreHook(&ctx2, &modifiedRequest) + if err != nil { + t.Fatalf("Second PreHook failed: %v", err) + } + + if shortCircuit2 != nil { + t.Fatal("Expected cache miss due to different temperature, but got cache hit") + } + + t.Log("βœ… Strict filtering working - different parameters result in cache miss") + + // Third request with different model - should be cache miss + t.Log("Testing third request with different model (expecting cache miss)...") + + ctx3 := context.Background() + ctx3 = context.WithValue(ctx3, CacheKey, "test-cache-enabled") + ctx3 = context.WithValue(ctx3, schemas.BifrostContextKeyRequestType, schemas.ChatCompletionRequest) + + modifiedRequest2 := *baseRequest + modifiedRequest2.Model = "gpt-3.5-turbo" // Different model + + _, shortCircuit3, err := setup.Plugin.PreHook(&ctx3, &modifiedRequest2) + if err != nil { + t.Fatalf("Third PreHook failed: %v", err) + } + + if shortCircuit3 != nil { + t.Fatal("Expected cache miss due to different model, but got cache hit") + } + + t.Log("βœ… Strict filtering working - different model results in cache miss") + t.Log("πŸŽ‰ Strict filtering test passed!") +} + +// TestSemanticCacheStreamingFlow tests streaming response caching +func TestSemanticCacheStreamingFlow(t *testing.T) { + setup := NewTestSetup(t) + defer setup.Cleanup() + + ctx := context.Background() + ctx = context.WithValue(ctx, CacheKey, "test-cache-enabled") + ctx = context.WithValue(ctx, schemas.BifrostContextKeyRequestType, schemas.ChatCompletionStreamRequest) + + request := &schemas.BifrostRequest{ + Provider: schemas.OpenAI, + Model: "gpt-4o-mini", + Input: schemas.RequestInput{ + ChatCompletionInput: &[]schemas.BifrostMessage{ + { + Role: "user", + Content: schemas.MessageContent{ + ContentStr: bifrost.Ptr("Tell me a short story"), + }, + }, + }, + }, + Params: &schemas.ModelParameters{ + Temperature: bifrost.Ptr(0.8), + }, + } + + t.Log("Testing streaming request (cache miss)...") + + // First request - should be cache miss + _, shortCircuit, err := setup.Plugin.PreHook(&ctx, request) + if err != nil { + t.Fatalf("PreHook failed: %v", err) + } + + if shortCircuit != nil { + t.Fatal("Expected cache miss for streaming request") + } + + t.Log("βœ… Streaming cache miss handled correctly") + + // Simulate streaming response chunks + t.Log("Caching streaming response chunks...") + + chunks := []string{ + "Once upon a time,", + " there was a brave", + " knight who saved the day.", + } + + for i, chunk := range chunks { + var finishReason *string + if i == len(chunks)-1 { + finishReason = bifrost.Ptr("stop") + } + + chunkResponse := &schemas.BifrostResponse{ + ID: uuid.New().String(), + Choices: []schemas.BifrostResponseChoice{ + { + Index: i, + FinishReason: finishReason, + BifrostStreamResponseChoice: &schemas.BifrostStreamResponseChoice{ + Delta: schemas.BifrostStreamDelta{ + Content: bifrost.Ptr(chunk), + }, + }, + }, + }, + ExtraFields: schemas.BifrostResponseExtraFields{ + Provider: schemas.OpenAI, + ChunkIndex: i, + }, + } + + _, _, err = setup.Plugin.PostHook(&ctx, chunkResponse, nil) + if err != nil { + t.Fatalf("PostHook failed for chunk %d: %v", i, err) + } + } + + WaitForCache() + t.Log("βœ… Streaming response chunks cached") + + // Test cache retrieval for streaming + t.Log("Testing streaming cache retrieval...") + + ctx2 := context.Background() + ctx2 = context.WithValue(ctx2, CacheKey, "test-cache-enabled") + ctx2 = context.WithValue(ctx2, schemas.BifrostContextKeyRequestType, schemas.ChatCompletionStreamRequest) + + _, shortCircuit2, err := setup.Plugin.PreHook(&ctx2, request) + if err != nil { + t.Fatalf("Second PreHook failed: %v", err) + } + + if shortCircuit2 == nil { + t.Log("⚠️ Expected streaming cache hit, but got cache miss - this may be expected with the new unified storage") + return + } + + if shortCircuit2.Stream == nil { + t.Fatal("Cache hit but stream is nil") + } + + t.Log("βœ… Streaming cache hit detected") + + // Read from the cached stream + chunkCount := 0 + for chunk := range shortCircuit2.Stream { + if chunk.BifrostResponse == nil { + continue + } + chunkCount++ + t.Logf("Received cached chunk %d", chunkCount) + } + + if chunkCount == 0 { + t.Fatal("No chunks received from cached stream") + } + + t.Logf("βœ… Received %d cached chunks", chunkCount) + t.Log("πŸŽ‰ Streaming cache test passed!") +} + +// TestSemanticCache_NoCacheWhenKeyMissing verifies cache is disabled when cache key is missing from context +func TestSemanticCache_NoCacheWhenKeyMissing(t *testing.T) { + t.Log("Testing cache behavior when cache key is missing...") + + setup := NewTestSetup(t) + defer setup.Cleanup() + + ctx := context.Background() + // Don't set the cache key - cache should be disabled + ctx = context.WithValue(ctx, schemas.BifrostContextKeyRequestType, schemas.ChatCompletionRequest) + + request := &schemas.BifrostRequest{ + Provider: schemas.OpenAI, + Model: "gpt-4o-mini", + Input: schemas.RequestInput{ + ChatCompletionInput: &[]schemas.BifrostMessage{ + { + Role: "user", + Content: schemas.MessageContent{ + ContentStr: bifrost.Ptr("Test message"), + }, + }, + }, + }, + } + + _, shortCircuit, err := setup.Plugin.PreHook(&ctx, request) + if err != nil { + t.Fatalf("PreHook failed: %v", err) + } + + if shortCircuit != nil { + t.Fatal("Expected no caching when cache key is not set, but got cache hit") + } + + t.Log("βœ… Cache properly disabled when no cache key is set") + t.Log("πŸŽ‰ No cache key test passed!") +} + +// TestSemanticCache_CustomTTLHandling verifies cache respects custom TTL values from context +func TestSemanticCache_CustomTTLHandling(t *testing.T) { + setup := NewTestSetup(t) + defer setup.Cleanup() + + // Configure plugin with custom TTL key + ctx := context.Background() + ctx = context.WithValue(ctx, CacheKey, "test-cache-enabled") + ctx = context.WithValue(ctx, CacheTTLKey, 1*time.Minute) // Custom TTL + ctx = context.WithValue(ctx, schemas.BifrostContextKeyRequestType, schemas.ChatCompletionRequest) + + request := &schemas.BifrostRequest{ + Provider: schemas.OpenAI, + Model: "gpt-4o-mini", + Input: schemas.RequestInput{ + ChatCompletionInput: &[]schemas.BifrostMessage{ + { + Role: "user", + Content: schemas.MessageContent{ + ContentStr: bifrost.Ptr("TTL test message"), + }, + }, + }, + }, + } + + // First request - cache miss + _, shortCircuit, err := setup.Plugin.PreHook(&ctx, request) + if err != nil { + t.Fatalf("PreHook failed: %v", err) + } + + if shortCircuit != nil { + t.Fatal("Expected cache miss, but got cache hit") + } + + // Simulate response and cache it + response := &schemas.BifrostResponse{ + ID: "ttl-test-response", + Choices: []schemas.BifrostResponseChoice{ + { + BifrostNonStreamResponseChoice: &schemas.BifrostNonStreamResponseChoice{ + Message: schemas.BifrostMessage{ + Role: "assistant", + Content: schemas.MessageContent{ + ContentStr: bifrost.Ptr("TTL test response"), + }, + }, + }, + }, + }, + ExtraFields: schemas.BifrostResponseExtraFields{ + Provider: schemas.OpenAI, + }, + } + + _, _, err = setup.Plugin.PostHook(&ctx, response, nil) + if err != nil { + t.Fatalf("PostHook failed: %v", err) + } + + WaitForCache() + + t.Log("βœ… Custom TTL configuration test passed!") +} + +// TestSemanticCache_CustomThresholdHandling verifies cache respects custom similarity threshold from context +func TestSemanticCache_CustomThresholdHandling(t *testing.T) { + setup := NewTestSetup(t) + defer setup.Cleanup() + + // Configure plugin with custom threshold key + ctx := context.Background() + ctx = context.WithValue(ctx, CacheKey, "test-cache-enabled") + ctx = context.WithValue(ctx, CacheThresholdKey, 0.95) // Very high threshold + ctx = context.WithValue(ctx, schemas.BifrostContextKeyRequestType, schemas.ChatCompletionRequest) + + request := &schemas.BifrostRequest{ + Provider: schemas.OpenAI, + Model: "gpt-4o-mini", + Input: schemas.RequestInput{ + ChatCompletionInput: &[]schemas.BifrostMessage{ + { + Role: "user", + Content: schemas.MessageContent{ + ContentStr: bifrost.Ptr("Threshold test message"), + }, + }, + }, + }, + } + + // Test that custom threshold is used (this would need semantic search to be fully testable) + _, shortCircuit, err := setup.Plugin.PreHook(&ctx, request) + if err != nil { + t.Fatalf("PreHook failed: %v", err) + } + + if shortCircuit != nil { + t.Fatal("Expected cache miss with high threshold, but got cache hit") + } + + t.Log("βœ… Custom threshold configuration test passed!") +} + +// TestSemanticCache_ProviderModelCachingFlags verifies cache behavior with provider/model caching flags +func TestSemanticCache_ProviderModelCachingFlags(t *testing.T) { + setup := NewTestSetup(t) + defer setup.Cleanup() + + // Test with provider/model caching disabled + setup.Config.CacheByProvider = bifrost.Ptr(false) + setup.Config.CacheByModel = bifrost.Ptr(false) + + ctx := context.Background() + ctx = context.WithValue(ctx, CacheKey, "test-cache-enabled") + ctx = context.WithValue(ctx, schemas.BifrostContextKeyRequestType, schemas.ChatCompletionRequest) + + request1 := &schemas.BifrostRequest{ + Provider: schemas.OpenAI, + Model: "gpt-4o-mini", + Input: schemas.RequestInput{ + ChatCompletionInput: &[]schemas.BifrostMessage{ + { + Role: "user", + Content: schemas.MessageContent{ + ContentStr: bifrost.Ptr("Provider model flags test"), + }, + }, + }, + }, + } + + // First request with OpenAI + _, shortCircuit1, err := setup.Plugin.PreHook(&ctx, request1) + if err != nil { + t.Fatalf("PreHook failed: %v", err) + } + + if shortCircuit1 != nil { + t.Fatal("Expected cache miss, but got cache hit") + } + + // Cache the response + response := &schemas.BifrostResponse{ + ID: "provider-model-test", + Choices: []schemas.BifrostResponseChoice{ + { + BifrostNonStreamResponseChoice: &schemas.BifrostNonStreamResponseChoice{ + Message: schemas.BifrostMessage{ + Role: "assistant", + Content: schemas.MessageContent{ + ContentStr: bifrost.Ptr("Provider model test response"), + }, + }, + }, + }, + }, + ExtraFields: schemas.BifrostResponseExtraFields{ + Provider: schemas.OpenAI, + }, + } + + _, _, err = setup.Plugin.PostHook(&ctx, response, nil) + if err != nil { + t.Fatalf("PostHook failed: %v", err) + } + + WaitForCache() + + // Second request with different provider - should potentially hit cache since provider is not considered + request2 := &schemas.BifrostRequest{ + Provider: schemas.Anthropic, // Different provider + Model: "claude-3-haiku", // Different model + Input: schemas.RequestInput{ + ChatCompletionInput: &[]schemas.BifrostMessage{ + { + Role: "user", + Content: schemas.MessageContent{ + ContentStr: bifrost.Ptr("Provider model flags test"), // Same content + }, + }, + }, + }, + } + + ctx2 := context.Background() + ctx2 = context.WithValue(ctx2, CacheKey, "test-cache-enabled") + ctx2 = context.WithValue(ctx2, schemas.BifrostContextKeyRequestType, schemas.ChatCompletionRequest) + + _, shortCircuit2, err := setup.Plugin.PreHook(&ctx2, request2) + if err != nil { + t.Fatalf("Second PreHook failed: %v", err) + } + + // With provider/model caching disabled, we might get cache hits across different providers/models + // This behavior depends on the exact implementation of hash generation + t.Logf("Cache behavior with disabled provider/model flags: hit=%v", shortCircuit2 != nil) + + t.Log("βœ… Provider/model caching flags test passed!") +} + +// TestSemanticCache_ConfigurationEdgeCases verifies edge cases in configuration handling +func TestSemanticCache_ConfigurationEdgeCases(t *testing.T) { + setup := NewTestSetup(t) + defer setup.Cleanup() + + // Test with invalid TTL type in context + ctx := context.Background() + ctx = context.WithValue(ctx, CacheKey, "test-cache-enabled") + ctx = context.WithValue(ctx, CacheTTLKey, "not-a-duration") // Invalid TTL type + ctx = context.WithValue(ctx, schemas.BifrostContextKeyRequestType, schemas.ChatCompletionRequest) + + request := &schemas.BifrostRequest{ + Provider: schemas.OpenAI, + Model: "gpt-4o-mini", + Input: schemas.RequestInput{ + ChatCompletionInput: &[]schemas.BifrostMessage{ + { + Role: "user", + Content: schemas.MessageContent{ + ContentStr: bifrost.Ptr("Edge case test"), + }, + }, + }, + }, + } + + // Should handle invalid TTL gracefully + _, shortCircuit, err := setup.Plugin.PreHook(&ctx, request) + if err != nil { + t.Fatalf("PreHook failed with invalid TTL: %v", err) + } + + if shortCircuit != nil { + t.Fatal("Unexpected cache hit with invalid TTL") + } + + // Test with invalid threshold type + ctx2 := context.Background() + ctx2 = context.WithValue(ctx2, CacheKey, "test-cache-enabled") + ctx2 = context.WithValue(ctx2, CacheThresholdKey, "not-a-float") // Invalid threshold type + ctx2 = context.WithValue(ctx2, schemas.BifrostContextKeyRequestType, schemas.ChatCompletionRequest) + + // Should handle invalid threshold gracefully + _, shortCircuit2, err := setup.Plugin.PreHook(&ctx2, request) + if err != nil { + t.Fatalf("PreHook failed with invalid threshold: %v", err) + } + + if shortCircuit2 != nil { + t.Fatal("Unexpected cache hit with invalid threshold") + } + + t.Log("βœ… Configuration edge cases test passed!") +} diff --git a/plugins/semanticcache/plugin_no_store_test.go b/plugins/semanticcache/plugin_no_store_test.go new file mode 100644 index 000000000..621563954 --- /dev/null +++ b/plugins/semanticcache/plugin_no_store_test.go @@ -0,0 +1,311 @@ +package semanticcache + +import ( + "context" + "testing" +) + +// TestCacheNoStoreBasicFunctionality tests that CacheNoStoreKey prevents caching +func TestCacheNoStoreBasicFunctionality(t *testing.T) { + setup := NewTestSetup(t) + defer setup.Cleanup() + + testRequest := CreateBasicChatRequest("What is artificial intelligence?", 0.7, 100) + + // Test 1: Normal caching (control test) + ctx1 := CreateContextWithCacheKey("test-no-store-control") + t.Log("Making normal request (should be cached)...") + response1, err1 := setup.Client.ChatCompletionRequest(ctx1, testRequest) + if err1 != nil { + t.Fatalf("First request failed: %v", err1) + } + AssertNoCacheHit(t, response1) // Fresh request + + WaitForCache() + + // Verify it got cached + t.Log("Verifying normal caching worked...") + response2, err2 := setup.Client.ChatCompletionRequest(ctx1, testRequest) + if err2 != nil { + t.Fatalf("Second request failed: %v", err2) + } + AssertCacheHit(t, response2, "direct") // Should be cached + + // Test 2: NoStore = true (should not cache) + ctx2 := CreateContextWithCacheKeyAndNoStore("test-no-store-disabled", true) + t.Log("Making request with CacheNoStoreKey=true (should not be cached)...") + response3, err3 := setup.Client.ChatCompletionRequest(ctx2, testRequest) + if err3 != nil { + t.Fatalf("Third request failed: %v", err3) + } + AssertNoCacheHit(t, response3) // Fresh request + + WaitForCache() + + // Verify it was NOT cached + t.Log("Verifying no-store request was not cached...") + response4, err4 := setup.Client.ChatCompletionRequest(ctx2, testRequest) + if err4 != nil { + t.Fatalf("Fourth request failed: %v", err4) + } + AssertNoCacheHit(t, response4) // Should still be fresh (not cached) + + // Test 3: NoStore = false (should cache normally) + ctx3 := CreateContextWithCacheKeyAndNoStore("test-no-store-enabled", false) + t.Log("Making request with CacheNoStoreKey=false (should be cached)...") + response5, err5 := setup.Client.ChatCompletionRequest(ctx3, testRequest) + if err5 != nil { + t.Fatalf("Fifth request failed: %v", err5) + } + AssertNoCacheHit(t, response5) // Fresh request + + WaitForCache() + + // Verify it got cached + t.Log("Verifying no-store=false request was cached...") + response6, err6 := setup.Client.ChatCompletionRequest(ctx3, testRequest) + if err6 != nil { + t.Fatalf("Sixth request failed: %v", err6) + } + AssertCacheHit(t, response6, "direct") // Should be cached + + t.Log("βœ… CacheNoStoreKey basic functionality works correctly") +} + +// TestCacheNoStoreWithDifferentRequestTypes tests NoStore with various request types +func TestCacheNoStoreWithDifferentRequestTypes(t *testing.T) { + setup := NewTestSetup(t) + defer setup.Cleanup() + + // Test with chat completion + chatRequest := CreateBasicChatRequest("Test no-store with chat", 0.7, 50) + ctx1 := CreateContextWithCacheKeyAndNoStore("test-no-store-chat", true) + + t.Log("Testing no-store with chat completion...") + response1, err1 := setup.Client.ChatCompletionRequest(ctx1, chatRequest) + if err1 != nil { + t.Fatalf("Chat request failed: %v", err1) + } + AssertNoCacheHit(t, response1) + + WaitForCache() + + // Verify not cached + response2, err2 := setup.Client.ChatCompletionRequest(ctx1, chatRequest) + if err2 != nil { + t.Fatalf("Second chat request failed: %v", err2) + } + AssertNoCacheHit(t, response2) // Should not be cached + + // Test with embedding request + embeddingRequest := CreateEmbeddingRequest([]string{"Test no-store with embeddings"}) + ctx2 := CreateContextWithCacheKeyAndNoStore("test-no-store-embedding", true) + + t.Log("Testing no-store with embedding request...") + response3, err3 := setup.Client.EmbeddingRequest(ctx2, embeddingRequest) + if err3 != nil { + t.Fatalf("Embedding request failed: %v", err3) + } + AssertNoCacheHit(t, response3) + + WaitForCache() + + // Verify not cached + response4, err4 := setup.Client.EmbeddingRequest(ctx2, embeddingRequest) + if err4 != nil { + t.Fatalf("Second embedding request failed: %v", err4) + } + AssertNoCacheHit(t, response4) // Should not be cached + + t.Log("βœ… CacheNoStoreKey works with different request types") +} + +// TestCacheNoStoreWithConversationHistory tests NoStore with conversation context +func TestCacheNoStoreWithConversationHistory(t *testing.T) { + setup := NewTestSetup(t) + defer setup.Cleanup() + + // Create conversation context + conversation := BuildConversationHistory( + "You are a helpful assistant", + []string{"Hello", "Hi! How can I help?"}, + ) + messages := AddUserMessage(conversation, "What is machine learning?") + request := CreateConversationRequest(messages, 0.7, 100) + + // Test with no-store enabled + ctx := CreateContextWithCacheKeyAndNoStore("test-no-store-conversation", true) + + t.Log("Testing no-store with conversation history...") + response1, err1 := setup.Client.ChatCompletionRequest(ctx, request) + if err1 != nil { + t.Fatalf("First request failed: %v", err1) + } + AssertNoCacheHit(t, response1) + + WaitForCache() + + // Verify not cached (same conversation should not hit cache) + response2, err2 := setup.Client.ChatCompletionRequest(ctx, request) + if err2 != nil { + t.Fatalf("Second request failed: %v", err2) + } + AssertNoCacheHit(t, response2) // Should not be cached due to no-store + + t.Log("βœ… CacheNoStoreKey works with conversation history") +} + +// TestCacheNoStoreWithCacheTypes tests NoStore interaction with CacheTypeKey +func TestCacheNoStoreWithCacheTypes(t *testing.T) { + setup := NewTestSetup(t) + defer setup.Cleanup() + + testRequest := CreateBasicChatRequest("Test no-store with cache types", 0.7, 50) + + // Test no-store with direct cache type + ctx1 := CreateContextWithCacheKey("test-no-store-cache-types") + ctx1 = context.WithValue(ctx1, CacheNoStoreKey, true) + ctx1 = context.WithValue(ctx1, CacheTypeKey, CacheTypeDirect) + + t.Log("Testing no-store with CacheTypeKey=direct...") + response1, err1 := setup.Client.ChatCompletionRequest(ctx1, testRequest) + if err1 != nil { + t.Fatalf("First request failed: %v", err1) + } + AssertNoCacheHit(t, response1) + + WaitForCache() + + // Should not be cached + response2, err2 := setup.Client.ChatCompletionRequest(ctx1, testRequest) + if err2 != nil { + t.Fatalf("Second request failed: %v", err2) + } + AssertNoCacheHit(t, response2) // No-store should override cache type + + // Test no-store with semantic cache type + ctx2 := CreateContextWithCacheKey("test-no-store-cache-types") + ctx2 = context.WithValue(ctx2, CacheNoStoreKey, true) + ctx2 = context.WithValue(ctx2, CacheTypeKey, CacheTypeSemantic) + + t.Log("Testing no-store with CacheTypeKey=semantic...") + response3, err3 := setup.Client.ChatCompletionRequest(ctx2, testRequest) + if err3 != nil { + t.Fatalf("Third request failed: %v", err3) + } + AssertNoCacheHit(t, response3) + + WaitForCache() + + // Should not be cached + response4, err4 := setup.Client.ChatCompletionRequest(ctx2, testRequest) + if err4 != nil { + t.Fatalf("Fourth request failed: %v", err4) + } + AssertNoCacheHit(t, response4) // No-store should override cache type + + t.Log("βœ… CacheNoStoreKey correctly overrides cache type settings") +} + +// TestCacheNoStoreErrorHandling tests error scenarios with NoStore +func TestCacheNoStoreErrorHandling(t *testing.T) { + setup := NewTestSetup(t) + defer setup.Cleanup() + + testRequest := CreateBasicChatRequest("Test no-store error handling", 0.7, 50) + + // Test with invalid no-store value (non-boolean) + ctx1 := CreateContextWithCacheKey("test-no-store-errors") + ctx1 = context.WithValue(ctx1, CacheNoStoreKey, "invalid") + + t.Log("Testing no-store with invalid value (should cache normally)...") + response1, err1 := setup.Client.ChatCompletionRequest(ctx1, testRequest) + if err1 != nil { + t.Fatalf("First request failed: %v", err1) + } + AssertNoCacheHit(t, response1) + + WaitForCache() + + // Should be cached (invalid value should be ignored) + response2, err2 := setup.Client.ChatCompletionRequest(ctx1, testRequest) + if err2 != nil { + t.Fatalf("Second request failed: %v", err2) + } + AssertCacheHit(t, response2, "direct") // Should be cached (invalid value ignored) + + // Test with nil value (should cache normally) + ctx2 := CreateContextWithCacheKey("test-no-store-nil") + ctx2 = context.WithValue(ctx2, CacheNoStoreKey, nil) + + t.Log("Testing no-store with nil value (should cache normally)...") + response3, err3 := setup.Client.ChatCompletionRequest(ctx2, testRequest) + if err3 != nil { + t.Fatalf("Third request failed: %v", err3) + } + AssertNoCacheHit(t, response3) + + WaitForCache() + + // Should be cached (nil should be treated as normal caching) + response4, err4 := setup.Client.ChatCompletionRequest(ctx2, testRequest) + if err4 != nil { + t.Fatalf("Fourth request failed: %v", err4) + } + AssertCacheHit(t, response4, "direct") // Should be cached (nil ignored) + + t.Log("βœ… CacheNoStoreKey error handling works correctly") +} + +// TestCacheNoStoreReadButNoWrite tests that NoStore allows reading cache but prevents writing +func TestCacheNoStoreReadButNoWrite(t *testing.T) { + setup := NewTestSetup(t) + defer setup.Cleanup() + + testRequest := CreateBasicChatRequest("Describe Isaac Newton's three laws of motion", 0.7, 50) + + // Step 1: Cache a response normally + ctx1 := CreateContextWithCacheKey("test-no-store-read") + t.Log("Caching response normally...") + response1, err1 := setup.Client.ChatCompletionRequest(ctx1, testRequest) + if err1 != nil { + t.Fatalf("First request failed: %v", err1) + } + AssertNoCacheHit(t, response1) + + WaitForCache() + + // Step 2: Try to read with no-store enabled (should still read from cache) + ctx2 := CreateContextWithCacheKeyAndNoStore("test-no-store-read", true) + t.Log("Reading with no-store enabled (should still hit cache for reads)...") + response2, err2 := setup.Client.ChatCompletionRequest(ctx2, testRequest) + if err2 != nil { + t.Fatalf("Second request failed: %v", err2) + } + // The current implementation should still read from cache even with no-store + // (no-store only affects writing, not reading) + AssertCacheHit(t, response2, "direct") + + // Step 3: Make a semantically similar request with no-store (strong paraphrase for deterministic semantic hit) + newRequest := CreateBasicChatRequest("Describe the three laws of motion by Isaac Newton", 0.7, 50) + t.Log("Making semantically similar request with no-store (should get semantic hit, but not cache response)...") + response3, err3 := setup.Client.ChatCompletionRequest(ctx2, newRequest) + if err3 != nil { + t.Fatalf("Third request failed: %v", err3) + } + // Should get semantic cache hit (no-store allows reads, just prevents writes) + AssertCacheHit(t, response3, "semantic") + + WaitForCache() + + // Step 4: Repeat similar request with no-store (should still get semantic hit) + t.Log("Repeating similar request with no-store (should still get semantic hit)...") + response4, err4 := setup.Client.ChatCompletionRequest(ctx2, newRequest) + if err4 != nil { + t.Fatalf("Fourth request failed: %v", err4) + } + // Should get semantic cache hit again (consistent behavior) + AssertCacheHit(t, response4, "semantic") + + t.Log("βœ… CacheNoStoreKey allows reading but prevents writing") +} diff --git a/plugins/semanticcache/plugin_normalization_test.go b/plugins/semanticcache/plugin_normalization_test.go new file mode 100644 index 000000000..23bb7f30e --- /dev/null +++ b/plugins/semanticcache/plugin_normalization_test.go @@ -0,0 +1,332 @@ +package semanticcache + +import ( + "testing" + + "github.com/maximhq/bifrost/core/schemas" +) + +// TestTextNormalizationDirectCache tests that text normalization works correctly +// for direct cache (hash-based) matching across all input types +func TestTextNormalizationDirectCache(t *testing.T) { + setup := NewTestSetup(t) + defer setup.Cleanup() + + t.Run("ChatCompletion", func(t *testing.T) { + testChatCompletionNormalization(t, setup) + }) + + t.Run("Speech", func(t *testing.T) { + testSpeechNormalization(t, setup) + }) +} + +func testChatCompletionNormalization(t *testing.T, setup *TestSetup) { + ctx := CreateContextWithCacheKey("test-chat-normalization") + + // Test cases with different case and whitespace variations + testCases := []struct { + name string + userMsg string + systemMsg string + }{ + { + name: "Original", + userMsg: "Explain quantum physics", + systemMsg: "You are a helpful science teacher", + }, + { + name: "Lowercase", + userMsg: "explain quantum physics", + systemMsg: "you are a helpful science teacher", + }, + { + name: "Uppercase", + userMsg: "EXPLAIN QUANTUM PHYSICS", + systemMsg: "YOU ARE A HELPFUL SCIENCE TEACHER", + }, + { + name: "Mixed Case", + userMsg: "ExPlAiN QuAnTuM PhYsIcS", + systemMsg: "YoU aRe A hElPfUl ScIeNcE tEaChEr", + }, + { + name: "With Whitespace", + userMsg: " Explain quantum physics ", + systemMsg: " You are a helpful science teacher ", + }, + { + name: "Extra Whitespace", + userMsg: " Explain quantum physics ", + systemMsg: " You are a helpful science teacher ", + }, + } + + // Create chat completion requests for all test cases + requests := make([]*schemas.BifrostRequest, len(testCases)) + for i, tc := range testCases { + requests[i] = &schemas.BifrostRequest{ + Provider: schemas.OpenAI, + Model: "gpt-4o-mini", + Input: schemas.RequestInput{ + ChatCompletionInput: &[]schemas.BifrostMessage{ + { + Role: schemas.ModelChatMessageRoleSystem, + Content: schemas.MessageContent{ + ContentStr: &tc.systemMsg, + }, + }, + { + Role: schemas.ModelChatMessageRoleUser, + Content: schemas.MessageContent{ + ContentStr: &tc.userMsg, + }, + }, + }, + }, + Params: &schemas.ModelParameters{ + Temperature: PtrFloat64(0.5), + MaxTokens: PtrInt(50), + }, + } + } + + // Make first request (should miss cache and be stored) + t.Logf("Making first request with user: '%s', system: '%s'", testCases[0].userMsg, testCases[0].systemMsg) + response1, err1 := setup.Client.ChatCompletionRequest(ctx, requests[0]) + if err1 != nil { + t.Fatalf("First request failed: %v", err1) + } + + if response1 == nil || len(response1.Choices) == 0 { + t.Fatal("First response is invalid") + } + + AssertNoCacheHit(t, response1) + WaitForCache() + + // Test all other variations should hit cache due to normalization + for i := 1; i < len(testCases); i++ { + tc := testCases[i] + t.Logf("Testing variation '%s' with user: '%s', system: '%s'", tc.name, tc.userMsg, tc.systemMsg) + + response, err := setup.Client.ChatCompletionRequest(ctx, requests[i]) + if err != nil { + t.Fatalf("Request for case '%s' failed: %v", tc.name, err) + } + + if response == nil || len(response.Choices) == 0 { + t.Fatalf("Response for case '%s' is invalid", tc.name) + } + + // Should be cache hit due to normalization + AssertCacheHit(t, response, "direct") + t.Logf("βœ“ Cache hit for '%s' variation", tc.name) + } +} + +func testSpeechNormalization(t *testing.T, setup *TestSetup) { + ctx := CreateContextWithCacheKey("test-speech-normalization") + + // Test cases with different case and whitespace variations for speech input + testCases := []struct { + name string + input string + }{ + {"Original", "Hello, this is a test speech synthesis"}, + {"Lowercase", "hello, this is a test speech synthesis"}, + {"Uppercase", "HELLO, THIS IS A TEST SPEECH SYNTHESIS"}, + {"Mixed Case", "HeLLo, ThIs Is A tEsT sPeEcH sYnThEsIs"}, + {"Leading Whitespace", " Hello, this is a test speech synthesis"}, + {"Trailing Whitespace", "Hello, this is a test speech synthesis "}, + {"Both Whitespace", " Hello, this is a test speech synthesis "}, + {"Extra Spaces", " Hello, this is a test speech synthesis "}, + } + + // Create speech requests for all test cases + requests := make([]*schemas.BifrostRequest, len(testCases)) + for i, tc := range testCases { + requests[i] = CreateSpeechRequest(tc.input, "alloy") + } + + // Make first request (should miss cache and be stored) + t.Logf("Making first speech request with: '%s'", testCases[0].input) + response1, err1 := setup.Client.SpeechRequest(ctx, requests[0]) + if err1 != nil { + t.Fatalf("First request failed: %v", err1) + } + + if response1 == nil || response1.Speech == nil { + t.Fatal("First response is invalid") + } + + AssertNoCacheHit(t, response1) + WaitForCache() + + // Test all other variations should hit cache due to normalization + for i := 1; i < len(testCases); i++ { + tc := testCases[i] + t.Logf("Testing variation '%s' with input: '%s'", tc.name, tc.input) + + response, err := setup.Client.SpeechRequest(ctx, requests[i]) + if err != nil { + t.Fatalf("Request for case '%s' failed: %v", tc.name, err) + } + + if response == nil || response.Speech == nil { + t.Fatalf("Response for case '%s' is invalid", tc.name) + } + + // Should be cache hit due to normalization + AssertCacheHit(t, response, "direct") + t.Logf("βœ“ Cache hit for '%s' variation", tc.name) + } +} + +// TestChatCompletionContentBlocksNormalization tests normalization for content blocks +func TestChatCompletionContentBlocksNormalization(t *testing.T) { + setup := NewTestSetup(t) + defer setup.Cleanup() + + ctx := CreateContextWithCacheKey("test-content-blocks-normalization") + + // Test cases with content blocks having different text normalization + testCases := []struct { + name string + textBlocks []string + }{ + { + name: "Original", + textBlocks: []string{"Hello World", "How are you today?"}, + }, + { + name: "Lowercase", + textBlocks: []string{"hello world", "how are you today?"}, + }, + { + name: "With Whitespace", + textBlocks: []string{" Hello World ", " How are you today? "}, + }, + { + name: "Mixed Case", + textBlocks: []string{"HeLLo WoRLd", "HoW aRe YoU tOdAy?"}, + }, + } + + // Create chat completion requests with content blocks + requests := make([]*schemas.BifrostRequest, len(testCases)) + for i, tc := range testCases { + // Create content blocks + contentBlocks := make([]schemas.ContentBlock, len(tc.textBlocks)) + for j, text := range tc.textBlocks { + contentBlocks[j] = schemas.ContentBlock{ + Type: schemas.ContentBlockTypeText, + Text: &text, + } + } + + requests[i] = &schemas.BifrostRequest{ + Provider: schemas.OpenAI, + Model: "gpt-4o-mini", + Input: schemas.RequestInput{ + ChatCompletionInput: &[]schemas.BifrostMessage{ + { + Role: schemas.ModelChatMessageRoleUser, + Content: schemas.MessageContent{ + ContentBlocks: &contentBlocks, + }, + }, + }, + }, + Params: &schemas.ModelParameters{ + Temperature: PtrFloat64(0.5), + MaxTokens: PtrInt(50), + }, + } + } + + // Make first request (should miss cache and be stored) + t.Logf("Making first request with content blocks: %v", testCases[0].textBlocks) + response1, err1 := setup.Client.ChatCompletionRequest(ctx, requests[0]) + if err1 != nil { + t.Fatalf("First request failed: %v", err1) + } + + if response1 == nil || len(response1.Choices) == 0 { + t.Fatal("First response is invalid") + } + + AssertNoCacheHit(t, response1) + WaitForCache() + + // Test all other variations should hit cache due to normalization + for i := 1; i < len(testCases); i++ { + tc := testCases[i] + t.Logf("Testing variation '%s' with content blocks: %v", tc.name, tc.textBlocks) + + response, err := setup.Client.ChatCompletionRequest(ctx, requests[i]) + if err != nil { + t.Fatalf("Request for case '%s' failed: %v", tc.name, err) + } + + if response == nil || len(response.Choices) == 0 { + t.Fatalf("Response for case '%s' is invalid", tc.name) + } + + // Should be cache hit due to normalization + AssertCacheHit(t, response, "direct") + t.Logf("βœ“ Cache hit for '%s' variation", tc.name) + } +} + +// TestNormalizationWithSemanticCache tests that normalization works with semantic cache as well +func TestNormalizationWithSemanticCache(t *testing.T) { + setup := NewTestSetup(t) + defer setup.Cleanup() + + ctx := CreateContextWithCacheKey("test-normalization-semantic") + + // Make first request with original text + originalRequest := CreateBasicChatRequest("What is Machine Learning?", 0.5, 50) + t.Log("Making first request with original text...") + response1, err1 := setup.Client.ChatCompletionRequest(ctx, originalRequest) + if err1 != nil { + t.Fatalf("First request failed: %v", err1) + } + + AssertNoCacheHit(t, response1) + WaitForCache() + + // Test semantic match with different case (should hit semantic cache after normalization) + normalizedRequest := CreateBasicChatRequest("what is machine learning?", 0.5, 50) + t.Log("Making semantic request with normalized case...") + response2, err2 := setup.Client.ChatCompletionRequest(ctx, normalizedRequest) + if err2 != nil { + t.Fatalf("Second request failed: %v", err2) + } + + // This should be a direct cache hit since the normalized text is identical + AssertCacheHit(t, response2, "direct") + t.Log("βœ“ Direct cache hit with normalized text") + + // Test with semantically similar but different text + semanticRequest := CreateBasicChatRequest("can you explain machine learning concepts?", 0.5, 50) + t.Log("Making semantically similar request...") + response3, err3 := setup.Client.ChatCompletionRequest(ctx, semanticRequest) + if err3 != nil { + t.Fatalf("Third request failed: %v", err3) + } + + // This should be a semantic cache hit + AssertCacheHit(t, response3, "semantic") + t.Log("βœ“ Semantic cache hit with similar content") +} + +// Helper functions for pointer creation +func PtrFloat64(f float64) *float64 { + return &f +} + +func PtrInt(i int) *int { + return &i +} diff --git a/plugins/semanticcache/plugin_streaming_test.go b/plugins/semanticcache/plugin_streaming_test.go new file mode 100644 index 000000000..5934e9119 --- /dev/null +++ b/plugins/semanticcache/plugin_streaming_test.go @@ -0,0 +1,333 @@ +package semanticcache + +import ( + "testing" + "time" + + "github.com/maximhq/bifrost/core/schemas" +) + +// TestStreamingCacheBasicFunctionality tests streaming response caching +func TestStreamingCacheBasicFunctionality(t *testing.T) { + setup := NewTestSetup(t) + defer setup.Cleanup() + + ctx := CreateContextWithCacheKey("test-stream-value") + + // Create a test streaming request + testRequest := CreateStreamingChatRequest( + "Count from 1 to 3, each number on a new line.", + 0.0, // Use 0 temperature for more predictable responses + 20, + ) + + t.Log("Making first streaming request (should go to OpenAI and be cached)...") + + // Make first streaming request + start1 := time.Now() + stream1, err1 := setup.Client.ChatCompletionStreamRequest(ctx, testRequest) + if err1 != nil { + t.Fatalf("First streaming request failed: %v", err1) + } + + var responses1 []schemas.BifrostResponse + for streamMsg := range stream1 { + if streamMsg.BifrostError != nil { + t.Fatalf("Error in first stream: %v", streamMsg.BifrostError) + } + if streamMsg.BifrostResponse != nil { + responses1 = append(responses1, *streamMsg.BifrostResponse) + } + } + duration1 := time.Since(start1) + + if len(responses1) == 0 { + t.Fatal("First streaming request returned no responses") + } + + t.Logf("First streaming request completed in %v with %d chunks", duration1, len(responses1)) + + // Wait for cache to be written + WaitForCache() + + t.Log("Making second identical streaming request (should be served from cache)...") + + // Make second identical streaming request + start2 := time.Now() + stream2, err2 := setup.Client.ChatCompletionStreamRequest(ctx, testRequest) + if err2 != nil { + t.Fatalf("Second streaming request failed: %v", err2) + } + + var responses2 []schemas.BifrostResponse + for streamMsg := range stream2 { + if streamMsg.BifrostError != nil { + t.Fatalf("Error in second stream: %v", streamMsg.BifrostError) + } + if streamMsg.BifrostResponse != nil { + responses2 = append(responses2, *streamMsg.BifrostResponse) + } + } + duration2 := time.Since(start2) + + if len(responses2) == 0 { + t.Fatal("Second streaming request returned no responses") + } + + t.Logf("Second streaming request completed in %v with %d chunks", duration2, len(responses2)) + + // Validate that both streams have the same number of chunks + if len(responses1) != len(responses2) { + t.Errorf("Stream chunk count mismatch: original=%d, cached=%d", len(responses1), len(responses2)) + } + + // Validate that the second stream was cached + cached := false + for _, response := range responses2 { + if response.ExtraFields.CacheDebug != nil && response.ExtraFields.CacheDebug.CacheHit { + cached = true + break + } + } + + if !cached { + t.Fatal("Second streaming request was not served from cache") + } + + // Validate performance improvement + if duration2 >= duration1 { + t.Errorf("Cached stream took longer than original: cache=%v, original=%v", duration2, duration1) + } else { + speedup := float64(duration1) / float64(duration2) + t.Logf("Streaming cache speedup: %.2fx faster", speedup) + } + + // Validate chunk ordering is maintained + for i := range responses2 { + if responses2[i].ExtraFields.ChunkIndex != responses1[i].ExtraFields.ChunkIndex { + t.Errorf("Chunk index mismatch at position %d: original=%d, cached=%d", + i, responses1[i].ExtraFields.ChunkIndex, responses2[i].ExtraFields.ChunkIndex) + } + } + + t.Log("βœ… Streaming cache test completed successfully!") +} + +// TestStreamingVsNonStreaming tests that streaming and non-streaming requests are cached separately +func TestStreamingVsNonStreaming(t *testing.T) { + setup := NewTestSetup(t) + defer setup.Cleanup() + + ctx := CreateContextWithCacheKey("stream-vs-non-test") + + prompt := "What is the meaning of life?" + + // Make non-streaming request first + t.Log("Making non-streaming request...") + nonStreamRequest := CreateBasicChatRequest(prompt, 0.5, 50) + nonStreamResponse, err1 := setup.Client.ChatCompletionRequest(ctx, nonStreamRequest) + if err1 != nil { + t.Fatalf("Non-streaming request failed: %v", err1) + } + + WaitForCache() + + // Make streaming request with same prompt and parameters + t.Log("Making streaming request with same prompt...") + streamRequest := CreateStreamingChatRequest(prompt, 0.5, 50) + stream, err2 := setup.Client.ChatCompletionStreamRequest(ctx, streamRequest) + if err2 != nil { + t.Fatalf("Streaming request failed: %v", err2) + } + + var streamResponses []schemas.BifrostResponse + for streamMsg := range stream { + if streamMsg.BifrostError != nil { + t.Fatalf("Error in stream: %v", streamMsg.BifrostError) + } + if streamMsg.BifrostResponse != nil { + streamResponses = append(streamResponses, *streamMsg.BifrostResponse) + } + } + + if len(streamResponses) == 0 { + t.Fatal("Streaming request returned no responses") + } + + // Verify that the streaming request was NOT served from the non-streaming cache + // (They should be cached separately) + streamCached := false + for _, response := range streamResponses { + if response.ExtraFields.RawResponse != nil { + if rawMap, ok := response.ExtraFields.RawResponse.(map[string]interface{}); ok { + if cachedFlag, exists := rawMap["bifrost_cached"]; exists { + if cachedBool, ok := cachedFlag.(bool); ok && cachedBool { + streamCached = true + break + } + } + } + } + } + + if streamCached { + t.Error("Streaming request should not be cached from non-streaming cache") + } else { + t.Log("βœ… Streaming request correctly not cached from non-streaming cache") + } + + // Verify non-streaming response was not affected + AssertNoCacheHit(t, nonStreamResponse) + + t.Log("βœ… Streaming vs non-streaming test completed!") +} + +// TestStreamingChunkOrdering tests that cached streaming responses maintain proper chunk ordering +func TestStreamingChunkOrdering(t *testing.T) { + setup := NewTestSetup(t) + defer setup.Cleanup() + + ctx := CreateContextWithCacheKey("chunk-order-test") + + // Request that should generate multiple chunks + testRequest := CreateStreamingChatRequest( + "List the first 5 prime numbers, one per line with explanation.", + 0.0, + 100, + ) + + t.Log("Making first streaming request to establish cache...") + stream1, err1 := setup.Client.ChatCompletionStreamRequest(ctx, testRequest) + if err1 != nil { + t.Fatalf("First streaming request failed: %v", err1) + } + + var originalChunks []schemas.BifrostResponse + for streamMsg := range stream1 { + if streamMsg.BifrostError != nil { + t.Fatalf("Error in first stream: %v", streamMsg.BifrostError) + } + if streamMsg.BifrostResponse != nil { + originalChunks = append(originalChunks, *streamMsg.BifrostResponse) + } + } + + if len(originalChunks) < 2 { + t.Skipf("Need at least 2 chunks to test ordering, got %d", len(originalChunks)) + } + + t.Logf("Original stream had %d chunks", len(originalChunks)) + + WaitForCache() + + t.Log("Making second streaming request to test cached chunk ordering...") + stream2, err2 := setup.Client.ChatCompletionStreamRequest(ctx, testRequest) + if err2 != nil { + t.Fatalf("Second streaming request failed: %v", err2) + } + + var cachedChunks []schemas.BifrostResponse + for streamMsg := range stream2 { + if streamMsg.BifrostError != nil { + t.Fatalf("Error in second stream: %v", streamMsg.BifrostError) + } + if streamMsg.BifrostResponse != nil { + cachedChunks = append(cachedChunks, *streamMsg.BifrostResponse) + } + } + + if len(cachedChunks) != len(originalChunks) { + t.Errorf("Cached stream chunk count mismatch: original=%d, cached=%d", + len(originalChunks), len(cachedChunks)) + } + + // Verify chunk ordering + for i := 0; i < len(cachedChunks) && i < len(originalChunks); i++ { + originalIndex := originalChunks[i].ExtraFields.ChunkIndex + cachedIndex := cachedChunks[i].ExtraFields.ChunkIndex + + if originalIndex != cachedIndex { + t.Errorf("Chunk index mismatch at position %d: original=%d, cached=%d", + i, originalIndex, cachedIndex) + } + + // Only verify cache hit on the last chunk (where CacheDebug is set) + if i == len(cachedChunks)-1 { + AssertCacheHit(t, &cachedChunks[i], string(CacheTypeDirect)) + } + } + + // Verify chunks are in sequential order + for i := 1; i < len(cachedChunks); i++ { + prevIndex := cachedChunks[i-1].ExtraFields.ChunkIndex + currIndex := cachedChunks[i].ExtraFields.ChunkIndex + + if currIndex <= prevIndex { + t.Errorf("Chunks not in sequential order: chunk %d has index %d, chunk %d has index %d", + i-1, prevIndex, i, currIndex) + } + } + + t.Log("βœ… Streaming chunk ordering test completed successfully!") +} + +// TestSpeechSynthesisStreaming tests speech synthesis streaming caching +func TestSpeechSynthesisStreaming(t *testing.T) { + setup := NewTestSetup(t) + defer setup.Cleanup() + + ctx := CreateContextWithCacheKey("speech-stream-test") + + // Create speech synthesis request + speechRequest := CreateSpeechRequest( + "This is a test of speech synthesis streaming cache.", + "alloy", + ) + + t.Log("Making first speech synthesis request...") + start1 := time.Now() + response1, err1 := setup.Client.SpeechRequest(ctx, speechRequest) + duration1 := time.Since(start1) + + if err1 != nil { + t.Fatalf("First speech request failed: %v", err1) + } + + if response1 == nil { + t.Fatal("First speech response is nil") + } + + t.Logf("First speech request completed in %v", duration1) + + WaitForCache() + + t.Log("Making second identical speech synthesis request...") + start2 := time.Now() + response2, err2 := setup.Client.SpeechRequest(ctx, speechRequest) + duration2 := time.Since(start2) + + if err2 != nil { + t.Fatalf("Second speech request failed: %v", err2) + } + + if response2 == nil { + t.Fatal("Second speech response is nil") + } + + t.Logf("Second speech request completed in %v", duration2) + + // Check if second request was cached + AssertCacheHit(t, response2, string(CacheTypeDirect)) + + // Performance comparison + t.Logf("Speech Synthesis Performance:") + t.Logf("First request: %v", duration1) + t.Logf("Second request: %v", duration2) + + if duration2 < duration1 { + speedup := float64(duration1) / float64(duration2) + t.Logf("Speech cache speedup: %.2fx faster", speedup) + } + + t.Log("βœ… Speech synthesis streaming test completed successfully!") +} diff --git a/plugins/semanticcache/search.go b/plugins/semanticcache/search.go new file mode 100644 index 000000000..30dffeec9 --- /dev/null +++ b/plugins/semanticcache/search.go @@ -0,0 +1,387 @@ +package semanticcache + +import ( + "context" + "encoding/json" + "errors" + "fmt" + "strconv" + "time" + + bifrost "github.com/maximhq/bifrost/core" + "github.com/maximhq/bifrost/core/schemas" + "github.com/maximhq/bifrost/framework/vectorstore" +) + +func (plugin *Plugin) performDirectSearch(ctx *context.Context, req *schemas.BifrostRequest, requestType schemas.RequestType, cacheKey string) (*schemas.PluginShortCircuit, error) { + // Generate hash for the request + hash, err := plugin.generateRequestHash(req, requestType) + if err != nil { + return nil, fmt.Errorf("failed to generate request hash: %w", err) + } + + plugin.logger.Debug(PluginLoggerPrefix + " Generated Hash for Request: " + hash) + + // Extract metadata for strict filtering + _, paramsHash, err := plugin.extractTextForEmbedding(req, requestType) + if err != nil { + return nil, fmt.Errorf("failed to extract metadata for filtering: %w", err) + } + + // Store has and metadata in context + *ctx = context.WithValue(*ctx, requestHashKey, hash) + *ctx = context.WithValue(*ctx, requestParamsHashKey, paramsHash) + + // Build strict filters for direct hash search + filters := []vectorstore.Query{ + {Field: "request_hash", Operator: vectorstore.QueryOperatorEqual, Value: hash}, + {Field: "cache_key", Operator: vectorstore.QueryOperatorEqual, Value: cacheKey}, + {Field: "params_hash", Operator: vectorstore.QueryOperatorEqual, Value: paramsHash}, + {Field: "from_bifrost_semantic_cache_plugin", Operator: vectorstore.QueryOperatorEqual, Value: true}, + } + + if plugin.config.CacheByProvider != nil && *plugin.config.CacheByProvider { + filters = append(filters, vectorstore.Query{Field: "provider", Operator: vectorstore.QueryOperatorEqual, Value: string(req.Provider)}) + } + if plugin.config.CacheByModel != nil && *plugin.config.CacheByModel { + filters = append(filters, vectorstore.Query{Field: "model", Operator: vectorstore.QueryOperatorEqual, Value: req.Model}) + } + + plugin.logger.Debug(fmt.Sprintf("%s Searching for direct hash match with %d filters", PluginLoggerPrefix, len(filters))) + + // Make a full copy so we don't mutate the original backing array + selectFields := append([]string(nil), SelectFields...) + if plugin.isStreamingRequest(requestType) { + selectFields = removeField(selectFields, "response") + } else { + selectFields = removeField(selectFields, "stream_chunks") + } + + // Search for entries with matching hash and all params + var cursor *string + results, _, err := plugin.store.GetAll(*ctx, plugin.config.VectorStoreNamespace, filters, selectFields, cursor, 1) + if err != nil { + if errors.Is(err, vectorstore.ErrNotFound) { + return nil, nil + } + return nil, fmt.Errorf("failed to search for direct hash match: %w", err) + } + + if len(results) == 0 { + plugin.logger.Debug(PluginLoggerPrefix + " No direct hash match found") + return nil, nil + } + + // Found a matching entry - extract the response + result := results[0] + plugin.logger.Debug(fmt.Sprintf("%s Found direct hash match with ID: %s", PluginLoggerPrefix, result.ID)) + + // Build response from cached result + return plugin.buildResponseFromResult(ctx, req, result, CacheTypeDirect, 1.0, 0) +} + +// performSemanticSearch performs semantic similarity search and returns matching response if found. +func (plugin *Plugin) performSemanticSearch(ctx *context.Context, req *schemas.BifrostRequest, requestType schemas.RequestType, cacheKey string) (*schemas.PluginShortCircuit, error) { + // Extract text and metadata for embedding + text, paramsHash, err := plugin.extractTextForEmbedding(req, requestType) + if err != nil { + return nil, fmt.Errorf("failed to extract text for embedding: %w", err) + } + + // Generate embedding + embedding, inputTokens, err := plugin.generateEmbedding(*ctx, text) + if err != nil { + return nil, fmt.Errorf("failed to generate embedding: %w", err) + } + + // Store embedding and metadata in context for PostHook + *ctx = context.WithValue(*ctx, requestEmbeddingKey, embedding) + *ctx = context.WithValue(*ctx, requestEmbeddingTokensKey, inputTokens) + *ctx = context.WithValue(*ctx, requestParamsHashKey, paramsHash) + + cacheThreshold := plugin.config.Threshold + + thresholdValue := (*ctx).Value(CacheThresholdKey) + if thresholdValue != nil { + threshold, ok := thresholdValue.(float64) + if !ok { + plugin.logger.Warn(PluginLoggerPrefix + " Threshold is not a float64, using default threshold") + } else { + cacheThreshold = threshold + } + } + + // Build strict metadata filters as Query slices (provider, model, and all params) + strictFilters := []vectorstore.Query{ + {Field: "cache_key", Operator: vectorstore.QueryOperatorEqual, Value: cacheKey}, + {Field: "params_hash", Operator: vectorstore.QueryOperatorEqual, Value: paramsHash}, + {Field: "from_bifrost_semantic_cache_plugin", Operator: vectorstore.QueryOperatorEqual, Value: true}, + } + + if plugin.config.CacheByProvider != nil && *plugin.config.CacheByProvider { + strictFilters = append(strictFilters, vectorstore.Query{Field: "provider", Operator: vectorstore.QueryOperatorEqual, Value: string(req.Provider)}) + } + if plugin.config.CacheByModel != nil && *plugin.config.CacheByModel { + strictFilters = append(strictFilters, vectorstore.Query{Field: "model", Operator: vectorstore.QueryOperatorEqual, Value: req.Model}) + } + + plugin.logger.Debug(fmt.Sprintf("%s Performing semantic search with %d metadata filters", PluginLoggerPrefix, len(strictFilters))) + + // Make a full copy so we don't mutate the original backing array + selectFields := append([]string(nil), SelectFields...) + if plugin.isStreamingRequest(requestType) { + selectFields = removeField(selectFields, "response") + } else { + selectFields = removeField(selectFields, "stream_chunks") + } + + // For semantic search, we want semantic similarity in content but exact parameter matching + results, err := plugin.store.GetNearest(*ctx, plugin.config.VectorStoreNamespace, embedding, strictFilters, selectFields, cacheThreshold, 1) + if err != nil { + return nil, fmt.Errorf("failed to search semantic cache: %w", err) + } + + if len(results) == 0 { + plugin.logger.Debug(PluginLoggerPrefix + " No semantic match found") + return nil, nil + } + + // Found a semantically similar entry + result := results[0] + plugin.logger.Debug(fmt.Sprintf("%s Found semantic match with ID: %s, Score: %f", PluginLoggerPrefix, result.ID, *result.Score)) + + // Build response from cached result + return plugin.buildResponseFromResult(ctx, req, result, CacheTypeSemantic, cacheThreshold, inputTokens) +} + +// buildResponseFromResult constructs a PluginShortCircuit response from a cached VectorEntry result +func (plugin *Plugin) buildResponseFromResult(ctx *context.Context, req *schemas.BifrostRequest, result vectorstore.SearchResult, cacheType CacheType, threshold float64, inputTokens int) (*schemas.PluginShortCircuit, error) { + // Extract response data from the result properties + properties := result.Properties + if properties == nil { + return nil, fmt.Errorf("no properties found in cached result") + } + + // Check TTL - if entry has expired, delete it and return cache miss + if expiresAtRaw, exists := properties["expires_at"]; exists && expiresAtRaw != nil { + var expiresAt int64 + var validType bool + switch v := expiresAtRaw.(type) { + case string: + var err error + expiresAt, err = strconv.ParseInt(v, 10, 64) + if err != nil { + validType = false + } else { + validType = true + } + case float64: + expiresAt = int64(v) + validType = true + case int64: + expiresAt = v + validType = true + case int: + expiresAt = int64(v) + validType = true + } + if validType { + currentTime := time.Now().Unix() + if expiresAt < currentTime { + // Entry has expired, delete it asynchronously + go func() { + deleteCtx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + err := plugin.store.Delete(deleteCtx, plugin.config.VectorStoreNamespace, result.ID) + if err != nil { + plugin.logger.Warn(fmt.Sprintf("%s Failed to delete expired entry %s: %v", PluginLoggerPrefix, result.ID, err)) + } + }() + // Return nil to indicate cache miss + return nil, nil + } + } + } + + // Check if this is a streaming response - need to check for non-null values + streamResponses, hasStreamingResponse := properties["stream_chunks"] + singleResponse, hasSingleResponse := properties["response"] + + // Consider fields present only if they're not null + hasValidSingleResponse := hasSingleResponse && singleResponse != nil + hasValidStreamingResponse := hasStreamingResponse && streamResponses != nil + + // Parse stream_chunks + streamChunks, err := plugin.parseStreamChunks(streamResponses) + if err != nil || len(streamChunks) == 0 { + hasValidStreamingResponse = false + } + + similarity := 0.0 + if result.Score != nil { + similarity = *result.Score + } + + if hasValidStreamingResponse && !hasValidSingleResponse { + // Handle streaming response + return plugin.buildStreamingResponseFromResult(ctx, req, result, streamResponses, cacheType, threshold, similarity, inputTokens) + } else if hasValidSingleResponse && !hasValidStreamingResponse { + // Handle single response + return plugin.buildSingleResponseFromResult(ctx, req, result, singleResponse, cacheType, threshold, similarity, inputTokens) + } else { + return nil, fmt.Errorf("cached result has invalid response data: both or neither response/stream_chunks are present (response: %v, stream_chunks: %v)", singleResponse, streamResponses) + } +} + +// buildSingleResponseFromResult constructs a single response from cached data +func (plugin *Plugin) buildSingleResponseFromResult(ctx *context.Context, req *schemas.BifrostRequest, result vectorstore.SearchResult, responseData interface{}, cacheType CacheType, threshold float64, similarity float64, inputTokens int) (*schemas.PluginShortCircuit, error) { + responseStr, ok := responseData.(string) + if !ok { + return nil, fmt.Errorf("cached response is not a string") + } + + // Unmarshal the cached response + var cachedResponse schemas.BifrostResponse + if err := json.Unmarshal([]byte(responseStr), &cachedResponse); err != nil { + return nil, fmt.Errorf("failed to unmarshal cached response: %w", err) + } + + if cachedResponse.ExtraFields.CacheDebug == nil { + cachedResponse.ExtraFields.CacheDebug = &schemas.BifrostCacheDebug{} + } + cachedResponse.ExtraFields.CacheDebug.CacheHit = true + cachedResponse.ExtraFields.CacheDebug.HitType = bifrost.Ptr(string(cacheType)) + cachedResponse.ExtraFields.CacheDebug.CacheID = bifrost.Ptr(result.ID) + if cacheType == CacheTypeSemantic { + cachedResponse.ExtraFields.CacheDebug.ProviderUsed = bifrost.Ptr(string(plugin.config.Provider)) + cachedResponse.ExtraFields.CacheDebug.ModelUsed = bifrost.Ptr(plugin.config.EmbeddingModel) + cachedResponse.ExtraFields.CacheDebug.Threshold = &threshold + cachedResponse.ExtraFields.CacheDebug.Similarity = &similarity + cachedResponse.ExtraFields.CacheDebug.InputTokens = &inputTokens + } else { + cachedResponse.ExtraFields.CacheDebug.ProviderUsed = nil + cachedResponse.ExtraFields.CacheDebug.ModelUsed = nil + cachedResponse.ExtraFields.CacheDebug.Threshold = nil + cachedResponse.ExtraFields.CacheDebug.Similarity = nil + cachedResponse.ExtraFields.CacheDebug.InputTokens = nil + } + + cachedResponse.ExtraFields.Provider = req.Provider + + *ctx = context.WithValue(*ctx, isCacheHitKey, true) + *ctx = context.WithValue(*ctx, cacheHitTypeKey, cacheType) + + return &schemas.PluginShortCircuit{ + Response: &cachedResponse, + }, nil +} + +// buildStreamingResponseFromResult constructs a streaming response from cached data +func (plugin *Plugin) buildStreamingResponseFromResult(ctx *context.Context, req *schemas.BifrostRequest, result vectorstore.SearchResult, streamData interface{}, cacheType CacheType, threshold float64, similarity float64, inputTokens int) (*schemas.PluginShortCircuit, error) { + // Parse stream_chunks + streamArray, err := plugin.parseStreamChunks(streamData) + if err != nil { + return nil, fmt.Errorf("failed to parse stream_chunks: %w", err) + } + + // Mark cache-hit once to avoid concurrent ctx writes + *ctx = context.WithValue(*ctx, isCacheHitKey, true) + *ctx = context.WithValue(*ctx, cacheHitTypeKey, cacheType) + + // Create stream channel + streamChan := make(chan *schemas.BifrostStream) + + go func() { + defer close(streamChan) + + // Set cache-hit markers inside the streaming goroutine to avoid races + *ctx = context.WithValue(*ctx, isCacheHitKey, true) + *ctx = context.WithValue(*ctx, cacheHitTypeKey, cacheType) + + // Process each stream chunk + for i, chunkData := range streamArray { + chunkStr, ok := chunkData.(string) + if !ok { + plugin.logger.Warn(fmt.Sprintf("%s Stream chunk %d is not a string, skipping", PluginLoggerPrefix, i)) + continue + } + + // Unmarshal the chunk as BifrostResponse + var cachedResponse schemas.BifrostResponse + if err := json.Unmarshal([]byte(chunkStr), &cachedResponse); err != nil { + plugin.logger.Warn(fmt.Sprintf("%s Failed to unmarshal stream chunk %d, skipping: %v", PluginLoggerPrefix, i, err)) + continue + } + + // Add cache debug to only the last chunk and set stream end indicator + if i == len(streamArray)-1 { + *ctx = context.WithValue(*ctx, schemas.BifrostContextKeyStreamEndIndicator, true) + + if cachedResponse.ExtraFields.CacheDebug == nil { + cachedResponse.ExtraFields.CacheDebug = &schemas.BifrostCacheDebug{} + } + cachedResponse.ExtraFields.CacheDebug.CacheHit = true + cachedResponse.ExtraFields.CacheDebug.HitType = bifrost.Ptr(string(cacheType)) + cachedResponse.ExtraFields.CacheDebug.CacheID = bifrost.Ptr(result.ID) + if cacheType == CacheTypeSemantic { + cachedResponse.ExtraFields.CacheDebug.ProviderUsed = bifrost.Ptr(string(plugin.config.Provider)) + cachedResponse.ExtraFields.CacheDebug.ModelUsed = bifrost.Ptr(plugin.config.EmbeddingModel) + cachedResponse.ExtraFields.CacheDebug.Threshold = &threshold + cachedResponse.ExtraFields.CacheDebug.Similarity = &similarity + cachedResponse.ExtraFields.CacheDebug.InputTokens = &inputTokens + } else { + cachedResponse.ExtraFields.CacheDebug.ProviderUsed = nil + cachedResponse.ExtraFields.CacheDebug.ModelUsed = nil + cachedResponse.ExtraFields.CacheDebug.Threshold = nil + cachedResponse.ExtraFields.CacheDebug.Similarity = nil + cachedResponse.ExtraFields.CacheDebug.InputTokens = nil + } + } + + cachedResponse.ExtraFields.Provider = req.Provider + + // Send chunk to stream + streamChan <- &schemas.BifrostStream{ + BifrostResponse: &cachedResponse, + } + } + }() + + return &schemas.PluginShortCircuit{ + Stream: streamChan, + }, nil +} + +// parseStreamChunks parses stream_chunks data from various formats into []interface{} +// Handles []interface{}, []string, and JSON string formats +func (plugin *Plugin) parseStreamChunks(streamData interface{}) ([]interface{}, error) { + if streamData == nil { + return nil, fmt.Errorf("stream data is nil") + } + + switch v := streamData.(type) { + case []interface{}: + return v, nil + case []string: + // Convert []string to []interface{} + result := make([]interface{}, len(v)) + for i, s := range v { + result[i] = s + } + return result, nil + case string: + // Parse JSON string from Redis + var stringArray []string + if err := json.Unmarshal([]byte(v), &stringArray); err != nil { + return nil, fmt.Errorf("failed to parse JSON string: %w", err) + } + // Convert to []interface{} + result := make([]interface{}, len(stringArray)) + for i, s := range stringArray { + result[i] = s + } + return result, nil + default: + return nil, fmt.Errorf("unsupported stream data type: %T", streamData) + } +} diff --git a/plugins/semanticcache/stream.go b/plugins/semanticcache/stream.go new file mode 100644 index 000000000..f13b0e27f --- /dev/null +++ b/plugins/semanticcache/stream.go @@ -0,0 +1,168 @@ +package semanticcache + +import ( + "context" + "encoding/json" + "fmt" + "sort" + "time" +) + +// Streaming State Management Methods + +// createStreamAccumulator creates a new stream accumulator for a request +func (plugin *Plugin) createStreamAccumulator(requestID string, embedding []float32, metadata map[string]interface{}, ttl time.Duration) *StreamAccumulator { + accumulator := &StreamAccumulator{ + RequestID: requestID, + Chunks: make([]*StreamChunk, 0), + IsComplete: false, + Embedding: embedding, + Metadata: metadata, + TTL: ttl, + } + + plugin.streamAccumulators.Store(requestID, accumulator) + return accumulator +} + +// getOrCreateStreamAccumulator gets or creates a stream accumulator for a request +func (plugin *Plugin) getOrCreateStreamAccumulator(requestID string, embedding []float32, metadata map[string]interface{}, ttl time.Duration) *StreamAccumulator { + if accumulator, exists := plugin.streamAccumulators.Load(requestID); exists { + return accumulator.(*StreamAccumulator) + } + + // Create new accumulator if it doesn't exist + return plugin.createStreamAccumulator(requestID, embedding, metadata, ttl) +} + +// addStreamChunk adds a chunk to the stream accumulator +func (plugin *Plugin) addStreamChunk(requestID string, chunk *StreamChunk, isFinalChunk bool) error { + // Get accumulator (should exist if properly initialized) + accumulatorInterface, exists := plugin.streamAccumulators.Load(requestID) + if !exists { + return fmt.Errorf("stream accumulator not found for request %s", requestID) + } + + accumulator := accumulatorInterface.(*StreamAccumulator) + accumulator.mu.Lock() + defer accumulator.mu.Unlock() + + // Add chunk to the list (chunks arrive in order) + accumulator.Chunks = append(accumulator.Chunks, chunk) + + // Set FinalTimestamp when FinishReason is present + // This handles both normal completion chunks and usage-only last chunks + if isFinalChunk { + accumulator.FinalTimestamp = chunk.Timestamp + } + + plugin.logger.Debug(fmt.Sprintf("%s Added chunk to stream accumulator for request %s", PluginLoggerPrefix, requestID)) + + return nil +} + +// processAccumulatedStream processes all accumulated chunks and caches the complete stream +// Flow: Collect everything β†’ Check for ANY errors β†’ If no errors, order and send to .Add() β†’ If any errors, drop operation +func (plugin *Plugin) processAccumulatedStream(ctx context.Context, requestID string) error { + accumulatorInterface, exists := plugin.streamAccumulators.Load(requestID) + if !exists { + return fmt.Errorf("stream accumulator not found for request %s", requestID) + } + + accumulator := accumulatorInterface.(*StreamAccumulator) + accumulator.mu.Lock() + + // Ensure cleanup happens + defer plugin.cleanupStreamAccumulator(requestID) + defer accumulator.mu.Unlock() + + // STEP 1: Check if any chunk in the entire stream had an error + if accumulator.HasError { + plugin.logger.Debug(fmt.Sprintf("%s Stream for request %s had errors, dropping entire operation (not caching)", PluginLoggerPrefix, requestID)) + return nil + } + + // STEP 2: All chunks are clean, now sort and build ordered stream for caching + plugin.logger.Debug(fmt.Sprintf("%s Stream for request %s completed successfully, processing %d chunks for caching", PluginLoggerPrefix, requestID, len(accumulator.Chunks))) + + // Sort chunks by their ChunkIndex to ensure proper order (stable + nil-safe) + sort.SliceStable(accumulator.Chunks, func(i, j int) bool { + if accumulator.Chunks[i].Response == nil || accumulator.Chunks[j].Response == nil { + // Push nils to the end deterministically + return accumulator.Chunks[j].Response != nil + } + return accumulator.Chunks[i].Response.ExtraFields.ChunkIndex < accumulator.Chunks[j].Response.ExtraFields.ChunkIndex + }) + + var streamResponses []string + for i, chunk := range accumulator.Chunks { + if chunk.Response != nil { + chunkData, err := json.Marshal(chunk.Response) + if err != nil { + plugin.logger.Warn(fmt.Sprintf("%s Failed to marshal stream chunk %d: %v", PluginLoggerPrefix, i, err)) + continue + } + streamResponses = append(streamResponses, string(chunkData)) + } + } + + // STEP 3: Validate we have valid chunks to cache + if len(streamResponses) == 0 { + plugin.logger.Warn(fmt.Sprintf("%s Stream for request %s has no valid response chunks, skipping cache storage", PluginLoggerPrefix, requestID)) + return nil + } + + // STEP 4: Build final metadata and submit to .Add() method + finalMetadata := make(map[string]interface{}) + for k, v := range accumulator.Metadata { + finalMetadata[k] = v + } + finalMetadata["stream_chunks"] = streamResponses + + // Store complete unified entry using original requestID - this is the final .Add() call + if err := plugin.store.Add(ctx, plugin.config.VectorStoreNamespace, requestID, accumulator.Embedding, finalMetadata); err != nil { + return fmt.Errorf("failed to store complete streaming cache entry: %w", err) + } + + plugin.logger.Debug(fmt.Sprintf("%s Successfully cached complete stream with %d ordered chunks, ID: %s", PluginLoggerPrefix, len(streamResponses), requestID)) + return nil +} + +// cleanupStreamAccumulator removes the stream accumulator for a request +func (plugin *Plugin) cleanupStreamAccumulator(requestID string) { + plugin.streamAccumulators.Delete(requestID) +} + +// cleanupOldStreamAccumulators removes stream accumulators older than 5 minutes +func (plugin *Plugin) cleanupOldStreamAccumulators() { + fiveMinutesAgo := time.Now().Add(-5 * time.Minute) + cleanedCount := 0 + toDelete := make([]string, 0) + + plugin.streamAccumulators.Range(func(key, value interface{}) bool { + requestID := key.(string) + accumulator := value.(*StreamAccumulator) + + // Check if this accumulator is old (no activity for 5 minutes) + accumulator.mu.Lock() + if len(accumulator.Chunks) > 0 { + firstChunkTime := accumulator.Chunks[0].Timestamp + if firstChunkTime.Before(fiveMinutesAgo) { + toDelete = append(toDelete, requestID) + plugin.logger.Debug(fmt.Sprintf("%s Cleaned up old stream accumulator for request %s", PluginLoggerPrefix, requestID)) + } + } + accumulator.mu.Unlock() + return true + }) + + // Delete outside the Range loop to avoid concurrent modification + for _, requestID := range toDelete { + plugin.streamAccumulators.Delete(requestID) + cleanedCount++ + } + + if cleanedCount > 0 { + plugin.logger.Debug(fmt.Sprintf("%s Cleaned up %d old stream accumulators", PluginLoggerPrefix, cleanedCount)) + } +} diff --git a/plugins/semanticcache/test_utils.go b/plugins/semanticcache/test_utils.go new file mode 100644 index 000000000..22c4b834a --- /dev/null +++ b/plugins/semanticcache/test_utils.go @@ -0,0 +1,463 @@ +package semanticcache + +import ( + "context" + "os" + "strconv" + "testing" + "time" + + bifrost "github.com/maximhq/bifrost/core" + "github.com/maximhq/bifrost/core/schemas" + "github.com/maximhq/bifrost/framework/vectorstore" +) + +// getWeaviateConfigFromEnv retrieves Weaviate configuration from environment variables +func getWeaviateConfigFromEnv() vectorstore.WeaviateConfig { + scheme := os.Getenv("WEAVIATE_SCHEME") + if scheme == "" { + scheme = "http" + } + + host := os.Getenv("WEAVIATE_HOST") + if host == "" { + host = "localhost:9000" + } + + apiKey := os.Getenv("WEAVIATE_API_KEY") + + timeoutStr := os.Getenv("WEAVIATE_TIMEOUT") + timeout := 30 // default + if timeoutStr != "" { + if t, err := strconv.Atoi(timeoutStr); err == nil { + timeout = t + } + } + + return vectorstore.WeaviateConfig{ + Scheme: scheme, + Host: host, + ApiKey: apiKey, + Timeout: time.Duration(timeout) * time.Second, + } +} + +// getRedisConfigFromEnv retrieves Redis configuration from environment variables +func getRedisConfigFromEnv() vectorstore.RedisConfig { + addr := os.Getenv("REDIS_ADDR") + if addr == "" { + addr = "localhost:6379" + } + username := os.Getenv("REDIS_USERNAME") + password := os.Getenv("REDIS_PASSWORD") + db := os.Getenv("REDIS_DB") + if db == "" { + db = "0" + } + dbInt, err := strconv.Atoi(db) + if err != nil { + dbInt = 0 + } + + timeoutStr := os.Getenv("REDIS_TIMEOUT") + if timeoutStr == "" { + timeoutStr = "10s" + } + timeout, err := time.ParseDuration(timeoutStr) + if err != nil { + timeout = 10 * time.Second + } + + return vectorstore.RedisConfig{ + Addr: addr, + Username: username, + Password: password, + DB: dbInt, + ContextTimeout: timeout, + } +} + +// BaseAccount implements the schemas.Account interface for testing purposes. +type BaseAccount struct{} + +func (baseAccount *BaseAccount) GetConfiguredProviders() ([]schemas.ModelProvider, error) { + return []schemas.ModelProvider{schemas.OpenAI}, nil +} + +func (baseAccount *BaseAccount) GetKeysForProvider(ctx *context.Context, providerKey schemas.ModelProvider) ([]schemas.Key, error) { + return []schemas.Key{ + { + Value: os.Getenv("OPENAI_API_KEY"), + Models: []string{}, // Empty models array means it supports ALL models + Weight: 1.0, + }, + }, nil +} + +func (baseAccount *BaseAccount) GetConfigForProvider(providerKey schemas.ModelProvider) (*schemas.ProviderConfig, error) { + return &schemas.ProviderConfig{ + NetworkConfig: schemas.NetworkConfig{ + DefaultRequestTimeoutInSeconds: 60, + MaxRetries: 5, + RetryBackoffInitial: 100 * time.Millisecond, + RetryBackoffMax: 10 * time.Second, + }, + ConcurrencyAndBufferSize: schemas.DefaultConcurrencyAndBufferSize, + }, nil +} + +// TestSetup contains common test setup components +type TestSetup struct { + Logger schemas.Logger + Store vectorstore.VectorStore + Plugin schemas.Plugin + Client *bifrost.Bifrost + Config Config +} + +// NewTestSetup creates a new test setup with default configuration +func NewTestSetup(t *testing.T) *TestSetup { + if os.Getenv("OPENAI_API_KEY") == "" { + t.Skip("OPENAI_API_KEY is not set, skipping test") + } + + return NewTestSetupWithConfig(t, Config{ + Provider: schemas.OpenAI, + EmbeddingModel: "text-embedding-3-small", + Threshold: 0.8, + CleanUpOnShutdown: true, + Keys: []schemas.Key{ + { + Value: os.Getenv("OPENAI_API_KEY"), + Models: []string{}, + Weight: 1.0, + }, + }, + }) +} + +// NewTestSetupWithConfig creates a new test setup with custom configuration +func NewTestSetupWithConfig(t *testing.T, config Config) *TestSetup { + ctx := context.Background() + logger := bifrost.NewDefaultLogger(schemas.LogLevelDebug) + + store, err := vectorstore.NewVectorStore(context.Background(), &vectorstore.Config{ + Type: vectorstore.VectorStoreTypeWeaviate, + Config: getWeaviateConfigFromEnv(), + Enabled: true, + }, logger) + if err != nil { + t.Fatalf("Vector store not available or failed to connect: %v", err) + } + + plugin, err := Init(context.Background(), config, logger, store) + if err != nil { + t.Fatalf("Failed to initialize plugin: %v", err) + } + + // Clear test keys + pluginImpl := plugin.(*Plugin) + clearTestKeysWithStore(t, pluginImpl.store) + + account := &BaseAccount{} + client, err := bifrost.Init(ctx, schemas.BifrostConfig{ + Account: account, + Plugins: []schemas.Plugin{plugin}, + Logger: logger, + }) + if err != nil { + t.Fatalf("Error initializing Bifrost: %v", err) + } + + return &TestSetup{ + Logger: logger, + Store: store, + Plugin: plugin, + Client: client, + Config: config, + } +} + +// Cleanup cleans up test resources +func (ts *TestSetup) Cleanup() { + if ts.Client != nil { + ts.Client.Shutdown() + } +} + +// clearTestKeysWithStore removes all keys matching the test prefix using the store interface +func clearTestKeysWithStore(t *testing.T, store vectorstore.VectorStore) { + // With the new unified VectorStore interface, cleanup is typically handled + // by the vector store implementation (e.g., dropping entire classes) + t.Logf("Test cleanup delegated to vector store implementation") +} + +// CreateBasicChatRequest creates a basic chat completion request for testing +func CreateBasicChatRequest(content string, temperature float64, maxTokens int) *schemas.BifrostRequest { + return &schemas.BifrostRequest{ + Provider: schemas.OpenAI, + Model: "gpt-4o-mini", + Input: schemas.RequestInput{ + ChatCompletionInput: &[]schemas.BifrostMessage{ + { + Role: "user", + Content: schemas.MessageContent{ + ContentStr: &content, + }, + }, + }, + }, + Params: &schemas.ModelParameters{ + Temperature: &temperature, + MaxTokens: &maxTokens, + }, + } +} + +// CreateStreamingChatRequest creates a streaming chat completion request for testing +func CreateStreamingChatRequest(content string, temperature float64, maxTokens int) *schemas.BifrostRequest { + return CreateBasicChatRequest(content, temperature, maxTokens) +} + +// CreateSpeechRequest creates a speech synthesis request for testing +func CreateSpeechRequest(input string, voice string) *schemas.BifrostRequest { + return &schemas.BifrostRequest{ + Provider: schemas.OpenAI, + Model: "tts-1", + Input: schemas.RequestInput{ + SpeechInput: &schemas.SpeechInput{ + Input: input, + VoiceConfig: schemas.SpeechVoiceInput{ + Voice: &voice, + }, + }, + }, + } +} + +// AssertCacheHit verifies that a response was served from cache +func AssertCacheHit(t *testing.T, response *schemas.BifrostResponse, expectedCacheType string) { + if response.ExtraFields.CacheDebug == nil { + t.Error("Cache metadata missing 'cache_debug'") + return + } + + // Check that it's actually a cache hit + if !response.ExtraFields.CacheDebug.CacheHit { + t.Error("❌ Expected cache hit but response was not cached") + return + } + + if expectedCacheType != "" { + cacheType := response.ExtraFields.CacheDebug.HitType + if cacheType != nil && *cacheType != expectedCacheType { + t.Errorf("Expected cache type '%s', got '%s'", expectedCacheType, *cacheType) + return + } + + t.Log("βœ… Response correctly served from cache") + } + + t.Log("βœ… Response correctly served from cache") +} + +// AssertNoCacheHit verifies that a response was NOT served from cache +func AssertNoCacheHit(t *testing.T, response *schemas.BifrostResponse) { + if response.ExtraFields.CacheDebug == nil { + t.Log("βœ… Response correctly not served from cache (no 'cache_debug' flag)") + return + } + + // Check the actual CacheHit field instead of just checking if CacheDebug exists + if response.ExtraFields.CacheDebug.CacheHit { + t.Error("❌ Response was cached when it shouldn't be") + return + } + + t.Log("βœ… Response correctly not served from cache (cache_debug present but CacheHit=false)") +} + +// WaitForCache waits for async cache operations to complete +func WaitForCache() { + time.Sleep(1 * time.Second) +} + +// CreateEmbeddingRequest creates an embedding request for testing +func CreateEmbeddingRequest(texts []string) *schemas.BifrostRequest { + return &schemas.BifrostRequest{ + Provider: schemas.OpenAI, + Model: "text-embedding-3-small", + Input: schemas.RequestInput{ + EmbeddingInput: &schemas.EmbeddingInput{ + Texts: texts, + }, + }, + } +} + +// CreateContextWithCacheKey creates a context with the test cache key +func CreateContextWithCacheKey(value string) context.Context { + return context.WithValue(context.Background(), CacheKey, value) +} + +// CreateContextWithCacheKeyAndType creates a context with cache key and cache type +func CreateContextWithCacheKeyAndType(value string, cacheType CacheType) context.Context { + ctx := context.WithValue(context.Background(), CacheKey, value) + return context.WithValue(ctx, CacheTypeKey, cacheType) +} + +// CreateContextWithCacheKeyAndTTL creates a context with cache key and custom TTL +func CreateContextWithCacheKeyAndTTL(value string, ttl time.Duration) context.Context { + ctx := context.WithValue(context.Background(), CacheKey, value) + return context.WithValue(ctx, CacheTTLKey, ttl) +} + +// CreateContextWithCacheKeyAndThreshold creates a context with cache key and custom threshold +func CreateContextWithCacheKeyAndThreshold(value string, threshold float64) context.Context { + ctx := context.WithValue(context.Background(), CacheKey, value) + return context.WithValue(ctx, CacheThresholdKey, threshold) +} + +// CreateContextWithCacheKeyAndNoStore creates a context with cache key and no-store flag +func CreateContextWithCacheKeyAndNoStore(value string, noStore bool) context.Context { + ctx := context.WithValue(context.Background(), CacheKey, value) + return context.WithValue(ctx, CacheNoStoreKey, noStore) +} + +// CreateTestSetupWithConversationThreshold creates a test setup with custom conversation history threshold +func CreateTestSetupWithConversationThreshold(t *testing.T, threshold int) *TestSetup { + if os.Getenv("OPENAI_API_KEY") == "" { + t.Skip("OPENAI_API_KEY is not set, skipping test") + } + + config := Config{ + Provider: schemas.OpenAI, + EmbeddingModel: "text-embedding-3-small", + CleanUpOnShutdown: true, + Threshold: 0.8, + ConversationHistoryThreshold: threshold, + Keys: []schemas.Key{ + { + Value: os.Getenv("OPENAI_API_KEY"), + Models: []string{}, + Weight: 1.0, + }, + }, + } + + return NewTestSetupWithConfig(t, config) +} + +// CreateTestSetupWithExcludeSystemPrompt creates a test setup with ExcludeSystemPrompt setting +func CreateTestSetupWithExcludeSystemPrompt(t *testing.T, excludeSystem bool) *TestSetup { + if os.Getenv("OPENAI_API_KEY") == "" { + t.Skip("OPENAI_API_KEY is not set, skipping test") + } + + config := Config{ + Provider: schemas.OpenAI, + EmbeddingModel: "text-embedding-3-small", + CleanUpOnShutdown: true, + Threshold: 0.8, + ExcludeSystemPrompt: &excludeSystem, + Keys: []schemas.Key{ + { + Value: os.Getenv("OPENAI_API_KEY"), + Models: []string{}, + Weight: 1.0, + }, + }, + } + + return NewTestSetupWithConfig(t, config) +} + +// CreateTestSetupWithThresholdAndExcludeSystem creates a test setup with both conversation threshold and exclude system prompt settings +func CreateTestSetupWithThresholdAndExcludeSystem(t *testing.T, threshold int, excludeSystem bool) *TestSetup { + if os.Getenv("OPENAI_API_KEY") == "" { + t.Skip("OPENAI_API_KEY is not set, skipping test") + } + + config := Config{ + Provider: schemas.OpenAI, + EmbeddingModel: "text-embedding-3-small", + CleanUpOnShutdown: true, + Threshold: 0.8, + ConversationHistoryThreshold: threshold, + ExcludeSystemPrompt: &excludeSystem, + Keys: []schemas.Key{ + { + Value: os.Getenv("OPENAI_API_KEY"), + Models: []string{}, + Weight: 1.0, + }, + }, + } + + return NewTestSetupWithConfig(t, config) +} + +// CreateConversationRequest creates a chat request with conversation history +func CreateConversationRequest(messages []schemas.BifrostMessage, temperature float64, maxTokens int) *schemas.BifrostRequest { + return &schemas.BifrostRequest{ + Provider: schemas.OpenAI, + Model: "gpt-4o-mini", + Input: schemas.RequestInput{ + ChatCompletionInput: &messages, + }, + Params: &schemas.ModelParameters{ + Temperature: &temperature, + MaxTokens: &maxTokens, + }, + } +} + +// BuildConversationHistory creates a conversation history from pairs of user/assistant messages +func BuildConversationHistory(systemPrompt string, userAssistantPairs ...[]string) []schemas.BifrostMessage { + messages := []schemas.BifrostMessage{} + + // Add system prompt if provided + if systemPrompt != "" { + messages = append(messages, schemas.BifrostMessage{ + Role: schemas.ModelChatMessageRoleSystem, + Content: schemas.MessageContent{ + ContentStr: &systemPrompt, + }, + }) + } + + // Add user/assistant pairs + for _, pair := range userAssistantPairs { + if len(pair) >= 1 && pair[0] != "" { + userMsg := pair[0] + messages = append(messages, schemas.BifrostMessage{ + Role: schemas.ModelChatMessageRoleUser, + Content: schemas.MessageContent{ + ContentStr: &userMsg, + }, + }) + } + if len(pair) >= 2 && pair[1] != "" { + assistantMsg := pair[1] + messages = append(messages, schemas.BifrostMessage{ + Role: schemas.ModelChatMessageRoleAssistant, + Content: schemas.MessageContent{ + ContentStr: &assistantMsg, + }, + }) + } + } + + return messages +} + +// AddUserMessage adds a user message to existing conversation +func AddUserMessage(messages []schemas.BifrostMessage, userMessage string) []schemas.BifrostMessage { + newMessage := schemas.BifrostMessage{ + Role: schemas.ModelChatMessageRoleUser, + Content: schemas.MessageContent{ + ContentStr: &userMessage, + }, + } + return append(messages, newMessage) +} diff --git a/plugins/semanticcache/utils.go b/plugins/semanticcache/utils.go new file mode 100644 index 000000000..bcdaf99c9 --- /dev/null +++ b/plugins/semanticcache/utils.go @@ -0,0 +1,461 @@ +package semanticcache + +import ( + "context" + "encoding/json" + "fmt" + "maps" + "strings" + "time" + + "github.com/cespare/xxhash/v2" + bifrost "github.com/maximhq/bifrost/core" + "github.com/maximhq/bifrost/core/schemas" +) + +// normalizeText applies consistent normalization to text inputs for better cache hit rates. +// It converts text to lowercase and trims whitespace to reduce cache misses due to minor variations. +func normalizeText(text string) string { + return strings.ToLower(strings.TrimSpace(text)) +} + +// generateEmbedding generates an embedding for the given text using the configured provider. +func (plugin *Plugin) generateEmbedding(ctx context.Context, text string) ([]float32, int, error) { + // Create embedding request + embeddingReq := &schemas.BifrostRequest{ + Provider: plugin.config.Provider, + Model: plugin.config.EmbeddingModel, + Input: schemas.RequestInput{ + EmbeddingInput: &schemas.EmbeddingInput{ + Texts: []string{text}, + }, + }, + } + + // Generate embedding using bifrost client + response, err := plugin.client.EmbeddingRequest(ctx, embeddingReq) + if err != nil { + return nil, 0, fmt.Errorf("failed to generate embedding: %v", err) + } + + // Extract the first embedding from response + if len(response.Data) == 0 { + return nil, 0, fmt.Errorf("no embeddings returned from provider") + } + + // Get the embedding from the first data item + embedding := response.Data[0].Embedding + inputTokens := 0 + if response.Usage != nil { + inputTokens = response.Usage.TotalTokens + } + + if embedding.EmbeddingStr != nil { + // decode embedding.EmbeddingStr to []float32 + var vals []float32 + if err := json.Unmarshal([]byte(*embedding.EmbeddingStr), &vals); err != nil { + return nil, 0, fmt.Errorf("failed to parse string embedding: %w", err) + } + return vals, inputTokens, nil + } else if embedding.EmbeddingArray != nil { + return *embedding.EmbeddingArray, inputTokens, nil + } else if embedding.Embedding2DArray != nil && len(*embedding.Embedding2DArray) > 0 { + // Flatten 2D array into single embedding + var flattened []float32 + for _, arr := range *embedding.Embedding2DArray { + flattened = append(flattened, arr...) + } + return flattened, inputTokens, nil + } + + return nil, 0, fmt.Errorf("embedding data is not in expected format") +} + +// generateRequestHash creates an xxhash of the request for semantic cache key generation. +// It normalizes the request by including all relevant fields that affect the response: +// - Input (chat completion, text completion, etc.) +// - Parameters (temperature, max_tokens, tools, etc.) +// - Provider (if CacheByProvider is true) +// - Model (if CacheByModel is true) +// +// Note: Fallbacks are excluded as they only affect error handling, not the actual response. +// +// Parameters: +// - req: The Bifrost request to hash for semantic cache key generation +// +// Returns: +// - string: Hexadecimal representation of the xxhash +// - error: Any error that occurred during request normalization or hashing +func (plugin *Plugin) generateRequestHash(req *schemas.BifrostRequest, requestType schemas.RequestType) (string, error) { + // Create a hash input structure that includes both input and parameters + hashInput := struct { + Input schemas.RequestInput `json:"input"` + Params *schemas.ModelParameters `json:"params,omitempty"` + Stream bool `json:"stream,omitempty"` + }{ + Input: *plugin.getInputForCaching(req), + Params: req.Params, + Stream: plugin.isStreamingRequest(requestType), + } + + // Marshal to JSON for consistent hashing + jsonData, err := json.Marshal(hashInput) + if err != nil { + return "", fmt.Errorf("failed to marshal request for hashing: %w", err) + } + + // Generate hash based on configured algorithm + hash := xxhash.Sum64(jsonData) + return fmt.Sprintf("%x", hash), nil +} + +// extractTextForEmbedding extracts meaningful text from different input types for embedding generation. +// Returns the text to embed and metadata for storage. +func (plugin *Plugin) extractTextForEmbedding(req *schemas.BifrostRequest, requestType schemas.RequestType) (string, string, error) { + metadata := map[string]interface{}{} + + attachments := []string{} + + // Add parameters as metadata if present + if req.Params != nil { + if req.Params.ToolChoice != nil { + if req.Params.ToolChoice.ToolChoiceStr != nil { + metadata["tool_choice"] = *req.Params.ToolChoice.ToolChoiceStr + } else if req.Params.ToolChoice.ToolChoiceStruct != nil { + metadata["tool_choice"] = (*req.Params.ToolChoice.ToolChoiceStruct).Function.Name + } + } + if req.Params.Temperature != nil { + metadata["temperature"] = *req.Params.Temperature + } + if req.Params.TopP != nil { + metadata["top_p"] = *req.Params.TopP + } + if req.Params.TopK != nil { + metadata["top_k"] = *req.Params.TopK + } + if req.Params.MaxTokens != nil { + metadata["max_tokens"] = *req.Params.MaxTokens + } + if req.Params.StopSequences != nil { + metadata["stop_sequences"] = *req.Params.StopSequences + } + if req.Params.PresencePenalty != nil { + metadata["presence_penalty"] = *req.Params.PresencePenalty + } + if req.Params.FrequencyPenalty != nil { + metadata["frequency_penalty"] = *req.Params.FrequencyPenalty + } + if req.Params.ParallelToolCalls != nil { + metadata["parallel_tool_calls"] = *req.Params.ParallelToolCalls + } + if req.Params.User != nil { + metadata["user"] = *req.Params.User + } + + if len(req.Params.ExtraParams) > 0 { + maps.Copy(metadata, req.Params.ExtraParams) + } + } + + metadata["stream"] = plugin.isStreamingRequest(requestType) + + if req.Params != nil && req.Params.Tools != nil { + if toolsJSON, err := json.Marshal(*req.Params.Tools); err != nil { + plugin.logger.Warn(fmt.Sprintf("%s Failed to marshal tools for metadata: %v", PluginLoggerPrefix, err)) + } else { + toolHash := xxhash.Sum64(toolsJSON) + metadata["tools_hash"] = fmt.Sprintf("%x", toolHash) + } + } + + switch { + case req.Input.TextCompletionInput != nil: + metadataHash, err := getMetadataHash(metadata) + if err != nil { + return "", "", fmt.Errorf("failed to marshal metadata for metadata hash: %w", err) + } + + return *req.Input.TextCompletionInput, metadataHash, nil + + case req.Input.ChatCompletionInput != nil: + reqInput := plugin.getInputForCaching(req) + + // Serialize chat messages for embedding + var textParts []string + for _, msg := range *reqInput.ChatCompletionInput { + // Extract content as string + var content string + if msg.Content.ContentStr != nil { + content = *msg.Content.ContentStr + } else if msg.Content.ContentBlocks != nil { + // For content blocks, extract text parts + var blockTexts []string + for _, block := range *msg.Content.ContentBlocks { + if block.Text != nil { + blockTexts = append(blockTexts, *block.Text) + } + if block.ImageURL != nil && block.ImageURL.URL != "" { + attachments = append(attachments, block.ImageURL.URL) + } + } + content = strings.Join(blockTexts, " ") + } + + if content != "" { + textParts = append(textParts, fmt.Sprintf("%s: %s", msg.Role, content)) + } + } + + if len(textParts) == 0 { + return "", "", fmt.Errorf("no text content found in chat messages") + } + + if len(attachments) > 0 { + metadata["attachments"] = attachments + } + + metadataHash, err := getMetadataHash(metadata) + if err != nil { + return "", "", fmt.Errorf("failed to marshal metadata for metadata hash: %w", err) + } + + return strings.Join(textParts, "\n"), metadataHash, nil + + case req.Input.SpeechInput != nil: + if req.Input.SpeechInput.Input != "" { + if req.Input.SpeechInput.VoiceConfig.Voice != nil { + metadata["voice"] = *req.Input.SpeechInput.VoiceConfig.Voice + } + + metadataHash, err := getMetadataHash(metadata) + if err != nil { + return "", "", fmt.Errorf("failed to marshal metadata for metadata hash: %w", err) + } + + return req.Input.SpeechInput.Input, metadataHash, nil + } + return "", "", fmt.Errorf("no input text found in speech request") + + case req.Input.EmbeddingInput != nil: + metadataHash, err := getMetadataHash(metadata) + if err != nil { + return "", "", fmt.Errorf("failed to marshal metadata for metadata hash: %w", err) + } + + texts := req.Input.EmbeddingInput.Texts + + if len(texts) == 0 && req.Input.EmbeddingInput.Text != nil { + texts = []string{*req.Input.EmbeddingInput.Text} + } + + var text string + for _, t := range texts { + text += t + " " + } + + return strings.TrimSpace(text), metadataHash, nil + + case req.Input.TranscriptionInput != nil: + // Skip semantic caching for transcription requests + return "", "", fmt.Errorf("transcription requests are not supported for semantic caching") + + default: + return "", "", fmt.Errorf("unsupported input type for semantic caching") + } +} + +func getMetadataHash(metadata map[string]interface{}) (string, error) { + metadataJSON, err := json.Marshal(metadata) + if err != nil { + return "", fmt.Errorf("failed to marshal metadata for metadata hash: %w", err) + } + return fmt.Sprintf("%x", xxhash.Sum64(metadataJSON)), nil +} + +// isStreamingRequest checks if the request is a streaming request +func (plugin *Plugin) isStreamingRequest(requestType schemas.RequestType) bool { + return requestType == schemas.ChatCompletionStreamRequest || + requestType == schemas.SpeechStreamRequest || + requestType == schemas.TranscriptionStreamRequest +} + +// buildUnifiedMetadata constructs the unified metadata structure for VectorEntry +func (plugin *Plugin) buildUnifiedMetadata(provider schemas.ModelProvider, model string, paramsHash string, requestHash string, cacheKey string, ttl time.Duration) map[string]interface{} { + unifiedMetadata := make(map[string]interface{}) + + // Top-level fields (outside params) + unifiedMetadata["provider"] = string(provider) + unifiedMetadata["model"] = model + unifiedMetadata["request_hash"] = requestHash + unifiedMetadata["cache_key"] = cacheKey + unifiedMetadata["from_bifrost_semantic_cache_plugin"] = true + + // Calculate expiration timestamp (current time + TTL) + expiresAt := time.Now().Add(ttl).Unix() + unifiedMetadata["expires_at"] = expiresAt + + // Individual param fields will be stored as params_* by the vectorstore + // We pass the params map to the vectorstore, and it handles the individual field storage + if paramsHash != "" { + unifiedMetadata["params_hash"] = paramsHash + } + + return unifiedMetadata +} + +// addSingleResponse stores a single (non-streaming) response in unified VectorEntry format +func (plugin *Plugin) addSingleResponse(ctx context.Context, responseID string, res *schemas.BifrostResponse, embedding []float32, metadata map[string]interface{}, ttl time.Duration) error { + // Marshal response as string + responseData, err := json.Marshal(res) + if err != nil { + return fmt.Errorf("failed to marshal response: %w", err) + } + + // Add response field to metadata + metadata["response"] = string(responseData) + metadata["stream_chunks"] = []string{} + + // Store unified entry using new VectorStore interface + if err := plugin.store.Add(ctx, plugin.config.VectorStoreNamespace, responseID, embedding, metadata); err != nil { + return fmt.Errorf("failed to store unified cache entry: %w", err) + } + + plugin.logger.Debug(fmt.Sprintf("%s Successfully cached single response with ID: %s", PluginLoggerPrefix, responseID)) + return nil +} + +// addStreamingResponse handles streaming response storage by accumulating chunks +func (plugin *Plugin) addStreamingResponse(ctx context.Context, responseID string, res *schemas.BifrostResponse, bifrostErr *schemas.BifrostError, embedding []float32, metadata map[string]interface{}, ttl time.Duration, isFinalChunk bool) error { + // Create accumulator if it doesn't exist + accumulator := plugin.getOrCreateStreamAccumulator(responseID, embedding, metadata, ttl) + + // Create chunk from current response + chunk := &StreamChunk{ + Timestamp: time.Now(), + Response: res, + } + + // Check for finish reason or set error finish reason + if bifrostErr != nil { + // Error case - mark as final chunk with error + chunk.FinishReason = bifrost.Ptr("error") + } else if res != nil && len(res.Choices) > 0 { + choice := res.Choices[0] + if choice.BifrostStreamResponseChoice != nil { + chunk.FinishReason = choice.FinishReason + } + } + + // Add chunk to accumulator synchronously to maintain order + if err := plugin.addStreamChunk(responseID, chunk, isFinalChunk); err != nil { + return fmt.Errorf("failed to add stream chunk: %w", err) + } + + // Check if this is the final chunk and gate final processing to ensure single invocation + accumulator.mu.Lock() + // Check for completion: either FinishReason is present, there's an error, or token usage exists + alreadyComplete := accumulator.IsComplete + + // Track if any chunk has an error + if bifrostErr != nil { + accumulator.HasError = true + } + + if isFinalChunk && !alreadyComplete { + accumulator.IsComplete = true + accumulator.FinalTimestamp = chunk.Timestamp + } + accumulator.mu.Unlock() + + // If this is the final chunk and hasn't been processed yet, process accumulated chunks + // Note: processAccumulatedStream will check for errors and skip caching if any errors occurred + if isFinalChunk && !alreadyComplete { + if processErr := plugin.processAccumulatedStream(ctx, responseID); processErr != nil { + plugin.logger.Warn(fmt.Sprintf("%s Failed to process accumulated stream for request %s: %v", PluginLoggerPrefix, responseID, processErr)) + } + } + + return nil +} + +// getInputForCaching returns a normalized and sanitized copy of req.Input for hashing/embedding. +// It applies text normalization (lowercase + trim) and optionally removes system messages. +func (plugin *Plugin) getInputForCaching(req *schemas.BifrostRequest) *schemas.RequestInput { + reqInput := req.Input + + // Handle text completion normalization + if reqInput.TextCompletionInput != nil { + normalizedText := normalizeText(*reqInput.TextCompletionInput) + reqInput.TextCompletionInput = &normalizedText + } + + // Handle chat completion normalization + if reqInput.ChatCompletionInput != nil { + originalMessages := *reqInput.ChatCompletionInput + normalizedMessages := make([]schemas.BifrostMessage, 0, len(originalMessages)) + + for _, msg := range originalMessages { + // Skip system messages if configured to exclude them + if plugin.config.ExcludeSystemPrompt != nil && *plugin.config.ExcludeSystemPrompt && msg.Role == schemas.ModelChatMessageRoleSystem { + continue + } + + // Create a copy of the message with normalized content + normalizedMsg := msg + + // Normalize message content + if msg.Content.ContentStr != nil { + normalizedContent := normalizeText(*msg.Content.ContentStr) + normalizedMsg.Content.ContentStr = &normalizedContent + } else if msg.Content.ContentBlocks != nil { + // Create a copy of content blocks with normalized text + normalizedBlocks := make([]schemas.ContentBlock, len(*msg.Content.ContentBlocks)) + for i, block := range *msg.Content.ContentBlocks { + normalizedBlocks[i] = block + if block.Text != nil { + normalizedText := normalizeText(*block.Text) + normalizedBlocks[i].Text = &normalizedText + } + } + normalizedMsg.Content.ContentBlocks = &normalizedBlocks + } + + normalizedMessages = append(normalizedMessages, normalizedMsg) + } + + reqInput.ChatCompletionInput = &normalizedMessages + } + + if reqInput.SpeechInput != nil { + normalizedInput := normalizeText(reqInput.SpeechInput.Input) + reqInput.SpeechInput.Input = normalizedInput + } + + return &reqInput +} + +// removeField removes the first occurrence of target from the slice. +func removeField(arr []string, target string) []string { + for i, v := range arr { + if v == target { + // remove element at index i + return append(arr[:i], arr[i+1:]...) + } + } + return arr // unchanged if target not found +} + +// isConversationHistoryThresholdExceeded checks if the conversation history threshold is exceeded +func (plugin *Plugin) isConversationHistoryThresholdExceeded(req *schemas.BifrostRequest) bool { + switch { + case req.Input.ChatCompletionInput != nil: + input := plugin.getInputForCaching(req) + if len(*input.ChatCompletionInput) > plugin.config.ConversationHistoryThreshold { + return true + } + return false + default: + return false + } +} diff --git a/plugins/semanticcache/version b/plugins/semanticcache/version new file mode 100644 index 000000000..591bdbcd6 --- /dev/null +++ b/plugins/semanticcache/version @@ -0,0 +1 @@ +1.2.18 diff --git a/plugins/telemetry/changelog.md b/plugins/telemetry/changelog.md new file mode 100644 index 000000000..6dcfe4edd --- /dev/null +++ b/plugins/telemetry/changelog.md @@ -0,0 +1,4 @@ + + + +- Upgrades framework to 1.0.23 \ No newline at end of file diff --git a/plugins/telemetry/docker-compose.yml b/plugins/telemetry/docker-compose.yml new file mode 100644 index 000000000..26ebdad61 --- /dev/null +++ b/plugins/telemetry/docker-compose.yml @@ -0,0 +1,29 @@ +# Prometheus and Grafana for tracking bifrost-http service (for development and testing purposes only, don't use in production without proper setup) +services: + prometheus: + image: prom/prometheus:latest + container_name: prometheus + ports: + - "9090:9090" # Expose Prometheus web UI + volumes: + - ./prometheus.yml:/etc/prometheus/prometheus.yml # Prometheus config file + restart: always + networks: + - bifrost_tracking_network + + grafana: + image: grafana/grafana:latest + container_name: grafana + ports: + - "3000:3000" # Expose Grafana web UI + depends_on: + - prometheus + environment: + GF_SECURITY_ADMIN_PASSWORD: "admin" # Default admin password for Grafana + restart: always + networks: + - bifrost_tracking_network + +networks: + bifrost_tracking_network: + driver: bridge diff --git a/plugins/telemetry/go.mod b/plugins/telemetry/go.mod new file mode 100644 index 000000000..9d5e912c6 --- /dev/null +++ b/plugins/telemetry/go.mod @@ -0,0 +1,94 @@ +module github.com/maximhq/bifrost/plugins/telemetry + +go 1.24 + +toolchain go1.24.3 + +require ( + github.com/maximhq/bifrost/core v1.1.37 + github.com/maximhq/bifrost/framework v1.0.23 + github.com/prometheus/client_golang v1.23.0 + github.com/valyala/fasthttp v1.65.0 +) + +require ( + cloud.google.com/go/compute/metadata v0.8.0 // indirect + github.com/andybalholm/brotli v1.2.0 // indirect + github.com/asaskevich/govalidator v0.0.0-20230301143203-a9d515a09cc2 // indirect + github.com/aws/aws-sdk-go-v2 v1.38.0 // indirect + github.com/aws/aws-sdk-go-v2/config v1.31.0 // indirect + github.com/aws/aws-sdk-go-v2/credentials v1.18.4 // indirect + github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.18.3 // indirect + github.com/aws/aws-sdk-go-v2/internal/configsources v1.4.3 // indirect + github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.7.3 // indirect + github.com/aws/aws-sdk-go-v2/internal/ini v1.8.3 // indirect + github.com/aws/aws-sdk-go-v2/service/internal/accept-encoding v1.13.0 // indirect + github.com/aws/aws-sdk-go-v2/service/internal/presigned-url v1.13.3 // indirect + github.com/aws/aws-sdk-go-v2/service/sso v1.28.0 // indirect + github.com/aws/aws-sdk-go-v2/service/ssooidc v1.33.0 // indirect + github.com/aws/aws-sdk-go-v2/service/sts v1.37.0 // indirect + github.com/aws/smithy-go v1.22.5 // indirect + github.com/bahlo/generic-list-go v0.2.0 // indirect + github.com/beorn7/perks v1.0.1 // indirect + github.com/buger/jsonparser v1.1.1 // indirect + github.com/bytedance/sonic v1.14.0 // indirect + github.com/bytedance/sonic/loader v0.3.0 // indirect + github.com/cespare/xxhash/v2 v2.3.0 // indirect + github.com/cloudwego/base64x v0.1.6 // indirect + github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f // indirect + github.com/go-openapi/analysis v0.23.0 // indirect + github.com/go-openapi/errors v0.22.0 // indirect + github.com/go-openapi/jsonpointer v0.21.0 // indirect + github.com/go-openapi/jsonreference v0.21.0 // indirect + github.com/go-openapi/loads v0.22.0 // indirect + github.com/go-openapi/runtime v0.24.2 // indirect + github.com/go-openapi/spec v0.21.0 // indirect + github.com/go-openapi/strfmt v0.23.0 // indirect + github.com/go-openapi/swag v0.23.0 // indirect + github.com/go-openapi/validate v0.24.0 // indirect + github.com/google/uuid v1.6.0 // indirect + github.com/invopop/jsonschema v0.13.0 // indirect + github.com/jinzhu/inflection v1.0.0 // indirect + github.com/jinzhu/now v1.1.5 // indirect + github.com/josharian/intern v1.0.0 // indirect + github.com/klauspost/compress v1.18.0 // indirect + github.com/klauspost/cpuid/v2 v2.3.0 // indirect + github.com/mailru/easyjson v0.9.0 // indirect + github.com/mark3labs/mcp-go v0.37.0 // indirect + github.com/mattn/go-colorable v0.1.14 // indirect + github.com/mattn/go-isatty v0.0.20 // indirect + github.com/mattn/go-sqlite3 v1.14.32 // indirect + github.com/mitchellh/mapstructure v1.5.0 // indirect + github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822 // indirect + github.com/oklog/ulid v1.3.1 // indirect + github.com/opentracing/opentracing-go v1.2.0 // indirect + github.com/pkg/errors v0.9.1 // indirect + github.com/prometheus/client_model v0.6.2 // indirect + github.com/prometheus/common v0.65.0 // indirect + github.com/prometheus/procfs v0.17.0 // indirect + github.com/redis/go-redis/v9 v9.12.1 // indirect + github.com/rs/zerolog v1.34.0 // indirect + github.com/spf13/cast v1.9.2 // indirect + github.com/twitchyliquid64/golang-asm v0.15.1 // indirect + github.com/valyala/bytebufferpool v1.0.0 // indirect + github.com/weaviate/weaviate v1.31.5 // indirect + github.com/weaviate/weaviate-go-client/v5 v5.2.0 // indirect + github.com/wk8/go-ordered-map/v2 v2.1.8 // indirect + github.com/yosida95/uritemplate/v3 v3.0.2 // indirect + go.mongodb.org/mongo-driver v1.14.0 // indirect + go.opentelemetry.io/otel/metric v1.37.0 // indirect + go.opentelemetry.io/otel/sdk v1.37.0 // indirect + go.opentelemetry.io/otel/trace v1.37.0 // indirect + golang.org/x/arch v0.20.0 // indirect + golang.org/x/net v0.43.0 // indirect + golang.org/x/oauth2 v0.30.0 // indirect + golang.org/x/sys v0.35.0 // indirect + golang.org/x/text v0.28.0 // indirect + google.golang.org/genproto/googleapis/rpc v0.0.0-20250811230008-5f3141c8851a // indirect + google.golang.org/grpc v1.74.2 // indirect + google.golang.org/protobuf v1.36.7 // indirect + gopkg.in/yaml.v2 v2.4.0 // indirect + gopkg.in/yaml.v3 v3.0.1 // indirect + gorm.io/driver/sqlite v1.6.0 // indirect + gorm.io/gorm v1.30.1 // indirect +) diff --git a/plugins/telemetry/go.sum b/plugins/telemetry/go.sum new file mode 100644 index 000000000..cb5c3a6e4 --- /dev/null +++ b/plugins/telemetry/go.sum @@ -0,0 +1,369 @@ +cloud.google.com/go/compute/metadata v0.8.0 h1:HxMRIbao8w17ZX6wBnjhcDkW6lTFpgcaobyVfZWqRLA= +cloud.google.com/go/compute/metadata v0.8.0/go.mod h1:sYOGTp851OV9bOFJ9CH7elVvyzopvWQFNNghtDQ/Biw= +github.com/BurntSushi/toml v0.3.1/go.mod h1:xHWCNGjB5oqiDr8zfno3MHue2Ht5sIBksp03qcyfWMU= +github.com/PuerkitoBio/purell v1.1.1/go.mod h1:c11w/QuzBsJSee3cPx9rAFu61PvFxuPbtSwDGJws/X0= +github.com/PuerkitoBio/urlesc v0.0.0-20170810143723-de5bf2ad4578/go.mod h1:uGdkoq3SwY9Y+13GIhn11/XLaGBb4BfwItxLd5jeuXE= +github.com/andybalholm/brotli v1.2.0 h1:ukwgCxwYrmACq68yiUqwIWnGY0cTPox/M94sVwToPjQ= +github.com/andybalholm/brotli v1.2.0/go.mod h1:rzTDkvFWvIrjDXZHkuS16NPggd91W3kUSvPlQ1pLaKY= +github.com/asaskevich/govalidator v0.0.0-20200907205600-7a23bdc65eef/go.mod h1:WaHUgvxTVq04UNunO+XhnAqY/wQc+bxr74GqbsZ/Jqw= +github.com/asaskevich/govalidator v0.0.0-20210307081110-f21760c49a8d/go.mod h1:WaHUgvxTVq04UNunO+XhnAqY/wQc+bxr74GqbsZ/Jqw= +github.com/asaskevich/govalidator v0.0.0-20230301143203-a9d515a09cc2 h1:DklsrG3dyBCFEj5IhUbnKptjxatkF07cF2ak3yi77so= +github.com/asaskevich/govalidator v0.0.0-20230301143203-a9d515a09cc2/go.mod h1:WaHUgvxTVq04UNunO+XhnAqY/wQc+bxr74GqbsZ/Jqw= +github.com/aws/aws-sdk-go-v2 v1.38.0 h1:UCRQ5mlqcFk9HJDIqENSLR3wiG1VTWlyUfLDEvY7RxU= +github.com/aws/aws-sdk-go-v2 v1.38.0/go.mod h1:9Q0OoGQoboYIAJyslFyF1f5K1Ryddop8gqMhWx/n4Wg= +github.com/aws/aws-sdk-go-v2/config v1.31.0 h1:9yH0xiY5fUnVNLRWO0AtayqwU1ndriZdN78LlhruJR4= +github.com/aws/aws-sdk-go-v2/config v1.31.0/go.mod h1:VeV3K72nXnhbe4EuxxhzsDc/ByrCSlZwUnWH52Nde/I= +github.com/aws/aws-sdk-go-v2/credentials v1.18.4 h1:IPd0Algf1b+Qy9BcDp0sCUcIWdCQPSzDoMK3a8pcbUM= +github.com/aws/aws-sdk-go-v2/credentials v1.18.4/go.mod h1:nwg78FjH2qvsRM1EVZlX9WuGUJOL5od+0qvm0adEzHk= +github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.18.3 h1:GicIdnekoJsjq9wqnvyi2elW6CGMSYKhdozE7/Svh78= +github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.18.3/go.mod h1:R7BIi6WNC5mc1kfRM7XM/VHC3uRWkjc396sfabq4iOo= +github.com/aws/aws-sdk-go-v2/internal/configsources v1.4.3 h1:o9RnO+YZ4X+kt5Z7Nvcishlz0nksIt2PIzDglLMP0vA= +github.com/aws/aws-sdk-go-v2/internal/configsources v1.4.3/go.mod h1:+6aLJzOG1fvMOyzIySYjOFjcguGvVRL68R+uoRencN4= +github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.7.3 h1:joyyUFhiTQQmVK6ImzNU9TQSNRNeD9kOklqTzyk5v6s= +github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.7.3/go.mod h1:+vNIyZQP3b3B1tSLI0lxvrU9cfM7gpdRXMFfm67ZcPc= +github.com/aws/aws-sdk-go-v2/internal/ini v1.8.3 h1:bIqFDwgGXXN1Kpp99pDOdKMTTb5d2KyU5X/BZxjOkRo= +github.com/aws/aws-sdk-go-v2/internal/ini v1.8.3/go.mod h1:H5O/EsxDWyU+LP/V8i5sm8cxoZgc2fdNR9bxlOFrQTo= +github.com/aws/aws-sdk-go-v2/service/internal/accept-encoding v1.13.0 h1:6+lZi2JeGKtCraAj1rpoZfKqnQ9SptseRZioejfUOLM= +github.com/aws/aws-sdk-go-v2/service/internal/accept-encoding v1.13.0/go.mod h1:eb3gfbVIxIoGgJsi9pGne19dhCBpK6opTYpQqAmdy44= +github.com/aws/aws-sdk-go-v2/service/internal/presigned-url v1.13.3 h1:ieRzyHXypu5ByllM7Sp4hC5f/1Fy5wqxqY0yB85hC7s= +github.com/aws/aws-sdk-go-v2/service/internal/presigned-url v1.13.3/go.mod h1:O5ROz8jHiOAKAwx179v+7sHMhfobFVi6nZt8DEyiYoM= +github.com/aws/aws-sdk-go-v2/service/sso v1.28.0 h1:Mc/MKBf2m4VynyJkABoVEN+QzkfLqGj0aiJuEe7cMeM= +github.com/aws/aws-sdk-go-v2/service/sso v1.28.0/go.mod h1:iS5OmxEcN4QIPXARGhavH7S8kETNL11kym6jhoS7IUQ= +github.com/aws/aws-sdk-go-v2/service/ssooidc v1.33.0 h1:6csaS/aJmqZQbKhi1EyEMM7yBW653Wy/B9hnBofW+sw= +github.com/aws/aws-sdk-go-v2/service/ssooidc v1.33.0/go.mod h1:59qHWaY5B+Rs7HGTuVGaC32m0rdpQ68N8QCN3khYiqs= +github.com/aws/aws-sdk-go-v2/service/sts v1.37.0 h1:MG9VFW43M4A8BYeAfaJJZWrroinxeTi2r3+SnmLQfSA= +github.com/aws/aws-sdk-go-v2/service/sts v1.37.0/go.mod h1:JdeBDPgpJfuS6rU/hNglmOigKhyEZtBmbraLE4GK1J8= +github.com/aws/smithy-go v1.22.5 h1:P9ATCXPMb2mPjYBgueqJNCA5S9UfktsW0tTxi+a7eqw= +github.com/aws/smithy-go v1.22.5/go.mod h1:t1ufH5HMublsJYulve2RKmHDC15xu1f26kHCp/HgceI= +github.com/bahlo/generic-list-go v0.2.0 h1:5sz/EEAK+ls5wF+NeqDpk5+iNdMDXrh3z3nPnH1Wvgk= +github.com/bahlo/generic-list-go v0.2.0/go.mod h1:2KvAjgMlE5NNynlg/5iLrrCCZ2+5xWbdbCW3pNTGyYg= +github.com/beorn7/perks v1.0.1 h1:VlbKKnNfV8bJzeqoa4cOKqO6bYr3WgKZxO8Z16+hsOM= +github.com/beorn7/perks v1.0.1/go.mod h1:G2ZrVWU2WbWT9wwq4/hrbKbnv/1ERSJQ0ibhJ6rlkpw= +github.com/bsm/ginkgo/v2 v2.12.0 h1:Ny8MWAHyOepLGlLKYmXG4IEkioBysk6GpaRTLC8zwWs= +github.com/bsm/ginkgo/v2 v2.12.0/go.mod h1:SwYbGRRDovPVboqFv0tPTcG1sN61LM1Z4ARdbAV9g4c= +github.com/bsm/gomega v1.27.10 h1:yeMWxP2pV2fG3FgAODIY8EiRE3dy0aeFYt4l7wh6yKA= +github.com/bsm/gomega v1.27.10/go.mod h1:JyEr/xRbxbtgWNi8tIEVPUYZ5Dzef52k01W3YH0H+O0= +github.com/buger/jsonparser v1.1.1 h1:2PnMjfWD7wBILjqQbt530v576A/cAbQvEW9gGIpYMUs= +github.com/buger/jsonparser v1.1.1/go.mod h1:6RYKKt7H4d4+iWqouImQ9R2FZql3VbhNgx27UK13J/0= +github.com/bytedance/sonic v1.14.0 h1:/OfKt8HFw0kh2rj8N0F6C/qPGRESq0BbaNZgcNXXzQQ= +github.com/bytedance/sonic v1.14.0/go.mod h1:WoEbx8WTcFJfzCe0hbmyTGrfjt8PzNEBdxlNUO24NhA= +github.com/bytedance/sonic/loader v0.3.0 h1:dskwH8edlzNMctoruo8FPTJDF3vLtDT0sXZwvZJyqeA= +github.com/bytedance/sonic/loader v0.3.0/go.mod h1:N8A3vUdtUebEY2/VQC0MyhYeKUFosQU6FxH2JmUe6VI= +github.com/cespare/xxhash/v2 v2.3.0 h1:UL815xU9SqsFlibzuggzjXhog7bL6oX9BbNZnL2UFvs= +github.com/cespare/xxhash/v2 v2.3.0/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs= +github.com/cloudwego/base64x v0.1.6 h1:t11wG9AECkCDk5fMSoxmufanudBtJ+/HemLstXDLI2M= +github.com/cloudwego/base64x v0.1.6/go.mod h1:OFcloc187FXDaYHvrNIjxSe8ncn0OOM8gEHfghB2IPU= +github.com/coreos/go-systemd/v22 v22.5.0/go.mod h1:Y58oyj3AT4RCenI/lSvhwexgC+NSVTIJ3seZv2GcEnc= +github.com/creack/pty v1.1.9/go.mod h1:oKZEueFk5CKHvIhNR5MUki03XCEU+Q6VDXinZuGJ33E= +github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc h1:U9qPSI2PIWSS1VwoXQT9A3Wy9MM3WgvqSxFWenqJduM= +github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f h1:lO4WD4F/rVNCu3HqELle0jiPLLBs70cWOduZpkS1E78= +github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f/go.mod h1:cuUVRXasLTGF7a8hSLbxyZXjz+1KgoB3wDUb6vlszIc= +github.com/docker/go-units v0.4.0/go.mod h1:fgPhTUdO+D/Jk86RDLlptpiXQzgHJF7gydDDbaIK4Dk= +github.com/frankban/quicktest v1.14.6 h1:7Xjx+VpznH+oBnejlPUj8oUpdxnVs4f8XU8WnHkI4W8= +github.com/frankban/quicktest v1.14.6/go.mod h1:4ptaffx2x8+WTWXmUCuVU6aPUX1/Mz7zb5vbUoiM6w0= +github.com/go-logr/logr v1.4.3 h1:CjnDlHq8ikf6E492q6eKboGOC0T8CDaOvkHCIg8idEI= +github.com/go-logr/logr v1.4.3/go.mod h1:9T104GzyrTigFIr8wt5mBrctHMim0Nb2HLGrmQ40KvY= +github.com/go-logr/stdr v1.2.2 h1:hSWxHoqTgW2S2qGc0LTAI563KZ5YKYRhT3MFKZMbjag= +github.com/go-logr/stdr v1.2.2/go.mod h1:mMo/vtBO5dYbehREoey6XUKy/eSumjCCveDpRre4VKE= +github.com/go-openapi/analysis v0.21.2/go.mod h1:HZwRk4RRisyG8vx2Oe6aqeSQcoxRp47Xkp3+K6q+LdY= +github.com/go-openapi/analysis v0.23.0 h1:aGday7OWupfMs+LbmLZG4k0MYXIANxcuBTYUC03zFCU= +github.com/go-openapi/analysis v0.23.0/go.mod h1:9mz9ZWaSlV8TvjQHLl2mUW2PbZtemkE8yA5v22ohupo= +github.com/go-openapi/errors v0.19.8/go.mod h1:cM//ZKUKyO06HSwqAelJ5NsEMMcpa6VpXe8DOa1Mi1M= +github.com/go-openapi/errors v0.19.9/go.mod h1:cM//ZKUKyO06HSwqAelJ5NsEMMcpa6VpXe8DOa1Mi1M= +github.com/go-openapi/errors v0.20.2/go.mod h1:cM//ZKUKyO06HSwqAelJ5NsEMMcpa6VpXe8DOa1Mi1M= +github.com/go-openapi/errors v0.22.0 h1:c4xY/OLxUBSTiepAg3j/MHuAv5mJhnf53LLMWFB+u/w= +github.com/go-openapi/errors v0.22.0/go.mod h1:J3DmZScxCDufmIMsdOuDHxJbdOGC0xtUynjIx092vXE= +github.com/go-openapi/jsonpointer v0.19.3/go.mod h1:Pl9vOtqEWErmShwVjC8pYs9cog34VGT37dQOVbmoatg= +github.com/go-openapi/jsonpointer v0.19.5/go.mod h1:Pl9vOtqEWErmShwVjC8pYs9cog34VGT37dQOVbmoatg= +github.com/go-openapi/jsonpointer v0.21.0 h1:YgdVicSA9vH5RiHs9TZW5oyafXZFc6+2Vc1rr/O9oNQ= +github.com/go-openapi/jsonpointer v0.21.0/go.mod h1:IUyH9l/+uyhIYQ/PXVA41Rexl+kOkAPDdXEYns6fzUY= +github.com/go-openapi/jsonreference v0.19.6/go.mod h1:diGHMEHg2IqXZGKxqyvWdfWU/aim5Dprw5bqpKkTvns= +github.com/go-openapi/jsonreference v0.21.0 h1:Rs+Y7hSXT83Jacb7kFyjn4ijOuVGSvOdF2+tg1TRrwQ= +github.com/go-openapi/jsonreference v0.21.0/go.mod h1:LmZmgsrTkVg9LG4EaHeY8cBDslNPMo06cago5JNLkm4= +github.com/go-openapi/loads v0.21.1/go.mod h1:/DtAMXXneXFjbQMGEtbamCZb+4x7eGwkvZCvBmwUG+g= +github.com/go-openapi/loads v0.22.0 h1:ECPGd4jX1U6NApCGG1We+uEozOAvXvJSF4nnwHZ8Aco= +github.com/go-openapi/loads v0.22.0/go.mod h1:yLsaTCS92mnSAZX5WWoxszLj0u+Ojl+Zs5Stn1oF+rs= +github.com/go-openapi/runtime v0.24.2 h1:yX9HMGQbz32M87ECaAhGpJjBmErO3QLcgdZj9BzGx7c= +github.com/go-openapi/runtime v0.24.2/go.mod h1:AKurw9fNre+h3ELZfk6ILsfvPN+bvvlaU/M9q/r9hpk= +github.com/go-openapi/spec v0.20.4/go.mod h1:faYFR1CvsJZ0mNsmsphTMSoRrNV3TEDoAM7FOEWeq8I= +github.com/go-openapi/spec v0.21.0 h1:LTVzPc3p/RzRnkQqLRndbAzjY0d0BCL72A6j3CdL9ZY= +github.com/go-openapi/spec v0.21.0/go.mod h1:78u6VdPw81XU44qEWGhtr982gJ5BWg2c0I5XwVMotYk= +github.com/go-openapi/strfmt v0.21.0/go.mod h1:ZRQ409bWMj+SOgXofQAGTIo2Ebu72Gs+WaRADcS5iNg= +github.com/go-openapi/strfmt v0.21.1/go.mod h1:I/XVKeLc5+MM5oPNN7P6urMOpuLXEcNrCX/rPGuWb0k= +github.com/go-openapi/strfmt v0.21.2/go.mod h1:I/XVKeLc5+MM5oPNN7P6urMOpuLXEcNrCX/rPGuWb0k= +github.com/go-openapi/strfmt v0.23.0 h1:nlUS6BCqcnAk0pyhi9Y+kdDVZdZMHfEKQiS4HaMgO/c= +github.com/go-openapi/strfmt v0.23.0/go.mod h1:NrtIpfKtWIygRkKVsxh7XQMDQW5HKQl6S5ik2elW+K4= +github.com/go-openapi/swag v0.19.5/go.mod h1:POnQmlKehdgb5mhVOsnJFsivZCEZ/vjK9gh66Z9tfKk= +github.com/go-openapi/swag v0.19.15/go.mod h1:QYRuS/SOXUCsnplDa677K7+DxSOj6IPNl/eQntq43wQ= +github.com/go-openapi/swag v0.21.1/go.mod h1:QYRuS/SOXUCsnplDa677K7+DxSOj6IPNl/eQntq43wQ= +github.com/go-openapi/swag v0.23.0 h1:vsEVJDUo2hPJ2tu0/Xc+4noaxyEffXNIs3cOULZ+GrE= +github.com/go-openapi/swag v0.23.0/go.mod h1:esZ8ITTYEsH1V2trKHjAN8Ai7xHb8RV+YSZ577vPjgQ= +github.com/go-openapi/validate v0.21.0/go.mod h1:rjnrwK57VJ7A8xqfpAOEKRH8yQSGUriMu5/zuPSQ1hg= +github.com/go-openapi/validate v0.24.0 h1:LdfDKwNbpB6Vn40xhTdNZAnfLECL81w+VX3BumrGD58= +github.com/go-openapi/validate v0.24.0/go.mod h1:iyeX1sEufmv3nPbBdX3ieNviWnOZaJ1+zquzJEf2BAQ= +github.com/go-stack/stack v1.8.0/go.mod h1:v0f6uXyyMGvRgIKkXu+yp6POWl0qKG85gN/melR3HDY= +github.com/go-stack/stack v1.8.1/go.mod h1:dcoOX6HbPZSZptuspn9bctJ+N/CnF5gGygcUP3XYfe4= +github.com/gobuffalo/attrs v0.0.0-20190224210810-a9411de4debd/go.mod h1:4duuawTqi2wkkpB4ePgWMaai6/Kc6WEz83bhFwpHzj0= +github.com/gobuffalo/depgen v0.0.0-20190329151759-d478694a28d3/go.mod h1:3STtPUQYuzV0gBVOY3vy6CfMm/ljR4pABfrTeHNLHUY= +github.com/gobuffalo/depgen v0.1.0/go.mod h1:+ifsuy7fhi15RWncXQQKjWS9JPkdah5sZvtHc2RXGlg= +github.com/gobuffalo/envy v1.6.15/go.mod h1:n7DRkBerg/aorDM8kbduw5dN3oXGswK5liaSCx4T5NI= +github.com/gobuffalo/envy v1.7.0/go.mod h1:n7DRkBerg/aorDM8kbduw5dN3oXGswK5liaSCx4T5NI= +github.com/gobuffalo/flect v0.1.0/go.mod h1:d2ehjJqGOH/Kjqcoz+F7jHTBbmDb38yXA598Hb50EGs= +github.com/gobuffalo/flect v0.1.1/go.mod h1:8JCgGVbRjJhVgD6399mQr4fx5rRfGKVzFjbj6RE/9UI= +github.com/gobuffalo/flect v0.1.3/go.mod h1:8JCgGVbRjJhVgD6399mQr4fx5rRfGKVzFjbj6RE/9UI= +github.com/gobuffalo/genny v0.0.0-20190329151137-27723ad26ef9/go.mod h1:rWs4Z12d1Zbf19rlsn0nurr75KqhYp52EAGGxTbBhNk= +github.com/gobuffalo/genny v0.0.0-20190403191548-3ca520ef0d9e/go.mod h1:80lIj3kVJWwOrXWWMRzzdhW3DsrdjILVil/SFKBzF28= +github.com/gobuffalo/genny v0.1.0/go.mod h1:XidbUqzak3lHdS//TPu2OgiFB+51Ur5f7CSnXZ/JDvo= +github.com/gobuffalo/genny v0.1.1/go.mod h1:5TExbEyY48pfunL4QSXxlDOmdsD44RRq4mVZ0Ex28Xk= +github.com/gobuffalo/gitgen v0.0.0-20190315122116-cc086187d211/go.mod h1:vEHJk/E9DmhejeLeNt7UVvlSGv3ziL+djtTr3yyzcOw= +github.com/gobuffalo/gogen v0.0.0-20190315121717-8f38393713f5/go.mod h1:V9QVDIxsgKNZs6L2IYiGR8datgMhB577vzTDqypH360= +github.com/gobuffalo/gogen v0.1.0/go.mod h1:8NTelM5qd8RZ15VjQTFkAW6qOMx5wBbW4dSCS3BY8gg= +github.com/gobuffalo/gogen v0.1.1/go.mod h1:y8iBtmHmGc4qa3urIyo1shvOD8JftTtfcKi+71xfDNE= +github.com/gobuffalo/logger v0.0.0-20190315122211-86e12af44bc2/go.mod h1:QdxcLw541hSGtBnhUc4gaNIXRjiDppFGaDqzbrBd3v8= +github.com/gobuffalo/mapi v1.0.1/go.mod h1:4VAGh89y6rVOvm5A8fKFxYG+wIW6LO1FMTG9hnKStFc= +github.com/gobuffalo/mapi v1.0.2/go.mod h1:4VAGh89y6rVOvm5A8fKFxYG+wIW6LO1FMTG9hnKStFc= +github.com/gobuffalo/packd v0.0.0-20190315124812-a385830c7fc0/go.mod h1:M2Juc+hhDXf/PnmBANFCqx4DM3wRbgDvnVWeG2RIxq4= +github.com/gobuffalo/packd v0.1.0/go.mod h1:M2Juc+hhDXf/PnmBANFCqx4DM3wRbgDvnVWeG2RIxq4= +github.com/gobuffalo/packr/v2 v2.0.9/go.mod h1:emmyGweYTm6Kdper+iywB6YK5YzuKchGtJQZ0Odn4pQ= +github.com/gobuffalo/packr/v2 v2.2.0/go.mod h1:CaAwI0GPIAv+5wKLtv8Afwl+Cm78K/I/VCm/3ptBN+0= +github.com/gobuffalo/syncx v0.0.0-20190224160051-33c29581e754/go.mod h1:HhnNqWY95UYwwW3uSASeV7vtgYkT2t16hJgV3AEPUpw= +github.com/godbus/dbus/v5 v5.0.4/go.mod h1:xhWf0FNVPg57R7Z0UbKHbJfkEywrmjJnf7w5xrFpKfA= +github.com/golang/protobuf v1.5.4 h1:i7eJL8qZTpSEXOPTxNKhASYpMn+8e5Q6AdndVa1dWek= +github.com/golang/protobuf v1.5.4/go.mod h1:lnTiLA8Wa4RWRcIUkrtSVa5nRhsEGBg48fD6rSs7xps= +github.com/golang/snappy v0.0.1/go.mod h1:/XxbfmMg8lxefKM7IXC3fBNl/7bRcc72aCRzEWrmP2Q= +github.com/google/go-cmp v0.5.2/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= +github.com/google/go-cmp v0.7.0 h1:wk8382ETsv4JYUZwIsn6YpYiWiBsYLSJiTsyBybVuN8= +github.com/google/go-cmp v0.7.0/go.mod h1:pXiqmnSA92OHEEa9HXL2W4E7lf9JzCmGVUdgjX3N/iU= +github.com/google/uuid v1.1.1/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= +github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0= +github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= +github.com/inconshreveable/mousetrap v1.0.0/go.mod h1:PxqpIevigyE2G7u3NXJIT2ANytuPF1OarO4DADm73n8= +github.com/invopop/jsonschema v0.13.0 h1:KvpoAJWEjR3uD9Kbm2HWJmqsEaHt8lBUpd0qHcIi21E= +github.com/invopop/jsonschema v0.13.0/go.mod h1:ffZ5Km5SWWRAIN6wbDXItl95euhFz2uON45H2qjYt+0= +github.com/jinzhu/inflection v1.0.0 h1:K317FqzuhWc8YvSVlFMCCUb36O/S9MCKRDI7QkRKD/E= +github.com/jinzhu/inflection v1.0.0/go.mod h1:h+uFLlag+Qp1Va5pdKtLDYj+kHp5pxUVkryuEj+Srlc= +github.com/jinzhu/now v1.1.5 h1:/o9tlHleP7gOFmsnYNz3RGnqzefHA47wQpKrrdTIwXQ= +github.com/jinzhu/now v1.1.5/go.mod h1:d3SSVoowX0Lcu0IBviAWJpolVfI5UJVZZ7cO71lE/z8= +github.com/joho/godotenv v1.3.0/go.mod h1:7hK45KPybAkOC6peb+G5yklZfMxEjkZhHbwpqxOKXbg= +github.com/josharian/intern v1.0.0 h1:vlS4z54oSdjm0bgjRigI+G1HpF+tI+9rE5LLzOg8HmY= +github.com/josharian/intern v1.0.0/go.mod h1:5DoeVV0s6jJacbCEi61lwdGj/aVlrQvzHFFd8Hwg//Y= +github.com/karrick/godirwalk v1.8.0/go.mod h1:H5KPZjojv4lE+QYImBI8xVtrBRgYrIVsaRPx4tDPEn4= +github.com/karrick/godirwalk v1.10.3/go.mod h1:RoGL9dQei4vP9ilrpETWE8CLOZ1kiN0LhBygSwrAsHA= +github.com/klauspost/compress v1.13.6/go.mod h1:/3/Vjq9QcHkK5uEr5lBEmyoZ1iFhe47etQ6QUkpK6sk= +github.com/klauspost/compress v1.18.0 h1:c/Cqfb0r+Yi+JtIEq73FWXVkRonBlf0CRNYc8Zttxdo= +github.com/klauspost/compress v1.18.0/go.mod h1:2Pp+KzxcywXVXMr50+X0Q/Lsb43OQHYWRCY2AiWywWQ= +github.com/klauspost/cpuid/v2 v2.3.0 h1:S4CRMLnYUhGeDFDqkGriYKdfoFlDnMtqTiI/sFzhA9Y= +github.com/klauspost/cpuid/v2 v2.3.0/go.mod h1:hqwkgyIinND0mEev00jJYCxPNVRVXFQeu1XKlok6oO0= +github.com/konsorten/go-windows-terminal-sequences v1.0.1/go.mod h1:T0+1ngSBFLxvqU3pZ+m/2kptfBszLMUkC4ZK/EgS/cQ= +github.com/konsorten/go-windows-terminal-sequences v1.0.2/go.mod h1:T0+1ngSBFLxvqU3pZ+m/2kptfBszLMUkC4ZK/EgS/cQ= +github.com/kr/pretty v0.1.0/go.mod h1:dAy3ld7l9f0ibDNOQOHHMYYIIbhfbHSm3C4ZsoJORNo= +github.com/kr/pretty v0.3.1 h1:flRD4NNwYAUpkphVc1HcthR4KEIFJ65n8Mw5qdRn3LE= +github.com/kr/pretty v0.3.1/go.mod h1:hoEshYVHaxMs3cyo3Yncou5ZscifuDolrwPKZanG3xk= +github.com/kr/pty v1.1.1/go.mod h1:pFQYn66WHrOpPYNljwOMqo10TkYh1fy3cYio2l3bCsQ= +github.com/kr/text v0.1.0/go.mod h1:4Jbv+DJW3UT/LiOwJeYQe1efqtUx/iVham/4vfdArNI= +github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY= +github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE= +github.com/mailru/easyjson v0.0.0-20190614124828-94de47d64c63/go.mod h1:C1wdFJiN94OJF2b5HbByQZoLdCWB1Yqtg26g4irojpc= +github.com/mailru/easyjson v0.0.0-20190626092158-b2ccc519800e/go.mod h1:C1wdFJiN94OJF2b5HbByQZoLdCWB1Yqtg26g4irojpc= +github.com/mailru/easyjson v0.7.6/go.mod h1:xzfreul335JAWq5oZzymOObrkdz5UnU4kGfJJLY9Nlc= +github.com/mailru/easyjson v0.7.7/go.mod h1:xzfreul335JAWq5oZzymOObrkdz5UnU4kGfJJLY9Nlc= +github.com/mailru/easyjson v0.9.0 h1:PrnmzHw7262yW8sTBwxi1PdJA3Iw/EKBa8psRf7d9a4= +github.com/mailru/easyjson v0.9.0/go.mod h1:1+xMtQp2MRNVL/V1bOzuP3aP8VNwRW55fQUto+XFtTU= +github.com/mark3labs/mcp-go v0.37.0 h1:BywvZLPRT6Zx6mMG/MJfxLSZQkTGIcJSEGKsvr4DsoQ= +github.com/mark3labs/mcp-go v0.37.0/go.mod h1:T7tUa2jO6MavG+3P25Oy/jR7iCeJPHImCZHRymCn39g= +github.com/markbates/oncer v0.0.0-20181203154359-bf2de49a0be2/go.mod h1:Ld9puTsIW75CHf65OeIOkyKbteujpZVXDpWK6YGZbxE= +github.com/markbates/safe v1.0.1/go.mod h1:nAqgmRi7cY2nqMc92/bSEeQA+R4OheNU2T1kNSCBdG0= +github.com/mattn/go-colorable v0.1.13/go.mod h1:7S9/ev0klgBDR4GtXTXX8a3vIGJpMovkB8vQcUbaXHg= +github.com/mattn/go-colorable v0.1.14 h1:9A9LHSqF/7dyVVX6g0U9cwm9pG3kP9gSzcuIPHPsaIE= +github.com/mattn/go-colorable v0.1.14/go.mod h1:6LmQG8QLFO4G5z1gPvYEzlUgJ2wF+stgPZH1UqBm1s8= +github.com/mattn/go-isatty v0.0.16/go.mod h1:kYGgaQfpe5nmfYZH+SKPsOc2e4SrIfOl2e/yFXSvRLM= +github.com/mattn/go-isatty v0.0.19/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y= +github.com/mattn/go-isatty v0.0.20 h1:xfD0iDuEKnDkl03q4limB+vH+GxLEtL/jb4xVJSWWEY= +github.com/mattn/go-isatty v0.0.20/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y= +github.com/mattn/go-sqlite3 v1.14.32 h1:JD12Ag3oLy1zQA+BNn74xRgaBbdhbNIDYvQUEuuErjs= +github.com/mattn/go-sqlite3 v1.14.32/go.mod h1:Uh1q+B4BYcTPb+yiD3kU8Ct7aC0hY9fxUwlHK0RXw+Y= +github.com/maximhq/bifrost/core v1.1.37 h1:jVFY1tQFY8T2r4S3RE1zN8cFp1Uw97Dec3Ud32rR8Uc= +github.com/maximhq/bifrost/core v1.1.37/go.mod h1:tf2pFTpoM53UGXXMFYxsaUjMqnCqYDOd9glFgMJvA0c= +github.com/maximhq/bifrost/framework v1.0.23 h1:erRPP9Q0WIaUgxuLBN8urd77SObEF9irPvpV9Wbegyk= +github.com/maximhq/bifrost/framework v1.0.23/go.mod h1:uEB0iuQtFfuFuMrhccMsb+51mf8m8X2tB8ZlDVoJUbM= +github.com/mitchellh/mapstructure v1.3.3/go.mod h1:bFUtVrKA4DC2yAKiSyO/QUcy7e+RRV2QTWOzhPopBRo= +github.com/mitchellh/mapstructure v1.4.1/go.mod h1:bFUtVrKA4DC2yAKiSyO/QUcy7e+RRV2QTWOzhPopBRo= +github.com/mitchellh/mapstructure v1.4.3/go.mod h1:bFUtVrKA4DC2yAKiSyO/QUcy7e+RRV2QTWOzhPopBRo= +github.com/mitchellh/mapstructure v1.5.0 h1:jeMsZIYE/09sWLaz43PL7Gy6RuMjD2eJVyuac5Z2hdY= +github.com/mitchellh/mapstructure v1.5.0/go.mod h1:bFUtVrKA4DC2yAKiSyO/QUcy7e+RRV2QTWOzhPopBRo= +github.com/montanaflynn/stats v0.0.0-20171201202039-1bf9dbcd8cbe/go.mod h1:wL8QJuTMNUDYhXwkmfOly8iTdp5TEcJFWZD2D7SIkUc= +github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822 h1:C3w9PqII01/Oq1c1nUAm88MOHcQC9l5mIlSMApZMrHA= +github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822/go.mod h1:+n7T8mK8HuQTcFwEeznm/DIxMOiR9yIdICNftLE1DvQ= +github.com/niemeyer/pretty v0.0.0-20200227124842-a10e7caefd8e/go.mod h1:zD1mROLANZcx1PVRCS0qkT7pwLkGfwJo4zjcN/Tysno= +github.com/oklog/ulid v1.3.1 h1:EGfNDEx6MqHz8B3uNV6QAib1UR2Lm97sHi3ocA6ESJ4= +github.com/oklog/ulid v1.3.1/go.mod h1:CirwcVhetQ6Lv90oh/F+FBtV6XMibvdAFo93nm5qn4U= +github.com/opentracing/opentracing-go v1.2.0 h1:uEJPy/1a5RIPAJ0Ov+OIO8OxWu77jEv+1B0VhjKrZUs= +github.com/opentracing/opentracing-go v1.2.0/go.mod h1:GxEUsuufX4nBwe+T+Wl9TAgYrxe9dPLANfrWvHYVTgc= +github.com/pelletier/go-toml v1.7.0/go.mod h1:vwGMzjaWMwyfHwgIBhI2YUM4fB6nL6lVAvS1LBMMhTE= +github.com/pkg/errors v0.8.0/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= +github.com/pkg/errors v0.8.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= +github.com/pkg/errors v0.9.1 h1:FEBLx1zS214owpjy7qsBeixbURkuhQAwrK5UwLGTwt4= +github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= +github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= +github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2 h1:Jamvg5psRIccs7FGNTlIRMkT8wgtp5eCXdBlqhYGL6U= +github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= +github.com/prometheus/client_golang v1.23.0 h1:ust4zpdl9r4trLY/gSjlm07PuiBq2ynaXXlptpfy8Uc= +github.com/prometheus/client_golang v1.23.0/go.mod h1:i/o0R9ByOnHX0McrTMTyhYvKE4haaf2mW08I+jGAjEE= +github.com/prometheus/client_model v0.6.2 h1:oBsgwpGs7iVziMvrGhE53c/GrLUsZdHnqNwqPLxwZyk= +github.com/prometheus/client_model v0.6.2/go.mod h1:y3m2F6Gdpfy6Ut/GBsUqTWZqCUvMVzSfMLjcu6wAwpE= +github.com/prometheus/common v0.65.0 h1:QDwzd+G1twt//Kwj/Ww6E9FQq1iVMmODnILtW1t2VzE= +github.com/prometheus/common v0.65.0/go.mod h1:0gZns+BLRQ3V6NdaerOhMbwwRbNh9hkGINtQAsP5GS8= +github.com/prometheus/procfs v0.17.0 h1:FuLQ+05u4ZI+SS/w9+BWEM2TXiHKsUQ9TADiRH7DuK0= +github.com/prometheus/procfs v0.17.0/go.mod h1:oPQLaDAMRbA+u8H5Pbfq+dl3VDAvHxMUOVhe0wYB2zw= +github.com/redis/go-redis/v9 v9.12.1 h1:k5iquqv27aBtnTm2tIkROUDp8JBXhXZIVu1InSgvovg= +github.com/redis/go-redis/v9 v9.12.1/go.mod h1:huWgSWd8mW6+m0VPhJjSSQ+d6Nh1VICQ6Q5lHuCH/Iw= +github.com/rogpeppe/go-internal v1.1.0/go.mod h1:M8bDsm7K2OlrFYOpmOWEs/qY81heoFRclV5y23lUDJ4= +github.com/rogpeppe/go-internal v1.2.2/go.mod h1:M8bDsm7K2OlrFYOpmOWEs/qY81heoFRclV5y23lUDJ4= +github.com/rogpeppe/go-internal v1.3.0/go.mod h1:M8bDsm7K2OlrFYOpmOWEs/qY81heoFRclV5y23lUDJ4= +github.com/rogpeppe/go-internal v1.13.1 h1:KvO1DLK/DRN07sQ1LQKScxyZJuNnedQ5/wKSR38lUII= +github.com/rogpeppe/go-internal v1.13.1/go.mod h1:uMEvuHeurkdAXX61udpOXGD/AzZDWNMNyH2VO9fmH0o= +github.com/rs/xid v1.6.0/go.mod h1:7XoLgs4eV+QndskICGsho+ADou8ySMSjJKDIan90Nz0= +github.com/rs/zerolog v1.34.0 h1:k43nTLIwcTVQAncfCw4KZ2VY6ukYoZaBPNOE8txlOeY= +github.com/rs/zerolog v1.34.0/go.mod h1:bJsvje4Z08ROH4Nhs5iH600c3IkWhwp44iRc54W6wYQ= +github.com/sirupsen/logrus v1.4.0/go.mod h1:LxeOpSwHxABJmUn/MG1IvRgCAasNZTLOkJPxbbu5VWo= +github.com/sirupsen/logrus v1.4.1/go.mod h1:ni0Sbl8bgC9z8RoU9G6nDWqqs/fq4eDPysMBDgk/93Q= +github.com/sirupsen/logrus v1.4.2/go.mod h1:tLMulIdttU9McNUspp0xgXVQah82FyeX6MwdIuYE2rE= +github.com/spf13/cast v1.9.2 h1:SsGfm7M8QOFtEzumm7UZrZdLLquNdzFYfIbEXntcFbE= +github.com/spf13/cast v1.9.2/go.mod h1:jNfB8QC9IA6ZuY2ZjDp0KtFO2LZZlg4S/7bzP6qqeHo= +github.com/spf13/cobra v0.0.3/go.mod h1:1l0Ry5zgKvJasoi3XT1TypsSe7PqH0Sj9dhYf7v3XqQ= +github.com/spf13/pflag v1.0.3/go.mod h1:DYY7MBk1bdzusC3SYhjObp+wFpr4gzcvqqNjLnInEg4= +github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= +github.com/stretchr/objx v0.1.1/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= +github.com/stretchr/objx v0.4.0/go.mod h1:YvHI0jy2hoMjB+UWwv71VJQ9isScKT/TqJzVSSt89Yw= +github.com/stretchr/objx v0.5.0/go.mod h1:Yh+to48EsGEfYuaHDzXPcE3xhTkx73EhmCGUpEOglKo= +github.com/stretchr/objx v0.5.2 h1:xuMeJ0Sdp5ZMRXx/aWO6RZxdr3beISkG5/G/aIRr3pY= +github.com/stretchr/objx v0.5.2/go.mod h1:FRsXN1f5AsAjCGJKqEizvkpNtU+EGNCLh3NxZ/8L+MA= +github.com/stretchr/testify v1.2.2/go.mod h1:a8OnRcib4nhh0OaRAV+Yts87kKdq0PP7pXfy6kDkUVs= +github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI= +github.com/stretchr/testify v1.6.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= +github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= +github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= +github.com/stretchr/testify v1.8.0/go.mod h1:yNjHg4UonilssWZ8iaSj1OCr/vHnekPRkoO+kdMU+MU= +github.com/stretchr/testify v1.8.1/go.mod h1:w2LPCIKwWwSfY2zedu0+kehJoqGctiVI29o6fzry7u4= +github.com/stretchr/testify v1.10.0 h1:Xv5erBjTwe/5IxqUQTdXv5kgmIvbHo3QQyRwhJsOfJA= +github.com/stretchr/testify v1.10.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY= +github.com/tidwall/pretty v1.0.0/go.mod h1:XNkn88O1ChpSDQmQeStsy+sBenx6DDtFZJxhVysOjyk= +github.com/twitchyliquid64/golang-asm v0.15.1 h1:SU5vSMR7hnwNxj24w34ZyCi/FmDZTkS4MhqMhdFk5YI= +github.com/twitchyliquid64/golang-asm v0.15.1/go.mod h1:a1lVb/DtPvCB8fslRZhAngC2+aY1QWCk3Cedj/Gdt08= +github.com/valyala/bytebufferpool v1.0.0 h1:GqA5TC/0021Y/b9FG4Oi9Mr3q7XYx6KllzawFIhcdPw= +github.com/valyala/bytebufferpool v1.0.0/go.mod h1:6bBcMArwyJ5K/AmCkWv1jt77kVWyCJ6HpOuEn7z0Csc= +github.com/valyala/fasthttp v1.65.0 h1:j/u3uzFEGFfRxw79iYzJN+TteTJwbYkru9uDp3d0Yf8= +github.com/valyala/fasthttp v1.65.0/go.mod h1:P/93/YkKPMsKSnATEeELUCkG8a7Y+k99uxNHVbKINr4= +github.com/weaviate/weaviate v1.31.5 h1:YcmU1NcY2rdegWpE/mifS/9OisjE3I30JC7k6OgRlIE= +github.com/weaviate/weaviate v1.31.5/go.mod h1:CMgFYC2WIekOrNtyCQZ+HRJzJVCtrJYAdAkZVUVy45E= +github.com/weaviate/weaviate-go-client/v5 v5.2.0 h1:/HG0vFiKBK3JoOKo0mdk2XVYZ+oM0KfvCLG2ySr/FCA= +github.com/weaviate/weaviate-go-client/v5 v5.2.0/go.mod h1:nzR0ScRmbbutI+0pAjylj9Pt6upGVotnphiLWjy/QNA= +github.com/wk8/go-ordered-map/v2 v2.1.8 h1:5h/BUHu93oj4gIdvHHHGsScSTMijfx5PeYkE/fJgbpc= +github.com/wk8/go-ordered-map/v2 v2.1.8/go.mod h1:5nJHM5DyteebpVlHnWMV0rPz6Zp7+xBAnxjb1X5vnTw= +github.com/xdg-go/pbkdf2 v1.0.0/go.mod h1:jrpuAogTd400dnrH08LKmI/xc1MbPOebTwRqcT5RDeI= +github.com/xdg-go/scram v1.0.2/go.mod h1:1WAq6h33pAW+iRreB34OORO2Nf7qel3VV3fjBj+hCSs= +github.com/xdg-go/stringprep v1.0.2/go.mod h1:8F9zXuvzgwmyT5DUm4GUfZGDdT3W+LCvS6+da4O5kxM= +github.com/xyproto/randomstring v1.0.5 h1:YtlWPoRdgMu3NZtP45drfy1GKoojuR7hmRcnhZqKjWU= +github.com/xyproto/randomstring v1.0.5/go.mod h1:rgmS5DeNXLivK7YprL0pY+lTuhNQW3iGxZ18UQApw/E= +github.com/yosida95/uritemplate/v3 v3.0.2 h1:Ed3Oyj9yrmi9087+NczuL5BwkIc4wvTb5zIM+UJPGz4= +github.com/yosida95/uritemplate/v3 v3.0.2/go.mod h1:ILOh0sOhIJR3+L/8afwt/kE++YT040gmv5BQTMR2HP4= +github.com/youmark/pkcs8 v0.0.0-20181117223130-1be2e3e5546d/go.mod h1:rHwXgn7JulP+udvsHwJoVG1YGAP6VLg4y9I5dyZdqmA= +go.mongodb.org/mongo-driver v1.7.3/go.mod h1:NqaYOwnXWr5Pm7AOpO5QFxKJ503nbMse/R79oO62zWg= +go.mongodb.org/mongo-driver v1.7.5/go.mod h1:VXEWRZ6URJIkUq2SCAyapmhH0ZLRBP+FT4xhp5Zvxng= +go.mongodb.org/mongo-driver v1.8.3/go.mod h1:0sQWfOeY63QTntERDJJ/0SuKK0T1uVSgKCuAROlKEPY= +go.mongodb.org/mongo-driver v1.14.0 h1:P98w8egYRjYe3XDjxhYJagTokP/H6HzlsnojRgZRd80= +go.mongodb.org/mongo-driver v1.14.0/go.mod h1:Vzb0Mk/pa7e6cWw85R4F/endUC3u0U9jGcNU603k65c= +go.opentelemetry.io/auto/sdk v1.1.0 h1:cH53jehLUN6UFLY71z+NDOiNJqDdPRaXzTel0sJySYA= +go.opentelemetry.io/auto/sdk v1.1.0/go.mod h1:3wSPjt5PWp2RhlCcmmOial7AvC4DQqZb7a7wCow3W8A= +go.opentelemetry.io/otel v1.37.0 h1:9zhNfelUvx0KBfu/gb+ZgeAfAgtWrfHJZcAqFC228wQ= +go.opentelemetry.io/otel v1.37.0/go.mod h1:ehE/umFRLnuLa/vSccNq9oS1ErUlkkK71gMcN34UG8I= +go.opentelemetry.io/otel/metric v1.37.0 h1:mvwbQS5m0tbmqML4NqK+e3aDiO02vsf/WgbsdpcPoZE= +go.opentelemetry.io/otel/metric v1.37.0/go.mod h1:04wGrZurHYKOc+RKeye86GwKiTb9FKm1WHtO+4EVr2E= +go.opentelemetry.io/otel/sdk v1.37.0 h1:ItB0QUqnjesGRvNcmAcU0LyvkVyGJ2xftD29bWdDvKI= +go.opentelemetry.io/otel/sdk v1.37.0/go.mod h1:VredYzxUvuo2q3WRcDnKDjbdvmO0sCzOvVAiY+yUkAg= +go.opentelemetry.io/otel/sdk/metric v1.37.0 h1:90lI228XrB9jCMuSdA0673aubgRobVZFhbjxHHspCPc= +go.opentelemetry.io/otel/sdk/metric v1.37.0/go.mod h1:cNen4ZWfiD37l5NhS+Keb5RXVWZWpRE+9WyVCpbo5ps= +go.opentelemetry.io/otel/trace v1.37.0 h1:HLdcFNbRQBE2imdSEgm/kwqmQj1Or1l/7bW6mxVK7z4= +go.opentelemetry.io/otel/trace v1.37.0/go.mod h1:TlgrlQ+PtQO5XFerSPUYG0JSgGyryXewPGyayAWSBS0= +go.uber.org/goleak v1.3.0 h1:2K3zAYmnTNqV73imy9J1T3WC+gmCePx2hEGkimedGto= +go.uber.org/goleak v1.3.0/go.mod h1:CoHD4mav9JJNrW/WLlf7HGZPjdw8EucARQHekz1X6bE= +golang.org/x/arch v0.20.0 h1:dx1zTU0MAE98U+TQ8BLl7XsJbgze2WnNKF/8tGp/Q6c= +golang.org/x/arch v0.20.0/go.mod h1:bdwinDaKcfZUGpH09BB7ZmOfhalA8lQdzl62l8gGWsk= +golang.org/x/crypto v0.0.0-20180904163835-0709b304e793/go.mod h1:6SG95UA2DQfeDnfUPMdvaQW0Q7yPrPDi9nlGo2tz2b4= +golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= +golang.org/x/crypto v0.0.0-20190422162423-af44ce270edf/go.mod h1:WFFai1msRO1wXaEeE5yQxYXgSfI8pQAWXbQop6sCtWE= +golang.org/x/crypto v0.0.0-20200302210943-78000ba7a073/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto= +golang.org/x/crypto v0.0.0-20201216223049-8b5274cf687f/go.mod h1:jdWPYTVW3xRLrWPugEBEK3UY2ZEsg3UU495nc5E+M+I= +golang.org/x/net v0.0.0-20190311183353-d8887717615a/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg= +golang.org/x/net v0.0.0-20190404232315-eb5bcb51f2a3/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg= +golang.org/x/net v0.0.0-20210421230115-4e50805a0758/go.mod h1:72T/g9IO56b78aLF+1Kcs5dz7/ng1VjMUvfKvpfy+jM= +golang.org/x/net v0.0.0-20220127200216-cd36cc0744dd/go.mod h1:CfG3xpIq0wQ8r1q4Su4UZFWDARRcnwPjda9FqA0JpMk= +golang.org/x/net v0.43.0 h1:lat02VYK2j4aLzMzecihNvTlJNQUq316m2Mr9rnM6YE= +golang.org/x/net v0.43.0/go.mod h1:vhO1fvI4dGsIjh73sWfUVjj3N7CA9WkKJNQm2svM6Jg= +golang.org/x/oauth2 v0.30.0 h1:dnDm7JmhM45NNpd8FDDeLhK6FwqbOf4MLCM9zb1BOHI= +golang.org/x/oauth2 v0.30.0/go.mod h1:B++QgG3ZKulg6sRPGD/mqlHQs5rB3Ml9erfeDY7xKlU= +golang.org/x/sync v0.0.0-20190227155943-e225da77a7e6/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= +golang.org/x/sync v0.0.0-20190412183630-56d357773e84/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= +golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= +golang.org/x/sync v0.0.0-20190911185100-cd5d95a43a6e/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= +golang.org/x/sys v0.0.0-20180905080454-ebe1bf3edb33/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= +golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= +golang.org/x/sys v0.0.0-20190403152447-81d4e9dc473e/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20190412213103-97732733099d/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20190419153524-e8e3143a4f4a/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20190422165155-953cdadca894/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20190531175056-4c3a928424d2/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20191026070338-33540a1f6037/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20210420072515-93ed5bcd2bfe/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20210615035016-665e8c7367d1/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.0.0-20211216021012-1d35b9e2eb4e/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.0.0-20220811171246-fbc7d0a398ab/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.12.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.35.0 h1:vz1N37gP5bs89s7He8XuIYXpyY0+QlsKmzipCbUtyxI= +golang.org/x/sys v0.35.0/go.mod h1:BJP2sWEmIv4KK5OTEluFJCKSidICx8ciO85XgH3Ak8k= +golang.org/x/term v0.0.0-20201117132131-f5c789dd3221/go.mod h1:Nr5EML6q2oocZ2LXRh80K7BxOlk5/8JxuGnuhpl+muw= +golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo= +golang.org/x/term v0.0.0-20210927222741-03fcf44c2211/go.mod h1:jbD1KX2456YbFQfuXm/mYQcufACuNUgVhRMnK/tPxf8= +golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= +golang.org/x/text v0.3.5/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= +golang.org/x/text v0.3.6/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= +golang.org/x/text v0.3.7/go.mod h1:u+2+/6zg+i71rQMx5EYifcz6MCKuco9NR6JIITiCfzQ= +golang.org/x/text v0.28.0 h1:rhazDwis8INMIwQ4tpjLDzUhx6RlXqZNPEM0huQojng= +golang.org/x/text v0.28.0/go.mod h1:U8nCwOR8jO/marOQ0QbDiOngZVEBB7MAiitBuMjXiNU= +golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= +golang.org/x/tools v0.0.0-20190329151228-23e29df326fe/go.mod h1:LCzVGOaR6xXOjkQ3onu1FJEFr0SW1gC7cKk1uF8kGRs= +golang.org/x/tools v0.0.0-20190416151739-9c9e1878f421/go.mod h1:LCzVGOaR6xXOjkQ3onu1FJEFr0SW1gC7cKk1uF8kGRs= +golang.org/x/tools v0.0.0-20190420181800-aa740d480789/go.mod h1:LCzVGOaR6xXOjkQ3onu1FJEFr0SW1gC7cKk1uF8kGRs= +golang.org/x/tools v0.0.0-20190531172133-b3315ee88b7d/go.mod h1:/rFqwRUd4F7ZHNgwSSTFct+R/Kf4OFW1sUzUTQQTgfc= +golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= +google.golang.org/genproto/googleapis/rpc v0.0.0-20250811230008-5f3141c8851a h1:tPE/Kp+x9dMSwUm/uM0JKK0IfdiJkwAbSMSeZBXXJXc= +google.golang.org/genproto/googleapis/rpc v0.0.0-20250811230008-5f3141c8851a/go.mod h1:gw1tLEfykwDz2ET4a12jcXt4couGAm7IwsVaTy0Sflo= +google.golang.org/grpc v1.74.2 h1:WoosgB65DlWVC9FqI82dGsZhWFNBSLjQ84bjROOpMu4= +google.golang.org/grpc v1.74.2/go.mod h1:CtQ+BGjaAIXHs/5YS3i473GqwBBa1zGQNevxdeBEXrM= +google.golang.org/protobuf v1.36.7 h1:IgrO7UwFQGJdRNXH/sQux4R1Dj1WAKcLElzeeRaXV2A= +google.golang.org/protobuf v1.36.7/go.mod h1:jduwjTPXsFjZGTmRluh+L6NjiWu7pchiJ2/5YcXBHnY= +gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= +gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= +gopkg.in/check.v1 v1.0.0-20200227125254-8fa46927fb4f/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= +gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c h1:Hei/4ADfdWqJk1ZMxUNpqntNwaWcugrBjAiHlqqRiVk= +gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c/go.mod h1:JHkPIbrfpd72SG/EVd6muEfDQjcINNoR0C8j2r3qZ4Q= +gopkg.in/errgo.v2 v2.1.0/go.mod h1:hNsd1EY+bozCKY1Ytp96fpM3vjJbqLJn88ws8XvfDNI= +gopkg.in/yaml.v2 v2.2.2/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= +gopkg.in/yaml.v2 v2.2.8/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= +gopkg.in/yaml.v2 v2.4.0 h1:D8xgwECY7CYvx+Y2n4sBz93Jn9JRvxdiyyo8CTfuKaY= +gopkg.in/yaml.v2 v2.4.0/go.mod h1:RDklbk79AGWmwhnvt/jBztapEOGDOx6ZbXqjP6csGnQ= +gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= +gopkg.in/yaml.v3 v3.0.0-20200605160147-a5ece683394c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= +gopkg.in/yaml.v3 v3.0.0-20200615113413-eeeca48fe776/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= +gopkg.in/yaml.v3 v3.0.0-20210107192922-496545a6307b/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= +gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= +gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= +gorm.io/driver/sqlite v1.6.0 h1:WHRRrIiulaPiPFmDcod6prc4l2VGVWHz80KspNsxSfQ= +gorm.io/driver/sqlite v1.6.0/go.mod h1:AO9V1qIQddBESngQUKWL9yoH93HIeA1X6V633rBwyT8= +gorm.io/gorm v1.30.1 h1:lSHg33jJTBxs2mgJRfRZeLDG+WZaHYCk3Wtfl6Ngzo4= +gorm.io/gorm v1.30.1/go.mod h1:8Z33v652h4//uMA76KjeDH8mJXPm1QNCYrMeatR0DOE= diff --git a/plugins/telemetry/main.go b/plugins/telemetry/main.go new file mode 100644 index 000000000..ce99ceb8f --- /dev/null +++ b/plugins/telemetry/main.go @@ -0,0 +1,201 @@ +// Package telemetry provides Prometheus metrics collection and monitoring functionality +// for the Bifrost HTTP service. It includes middleware for HTTP request tracking +// and a plugin for tracking upstream provider metrics. +package telemetry + +import ( + "context" + "log" + "time" + + bifrost "github.com/maximhq/bifrost/core" + schemas "github.com/maximhq/bifrost/core/schemas" + "github.com/maximhq/bifrost/framework/pricing" + "github.com/prometheus/client_golang/prometheus" +) + +const ( + PluginName = "telemetry" +) + +// ContextKey is a custom type for prometheus context keys to prevent collisions +type ContextKey string + +const ( + startTimeKey ContextKey = "bf-prom-start-time" +) + +// PrometheusPlugin implements the schemas.Plugin interface for Prometheus metrics. +// It tracks metrics for upstream provider requests, including: +// - Total number of requests +// - Request latency +// - Error counts +type PrometheusPlugin struct { + pricingManager *pricing.PricingManager + + // Metrics are defined using promauto for automatic registration + UpstreamRequestsTotal *prometheus.CounterVec + UpstreamLatency *prometheus.HistogramVec + SuccessRequestsTotal *prometheus.CounterVec + ErrorRequestsTotal *prometheus.CounterVec + InputTokensTotal *prometheus.CounterVec + OutputTokensTotal *prometheus.CounterVec + CacheHitsTotal *prometheus.CounterVec + CostTotal *prometheus.CounterVec +} + +// NewPrometheusPlugin creates a new PrometheusPlugin with initialized metrics. +func Init(pricingManager *pricing.PricingManager, logger schemas.Logger) *PrometheusPlugin { + if pricingManager == nil { + logger.Warn("telemetry plugin requires pricing manager to calculate cost, all cost calculations will be skipped.") + } + + return &PrometheusPlugin{ + pricingManager: pricingManager, + UpstreamRequestsTotal: bifrostUpstreamRequestsTotal, + UpstreamLatency: bifrostUpstreamLatencySeconds, + SuccessRequestsTotal: bifrostSuccessRequestsTotal, + ErrorRequestsTotal: bifrostErrorRequestsTotal, + InputTokensTotal: bifrostInputTokensTotal, + OutputTokensTotal: bifrostOutputTokensTotal, + CacheHitsTotal: bifrostCacheHitsTotal, + CostTotal: bifrostCostTotal, + } +} + +// GetName returns the name of the plugin. +func (p *PrometheusPlugin) GetName() string { + return PluginName +} + +// PreHook records the start time of the request in the context. +// This time is used later in PostHook to calculate request duration. +func (p *PrometheusPlugin) PreHook(ctx *context.Context, req *schemas.BifrostRequest) (*schemas.BifrostRequest, *schemas.PluginShortCircuit, error) { + *ctx = context.WithValue(*ctx, startTimeKey, time.Now()) + + return req, nil, nil +} + +// PostHook calculates duration and records upstream metrics for successful requests. +// It records: +// - Request latency +// - Total request count +func (p *PrometheusPlugin) PostHook(ctx *context.Context, result *schemas.BifrostResponse, bifrostErr *schemas.BifrostError) (*schemas.BifrostResponse, *schemas.BifrostError, error) { + if result == nil { + return result, bifrostErr, nil + } + + requestType, ok := (*ctx).Value(schemas.BifrostContextKeyRequestType).(schemas.RequestType) + if !ok { + log.Println("Warning: request type not found in context for Prometheus PostHook") + return result, bifrostErr, nil + } + + // For streaming requests, only record metrics on the final chunk + if bifrost.IsStreamRequestType(requestType) { + streamEndIndicatorValue := (*ctx).Value(schemas.BifrostContextKeyStreamEndIndicator) + if streamEndIndicatorValue == nil { + // No stream end indicator - this is an intermediate chunk, skip metrics + return result, bifrostErr, nil + } + + isFinalChunk, ok := streamEndIndicatorValue.(bool) + if !ok || !isFinalChunk { + // Not the final chunk or can't parse indicator - skip metrics + return result, bifrostErr, nil + } + // This is the final chunk - continue with metrics recording + } + + startTime, ok := (*ctx).Value(startTimeKey).(time.Time) + if !ok { + log.Println("Warning: startTime not found in context for Prometheus PostHook") + return result, bifrostErr, nil + } + + provider, ok := (*ctx).Value(schemas.BifrostContextKeyRequestProvider).(schemas.ModelProvider) + if !ok { + log.Println("Warning: provider not found in context for Prometheus PostHook") + return result, bifrostErr, nil + } + + model, ok := (*ctx).Value(schemas.BifrostContextKeyRequestModel).(string) + if !ok { + log.Println("Warning: model not found in context for Prometheus PostHook") + return result, bifrostErr, nil + } + + method, ok := (*ctx).Value(schemas.BifrostContextKeyRequestType).(schemas.RequestType) + if !ok { + log.Println("Warning: method not found in context for Prometheus PostHook") + return result, bifrostErr, nil + } + + // Calculate cost and record metrics in a separate goroutine to avoid blocking the main thread + go func() { + cost := 0.0 + if p.pricingManager != nil { + cost = p.pricingManager.CalculateCostWithCacheDebug(result, provider, model, requestType) + } + + labelValues := map[string]string{ + "provider": string(provider), + "model": model, + "method": string(method), + } + + // Get all prometheus labels from context + for _, key := range customLabels { + if value := (*ctx).Value(ContextKey(key)); value != nil { + if strValue, ok := value.(string); ok { + labelValues[key] = strValue + } + } + } + + // Get label values in the correct order (cache_type will be handled separately for cache hits) + promLabelValues := getPrometheusLabelValues(append([]string{"provider", "model", "method"}, customLabels...), labelValues) + + duration := time.Since(startTime).Seconds() + p.UpstreamLatency.WithLabelValues(promLabelValues...).Observe(duration) + p.UpstreamRequestsTotal.WithLabelValues(promLabelValues...).Inc() + + // Record cost using the dedicated cost counter + if cost > 0 { + p.CostTotal.WithLabelValues(promLabelValues...).Add(cost) + } + + // Record error and success counts + if bifrostErr != nil { + p.ErrorRequestsTotal.WithLabelValues(promLabelValues...).Inc() + } else { + p.SuccessRequestsTotal.WithLabelValues(promLabelValues...).Inc() + } + + // Record input and output tokens + if result.Usage != nil { + p.InputTokensTotal.WithLabelValues(promLabelValues...).Add(float64(result.Usage.PromptTokens)) + p.OutputTokensTotal.WithLabelValues(promLabelValues...).Add(float64(result.Usage.CompletionTokens)) + } + + // Record cache hits with cache type + if result.ExtraFields.CacheDebug != nil && result.ExtraFields.CacheDebug.CacheHit { + cacheType := "unknown" + if result.ExtraFields.CacheDebug.HitType != nil { + cacheType = *result.ExtraFields.CacheDebug.HitType + } + + // Add cache_type to label values + cacheHitLabelValues := append(promLabelValues[:3], cacheType) // provider, model, method, cache_type + cacheHitLabelValues = append(cacheHitLabelValues, promLabelValues[3:]...) // then custom labels + + p.CacheHitsTotal.WithLabelValues(cacheHitLabelValues...).Inc() + } + }() + + return result, bifrostErr, nil +} + +func (p *PrometheusPlugin) Cleanup() error { + return nil +} diff --git a/plugins/telemetry/prometheus.yml b/plugins/telemetry/prometheus.yml new file mode 100644 index 000000000..6682b021f --- /dev/null +++ b/plugins/telemetry/prometheus.yml @@ -0,0 +1,15 @@ +# Prometheus configuration for tracking bifrost-http service (for development and testing purposes only, don't use in production without proper setup) +global: + scrape_interval: 5s # Scrape every 5 seconds + +# Note: Target configuration depends on your deployment environment: +# - For local development: Use "host.docker.internal:8080" to access the service running on your host machine +# - For Docker deployment: Use "bifrost-api:8080" to access the service within the Docker network +# Make sure to replace "bifrost-api" and "8080" with your actual docker container name and port if different +# Also check that you have the bifrost container inside "bifrost_tracking_network". + +scrape_configs: + - job_name: "bifrost-api" + static_configs: + - targets: ["host.docker.internal:8080"] # Scrape from the /metrics endpoint + diff --git a/plugins/telemetry/setup.go b/plugins/telemetry/setup.go new file mode 100644 index 000000000..70ae1ed4e --- /dev/null +++ b/plugins/telemetry/setup.go @@ -0,0 +1,276 @@ +// Package telemetry provides Prometheus metrics collection and monitoring functionality +// for the Bifrost HTTP service. This file contains the setup and configuration +// for Prometheus metrics collection, including HTTP middleware and metric definitions. +package telemetry + +import ( + "log" + "math" + "strconv" + "strings" + "time" + + "github.com/prometheus/client_golang/prometheus" + "github.com/prometheus/client_golang/prometheus/promauto" + "github.com/valyala/fasthttp" +) + +var ( + // httpRequestsTotal tracks the total number of HTTP requests + httpRequestsTotal *prometheus.CounterVec + // httpRequestDuration tracks the duration of HTTP requests + httpRequestDuration *prometheus.HistogramVec + // httpRequestSizeBytes tracks the size of incoming HTTP requests + httpRequestSizeBytes *prometheus.HistogramVec + // httpResponseSizeBytes tracks the size of outgoing HTTP responses + httpResponseSizeBytes *prometheus.HistogramVec + + // bifrostUpstreamRequestsTotal tracks the total number of requests forwarded to upstream providers by Bifrost. + bifrostUpstreamRequestsTotal *prometheus.CounterVec + + // bifrostUpstreamLatencySeconds tracks the latency of requests forwarded to upstream providers by Bifrost. + bifrostUpstreamLatencySeconds *prometheus.HistogramVec + + // bifrostSuccessRequestsTotal tracks the total number of successful requests forwarded to upstream providers by Bifrost. + bifrostSuccessRequestsTotal *prometheus.CounterVec + + // bifrostErrorRequestsTotal tracks the total number of error requests forwarded to upstream providers by Bifrost. + bifrostErrorRequestsTotal *prometheus.CounterVec + + // bifrostInputTokensTotal tracks the total number of input tokens forwarded to upstream providers by Bifrost. + bifrostInputTokensTotal *prometheus.CounterVec + + // bifrostOutputTokensTotal tracks the total number of output tokens forwarded to upstream providers by Bifrost. + bifrostOutputTokensTotal *prometheus.CounterVec + + // bifrostCacheHitsTotal tracks the total number of cache hits forwarded to upstream providers by Bifrost, separated by cache type (direct/semantic). + bifrostCacheHitsTotal *prometheus.CounterVec + + // bifrostCostTotal tracks the total cost in USD for requests to upstream providers + bifrostCostTotal *prometheus.CounterVec + + // customLabels stores the expected label names in order + customLabels []string + isInitialized bool +) + +func InitPrometheusMetrics(labels []string) { + if isInitialized { + return + } + + customLabels = labels + + httpDefaultLabels := []string{"path", "method", "status"} + bifrostDefaultLabels := []string{"provider", "model", "method"} + + // Upstream LLM latency buckets - extended range for AI model inference times + upstreamLatencyBuckets := []float64{.005, .01, .025, .05, .1, .25, .5, 1, 2.5, 5, 10, 15, 30, 45, 60, 90} // in seconds + + httpRequestsTotal = promauto.NewCounterVec( + prometheus.CounterOpts{ + Name: "http_requests_total", + Help: "Total number of HTTP requests.", + }, + append(httpDefaultLabels, labels...), + ) + + // httpRequestDuration tracks the duration of HTTP requests + httpRequestDuration = promauto.NewHistogramVec( + prometheus.HistogramOpts{ + Name: "http_request_duration_seconds", + Help: "Duration of HTTP requests.", + Buckets: prometheus.DefBuckets, + }, + append(httpDefaultLabels, labels...), + ) + + // httpRequestSizeBytes tracks the size of incoming HTTP requests + httpRequestSizeBytes = promauto.NewHistogramVec( + prometheus.HistogramOpts{ + Name: "http_request_size_bytes", + Help: "Size of HTTP requests.", + Buckets: prometheus.ExponentialBuckets(100, 10, 8), // 100B to 1GB + }, + append(httpDefaultLabels, labels...), + ) + + // httpResponseSizeBytes tracks the size of outgoing HTTP responses + httpResponseSizeBytes = promauto.NewHistogramVec( + prometheus.HistogramOpts{ + Name: "http_response_size_bytes", + Help: "Size of HTTP responses.", + Buckets: prometheus.ExponentialBuckets(100, 10, 8), // 100B to 1GB + }, + append(httpDefaultLabels, labels...), + ) + + // Bifrost Upstream Metrics (Defined globally, used by PrometheusPlugin) + bifrostUpstreamRequestsTotal = promauto.NewCounterVec( + prometheus.CounterOpts{ + Name: "bifrost_upstream_requests_total", + Help: "Total number of requests forwarded to upstream providers by Bifrost.", + }, + append(bifrostDefaultLabels, labels...), + ) + + bifrostUpstreamLatencySeconds = promauto.NewHistogramVec( + prometheus.HistogramOpts{ + Name: "bifrost_upstream_latency_seconds", + Help: "Latency of requests forwarded to upstream providers by Bifrost.", + Buckets: upstreamLatencyBuckets, // Extended range for AI model inference times + }, + append(bifrostDefaultLabels, labels...), + ) + + bifrostSuccessRequestsTotal = promauto.NewCounterVec( + prometheus.CounterOpts{ + Name: "bifrost_success_requests_total", + Help: "Total number of successful requests forwarded to upstream providers by Bifrost.", + }, + append(bifrostDefaultLabels, labels...), + ) + + bifrostErrorRequestsTotal = promauto.NewCounterVec( + prometheus.CounterOpts{ + Name: "bifrost_error_requests_total", + Help: "Total number of error requests forwarded to upstream providers by Bifrost.", + }, + append(bifrostDefaultLabels, labels...), + ) + + bifrostInputTokensTotal = promauto.NewCounterVec( + prometheus.CounterOpts{ + Name: "bifrost_input_tokens_total", + Help: "Total number of input tokens forwarded to upstream providers by Bifrost.", + }, + append(bifrostDefaultLabels, labels...), + ) + + bifrostOutputTokensTotal = promauto.NewCounterVec( + prometheus.CounterOpts{ + Name: "bifrost_output_tokens_total", + Help: "Total number of output tokens forwarded to upstream providers by Bifrost.", + }, + append(bifrostDefaultLabels, labels...), + ) + + bifrostCacheHitsTotal = promauto.NewCounterVec( + prometheus.CounterOpts{ + Name: "bifrost_cache_hits_total", + Help: "Total number of cache hits forwarded to upstream providers by Bifrost, separated by cache type (direct/semantic).", + }, + append(append(bifrostDefaultLabels, "cache_type"), labels...), + ) + + bifrostCostTotal = promauto.NewCounterVec( + prometheus.CounterOpts{ + Name: "bifrost_cost_total", + Help: "Total cost in USD for requests to upstream providers.", + }, + append(bifrostDefaultLabels, labels...), + ) + + isInitialized = true +} + +// getPrometheusLabelValues takes an array of expected label keys and a map of header values, +// and returns an array of values in the same order as the keys, using empty string for missing values. +func getPrometheusLabelValues(expectedLabels []string, headerValues map[string]string) []string { + values := make([]string, len(expectedLabels)) + for i, label := range expectedLabels { + if value, exists := headerValues[label]; exists { + values[i] = value + } else { + values[i] = "" // Default empty value for missing labels + } + } + return values +} + +// collectPrometheusKeyValues collects all metrics for a request including: +// - Default metrics (path, method, status, request size) +// - Custom prometheus headers (x-bf-prom-*) +// Returns a map of all label values +func collectPrometheusKeyValues(ctx *fasthttp.RequestCtx) map[string]string { + path := string(ctx.Path()) + method := string(ctx.Method()) + + // Initialize with default metrics + labelValues := map[string]string{ + "path": path, + "method": method, + } + + // Collect custom prometheus headers + ctx.Request.Header.All()(func(key, value []byte) bool { + keyStr := strings.ToLower(string(key)) + if strings.HasPrefix(keyStr, "x-bf-prom-") { + labelName := strings.TrimPrefix(keyStr, "x-bf-prom-") + labelValues[labelName] = string(value) + ctx.SetUserValue(keyStr, string(value)) + } + return true + }) + + return labelValues +} + +// PrometheusMiddleware wraps a FastHTTP handler to collect Prometheus metrics. +// It tracks: +// - Total number of requests +// - Request duration +// - Request and response sizes +// - HTTP status codes +// - Bifrost upstream requests and errors +func PrometheusMiddleware(handler fasthttp.RequestHandler) fasthttp.RequestHandler { + if !isInitialized { + log.Println("Prometheus metrics are not initialized. Please call InitPrometheusMetrics first. Skipping metrics collection.") + return handler + } + + return func(ctx *fasthttp.RequestCtx) { + start := time.Now() + + // Collect request metrics and headers + promKeyValues := collectPrometheusKeyValues(ctx) + reqSize := float64(ctx.Request.Header.ContentLength()) + + // Process the request + handler(ctx) + + // Record metrics after request completion + duration := time.Since(start).Seconds() + status := strconv.Itoa(ctx.Response.StatusCode()) + respSize := float64(ctx.Response.Header.ContentLength()) + + // Add status to the label values + promKeyValues["status"] = status + + // Get label values in the correct order + promLabelValues := getPrometheusLabelValues(append([]string{"path", "method", "status"}, customLabels...), promKeyValues) + + // Record all metrics with prometheus labels + httpRequestsTotal.WithLabelValues(promLabelValues...).Inc() + httpRequestDuration.WithLabelValues(promLabelValues...).Observe(duration) + if reqSize >= 0 { + safeObserve(httpRequestSizeBytes, reqSize, promLabelValues...) + } + if respSize >= 0 { + safeObserve(httpResponseSizeBytes, respSize, promLabelValues...) + } + } +} + +// safeObserve safely records a value in a Prometheus histogram. +// It prevents recording invalid values (negative or infinite) that could cause issues. +func safeObserve(histogram *prometheus.HistogramVec, value float64, labels ...string) { + if value > 0 && value < math.MaxFloat64 { + metric, err := histogram.GetMetricWithLabelValues(labels...) + if err != nil { + log.Printf("Error getting metric with label values: %v", err) + } else { + metric.Observe(value) + } + } +} diff --git a/plugins/telemetry/version b/plugins/telemetry/version new file mode 100644 index 000000000..05060b805 --- /dev/null +++ b/plugins/telemetry/version @@ -0,0 +1 @@ +1.2.15 \ No newline at end of file diff --git a/tests/configs/noconfigstorenologstore/config.json b/tests/configs/noconfigstorenologstore/config.json new file mode 100644 index 000000000..ad3d10774 --- /dev/null +++ b/tests/configs/noconfigstorenologstore/config.json @@ -0,0 +1,3 @@ +{ + "$schema": "https://www.getbifrost.ai/schema" +} \ No newline at end of file diff --git a/tests/configs/withconfigstore/config.json b/tests/configs/withconfigstore/config.json new file mode 100644 index 000000000..c0ab4f84c --- /dev/null +++ b/tests/configs/withconfigstore/config.json @@ -0,0 +1,10 @@ +{ + "$schema": "https://www.getbifrost.ai/schema", + "config_store": { + "enabled": true, + "type": "sqlite", + "config": { + "path": "../../tests/configs/withconfigstore/config.db" + } + } +} \ No newline at end of file diff --git a/tests/configs/withconfigstorelogsstore/config.json b/tests/configs/withconfigstorelogsstore/config.json new file mode 100644 index 000000000..56912ef4e --- /dev/null +++ b/tests/configs/withconfigstorelogsstore/config.json @@ -0,0 +1,17 @@ +{ + "$schema": "https://www.getbifrost.ai/schema", + "config_store": { + "enabled": true, + "type": "sqlite", + "config": { + "path": "../../tests/configs/withconfigstorelogsstore/config.db" + } + }, + "logs_store": { + "enabled": true, + "type": "sqlite", + "config": { + "path": "../../tests/configs/withconfigstorelogsstore/logs.db" + } + } +} \ No newline at end of file diff --git a/tests/configs/withsemanticcache/config.json b/tests/configs/withsemanticcache/config.json new file mode 100644 index 000000000..f0775d490 --- /dev/null +++ b/tests/configs/withsemanticcache/config.json @@ -0,0 +1,21 @@ +{ + "$schema": "https://www.getbifrost.ai/schema", + "vector_store": { + "enabled": true, + "type": "weaviate", + "config": { + "scheme": "http", + "host": "localhost:9000" + } + }, + "plugins": [ + { + "enabled": true, + "name": "semantic_cache", + "config": { + "ttl": 300, + "threshold": 0.8 + } + } + ] +} \ No newline at end of file diff --git a/tests/core-chatbot/go.mod b/tests/core-chatbot/go.mod new file mode 100644 index 000000000..dddf8ceb2 --- /dev/null +++ b/tests/core-chatbot/go.mod @@ -0,0 +1,53 @@ +module github.com/maximhq/bifrost/tests/core-chatbot + +go 1.24 + +toolchain go1.24.3 + +replace github.com/maximhq/bifrost/core => ../../core + +require github.com/maximhq/bifrost/core v1.1.21 + +require ( + cloud.google.com/go/compute/metadata v0.8.0 // indirect + github.com/andybalholm/brotli v1.2.0 // indirect + github.com/aws/aws-sdk-go-v2 v1.38.0 // indirect + github.com/aws/aws-sdk-go-v2/config v1.31.0 // indirect + github.com/aws/aws-sdk-go-v2/credentials v1.18.4 // indirect + github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.18.3 // indirect + github.com/aws/aws-sdk-go-v2/internal/configsources v1.4.3 // indirect + github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.7.3 // indirect + github.com/aws/aws-sdk-go-v2/internal/ini v1.8.3 // indirect + github.com/aws/aws-sdk-go-v2/service/internal/accept-encoding v1.13.0 // indirect + github.com/aws/aws-sdk-go-v2/service/internal/presigned-url v1.13.3 // indirect + github.com/aws/aws-sdk-go-v2/service/sso v1.28.0 // indirect + github.com/aws/aws-sdk-go-v2/service/ssooidc v1.33.0 // indirect + github.com/aws/aws-sdk-go-v2/service/sts v1.37.0 // indirect + github.com/aws/smithy-go v1.22.5 // indirect + github.com/bahlo/generic-list-go v0.2.0 // indirect + github.com/buger/jsonparser v1.1.1 // indirect + github.com/bytedance/sonic v1.14.0 // indirect + github.com/bytedance/sonic/loader v0.3.0 // indirect + github.com/cloudwego/base64x v0.1.6 // indirect + github.com/google/uuid v1.6.0 // indirect + github.com/invopop/jsonschema v0.13.0 // indirect + github.com/klauspost/compress v1.18.0 // indirect + github.com/klauspost/cpuid/v2 v2.3.0 // indirect + github.com/mailru/easyjson v0.9.0 // indirect + github.com/mark3labs/mcp-go v0.37.0 // indirect + github.com/mattn/go-colorable v0.1.14 // indirect + github.com/mattn/go-isatty v0.0.20 // indirect + github.com/rs/zerolog v1.34.0 // indirect + github.com/spf13/cast v1.9.2 // indirect + github.com/twitchyliquid64/golang-asm v0.15.1 // indirect + github.com/valyala/bytebufferpool v1.0.0 // indirect + github.com/valyala/fasthttp v1.65.0 // indirect + github.com/wk8/go-ordered-map/v2 v2.1.8 // indirect + github.com/yosida95/uritemplate/v3 v3.0.2 // indirect + golang.org/x/arch v0.20.0 // indirect + golang.org/x/net v0.43.0 // indirect + golang.org/x/oauth2 v0.30.0 // indirect + golang.org/x/sys v0.35.0 // indirect + golang.org/x/text v0.28.0 // indirect + gopkg.in/yaml.v3 v3.0.1 // indirect +) diff --git a/tests/core-chatbot/go.sum b/tests/core-chatbot/go.sum new file mode 100644 index 000000000..f35c0925a --- /dev/null +++ b/tests/core-chatbot/go.sum @@ -0,0 +1,122 @@ +cloud.google.com/go/compute/metadata v0.8.0 h1:HxMRIbao8w17ZX6wBnjhcDkW6lTFpgcaobyVfZWqRLA= +cloud.google.com/go/compute/metadata v0.8.0/go.mod h1:sYOGTp851OV9bOFJ9CH7elVvyzopvWQFNNghtDQ/Biw= +github.com/andybalholm/brotli v1.2.0 h1:ukwgCxwYrmACq68yiUqwIWnGY0cTPox/M94sVwToPjQ= +github.com/andybalholm/brotli v1.2.0/go.mod h1:rzTDkvFWvIrjDXZHkuS16NPggd91W3kUSvPlQ1pLaKY= +github.com/aws/aws-sdk-go-v2 v1.38.0 h1:UCRQ5mlqcFk9HJDIqENSLR3wiG1VTWlyUfLDEvY7RxU= +github.com/aws/aws-sdk-go-v2 v1.38.0/go.mod h1:9Q0OoGQoboYIAJyslFyF1f5K1Ryddop8gqMhWx/n4Wg= +github.com/aws/aws-sdk-go-v2/config v1.31.0 h1:9yH0xiY5fUnVNLRWO0AtayqwU1ndriZdN78LlhruJR4= +github.com/aws/aws-sdk-go-v2/config v1.31.0/go.mod h1:VeV3K72nXnhbe4EuxxhzsDc/ByrCSlZwUnWH52Nde/I= +github.com/aws/aws-sdk-go-v2/credentials v1.18.4 h1:IPd0Algf1b+Qy9BcDp0sCUcIWdCQPSzDoMK3a8pcbUM= +github.com/aws/aws-sdk-go-v2/credentials v1.18.4/go.mod h1:nwg78FjH2qvsRM1EVZlX9WuGUJOL5od+0qvm0adEzHk= +github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.18.3 h1:GicIdnekoJsjq9wqnvyi2elW6CGMSYKhdozE7/Svh78= +github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.18.3/go.mod h1:R7BIi6WNC5mc1kfRM7XM/VHC3uRWkjc396sfabq4iOo= +github.com/aws/aws-sdk-go-v2/internal/configsources v1.4.3 h1:o9RnO+YZ4X+kt5Z7Nvcishlz0nksIt2PIzDglLMP0vA= +github.com/aws/aws-sdk-go-v2/internal/configsources v1.4.3/go.mod h1:+6aLJzOG1fvMOyzIySYjOFjcguGvVRL68R+uoRencN4= +github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.7.3 h1:joyyUFhiTQQmVK6ImzNU9TQSNRNeD9kOklqTzyk5v6s= +github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.7.3/go.mod h1:+vNIyZQP3b3B1tSLI0lxvrU9cfM7gpdRXMFfm67ZcPc= +github.com/aws/aws-sdk-go-v2/internal/ini v1.8.3 h1:bIqFDwgGXXN1Kpp99pDOdKMTTb5d2KyU5X/BZxjOkRo= +github.com/aws/aws-sdk-go-v2/internal/ini v1.8.3/go.mod h1:H5O/EsxDWyU+LP/V8i5sm8cxoZgc2fdNR9bxlOFrQTo= +github.com/aws/aws-sdk-go-v2/service/internal/accept-encoding v1.13.0 h1:6+lZi2JeGKtCraAj1rpoZfKqnQ9SptseRZioejfUOLM= +github.com/aws/aws-sdk-go-v2/service/internal/accept-encoding v1.13.0/go.mod h1:eb3gfbVIxIoGgJsi9pGne19dhCBpK6opTYpQqAmdy44= +github.com/aws/aws-sdk-go-v2/service/internal/presigned-url v1.13.3 h1:ieRzyHXypu5ByllM7Sp4hC5f/1Fy5wqxqY0yB85hC7s= +github.com/aws/aws-sdk-go-v2/service/internal/presigned-url v1.13.3/go.mod h1:O5ROz8jHiOAKAwx179v+7sHMhfobFVi6nZt8DEyiYoM= +github.com/aws/aws-sdk-go-v2/service/sso v1.28.0 h1:Mc/MKBf2m4VynyJkABoVEN+QzkfLqGj0aiJuEe7cMeM= +github.com/aws/aws-sdk-go-v2/service/sso v1.28.0/go.mod h1:iS5OmxEcN4QIPXARGhavH7S8kETNL11kym6jhoS7IUQ= +github.com/aws/aws-sdk-go-v2/service/ssooidc v1.33.0 h1:6csaS/aJmqZQbKhi1EyEMM7yBW653Wy/B9hnBofW+sw= +github.com/aws/aws-sdk-go-v2/service/ssooidc v1.33.0/go.mod h1:59qHWaY5B+Rs7HGTuVGaC32m0rdpQ68N8QCN3khYiqs= +github.com/aws/aws-sdk-go-v2/service/sts v1.37.0 h1:MG9VFW43M4A8BYeAfaJJZWrroinxeTi2r3+SnmLQfSA= +github.com/aws/aws-sdk-go-v2/service/sts v1.37.0/go.mod h1:JdeBDPgpJfuS6rU/hNglmOigKhyEZtBmbraLE4GK1J8= +github.com/aws/smithy-go v1.22.5 h1:P9ATCXPMb2mPjYBgueqJNCA5S9UfktsW0tTxi+a7eqw= +github.com/aws/smithy-go v1.22.5/go.mod h1:t1ufH5HMublsJYulve2RKmHDC15xu1f26kHCp/HgceI= +github.com/bahlo/generic-list-go v0.2.0 h1:5sz/EEAK+ls5wF+NeqDpk5+iNdMDXrh3z3nPnH1Wvgk= +github.com/bahlo/generic-list-go v0.2.0/go.mod h1:2KvAjgMlE5NNynlg/5iLrrCCZ2+5xWbdbCW3pNTGyYg= +github.com/buger/jsonparser v1.1.1 h1:2PnMjfWD7wBILjqQbt530v576A/cAbQvEW9gGIpYMUs= +github.com/buger/jsonparser v1.1.1/go.mod h1:6RYKKt7H4d4+iWqouImQ9R2FZql3VbhNgx27UK13J/0= +github.com/bytedance/sonic v1.14.0 h1:/OfKt8HFw0kh2rj8N0F6C/qPGRESq0BbaNZgcNXXzQQ= +github.com/bytedance/sonic v1.14.0/go.mod h1:WoEbx8WTcFJfzCe0hbmyTGrfjt8PzNEBdxlNUO24NhA= +github.com/bytedance/sonic/loader v0.3.0 h1:dskwH8edlzNMctoruo8FPTJDF3vLtDT0sXZwvZJyqeA= +github.com/bytedance/sonic/loader v0.3.0/go.mod h1:N8A3vUdtUebEY2/VQC0MyhYeKUFosQU6FxH2JmUe6VI= +github.com/cloudwego/base64x v0.1.6 h1:t11wG9AECkCDk5fMSoxmufanudBtJ+/HemLstXDLI2M= +github.com/cloudwego/base64x v0.1.6/go.mod h1:OFcloc187FXDaYHvrNIjxSe8ncn0OOM8gEHfghB2IPU= +github.com/coreos/go-systemd/v22 v22.5.0/go.mod h1:Y58oyj3AT4RCenI/lSvhwexgC+NSVTIJ3seZv2GcEnc= +github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= +github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/frankban/quicktest v1.14.6 h1:7Xjx+VpznH+oBnejlPUj8oUpdxnVs4f8XU8WnHkI4W8= +github.com/frankban/quicktest v1.14.6/go.mod h1:4ptaffx2x8+WTWXmUCuVU6aPUX1/Mz7zb5vbUoiM6w0= +github.com/godbus/dbus/v5 v5.0.4/go.mod h1:xhWf0FNVPg57R7Z0UbKHbJfkEywrmjJnf7w5xrFpKfA= +github.com/google/go-cmp v0.7.0 h1:wk8382ETsv4JYUZwIsn6YpYiWiBsYLSJiTsyBybVuN8= +github.com/google/go-cmp v0.7.0/go.mod h1:pXiqmnSA92OHEEa9HXL2W4E7lf9JzCmGVUdgjX3N/iU= +github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0= +github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= +github.com/invopop/jsonschema v0.13.0 h1:KvpoAJWEjR3uD9Kbm2HWJmqsEaHt8lBUpd0qHcIi21E= +github.com/invopop/jsonschema v0.13.0/go.mod h1:ffZ5Km5SWWRAIN6wbDXItl95euhFz2uON45H2qjYt+0= +github.com/klauspost/compress v1.18.0 h1:c/Cqfb0r+Yi+JtIEq73FWXVkRonBlf0CRNYc8Zttxdo= +github.com/klauspost/compress v1.18.0/go.mod h1:2Pp+KzxcywXVXMr50+X0Q/Lsb43OQHYWRCY2AiWywWQ= +github.com/klauspost/cpuid/v2 v2.3.0 h1:S4CRMLnYUhGeDFDqkGriYKdfoFlDnMtqTiI/sFzhA9Y= +github.com/klauspost/cpuid/v2 v2.3.0/go.mod h1:hqwkgyIinND0mEev00jJYCxPNVRVXFQeu1XKlok6oO0= +github.com/kr/pretty v0.3.1 h1:flRD4NNwYAUpkphVc1HcthR4KEIFJ65n8Mw5qdRn3LE= +github.com/kr/pretty v0.3.1/go.mod h1:hoEshYVHaxMs3cyo3Yncou5ZscifuDolrwPKZanG3xk= +github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY= +github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE= +github.com/mailru/easyjson v0.9.0 h1:PrnmzHw7262yW8sTBwxi1PdJA3Iw/EKBa8psRf7d9a4= +github.com/mailru/easyjson v0.9.0/go.mod h1:1+xMtQp2MRNVL/V1bOzuP3aP8VNwRW55fQUto+XFtTU= +github.com/mark3labs/mcp-go v0.37.0 h1:BywvZLPRT6Zx6mMG/MJfxLSZQkTGIcJSEGKsvr4DsoQ= +github.com/mark3labs/mcp-go v0.37.0/go.mod h1:T7tUa2jO6MavG+3P25Oy/jR7iCeJPHImCZHRymCn39g= +github.com/mattn/go-colorable v0.1.13/go.mod h1:7S9/ev0klgBDR4GtXTXX8a3vIGJpMovkB8vQcUbaXHg= +github.com/mattn/go-colorable v0.1.14 h1:9A9LHSqF/7dyVVX6g0U9cwm9pG3kP9gSzcuIPHPsaIE= +github.com/mattn/go-colorable v0.1.14/go.mod h1:6LmQG8QLFO4G5z1gPvYEzlUgJ2wF+stgPZH1UqBm1s8= +github.com/mattn/go-isatty v0.0.16/go.mod h1:kYGgaQfpe5nmfYZH+SKPsOc2e4SrIfOl2e/yFXSvRLM= +github.com/mattn/go-isatty v0.0.19/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y= +github.com/mattn/go-isatty v0.0.20 h1:xfD0iDuEKnDkl03q4limB+vH+GxLEtL/jb4xVJSWWEY= +github.com/mattn/go-isatty v0.0.20/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y= +github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= +github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= +github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2 h1:Jamvg5psRIccs7FGNTlIRMkT8wgtp5eCXdBlqhYGL6U= +github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= +github.com/rogpeppe/go-internal v1.13.1 h1:KvO1DLK/DRN07sQ1LQKScxyZJuNnedQ5/wKSR38lUII= +github.com/rogpeppe/go-internal v1.13.1/go.mod h1:uMEvuHeurkdAXX61udpOXGD/AzZDWNMNyH2VO9fmH0o= +github.com/rs/xid v1.6.0/go.mod h1:7XoLgs4eV+QndskICGsho+ADou8ySMSjJKDIan90Nz0= +github.com/rs/zerolog v1.34.0 h1:k43nTLIwcTVQAncfCw4KZ2VY6ukYoZaBPNOE8txlOeY= +github.com/rs/zerolog v1.34.0/go.mod h1:bJsvje4Z08ROH4Nhs5iH600c3IkWhwp44iRc54W6wYQ= +github.com/spf13/cast v1.9.2 h1:SsGfm7M8QOFtEzumm7UZrZdLLquNdzFYfIbEXntcFbE= +github.com/spf13/cast v1.9.2/go.mod h1:jNfB8QC9IA6ZuY2ZjDp0KtFO2LZZlg4S/7bzP6qqeHo= +github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= +github.com/stretchr/objx v0.4.0/go.mod h1:YvHI0jy2hoMjB+UWwv71VJQ9isScKT/TqJzVSSt89Yw= +github.com/stretchr/objx v0.5.0/go.mod h1:Yh+to48EsGEfYuaHDzXPcE3xhTkx73EhmCGUpEOglKo= +github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= +github.com/stretchr/testify v1.8.0/go.mod h1:yNjHg4UonilssWZ8iaSj1OCr/vHnekPRkoO+kdMU+MU= +github.com/stretchr/testify v1.8.1/go.mod h1:w2LPCIKwWwSfY2zedu0+kehJoqGctiVI29o6fzry7u4= +github.com/stretchr/testify v1.10.0 h1:Xv5erBjTwe/5IxqUQTdXv5kgmIvbHo3QQyRwhJsOfJA= +github.com/stretchr/testify v1.10.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY= +github.com/twitchyliquid64/golang-asm v0.15.1 h1:SU5vSMR7hnwNxj24w34ZyCi/FmDZTkS4MhqMhdFk5YI= +github.com/twitchyliquid64/golang-asm v0.15.1/go.mod h1:a1lVb/DtPvCB8fslRZhAngC2+aY1QWCk3Cedj/Gdt08= +github.com/valyala/bytebufferpool v1.0.0 h1:GqA5TC/0021Y/b9FG4Oi9Mr3q7XYx6KllzawFIhcdPw= +github.com/valyala/bytebufferpool v1.0.0/go.mod h1:6bBcMArwyJ5K/AmCkWv1jt77kVWyCJ6HpOuEn7z0Csc= +github.com/valyala/fasthttp v1.65.0 h1:j/u3uzFEGFfRxw79iYzJN+TteTJwbYkru9uDp3d0Yf8= +github.com/valyala/fasthttp v1.65.0/go.mod h1:P/93/YkKPMsKSnATEeELUCkG8a7Y+k99uxNHVbKINr4= +github.com/wk8/go-ordered-map/v2 v2.1.8 h1:5h/BUHu93oj4gIdvHHHGsScSTMijfx5PeYkE/fJgbpc= +github.com/wk8/go-ordered-map/v2 v2.1.8/go.mod h1:5nJHM5DyteebpVlHnWMV0rPz6Zp7+xBAnxjb1X5vnTw= +github.com/xyproto/randomstring v1.0.5 h1:YtlWPoRdgMu3NZtP45drfy1GKoojuR7hmRcnhZqKjWU= +github.com/xyproto/randomstring v1.0.5/go.mod h1:rgmS5DeNXLivK7YprL0pY+lTuhNQW3iGxZ18UQApw/E= +github.com/yosida95/uritemplate/v3 v3.0.2 h1:Ed3Oyj9yrmi9087+NczuL5BwkIc4wvTb5zIM+UJPGz4= +github.com/yosida95/uritemplate/v3 v3.0.2/go.mod h1:ILOh0sOhIJR3+L/8afwt/kE++YT040gmv5BQTMR2HP4= +golang.org/x/arch v0.20.0 h1:dx1zTU0MAE98U+TQ8BLl7XsJbgze2WnNKF/8tGp/Q6c= +golang.org/x/arch v0.20.0/go.mod h1:bdwinDaKcfZUGpH09BB7ZmOfhalA8lQdzl62l8gGWsk= +golang.org/x/net v0.43.0 h1:lat02VYK2j4aLzMzecihNvTlJNQUq316m2Mr9rnM6YE= +golang.org/x/net v0.43.0/go.mod h1:vhO1fvI4dGsIjh73sWfUVjj3N7CA9WkKJNQm2svM6Jg= +golang.org/x/oauth2 v0.30.0 h1:dnDm7JmhM45NNpd8FDDeLhK6FwqbOf4MLCM9zb1BOHI= +golang.org/x/oauth2 v0.30.0/go.mod h1:B++QgG3ZKulg6sRPGD/mqlHQs5rB3Ml9erfeDY7xKlU= +golang.org/x/sys v0.0.0-20220811171246-fbc7d0a398ab/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.12.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.35.0 h1:vz1N37gP5bs89s7He8XuIYXpyY0+QlsKmzipCbUtyxI= +golang.org/x/sys v0.35.0/go.mod h1:BJP2sWEmIv4KK5OTEluFJCKSidICx8ciO85XgH3Ak8k= +golang.org/x/text v0.28.0 h1:rhazDwis8INMIwQ4tpjLDzUhx6RlXqZNPEM0huQojng= +golang.org/x/text v0.28.0/go.mod h1:U8nCwOR8jO/marOQ0QbDiOngZVEBB7MAiitBuMjXiNU= +gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= +gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c h1:Hei/4ADfdWqJk1ZMxUNpqntNwaWcugrBjAiHlqqRiVk= +gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c/go.mod h1:JHkPIbrfpd72SG/EVd6muEfDQjcINNoR0C8j2r3qZ4Q= +gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= +gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= +gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= diff --git a/tests/core-chatbot/main.go b/tests/core-chatbot/main.go new file mode 100644 index 000000000..c98d410f5 --- /dev/null +++ b/tests/core-chatbot/main.go @@ -0,0 +1,944 @@ +package main + +import ( + "bufio" + "context" + "fmt" + "os" + "os/signal" + "strconv" + "strings" + "sync" + "syscall" + "time" + + bifrost "github.com/maximhq/bifrost/core" + "github.com/maximhq/bifrost/core/schemas" + "golang.org/x/text/cases" + "golang.org/x/text/language" +) + +// ChatbotConfig holds configuration for the chatbot +type ChatbotConfig struct { + Provider schemas.ModelProvider + Model string + MCPAgenticMode bool + MCPServerPort int + Temperature *float64 + MaxTokens *int +} + +// ChatSession manages the conversation state +type ChatSession struct { + history []schemas.BifrostMessage + client *bifrost.Bifrost + config ChatbotConfig + systemPrompt string + account *ComprehensiveTestAccount +} + +// ComprehensiveTestAccount provides a test implementation of the Account interface for comprehensive testing. +type ComprehensiveTestAccount struct{} + +// getEnvWithDefault returns the value of the environment variable if set, otherwise returns the default value +func getEnvWithDefault(envVar, defaultValue string) string { + if value := os.Getenv(envVar); value != "" { + return value + } + return defaultValue +} + +// GetConfiguredProviders returns the list of initially supported providers. +func (account *ComprehensiveTestAccount) GetConfiguredProviders() ([]schemas.ModelProvider, error) { + return []schemas.ModelProvider{ + schemas.OpenAI, + schemas.Anthropic, + schemas.Bedrock, + schemas.Cohere, + schemas.Azure, + schemas.Vertex, + schemas.Ollama, + schemas.Mistral, + }, nil +} + +// GetKeysForProvider returns the API keys and associated models for a given provider. +func (account *ComprehensiveTestAccount) GetKeysForProvider(ctx *context.Context, providerKey schemas.ModelProvider) ([]schemas.Key, error) { + switch providerKey { + case schemas.OpenAI: + return []schemas.Key{ + { + Value: os.Getenv("OPENAI_API_KEY"), + Models: []string{"gpt-4o-mini", "gpt-4-turbo", "gpt-4o"}, + Weight: 1.0, + }, + }, nil + case schemas.Anthropic: + return []schemas.Key{ + { + Value: os.Getenv("ANTHROPIC_API_KEY"), + Models: []string{"claude-3-7-sonnet-20250219", "claude-3-5-sonnet-20240620", "claude-2.1"}, + Weight: 1.0, + }, + }, nil + case schemas.Bedrock: + return []schemas.Key{ + { + Value: os.Getenv("BEDROCK_API_KEY"), + Models: []string{"anthropic.claude-v2:1", "mistral.mixtral-8x7b-instruct-v0:1", "mistral.mistral-large-2402-v1:0", "anthropic.claude-3-sonnet-20240229-v1:0"}, + Weight: 1.0, + }, + }, nil + case schemas.Cohere: + return []schemas.Key{ + { + Value: os.Getenv("COHERE_API_KEY"), + Models: []string{"command-a-03-2025", "c4ai-aya-vision-8b"}, + Weight: 1.0, + }, + }, nil + case schemas.Azure: + return []schemas.Key{ + { + Value: os.Getenv("AZURE_API_KEY"), + Models: []string{"gpt-4o"}, + Weight: 1.0, + }, + }, nil + case schemas.Vertex: + return []schemas.Key{ + { + Value: os.Getenv("VERTEX_API_KEY"), + Models: []string{"gemini-pro", "gemini-1.5-pro"}, + Weight: 1.0, + }, + }, nil + case schemas.Mistral: + return []schemas.Key{ + { + Value: os.Getenv("MISTRAL_API_KEY"), + Models: []string{"mistral-large-2411", "pixtral-12b-latest"}, + Weight: 1.0, + }, + }, nil + case schemas.Ollama: + return []schemas.Key{ + { + Value: "", // Ollama is keyless + Models: []string{"llama3.2", "llama3.1", "mistral", "codellama"}, + Weight: 1.0, + }, + }, nil + default: + return nil, fmt.Errorf("unsupported provider: %s", providerKey) + } +} + +// GetConfigForProvider returns the configuration settings for a given provider. +func (account *ComprehensiveTestAccount) GetConfigForProvider(providerKey schemas.ModelProvider) (*schemas.ProviderConfig, error) { + switch providerKey { + case schemas.OpenAI: + return &schemas.ProviderConfig{ + NetworkConfig: schemas.NetworkConfig{ + DefaultRequestTimeoutInSeconds: 30, + MaxRetries: 1, + RetryBackoffInitial: 100 * time.Millisecond, + RetryBackoffMax: 2 * time.Second, + }, + ConcurrencyAndBufferSize: schemas.ConcurrencyAndBufferSize{ + Concurrency: 3, + BufferSize: 10, + }, + }, nil + case schemas.Anthropic: + return &schemas.ProviderConfig{ + NetworkConfig: schemas.DefaultNetworkConfig, + ConcurrencyAndBufferSize: schemas.DefaultConcurrencyAndBufferSize, + }, nil + case schemas.Bedrock: + return &schemas.ProviderConfig{ + NetworkConfig: schemas.NetworkConfig{ + DefaultRequestTimeoutInSeconds: 30, + MaxRetries: 1, + RetryBackoffInitial: 100 * time.Millisecond, + RetryBackoffMax: 2 * time.Second, + }, + // MetaConfig: &meta.BedrockMetaConfig{ // FIXME: meta package doesn't exist + // SecretAccessKey: os.Getenv("AWS_SECRET_ACCESS_KEY"), + // Region: bifrost.Ptr(getEnvWithDefault("AWS_REGION", "us-east-1")), + // }, + ConcurrencyAndBufferSize: schemas.ConcurrencyAndBufferSize{ + Concurrency: 3, + BufferSize: 10, + }, + }, nil + case schemas.Cohere: + return &schemas.ProviderConfig{ + NetworkConfig: schemas.DefaultNetworkConfig, + ConcurrencyAndBufferSize: schemas.DefaultConcurrencyAndBufferSize, + }, nil + case schemas.Azure: + return &schemas.ProviderConfig{ + NetworkConfig: schemas.NetworkConfig{ + DefaultRequestTimeoutInSeconds: 30, + MaxRetries: 1, + RetryBackoffInitial: 100 * time.Millisecond, + RetryBackoffMax: 2 * time.Second, + }, + // MetaConfig: &meta.AzureMetaConfig{ // FIXME: meta package doesn't exist + // Endpoint: os.Getenv("AZURE_ENDPOINT"), + // Deployments: map[string]string{ + // "gpt-4o": "gpt-4o-aug", + // }, + // APIVersion: bifrost.Ptr(getEnvWithDefault("AZURE_API_VERSION", "2024-08-01-preview")), + // }, + ConcurrencyAndBufferSize: schemas.ConcurrencyAndBufferSize{ + Concurrency: 3, + BufferSize: 10, + }, + }, nil + case schemas.Vertex: + return &schemas.ProviderConfig{ + NetworkConfig: schemas.NetworkConfig{ + DefaultRequestTimeoutInSeconds: 30, + MaxRetries: 1, + RetryBackoffInitial: 100 * time.Millisecond, + RetryBackoffMax: 2 * time.Second, + }, + // MetaConfig: &meta.VertexMetaConfig{ // FIXME: meta package doesn't exist + // ProjectID: os.Getenv("VERTEX_PROJECT_ID"), + // Region: getEnvWithDefault("VERTEX_REGION", "us-central1"), + // AuthCredentials: os.Getenv("VERTEX_CREDENTIALS"), + // }, + ConcurrencyAndBufferSize: schemas.ConcurrencyAndBufferSize{ + Concurrency: 3, + BufferSize: 10, + }, + }, nil + case schemas.Ollama: + return &schemas.ProviderConfig{ + NetworkConfig: schemas.DefaultNetworkConfig, + ConcurrencyAndBufferSize: schemas.DefaultConcurrencyAndBufferSize, + }, nil + case schemas.Mistral: + return &schemas.ProviderConfig{ + NetworkConfig: schemas.DefaultNetworkConfig, + ConcurrencyAndBufferSize: schemas.DefaultConcurrencyAndBufferSize, + }, nil + default: + return nil, fmt.Errorf("unsupported provider: %s", providerKey) + } +} + +// NewChatSession creates a new chat session with the given configuration +func NewChatSession(config ChatbotConfig) (*ChatSession, error) { + // Create MCP configuration for Bifrost + mcpConfig := &schemas.MCPConfig{ + ClientConfigs: []schemas.MCPClientConfig{}, + } + + fmt.Println("πŸ”Œ Configuring Serper MCP server...") + mcpConfig.ClientConfigs = append(mcpConfig.ClientConfigs, schemas.MCPClientConfig{ + Name: "serper-web-search-mcp", + ConnectionType: schemas.MCPConnectionTypeSTDIO, + StdioConfig: &schemas.MCPStdioConfig{ + Command: "npx", + Args: []string{"-y", "serper-search-scrape-mcp-server"}, + Envs: []string{"SERPER_API_KEY"}, + }, + ToolsToSkip: []string{}, // No tools to skip for this client + }, + schemas.MCPClientConfig{ + Name: "gmail-mcp", + ConnectionType: schemas.MCPConnectionTypeSSE, + ConnectionString: bifrost.Ptr("https://mcp.composio.dev/composio/server/654c1e3f-ea7d-47b6-9e31-398d00449654/sse"), + }, + ) + + fmt.Println("πŸ”Œ Configuring Context7 MCP server...") + mcpConfig.ClientConfigs = append(mcpConfig.ClientConfigs, schemas.MCPClientConfig{ + Name: "context7", + ConnectionType: schemas.MCPConnectionTypeSTDIO, + StdioConfig: &schemas.MCPStdioConfig{ + Command: "npx", + Args: []string{"-y", "@upstash/context7-mcp"}, + }, + ToolsToSkip: []string{}, // No tools to skip for this client + }) + + // Initialize Bifrost with MCP configuration + account := &ComprehensiveTestAccount{} + + client, err := bifrost.Init(context.Background(), schemas.BifrostConfig{ + Account: account, + Plugins: []schemas.Plugin{}, // No separate plugins needed - MCP is integrated + Logger: bifrost.NewDefaultLogger(schemas.LogLevelInfo), + MCPConfig: mcpConfig, // MCP is now configured here + }) + if err != nil { + return nil, fmt.Errorf("failed to initialize Bifrost: %w", err) + } + + session := &ChatSession{ + history: make([]schemas.BifrostMessage, 0), + client: client, + config: config, + account: account, + systemPrompt: "You are a helpful AI assistant with access to various tools. " + + "Use the available tools when they can help answer the user's questions more accurately or provide additional information.", + } + + // Add system message to history + if session.systemPrompt != "" { + session.history = append(session.history, schemas.BifrostMessage{ + Role: schemas.ModelChatMessageRoleSystem, + Content: schemas.MessageContent{ + ContentStr: &session.systemPrompt, + }, + }) + } + + return session, nil +} + +// getAvailableProviders returns a list of providers that have valid configurations +func (s *ChatSession) getAvailableProviders() []schemas.ModelProvider { + configuredProviders, err := s.account.GetConfiguredProviders() + if err != nil { + return []schemas.ModelProvider{} + } + + var availableProviders []schemas.ModelProvider + for _, provider := range configuredProviders { + // Check if provider has valid keys (except for keyless providers) + if provider == schemas.Ollama || provider == schemas.Vertex { + availableProviders = append(availableProviders, provider) + continue + } + ctx := context.Background() + keys, err := s.account.GetKeysForProvider(&ctx, provider) + if err == nil && len(keys) > 0 && keys[0].Value != "" { + availableProviders = append(availableProviders, provider) + } + } + return availableProviders +} + +// getAvailableModels returns available models for a given provider +func (s *ChatSession) getAvailableModels(provider schemas.ModelProvider) []string { + ctx := context.Background() + keys, err := s.account.GetKeysForProvider(&ctx, provider) + if err != nil || len(keys) == 0 { + return []string{} + } + return keys[0].Models +} + +// switchProvider handles switching to a different provider +func (s *ChatSession) switchProvider() error { + availableProviders := s.getAvailableProviders() + if len(availableProviders) == 0 { + fmt.Println("❌ No available providers found") + return fmt.Errorf("no available providers") + } + + fmt.Println("\nπŸ”„ Available Providers:") + fmt.Println("======================") + for i, provider := range availableProviders { + status := "" + if provider == s.config.Provider { + status = " (current)" + } + fmt.Printf("[%d] %s%s\n", i+1, provider, status) + } + + fmt.Print("\nSelect provider (number): ") + scanner := bufio.NewScanner(os.Stdin) + if !scanner.Scan() { + return fmt.Errorf("input cancelled") + } + + choice, err := strconv.Atoi(strings.TrimSpace(scanner.Text())) + if err != nil || choice < 1 || choice > len(availableProviders) { + return fmt.Errorf("invalid choice") + } + + newProvider := availableProviders[choice-1] + + // Get available models for the new provider + models := s.getAvailableModels(newProvider) + if len(models) == 0 { + return fmt.Errorf("no models available for provider %s", newProvider) + } + + // Auto-select first model or let user choose if multiple + var newModel string + if len(models) == 1 { + newModel = models[0] + } else { + fmt.Printf("\n🧠 Available Models for %s:\n", newProvider) + fmt.Println("================================") + for i, model := range models { + fmt.Printf("[%d] %s\n", i+1, model) + } + + fmt.Print("\nSelect model (number): ") + if !scanner.Scan() { + return fmt.Errorf("input cancelled") + } + + modelChoice, err := strconv.Atoi(strings.TrimSpace(scanner.Text())) + if err != nil || modelChoice < 1 || modelChoice > len(models) { + return fmt.Errorf("invalid model choice") + } + + newModel = models[modelChoice-1] + } + + // Update configuration + s.config.Provider = newProvider + s.config.Model = newModel + + fmt.Printf("βœ… Switched to %s with model %s\n", newProvider, newModel) + return nil +} + +// switchModel handles switching to a different model for the current provider +func (s *ChatSession) switchModel() error { + models := s.getAvailableModels(s.config.Provider) + if len(models) == 0 { + return fmt.Errorf("no models available for provider %s", s.config.Provider) + } + + if len(models) == 1 { + fmt.Printf("Only one model available for %s: %s\n", s.config.Provider, models[0]) + return nil + } + + fmt.Printf("\n🧠 Available Models for %s:\n", s.config.Provider) + fmt.Println("===============================") + for i, model := range models { + status := "" + if model == s.config.Model { + status = " (current)" + } + fmt.Printf("[%d] %s%s\n", i+1, model, status) + } + + fmt.Print("\nSelect model (number): ") + scanner := bufio.NewScanner(os.Stdin) + if !scanner.Scan() { + return fmt.Errorf("input cancelled") + } + + choice, err := strconv.Atoi(strings.TrimSpace(scanner.Text())) + if err != nil || choice < 1 || choice > len(models) { + return fmt.Errorf("invalid choice") + } + + newModel := models[choice-1] + s.config.Model = newModel + + fmt.Printf("βœ… Switched to model %s\n", newModel) + return nil +} + +// showCurrentConfig displays the current configuration +func (s *ChatSession) showCurrentConfig() { + fmt.Println("\nβš™οΈ Current Configuration:") + fmt.Println("=========================") + fmt.Printf("πŸ”§ Provider: %s\n", s.config.Provider) + fmt.Printf("🧠 Model: %s\n", s.config.Model) + fmt.Printf("πŸ”„ Agentic Mode: %t\n", s.config.MCPAgenticMode) + fmt.Printf("🌑️ Temperature: %.1f\n", *s.config.Temperature) + fmt.Printf("πŸ“ Max Tokens: %d\n", *s.config.MaxTokens) + fmt.Printf("πŸ”§ Tool Execution: Manual approval required\n") +} + +// AddUserMessage adds a user message to the conversation history +func (s *ChatSession) AddUserMessage(message string) { + userMessage := schemas.BifrostMessage{ + Role: schemas.ModelChatMessageRoleUser, + Content: schemas.MessageContent{ + ContentStr: &message, + }, + } + s.history = append(s.history, userMessage) +} + +// SendMessage sends a message and returns the assistant's response +func (s *ChatSession) SendMessage(message string) (string, error) { + // Add user message to history + s.AddUserMessage(message) + + // Prepare model parameters + params := &schemas.ModelParameters{} + if s.config.Temperature != nil { + params.Temperature = s.config.Temperature + } + if s.config.MaxTokens != nil { + params.MaxTokens = s.config.MaxTokens + } + params.ToolChoice = &schemas.ToolChoice{ + ToolChoiceStr: stringPtr("auto"), + } + + // Create request + request := &schemas.BifrostRequest{ + Provider: s.config.Provider, + Model: s.config.Model, + Input: schemas.RequestInput{ + ChatCompletionInput: &s.history, + }, + Params: params, + } + + // Start loading animation + stopChan, wg := startLoader() + + // Send request + response, err := s.client.ChatCompletionRequest(context.Background(), request) + + // Stop loading animation + stopLoader(stopChan, wg) + + if err != nil { + return "", fmt.Errorf("chat completion failed: %s", err.Error.Message) + } + + if response == nil || len(response.Choices) == 0 { + return "", fmt.Errorf("no response received") + } + + // Get the assistant's response + choice := response.Choices[0] + assistantMessage := choice.Message + + // Add assistant message to history + s.history = append(s.history, assistantMessage) + + // Check if assistant wants to use tools + if assistantMessage.ToolCalls != nil && len(*assistantMessage.ToolCalls) > 0 { + return s.handleToolCalls(assistantMessage) + } + + // Extract text content for regular responses + var responseText string + if assistantMessage.Content.ContentStr != nil { + responseText = *assistantMessage.Content.ContentStr + } else if assistantMessage.Content.ContentBlocks != nil { + var textParts []string + for _, block := range *assistantMessage.Content.ContentBlocks { + if block.Text != nil { + textParts = append(textParts, *block.Text) + } + } + responseText = strings.Join(textParts, "\n") + } + + return responseText, nil +} + +// handleToolCalls handles tool execution using the new Bifrost MCP integration +func (s *ChatSession) handleToolCalls(assistantMessage schemas.BifrostMessage) (string, error) { + toolCalls := *assistantMessage.ToolCalls + + // Display tools to user for approval + fmt.Println("\nπŸ”§ Assistant wants to use the following tools:") + fmt.Println("============================================") + + for i, toolCall := range toolCalls { + fmt.Printf("[%d] Tool: %s\n", i+1, *toolCall.Function.Name) + fmt.Printf(" Arguments: %s\n", toolCall.Function.Arguments) + fmt.Println() + } + + fmt.Print("Do you want to execute these tools? (y/n): ") + + scanner := bufio.NewScanner(os.Stdin) + if !scanner.Scan() { + return "❌ Tool execution cancelled by user.", nil + } + + input := strings.ToLower(strings.TrimSpace(scanner.Text())) + if input != "y" && input != "yes" { + return "❌ Tool execution cancelled by user.", nil + } + + fmt.Println("βœ… Executing tools...") + + // Execute each tool using Bifrost's ExecuteMCPTool method + toolResults := make([]schemas.BifrostMessage, 0) + for _, toolCall := range toolCalls { + // Start loading animation for this tool + stopChan, wg := startLoader() + + // Execute the tool using Bifrost's integrated MCP functionality + toolResult, err := s.client.ExecuteMCPTool(context.Background(), toolCall) + + // Stop loading animation + stopLoader(stopChan, wg) + + if err != nil { + fmt.Printf("❌ Error executing tool %s: %v\n", *toolCall.Function.Name, err) + // Create error message for this tool + errorResult := schemas.BifrostMessage{ + Role: schemas.ModelChatMessageRoleTool, + Content: schemas.MessageContent{ + ContentStr: stringPtr(fmt.Sprintf("Error executing tool: %v", err)), + }, + ToolMessage: &schemas.ToolMessage{ + ToolCallID: toolCall.ID, + }, + } + toolResults = append(toolResults, errorResult) + } else { + fmt.Printf("βœ… Tool %s executed successfully\n", *toolCall.Function.Name) + toolResults = append(toolResults, *toolResult) + } + } + + // Add tool results to conversation history + s.history = append(s.history, toolResults...) + + // If agentic mode is enabled, send conversation back to LLM for synthesis + if s.config.MCPAgenticMode { + return s.synthesizeToolResults() + } + + // Non-agentic mode: return the results directly + var responseText strings.Builder + responseText.WriteString("πŸ”§ Tool execution completed:\n\n") + + for i, result := range toolResults { + if result.Content.ContentStr != nil { + responseText.WriteString(fmt.Sprintf("Tool %d result: %s\n", i+1, *result.Content.ContentStr)) + } + } + + return responseText.String(), nil +} + +// synthesizeToolResults sends the conversation with tool results back to LLM for synthesis +func (s *ChatSession) synthesizeToolResults() (string, error) { + // Add synthesis prompt + synthesisPrompt := schemas.BifrostMessage{ + Role: schemas.ModelChatMessageRoleUser, + Content: schemas.MessageContent{ + ContentStr: stringPtr("Please provide a comprehensive response based on the tool results above."), + }, + } + + // Temporarily add synthesis prompt for the request + conversationWithSynthesis := append(s.history, synthesisPrompt) + + // Create synthesis request + synthesisRequest := &schemas.BifrostRequest{ + Provider: s.config.Provider, + Model: s.config.Model, + Input: schemas.RequestInput{ + ChatCompletionInput: &conversationWithSynthesis, + }, + Params: &schemas.ModelParameters{ + Temperature: s.config.Temperature, + MaxTokens: s.config.MaxTokens, + }, + } + + fmt.Println("πŸ€– Synthesizing response...") + + // Start loading animation + stopChan, wg := startLoader() + + // Send synthesis request + synthesisResponse, err := s.client.ChatCompletionRequest(context.Background(), synthesisRequest) + + // Stop loading animation + stopLoader(stopChan, wg) + + if err != nil { + fmt.Printf("⚠️ Synthesis failed: %v. Returning tool results directly.\n", err) + // Fallback to direct tool results + var responseText strings.Builder + responseText.WriteString("πŸ”§ Tool execution completed (synthesis failed):\n\n") + + // Get tool results from history (last few messages that are tool messages) + for i := len(s.history) - 1; i >= 0; i-- { + if s.history[i].Role == schemas.ModelChatMessageRoleTool { + if s.history[i].Content.ContentStr != nil { + responseText.WriteString(fmt.Sprintf("Tool result: %s\n", *s.history[i].Content.ContentStr)) + } + } else { + break // Stop when we hit non-tool messages + } + } + + return responseText.String(), nil + } + + if synthesisResponse == nil || len(synthesisResponse.Choices) == 0 { + return "❌ No synthesis response received", nil + } + + // Get synthesized response + synthesizedMessage := synthesisResponse.Choices[0].Message + + // Add synthesized response to history (replace the temporary synthesis prompt effect) + s.history = append(s.history, synthesizedMessage) + + // Extract text content + var responseText string + if synthesizedMessage.Content.ContentStr != nil { + responseText = *synthesizedMessage.Content.ContentStr + } else if synthesizedMessage.Content.ContentBlocks != nil { + var textParts []string + for _, block := range *synthesizedMessage.Content.ContentBlocks { + if block.Text != nil { + textParts = append(textParts, *block.Text) + } + } + responseText = strings.Join(textParts, "\n") + } + + return responseText, nil +} + +// PrintHistory prints the conversation history +func (s *ChatSession) PrintHistory() { + fmt.Println("\nπŸ“œ Conversation History:") + fmt.Println("========================") + + for i, msg := range s.history { + if msg.Role == schemas.ModelChatMessageRoleSystem { + continue // Skip system messages in history display + } + + var content string + if msg.Content.ContentStr != nil { + content = *msg.Content.ContentStr + } else if msg.Content.ContentBlocks != nil { + var textParts []string + for _, block := range *msg.Content.ContentBlocks { + if block.Text != nil { + textParts = append(textParts, *block.Text) + } + } + content = strings.Join(textParts, "\n") + } + + role := cases.Title(language.English).String(string(msg.Role)) + timestamp := fmt.Sprintf("[%d]", i) + + fmt.Printf("%s %s: %s\n\n", timestamp, role, content) + } +} + +// Cleanup closes the chat session and cleans up resources +func (s *ChatSession) Cleanup() { + if s.client != nil { + s.client.Shutdown() + } +} + +// printWelcome prints the welcome message and instructions +func printWelcome(config ChatbotConfig) { + fmt.Println("πŸ€– Bifrost CLI Chatbot") + fmt.Println("======================") + fmt.Printf("πŸ”§ Provider: %s\n", config.Provider) + fmt.Printf("🧠 Model: %s\n", config.Model) + fmt.Printf("πŸ”„ Agentic Mode: %t\n", config.MCPAgenticMode) + fmt.Printf("πŸ”§ Tool Execution: Manual approval required\n") + fmt.Println() + fmt.Println("Commands:") + fmt.Println(" /help - Show this help message") + fmt.Println(" /history - Show conversation history") + fmt.Println(" /clear - Clear conversation history") + fmt.Println(" /config - Show current configuration") + fmt.Println(" /provider - Switch provider") + fmt.Println(" /model - Switch model") + fmt.Println(" /quit - Exit the chatbot") + fmt.Println() + fmt.Println("Type your message and press Enter to chat!") + fmt.Println("When the assistant wants to use tools, you'll be asked to approve them.") + fmt.Println("==========================================") +} + +// printHelp prints help information +func printHelp() { + fmt.Println("\nπŸ“– Help") + fmt.Println("========") + fmt.Println("Available commands:") + fmt.Println(" /help - Show this help message") + fmt.Println(" /history - Show conversation history") + fmt.Println(" /clear - Clear conversation history (keeps system prompt)") + fmt.Println(" /config - Show current provider, model, and settings") + fmt.Println(" /provider - Switch between different AI providers") + fmt.Println(" /model - Switch between models for current provider") + fmt.Println(" /quit - Exit the chatbot") + fmt.Println() + fmt.Println("Supported providers:") + fmt.Println("β€’ OpenAI (gpt-4o-mini, gpt-4-turbo, gpt-4o)") + fmt.Println("β€’ Anthropic (claude models)") + fmt.Println("β€’ Bedrock (AWS hosted models)") + fmt.Println("β€’ Cohere (command models)") + fmt.Println("β€’ Azure (Azure OpenAI models)") + fmt.Println("β€’ Vertex (Google Cloud models)") + fmt.Println("β€’ Mistral (mistral models)") + fmt.Println("β€’ Ollama (local models)") + fmt.Println() + fmt.Println("Tool execution:") + fmt.Println("β€’ When the assistant wants to use tools, you'll be asked to approve them") + fmt.Println("β€’ You can review the tool names and arguments before approving") + fmt.Println("β€’ Available tools include web search and Context7") + fmt.Println("β€’ In agentic mode, tool results are synthesized into natural responses") + fmt.Println("β€’ In non-agentic mode, raw tool results are displayed") + fmt.Println() +} + +// stringPtr is a helper function to create string pointers +func stringPtr(s string) *string { + return &s +} + +// startLoader starts a loading spinner animation +func startLoader() (chan bool, *sync.WaitGroup) { + stopChan := make(chan bool) + var wg sync.WaitGroup + + wg.Add(1) + go func() { + defer wg.Done() + spinner := []string{"β ‹", "β ™", "β Ή", "β Έ", "β Ό", "β ΄", "β ¦", "β §", "β ‡", "⠏"} + i := 0 + + for { + select { + case <-stopChan: + // Clear the spinner + fmt.Print("\r\033[K") // Clear current line + return + default: + fmt.Printf("\rπŸ€– Assistant: %s Thinking...", spinner[i%len(spinner)]) + i++ + time.Sleep(100 * time.Millisecond) + } + } + }() + + return stopChan, &wg +} + +// stopLoader stops the loading animation +func stopLoader(stopChan chan bool, wg *sync.WaitGroup) { + close(stopChan) + wg.Wait() +} + +func main() { + // Check for required environment variables + if os.Getenv("OPENAI_API_KEY") == "" { + fmt.Println("❌ Error: OPENAI_API_KEY environment variable is required") + fmt.Println("πŸ’‘ Set additional provider API keys to access more models:") + fmt.Println(" - ANTHROPIC_API_KEY for Claude models") + fmt.Println(" - COHERE_API_KEY for Cohere models") + fmt.Println(" - MISTRAL_API_KEY for Mistral models") + fmt.Println(" - AWS credentials for Bedrock") + fmt.Println(" - AZURE_API_KEY and AZURE_ENDPOINT for Azure OpenAI") + fmt.Println(" - VERTEX_PROJECT_ID and credentials for Vertex AI") + os.Exit(1) + } + + // Default configuration + config := ChatbotConfig{ + Provider: schemas.OpenAI, + Model: "gpt-4o-mini", + MCPAgenticMode: true, + MCPServerPort: 8585, + Temperature: bifrost.Ptr(0.7), + MaxTokens: bifrost.Ptr(1000), + } + + // Create chat session + fmt.Println("πŸš€ Starting Bifrost CLI Chatbot...") + session, err := NewChatSession(config) + if err != nil { + fmt.Printf("❌ Failed to create chat session: %v\n", err) + os.Exit(1) + } + + // Setup graceful shutdown + sigChan := make(chan os.Signal, 1) + signal.Notify(sigChan, syscall.SIGINT, syscall.SIGTERM) + go func() { + <-sigChan + fmt.Println("\n\nπŸ‘‹ Goodbye! Cleaning up...") + session.Cleanup() + os.Exit(0) + }() + + // Give MCP servers time to initialize + fmt.Println("⏳ Waiting for MCP servers to initialize...") + time.Sleep(3 * time.Second) + + // Print welcome message + printWelcome(config) + + // Main chat loop + scanner := bufio.NewScanner(os.Stdin) + for { + fmt.Print("\nπŸ’¬ You: ") + if !scanner.Scan() { + break + } + + input := strings.TrimSpace(scanner.Text()) + if input == "" { + continue + } + + // Handle commands + switch input { + case "/help": + printHelp() + continue + case "/history": + session.PrintHistory() + continue + case "/clear": + // Keep system prompt but clear conversation history + systemPrompt := session.history[0] // Assuming first message is system + session.history = []schemas.BifrostMessage{systemPrompt} + fmt.Println("🧹 Conversation history cleared!") + continue + case "/config": + session.showCurrentConfig() + continue + case "/provider": + if err := session.switchProvider(); err != nil { + fmt.Printf("❌ Error switching provider: %v\n", err) + } + continue + case "/model": + if err := session.switchModel(); err != nil { + fmt.Printf("❌ Error switching model: %v\n", err) + } + continue + case "/quit": + fmt.Println("πŸ‘‹ Goodbye!") + session.Cleanup() + return + } + + // Send message and get response + response, err := session.SendMessage(input) + if err != nil { + fmt.Printf("\rπŸ€– Assistant: ❌ Error: %v\n", err) + continue + } + + fmt.Printf("πŸ€– Assistant: %s\n", response) + } + + // Cleanup + session.Cleanup() +} diff --git a/tests/core-providers/README.md b/tests/core-providers/README.md new file mode 100644 index 000000000..5b5fa6778 --- /dev/null +++ b/tests/core-providers/README.md @@ -0,0 +1,441 @@ +# Bifrost Core Providers Test Suite πŸš€ + +This directory contains comprehensive tests for all Bifrost AI providers, ensuring compatibility and functionality across different AI services. + +## πŸ“‹ Supported Providers + +- **OpenAI** - GPT models and function calling +- **Anthropic** - Claude models +- **Azure OpenAI** - Azure-hosted OpenAI models +- **AWS Bedrock** - Amazon's managed AI service +- **Cohere** - Cohere's language models +- **Google Vertex AI** - Google Cloud's AI platform +- **Mistral** - Mistral AI models with vision capabilities +- **Ollama** - Local LLM serving platform +- **Groq** - OSS models +- **SGLang** - OSS models +- **Parasail** - OSS models +- **Cerebras** - Llama, Qwen and GPT-OSS models +- **Gemini** - Gemini models +- **OpenRouter** - Models supported by OpenRouter + +## πŸƒβ€β™‚οΈ Running Tests + +### Development with Local Bifrost Core + +To test changes with a forked or local version of bifrost-core: + +1. **Uncomment the replace directive** in `tests/core-providers/go.mod`: + + ```go + // Uncomment this line to use your local bifrost-core + replace github.com/maximhq/bifrost/core => ../../core + ``` + +2. **Update dependencies**: + + ```bash + cd tests/core-providers + go mod tidy + ``` + +3. **Run tests** with your local changes: + + ```bash + go test -v ./tests/core-providers/ + ``` + +⚠️ **Important**: Ensure your local `../../core` directory contains your bifrost-core implementation. The path should be relative to the `tests/core-providers` directory. + +### Prerequisites + +Set up environment variables for the providers you want to test: + +```bash +# OpenAI +export OPENAI_API_KEY="your-openai-key" + +# Anthropic +export ANTHROPIC_API_KEY="your-anthropic-key" + +# Azure OpenAI +export AZURE_API_KEY="your-azure-key" +export AZURE_ENDPOINT="your-azure-endpoint" + +# AWS Bedrock +export AWS_ACCESS_KEY_ID_ID="your-aws-access-key" +export AWS_SECRET_ACCESS_KEY="your-aws-secret-key" +export AWS_REGION="us-east-1" + +# Cohere +export COHERE_API_KEY="your-cohere-key" + +# Google Vertex AI +export GOOGLE_APPLICATION_CREDENTIALS="path/to/service-account.json" +export GOOGLE_PROJECT_ID="your-project-id" + +# Mistral AI +export MISTRAL_API_KEY="your-mistral-key" + +# Gemini +export GEMINI_API_KEY="your-gemini-key" + +# Ollama (local installation) +# No API key required - ensure Ollama is running locally +# Default endpoint: http://localhost:11434 +``` + +### Run All Provider Tests + +```bash +# Run all tests with verbose output (recommended) +go test -v ./tests/core-providers/ + +# Run with debug logs +go test -v ./tests/core-providers/ -debug +``` + +### Run Specific Provider Tests + +```bash +# Test only OpenAI +go test -v ./tests/core-providers/ -run TestOpenAI + +# Test only Anthropic +go test -v ./tests/core-providers/ -run TestAnthropic + +# Test only Azure +go test -v ./tests/core-providers/ -run TestAzure + +# Test only Bedrock +go test -v ./tests/core-providers/ -run TestBedrock + +# Test only Cohere +go test -v ./tests/core-providers/ -run TestCohere + +# Test only Vertex AI +go test -v ./tests/core-providers/ -run TestVertex + +# Test only Mistral +go test -v ./tests/core-providers/ -run TestMistral + +# Test only Gemini +go test -v ./tests/core-providers/ -run TestGemini + +# Test only Ollama +go test -v ./tests/core-providers/ -run TestOllama +``` + +### Run Specific Test Scenarios + +You can run specific scenarios across all providers: + +```bash +# Test only chat completion +go test -v ./tests/core-providers/ -run "Chat" + +# Test only function calling +go test -v ./tests/core-providers/ -run "Function" +``` + +### Run Specific Scenario for Specific Provider + +You can combine provider and scenario filters to test specific functionality: + +```bash +# Test only OpenAI simple chat +go test -v ./tests/core-providers/ -run "TestOpenAI/SimpleChat" + +# Test only Anthropic tool calls +go test -v ./tests/core-providers/ -run "TestAnthropic/ToolCalls" + +# Test only Azure multi-turn conversation +go test -v ./tests/core-providers/ -run "TestAzure/MultiTurnConversation" + +# Test only Bedrock text completion +go test -v ./tests/core-providers/ -run "TestBedrock/TextCompletion" + +# Test only Cohere image URL processing +go test -v ./tests/core-providers/ -run "TestCohere/ImageURL" + +# Test only Vertex automatic function calling +go test -v ./tests/core-providers/ -run "TestVertex/AutomaticFunctionCalling" + +# Test only Mistral image processing +go test -v ./tests/core-providers/ -run "TestMistral/ImageURL" + +# Test only Gemini simple chat +go test -v ./tests/core-providers/ -run "TestGemini/SimpleChat" + +# Test only Ollama simple chat +go test -v ./tests/core-providers/ -run "TestOllama/SimpleChat" +``` + +**Available Scenario Names:** + +- `SimpleChat` - Basic chat completion +- `TextCompletion` - Text completion (legacy models) +- `MultiTurnConversation` - Multi-turn chat conversations +- `ToolCalls` - Basic function/tool calling +- `MultipleToolCalls` - Multiple tool calls in one request +- `End2EndToolCalling` - Complete tool calling workflow +- `AutomaticFunctionCalling` - Automatic function selection +- `ImageURL` - Image processing from URLs +- `ImageBase64` - Image processing from base64 +- `MultipleImages` - Multiple image processing +- `CompleteEnd2End` - Full end-to-end test +- `ProviderSpecific` - Provider-specific features +- `Embedding` - Basic embedding request + +## πŸ§ͺ Test Scenarios + +Each provider is tested against these scenarios when supported: + +βœ… **Supported by Most Providers:** + +- Simple Text Completion +- Simple Chat Completion +- Multi-turn Chat Conversation +- Chat with System Message +- Text Completion with Parameters +- Chat Completion with Parameters +- Error Handling (Invalid Model) +- Model Information Retrieval +- Simple Function Calling + +❌ **Provider-Specific Support:** + +- **Automatic Function Calling**: OpenAI, Anthropic, Bedrock, Azure, Vertex, Mistral, Ollama, Gemini +- **Vision/Image Analysis**: OpenAI, Anthropic, Bedrock, Azure, Vertex, Mistral, Gemini (limited support for Cohere and Ollama) +- **Text Completion**: Legacy models only (most providers now focus on chat completion) + +## πŸ“Š Understanding Test Output + +The test suite provides rich visual feedback: + +- πŸš€ **Test suite starting** +- βœ… **Successful operations and supported tests** +- ❌ **Failed operations and unsupported features** +- ⏭️ **Skipped scenarios (not supported by provider)** +- πŸ“Š **Summary statistics** +- ℹ️ **Informational notes** + +Example output: + +```text +=== RUN TestOpenAI +πŸš€ Starting comprehensive test suite for OpenAI provider... +βœ… Simple Text Completion test completed successfully +βœ… Simple Chat Completion test completed successfully +⏭️ Automatic Function Calling not supported by this provider +πŸ“Š Test Summary for OpenAI: +βœ…βœ… Supported Tests: 11 +❌ Unsupported Tests: 1 +``` + +## πŸ”§ Adding New Providers + +To add a new provider to the test suite: + +### 1. Create Provider Test File + +Create a new file `{provider}_test.go`: + +```go +package tests + +import ( + "testing" + "github.com/BifrostDev/bifrost/pkg/client" +) + +func TestNewProvider(t *testing.T) { + config := client.Config{ + Provider: "newprovider", + APIKey: getEnvVar("NEW_PROVIDER_API_KEY"), + // Add other required config fields + } + + // Skip if no API key provided + if config.APIKey == "" { + t.Skip("NEW_PROVIDER_API_KEY not set, skipping NewProvider tests") + } + + runProviderTests(t, config, "NewProvider") +} +``` + +### 2. Update Provider Configuration + +Add your provider's capabilities in `tests.go`: + +```go +func getProviderCapabilities(providerName string) ProviderCapabilities { + switch providerName { + case "NewProvider": + return ProviderCapabilities{ + SupportsTextCompletion: true, + SupportsChatCompletion: true, + SupportsFunctionCalling: false, // Update based on provider + SupportsAutomaticFunctions: false, + SupportsVision: false, + SupportsSystemMessages: true, + SupportsMultiTurn: true, + SupportsParameters: true, + SupportsModelInfo: true, + SupportsErrorHandling: true, + } + // ... other cases + } +} +``` + +### 3. Add Default Models + +Add default models for your provider: + +```go +func getDefaultModel(providerName string) string { + switch providerName { + case "NewProvider": + return "newprovider-model-name" + // ... other cases + } +} +``` + +### 4. Environment Variables + +Document any required environment variables in this README and ensure they're handled in the test setup. + +### 5. Test Your Implementation + +Run your new provider tests: + +```bash +go test -v ./tests/core-providers/ -run TestNewProvider +``` + +## πŸ› οΈ Troubleshooting + +### Common Issues + +1. **Tests being skipped**: Make sure environment variables are set correctly +2. **Connection timeouts**: Check your network connection and API endpoints +3. **Authentication errors**: Verify your API keys are valid and have proper permissions +4. **Missing logs**: Use `-v` flag to see detailed test output +5. **Rate limiting**: Some providers have rate limits; tests may need delays +6. **Ollama connection issues**: Ensure Ollama is running locally (`ollama serve`) +7. **Mistral vision failures**: Check if your account has access to Pixtral models + +### Debug Mode + +Enable debug logging to see detailed API interactions: + +```bash +go test -v ./tests/core-providers/ -debug +``` + +### Provider-Specific Considerations + +#### Mistral AI + +- **Models**: Uses `pixtral-12b-latest` for vision tasks +- **Capabilities**: Full support for chat, tools, and vision +- **API Key**: Required via `MISTRAL_API_KEY` environment variable + +#### Gemini + +- **Models**: Uses `gemini-2.0-flash` for chat and `text-embedding-004` for embeddings +- **Capabilities**: Full support for chat, tools, vision (base64), speech synthesis, and transcription +- **API Key**: Required via `GEMINI_API_KEY` environment variable +- **Limitations**: No text completion support, limited image URL support (base64 preferred) + +#### Ollama + +- **Local Setup**: Requires Ollama to be running locally (default: `http://localhost:11434`) +- **Models**: Uses `llama3.2` model by default +- **No API Key**: Authentication not required for local instances +- **Limitations**: No vision/image processing support +- **Installation**: [Download from ollama.ai](https://ollama.ai/) and ensure the service is running + +### Checking Provider Status + +If a provider seems to be failing, you can check their status pages: + +- [OpenAI Status](https://status.openai.com/) +- [Anthropic Status](https://status.anthropic.com/) +- [Azure Status](https://status.azure.com/) +- [AWS Status](https://status.aws.amazon.com/) +- [Mistral Status](https://status.mistral.ai/) + +## πŸ“ Test Coverage + +The comprehensive test suite covers: + +- βœ… **Text Completion** - Legacy completion models (where supported) +- βœ… **Simple Chat** - Basic chat completion functionality +- βœ… **Multi-Turn Conversations** - Context maintenance across messages +- βœ… **Tool Calls** - Basic function/tool calling capabilities +- βœ… **Multiple Tool Calls** - Multiple tools in a single request +- βœ… **End-to-End Tool Calling** - Complete tool workflow with result integration +- βœ… **Automatic Function Calling** - Provider-managed tool execution +- βœ… **Image URL Processing** - Image analysis from URLs +- βœ… **Image Base64 Processing** - Image analysis from base64 encoded data +- βœ… **Multiple Images** - Multi-image analysis and comparison +- βœ… **Complete End-to-End** - Full multimodal workflows +- βœ… **Provider-Specific Features** - Integration-unique capabilities + +### Provider Capability Matrix + +| Provider | Chat | Tools | Vision | Text Completion | Auto Functions | +| --------- | ---- | ----- | ------ | --------------- | -------------- | +| OpenAI | βœ… | βœ… | βœ… | ❌ | βœ… | +| Anthropic | βœ… | βœ… | βœ… | βœ… | βœ… | +| Azure | βœ… | βœ… | βœ… | βœ… | βœ… | +| Bedrock | βœ… | βœ… | βœ… | βœ… | βœ… | +| Vertex | βœ… | βœ… | βœ… | ❌ | βœ… | +| Cohere | βœ… | βœ… | ❌ | ❌ | ❌ | +| Mistral | βœ… | βœ… | βœ… | ❌ | βœ… | +| Ollama | βœ… | βœ… | ❌ | ❌ | βœ… | +| Gemini | βœ… | βœ… | βœ… | ❌ | βœ… | + +## 🀝 Contributing + +When adding new providers or test scenarios: + +### Adding New Providers + +1. **Create test file**: Add `{provider}_test.go` following the existing pattern +2. **Update config**: Add provider configuration in `config/account.go`: + - Add to `GetKeysForProvider()` (if API key required) + - Add to `GetConfigForProvider()` + - Add to `GetConfiguredProviders()` list +3. **Test scenarios**: Configure supported scenarios in the test file +4. **Documentation**: Update this README with environment variables and capabilities +5. **Testing**: Test with multiple scenarios to verify integration + +### Adding New Test Scenarios + +1. **Implement scenario**: Add new test function in `scenarios/` directory +2. **Update structure**: Add scenario to `TestScenarios` struct in `config/account.go` +3. **Configure providers**: Update each provider's scenario configuration +4. **Update runner**: Add scenario call to `runAllComprehensiveTests()` in `tests.go` +5. **Documentation**: Update README with scenario description and examples + +### Testing Your Changes + +```bash +# Test specific provider +go test -v ./tests/core-providers/ -run TestYourProvider + +# Test all providers +go test -v ./tests/core-providers/ + +# Test with debug output +go test -v ./tests/core-providers/ -debug +``` + +## πŸ“„ License + +This test suite is part of the Bifrost project and follows the same license terms. diff --git a/tests/core-providers/anthropic_test.go b/tests/core-providers/anthropic_test.go new file mode 100644 index 000000000..3878e0a6d --- /dev/null +++ b/tests/core-providers/anthropic_test.go @@ -0,0 +1,46 @@ +package tests + +import ( + "testing" + + "github.com/maximhq/bifrost/tests/core-providers/config" + + "github.com/maximhq/bifrost/core/schemas" +) + +func TestAnthropic(t *testing.T) { + client, ctx, cancel, err := config.SetupTest() + if err != nil { + t.Fatalf("Error initializing test setup: %v", err) + } + defer cancel() + defer client.Shutdown() + + testConfig := config.ComprehensiveTestConfig{ + Provider: schemas.Anthropic, + ChatModel: "claude-3-7-sonnet-20250219", + TextModel: "", // Anthropic doesn't support text completion + EmbeddingModel: "", // Anthropic doesn't support embedding + Scenarios: config.TestScenarios{ + TextCompletion: false, // Not supported + SimpleChat: true, + ChatCompletionStream: true, + MultiTurnConversation: true, + ToolCalls: true, + MultipleToolCalls: true, + End2EndToolCalling: true, + AutomaticFunctionCall: true, + ImageURL: true, + ImageBase64: true, + MultipleImages: true, + CompleteEnd2End: true, + ProviderSpecific: true, + Embedding: false, + }, + Fallbacks: []schemas.Fallback{ + {Provider: schemas.OpenAI, Model: "gpt-4o-mini"}, + }, + } + + runAllComprehensiveTests(t, client, ctx, testConfig) +} diff --git a/tests/core-providers/azure_test.go b/tests/core-providers/azure_test.go new file mode 100644 index 000000000..fd563321e --- /dev/null +++ b/tests/core-providers/azure_test.go @@ -0,0 +1,43 @@ +package tests + +import ( + "testing" + + "github.com/maximhq/bifrost/tests/core-providers/config" + + "github.com/maximhq/bifrost/core/schemas" +) + +func TestAzure(t *testing.T) { + client, ctx, cancel, err := config.SetupTest() + if err != nil { + t.Fatalf("Error initializing test setup: %v", err) + } + defer cancel() + defer client.Shutdown() + + testConfig := config.ComprehensiveTestConfig{ + Provider: schemas.Azure, + ChatModel: "gpt-4o", + TextModel: "", // Azure OpenAI doesn't support text completion in newer models + EmbeddingModel: "text-embedding-3-small", + Scenarios: config.TestScenarios{ + TextCompletion: false, // Not supported + SimpleChat: true, + ChatCompletionStream: true, + MultiTurnConversation: true, + ToolCalls: true, + MultipleToolCalls: true, + End2EndToolCalling: true, + AutomaticFunctionCall: true, + ImageURL: true, + ImageBase64: true, + MultipleImages: true, + CompleteEnd2End: true, + ProviderSpecific: true, + Embedding: true, + }, + } + + runAllComprehensiveTests(t, client, ctx, testConfig) +} diff --git a/tests/core-providers/bedrock_test.go b/tests/core-providers/bedrock_test.go new file mode 100644 index 000000000..2eda3e480 --- /dev/null +++ b/tests/core-providers/bedrock_test.go @@ -0,0 +1,48 @@ +package tests + +import ( + "os" + "testing" + + "github.com/maximhq/bifrost/tests/core-providers/config" + + "github.com/maximhq/bifrost/core/schemas" +) + +func TestBedrock(t *testing.T) { + if os.Getenv("AWS_ACCESS_KEY_ID") == "" || os.Getenv("AWS_SECRET_ACCESS_KEY") == "" { + t.Skip("Skipping Bedrock embedding: AWS credentials not set") + } + + client, ctx, cancel, err := config.SetupTest() + if err != nil { + t.Fatalf("Error initializing test setup: %v", err) + } + defer cancel() + defer client.Shutdown() + + testConfig := config.ComprehensiveTestConfig{ + Provider: schemas.Bedrock, + ChatModel: "anthropic.claude-3-sonnet-20240229-v1:0", + TextModel: "", // Bedrock Claude doesn't support text completion + EmbeddingModel: "amazon.titan-embed-text-v2:0", + Scenarios: config.TestScenarios{ + TextCompletion: false, // Not supported for Claude + SimpleChat: true, + ChatCompletionStream: true, + MultiTurnConversation: true, + ToolCalls: true, + MultipleToolCalls: true, + End2EndToolCalling: true, + AutomaticFunctionCall: true, + ImageURL: false, + ImageBase64: true, + MultipleImages: false, + CompleteEnd2End: true, + ProviderSpecific: true, + Embedding: true, + }, + } + + runAllComprehensiveTests(t, client, ctx, testConfig) +} diff --git a/tests/core-providers/cerebras_test.go b/tests/core-providers/cerebras_test.go new file mode 100644 index 000000000..9a750eeb9 --- /dev/null +++ b/tests/core-providers/cerebras_test.go @@ -0,0 +1,43 @@ +package tests + +import ( + "testing" + + "github.com/maximhq/bifrost/tests/core-providers/config" + + "github.com/maximhq/bifrost/core/schemas" +) + +func TestCerebras(t *testing.T) { + client, ctx, cancel, err := config.SetupTest() + if err != nil { + t.Fatalf("Error initializing test setup: %v", err) + } + defer cancel() + defer client.Shutdown() + + testConfig := config.ComprehensiveTestConfig{ + Provider: schemas.Cerebras, + ChatModel: "llama-3.3-70b", + TextModel: "llama3.1-8b", + EmbeddingModel: "", // Cerebras doesn't support embedding + Scenarios: config.TestScenarios{ + TextCompletion: true, + SimpleChat: true, + ChatCompletionStream: true, + MultiTurnConversation: true, + ToolCalls: true, + MultipleToolCalls: true, + End2EndToolCalling: true, + AutomaticFunctionCall: true, + ImageURL: false, + ImageBase64: false, + MultipleImages: false, + CompleteEnd2End: true, + ProviderSpecific: false, + Embedding: false, + }, + } + + runAllComprehensiveTests(t, client, ctx, testConfig) +} diff --git a/tests/core-providers/cohere_test.go b/tests/core-providers/cohere_test.go new file mode 100644 index 000000000..6ff9d7338 --- /dev/null +++ b/tests/core-providers/cohere_test.go @@ -0,0 +1,46 @@ +package tests + +import ( + "testing" + + "github.com/maximhq/bifrost/tests/core-providers/config" + + "github.com/maximhq/bifrost/core/schemas" +) + +func TestCohere(t *testing.T) { + client, ctx, cancel, err := config.SetupTest() + if err != nil { + t.Fatalf("Error initializing test setup: %v", err) + } + defer cancel() + defer client.Shutdown() + + testConfig := config.ComprehensiveTestConfig{ + Provider: schemas.Cohere, + ChatModel: "command-a-03-2025", + TextModel: "", // Cohere focuses on chat + EmbeddingModel: "embed-english-v3.0", + Scenarios: config.TestScenarios{ + TextCompletion: false, // Not typical for Cohere + SimpleChat: true, + ChatCompletionStream: true, + MultiTurnConversation: true, + ToolCalls: true, + MultipleToolCalls: true, + End2EndToolCalling: true, + AutomaticFunctionCall: false, // May not support automatic + ImageURL: false, // Check if supported + ImageBase64: false, // Check if supported + MultipleImages: false, // Check if supported + CompleteEnd2End: true, + ProviderSpecific: true, + Embedding: true, + }, + Fallbacks: []schemas.Fallback{ + {Provider: schemas.OpenAI, Model: "gpt-4o-mini"}, + }, + } + + runAllComprehensiveTests(t, client, ctx, testConfig) +} diff --git a/tests/core-providers/config/account.go b/tests/core-providers/config/account.go new file mode 100644 index 000000000..d2f26587a --- /dev/null +++ b/tests/core-providers/config/account.go @@ -0,0 +1,721 @@ +// Package config provides comprehensive test account and configuration management for the Bifrost system. +// It implements account functionality for testing purposes, supporting multiple AI providers +// and comprehensive test scenarios. +package config + +import ( + "context" + "fmt" + "os" + "time" + + bifrost "github.com/maximhq/bifrost/core" + "github.com/maximhq/bifrost/core/schemas" +) + +// ProviderOpenAICustom represents the custom OpenAI provider for testing +const ProviderOpenAICustom = schemas.ModelProvider("openai-custom") + +// TestScenarios defines the comprehensive test scenarios +type TestScenarios struct { + TextCompletion bool + SimpleChat bool + ChatCompletionStream bool + MultiTurnConversation bool + ToolCalls bool + MultipleToolCalls bool + End2EndToolCalling bool + AutomaticFunctionCall bool + ImageURL bool + ImageBase64 bool + MultipleImages bool + CompleteEnd2End bool + ProviderSpecific bool + SpeechSynthesis bool // Text-to-speech functionality + SpeechSynthesisStream bool // Streaming text-to-speech functionality + Transcription bool // Speech-to-text functionality + TranscriptionStream bool // Streaming speech-to-text functionality + Embedding bool // Embedding functionality +} + +// ComprehensiveTestConfig extends TestConfig with additional scenarios +type ComprehensiveTestConfig struct { + Provider schemas.ModelProvider + ChatModel string + TextModel string + EmbeddingModel string + TranscriptionModel string + SpeechSynthesisModel string + Scenarios TestScenarios + CustomParams *schemas.ModelParameters + Fallbacks []schemas.Fallback + SkipReason string // Reason to skip certain tests +} + +// ComprehensiveTestAccount provides a test implementation of the Account interface for comprehensive testing. +type ComprehensiveTestAccount struct{} + +// getEnvWithDefault returns the value of the environment variable if set, otherwise returns the default value +func getEnvWithDefault(envVar, defaultValue string) string { + if value := os.Getenv(envVar); value != "" { + return value + } + return defaultValue +} + +// GetConfiguredProviders returns the list of initially supported providers. +func (account *ComprehensiveTestAccount) GetConfiguredProviders() ([]schemas.ModelProvider, error) { + return []schemas.ModelProvider{ + schemas.OpenAI, + schemas.Anthropic, + schemas.Bedrock, + schemas.Cohere, + schemas.Azure, + schemas.Vertex, + schemas.Ollama, + schemas.Mistral, + schemas.Groq, + schemas.SGL, + schemas.Parasail, + schemas.Cerebras, + schemas.Gemini, + schemas.OpenRouter, + ProviderOpenAICustom, + }, nil +} + +// GetKeysForProvider returns the API keys and associated models for a given provider. +func (account *ComprehensiveTestAccount) GetKeysForProvider(ctx *context.Context, providerKey schemas.ModelProvider) ([]schemas.Key, error) { + switch providerKey { + case schemas.OpenAI: + return []schemas.Key{ + { + Value: os.Getenv("OPENAI_API_KEY"), + Models: []string{}, + Weight: 1.0, + }, + }, nil + case ProviderOpenAICustom: + return []schemas.Key{ + { + Value: os.Getenv("GROQ_API_KEY"), // Use GROQ API key for OpenAI-compatible endpoint + Models: []string{}, + Weight: 1.0, + }, + }, nil + case schemas.Anthropic: + return []schemas.Key{ + { + Value: os.Getenv("ANTHROPIC_API_KEY"), + Models: []string{"claude-3-7-sonnet-20250219", "claude-3-5-sonnet-20240620", "claude-2.1"}, + Weight: 1.0, + }, + }, nil + case schemas.Bedrock: + return []schemas.Key{ + { + Models: []string{"anthropic.claude-v2:1", "mistral.mixtral-8x7b-instruct-v0:1", "mistral.mistral-large-2402-v1:0", "anthropic.claude-3-sonnet-20240229-v1:0"}, + Weight: 1.0, + BedrockKeyConfig: &schemas.BedrockKeyConfig{ + AccessKey: os.Getenv("AWS_ACCESS_KEY_ID"), + SecretKey: os.Getenv("AWS_SECRET_ACCESS_KEY"), + SessionToken: bifrost.Ptr(os.Getenv("AWS_SESSION_TOKEN")), + Region: bifrost.Ptr(getEnvWithDefault("AWS_REGION", "us-east-1")), + }, + }, + }, nil + case schemas.Cohere: + return []schemas.Key{ + { + Value: os.Getenv("COHERE_API_KEY"), + Models: []string{"command-a-03-2025", "c4ai-aya-vision-8b"}, + Weight: 1.0, + }, + }, nil + case schemas.Azure: + return []schemas.Key{ + { + Value: os.Getenv("AZURE_API_KEY"), + Models: []string{"gpt-4o", "text-embedding-3-small"}, + Weight: 1.0, + AzureKeyConfig: &schemas.AzureKeyConfig{ + Endpoint: os.Getenv("AZURE_ENDPOINT"), + Deployments: map[string]string{ + "gpt-4o": "gpt-4o-aug", + "text-embedding-3-small": "text-embedding-3-small-deployment", + }, + // Use environment variable for API version with fallback to current preview version + // Note: This is a preview API version that may change over time. Update as needed. + // Set AZURE_API_VERSION environment variable to override the default. + APIVersion: bifrost.Ptr(getEnvWithDefault("AZURE_API_VERSION", "2024-08-01-preview")), + }, + }, + }, nil + case schemas.Vertex: + return []schemas.Key{ + { + Value: os.Getenv("VERTEX_API_KEY"), + Models: []string{}, + Weight: 1.0, + VertexKeyConfig: &schemas.VertexKeyConfig{ + ProjectID: os.Getenv("VERTEX_PROJECT_ID"), + Region: getEnvWithDefault("VERTEX_REGION", "us-central1"), + AuthCredentials: os.Getenv("VERTEX_CREDENTIALS"), + }, + }, + }, nil + case schemas.Mistral: + return []schemas.Key{ + { + Value: os.Getenv("MISTRAL_API_KEY"), + Models: []string{"mistral-large-2411", "pixtral-12b-latest", "mistral-embed"}, + Weight: 1.0, + }, + }, nil + case schemas.Groq: + return []schemas.Key{ + { + Value: os.Getenv("GROQ_API_KEY"), + Models: []string{}, + Weight: 1.0, + }, + }, nil + case schemas.Parasail: + return []schemas.Key{ + { + Value: os.Getenv("PARASAIL_API_KEY"), + Models: []string{}, + Weight: 1.0, + }, + }, nil + case schemas.Cerebras: + return []schemas.Key{ + { + Value: os.Getenv("CEREBRAS_API_KEY"), + Models: []string{}, + Weight: 1.0, + }, + }, nil + case schemas.Gemini: + return []schemas.Key{ + { + Value: os.Getenv("GEMINI_API_KEY"), + Models: []string{}, + Weight: 1.0, + }, + }, nil + case schemas.OpenRouter: + return []schemas.Key{ + { + Value: os.Getenv("OPENROUTER_API_KEY"), + Models: []string{}, + Weight: 1.0, + }, + }, nil + default: + return nil, fmt.Errorf("unsupported provider: %s", providerKey) + } +} + +// GetConfigForProvider returns the configuration settings for a given provider. +func (account *ComprehensiveTestAccount) GetConfigForProvider(providerKey schemas.ModelProvider) (*schemas.ProviderConfig, error) { + switch providerKey { + case schemas.OpenAI: + return &schemas.ProviderConfig{ + NetworkConfig: schemas.NetworkConfig{ + DefaultRequestTimeoutInSeconds: 60, + MaxRetries: 1, + RetryBackoffInitial: 100 * time.Millisecond, + RetryBackoffMax: 2 * time.Second, + }, + ConcurrencyAndBufferSize: schemas.ConcurrencyAndBufferSize{ + Concurrency: 3, + BufferSize: 10, + }, + }, nil + case ProviderOpenAICustom: + return &schemas.ProviderConfig{ + NetworkConfig: schemas.NetworkConfig{ + BaseURL: getEnvWithDefault("GROQ_OPENAI_BASE_URL", "https://api.groq.com/openai"), + DefaultRequestTimeoutInSeconds: 60, + MaxRetries: 1, + RetryBackoffInitial: 100 * time.Millisecond, + RetryBackoffMax: 2 * time.Second, + }, + ConcurrencyAndBufferSize: schemas.ConcurrencyAndBufferSize{ + Concurrency: 3, + BufferSize: 10, + }, + CustomProviderConfig: &schemas.CustomProviderConfig{ + BaseProviderType: schemas.OpenAI, + AllowedRequests: &schemas.AllowedRequests{ + TextCompletion: false, + ChatCompletion: true, + ChatCompletionStream: true, + Embedding: false, + Speech: false, + SpeechStream: false, + Transcription: false, + TranscriptionStream: false, + }, + }, + }, nil + case schemas.Anthropic: + return &schemas.ProviderConfig{ + NetworkConfig: schemas.DefaultNetworkConfig, + ConcurrencyAndBufferSize: schemas.DefaultConcurrencyAndBufferSize, + }, nil + case schemas.Bedrock: + return &schemas.ProviderConfig{ + NetworkConfig: schemas.NetworkConfig{ + DefaultRequestTimeoutInSeconds: 60, + MaxRetries: 1, + RetryBackoffInitial: 100 * time.Millisecond, + RetryBackoffMax: 2 * time.Second, + }, + ConcurrencyAndBufferSize: schemas.ConcurrencyAndBufferSize{ + Concurrency: 3, + BufferSize: 10, + }, + }, nil + case schemas.Cohere: + return &schemas.ProviderConfig{ + NetworkConfig: schemas.DefaultNetworkConfig, + ConcurrencyAndBufferSize: schemas.DefaultConcurrencyAndBufferSize, + }, nil + case schemas.Azure: + return &schemas.ProviderConfig{ + NetworkConfig: schemas.NetworkConfig{ + DefaultRequestTimeoutInSeconds: 60, + MaxRetries: 1, + RetryBackoffInitial: 100 * time.Millisecond, + RetryBackoffMax: 2 * time.Second, + }, + ConcurrencyAndBufferSize: schemas.ConcurrencyAndBufferSize{ + Concurrency: 3, + BufferSize: 10, + }, + }, nil + case schemas.Vertex: + return &schemas.ProviderConfig{ + NetworkConfig: schemas.NetworkConfig{ + DefaultRequestTimeoutInSeconds: 60, + MaxRetries: 1, + RetryBackoffInitial: 100 * time.Millisecond, + RetryBackoffMax: 2 * time.Second, + }, + ConcurrencyAndBufferSize: schemas.ConcurrencyAndBufferSize{ + Concurrency: 3, + BufferSize: 10, + }, + }, nil + case schemas.Ollama: + return &schemas.ProviderConfig{ + NetworkConfig: schemas.NetworkConfig{ + DefaultRequestTimeoutInSeconds: 60, + MaxRetries: 1, + RetryBackoffInitial: 100 * time.Millisecond, + RetryBackoffMax: 2 * time.Second, + BaseURL: getEnvWithDefault("OLLAMA_BASE_URL", "http://localhost:11434"), + }, + ConcurrencyAndBufferSize: schemas.DefaultConcurrencyAndBufferSize, + }, nil + case schemas.Mistral: + return &schemas.ProviderConfig{ + NetworkConfig: schemas.DefaultNetworkConfig, + ConcurrencyAndBufferSize: schemas.DefaultConcurrencyAndBufferSize, + }, nil + case schemas.Groq: + return &schemas.ProviderConfig{ + NetworkConfig: schemas.DefaultNetworkConfig, + ConcurrencyAndBufferSize: schemas.DefaultConcurrencyAndBufferSize, + }, nil + case schemas.SGL: + return &schemas.ProviderConfig{ + NetworkConfig: schemas.NetworkConfig{ + BaseURL: os.Getenv("SGL_BASE_URL"), + DefaultRequestTimeoutInSeconds: 60, + MaxRetries: 1, + RetryBackoffInitial: 100 * time.Millisecond, + RetryBackoffMax: 2 * time.Second, + }, + ConcurrencyAndBufferSize: schemas.DefaultConcurrencyAndBufferSize, + }, nil + case schemas.Parasail: + return &schemas.ProviderConfig{ + NetworkConfig: schemas.DefaultNetworkConfig, + ConcurrencyAndBufferSize: schemas.DefaultConcurrencyAndBufferSize, + }, nil + case schemas.Cerebras: + return &schemas.ProviderConfig{ + NetworkConfig: schemas.DefaultNetworkConfig, + ConcurrencyAndBufferSize: schemas.DefaultConcurrencyAndBufferSize, + }, nil + case schemas.Gemini: + return &schemas.ProviderConfig{ + NetworkConfig: schemas.NetworkConfig{ + DefaultRequestTimeoutInSeconds: 60, + MaxRetries: 1, + RetryBackoffInitial: 100 * time.Millisecond, + RetryBackoffMax: 2 * time.Second, + }, + ConcurrencyAndBufferSize: schemas.DefaultConcurrencyAndBufferSize, + }, nil + case schemas.OpenRouter: + return &schemas.ProviderConfig{ + NetworkConfig: schemas.NetworkConfig{ + DefaultRequestTimeoutInSeconds: 60, + MaxRetries: 1, + RetryBackoffInitial: 100 * time.Millisecond, + RetryBackoffMax: 2 * time.Second, + }, + ConcurrencyAndBufferSize: schemas.DefaultConcurrencyAndBufferSize, + }, nil + default: + return nil, fmt.Errorf("unsupported provider: %s", providerKey) + } +} + +// AllProviderConfigs contains test configurations for all providers +var AllProviderConfigs = []ComprehensiveTestConfig{ + { + Provider: schemas.OpenAI, + ChatModel: "gpt-4o-mini", + TextModel: "", // OpenAI doesn't support text completion in newer models + TranscriptionModel: "whisper-1", + SpeechSynthesisModel: "tts-1", + Scenarios: TestScenarios{ + TextCompletion: false, // Not supported + SimpleChat: true, + ChatCompletionStream: true, + MultiTurnConversation: true, + ToolCalls: true, + MultipleToolCalls: true, + End2EndToolCalling: true, + AutomaticFunctionCall: true, + ImageURL: true, + ImageBase64: true, + MultipleImages: true, + CompleteEnd2End: true, + ProviderSpecific: true, + SpeechSynthesis: true, // OpenAI supports TTS + SpeechSynthesisStream: true, // OpenAI supports streaming TTS + Transcription: true, // OpenAI supports STT with Whisper + TranscriptionStream: true, // OpenAI supports streaming STT + Embedding: true, + }, + Fallbacks: []schemas.Fallback{ + {Provider: schemas.Anthropic, Model: "claude-3-7-sonnet-20250219"}, + }, + }, + { + Provider: schemas.Anthropic, + ChatModel: "claude-3-7-sonnet-20250219", + TextModel: "", // Anthropic doesn't support text completion + Scenarios: TestScenarios{ + TextCompletion: false, // Not supported + SimpleChat: true, + ChatCompletionStream: true, + MultiTurnConversation: true, + ToolCalls: true, + MultipleToolCalls: true, + End2EndToolCalling: true, + AutomaticFunctionCall: true, + ImageURL: true, + ImageBase64: true, + MultipleImages: true, + CompleteEnd2End: true, + ProviderSpecific: true, + SpeechSynthesis: false, // Not supported + SpeechSynthesisStream: false, // Not supported + Transcription: false, // Not supported + TranscriptionStream: false, // Not supported + Embedding: false, + }, + Fallbacks: []schemas.Fallback{ + {Provider: schemas.OpenAI, Model: "gpt-4o-mini"}, + }, + }, + { + Provider: schemas.Bedrock, + ChatModel: "anthropic.claude-3-sonnet-20240229-v1:0", + TextModel: "", // Bedrock Claude doesn't support text completion + Scenarios: TestScenarios{ + TextCompletion: false, // Not supported for Claude + SimpleChat: true, + ChatCompletionStream: true, + MultiTurnConversation: true, + ToolCalls: true, + MultipleToolCalls: true, + End2EndToolCalling: true, + AutomaticFunctionCall: true, + ImageURL: true, + ImageBase64: true, + MultipleImages: true, + CompleteEnd2End: true, + ProviderSpecific: true, + SpeechSynthesis: false, // Not supported + SpeechSynthesisStream: false, // Not supported + Transcription: false, // Not supported + TranscriptionStream: false, // Not supported + Embedding: true, + }, + Fallbacks: []schemas.Fallback{ + {Provider: schemas.OpenAI, Model: "gpt-4o-mini"}, + }, + }, + { + Provider: schemas.Cohere, + ChatModel: "command-a-03-2025", + TextModel: "", // Cohere focuses on chat + Scenarios: TestScenarios{ + TextCompletion: false, // Not typical for Cohere + SimpleChat: true, + ChatCompletionStream: true, + MultiTurnConversation: true, + ToolCalls: true, + MultipleToolCalls: true, + End2EndToolCalling: true, + AutomaticFunctionCall: false, // May not support automatic + ImageURL: false, // Check if supported + ImageBase64: false, // Check if supported + MultipleImages: false, // Check if supported + CompleteEnd2End: true, + ProviderSpecific: true, + SpeechSynthesis: false, // Not supported + SpeechSynthesisStream: false, // Not supported + Transcription: false, // Not supported + TranscriptionStream: false, // Not supported + Embedding: true, + }, + Fallbacks: []schemas.Fallback{ + {Provider: schemas.OpenAI, Model: "gpt-4o-mini"}, + }, + }, + { + Provider: schemas.Azure, + ChatModel: "gpt-4o", + TextModel: "", // Azure OpenAI doesn't support text completion in newer models + Scenarios: TestScenarios{ + TextCompletion: false, // Not supported + SimpleChat: true, + ChatCompletionStream: true, + MultiTurnConversation: true, + ToolCalls: true, + MultipleToolCalls: true, + End2EndToolCalling: true, + AutomaticFunctionCall: true, + ImageURL: true, + ImageBase64: true, + MultipleImages: true, + CompleteEnd2End: true, + ProviderSpecific: true, + SpeechSynthesis: false, // Not supported yet + SpeechSynthesisStream: false, // Not supported yet + Transcription: false, // Not supported yet + TranscriptionStream: false, // Not supported yet + Embedding: true, + }, + Fallbacks: []schemas.Fallback{ + {Provider: schemas.OpenAI, Model: "gpt-4o-mini"}, + }, + }, + { + Provider: schemas.Vertex, + ChatModel: "gemini-pro", + TextModel: "", // Vertex focuses on chat + Scenarios: TestScenarios{ + TextCompletion: false, // Not typical + SimpleChat: true, + ChatCompletionStream: true, + MultiTurnConversation: true, + ToolCalls: true, + MultipleToolCalls: true, + End2EndToolCalling: true, + AutomaticFunctionCall: true, + ImageURL: true, + ImageBase64: true, + MultipleImages: true, + CompleteEnd2End: true, + ProviderSpecific: true, + SpeechSynthesis: false, // Not supported + SpeechSynthesisStream: false, // Not supported + Transcription: false, // Not supported + TranscriptionStream: false, // Not supported + Embedding: true, + }, + Fallbacks: []schemas.Fallback{ + {Provider: schemas.OpenAI, Model: "gpt-4o-mini"}, + }, + }, + { + Provider: schemas.Mistral, + ChatModel: "mistral-large-2411", + TextModel: "", // Mistral focuses on chat + Scenarios: TestScenarios{ + TextCompletion: false, // Not typical + SimpleChat: true, + MultiTurnConversation: true, + ToolCalls: true, + MultipleToolCalls: true, + End2EndToolCalling: true, + AutomaticFunctionCall: true, + ImageURL: true, + ImageBase64: true, + MultipleImages: true, + CompleteEnd2End: true, + ProviderSpecific: true, + SpeechSynthesis: false, // Not supported + SpeechSynthesisStream: false, // Not supported + Transcription: false, // Not supported + TranscriptionStream: false, // Not supported + Embedding: true, + }, + Fallbacks: []schemas.Fallback{ + {Provider: schemas.OpenAI, Model: "gpt-4o-mini"}, + }, + }, + { + Provider: schemas.Ollama, + ChatModel: "llama3.2", + TextModel: "", // Ollama focuses on chat + Scenarios: TestScenarios{ + TextCompletion: false, // Not typical + SimpleChat: true, + ChatCompletionStream: true, + MultiTurnConversation: true, + ToolCalls: true, + MultipleToolCalls: true, + End2EndToolCalling: true, + AutomaticFunctionCall: true, + ImageURL: true, + ImageBase64: true, + MultipleImages: true, + CompleteEnd2End: true, + ProviderSpecific: true, + SpeechSynthesis: false, // Not supported + SpeechSynthesisStream: false, // Not supported + Transcription: false, // Not supported + TranscriptionStream: false, // Not supported + Embedding: false, + }, + Fallbacks: []schemas.Fallback{ + {Provider: schemas.OpenAI, Model: "gpt-4o-mini"}, + }, + }, + { + Provider: schemas.Groq, + ChatModel: "llama-3.3-70b-versatile", + TextModel: "", // Groq doesn't support text completion + Scenarios: TestScenarios{ + TextCompletion: false, // Not supported + SimpleChat: true, + ChatCompletionStream: true, + MultiTurnConversation: true, + ToolCalls: true, + MultipleToolCalls: true, + End2EndToolCalling: true, + AutomaticFunctionCall: true, + ImageURL: true, + ImageBase64: true, + MultipleImages: true, + CompleteEnd2End: true, + ProviderSpecific: true, + SpeechSynthesis: false, // Not supported + SpeechSynthesisStream: false, // Not supported + Transcription: false, // Not supported + TranscriptionStream: false, // Not supported + Embedding: false, + }, + Fallbacks: []schemas.Fallback{ + {Provider: schemas.OpenAI, Model: "gpt-4o-mini"}, + }, + }, + { + Provider: ProviderOpenAICustom, + ChatModel: "llama-3.3-70b-versatile", + TextModel: "", // Custom OpenAI instance doesn't support text completion + Scenarios: TestScenarios{ + TextCompletion: false, + SimpleChat: true, // Enable simple chat for testing + ChatCompletionStream: true, + MultiTurnConversation: true, + ToolCalls: true, + MultipleToolCalls: true, + End2EndToolCalling: true, + AutomaticFunctionCall: true, + ImageURL: false, + ImageBase64: false, + MultipleImages: false, + CompleteEnd2End: true, + ProviderSpecific: true, + SpeechSynthesis: false, // Not supported + SpeechSynthesisStream: false, // Not supported + Transcription: false, // Not supported + TranscriptionStream: false, // Not supported + Embedding: false, + }, + Fallbacks: []schemas.Fallback{ + {Provider: schemas.OpenAI, Model: "gpt-4o-mini"}, + }, + }, + { + Provider: schemas.Gemini, + ChatModel: "gemini-2.0-flash", + TextModel: "", // GenAI doesn't support text completion in newer models + TranscriptionModel: "gemini-2.5-flash", + SpeechSynthesisModel: "gemini-2.5-flash-preview-tts", + EmbeddingModel: "text-embedding-004", + Scenarios: TestScenarios{ + TextCompletion: false, // Not supported + SimpleChat: true, + ChatCompletionStream: true, + MultiTurnConversation: true, + ToolCalls: true, + MultipleToolCalls: true, + End2EndToolCalling: true, + AutomaticFunctionCall: true, + ImageURL: true, + ImageBase64: true, + MultipleImages: true, + CompleteEnd2End: true, + ProviderSpecific: true, + SpeechSynthesis: true, + SpeechSynthesisStream: true, + Transcription: true, + TranscriptionStream: true, + Embedding: true, + }, + Fallbacks: []schemas.Fallback{ + {Provider: schemas.OpenAI, Model: "gpt-4o-mini"}, + }, + }, + { + Provider: schemas.OpenRouter, + ChatModel: "openai/gpt-4o", + TextModel: "google/gemini-2.5-flash", + Scenarios: TestScenarios{ + TextCompletion: true, + SimpleChat: true, + ChatCompletionStream: true, + MultiTurnConversation: true, + ToolCalls: true, + MultipleToolCalls: true, + End2EndToolCalling: true, + AutomaticFunctionCall: true, + ImageURL: true, + ImageBase64: true, + MultipleImages: true, + CompleteEnd2End: true, + ProviderSpecific: true, + SpeechSynthesis: false, + SpeechSynthesisStream: false, + Transcription: false, + TranscriptionStream: false, + Embedding: false, + }, + Fallbacks: []schemas.Fallback{ + {Provider: schemas.OpenAI, Model: "gpt-4o-mini"}, + }, + }, +} diff --git a/tests/core-providers/config/setup.go b/tests/core-providers/config/setup.go new file mode 100644 index 000000000..8760edcc4 --- /dev/null +++ b/tests/core-providers/config/setup.go @@ -0,0 +1,60 @@ +// Package config provides comprehensive test utilities and configurations for the Bifrost system. +// It includes comprehensive test implementations covering all major AI provider scenarios, +// including text completion, chat, tool calling, image processing, and end-to-end workflows. +package config + +import ( + "context" + "time" + + bifrost "github.com/maximhq/bifrost/core" + "github.com/maximhq/bifrost/core/schemas" +) + +// Constants for test configuration +const ( + // TestTimeout defines the maximum duration for comprehensive tests + // Set to 5 minutes to allow for complex multi-step operations + TestTimeout = 5 * time.Minute +) + +// getBifrost initializes and returns a Bifrost instance for comprehensive testing. +// It sets up the comprehensive test account, plugin, and logger configuration. +// +// Environment variables are expected to be set by the system or test runner before calling this function. +// The account configuration will read API keys and settings from these environment variables. +// +// Returns: +// - *bifrost.Bifrost: A configured Bifrost instance ready for comprehensive testing +// - error: Any error that occurred during Bifrost initialization +// +// The function: +// 1. Creates a comprehensive test account instance +// 2. Configures Bifrost with the account and default logger +func getBifrost(ctx context.Context) (*bifrost.Bifrost, error) { + account := ComprehensiveTestAccount{} + + // Initialize Bifrost + b, err := bifrost.Init(ctx, schemas.BifrostConfig{ + Account: &account, + Plugins: nil, + Logger: bifrost.NewDefaultLogger(schemas.LogLevelDebug), + }) + if err != nil { + return nil, err + } + + return b, nil +} + +// SetupTest initializes a test environment with timeout context +func SetupTest() (*bifrost.Bifrost, context.Context, context.CancelFunc, error) { + ctx, cancel := context.WithTimeout(context.Background(), TestTimeout) + client, err := getBifrost(ctx) + if err != nil { + cancel() + return nil, nil, nil, err + } + + return client, ctx, cancel, nil +} diff --git a/tests/core-providers/custom_test.go b/tests/core-providers/custom_test.go new file mode 100644 index 000000000..2a650cc06 --- /dev/null +++ b/tests/core-providers/custom_test.go @@ -0,0 +1,135 @@ +package tests + +import ( + "os" + "strings" + "testing" + + "github.com/maximhq/bifrost/tests/core-providers/config" + + bifrost "github.com/maximhq/bifrost/core" + "github.com/maximhq/bifrost/core/schemas" + "github.com/stretchr/testify/assert" +) + +func TestCustomProvider(t *testing.T) { + client, ctx, cancel, err := config.SetupTest() + if err != nil { + t.Fatalf("Error initializing test setup: %v", err) + } + defer cancel() + defer client.Shutdown() + + testConfig := config.ComprehensiveTestConfig{ + Provider: config.ProviderOpenAICustom, + ChatModel: "llama-3.3-70b-versatile", + TextModel: "", // OpenAI doesn't support text completion in newer models + EmbeddingModel: "", // groq custom base: embeddings not supported + Scenarios: config.TestScenarios{ + TextCompletion: false, // Not supported + SimpleChat: true, + ChatCompletionStream: true, + MultiTurnConversation: true, + ToolCalls: true, + MultipleToolCalls: true, + End2EndToolCalling: true, + AutomaticFunctionCall: true, + ImageURL: false, + ImageBase64: false, + MultipleImages: false, + CompleteEnd2End: true, + ProviderSpecific: true, + Embedding: false, + }, + Fallbacks: []schemas.Fallback{ + {Provider: schemas.Anthropic, Model: "claude-3-7-sonnet-20250219"}, + }, + } + + runAllComprehensiveTests(t, client, ctx, testConfig) +} + +func TestCustomProvider_DisallowedOperation(t *testing.T) { + // Skip test if required API key is not available + if os.Getenv("GROQ_API_KEY") == "" { + t.Skipf("skipping test: GROQ_API_KEY not set") + } + + client, ctx, cancel, err := config.SetupTest() + if err != nil { + t.Fatalf("Error initializing test setup: %v", err) + } + defer cancel() + defer client.Shutdown() + + + // Create a speech request to the custom provider + prompt := "The future of artificial intelligence is" + request := &schemas.BifrostRequest{ + Provider: config.ProviderOpenAICustom, // Use the custom provider + Model: "llama-3.3-70b-versatile", // Use a model that exists for this provider + Input: schemas.RequestInput{ + SpeechInput: &schemas.SpeechInput{ + Input: prompt, + }, + }, + Params: &schemas.ModelParameters{ + MaxTokens: bifrost.Ptr(100), + }, + } + + // Attempt to make a speech stream request + response, bifrostErr := client.SpeechStreamRequest(ctx, request) + + // Assert that the request failed with an error + assert.NotNil(t, bifrostErr, "Expected error for disallowed speech stream operation") + assert.Nil(t, response, "Expected no response for disallowed operation") + + // Assert that the error message contains "not supported" or "not supported by openai-custom" + msg := strings.ToLower(bifrostErr.Error.Message) + assert.Contains(t, msg, "not supported", "error should indicate operation is not supported") + assert.Contains(t, msg, string(config.ProviderOpenAICustom), "error should mention refusing provider") + assert.Equal(t, config.ProviderOpenAICustom, bifrostErr.Provider, "error should be attributed to the custom provider") +} + +func TestCustomProvider_MismatchedIdentity(t *testing.T) { + client, ctx, cancel, err := config.SetupTest() + if err != nil { + t.Fatalf("Error initializing test setup: %v", err) + } + defer cancel() + defer client.Shutdown() + + // Use a provider that doesn't exist + wrongProvider := schemas.ModelProvider("wrong-provider") + + request := &schemas.BifrostRequest{ + Provider: wrongProvider, + Model: "llama-3.3-70b-versatile", + Input: schemas.RequestInput{ + ChatCompletionInput: &[]schemas.BifrostMessage{ + { + Role: schemas.ModelChatMessageRoleUser, + Content: schemas.MessageContent{ + ContentStr: bifrost.Ptr("Hello! What's the capital of France?"), + }, + }, + }, + }, + Params: &schemas.ModelParameters{ + MaxTokens: bifrost.Ptr(100), + }, + } + + // Attempt to make a chat completion request + response, bifrostErr := client.ChatCompletionRequest(ctx, request) + + // Assert that the request failed with an error + assert.NotNil(t, bifrostErr, "Expected error for mismatched identity") + assert.Nil(t, response, "Expected no response for mismatched identity") + + msg := strings.ToLower(bifrostErr.Error.Message) + assert.Contains(t, msg, "unsupported provider", "error should mention unsupported provider") + assert.Contains(t, msg, strings.ToLower(string(wrongProvider)), "error should mention the wrong provider") + assert.Equal(t, wrongProvider, bifrostErr.Provider, "error should include the unsupported provider identity") +} diff --git a/tests/core-providers/gemini_test.go b/tests/core-providers/gemini_test.go new file mode 100644 index 000000000..53f145ab3 --- /dev/null +++ b/tests/core-providers/gemini_test.go @@ -0,0 +1,54 @@ +package tests + +import ( + "os" + "testing" + + "github.com/maximhq/bifrost/tests/core-providers/config" + + "github.com/maximhq/bifrost/core/schemas" +) + +func TestGemini(t *testing.T) { + if os.Getenv("GEMINI_API_KEY") == "" { + t.Skip("GEMINI_API_KEY not set; skipping Gemini tests") + } + + client, ctx, cancel, err := config.SetupTest() + if err != nil { + t.Fatalf("Error initializing test setup: %v", err) + } + defer cancel() + defer client.Shutdown() + + testConfig := config.ComprehensiveTestConfig{ + Provider: schemas.Gemini, + ChatModel: "gemini-2.0-flash", + TextModel: "", // Gemini doesn't support text completion + EmbeddingModel: "text-embedding-004", + TranscriptionModel: "gemini-2.5-flash", + SpeechSynthesisModel: "gemini-2.5-flash-preview-tts", + Scenarios: config.TestScenarios{ + TextCompletion: false, // Not supported + SimpleChat: true, + ChatCompletionStream: true, + MultiTurnConversation: true, + ToolCalls: true, + MultipleToolCalls: true, + End2EndToolCalling: true, + AutomaticFunctionCall: true, + ImageURL: false, + ImageBase64: true, + MultipleImages: true, + CompleteEnd2End: true, + ProviderSpecific: true, + Embedding: true, + Transcription: true, + TranscriptionStream: true, + SpeechSynthesis: true, + SpeechSynthesisStream: true, + }, + } + + runAllComprehensiveTests(t, client, ctx, testConfig) +} diff --git a/tests/core-providers/go.mod b/tests/core-providers/go.mod new file mode 100644 index 000000000..b232a1fd1 --- /dev/null +++ b/tests/core-providers/go.mod @@ -0,0 +1,58 @@ +module github.com/maximhq/bifrost/tests/core-providers + +go 1.24 + +toolchain go1.24.3 + +require ( + github.com/maximhq/bifrost/core v1.1.21 + github.com/stretchr/testify v1.10.0 +) + +replace github.com/maximhq/bifrost/core => ../../core + +require ( + cloud.google.com/go/compute/metadata v0.8.0 // indirect + github.com/andybalholm/brotli v1.2.0 // indirect + github.com/aws/aws-sdk-go-v2 v1.38.0 // indirect + github.com/aws/aws-sdk-go-v2/config v1.31.0 // indirect + github.com/aws/aws-sdk-go-v2/credentials v1.18.4 // indirect + github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.18.3 // indirect + github.com/aws/aws-sdk-go-v2/internal/configsources v1.4.3 // indirect + github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.7.3 // indirect + github.com/aws/aws-sdk-go-v2/internal/ini v1.8.3 // indirect + github.com/aws/aws-sdk-go-v2/service/internal/accept-encoding v1.13.0 // indirect + github.com/aws/aws-sdk-go-v2/service/internal/presigned-url v1.13.3 // indirect + github.com/aws/aws-sdk-go-v2/service/sso v1.28.0 // indirect + github.com/aws/aws-sdk-go-v2/service/ssooidc v1.33.0 // indirect + github.com/aws/aws-sdk-go-v2/service/sts v1.37.0 // indirect + github.com/aws/smithy-go v1.22.5 // indirect + github.com/bahlo/generic-list-go v0.2.0 // indirect + github.com/buger/jsonparser v1.1.1 // indirect + github.com/bytedance/sonic v1.14.0 // indirect + github.com/bytedance/sonic/loader v0.3.0 // indirect + github.com/cloudwego/base64x v0.1.6 // indirect + github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc // indirect + github.com/google/uuid v1.6.0 // indirect + github.com/invopop/jsonschema v0.13.0 // indirect + github.com/klauspost/compress v1.18.0 // indirect + github.com/klauspost/cpuid/v2 v2.3.0 // indirect + github.com/mailru/easyjson v0.9.0 // indirect + github.com/mark3labs/mcp-go v0.37.0 // indirect + github.com/mattn/go-colorable v0.1.14 // indirect + github.com/mattn/go-isatty v0.0.20 // indirect + github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2 // indirect + github.com/rs/zerolog v1.34.0 // indirect + github.com/spf13/cast v1.9.2 // indirect + github.com/twitchyliquid64/golang-asm v0.15.1 // indirect + github.com/valyala/bytebufferpool v1.0.0 // indirect + github.com/valyala/fasthttp v1.65.0 // indirect + github.com/wk8/go-ordered-map/v2 v2.1.8 // indirect + github.com/yosida95/uritemplate/v3 v3.0.2 // indirect + golang.org/x/arch v0.20.0 // indirect + golang.org/x/net v0.43.0 // indirect + golang.org/x/oauth2 v0.30.0 // indirect + golang.org/x/sys v0.35.0 // indirect + golang.org/x/text v0.28.0 // indirect + gopkg.in/yaml.v3 v3.0.1 // indirect +) diff --git a/tests/core-providers/go.sum b/tests/core-providers/go.sum new file mode 100644 index 000000000..feb125811 --- /dev/null +++ b/tests/core-providers/go.sum @@ -0,0 +1,123 @@ +cloud.google.com/go/compute/metadata v0.8.0 h1:HxMRIbao8w17ZX6wBnjhcDkW6lTFpgcaobyVfZWqRLA= +cloud.google.com/go/compute/metadata v0.8.0/go.mod h1:sYOGTp851OV9bOFJ9CH7elVvyzopvWQFNNghtDQ/Biw= +github.com/andybalholm/brotli v1.2.0 h1:ukwgCxwYrmACq68yiUqwIWnGY0cTPox/M94sVwToPjQ= +github.com/andybalholm/brotli v1.2.0/go.mod h1:rzTDkvFWvIrjDXZHkuS16NPggd91W3kUSvPlQ1pLaKY= +github.com/aws/aws-sdk-go-v2 v1.38.0 h1:UCRQ5mlqcFk9HJDIqENSLR3wiG1VTWlyUfLDEvY7RxU= +github.com/aws/aws-sdk-go-v2 v1.38.0/go.mod h1:9Q0OoGQoboYIAJyslFyF1f5K1Ryddop8gqMhWx/n4Wg= +github.com/aws/aws-sdk-go-v2/config v1.31.0 h1:9yH0xiY5fUnVNLRWO0AtayqwU1ndriZdN78LlhruJR4= +github.com/aws/aws-sdk-go-v2/config v1.31.0/go.mod h1:VeV3K72nXnhbe4EuxxhzsDc/ByrCSlZwUnWH52Nde/I= +github.com/aws/aws-sdk-go-v2/credentials v1.18.4 h1:IPd0Algf1b+Qy9BcDp0sCUcIWdCQPSzDoMK3a8pcbUM= +github.com/aws/aws-sdk-go-v2/credentials v1.18.4/go.mod h1:nwg78FjH2qvsRM1EVZlX9WuGUJOL5od+0qvm0adEzHk= +github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.18.3 h1:GicIdnekoJsjq9wqnvyi2elW6CGMSYKhdozE7/Svh78= +github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.18.3/go.mod h1:R7BIi6WNC5mc1kfRM7XM/VHC3uRWkjc396sfabq4iOo= +github.com/aws/aws-sdk-go-v2/internal/configsources v1.4.3 h1:o9RnO+YZ4X+kt5Z7Nvcishlz0nksIt2PIzDglLMP0vA= +github.com/aws/aws-sdk-go-v2/internal/configsources v1.4.3/go.mod h1:+6aLJzOG1fvMOyzIySYjOFjcguGvVRL68R+uoRencN4= +github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.7.3 h1:joyyUFhiTQQmVK6ImzNU9TQSNRNeD9kOklqTzyk5v6s= +github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.7.3/go.mod h1:+vNIyZQP3b3B1tSLI0lxvrU9cfM7gpdRXMFfm67ZcPc= +github.com/aws/aws-sdk-go-v2/internal/ini v1.8.3 h1:bIqFDwgGXXN1Kpp99pDOdKMTTb5d2KyU5X/BZxjOkRo= +github.com/aws/aws-sdk-go-v2/internal/ini v1.8.3/go.mod h1:H5O/EsxDWyU+LP/V8i5sm8cxoZgc2fdNR9bxlOFrQTo= +github.com/aws/aws-sdk-go-v2/service/internal/accept-encoding v1.13.0 h1:6+lZi2JeGKtCraAj1rpoZfKqnQ9SptseRZioejfUOLM= +github.com/aws/aws-sdk-go-v2/service/internal/accept-encoding v1.13.0/go.mod h1:eb3gfbVIxIoGgJsi9pGne19dhCBpK6opTYpQqAmdy44= +github.com/aws/aws-sdk-go-v2/service/internal/presigned-url v1.13.3 h1:ieRzyHXypu5ByllM7Sp4hC5f/1Fy5wqxqY0yB85hC7s= +github.com/aws/aws-sdk-go-v2/service/internal/presigned-url v1.13.3/go.mod h1:O5ROz8jHiOAKAwx179v+7sHMhfobFVi6nZt8DEyiYoM= +github.com/aws/aws-sdk-go-v2/service/sso v1.28.0 h1:Mc/MKBf2m4VynyJkABoVEN+QzkfLqGj0aiJuEe7cMeM= +github.com/aws/aws-sdk-go-v2/service/sso v1.28.0/go.mod h1:iS5OmxEcN4QIPXARGhavH7S8kETNL11kym6jhoS7IUQ= +github.com/aws/aws-sdk-go-v2/service/ssooidc v1.33.0 h1:6csaS/aJmqZQbKhi1EyEMM7yBW653Wy/B9hnBofW+sw= +github.com/aws/aws-sdk-go-v2/service/ssooidc v1.33.0/go.mod h1:59qHWaY5B+Rs7HGTuVGaC32m0rdpQ68N8QCN3khYiqs= +github.com/aws/aws-sdk-go-v2/service/sts v1.37.0 h1:MG9VFW43M4A8BYeAfaJJZWrroinxeTi2r3+SnmLQfSA= +github.com/aws/aws-sdk-go-v2/service/sts v1.37.0/go.mod h1:JdeBDPgpJfuS6rU/hNglmOigKhyEZtBmbraLE4GK1J8= +github.com/aws/smithy-go v1.22.5 h1:P9ATCXPMb2mPjYBgueqJNCA5S9UfktsW0tTxi+a7eqw= +github.com/aws/smithy-go v1.22.5/go.mod h1:t1ufH5HMublsJYulve2RKmHDC15xu1f26kHCp/HgceI= +github.com/bahlo/generic-list-go v0.2.0 h1:5sz/EEAK+ls5wF+NeqDpk5+iNdMDXrh3z3nPnH1Wvgk= +github.com/bahlo/generic-list-go v0.2.0/go.mod h1:2KvAjgMlE5NNynlg/5iLrrCCZ2+5xWbdbCW3pNTGyYg= +github.com/buger/jsonparser v1.1.1 h1:2PnMjfWD7wBILjqQbt530v576A/cAbQvEW9gGIpYMUs= +github.com/buger/jsonparser v1.1.1/go.mod h1:6RYKKt7H4d4+iWqouImQ9R2FZql3VbhNgx27UK13J/0= +github.com/bytedance/sonic v1.14.0 h1:/OfKt8HFw0kh2rj8N0F6C/qPGRESq0BbaNZgcNXXzQQ= +github.com/bytedance/sonic v1.14.0/go.mod h1:WoEbx8WTcFJfzCe0hbmyTGrfjt8PzNEBdxlNUO24NhA= +github.com/bytedance/sonic/loader v0.3.0 h1:dskwH8edlzNMctoruo8FPTJDF3vLtDT0sXZwvZJyqeA= +github.com/bytedance/sonic/loader v0.3.0/go.mod h1:N8A3vUdtUebEY2/VQC0MyhYeKUFosQU6FxH2JmUe6VI= +github.com/cloudwego/base64x v0.1.6 h1:t11wG9AECkCDk5fMSoxmufanudBtJ+/HemLstXDLI2M= +github.com/cloudwego/base64x v0.1.6/go.mod h1:OFcloc187FXDaYHvrNIjxSe8ncn0OOM8gEHfghB2IPU= +github.com/coreos/go-systemd/v22 v22.5.0/go.mod h1:Y58oyj3AT4RCenI/lSvhwexgC+NSVTIJ3seZv2GcEnc= +github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc h1:U9qPSI2PIWSS1VwoXQT9A3Wy9MM3WgvqSxFWenqJduM= +github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/frankban/quicktest v1.14.6 h1:7Xjx+VpznH+oBnejlPUj8oUpdxnVs4f8XU8WnHkI4W8= +github.com/frankban/quicktest v1.14.6/go.mod h1:4ptaffx2x8+WTWXmUCuVU6aPUX1/Mz7zb5vbUoiM6w0= +github.com/godbus/dbus/v5 v5.0.4/go.mod h1:xhWf0FNVPg57R7Z0UbKHbJfkEywrmjJnf7w5xrFpKfA= +github.com/google/go-cmp v0.7.0 h1:wk8382ETsv4JYUZwIsn6YpYiWiBsYLSJiTsyBybVuN8= +github.com/google/go-cmp v0.7.0/go.mod h1:pXiqmnSA92OHEEa9HXL2W4E7lf9JzCmGVUdgjX3N/iU= +github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0= +github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= +github.com/invopop/jsonschema v0.13.0 h1:KvpoAJWEjR3uD9Kbm2HWJmqsEaHt8lBUpd0qHcIi21E= +github.com/invopop/jsonschema v0.13.0/go.mod h1:ffZ5Km5SWWRAIN6wbDXItl95euhFz2uON45H2qjYt+0= +github.com/klauspost/compress v1.18.0 h1:c/Cqfb0r+Yi+JtIEq73FWXVkRonBlf0CRNYc8Zttxdo= +github.com/klauspost/compress v1.18.0/go.mod h1:2Pp+KzxcywXVXMr50+X0Q/Lsb43OQHYWRCY2AiWywWQ= +github.com/klauspost/cpuid/v2 v2.3.0 h1:S4CRMLnYUhGeDFDqkGriYKdfoFlDnMtqTiI/sFzhA9Y= +github.com/klauspost/cpuid/v2 v2.3.0/go.mod h1:hqwkgyIinND0mEev00jJYCxPNVRVXFQeu1XKlok6oO0= +github.com/kr/pretty v0.3.1 h1:flRD4NNwYAUpkphVc1HcthR4KEIFJ65n8Mw5qdRn3LE= +github.com/kr/pretty v0.3.1/go.mod h1:hoEshYVHaxMs3cyo3Yncou5ZscifuDolrwPKZanG3xk= +github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY= +github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE= +github.com/mailru/easyjson v0.9.0 h1:PrnmzHw7262yW8sTBwxi1PdJA3Iw/EKBa8psRf7d9a4= +github.com/mailru/easyjson v0.9.0/go.mod h1:1+xMtQp2MRNVL/V1bOzuP3aP8VNwRW55fQUto+XFtTU= +github.com/mark3labs/mcp-go v0.37.0 h1:BywvZLPRT6Zx6mMG/MJfxLSZQkTGIcJSEGKsvr4DsoQ= +github.com/mark3labs/mcp-go v0.37.0/go.mod h1:T7tUa2jO6MavG+3P25Oy/jR7iCeJPHImCZHRymCn39g= +github.com/mattn/go-colorable v0.1.13/go.mod h1:7S9/ev0klgBDR4GtXTXX8a3vIGJpMovkB8vQcUbaXHg= +github.com/mattn/go-colorable v0.1.14 h1:9A9LHSqF/7dyVVX6g0U9cwm9pG3kP9gSzcuIPHPsaIE= +github.com/mattn/go-colorable v0.1.14/go.mod h1:6LmQG8QLFO4G5z1gPvYEzlUgJ2wF+stgPZH1UqBm1s8= +github.com/mattn/go-isatty v0.0.16/go.mod h1:kYGgaQfpe5nmfYZH+SKPsOc2e4SrIfOl2e/yFXSvRLM= +github.com/mattn/go-isatty v0.0.19/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y= +github.com/mattn/go-isatty v0.0.20 h1:xfD0iDuEKnDkl03q4limB+vH+GxLEtL/jb4xVJSWWEY= +github.com/mattn/go-isatty v0.0.20/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y= +github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= +github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= +github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2 h1:Jamvg5psRIccs7FGNTlIRMkT8wgtp5eCXdBlqhYGL6U= +github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= +github.com/rogpeppe/go-internal v1.13.1 h1:KvO1DLK/DRN07sQ1LQKScxyZJuNnedQ5/wKSR38lUII= +github.com/rogpeppe/go-internal v1.13.1/go.mod h1:uMEvuHeurkdAXX61udpOXGD/AzZDWNMNyH2VO9fmH0o= +github.com/rs/xid v1.6.0/go.mod h1:7XoLgs4eV+QndskICGsho+ADou8ySMSjJKDIan90Nz0= +github.com/rs/zerolog v1.34.0 h1:k43nTLIwcTVQAncfCw4KZ2VY6ukYoZaBPNOE8txlOeY= +github.com/rs/zerolog v1.34.0/go.mod h1:bJsvje4Z08ROH4Nhs5iH600c3IkWhwp44iRc54W6wYQ= +github.com/spf13/cast v1.9.2 h1:SsGfm7M8QOFtEzumm7UZrZdLLquNdzFYfIbEXntcFbE= +github.com/spf13/cast v1.9.2/go.mod h1:jNfB8QC9IA6ZuY2ZjDp0KtFO2LZZlg4S/7bzP6qqeHo= +github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= +github.com/stretchr/objx v0.4.0/go.mod h1:YvHI0jy2hoMjB+UWwv71VJQ9isScKT/TqJzVSSt89Yw= +github.com/stretchr/objx v0.5.0/go.mod h1:Yh+to48EsGEfYuaHDzXPcE3xhTkx73EhmCGUpEOglKo= +github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= +github.com/stretchr/testify v1.8.0/go.mod h1:yNjHg4UonilssWZ8iaSj1OCr/vHnekPRkoO+kdMU+MU= +github.com/stretchr/testify v1.8.1/go.mod h1:w2LPCIKwWwSfY2zedu0+kehJoqGctiVI29o6fzry7u4= +github.com/stretchr/testify v1.10.0 h1:Xv5erBjTwe/5IxqUQTdXv5kgmIvbHo3QQyRwhJsOfJA= +github.com/stretchr/testify v1.10.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY= +github.com/twitchyliquid64/golang-asm v0.15.1 h1:SU5vSMR7hnwNxj24w34ZyCi/FmDZTkS4MhqMhdFk5YI= +github.com/twitchyliquid64/golang-asm v0.15.1/go.mod h1:a1lVb/DtPvCB8fslRZhAngC2+aY1QWCk3Cedj/Gdt08= +github.com/valyala/bytebufferpool v1.0.0 h1:GqA5TC/0021Y/b9FG4Oi9Mr3q7XYx6KllzawFIhcdPw= +github.com/valyala/bytebufferpool v1.0.0/go.mod h1:6bBcMArwyJ5K/AmCkWv1jt77kVWyCJ6HpOuEn7z0Csc= +github.com/valyala/fasthttp v1.65.0 h1:j/u3uzFEGFfRxw79iYzJN+TteTJwbYkru9uDp3d0Yf8= +github.com/valyala/fasthttp v1.65.0/go.mod h1:P/93/YkKPMsKSnATEeELUCkG8a7Y+k99uxNHVbKINr4= +github.com/wk8/go-ordered-map/v2 v2.1.8 h1:5h/BUHu93oj4gIdvHHHGsScSTMijfx5PeYkE/fJgbpc= +github.com/wk8/go-ordered-map/v2 v2.1.8/go.mod h1:5nJHM5DyteebpVlHnWMV0rPz6Zp7+xBAnxjb1X5vnTw= +github.com/xyproto/randomstring v1.0.5 h1:YtlWPoRdgMu3NZtP45drfy1GKoojuR7hmRcnhZqKjWU= +github.com/xyproto/randomstring v1.0.5/go.mod h1:rgmS5DeNXLivK7YprL0pY+lTuhNQW3iGxZ18UQApw/E= +github.com/yosida95/uritemplate/v3 v3.0.2 h1:Ed3Oyj9yrmi9087+NczuL5BwkIc4wvTb5zIM+UJPGz4= +github.com/yosida95/uritemplate/v3 v3.0.2/go.mod h1:ILOh0sOhIJR3+L/8afwt/kE++YT040gmv5BQTMR2HP4= +golang.org/x/arch v0.20.0 h1:dx1zTU0MAE98U+TQ8BLl7XsJbgze2WnNKF/8tGp/Q6c= +golang.org/x/arch v0.20.0/go.mod h1:bdwinDaKcfZUGpH09BB7ZmOfhalA8lQdzl62l8gGWsk= +golang.org/x/net v0.43.0 h1:lat02VYK2j4aLzMzecihNvTlJNQUq316m2Mr9rnM6YE= +golang.org/x/net v0.43.0/go.mod h1:vhO1fvI4dGsIjh73sWfUVjj3N7CA9WkKJNQm2svM6Jg= +golang.org/x/oauth2 v0.30.0 h1:dnDm7JmhM45NNpd8FDDeLhK6FwqbOf4MLCM9zb1BOHI= +golang.org/x/oauth2 v0.30.0/go.mod h1:B++QgG3ZKulg6sRPGD/mqlHQs5rB3Ml9erfeDY7xKlU= +golang.org/x/sys v0.0.0-20220811171246-fbc7d0a398ab/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.12.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.35.0 h1:vz1N37gP5bs89s7He8XuIYXpyY0+QlsKmzipCbUtyxI= +golang.org/x/sys v0.35.0/go.mod h1:BJP2sWEmIv4KK5OTEluFJCKSidICx8ciO85XgH3Ak8k= +golang.org/x/text v0.28.0 h1:rhazDwis8INMIwQ4tpjLDzUhx6RlXqZNPEM0huQojng= +golang.org/x/text v0.28.0/go.mod h1:U8nCwOR8jO/marOQ0QbDiOngZVEBB7MAiitBuMjXiNU= +gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= +gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c h1:Hei/4ADfdWqJk1ZMxUNpqntNwaWcugrBjAiHlqqRiVk= +gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c/go.mod h1:JHkPIbrfpd72SG/EVd6muEfDQjcINNoR0C8j2r3qZ4Q= +gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= +gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= +gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= diff --git a/tests/core-providers/groq_test.go b/tests/core-providers/groq_test.go new file mode 100644 index 000000000..9b0e24081 --- /dev/null +++ b/tests/core-providers/groq_test.go @@ -0,0 +1,43 @@ +package tests + +import ( + "testing" + + "github.com/maximhq/bifrost/tests/core-providers/config" + + "github.com/maximhq/bifrost/core/schemas" +) + +func TestGroq(t *testing.T) { + client, ctx, cancel, err := config.SetupTest() + if err != nil { + t.Fatalf("Error initializing test setup: %v", err) + } + defer cancel() + defer client.Shutdown() + + testConfig := config.ComprehensiveTestConfig{ + Provider: schemas.Groq, + ChatModel: "llama-3.3-70b-versatile", + TextModel: "", // Groq doesn't support text completion + EmbeddingModel: "", // Groq doesn't support embedding + Scenarios: config.TestScenarios{ + TextCompletion: false, // Not supported + SimpleChat: true, + ChatCompletionStream: true, + MultiTurnConversation: true, + ToolCalls: true, + MultipleToolCalls: true, + End2EndToolCalling: true, + AutomaticFunctionCall: true, + ImageURL: false, + ImageBase64: false, + MultipleImages: false, + CompleteEnd2End: true, + ProviderSpecific: true, + Embedding: false, + }, + } + + runAllComprehensiveTests(t, client, ctx, testConfig) +} diff --git a/tests/core-providers/mistral_test.go b/tests/core-providers/mistral_test.go new file mode 100644 index 000000000..e9347de09 --- /dev/null +++ b/tests/core-providers/mistral_test.go @@ -0,0 +1,43 @@ +package tests + +import ( + "testing" + + "github.com/maximhq/bifrost/tests/core-providers/config" + + "github.com/maximhq/bifrost/core/schemas" +) + +func TestMistral(t *testing.T) { + client, ctx, cancel, err := config.SetupTest() + if err != nil { + t.Fatalf("Error initializing test setup: %v", err) + } + defer cancel() + defer client.Shutdown() + + testConfig := config.ComprehensiveTestConfig{ + Provider: schemas.Mistral, + ChatModel: "pixtral-12b-latest", + TextModel: "", // Mistral doesn't support text completion in newer models + EmbeddingModel: "mistral-embed", + Scenarios: config.TestScenarios{ + TextCompletion: false, // Not supported + SimpleChat: true, + ChatCompletionStream: true, + MultiTurnConversation: true, + ToolCalls: true, + MultipleToolCalls: true, + End2EndToolCalling: true, + AutomaticFunctionCall: true, + ImageURL: true, + ImageBase64: true, + MultipleImages: true, + CompleteEnd2End: true, + ProviderSpecific: true, + Embedding: true, + }, + } + + runAllComprehensiveTests(t, client, ctx, testConfig) +} diff --git a/tests/core-providers/ollama_test.go b/tests/core-providers/ollama_test.go new file mode 100644 index 000000000..af7960c20 --- /dev/null +++ b/tests/core-providers/ollama_test.go @@ -0,0 +1,43 @@ +package tests + +import ( + "testing" + + "github.com/maximhq/bifrost/tests/core-providers/config" + + "github.com/maximhq/bifrost/core/schemas" +) + +func TestOllama(t *testing.T) { + client, ctx, cancel, err := config.SetupTest() + if err != nil { + t.Fatalf("Error initializing test setup: %v", err) + } + defer cancel() + defer client.Shutdown() + + testConfig := config.ComprehensiveTestConfig{ + Provider: schemas.Ollama, + ChatModel: "llama3.2", + TextModel: "", // Ollama doesn't support text completion in newer models + EmbeddingModel: "", // Ollama doesn't support embedding + Scenarios: config.TestScenarios{ + TextCompletion: false, // Not supported + SimpleChat: true, + ChatCompletionStream: true, + MultiTurnConversation: true, + ToolCalls: true, + MultipleToolCalls: true, + End2EndToolCalling: true, + AutomaticFunctionCall: true, + ImageURL: false, + ImageBase64: false, + MultipleImages: false, + CompleteEnd2End: true, + ProviderSpecific: true, + Embedding: false, + }, + } + + runAllComprehensiveTests(t, client, ctx, testConfig) +} diff --git a/tests/core-providers/openai_test.go b/tests/core-providers/openai_test.go new file mode 100644 index 000000000..69142fe41 --- /dev/null +++ b/tests/core-providers/openai_test.go @@ -0,0 +1,52 @@ +package tests + +import ( + "testing" + + "github.com/maximhq/bifrost/tests/core-providers/config" + + "github.com/maximhq/bifrost/core/schemas" +) + +func TestOpenAI(t *testing.T) { + client, ctx, cancel, err := config.SetupTest() + if err != nil { + t.Fatalf("Error initializing test setup: %v", err) + } + defer cancel() + defer client.Shutdown() + + testConfig := config.ComprehensiveTestConfig{ + Provider: schemas.OpenAI, + ChatModel: "gpt-4o-mini", + TextModel: "", // OpenAI doesn't support text completion in newer models + EmbeddingModel: "text-embedding-3-small", + TranscriptionModel: "whisper-1", + SpeechSynthesisModel: "tts-1", + Scenarios: config.TestScenarios{ + TextCompletion: false, // Not supported + SimpleChat: true, + ChatCompletionStream: true, + MultiTurnConversation: true, + ToolCalls: true, + MultipleToolCalls: true, + End2EndToolCalling: true, + AutomaticFunctionCall: true, + ImageURL: true, + ImageBase64: true, + MultipleImages: true, + CompleteEnd2End: true, + ProviderSpecific: true, + SpeechSynthesis: true, + SpeechSynthesisStream: true, + Transcription: true, + TranscriptionStream: true, + Embedding: true, + }, + Fallbacks: []schemas.Fallback{ + {Provider: schemas.Anthropic, Model: "claude-3-7-sonnet-20250219"}, + }, + } + + runAllComprehensiveTests(t, client, ctx, testConfig) +} diff --git a/tests/core-providers/openrouter_test.go b/tests/core-providers/openrouter_test.go new file mode 100644 index 000000000..8ce5ca342 --- /dev/null +++ b/tests/core-providers/openrouter_test.go @@ -0,0 +1,46 @@ +package tests + +import ( + "os" + "testing" + + "github.com/maximhq/bifrost/tests/core-providers/config" + + "github.com/maximhq/bifrost/core/schemas" +) + +func TestOpenRouter(t *testing.T) { + if os.Getenv("OPENROUTER_API_KEY") == "" { + t.Skip("OPENROUTER_API_KEY not set; skipping OpenRouter tests") + } + client, ctx, cancel, err := config.SetupTest() + if err != nil { + t.Fatalf("Error initializing test setup: %v", err) + } + defer cancel() + defer client.Shutdown() + + testConfig := config.ComprehensiveTestConfig{ + Provider: schemas.OpenRouter, + ChatModel: "openai/gpt-4o", + TextModel: "google/gemini-2.5-flash", + EmbeddingModel: "", + Scenarios: config.TestScenarios{ + TextCompletion: true, + SimpleChat: true, + ChatCompletionStream: true, + MultiTurnConversation: true, + ToolCalls: true, + MultipleToolCalls: true, + End2EndToolCalling: true, + AutomaticFunctionCall: true, + ImageURL: true, + ImageBase64: true, + MultipleImages: true, + CompleteEnd2End: true, + ProviderSpecific: true, + }, + } + + runAllComprehensiveTests(t, client, ctx, testConfig) +} diff --git a/tests/core-providers/parasail_test.go b/tests/core-providers/parasail_test.go new file mode 100644 index 000000000..46e4053e8 --- /dev/null +++ b/tests/core-providers/parasail_test.go @@ -0,0 +1,43 @@ +package tests + +import ( + "testing" + + "github.com/maximhq/bifrost/tests/core-providers/config" + + "github.com/maximhq/bifrost/core/schemas" +) + +func TestParasail(t *testing.T) { + client, ctx, cancel, err := config.SetupTest() + if err != nil { + t.Fatalf("Error initializing test setup: %v", err) + } + defer cancel() + defer client.Shutdown() + + testConfig := config.ComprehensiveTestConfig{ + Provider: schemas.Parasail, + ChatModel: "parasail-deepseek-r1", + TextModel: "", // Parasail doesn't support text completion + EmbeddingModel: "", // Parasail doesn't support embedding + Scenarios: config.TestScenarios{ + TextCompletion: false, // Not supported + SimpleChat: true, + ChatCompletionStream: true, + MultiTurnConversation: true, + ToolCalls: true, + MultipleToolCalls: true, + End2EndToolCalling: true, + AutomaticFunctionCall: true, + ImageURL: false, // Not supported yet + ImageBase64: false, // Not supported yet + MultipleImages: false, // Not supported yet + CompleteEnd2End: true, + ProviderSpecific: true, + Embedding: false, // Not supported yet + }, + } + + runAllComprehensiveTests(t, client, ctx, testConfig) +} diff --git a/tests/core-providers/scenarios/automatic_function_calling.go b/tests/core-providers/scenarios/automatic_function_calling.go new file mode 100644 index 000000000..08d9b4191 --- /dev/null +++ b/tests/core-providers/scenarios/automatic_function_calling.go @@ -0,0 +1,76 @@ +package scenarios + +import ( + "context" + "testing" + + "github.com/maximhq/bifrost/tests/core-providers/config" + + bifrost "github.com/maximhq/bifrost/core" + "github.com/maximhq/bifrost/core/schemas" + "github.com/stretchr/testify/require" +) + +// RunAutomaticFunctionCallingTest executes the automatic function calling test scenario +func RunAutomaticFunctionCallingTest(t *testing.T, client *bifrost.Bifrost, ctx context.Context, testConfig config.ComprehensiveTestConfig) { + if !testConfig.Scenarios.AutomaticFunctionCall { + t.Logf("Automatic function calling not supported for provider %s", testConfig.Provider) + return + } + + t.Run("AutomaticFunctionCalling", func(t *testing.T) { + messages := []schemas.BifrostMessage{ + CreateBasicChatMessage("Get the current time in UTC timezone"), + } + + params := MergeModelParameters(&schemas.ModelParameters{ + Tools: &[]schemas.Tool{TimeToolDefinition}, + ToolChoice: &schemas.ToolChoice{ + ToolChoiceStruct: &schemas.ToolChoiceStruct{ + Type: schemas.ToolChoiceTypeFunction, + Function: schemas.ToolChoiceFunction{ + Name: "get_current_time", + }, + }, + }, + MaxTokens: bifrost.Ptr(150), + }, testConfig.CustomParams) + + request := &schemas.BifrostRequest{ + Provider: testConfig.Provider, + Model: testConfig.ChatModel, + Input: schemas.RequestInput{ + ChatCompletionInput: &messages, + }, + Params: params, + Fallbacks: testConfig.Fallbacks, + } + + response, err := client.ChatCompletionRequest(ctx, request) + require.Nilf(t, err, "Automatic function calling failed: %v", err) + require.NotNil(t, response) + require.NotEmpty(t, response.Choices) + + // Find at least one choice with valid tool calls + foundValidToolCall := false + for i, choice := range response.Choices { + message := choice.Message + if message.AssistantMessage != nil && message.AssistantMessage.ToolCalls != nil { + toolCalls := *message.AssistantMessage.ToolCalls + // Iterate through all tool calls, not just the first one + for j, toolCall := range toolCalls { + if toolCall.Function.Name != nil && *toolCall.Function.Name == "get_current_time" { + foundValidToolCall = true + t.Logf("βœ… Automatic function call for choice %d, tool call %d: %s", i, j, toolCall.Function.Arguments) + break // Found valid tool call, can break from this inner loop + } + } + if foundValidToolCall { + break // Found valid tool call, can break from choices loop + } + } + } + + require.True(t, foundValidToolCall, "Expected at least one choice to have automatic tool call for 'get_current_time'. Response: %s", GetResultContent(response)) + }) +} diff --git a/tests/core-providers/scenarios/chat_completion_stream.go b/tests/core-providers/scenarios/chat_completion_stream.go new file mode 100644 index 000000000..6af2e6c3b --- /dev/null +++ b/tests/core-providers/scenarios/chat_completion_stream.go @@ -0,0 +1,231 @@ +package scenarios + +import ( + "context" + "strings" + "testing" + "time" + + "github.com/maximhq/bifrost/tests/core-providers/config" + + bifrost "github.com/maximhq/bifrost/core" + "github.com/maximhq/bifrost/core/schemas" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +// RunChatCompletionStreamTest executes the chat completion stream test scenario +func RunChatCompletionStreamTest(t *testing.T, client *bifrost.Bifrost, ctx context.Context, testConfig config.ComprehensiveTestConfig) { + if !testConfig.Scenarios.ChatCompletionStream { + t.Logf("Chat completion stream not supported for provider %s", testConfig.Provider) + return + } + + t.Run("ChatCompletionStream", func(t *testing.T) { + messages := []schemas.BifrostMessage{ + CreateBasicChatMessage("Tell me a short story about a robot learning to paint. Keep it under 200 words."), + } + + request := &schemas.BifrostRequest{ + Provider: testConfig.Provider, + Model: testConfig.ChatModel, + Input: schemas.RequestInput{ + ChatCompletionInput: &messages, + }, + Params: MergeModelParameters(&schemas.ModelParameters{ + MaxTokens: bifrost.Ptr(250), + }, testConfig.CustomParams), + Fallbacks: testConfig.Fallbacks, + } + + // Test streaming response + responseChannel, err := client.ChatCompletionStreamRequest(ctx, request) + require.Nilf(t, err, "Chat completion stream failed: %v", err) + require.NotNil(t, responseChannel, "Response channel should not be nil") + + var fullContent strings.Builder + var responseCount int + var lastResponse *schemas.BifrostStream + + // Create a timeout context for the stream reading + streamCtx, cancel := context.WithTimeout(ctx, 30*time.Second) + defer cancel() + + t.Logf("πŸ“‘ Starting to read streaming response...") + + // Read streaming responses + for { + select { + case response, ok := <-responseChannel: + if !ok { + // Channel closed, streaming completed + t.Logf("βœ… Streaming completed. Total chunks received: %d", responseCount) + goto streamComplete + } + + require.NotNil(t, response, "Streaming response should not be nil") + lastResponse = response + + // Validate response structure + assert.Equal(t, testConfig.Provider, response.ExtraFields.Provider, "Provider should match") + assert.NotEmpty(t, response.ID, "Response ID should not be empty") + assert.Equal(t, "chat.completion.chunk", response.Object, "Object type should be chat.completion.chunk") + assert.NotEmpty(t, response.Choices, "Choices should not be empty") + + // Process each choice in the response + for _, choice := range response.Choices { + // Validate that this is a stream response + assert.NotNil(t, choice.BifrostStreamResponseChoice, "Stream response choice should not be nil") + assert.Nil(t, choice.BifrostNonStreamResponseChoice, "Non-stream response choice should be nil") + + // Get content from delta + if choice.BifrostStreamResponseChoice != nil { + delta := choice.BifrostStreamResponseChoice.Delta + if delta.Content != nil { + fullContent.WriteString(*delta.Content) + } + + // Log role if present (usually in first chunk) + if delta.Role != nil { + t.Logf("πŸ€– Role: %s", *delta.Role) + } + + // Check finish reason if present + if choice.FinishReason != nil { + t.Logf("🏁 Finish reason: %s", *choice.FinishReason) + } + } + } + + responseCount++ + + // Safety check to prevent infinite loops in case of issues + if responseCount > 500 { + t.Fatal("Received too many streaming chunks, something might be wrong") + } + + case <-streamCtx.Done(): + t.Fatal("Timeout waiting for streaming response") + } + } + + streamComplete: + // Validate that the last response contains usage information and/or finish reason + // with empty choices (typical final chunk pattern) + if lastResponse != nil && lastResponse.BifrostResponse != nil { + // Check if this is a final metadata chunk (empty choices with usage/finish info) + if len(lastResponse.Choices) == 0 && lastResponse.Usage != nil { + assert.Greater(t, lastResponse.Usage.TotalTokens, 0, "Final chunk should have total token count") + t.Logf("πŸ“Š Final metadata chunk - Total tokens: %d", lastResponse.Usage.TotalTokens) + } else if len(lastResponse.Choices) > 0 { + // Check if final choice has finish reason + finalChoice := lastResponse.Choices[0] + if finalChoice.FinishReason != nil { + t.Logf("🏁 Stream ended with finish reason: %s", *finalChoice.FinishReason) + } + } else { + t.Fatal("Last response should have choices or usage") + } + } + + // Validate the complete response + assert.Greater(t, responseCount, 0, "Should receive at least one streaming response") + + finalContent := strings.TrimSpace(fullContent.String()) + assert.NotEmpty(t, finalContent, "Final content should not be empty") + assert.Greater(t, len(finalContent), 10, "Final content should be substantial") + + if lastResponse.BifrostResponse != nil { + // Validate the last response has usage information + if len(lastResponse.Choices) > 0 { + finishReason := lastResponse.Choices[0].FinishReason + assert.NotNil(t, finishReason, "Finish reason should not be nil") + } else { + // This is a metadata-only chunk, which is valid for final chunks + assert.NotNil(t, lastResponse.Usage, "Usage should not be nil") + } + } + + t.Logf("βœ… Streaming test completed successfully") + t.Logf("πŸ“ Final content (%d chars)", len(finalContent)) + }) + + // Test streaming with tool calls if supported + if testConfig.Scenarios.ToolCalls { + t.Run("ChatCompletionStreamWithTools", func(t *testing.T) { + messages := []schemas.BifrostMessage{ + CreateBasicChatMessage("What's the weather like in San Francisco? Please use the get_weather function."), + } + + request := &schemas.BifrostRequest{ + Provider: testConfig.Provider, + Model: testConfig.ChatModel, + Input: schemas.RequestInput{ + ChatCompletionInput: &messages, + }, + Params: MergeModelParameters(&schemas.ModelParameters{ + MaxTokens: bifrost.Ptr(150), + Tools: &[]schemas.Tool{WeatherToolDefinition}, + }, testConfig.CustomParams), + Fallbacks: testConfig.Fallbacks, + } + + responseChannel, err := client.ChatCompletionStreamRequest(ctx, request) + require.Nilf(t, err, "Chat completion stream with tools failed: %v", err) + require.NotNil(t, responseChannel, "Response channel should not be nil") + + var toolCallDetected bool + var responseCount int + + streamCtx, cancel := context.WithTimeout(ctx, 30*time.Second) + defer cancel() + + t.Logf("πŸ”§ Testing streaming with tool calls...") + + for { + select { + case response, ok := <-responseChannel: + if !ok { + goto toolStreamComplete + } + + require.NotNil(t, response, "Streaming response should not be nil") + responseCount++ + + for _, choice := range response.Choices { + if choice.BifrostStreamResponseChoice != nil { + delta := choice.BifrostStreamResponseChoice.Delta + + // Check for tool calls in delta + if len(delta.ToolCalls) > 0 { + toolCallDetected = true + t.Logf("πŸ”§ Tool call detected in streaming response") + + for _, toolCall := range delta.ToolCalls { + if toolCall.Function.Name != nil { + t.Logf("πŸ”§ Tool: %s", *toolCall.Function.Name) + if toolCall.Function.Arguments != "" { + t.Logf("πŸ”§ Args: %s", toolCall.Function.Arguments) + } + } + } + } + } + } + + if responseCount > 100 { + goto toolStreamComplete + } + + case <-streamCtx.Done(): + t.Fatal("Timeout waiting for streaming response with tools") + } + } + + toolStreamComplete: + assert.Greater(t, responseCount, 0, "Should receive at least one streaming response") + assert.True(t, toolCallDetected, "Should detect tool calls in streaming response") + t.Logf("βœ… Streaming with tools test completed successfully") + }) + } +} diff --git a/tests/core-providers/scenarios/complete_end_to_end.go b/tests/core-providers/scenarios/complete_end_to_end.go new file mode 100644 index 000000000..880485a06 --- /dev/null +++ b/tests/core-providers/scenarios/complete_end_to_end.go @@ -0,0 +1,118 @@ +package scenarios + +import ( + "context" + "testing" + + "github.com/maximhq/bifrost/tests/core-providers/config" + + bifrost "github.com/maximhq/bifrost/core" + "github.com/maximhq/bifrost/core/schemas" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +// RunCompleteEnd2EndTest executes the complete end-to-end test scenario +func RunCompleteEnd2EndTest(t *testing.T, client *bifrost.Bifrost, ctx context.Context, testConfig config.ComprehensiveTestConfig) { + if !testConfig.Scenarios.CompleteEnd2End { + t.Logf("Complete end-to-end not supported for provider %s", testConfig.Provider) + return + } + + t.Run("CompleteEnd2End", func(t *testing.T) { + // Multi-step conversation with tools and images + userMessage1 := CreateBasicChatMessage("Hi, I'm planning a trip. Can you help me get the weather in Paris?") + + request1 := &schemas.BifrostRequest{ + Provider: testConfig.Provider, + Model: testConfig.ChatModel, + Input: schemas.RequestInput{ + ChatCompletionInput: &[]schemas.BifrostMessage{userMessage1}, + }, + Params: MergeModelParameters(&schemas.ModelParameters{ + Tools: &[]schemas.Tool{WeatherToolDefinition}, + MaxTokens: bifrost.Ptr(150), + }, testConfig.CustomParams), + Fallbacks: testConfig.Fallbacks, + } + + response1, err := client.ChatCompletionRequest(ctx, request1) + require.Nilf(t, err, "First end-to-end request failed: %v", err) + require.NotNil(t, response1) + require.NotEmpty(t, response1.Choices) + + t.Logf("βœ… First response: %s", GetResultContent(response1)) + + // If tool was called, simulate result and continue conversation + var conversationHistory []schemas.BifrostMessage + conversationHistory = append(conversationHistory, userMessage1) + + // Add all choice messages to conversation history + for _, choice := range response1.Choices { + conversationHistory = append(conversationHistory, choice.Message) + } + + // Find any choice with tool calls for processing + var selectedToolCall *schemas.ToolCall + for _, choice := range response1.Choices { + message := choice.Message + if message.AssistantMessage != nil && message.AssistantMessage.ToolCalls != nil { + toolCalls := *message.AssistantMessage.ToolCalls + // Look for a valid weather tool call + for _, toolCall := range toolCalls { + if toolCall.Function.Name != nil && *toolCall.Function.Name == "get_weather" { + selectedToolCall = &toolCall + break + } + } + if selectedToolCall != nil { + break + } + } + } + + // If a tool call was found, simulate the result + if selectedToolCall != nil { + // Simulate tool result + toolResult := `{"temperature": "18", "unit": "celsius", "description": "Partly cloudy", "humidity": "70%"}` + toolCallID := "" + if selectedToolCall.ID != nil { + toolCallID = *selectedToolCall.ID + } else if selectedToolCall.Function.Name != nil { + toolCallID = *selectedToolCall.Function.Name + } + require.NotEmpty(t, toolCallID, "toolCallID must not be empty – provider did not return ID or Function.Name") + toolMessage := CreateToolMessage(toolResult, toolCallID) + conversationHistory = append(conversationHistory, toolMessage) + } + + // Continue with follow-up + followUpMessage := CreateBasicChatMessage("Thanks! Now can you tell me about this travel image?") + if testConfig.Scenarios.ImageURL { + followUpMessage = CreateImageMessage("Thanks! Now can you tell me what you see in this travel-related image?", TestImageURL) + } + conversationHistory = append(conversationHistory, followUpMessage) + + finalRequest := &schemas.BifrostRequest{ + Provider: testConfig.Provider, + Model: testConfig.ChatModel, + Input: schemas.RequestInput{ + ChatCompletionInput: &conversationHistory, + }, + Params: MergeModelParameters(&schemas.ModelParameters{ + MaxTokens: bifrost.Ptr(200), + }, testConfig.CustomParams), + Fallbacks: testConfig.Fallbacks, + } + + finalResponse, err := client.ChatCompletionRequest(ctx, finalRequest) + require.Nilf(t, err, "Final end-to-end request failed: %v", err) + require.NotNil(t, finalResponse) + require.NotEmpty(t, finalResponse.Choices) + + finalContent := GetResultContent(finalResponse) + assert.NotEmpty(t, finalContent, "Final response content should not be empty") + + t.Logf("βœ… Complete end-to-end result: %s", finalContent) + }) +} diff --git a/tests/core-providers/scenarios/embedding.go b/tests/core-providers/scenarios/embedding.go new file mode 100644 index 000000000..c15749236 --- /dev/null +++ b/tests/core-providers/scenarios/embedding.go @@ -0,0 +1,120 @@ +package scenarios + +import ( + "context" + "fmt" + "math" + "strings" + "testing" + + "github.com/maximhq/bifrost/tests/core-providers/config" + + bifrost "github.com/maximhq/bifrost/core" + "github.com/maximhq/bifrost/core/schemas" + "github.com/stretchr/testify/require" +) + +// cosineSimilarity computes the cosine similarity between two vectors +func cosineSimilarity(a, b []float32) float64 { + if len(a) != len(b) { + panic(fmt.Errorf("cosineSimilarity: vectors must have same length, got %d and %d", len(a), len(b))) + } + + var dotProduct float64 + var normA float64 + var normB float64 + + for i := 0; i < len(a); i++ { + dotProduct += float64(a[i] * b[i]) + normA += float64(a[i] * a[i]) + normB += float64(b[i] * b[i]) + } + + if normA == 0 || normB == 0 { + return 0.0 + } + + return dotProduct / (math.Sqrt(normA) * math.Sqrt(normB)) +} + +// getEmbeddingVector extracts the embedding vector from BifrostEmbeddingResponse +func getEmbeddingVector(embedding schemas.BifrostEmbeddingResponse) ([]float32, error) { + if embedding.EmbeddingArray != nil { + return *embedding.EmbeddingArray, nil + } + if embedding.Embedding2DArray != nil && len(*embedding.Embedding2DArray) > 0 { + return (*embedding.Embedding2DArray)[0], nil + } + return nil, fmt.Errorf("no valid embedding vector found") +} + +// RunEmbeddingTest executes the embedding test scenario +func RunEmbeddingTest(t *testing.T, client *bifrost.Bifrost, ctx context.Context, testConfig config.ComprehensiveTestConfig) { + if !testConfig.Scenarios.Embedding { + t.Logf("Embedding not supported for provider %s", testConfig.Provider) + return + } + + if strings.TrimSpace(testConfig.EmbeddingModel) == "" { + t.Skipf("Embedding enabled but model is not configured for provider %s; skipping", testConfig.Provider) + } + + t.Run(fmt.Sprintf("Embedding/%s/%s", testConfig.Provider, testConfig.EmbeddingModel), func(t *testing.T) { + // Test texts with expected semantic relationships + testTexts := []string{ + "Hello, world!", + "Hi, world!", + "Goodnight, moon!", + } + + // Get embeddings for all test texts + embeddings := make([][]float32, len(testTexts)) + request := &schemas.BifrostRequest{ + Provider: testConfig.Provider, + Model: testConfig.EmbeddingModel, + Input: schemas.RequestInput{ + EmbeddingInput: &schemas.EmbeddingInput{ + Texts: testTexts, + }, + }, + Params: MergeModelParameters(&schemas.ModelParameters{ + EncodingFormat: bifrost.Ptr("float"), + }, testConfig.CustomParams), + Fallbacks: testConfig.Fallbacks, + } + + response, err := client.EmbeddingRequest(ctx, request) + require.Nilf(t, err, "Embedding request failed: %v", err) + require.NotNil(t, response) + require.Lenf(t, response.Data, len(testTexts), "expected %d results", len(testTexts)) + for i := range response.Data { + vec, extractErr := getEmbeddingVector(response.Data[i].Embedding) + require.NoErrorf(t, extractErr, "Failed to extract embedding vector for text '%s': %v", testTexts[i], extractErr) + require.NotEmptyf(t, vec, "Embedding vector is empty for text '%s'", testTexts[i]) + embeddings[i] = vec + } + + // Ensure all embeddings have the same length + embeddingLength := len(embeddings[0]) + require.Greaterf(t, embeddingLength, 0, "First embedding length must be > 0") + for i, embedding := range embeddings { + require.Equalf(t, embeddingLength, len(embedding), + "Embedding %d has different length (%d) than first embedding (%d)", + i, len(embedding), embeddingLength) + } + + // Compute pairwise similarities + similarityHelloHi := cosineSimilarity(embeddings[0], embeddings[1]) // "Hello, world!" vs "Hi, world!" + similarityHelloGoodnight := cosineSimilarity(embeddings[0], embeddings[2]) // "Hello, world!" vs "Goodnight, moon!" + + // Assert semantic coherence: similar phrases should be more similar than dissimilar ones + require.Greaterf(t, similarityHelloHi, similarityHelloGoodnight+0.02, + "Semantic coherence test failed: similarity('Hello, world!' vs 'Hi, world!') = %.6f should be greater than similarity('Hello, world!' vs 'Goodnight, moon!') = %.6f. This suggests the embedding model may not be capturing semantic meaning correctly.", + similarityHelloHi, similarityHelloGoodnight) + + t.Logf("βœ… Semantic coherence validated:") + t.Logf(" Similarity('Hello, world!' vs 'Hi, world!'): %.6f", similarityHelloHi) + t.Logf(" Similarity('Hello, world!' vs 'Goodnight, moon!'): %.6f", similarityHelloGoodnight) + t.Logf(" Difference: %.6f", similarityHelloHi-similarityHelloGoodnight) + }) +} diff --git a/tests/core-providers/scenarios/end_to_end_tool_calling.go b/tests/core-providers/scenarios/end_to_end_tool_calling.go new file mode 100644 index 000000000..9995b61cc --- /dev/null +++ b/tests/core-providers/scenarios/end_to_end_tool_calling.go @@ -0,0 +1,121 @@ +package scenarios + +import ( + "context" + "strings" + "testing" + + "github.com/maximhq/bifrost/tests/core-providers/config" + + bifrost "github.com/maximhq/bifrost/core" + "github.com/maximhq/bifrost/core/schemas" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +// RunEnd2EndToolCallingTest executes the end-to-end tool calling test scenario +func RunEnd2EndToolCallingTest(t *testing.T, client *bifrost.Bifrost, ctx context.Context, testConfig config.ComprehensiveTestConfig) { + if !testConfig.Scenarios.End2EndToolCalling { + t.Logf("End-to-end tool calling not supported for provider %s", testConfig.Provider) + return + } + + t.Run("End2EndToolCalling", func(t *testing.T) { + // Step 1: User asks for weather + userMessage := CreateBasicChatMessage("What's the weather in San Francisco?") + + params := MergeModelParameters(&schemas.ModelParameters{ + Tools: &[]schemas.Tool{WeatherToolDefinition}, + MaxTokens: bifrost.Ptr(150), + }, testConfig.CustomParams) + + request := &schemas.BifrostRequest{ + Provider: testConfig.Provider, + Model: testConfig.ChatModel, + Input: schemas.RequestInput{ + ChatCompletionInput: &[]schemas.BifrostMessage{userMessage}, + }, + Params: params, + Fallbacks: testConfig.Fallbacks, + } + + // Execute first request + firstResponse, err := client.ChatCompletionRequest(ctx, request) + require.Nilf(t, err, "First request failed: %v", err) + require.NotNil(t, firstResponse) + require.NotEmpty(t, firstResponse.Choices) + + // Find a choice with valid tool calls + var toolCall schemas.ToolCall + foundValidChoice := false + + for _, choice := range firstResponse.Choices { + if choice.Message.AssistantMessage != nil && + choice.Message.AssistantMessage.ToolCalls != nil && + len(*choice.Message.AssistantMessage.ToolCalls) > 0 { + + firstToolCall := (*choice.Message.AssistantMessage.ToolCalls)[0] + if firstToolCall.Function.Name != nil && *firstToolCall.Function.Name == "get_weather" { + toolCall = firstToolCall + foundValidChoice = true + break + } + } + } + + require.True(t, foundValidChoice, "Expected at least one choice to have valid tool call for 'get_weather'") + + // Step 2: Simulate tool execution and provide result + toolResult := `{"temperature": "22", "unit": "celsius", "description": "Sunny with light clouds", "humidity": "65%"}` + + toolCallID := "" + if toolCall.ID != nil { + toolCallID = *toolCall.ID + } else { + toolCallID = *toolCall.Function.Name + } + + require.NotEmpty(t, toolCallID, "toolCallID must not be empty") + + // Build conversation history with all choice messages from first response + conversationMessages := []schemas.BifrostMessage{ + userMessage, + } + + // Add all choice messages from the first response + for _, choice := range firstResponse.Choices { + conversationMessages = append(conversationMessages, choice.Message) + } + + // Add the tool result message + conversationMessages = append(conversationMessages, CreateToolMessage(toolResult, toolCallID)) + + secondRequest := &schemas.BifrostRequest{ + Provider: testConfig.Provider, + Model: testConfig.ChatModel, + Input: schemas.RequestInput{ + ChatCompletionInput: &conversationMessages, + }, + Params: MergeModelParameters(&schemas.ModelParameters{ + MaxTokens: bifrost.Ptr(200), + }, testConfig.CustomParams), + Fallbacks: testConfig.Fallbacks, + } + + // Execute second request + finalResponse, err := client.ChatCompletionRequest(ctx, secondRequest) + require.Nilf(t, err, "Second request failed: %v", err) + require.NotNil(t, finalResponse) + require.NotEmpty(t, finalResponse.Choices) + + content := GetResultContent(finalResponse) + require.NotEmpty(t, content, "Response content should not be empty") + + // Verify response contains expected information + assert.Contains(t, strings.ToLower(content), "san francisco", "Response should mention San Francisco") + assert.Contains(t, content, "22", "Response should mention temperature") + assert.Contains(t, strings.ToLower(content), "sunny", "Response should mention weather description") + + t.Logf("βœ… End-to-end tool calling result: %s", content) + }) +} diff --git a/tests/core-providers/scenarios/image_base64.go b/tests/core-providers/scenarios/image_base64.go new file mode 100644 index 000000000..b10655d6f --- /dev/null +++ b/tests/core-providers/scenarios/image_base64.go @@ -0,0 +1,49 @@ +package scenarios + +import ( + "context" + "testing" + + "github.com/maximhq/bifrost/tests/core-providers/config" + + bifrost "github.com/maximhq/bifrost/core" + "github.com/maximhq/bifrost/core/schemas" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +// RunImageBase64Test executes the image base64 test scenario +func RunImageBase64Test(t *testing.T, client *bifrost.Bifrost, ctx context.Context, testConfig config.ComprehensiveTestConfig) { + if !testConfig.Scenarios.ImageBase64 { + t.Logf("Image base64 not supported for provider %s", testConfig.Provider) + return + } + + t.Run("ImageBase64", func(t *testing.T) { + messages := []schemas.BifrostMessage{ + CreateImageMessage("Describe this image briefly", TestImageBase64), + } + + request := &schemas.BifrostRequest{ + Provider: testConfig.Provider, + Model: testConfig.ChatModel, + Input: schemas.RequestInput{ + ChatCompletionInput: &messages, + }, + Params: MergeModelParameters(&schemas.ModelParameters{ + MaxTokens: bifrost.Ptr(200), + }, testConfig.CustomParams), + Fallbacks: testConfig.Fallbacks, + } + + response, err := client.ChatCompletionRequest(ctx, request) + require.Nilf(t, err, "Image base64 test failed: %v", err) + require.NotNil(t, response) + require.NotEmpty(t, response.Choices) + + content := GetResultContent(response) + assert.NotEmpty(t, content, "Response content should not be empty") + + t.Logf("βœ… Image base64 result: %s", content) + }) +} diff --git a/tests/core-providers/scenarios/image_url.go b/tests/core-providers/scenarios/image_url.go new file mode 100644 index 000000000..9f11d6328 --- /dev/null +++ b/tests/core-providers/scenarios/image_url.go @@ -0,0 +1,55 @@ +package scenarios + +import ( + "context" + "strings" + "testing" + + "github.com/maximhq/bifrost/tests/core-providers/config" + + bifrost "github.com/maximhq/bifrost/core" + "github.com/maximhq/bifrost/core/schemas" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +// RunImageURLTest executes the image URL test scenario +func RunImageURLTest(t *testing.T, client *bifrost.Bifrost, ctx context.Context, testConfig config.ComprehensiveTestConfig) { + if !testConfig.Scenarios.ImageURL { + t.Logf("Image URL not supported for provider %s", testConfig.Provider) + return + } + + t.Run("ImageURL", func(t *testing.T) { + messages := []schemas.BifrostMessage{ + CreateImageMessage("What do you see in this image?", TestImageURL), + } + + request := &schemas.BifrostRequest{ + Provider: testConfig.Provider, + Model: testConfig.ChatModel, + Input: schemas.RequestInput{ + ChatCompletionInput: &messages, + }, + Params: MergeModelParameters(&schemas.ModelParameters{ + MaxTokens: bifrost.Ptr(200), + }, testConfig.CustomParams), + Fallbacks: testConfig.Fallbacks, + } + + response, err := client.ChatCompletionRequest(ctx, request) + require.Nilf(t, err, "Image URL test failed: %v", err) + require.NotNil(t, response) + require.NotEmpty(t, response.Choices) + + content := GetResultContent(response) + assert.NotEmpty(t, content, "Response content should not be empty") + // Should mention something about the ant in the image + lowerContent := strings.ToLower(content) + assert.True(t, strings.Contains(lowerContent, "ant") || + strings.Contains(lowerContent, "insect"), + "Response should identify the ant/insect in the image") + + t.Logf("βœ… Image URL result: %s", content) + }) +} diff --git a/tests/core-providers/scenarios/multi_turn_conversation.go b/tests/core-providers/scenarios/multi_turn_conversation.go new file mode 100644 index 000000000..10512578b --- /dev/null +++ b/tests/core-providers/scenarios/multi_turn_conversation.go @@ -0,0 +1,84 @@ +package scenarios + +import ( + "context" + "strings" + "testing" + + "github.com/maximhq/bifrost/tests/core-providers/config" + + bifrost "github.com/maximhq/bifrost/core" + "github.com/maximhq/bifrost/core/schemas" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +// RunMultiTurnConversationTest executes the multi-turn conversation test scenario +func RunMultiTurnConversationTest(t *testing.T, client *bifrost.Bifrost, ctx context.Context, testConfig config.ComprehensiveTestConfig) { + if !testConfig.Scenarios.MultiTurnConversation { + t.Logf("Multi-turn conversation not supported for provider %s", testConfig.Provider) + return + } + + t.Run("MultiTurnConversation", func(t *testing.T) { + // First message + userMessage1 := CreateBasicChatMessage("My name is Alice. Remember this.") + messages1 := []schemas.BifrostMessage{ + userMessage1, + } + + firstRequest := &schemas.BifrostRequest{ + Provider: testConfig.Provider, + Model: testConfig.ChatModel, + Input: schemas.RequestInput{ + ChatCompletionInput: &messages1, + }, + Params: MergeModelParameters(&schemas.ModelParameters{ + MaxTokens: bifrost.Ptr(150), + }, testConfig.CustomParams), + Fallbacks: testConfig.Fallbacks, + } + + response1, err := client.ChatCompletionRequest(ctx, firstRequest) + require.Nilf(t, err, "First conversation turn failed: %v", err) + require.NotNil(t, response1) + require.NotEmpty(t, response1.Choices) + + // Second message with conversation history + // Build conversation history with all choice messages + messages2 := []schemas.BifrostMessage{ + userMessage1, + } + + // Add all choice messages from the first response + for _, choice := range response1.Choices { + messages2 = append(messages2, choice.Message) + } + + // Add the follow-up question + messages2 = append(messages2, CreateBasicChatMessage("What's my name?")) + + secondRequest := &schemas.BifrostRequest{ + Provider: testConfig.Provider, + Model: testConfig.ChatModel, + Input: schemas.RequestInput{ + ChatCompletionInput: &messages2, + }, + Params: MergeModelParameters(&schemas.ModelParameters{ + MaxTokens: bifrost.Ptr(150), + }, testConfig.CustomParams), + Fallbacks: testConfig.Fallbacks, + } + + response2, err := client.ChatCompletionRequest(ctx, secondRequest) + require.Nilf(t, err, "Second conversation turn failed: %v", err) + require.NotNil(t, response2) + require.NotEmpty(t, response2.Choices) + + content := GetResultContent(response2) + assert.NotEmpty(t, content, "Response content should not be empty") + // Check if the model remembered the name + assert.Contains(t, strings.ToLower(content), "alice", "Model should remember the name Alice") + t.Logf("βœ… Multi-turn conversation result: %s", content) + }) +} diff --git a/tests/core-providers/scenarios/multiple_images.go b/tests/core-providers/scenarios/multiple_images.go new file mode 100644 index 000000000..ba8d70a2b --- /dev/null +++ b/tests/core-providers/scenarios/multiple_images.go @@ -0,0 +1,71 @@ +package scenarios + +import ( + "context" + "testing" + + "github.com/maximhq/bifrost/tests/core-providers/config" + + bifrost "github.com/maximhq/bifrost/core" + "github.com/maximhq/bifrost/core/schemas" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +// RunMultipleImagesTest executes the multiple images test scenario +func RunMultipleImagesTest(t *testing.T, client *bifrost.Bifrost, ctx context.Context, testConfig config.ComprehensiveTestConfig) { + if !testConfig.Scenarios.MultipleImages { + t.Logf("Multiple images not supported for provider %s", testConfig.Provider) + return + } + + t.Run("MultipleImages", func(t *testing.T) { + messages := []schemas.BifrostMessage{ + { + Role: schemas.ModelChatMessageRoleUser, + Content: schemas.MessageContent{ + ContentBlocks: &[]schemas.ContentBlock{ + { + Type: schemas.ContentBlockTypeText, + Text: bifrost.Ptr("Compare these two images - what are the similarities and differences?"), + }, + { + Type: schemas.ContentBlockTypeImage, + ImageURL: &schemas.ImageURLStruct{ + URL: TestImageURL, + }, + }, + { + Type: schemas.ContentBlockTypeImage, + ImageURL: &schemas.ImageURLStruct{ + URL: TestImageBase64, + }, + }, + }, + }, + }, + } + + request := &schemas.BifrostRequest{ + Provider: testConfig.Provider, + Model: testConfig.ChatModel, + Input: schemas.RequestInput{ + ChatCompletionInput: &messages, + }, + Params: MergeModelParameters(&schemas.ModelParameters{ + MaxTokens: bifrost.Ptr(300), + }, testConfig.CustomParams), + Fallbacks: testConfig.Fallbacks, + } + + response, err := client.ChatCompletionRequest(ctx, request) + require.Nilf(t, err, "Multiple images test failed: %v", err) + require.NotNil(t, response) + require.NotEmpty(t, response.Choices) + + content := GetResultContent(response) + assert.NotEmpty(t, content, "Response content should not be empty") + + t.Logf("βœ… Multiple images result: %s", content) + }) +} diff --git a/tests/core-providers/scenarios/multiple_tool_calls.go b/tests/core-providers/scenarios/multiple_tool_calls.go new file mode 100644 index 000000000..dd65e663c --- /dev/null +++ b/tests/core-providers/scenarios/multiple_tool_calls.go @@ -0,0 +1,109 @@ +package scenarios + +import ( + "context" + "testing" + + "github.com/maximhq/bifrost/tests/core-providers/config" + + "slices" + + bifrost "github.com/maximhq/bifrost/core" + "github.com/maximhq/bifrost/core/schemas" + "github.com/stretchr/testify/require" +) + +// getKeysFromMap returns the keys of a map[string]bool as a slice +func getKeysFromMap(m map[string]bool) []string { + keys := make([]string, 0, len(m)) + for k := range m { + keys = append(keys, k) + } + return keys +} + +// RunMultipleToolCallsTest executes the multiple tool calls test scenario +func RunMultipleToolCallsTest(t *testing.T, client *bifrost.Bifrost, ctx context.Context, testConfig config.ComprehensiveTestConfig) { + if !testConfig.Scenarios.MultipleToolCalls { + t.Logf("Multiple tool calls not supported for provider %s", testConfig.Provider) + return + } + + t.Run("MultipleToolCalls", func(t *testing.T) { + messages := []schemas.BifrostMessage{ + CreateBasicChatMessage("I need to know the weather in London and also calculate 15 * 23. Can you help with both?"), + } + + params := MergeModelParameters(&schemas.ModelParameters{ + Tools: &[]schemas.Tool{WeatherToolDefinition, CalculatorToolDefinition}, + MaxTokens: bifrost.Ptr(200), + }, testConfig.CustomParams) + + request := &schemas.BifrostRequest{ + Provider: testConfig.Provider, + Model: testConfig.ChatModel, + Input: schemas.RequestInput{ + ChatCompletionInput: &messages, + }, + Params: params, + Fallbacks: testConfig.Fallbacks, + } + + response, err := client.ChatCompletionRequest(ctx, request) + require.Nilf(t, err, "Multiple tool calls failed: %v", err) + require.NotNil(t, response) + require.NotEmpty(t, response.Choices) + + // Find at least one choice with multiple valid tool calls + expectedToolNames := []string{"get_weather", "calculate"} + foundValidMultipleToolCalls := false + for choiceIdx, choice := range response.Choices { + message := choice.Message + if message.AssistantMessage != nil && message.AssistantMessage.ToolCalls != nil { + toolCalls := *message.AssistantMessage.ToolCalls + if len(toolCalls) >= 2 { + validToolCalls := 0 + foundToolNames := make(map[string]bool) + + for _, toolCall := range toolCalls { + if toolCall.Function.Name != nil { + toolName := *toolCall.Function.Name + // Check if this is one of the expected tool names + isExpected := false + for _, expectedName := range expectedToolNames { + if toolName == expectedName { + isExpected = true + foundToolNames[toolName] = true + break + } + } + if isExpected { + validToolCalls++ + } + } + } + + // Require at least 2 valid tool calls with expected names + if validToolCalls >= 2 { + foundValidMultipleToolCalls = true + t.Logf("βœ… Number of tool calls for choice %d: %d", choiceIdx, len(toolCalls)) + t.Logf("βœ… Found expected tools: %v", getKeysFromMap(foundToolNames)) + + for i, toolCall := range toolCalls { + if toolCall.Function.Name != nil { + toolName := *toolCall.Function.Name + // Validate that each tool name is expected + isExpected := slices.Contains(expectedToolNames, toolName) + require.True(t, isExpected, "Unexpected tool call '%s' - expected one of %v", toolName, expectedToolNames) + t.Logf("βœ… Tool call %d for choice %d: %s with args: %s", i+1, choiceIdx, toolName, toolCall.Function.Arguments) + } + } + break // Found a valid choice with multiple tool calls + } + } + } + } + + require.True(t, foundValidMultipleToolCalls, "Expected at least one choice to have 2 or more valid tool calls. Response: %s", GetResultContent(response)) + }) +} diff --git a/tests/core-providers/scenarios/provider_specific.go b/tests/core-providers/scenarios/provider_specific.go new file mode 100644 index 000000000..0c4b3bff0 --- /dev/null +++ b/tests/core-providers/scenarios/provider_specific.go @@ -0,0 +1,55 @@ +package scenarios + +import ( + "context" + "testing" + + "github.com/maximhq/bifrost/tests/core-providers/config" + + bifrost "github.com/maximhq/bifrost/core" + "github.com/maximhq/bifrost/core/schemas" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +// RunProviderSpecificTest executes the provider-specific test scenario +func RunProviderSpecificTest(t *testing.T, client *bifrost.Bifrost, ctx context.Context, testConfig config.ComprehensiveTestConfig) { + if !testConfig.Scenarios.ProviderSpecific { + t.Logf("Provider-specific tests not configured for provider %s", testConfig.Provider) + return + } + + t.Run("ProviderSpecific", func(t *testing.T) { + // This would contain provider-specific tests + // For now, we'll do a basic functionality test + messages := []schemas.BifrostMessage{ + CreateBasicChatMessage("Test provider-specific functionality. What makes you unique?"), + } + + // Initialize with default parameters and merge with custom parameters + defaultParams := &schemas.ModelParameters{ + MaxTokens: bifrost.Ptr(150), + } + params := MergeModelParameters(defaultParams, testConfig.CustomParams) + + request := &schemas.BifrostRequest{ + Provider: testConfig.Provider, + Model: testConfig.ChatModel, + Input: schemas.RequestInput{ + ChatCompletionInput: &messages, + }, + Params: params, + Fallbacks: testConfig.Fallbacks, + } + + response, err := client.ChatCompletionRequest(ctx, request) + require.Nilf(t, err, "Provider-specific test failed: %v", err) + require.NotNil(t, response) + require.NotEmpty(t, response.Choices) + + content := GetResultContent(response) + assert.NotEmpty(t, content, "Response content should not be empty") + + t.Logf("βœ… Provider-specific result for %s: %s", testConfig.Provider, content) + }) +} diff --git a/tests/core-providers/scenarios/simple_chat.go b/tests/core-providers/scenarios/simple_chat.go new file mode 100644 index 000000000..5665e4fc6 --- /dev/null +++ b/tests/core-providers/scenarios/simple_chat.go @@ -0,0 +1,48 @@ +package scenarios + +import ( + "context" + "testing" + + "github.com/maximhq/bifrost/tests/core-providers/config" + + bifrost "github.com/maximhq/bifrost/core" + "github.com/maximhq/bifrost/core/schemas" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +// RunSimpleChatTest executes the simple chat test scenario +func RunSimpleChatTest(t *testing.T, client *bifrost.Bifrost, ctx context.Context, testConfig config.ComprehensiveTestConfig) { + if !testConfig.Scenarios.SimpleChat { + t.Logf("Simple chat not supported for provider %s", testConfig.Provider) + return + } + + t.Run("SimpleChat", func(t *testing.T) { + messages := []schemas.BifrostMessage{ + CreateBasicChatMessage("Hello! What's the capital of France?"), + } + + request := &schemas.BifrostRequest{ + Provider: testConfig.Provider, + Model: testConfig.ChatModel, + Input: schemas.RequestInput{ + ChatCompletionInput: &messages, + }, + Params: MergeModelParameters(&schemas.ModelParameters{ + MaxTokens: bifrost.Ptr(150), + }, testConfig.CustomParams), + Fallbacks: testConfig.Fallbacks, + } + + response, err := client.ChatCompletionRequest(ctx, request) + require.Nilf(t, err, "Simple chat failed: %v", err) + require.NotNil(t, response) + require.NotEmpty(t, response.Choices) + + content := GetResultContent(response) + assert.NotEmpty(t, content, "Response content should not be empty") + t.Logf("βœ… Simple chat result: %s", content) + }) +} diff --git a/tests/core-providers/scenarios/speech_synthesis.go b/tests/core-providers/scenarios/speech_synthesis.go new file mode 100644 index 000000000..6691c4ccc --- /dev/null +++ b/tests/core-providers/scenarios/speech_synthesis.go @@ -0,0 +1,197 @@ +package scenarios + +import ( + "context" + "os" + "path/filepath" + "testing" + + "github.com/maximhq/bifrost/tests/core-providers/config" + + bifrost "github.com/maximhq/bifrost/core" + "github.com/maximhq/bifrost/core/schemas" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +// RunSpeechSynthesisTest executes the speech synthesis test scenario +func RunSpeechSynthesisTest(t *testing.T, client *bifrost.Bifrost, ctx context.Context, testConfig config.ComprehensiveTestConfig) { + if !testConfig.Scenarios.SpeechSynthesis { + t.Logf("Speech synthesis not supported for provider %s", testConfig.Provider) + return + } + + t.Run("SpeechSynthesis", func(t *testing.T) { + // Test with shared text constants for round-trip validation with transcription + testCases := []struct { + name string + text string + voiceType string + format string + expectMinBytes int + saveForSST bool // Whether to save this audio for SST round-trip testing + }{ + { + name: "BasicText_Primary_MP3", + text: TTSTestTextBasic, + voiceType: "primary", + format: "mp3", + expectMinBytes: 1000, + saveForSST: true, + }, + { + name: "MediumText_Secondary_MP3", + text: TTSTestTextMedium, + voiceType: "secondary", + format: "mp3", + expectMinBytes: 2000, + saveForSST: true, + }, + { + name: "TechnicalText_Tertiary_MP3", + text: TTSTestTextTechnical, + voiceType: "tertiary", + format: "mp3", + expectMinBytes: 500, + saveForSST: true, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + voice := GetProviderVoice(testConfig.Provider, tc.voiceType) + request := &schemas.BifrostRequest{ + Provider: testConfig.Provider, + Model: testConfig.SpeechSynthesisModel, // Use configured model + Input: schemas.RequestInput{ + SpeechInput: &schemas.SpeechInput{ + Input: tc.text, + VoiceConfig: schemas.SpeechVoiceInput{ + Voice: &voice, + }, + ResponseFormat: tc.format, + }, + }, + Params: MergeModelParameters(&schemas.ModelParameters{}, testConfig.CustomParams), + Fallbacks: testConfig.Fallbacks, + } + + response, err := client.SpeechRequest(ctx, request) + require.Nilf(t, err, "Speech synthesis failed: %v", err) + require.NotNil(t, response) + require.NotNil(t, response.Speech) + require.NotNil(t, response.Speech.Audio) + + // Validate audio data + assert.Greater(t, len(response.Speech.Audio), tc.expectMinBytes, "Audio data should have minimum expected size") + assert.Equal(t, "audio.speech", response.Object) + assert.Equal(t, testConfig.SpeechSynthesisModel, response.Model) + + // Save audio file for SST round-trip testing if requested + if tc.saveForSST { + tempDir := os.TempDir() + audioFileName := filepath.Join(tempDir, "tts_"+tc.name+"."+tc.format) + + err := os.WriteFile(audioFileName, response.Speech.Audio, 0644) + require.NoError(t, err, "Failed to save audio file for SST testing") + + // Register cleanup to remove temp file + t.Cleanup(func() { + os.Remove(audioFileName) + }) + + t.Logf("πŸ’Ύ Audio saved for SST testing: %s (text: '%s')", audioFileName, tc.text) + } + + t.Logf("βœ… Speech synthesis successful: %d bytes of %s audio generated for voice '%s'", + len(response.Speech.Audio), tc.format, voice) + }) + } + }) +} + +// RunSpeechSynthesisAdvancedTest executes advanced speech synthesis test scenarios +func RunSpeechSynthesisAdvancedTest(t *testing.T, client *bifrost.Bifrost, ctx context.Context, testConfig config.ComprehensiveTestConfig) { + if !testConfig.Scenarios.SpeechSynthesis { + t.Logf("Speech synthesis not supported for provider %s", testConfig.Provider) + return + } + + t.Run("SpeechSynthesisAdvanced", func(t *testing.T) { + t.Run("LongText_HDModel", func(t *testing.T) { + // Test with longer text and HD model + longText := ` + This is a comprehensive test of the text-to-speech functionality using a longer piece of text. + The system should be able to handle multiple sentences, proper punctuation, and maintain + consistent voice quality throughout the entire speech generation process. This test ensures + that the speech synthesis can handle realistic use cases with substantial content. + ` + + voice := "shimmer" + request := &schemas.BifrostRequest{ + Provider: testConfig.Provider, + Model: "tts-1-hd", // Test with HD model + Input: schemas.RequestInput{ + SpeechInput: &schemas.SpeechInput{ + Input: longText, + VoiceConfig: schemas.SpeechVoiceInput{ + Voice: &voice, + }, + ResponseFormat: "mp3", + Instructions: "Speak slowly and clearly with natural intonation.", + }, + }, + Params: MergeModelParameters(&schemas.ModelParameters{}, testConfig.CustomParams), + Fallbacks: testConfig.Fallbacks, + } + + response, err := client.SpeechRequest(ctx, request) + require.Nilf(t, err, "HD speech synthesis failed: %v", err) + require.NotNil(t, response) + require.NotNil(t, response.Speech) + require.NotNil(t, response.Speech.Audio) + + // Validate longer audio + assert.Greater(t, len(response.Speech.Audio), 5000, "HD audio should be substantial") + assert.Equal(t, "tts-1-hd", response.Model) + + t.Logf("βœ… HD speech synthesis successful: %d bytes generated", len(response.Speech.Audio)) + }) + + t.Run("AllVoiceOptions", func(t *testing.T) { + // Test provider-specific voice options + voiceTypes := []string{"primary", "secondary", "tertiary"} + testText := TTSTestTextBasic // Use shared constant + + for _, voiceType := range voiceTypes { + t.Run("VoiceType_"+voiceType, func(t *testing.T) { + voice := GetProviderVoice(testConfig.Provider, voiceType) + request := &schemas.BifrostRequest{ + Provider: testConfig.Provider, + Model: testConfig.SpeechSynthesisModel, + Input: schemas.RequestInput{ + SpeechInput: &schemas.SpeechInput{ + Input: testText, + VoiceConfig: schemas.SpeechVoiceInput{ + Voice: &voice, + }, + ResponseFormat: "mp3", + }, + }, + Params: MergeModelParameters(&schemas.ModelParameters{}, testConfig.CustomParams), + Fallbacks: testConfig.Fallbacks, + } + + response, err := client.SpeechRequest(ctx, request) + require.Nilf(t, err, "Speech synthesis failed for voice %s (%s): %v", voice, voiceType, err) + require.NotNil(t, response) + require.NotNil(t, response.Speech) + require.NotNil(t, response.Speech.Audio) + + assert.Greater(t, len(response.Speech.Audio), 500, "Audio should be generated for voice %s", voice) + t.Logf("βœ… Voice %s (%s): %d bytes generated", voice, voiceType, len(response.Speech.Audio)) + }) + } + }) + }) +} diff --git a/tests/core-providers/scenarios/speech_synthesis_stream.go b/tests/core-providers/scenarios/speech_synthesis_stream.go new file mode 100644 index 000000000..7c45fd9d9 --- /dev/null +++ b/tests/core-providers/scenarios/speech_synthesis_stream.go @@ -0,0 +1,274 @@ +package scenarios + +import ( + "context" + "strings" + "testing" + "time" + + "github.com/maximhq/bifrost/tests/core-providers/config" + + bifrost "github.com/maximhq/bifrost/core" + "github.com/maximhq/bifrost/core/schemas" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +// RunSpeechSynthesisStreamTest executes the streaming speech synthesis test scenario +func RunSpeechSynthesisStreamTest(t *testing.T, client *bifrost.Bifrost, ctx context.Context, testConfig config.ComprehensiveTestConfig) { + if !testConfig.Scenarios.SpeechSynthesisStream { + t.Logf("Speech synthesis streaming not supported for provider %s", testConfig.Provider) + return + } + + t.Run("SpeechSynthesisStream", func(t *testing.T) { + // Test streaming with different text lengths + testCases := []struct { + name string + text string + voice string + format string + expectMinChunks int + expectMinBytes int + }{ + { + name: "ShortText_Streaming", + text: "This is a short text for streaming speech synthesis test.", + voice: "alloy", + format: "mp3", + expectMinChunks: 1, + expectMinBytes: 1000, + }, + { + name: "LongText_Streaming", + text: `This is a longer text to test streaming speech synthesis functionality. + The streaming should provide audio chunks as they are generated, allowing for + real-time playback while the rest of the audio is still being processed. + This enables better user experience with reduced latency.`, + voice: "nova", + format: "mp3", + expectMinChunks: 2, + expectMinBytes: 3000, + }, + { + name: "MediumText_Echo_WAV", + text: "Testing streaming with WAV format. This should produce multiple audio chunks in WAV format for streaming playback.", + voice: "echo", + format: "wav", + expectMinChunks: 1, + expectMinBytes: 2000, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + voice := tc.voice + request := &schemas.BifrostRequest{ + Provider: testConfig.Provider, + Model: "gpt-4o-mini-tts", + Input: schemas.RequestInput{ + SpeechInput: &schemas.SpeechInput{ + Input: tc.text, + VoiceConfig: schemas.SpeechVoiceInput{ + Voice: &voice, + }, + ResponseFormat: tc.format, + }, + }, + Params: MergeModelParameters(&schemas.ModelParameters{}, testConfig.CustomParams), + Fallbacks: testConfig.Fallbacks, + } + + // Test streaming response + responseChannel, err := client.SpeechStreamRequest(ctx, request) + require.Nilf(t, err, "Speech synthesis stream failed: %v", err) + require.NotNil(t, responseChannel, "Response channel should not be nil") + + var totalAudioBytes []byte + var chunkCount int + var lastResponse *schemas.BifrostStream + + // Create a timeout context for the stream reading + streamCtx, cancel := context.WithTimeout(ctx, 30*time.Second) + defer cancel() + + // Read streaming chunks + for { + select { + case response, ok := <-responseChannel: + if !ok { + // Channel closed, streaming complete + goto streamComplete + } + + require.NotNil(t, response, "Stream response should not be nil") + + // Check for errors in stream + if response.BifrostError != nil { + t.Fatalf("Error in stream: %v", response.BifrostError) + } + + require.NotNil(t, response.Speech, "Speech data should be present in stream") + + // Collect audio chunks + if response.Speech.Audio != nil { + totalAudioBytes = append(totalAudioBytes, response.Speech.Audio...) + chunkCount++ + t.Logf("Received audio chunk %d: %d bytes", chunkCount, len(response.Speech.Audio)) + } + + // Validate stream response structure + assert.Equal(t, "audio.speech.chunk", response.Object) + assert.Equal(t, "gpt-4o-mini-tts", response.Model) + assert.Equal(t, testConfig.Provider, response.ExtraFields.Provider) + + lastResponse = response + + case <-streamCtx.Done(): + t.Fatal("Stream reading timed out") + } + } + + streamComplete: + // Validate streaming results + assert.GreaterOrEqual(t, chunkCount, tc.expectMinChunks, "Should receive minimum expected chunks") + assert.Greater(t, len(totalAudioBytes), tc.expectMinBytes, "Total audio should meet minimum size") + assert.NotNil(t, lastResponse, "Should have received at least one response") + + t.Logf("βœ… Streaming speech synthesis successful: %d chunks, %d total bytes for voice '%s' in %s format", + chunkCount, len(totalAudioBytes), tc.voice, tc.format) + }) + } + }) +} + +// RunSpeechSynthesisStreamAdvancedTest executes advanced streaming speech synthesis test scenarios +func RunSpeechSynthesisStreamAdvancedTest(t *testing.T, client *bifrost.Bifrost, ctx context.Context, testConfig config.ComprehensiveTestConfig) { + if !testConfig.Scenarios.SpeechSynthesisStream { + t.Logf("Speech synthesis streaming not supported for provider %s", testConfig.Provider) + return + } + + t.Run("SpeechSynthesisStreamAdvanced", func(t *testing.T) { + t.Run("LongText_HDModel_Streaming", func(t *testing.T) { + // Test streaming with HD model and very long text + finalText := "" + for i := 1; i <= 20; i++ { + finalText += strings.Replace("This is sentence number %d in a very long text for testing streaming speech synthesis with the HD model. ", "%d", string(rune('0'+i%10)), -1) + } + + voice := "shimmer" + request := &schemas.BifrostRequest{ + Provider: testConfig.Provider, + Model: "gpt-4o-mini-tts", + Input: schemas.RequestInput{ + SpeechInput: &schemas.SpeechInput{ + Input: finalText, + VoiceConfig: schemas.SpeechVoiceInput{ + Voice: &voice, + }, + ResponseFormat: "mp3", + Instructions: "Speak at a natural pace with clear pronunciation.", + }, + }, + Params: MergeModelParameters(&schemas.ModelParameters{}, testConfig.CustomParams), + Fallbacks: testConfig.Fallbacks, + } + + responseChannel, err := client.SpeechStreamRequest(ctx, request) + require.Nilf(t, err, "HD streaming speech synthesis failed: %v", err) + + var totalBytes int + var chunkCount int + streamCtx, cancel := context.WithTimeout(ctx, 60*time.Second) // Longer timeout for HD model + defer cancel() + + for { + select { + case response, ok := <-responseChannel: + if !ok { + goto hdStreamComplete + } + + if response.BifrostError != nil { + t.Fatalf("Error in HD stream: %v", response.BifrostError) + } + + if response.Speech != nil && response.Speech.Audio != nil { + totalBytes += len(response.Speech.Audio) + chunkCount++ + } + + assert.Equal(t, "gpt-4o-mini-tts", response.Model) + + case <-streamCtx.Done(): + t.Fatal("HD stream reading timed out") + } + } + + hdStreamComplete: + assert.Greater(t, chunkCount, 3, "HD model should produce multiple chunks for long text") + assert.Greater(t, totalBytes, 10000, "HD model should produce substantial audio data") + + t.Logf("βœ… HD streaming successful: %d chunks, %d total bytes", chunkCount, totalBytes) + }) + + t.Run("MultipleVoices_Streaming", func(t *testing.T) { + // Test streaming with all available voices + voices := []string{"alloy", "echo", "fable", "onyx", "nova", "shimmer"} + testText := "Testing streaming speech synthesis with different voice options." + + for _, voice := range voices { + t.Run("StreamingVoice_"+voice, func(t *testing.T) { + voiceCopy := voice + request := &schemas.BifrostRequest{ + Provider: testConfig.Provider, + Model: "gpt-4o-mini-tts", + Input: schemas.RequestInput{ + SpeechInput: &schemas.SpeechInput{ + Input: testText, + VoiceConfig: schemas.SpeechVoiceInput{ + Voice: &voiceCopy, + }, + ResponseFormat: "mp3", + }, + }, + Params: MergeModelParameters(&schemas.ModelParameters{}, testConfig.CustomParams), + Fallbacks: testConfig.Fallbacks, + } + + responseChannel, err := client.SpeechStreamRequest(ctx, request) + require.Nilf(t, err, "Streaming failed for voice %s: %v", voice, err) + + var receivedData bool + streamCtx, cancel := context.WithTimeout(ctx, 30*time.Second) + defer cancel() + + for { + select { + case response, ok := <-responseChannel: + if !ok { + goto voiceStreamComplete + } + + if response.BifrostError != nil { + t.Fatalf("Error in stream for voice %s: %v", voice, response.BifrostError) + } + + if response.Speech != nil && response.Speech.Audio != nil && len(response.Speech.Audio) > 0 { + receivedData = true + } + + case <-streamCtx.Done(): + t.Fatalf("Stream timed out for voice %s", voice) + } + } + + voiceStreamComplete: + assert.True(t, receivedData, "Should receive audio data for voice %s", voice) + t.Logf("βœ… Streaming successful for voice: %s", voice) + }) + } + }) + }) +} diff --git a/tests/core-providers/scenarios/text_completion.go b/tests/core-providers/scenarios/text_completion.go new file mode 100644 index 000000000..ead8f5fe4 --- /dev/null +++ b/tests/core-providers/scenarios/text_completion.go @@ -0,0 +1,45 @@ +package scenarios + +import ( + "context" + "testing" + + "github.com/maximhq/bifrost/tests/core-providers/config" + + bifrost "github.com/maximhq/bifrost/core" + "github.com/maximhq/bifrost/core/schemas" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +// RunTextCompletionTest tests text completion functionality +func RunTextCompletionTest(t *testing.T, client *bifrost.Bifrost, ctx context.Context, testConfig config.ComprehensiveTestConfig) { + if !testConfig.Scenarios.TextCompletion || testConfig.TextModel == "" { + t.Logf("⏭️ Text completion not supported for provider %s", testConfig.Provider) + return + } + + t.Run("TextCompletion", func(t *testing.T) { + prompt := "The future of artificial intelligence is" + request := &schemas.BifrostRequest{ + Provider: testConfig.Provider, + Model: testConfig.TextModel, + Input: schemas.RequestInput{ + TextCompletionInput: &prompt, + }, + Params: MergeModelParameters(&schemas.ModelParameters{ + MaxTokens: bifrost.Ptr(100), + }, testConfig.CustomParams), + Fallbacks: testConfig.Fallbacks, + } + + response, err := client.TextCompletionRequest(ctx, request) + require.Nilf(t, err, "Text completion failed: %v", err) + require.NotNil(t, response) + require.NotEmpty(t, response.Choices) + + content := GetResultContent(response) + assert.NotEmpty(t, content, "Response content should not be empty") + t.Logf("βœ… Text completion result: %s", content) + }) +} diff --git a/tests/core-providers/scenarios/tool_calls.go b/tests/core-providers/scenarios/tool_calls.go new file mode 100644 index 000000000..486b1e094 --- /dev/null +++ b/tests/core-providers/scenarios/tool_calls.go @@ -0,0 +1,79 @@ +package scenarios + +import ( + "context" + "encoding/json" + "testing" + + "github.com/maximhq/bifrost/tests/core-providers/config" + + bifrost "github.com/maximhq/bifrost/core" + "github.com/maximhq/bifrost/core/schemas" + "github.com/stretchr/testify/require" +) + +// RunToolCallsTest executes the tool calls test scenario +func RunToolCallsTest(t *testing.T, client *bifrost.Bifrost, ctx context.Context, testConfig config.ComprehensiveTestConfig) { + if !testConfig.Scenarios.ToolCalls { + t.Logf("Tool calls not supported for provider %s", testConfig.Provider) + return + } + + t.Run("ToolCalls", func(t *testing.T) { + messages := []schemas.BifrostMessage{ + CreateBasicChatMessage("What's the weather like in New York? answer in celsius"), + } + + params := MergeModelParameters(&schemas.ModelParameters{ + Tools: &[]schemas.Tool{WeatherToolDefinition}, + MaxTokens: bifrost.Ptr(150), + }, testConfig.CustomParams) + + request := &schemas.BifrostRequest{ + Provider: testConfig.Provider, + Model: testConfig.ChatModel, + Input: schemas.RequestInput{ + ChatCompletionInput: &messages, + }, + Params: params, + Fallbacks: testConfig.Fallbacks, + } + + response, err := client.ChatCompletionRequest(ctx, request) + require.Nilf(t, err, "Tool calls failed: %v", err) + require.NotNil(t, response) + require.NotEmpty(t, response.Choices) + + // Find at least one choice with valid tool calls + foundValidToolCall := false + for i, choice := range response.Choices { + message := choice.Message + if message.AssistantMessage != nil && message.AssistantMessage.ToolCalls != nil { + toolCalls := *message.AssistantMessage.ToolCalls + // Iterate through all tool calls, not just the first one + for j, toolCall := range toolCalls { + if toolCall.Function.Name != nil && *toolCall.Function.Name == "get_weather" { + // Verify arguments contain location + var args map[string]interface{} + err := json.Unmarshal([]byte(toolCall.Function.Arguments), &args) + if err == nil { + if _, hasLocation := args["location"]; hasLocation { + foundValidToolCall = true + t.Logf("βœ… Tool call arguments for choice %d, tool call %d: %s", i, j, toolCall.Function.Arguments) + break // Found valid tool call, can break from this inner loop + } + } + } + } + if foundValidToolCall { + break // Found valid tool call, can break from choices loop + } + } + } + + if !foundValidToolCall { + t.Logf("❌ No valid tool calls found in any choice, response: %s", GetResultContent(response)) + } + require.True(t, foundValidToolCall, "Expected at least one choice to have valid tool call for 'get_weather' with 'location' argument") + }) +} diff --git a/tests/core-providers/scenarios/transcription.go b/tests/core-providers/scenarios/transcription.go new file mode 100644 index 000000000..85520c5ac --- /dev/null +++ b/tests/core-providers/scenarios/transcription.go @@ -0,0 +1,342 @@ +package scenarios + +import ( + "context" + "os" + "path/filepath" + "strings" + "testing" + + "github.com/maximhq/bifrost/tests/core-providers/config" + + bifrost "github.com/maximhq/bifrost/core" + "github.com/maximhq/bifrost/core/schemas" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +// RunTranscriptionTest executes the transcription test scenario +func RunTranscriptionTest(t *testing.T, client *bifrost.Bifrost, ctx context.Context, testConfig config.ComprehensiveTestConfig) { + if !testConfig.Scenarios.Transcription { + t.Logf("Transcription not supported for provider %s", testConfig.Provider) + return + } + + t.Run("Transcription", func(t *testing.T) { + // First generate TTS audio for round-trip validation + roundTripCases := []struct { + name string + text string + voiceType string + format string + responseFormat *string + }{ + { + name: "RoundTrip_Basic_MP3", + text: TTSTestTextBasic, + voiceType: "primary", + format: "mp3", + responseFormat: bifrost.Ptr("json"), + }, + { + name: "RoundTrip_Medium_MP3", + text: TTSTestTextMedium, + voiceType: "secondary", + format: "mp3", + responseFormat: bifrost.Ptr("json"), + }, + { + name: "RoundTrip_Technical_MP3", + text: TTSTestTextTechnical, + voiceType: "tertiary", + format: "mp3", + responseFormat: bifrost.Ptr("json"), + }, + } + + for _, tc := range roundTripCases { + t.Run(tc.name, func(t *testing.T) { + // Step 1: Generate TTS audio + voice := GetProviderVoice(testConfig.Provider, tc.voiceType) + ttsRequest := &schemas.BifrostRequest{ + Provider: testConfig.Provider, + Model: testConfig.SpeechSynthesisModel, + Input: schemas.RequestInput{ + SpeechInput: &schemas.SpeechInput{ + Input: tc.text, + VoiceConfig: schemas.SpeechVoiceInput{ + Voice: &voice, + }, + ResponseFormat: tc.format, + }, + }, + Params: MergeModelParameters(&schemas.ModelParameters{}, testConfig.CustomParams), + Fallbacks: testConfig.Fallbacks, + } + + ttsResponse, err := client.SpeechRequest(ctx, ttsRequest) + require.Nilf(t, err, "TTS generation failed for round-trip test: %v", err) + require.NotNil(t, ttsResponse.Speech) + require.NotNil(t, ttsResponse.Speech.Audio) + require.Greater(t, len(ttsResponse.Speech.Audio), 0, "TTS returned empty audio") + + // Save temp audio file + tempDir := os.TempDir() + audioFileName := filepath.Join(tempDir, "roundtrip_"+tc.name+"."+tc.format) + writeErr := os.WriteFile(audioFileName, ttsResponse.Speech.Audio, 0644) + require.NoError(t, writeErr, "Failed to save temp audio file") + + // Register cleanup + t.Cleanup(func() { + os.Remove(audioFileName) + }) + + t.Logf("πŸ”„ Generated TTS audio for round-trip: %s (%d bytes)", audioFileName, len(ttsResponse.Speech.Audio)) + + // Step 2: Transcribe the generated audio + transcriptionRequest := &schemas.BifrostRequest{ + Provider: testConfig.Provider, + Model: testConfig.TranscriptionModel, + Input: schemas.RequestInput{ + TranscriptionInput: &schemas.TranscriptionInput{ + File: ttsResponse.Speech.Audio, + Language: bifrost.Ptr("en"), + Format: bifrost.Ptr("mp3"), + ResponseFormat: tc.responseFormat, + }, + }, + Params: MergeModelParameters(&schemas.ModelParameters{ + Temperature: bifrost.Ptr(0.0), // Deterministic + }, testConfig.CustomParams), + Fallbacks: testConfig.Fallbacks, + } + + transcriptionResponse, err := client.TranscriptionRequest(ctx, transcriptionRequest) + require.Nilf(t, err, "Transcription failed for round-trip test: %v", err) + require.NotNil(t, transcriptionResponse) + require.NotNil(t, transcriptionResponse.Transcribe) + + // Validate round-trip: check if transcribed text contains key words from original + transcribedText := transcriptionResponse.Transcribe.Text + require.NotEmpty(t, transcribedText, "Transcribed text should not be empty") + + // Normalize for comparison (lowercase, remove punctuation) + originalWords := strings.Fields(strings.ToLower(tc.text)) + transcribedWords := strings.Fields(strings.ToLower(transcribedText)) + + // Check that at least 50% of original words are found in transcription + foundWords := 0 + for _, originalWord := range originalWords { + // Remove punctuation for comparison + cleanOriginal := strings.Trim(originalWord, ".,!?;:") + if len(cleanOriginal) < 3 { // Skip very short words + continue + } + + for _, transcribedWord := range transcribedWords { + cleanTranscribed := strings.Trim(transcribedWord, ".,!?;:") + if strings.Contains(cleanTranscribed, cleanOriginal) || strings.Contains(cleanOriginal, cleanTranscribed) { + foundWords++ + break + } + } + } + + // Expect at least 50% word match for successful round-trip + minExpectedWords := len(originalWords) / 2 + assert.GreaterOrEqual(t, foundWords, minExpectedWords, + "Round-trip failed: original='%s', transcribed='%s', found %d/%d words", + tc.text, transcribedText, foundWords, len(originalWords)) + + // Validate response structure + assert.Equal(t, "audio.transcription", transcriptionResponse.Object) + assert.Equal(t, testConfig.TranscriptionModel, transcriptionResponse.Model) + assert.Equal(t, testConfig.Provider, transcriptionResponse.ExtraFields.Provider) + + // For verbose_json format, check additional fields + if tc.responseFormat != nil && *tc.responseFormat == "verbose_json" { + assert.NotNil(t, transcriptionResponse.Transcribe.BifrostTranscribeNonStreamResponse) + if transcriptionResponse.Transcribe.Task != nil { + assert.Equal(t, "transcribe", *transcriptionResponse.Transcribe.Task) + } + if transcriptionResponse.Transcribe.Language != nil { + assert.NotEmpty(t, *transcriptionResponse.Transcribe.Language) + } + } + + t.Logf("βœ… Round-trip successful: '%s' β†’ TTS β†’ SST β†’ '%s' (found %d/%d words)", + tc.text, transcribedText, foundWords, len(originalWords)) + }) + } + + // Additional test cases using the utility function for edge cases + t.Run("AdditionalAudioTests", func(t *testing.T) { + // Test with custom generated audio for specific scenarios + customCases := []struct { + name string + text string + language *string + responseFormat *string + }{ + { + name: "Numbers_And_Punctuation", + text: "Testing numbers 1, 2, 3 and punctuation marks! Question?", + language: bifrost.Ptr("en"), + responseFormat: bifrost.Ptr("json"), + }, + { + name: "Technical_Terms", + text: "API gateway processes HTTP requests with JSON payloads", + language: bifrost.Ptr("en"), + responseFormat: bifrost.Ptr("json"), + }, + } + + for _, tc := range customCases { + t.Run(tc.name, func(t *testing.T) { + // Use the utility function to generate audio + audioData, _ := GenerateTTSAudioForTest(ctx, t, client, testConfig.Provider, testConfig.SpeechSynthesisModel, tc.text, "primary", "mp3") + + // Test transcription + request := &schemas.BifrostRequest{ + Provider: testConfig.Provider, + Model: testConfig.TranscriptionModel, + Input: schemas.RequestInput{ + TranscriptionInput: &schemas.TranscriptionInput{ + File: audioData, + Language: tc.language, + Format: bifrost.Ptr("mp3"), + ResponseFormat: tc.responseFormat, + }, + }, + Params: MergeModelParameters(&schemas.ModelParameters{ + Temperature: bifrost.Ptr(0.0), + }, testConfig.CustomParams), + Fallbacks: testConfig.Fallbacks, + } + + response, err := client.TranscriptionRequest(ctx, request) + require.Nilf(t, err, "Custom transcription failed: %v", err) + require.NotNil(t, response.Transcribe) + assert.NotEmpty(t, response.Transcribe.Text) + + t.Logf("βœ… Custom transcription successful: '%s' β†’ '%s'", tc.text, response.Transcribe.Text) + }) + } + }) + }) +} + +// RunTranscriptionAdvancedTest executes advanced transcription test scenarios +func RunTranscriptionAdvancedTest(t *testing.T, client *bifrost.Bifrost, ctx context.Context, testConfig config.ComprehensiveTestConfig) { + if !testConfig.Scenarios.Transcription { + t.Logf("Transcription not supported for provider %s", testConfig.Provider) + return + } + + t.Run("TranscriptionAdvanced", func(t *testing.T) { + t.Run("AllResponseFormats", func(t *testing.T) { + // Generate audio first for all format tests + audioData, _ := GenerateTTSAudioForTest(ctx, t, client, testConfig.Provider, testConfig.SpeechSynthesisModel, TTSTestTextBasic, "primary", "mp3") + + // Test supported response formats (excluding text to avoid JSON parsing issues) + formats := []string{"json", "verbose_json"} + + for _, format := range formats { + t.Run("Format_"+format, func(t *testing.T) { + formatCopy := format + request := &schemas.BifrostRequest{ + Provider: testConfig.Provider, + Model: testConfig.TranscriptionModel, + Input: schemas.RequestInput{ + TranscriptionInput: &schemas.TranscriptionInput{ + File: audioData, + Format: bifrost.Ptr("mp3"), + ResponseFormat: &formatCopy, + }, + }, + Params: MergeModelParameters(&schemas.ModelParameters{}, testConfig.CustomParams), + Fallbacks: testConfig.Fallbacks, + } + + response, err := client.TranscriptionRequest(ctx, request) + require.Nilf(t, err, "Transcription failed for format %s: %v", format, err) + require.NotNil(t, response) + require.NotNil(t, response.Transcribe) + + // All formats should return some text + assert.NotEmpty(t, response.Transcribe.Text) + + t.Logf("βœ… Format %s successful: '%s'", format, response.Transcribe.Text) + }) + } + }) + + t.Run("WithCustomParameters", func(t *testing.T) { + // Generate audio for custom parameters test + audioData, _ := GenerateTTSAudioForTest(ctx, t, client, testConfig.Provider, testConfig.SpeechSynthesisModel, TTSTestTextMedium, "secondary", "mp3") + + // Test with custom parameters and temperature + request := &schemas.BifrostRequest{ + Provider: testConfig.Provider, + Model: testConfig.TranscriptionModel, + Input: schemas.RequestInput{ + TranscriptionInput: &schemas.TranscriptionInput{ + File: audioData, + Language: bifrost.Ptr("en"), + Format: bifrost.Ptr("mp3"), + Prompt: bifrost.Ptr("This audio contains technical terminology and proper nouns."), + ResponseFormat: bifrost.Ptr("json"), // Use json instead of verbose_json for whisper-1 + }, + }, + Params: MergeModelParameters(&schemas.ModelParameters{ + Temperature: bifrost.Ptr(0.2), + }, testConfig.CustomParams), + Fallbacks: testConfig.Fallbacks, + } + + response, err := client.TranscriptionRequest(ctx, request) + require.Nilf(t, err, "Advanced transcription failed: %v", err) + require.NotNil(t, response) + require.NotNil(t, response.Transcribe) + assert.NotEmpty(t, response.Transcribe.Text) + + t.Logf("βœ… Advanced transcription successful: '%s'", response.Transcribe.Text) + }) + + t.Run("MultipleLanguages", func(t *testing.T) { + // Generate audio for language tests + audioData, _ := GenerateTTSAudioForTest(ctx, t, client, testConfig.Provider, testConfig.SpeechSynthesisModel, TTSTestTextBasic, "primary", "mp3") + + // Test with different language hints (only English for now since our TTS is English) + languages := []string{"en"} + + for _, lang := range languages { + t.Run("Language_"+lang, func(t *testing.T) { + langCopy := lang + request := &schemas.BifrostRequest{ + Provider: testConfig.Provider, + Model: testConfig.TranscriptionModel, + Input: schemas.RequestInput{ + TranscriptionInput: &schemas.TranscriptionInput{ + File: audioData, + Format: bifrost.Ptr("mp3"), + Language: &langCopy, + }, + }, + Params: MergeModelParameters(&schemas.ModelParameters{}, testConfig.CustomParams), + Fallbacks: testConfig.Fallbacks, + } + + response, err := client.TranscriptionRequest(ctx, request) + require.Nilf(t, err, "Transcription failed for language %s: %v", lang, err) + require.NotNil(t, response) + require.NotNil(t, response.Transcribe) + + assert.NotEmpty(t, response.Transcribe.Text) + t.Logf("βœ… Language %s transcription successful: '%s'", lang, response.Transcribe.Text) + }) + } + }) + }) +} diff --git a/tests/core-providers/scenarios/transcription_stream.go b/tests/core-providers/scenarios/transcription_stream.go new file mode 100644 index 000000000..feb8a9410 --- /dev/null +++ b/tests/core-providers/scenarios/transcription_stream.go @@ -0,0 +1,397 @@ +package scenarios + +import ( + "context" + "os" + "path/filepath" + "strings" + "testing" + "time" + + "github.com/maximhq/bifrost/tests/core-providers/config" + + bifrost "github.com/maximhq/bifrost/core" + "github.com/maximhq/bifrost/core/schemas" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +// RunTranscriptionStreamTest executes the streaming transcription test scenario +func RunTranscriptionStreamTest(t *testing.T, client *bifrost.Bifrost, ctx context.Context, testConfig config.ComprehensiveTestConfig) { + if !testConfig.Scenarios.TranscriptionStream { + t.Logf("Transcription streaming not supported for provider %s", testConfig.Provider) + return + } + + t.Run("TranscriptionStream", func(t *testing.T) { + // Generate TTS audio for streaming round-trip validation + streamRoundTripCases := []struct { + name string + text string + voiceType string + format string + responseFormat *string + expectChunks int + }{ + { + name: "StreamRoundTrip_Basic_MP3", + text: TTSTestTextBasic, + voiceType: "primary", + format: "mp3", + responseFormat: nil, // Default JSON streaming + expectChunks: 1, + }, + { + name: "StreamRoundTrip_Medium_MP3", + text: TTSTestTextMedium, + voiceType: "secondary", + format: "mp3", + responseFormat: bifrost.Ptr("json"), + expectChunks: 1, + }, + { + name: "StreamRoundTrip_Technical_MP3", + text: TTSTestTextTechnical, + voiceType: "tertiary", + format: "mp3", + responseFormat: bifrost.Ptr("json"), + expectChunks: 1, + }, + } + + for _, tc := range streamRoundTripCases { + t.Run(tc.name, func(t *testing.T) { + // Step 1: Generate TTS audio + voice := GetProviderVoice(testConfig.Provider, tc.voiceType) + ttsRequest := &schemas.BifrostRequest{ + Provider: testConfig.Provider, + Model: testConfig.SpeechSynthesisModel, + Input: schemas.RequestInput{ + SpeechInput: &schemas.SpeechInput{ + Input: tc.text, + VoiceConfig: schemas.SpeechVoiceInput{ + Voice: &voice, + }, + ResponseFormat: tc.format, + }, + }, + Params: MergeModelParameters(&schemas.ModelParameters{}, testConfig.CustomParams), + Fallbacks: testConfig.Fallbacks, + } + + ttsResponse, err := client.SpeechRequest(ctx, ttsRequest) + require.Nilf(t, err, "TTS generation failed for stream round-trip test: %v", err) + require.NotNil(t, ttsResponse.Speech) + require.NotNil(t, ttsResponse.Speech.Audio) + require.Greater(t, len(ttsResponse.Speech.Audio), 0, "TTS returned empty audio") + + // Save temp audio file + tempDir := os.TempDir() + audioFileName := filepath.Join(tempDir, "stream_roundtrip_"+tc.name+"."+tc.format) + writeErr := os.WriteFile(audioFileName, ttsResponse.Speech.Audio, 0644) + require.NoError(t, writeErr, "Failed to save temp audio file") + + // Register cleanup + t.Cleanup(func() { + os.Remove(audioFileName) + }) + + t.Logf("πŸ”„ Generated TTS audio for stream round-trip: %s (%d bytes)", audioFileName, len(ttsResponse.Speech.Audio)) + + // Step 2: Test streaming transcription + streamRequest := &schemas.BifrostRequest{ + Provider: testConfig.Provider, + Model: testConfig.TranscriptionModel, + Input: schemas.RequestInput{ + TranscriptionInput: &schemas.TranscriptionInput{ + File: ttsResponse.Speech.Audio, + Language: bifrost.Ptr("en"), + Format: bifrost.Ptr(tc.format), + ResponseFormat: tc.responseFormat, + }, + }, + Params: MergeModelParameters(&schemas.ModelParameters{ + Temperature: bifrost.Ptr(0.0), // More deterministic output + }, testConfig.CustomParams), + Fallbacks: testConfig.Fallbacks, + } + + // Test streaming response + responseChannel, err := client.TranscriptionStreamRequest(ctx, streamRequest) + require.Nilf(t, err, "Transcription stream failed: %v", err) + require.NotNil(t, responseChannel, "Response channel should not be nil") + + var fullTranscriptionText string + var chunkCount int + var lastResponse *schemas.BifrostStream + + // Create a timeout context for the stream reading + streamCtx, cancel := context.WithTimeout(ctx, 30*time.Second) + defer cancel() + + // Read streaming chunks + for { + select { + case response, ok := <-responseChannel: + if !ok { + // Channel closed, streaming complete + goto streamComplete + } + + require.NotNil(t, response, "Stream response should not be nil") + + // Check for errors in stream + if response.BifrostError != nil { + t.Fatalf("Error in stream: %v", response.BifrostError) + } + + require.NotNil(t, response.Transcribe, "Transcribe data should be present in stream") + + // Collect transcription chunks + if response.Transcribe.Text != "" { + if response.Transcribe.BifrostTranscribeStreamResponse != nil && + response.Transcribe.BifrostTranscribeStreamResponse.Delta != nil { + // This is a delta chunk + fullTranscriptionText += *response.Transcribe.BifrostTranscribeStreamResponse.Delta + } else { + // This is a complete text chunk + fullTranscriptionText += response.Transcribe.Text + } + chunkCount++ + t.Logf("Received transcription chunk %d: '%s'", chunkCount, response.Transcribe.Text) + } + + // Validate stream response structure + assert.Equal(t, "audio.transcription.chunk", response.Object) + assert.Equal(t, testConfig.TranscriptionModel, response.Model) + assert.Equal(t, testConfig.Provider, response.ExtraFields.Provider) + + lastResponse = response + + case <-streamCtx.Done(): + t.Fatal("Stream reading timed out") + } + } + + streamComplete: + // Validate streaming results + assert.GreaterOrEqual(t, chunkCount, tc.expectChunks, "Should receive minimum expected chunks") + assert.NotNil(t, lastResponse, "Should have received at least one response") + + // Validate round-trip: check if transcribed text contains key words from original + require.NotEmpty(t, fullTranscriptionText, "Transcribed text should not be empty") + + // Normalize for comparison (lowercase, remove punctuation) + originalWords := strings.Fields(strings.ToLower(tc.text)) + transcribedWords := strings.Fields(strings.ToLower(fullTranscriptionText)) + + // Check that at least 50% of original words are found in transcription + foundWords := 0 + for _, originalWord := range originalWords { + // Remove punctuation for comparison + cleanOriginal := strings.Trim(originalWord, ".,!?;:") + if len(cleanOriginal) < 3 { // Skip very short words + continue + } + + for _, transcribedWord := range transcribedWords { + cleanTranscribed := strings.Trim(transcribedWord, ".,!?;:") + if strings.Contains(cleanTranscribed, cleanOriginal) || strings.Contains(cleanOriginal, cleanTranscribed) { + foundWords++ + break + } + } + } + + // Expect at least 50% word match for successful round-trip + minExpectedWords := len(originalWords) / 2 + assert.GreaterOrEqual(t, foundWords, minExpectedWords, + "Stream round-trip failed: original='%s', transcribed='%s', found %d/%d words", + tc.text, fullTranscriptionText, foundWords, len(originalWords)) + + t.Logf("βœ… Stream round-trip successful: '%s' β†’ TTS β†’ SST β†’ '%s' (%d chunks, found %d/%d words)", + tc.text, fullTranscriptionText, chunkCount, foundWords, len(originalWords)) + }) + } + }) +} + +// RunTranscriptionStreamAdvancedTest executes advanced streaming transcription test scenarios +func RunTranscriptionStreamAdvancedTest(t *testing.T, client *bifrost.Bifrost, ctx context.Context, testConfig config.ComprehensiveTestConfig) { + if !testConfig.Scenarios.TranscriptionStream { + t.Logf("Transcription streaming not supported for provider %s", testConfig.Provider) + return + } + + t.Run("TranscriptionStreamAdvanced", func(t *testing.T) { + t.Run("JSONStreaming", func(t *testing.T) { + // Generate audio for streaming test + audioData, _ := GenerateTTSAudioForTest(ctx, t, client, testConfig.Provider, testConfig.SpeechSynthesisModel, TTSTestTextBasic, "primary", "mp3") + + // Test streaming with JSON format + request := &schemas.BifrostRequest{ + Provider: testConfig.Provider, + Model: testConfig.TranscriptionModel, + Input: schemas.RequestInput{ + TranscriptionInput: &schemas.TranscriptionInput{ + File: audioData, + Language: bifrost.Ptr("en"), + Format: bifrost.Ptr("mp3"), + ResponseFormat: bifrost.Ptr("json"), + }, + }, + Params: MergeModelParameters(&schemas.ModelParameters{}, testConfig.CustomParams), + Fallbacks: testConfig.Fallbacks, + } + + responseChannel, err := client.TranscriptionStreamRequest(ctx, request) + require.Nilf(t, err, "JSON streaming failed: %v", err) + + var receivedResponse bool + streamCtx, cancel := context.WithTimeout(ctx, 30*time.Second) + defer cancel() + + for { + select { + case response, ok := <-responseChannel: + if !ok { + goto verboseStreamComplete + } + + if response.BifrostError != nil { + t.Fatalf("Error in verbose stream: %v", response.BifrostError) + } + + if response.Transcribe != nil { + receivedResponse = true + + // Check for verbose_json specific fields + if response.Transcribe.BifrostTranscribeStreamResponse != nil { + t.Logf("Stream type: %v", response.Transcribe.BifrostTranscribeStreamResponse.Type) + if response.Transcribe.BifrostTranscribeStreamResponse.Delta != nil { + t.Logf("Delta: %s", *response.Transcribe.BifrostTranscribeStreamResponse.Delta) + } + } + } + + case <-streamCtx.Done(): + t.Fatal("Verbose stream reading timed out") + } + } + + verboseStreamComplete: + assert.True(t, receivedResponse, "Should receive at least one response") + t.Logf("βœ… Verbose JSON streaming successful") + }) + + t.Run("MultipleLanguages_Streaming", func(t *testing.T) { + // Generate audio for language streaming tests + audioData, _ := GenerateTTSAudioForTest(ctx, t, client, testConfig.Provider, testConfig.SpeechSynthesisModel, TTSTestTextBasic, "primary", "mp3") + + // Test streaming with different language hints (only English for now) + languages := []string{"en"} + + for _, lang := range languages { + t.Run("StreamLang_"+lang, func(t *testing.T) { + langCopy := lang + request := &schemas.BifrostRequest{ + Provider: testConfig.Provider, + Model: testConfig.TranscriptionModel, + Input: schemas.RequestInput{ + TranscriptionInput: &schemas.TranscriptionInput{ + File: audioData, + Language: &langCopy, + }, + }, + Params: MergeModelParameters(&schemas.ModelParameters{}, testConfig.CustomParams), + Fallbacks: testConfig.Fallbacks, + } + + responseChannel, err := client.TranscriptionStreamRequest(ctx, request) + require.Nilf(t, err, "Streaming failed for language %s: %v", lang, err) + + var receivedData bool + streamCtx, cancel := context.WithTimeout(ctx, 30*time.Second) + defer cancel() + + for { + select { + case response, ok := <-responseChannel: + if !ok { + goto langStreamComplete + } + + if response.BifrostError != nil { + t.Fatalf("Error in stream for language %s: %v", lang, response.BifrostError) + } + + if response.Transcribe != nil { + receivedData = true + } + + case <-streamCtx.Done(): + t.Fatalf("Stream timed out for language %s", lang) + } + } + + langStreamComplete: + assert.True(t, receivedData, "Should receive transcription data for language %s", lang) + t.Logf("βœ… Streaming successful for language: %s", lang) + }) + } + }) + + t.Run("WithCustomPrompt_Streaming", func(t *testing.T) { + // Generate audio for custom prompt streaming test + audioData, _ := GenerateTTSAudioForTest(ctx, t, client, testConfig.Provider, testConfig.SpeechSynthesisModel, TTSTestTextTechnical, "tertiary", "mp3") + + // Test streaming with custom prompt for context + request := &schemas.BifrostRequest{ + Provider: testConfig.Provider, + Model: testConfig.TranscriptionModel, + Input: schemas.RequestInput{ + TranscriptionInput: &schemas.TranscriptionInput{ + File: audioData, + Language: bifrost.Ptr("en"), + Prompt: bifrost.Ptr("This audio contains technical terms, proper nouns, and streaming-related vocabulary."), + }, + }, + Params: MergeModelParameters(&schemas.ModelParameters{ + Temperature: bifrost.Ptr(0.1), + }, testConfig.CustomParams), + Fallbacks: testConfig.Fallbacks, + } + + responseChannel, err := client.TranscriptionStreamRequest(ctx, request) + require.Nilf(t, err, "Custom prompt streaming failed: %v", err) + + var chunkCount int + streamCtx, cancel := context.WithTimeout(ctx, 30*time.Second) + defer cancel() + + for { + select { + case response, ok := <-responseChannel: + if !ok { + goto promptStreamComplete + } + + if response.BifrostError != nil { + t.Fatalf("Error in prompt stream: %v", response.BifrostError) + } + + if response.Transcribe != nil && response.Transcribe.Text != "" { + chunkCount++ + } + + case <-streamCtx.Done(): + t.Fatal("Prompt stream reading timed out") + } + } + + promptStreamComplete: + assert.Greater(t, chunkCount, 0, "Should receive at least one transcription chunk") + t.Logf("βœ… Custom prompt streaming successful: %d chunks received", chunkCount) + }) + }) +} diff --git a/tests/core-providers/scenarios/utils.go b/tests/core-providers/scenarios/utils.go new file mode 100644 index 000000000..27937762f --- /dev/null +++ b/tests/core-providers/scenarios/utils.go @@ -0,0 +1,372 @@ +package scenarios + +import ( + "context" + "os" + "strings" + "testing" + + bifrost "github.com/maximhq/bifrost/core" + "github.com/maximhq/bifrost/core/schemas" +) + +// Shared test texts for TTS->SST round-trip validation +const ( + // Basic test text for simple round-trip validation + TTSTestTextBasic = "Hello, this is a test of speech synthesis from Bifrost." + + // Medium length text with punctuation for comprehensive testing + TTSTestTextMedium = "Testing speech synthesis and transcription round-trip. This text includes punctuation, numbers like 123, and technical terms." + + // Short technical text for WAV format testing + TTSTestTextTechnical = "Bifrost AI gateway processes audio requests efficiently." +) + +// GetProviderVoice returns an appropriate voice for the given provider +func GetProviderVoice(provider schemas.ModelProvider, voiceType string) string { + switch provider { + case schemas.OpenAI: + switch voiceType { + case "primary": + return "alloy" + case "secondary": + return "nova" + case "tertiary": + return "echo" + default: + return "alloy" + } + case schemas.Gemini: + switch voiceType { + case "primary": + return "achernar" + case "secondary": + return "aoede" + case "tertiary": + return "charon" + default: + return "achernar" + } + default: + // Default to OpenAI voices for other providers + switch voiceType { + case "primary": + return "alloy" + case "secondary": + return "nova" + case "tertiary": + return "echo" + default: + return "alloy" + } + } +} + +// Tool definitions for testing +var WeatherToolDefinition = schemas.Tool{ + Type: "function", + Function: schemas.Function{ + Name: "get_weather", + Description: "Get the current weather in a given location", + Parameters: schemas.FunctionParameters{ + Type: "object", + Properties: map[string]interface{}{ + "location": map[string]interface{}{ + "type": "string", + "description": "The city and state, e.g. San Francisco, CA", + }, + "unit": map[string]interface{}{ + "type": "string", + "enum": []string{"celsius", "fahrenheit"}, + }, + }, + Required: []string{"location"}, + }, + }, +} + +var CalculatorToolDefinition = schemas.Tool{ + Type: "function", + Function: schemas.Function{ + Name: "calculate", + Description: "Perform basic mathematical calculations", + Parameters: schemas.FunctionParameters{ + Type: "object", + Properties: map[string]interface{}{ + "expression": map[string]interface{}{ + "type": "string", + "description": "The mathematical expression to evaluate, e.g. '2 + 3' or '10 * 5'", + }, + }, + Required: []string{"expression"}, + }, + }, +} + +var TimeToolDefinition = schemas.Tool{ + Type: "function", + Function: schemas.Function{ + Name: "get_current_time", + Description: "Get the current time in a specific timezone", + Parameters: schemas.FunctionParameters{ + Type: "object", + Properties: map[string]interface{}{ + "timezone": map[string]interface{}{ + "type": "string", + "description": "The timezone identifier, e.g. 'America/New_York' or 'UTC'", + }, + }, + Required: []string{"timezone"}, + }, + }, +} + +// Test images for testing +const TestImageURL = "https://upload.wikimedia.org/wikipedia/commons/a/a7/Camponotus_flavomarginatus_ant.jpg" +const TestImageBase64 = "" + +// CreateSpeechInput creates a basic speech input for testing +func CreateSpeechInput(text, voice, format string) *schemas.SpeechInput { + return &schemas.SpeechInput{ + Input: text, + VoiceConfig: schemas.SpeechVoiceInput{ + Voice: &voice, + }, + ResponseFormat: format, + } +} + +// CreateTranscriptionInput creates a basic transcription input for testing +func CreateTranscriptionInput(audioData []byte, language, responseFormat *string) *schemas.TranscriptionInput { + return &schemas.TranscriptionInput{ + File: audioData, + Language: language, + ResponseFormat: responseFormat, + } +} + +// Helper functions for creating requests +func CreateBasicChatMessage(content string) schemas.BifrostMessage { + return schemas.BifrostMessage{ + Role: schemas.ModelChatMessageRoleUser, + Content: schemas.MessageContent{ + ContentStr: bifrost.Ptr(content), + }, + } +} + +func CreateImageMessage(text, imageURL string) schemas.BifrostMessage { + return schemas.BifrostMessage{ + Role: schemas.ModelChatMessageRoleUser, + Content: schemas.MessageContent{ + ContentBlocks: &[]schemas.ContentBlock{ + { + Type: schemas.ContentBlockTypeText, + Text: bifrost.Ptr(text), + }, + { + Type: schemas.ContentBlockTypeImage, + ImageURL: &schemas.ImageURLStruct{ + URL: imageURL, + }, + }, + }, + }, + } +} + +func CreateToolMessage(content string, toolCallID string) schemas.BifrostMessage { + return schemas.BifrostMessage{ + Role: schemas.ModelChatMessageRoleTool, + Content: schemas.MessageContent{ + ContentStr: bifrost.Ptr(content), + }, + ToolMessage: &schemas.ToolMessage{ + ToolCallID: &toolCallID, + }, + } +} + +// GetResultContent returns the string content from a BifrostResponse +// It looks through all choices and returns content from the first choice that has any +func GetResultContent(result *schemas.BifrostResponse) string { + if result == nil || len(result.Choices) == 0 { + return "" + } + + // Try to find content from any choice, prioritizing non-empty content + for _, choice := range result.Choices { + if choice.Message.Content.ContentStr != nil && *choice.Message.Content.ContentStr != "" { + return *choice.Message.Content.ContentStr + } else if choice.Message.Content.ContentBlocks != nil { + var builder strings.Builder + for _, block := range *choice.Message.Content.ContentBlocks { + if block.Text != nil { + builder.WriteString(*block.Text) + } + } + content := builder.String() + if content != "" { + return content + } + } + } + + // Fallback to first choice if no content found + if result.Choices[0].Message.Content.ContentStr != nil { + return *result.Choices[0].Message.Content.ContentStr + } else if result.Choices[0].Message.Content.ContentBlocks != nil { + var builder strings.Builder + for _, block := range *result.Choices[0].Message.Content.ContentBlocks { + if block.Text != nil { + builder.WriteString(*block.Text) + } + } + return builder.String() + } + return "" +} + +// MergeModelParameters performs a shallow merge of two ModelParameters instances. +// Non-nil fields from the override parameter take precedence over the base parameter. +// Returns a new ModelParameters instance with the merged values. +func MergeModelParameters(base *schemas.ModelParameters, override *schemas.ModelParameters) *schemas.ModelParameters { + if base == nil && override == nil { + return &schemas.ModelParameters{} + } + if base == nil { + return copyModelParameters(override) + } + if override == nil { + return copyModelParameters(base) + } + + // Start with a copy of base parameters + result := copyModelParameters(base) + + // Override with non-nil fields from override + if override.MaxTokens != nil { + result.MaxTokens = override.MaxTokens + } + if override.Temperature != nil { + result.Temperature = override.Temperature + } + if override.TopP != nil { + result.TopP = override.TopP + } + if override.TopK != nil { + result.TopK = override.TopK + } + if override.FrequencyPenalty != nil { + result.FrequencyPenalty = override.FrequencyPenalty + } + if override.PresencePenalty != nil { + result.PresencePenalty = override.PresencePenalty + } + if override.StopSequences != nil { + result.StopSequences = override.StopSequences + } + if override.Tools != nil { + result.Tools = override.Tools + } + if override.ToolChoice != nil { + result.ToolChoice = override.ToolChoice + } + if override.ParallelToolCalls != nil { + result.ParallelToolCalls = override.ParallelToolCalls + } + if override.EncodingFormat != nil { + result.EncodingFormat = override.EncodingFormat + } + if override.Dimensions != nil { + result.Dimensions = override.Dimensions + } + if override.User != nil { + result.User = override.User + } + if override.ExtraParams != nil { + result.ExtraParams = override.ExtraParams + } + + return result +} + +// copyModelParameters creates a shallow copy of a ModelParameters instance +func copyModelParameters(src *schemas.ModelParameters) *schemas.ModelParameters { + if src == nil { + return &schemas.ModelParameters{} + } + + return &schemas.ModelParameters{ + MaxTokens: src.MaxTokens, + Temperature: src.Temperature, + TopP: src.TopP, + TopK: src.TopK, + FrequencyPenalty: src.FrequencyPenalty, + PresencePenalty: src.PresencePenalty, + StopSequences: src.StopSequences, + Tools: src.Tools, + ToolChoice: src.ToolChoice, + ParallelToolCalls: src.ParallelToolCalls, + EncodingFormat: src.EncodingFormat, + Dimensions: src.Dimensions, + User: src.User, + ExtraParams: src.ExtraParams, + } +} + +// --- Additional test helpers appended below (imported on demand) --- + +// NOTE: importing context, os, testing only in this block to avoid breaking existing imports. +// We duplicate types by fully qualifying to not touch import list above. + +// GenerateTTSAudioForTest generates real audio using TTS and writes a temp file. +// Returns audio bytes and temp filepath. Caller’s t will clean it up. +func GenerateTTSAudioForTest(ctx context.Context, t *testing.T, client *bifrost.Bifrost, provider schemas.ModelProvider, ttsModel string, text string, voiceType string, format string) ([]byte, string) { + // inline import guard comment: context/testing/os are required at call sites; Go compiler will include them. + voice := GetProviderVoice(provider, voiceType) + if voice == "" { + voice = GetProviderVoice(provider, "primary") + } + if format == "" { + format = "mp3" + } + + req := &schemas.BifrostRequest{ + Provider: provider, + Model: ttsModel, + Input: schemas.RequestInput{ + SpeechInput: &schemas.SpeechInput{ + Input: text, + VoiceConfig: schemas.SpeechVoiceInput{ + Voice: &voice, + }, + ResponseFormat: format, + }, + }, + } + + resp, err := client.SpeechRequest(ctx, req) + if err != nil { + t.Fatalf("TTS request failed: %v", err) + } + if resp == nil || resp.Speech == nil || len(resp.Speech.Audio) == 0 { + t.Fatalf("TTS response missing audio data") + } + + suffix := "." + format + f, cerr := os.CreateTemp("", "bifrost-tts-*"+suffix) + if cerr != nil { + t.Fatalf("failed to create temp audio file: %v", cerr) + } + tempPath := f.Name() + if _, werr := f.Write(resp.Speech.Audio); werr != nil { + _ = f.Close() + t.Fatalf("failed to write temp audio file: %v", werr) + } + _ = f.Close() + + t.Cleanup(func() { _ = os.Remove(tempPath) }) + + return resp.Speech.Audio, tempPath +} diff --git a/tests/core-providers/sgl_test.go b/tests/core-providers/sgl_test.go new file mode 100644 index 000000000..967e36481 --- /dev/null +++ b/tests/core-providers/sgl_test.go @@ -0,0 +1,43 @@ +package tests + +import ( + "testing" + + "github.com/maximhq/bifrost/tests/core-providers/config" + + "github.com/maximhq/bifrost/core/schemas" +) + +func TestSGL(t *testing.T) { + client, ctx, cancel, err := config.SetupTest() + if err != nil { + t.Fatalf("Error initializing test setup: %v", err) + } + defer cancel() + defer client.Shutdown() + + testConfig := config.ComprehensiveTestConfig{ + Provider: schemas.SGL, + ChatModel: "Qwen2.5-VL-7B-Instruct", + TextModel: "", // SGL doesn't support text completion + EmbeddingModel: "", // SGL doesn't support embedding + Scenarios: config.TestScenarios{ + TextCompletion: false, // Not supported + SimpleChat: true, + ChatCompletionStream: true, + MultiTurnConversation: true, + ToolCalls: true, + MultipleToolCalls: true, + End2EndToolCalling: true, + AutomaticFunctionCall: true, + ImageURL: true, + ImageBase64: true, + MultipleImages: true, + CompleteEnd2End: true, + ProviderSpecific: true, + Embedding: false, + }, + } + + runAllComprehensiveTests(t, client, ctx, testConfig) +} diff --git a/tests/core-providers/tests.go b/tests/core-providers/tests.go new file mode 100644 index 000000000..9f0631bdf --- /dev/null +++ b/tests/core-providers/tests.go @@ -0,0 +1,113 @@ +package tests + +import ( + "context" + "strings" + "testing" + + "github.com/maximhq/bifrost/tests/core-providers/config" + "github.com/maximhq/bifrost/tests/core-providers/scenarios" + + bifrost "github.com/maximhq/bifrost/core" +) + +// TestScenarioFunc defines the function signature for test scenario functions +type TestScenarioFunc func(*testing.T, *bifrost.Bifrost, context.Context, config.ComprehensiveTestConfig) + +// runAllComprehensiveTests executes all comprehensive test scenarios for a given configuration +func runAllComprehensiveTests(t *testing.T, client *bifrost.Bifrost, ctx context.Context, testConfig config.ComprehensiveTestConfig) { + if testConfig.SkipReason != "" { + t.Skipf("Skipping %s: %s", testConfig.Provider, testConfig.SkipReason) + return + } + + t.Logf("πŸš€ Running comprehensive tests for provider: %s", testConfig.Provider) + + // Define all test scenario functions in a slice + testScenarios := []TestScenarioFunc{ + scenarios.RunTextCompletionTest, + scenarios.RunSimpleChatTest, + scenarios.RunChatCompletionStreamTest, + scenarios.RunMultiTurnConversationTest, + scenarios.RunToolCallsTest, + scenarios.RunMultipleToolCallsTest, + scenarios.RunEnd2EndToolCallingTest, + scenarios.RunAutomaticFunctionCallingTest, + scenarios.RunImageURLTest, + scenarios.RunImageBase64Test, + scenarios.RunMultipleImagesTest, + scenarios.RunCompleteEnd2EndTest, + scenarios.RunProviderSpecificTest, + scenarios.RunSpeechSynthesisTest, + scenarios.RunSpeechSynthesisAdvancedTest, + scenarios.RunSpeechSynthesisStreamTest, + scenarios.RunSpeechSynthesisStreamAdvancedTest, + scenarios.RunTranscriptionTest, + scenarios.RunTranscriptionAdvancedTest, + scenarios.RunTranscriptionStreamTest, + scenarios.RunTranscriptionStreamAdvancedTest, + scenarios.RunEmbeddingTest, + } + + // Execute all test scenarios + for _, scenarioFunc := range testScenarios { + scenarioFunc(t, client, ctx, testConfig) + } + + // Print comprehensive summary based on configuration + printTestSummary(t, testConfig) +} + +// printTestSummary prints a detailed summary of all test scenarios +func printTestSummary(t *testing.T, testConfig config.ComprehensiveTestConfig) { + testScenarios := []struct { + name string + supported bool + }{ + {"TextCompletion", testConfig.Scenarios.TextCompletion && testConfig.TextModel != ""}, + {"SimpleChat", testConfig.Scenarios.SimpleChat}, + {"ChatCompletionStream", testConfig.Scenarios.ChatCompletionStream}, + {"MultiTurnConversation", testConfig.Scenarios.MultiTurnConversation}, + {"ToolCalls", testConfig.Scenarios.ToolCalls}, + {"MultipleToolCalls", testConfig.Scenarios.MultipleToolCalls}, + {"End2EndToolCalling", testConfig.Scenarios.End2EndToolCalling}, + {"AutomaticFunctionCall", testConfig.Scenarios.AutomaticFunctionCall}, + {"ImageURL", testConfig.Scenarios.ImageURL}, + {"ImageBase64", testConfig.Scenarios.ImageBase64}, + {"MultipleImages", testConfig.Scenarios.MultipleImages}, + {"CompleteEnd2End", testConfig.Scenarios.CompleteEnd2End}, + {"ProviderSpecific", testConfig.Scenarios.ProviderSpecific}, + {"SpeechSynthesis", testConfig.Scenarios.SpeechSynthesis}, + {"SpeechSynthesisStream", testConfig.Scenarios.SpeechSynthesisStream}, + {"Transcription", testConfig.Scenarios.Transcription}, + {"TranscriptionStream", testConfig.Scenarios.TranscriptionStream}, + {"Embedding", testConfig.Scenarios.Embedding && testConfig.EmbeddingModel != ""}, + } + + supported := 0 + unsupported := 0 + + t.Logf("\n%s", strings.Repeat("=", 80)) + t.Logf("COMPREHENSIVE TEST SUMMARY FOR PROVIDER: %s", strings.ToUpper(string(testConfig.Provider))) + t.Logf("%s", strings.Repeat("=", 80)) + + for _, scenario := range testScenarios { + if scenario.supported { + supported++ + t.Logf("βœ… SUPPORTED: %-25s βœ… Configured to run", scenario.name) + } else { + unsupported++ + t.Logf("❌ UNSUPPORTED: %-25s ❌ Not supported by provider", scenario.name) + } + } + + t.Logf("%s", strings.Repeat("-", 80)) + t.Logf("CONFIGURATION SUMMARY:") + t.Logf(" βœ… Supported Tests: %d", supported) + t.Logf(" ❌ Unsupported Tests: %d", unsupported) + t.Logf(" πŸ“Š Total Test Types: %d", len(testScenarios)) + t.Logf("") + t.Logf("ℹ️ NOTE: Actual PASS/FAIL results are shown in the individual test output above.") + t.Logf("ℹ️ Look for individual test results like 'PASS: TestOpenAI/SimpleChat' or 'FAIL: TestOpenAI/ToolCalls'") + t.Logf("%s\n", strings.Repeat("=", 80)) +} diff --git a/tests/core-providers/vertex_test.go b/tests/core-providers/vertex_test.go new file mode 100644 index 000000000..a9a3de82a --- /dev/null +++ b/tests/core-providers/vertex_test.go @@ -0,0 +1,43 @@ +package tests + +import ( + "testing" + + "github.com/maximhq/bifrost/tests/core-providers/config" + + "github.com/maximhq/bifrost/core/schemas" +) + +func TestVertex(t *testing.T) { + client, ctx, cancel, err := config.SetupTest() + if err != nil { + t.Fatalf("Error initializing test setup: %v", err) + } + defer cancel() + defer client.Shutdown() + + testConfig := config.ComprehensiveTestConfig{ + Provider: schemas.Vertex, + ChatModel: "google/gemini-2.0-flash-001", + TextModel: "", // Vertex doesn't support text completion in newer models + EmbeddingModel: "text-multilingual-embedding-002", + Scenarios: config.TestScenarios{ + TextCompletion: false, // Not supported + SimpleChat: true, + ChatCompletionStream: true, + MultiTurnConversation: true, + ToolCalls: true, + MultipleToolCalls: true, + End2EndToolCalling: true, + AutomaticFunctionCall: true, + ImageURL: true, + ImageBase64: true, + MultipleImages: true, + CompleteEnd2End: true, + ProviderSpecific: true, + Embedding: true, + }, + } + + runAllComprehensiveTests(t, client, ctx, testConfig) +} diff --git a/tests/docker-compose.yml b/tests/docker-compose.yml new file mode 100644 index 000000000..b45bed796 --- /dev/null +++ b/tests/docker-compose.yml @@ -0,0 +1,59 @@ +services: + # Weaviate instance for basic tests + weaviate: + image: cr.weaviate.io/semitechnologies/weaviate:1.32.4 + command: + - --host + - 0.0.0.0 + - --port + - '8080' + - --scheme + - http + environment: + - CLUSTER_HOSTNAME=weaviate + - CLUSTER_ADVERTISE_ADDR=172.38.0.12 + - CLUSTER_GOSSIP_BIND_PORT=7946 + - CLUSTER_DATA_BIND_PORT=7947 + - DISABLE_TELEMETRY=true + - PERSISTENCE_DATA_PATH=/var/lib/weaviate + - DEFAULT_VECTORIZER_MODULE=none + - ENABLE_MODULES= + - AUTHENTICATION_ANONYMOUS_ACCESS_ENABLED=true + - LOG_LEVEL=info + ports: + - "9000:8080" + volumes: + - weaviate_data:/var/lib/weaviate + networks: + bifrost_network: + ipv4_address: 172.38.0.12 + + # Redis Stack instance for vector store tests + redis-stack: + image: redis/redis-stack:7.4.0-v6 + command: redis-stack-server --protected-mode no + ports: + - "6379:6379" + - "8001:8001" # RedisInsight web UI + volumes: + - redis_data:/data + networks: + bifrost_network: + ipv4_address: 172.38.0.13 + healthcheck: + test: ["CMD", "redis-cli", "ping"] + interval: 30s + timeout: 10s + retries: 3 + +networks: + bifrost_network: + driver: bridge + ipam: + config: + - subnet: 172.38.0.0/16 + gateway: 172.38.0.1 + +volumes: + weaviate_data: + redis_data: \ No newline at end of file diff --git a/tests/governance/README.md b/tests/governance/README.md new file mode 100644 index 000000000..1cbc0f988 --- /dev/null +++ b/tests/governance/README.md @@ -0,0 +1,388 @@ +# Bifrost Governance Plugin Test Suite + +A comprehensive test suite for the Bifrost Governance Plugin, testing hierarchical governance, budgets, rate limiting, usage tracking, and CRUD operations. + +## Overview + +This test suite provides extensive coverage of the Bifrost governance system including: + +- **Virtual Key Management**: Complete CRUD operations with comprehensive field update testing +- **Team Management**: Team CRUD with customer relationships and budget inheritance +- **Customer Management**: Customer CRUD with team hierarchies and budget controls +- **Usage Tracking**: Real-time usage monitoring and audit logging +- **Rate Limiting**: Flexible token and request rate limiting with configurable reset periods +- **Budget Enforcement**: Hierarchical budget controls (Customer β†’ Team β†’ Virtual Key) +- **Integration Testing**: End-to-end testing with chat completion API +- **Edge Cases**: Boundary conditions, concurrency, and error scenarios + +## Test Structure + +### Test Files + +1. **`test_virtual_keys_crud.py`** - Virtual Key CRUD operations + - Complete CRUD lifecycle testing + - Comprehensive field update testing (individual and batch) + - Mutual exclusivity validation (team_id vs customer_id) + - Budget and rate limit management + - Relationship testing with teams and customers + +2. **`test_teams_crud.py`** - Team CRUD operations + - Team lifecycle management + - Customer association testing + - Budget inheritance and conflicts + - Comprehensive field updates + - Filtering and relationships + +3. **`test_customers_crud.py`** - Customer CRUD operations + - Customer lifecycle management + - Team relationship management + - Budget management and hierarchies + - Comprehensive field updates + - Cascading operations + +4. **`test_usage_tracking.py`** - Usage tracking and monitoring + - Chat completion integration with governance headers + - Usage tracking and budget enforcement + - Rate limiting enforcement + - Monitoring endpoints + - Reset functionality + - Debug and health endpoints + +### Configuration Files + +- **`conftest.py`** - Test fixtures, utilities, and configuration +- **`pytest.ini`** - pytest configuration with markers and settings +- **`requirements.txt`** - Test dependencies +- **`__init__.py`** - Package initialization + +## Key Features + +### Comprehensive Field Update Testing + +Each entity (Virtual Key, Team, Customer) has exhaustive field update tests that verify: + +- **Individual field updates** - Each field updated independently +- **Unchanged field verification** - Other fields remain unmodified +- **Relationship preservation** - Associated data maintained correctly +- **Timestamp validation** - updated_at changes, created_at preserved +- **Multiple field updates** - Batch field modifications +- **Nested object updates** - Budget and rate limit sub-objects +- **Edge cases** - Empty updates, null values, invalid data + +### Mutual Exclusivity Testing + +Critical validation of Virtual Key constraints: +- VK can have `team_id` OR `customer_id`, but NEVER both +- Switching between team and customer associations +- Validation error scenarios for invalid combinations + +### Hierarchical Testing + +Testing the Customer β†’ Team β†’ Virtual Key hierarchy: +- Budget inheritance and override scenarios +- Rate limit cascading and conflicts +- Usage tracking across hierarchy levels +- Permission and access control validation + +### Integration Testing + +End-to-end testing with actual chat completion requests: +- Governance header validation (`x-bf-vk`) +- Usage tracking during real requests +- Budget enforcement during streaming +- Rate limiting during concurrent requests +- Provider and model access control + +## Setup and Usage + +### Prerequisites + +1. **Bifrost Server Running**: The governance plugin must be running on `localhost:8080` +2. **Python 3.8+**: Required for the test suite +3. **Dependencies**: Install via `pip install -r requirements.txt` + +### Environment Configuration + +Set the following environment variables (optional): + +```bash +export BIFROST_BASE_URL="http://localhost:8080" # Default +export GOVERNANCE_TEST_TIMEOUT="300" # Test timeout in seconds +export GOVERNANCE_TEST_CLEANUP="true" # Auto-cleanup entities +``` + +### Running Tests + +```bash +# Install dependencies +pip install -r requirements.txt + +# Run all governance tests +pytest + +# Run specific test files +pytest test_virtual_keys_crud.py +pytest test_teams_crud.py +pytest test_customers_crud.py +pytest test_usage_tracking.py + +# Run with specific markers +pytest -m "virtual_keys" +pytest -m "field_updates" +pytest -m "edge_cases" +pytest -m "integration" + +# Run with coverage +pytest --cov=. --cov-report=html + +# Run in parallel +pytest -n auto + +# Run with verbose output +pytest -v + +# Run smoke tests only +pytest -m "smoke" +``` + +### Test Markers + +The test suite uses pytest markers for categorization: + +- `@pytest.mark.virtual_keys` - Virtual Key related tests +- `@pytest.mark.teams` - Team related tests +- `@pytest.mark.customers` - Customer related tests +- `@pytest.mark.field_updates` - Comprehensive field update tests +- `@pytest.mark.mutual_exclusivity` - Mutual exclusivity constraint tests +- `@pytest.mark.budget` - Budget related tests +- `@pytest.mark.rate_limit` - Rate limiting tests +- `@pytest.mark.usage_tracking` - Usage tracking tests +- `@pytest.mark.integration` - Integration tests +- `@pytest.mark.edge_cases` - Edge case tests +- `@pytest.mark.concurrency` - Concurrency tests +- `@pytest.mark.slow` - Slow running tests (>5s) +- `@pytest.mark.smoke` - Quick smoke tests + +## API Endpoints Tested + +### Virtual Key Endpoints +- `GET /api/governance/virtual-keys` - List all VKs with relationships +- `POST /api/governance/virtual-keys` - Create VK with optional budget/rate limits +- `GET /api/governance/virtual-keys/{vk_id}` - Get specific VK +- `PUT /api/governance/virtual-keys/{vk_id}` - Update VK +- `DELETE /api/governance/virtual-keys/{vk_id}` - Delete VK + +### Team Endpoints +- `GET /api/governance/teams` - List teams with optional customer filter +- `POST /api/governance/teams` - Create team with optional customer/budget +- `GET /api/governance/teams/{team_id}` - Get specific team +- `PUT /api/governance/teams/{team_id}` - Update team +- `DELETE /api/governance/teams/{team_id}` - Delete team + +### Customer Endpoints +- `GET /api/governance/customers` - List customers with teams/budgets +- `POST /api/governance/customers` - Create customer with optional budget +- `GET /api/governance/customers/{customer_id}` - Get specific customer +- `PUT /api/governance/customers/{customer_id}` - Update customer +- `DELETE /api/governance/customers/{customer_id}` - Delete customer + +### Monitoring Endpoints +- `GET /api/governance/usage-stats` - Usage statistics with optional VK filter +- `POST /api/governance/usage-reset` - Reset VK usage counters +- `GET /api/governance/debug/stats` - Debug statistics +- `GET /api/governance/debug/counters` - All VK usage counters +- `GET /api/governance/debug/health` - Health check + +### Integration Endpoints +- `POST /v1/chat/completions` - Chat completion with governance headers + +## Test Data and Schemas + +### Virtual Key Request Schema +```json +{ + "name": "string (required)", + "description": "string (optional)", + "allowed_models": ["string"] (optional), + "allowed_providers": ["string"] (optional), + "team_id": "string (optional, mutually exclusive with customer_id)", + "customer_id": "string (optional, mutually exclusive with team_id)", + "budget": { + "max_limit": "integer (cents)", + "reset_duration": "string (e.g., '1h', '1d')" + }, + "rate_limit": { + "token_max_limit": "integer (optional)", + "token_reset_duration": "string (optional)", + "request_max_limit": "integer (optional)", + "request_reset_duration": "string (optional)" + }, + "is_active": "boolean (optional, default true)" +} +``` + +### Team Request Schema +```json +{ + "name": "string (required)", + "customer_id": "string (optional)", + "budget": { + "max_limit": "integer (cents)", + "reset_duration": "string" + } +} +``` + +### Customer Request Schema +```json +{ + "name": "string (required)", + "budget": { + "max_limit": "integer (cents)", + "reset_duration": "string" + } +} +``` + +## Edge Cases Covered + +### Budget Edge Cases +- Boundary values: 0, negative, max int64, overflow +- Reset timing: exact boundaries, concurrent resets +- Hierarchical conflicts: VK vs Team vs Customer budgets +- Fractional costs: proper cents handling +- Concurrent usage: multiple requests hitting limits +- Reset during flight: budget resets while processing +- Streaming cost tracking: partial vs final costs + +### Rate Limiting Edge Cases +- Independent limits: token vs request limits with different resets +- Sub-second precision: very short reset durations +- Burst scenarios: simultaneous requests +- Provider variations: different limits per provider/model +- Streaming rate limits: token counting across chunks +- Reset race conditions: limits resetting during validation + +### Relationship Edge Cases +- Orphaned entities: VKs without parent relationships +- Invalid references: team_id pointing to non-existent team +- Mutual exclusivity: VK with both team_id and customer_id (MUST FAIL) +- Circular dependencies: prevention testing +- Deep hierarchies: Customer β†’ Team β†’ VK inheritance + +### Update Edge Cases +- Partial updates: only some fields updated +- Null handling: null values clearing optional fields +- Type validation: wrong data types in requests +- Concurrent updates: multiple clients updating same entity +- Cache invalidation: in-memory cache updates after DB changes +- Rollback scenarios: failed updates don't leave partial changes + +### Integration Edge Cases +- Missing headers: requests without x-bf-vk header +- Invalid headers: malformed or non-existent VK values +- Provider/model validation: invalid combinations +- Error propagation: governance vs completion errors +- Streaming interruption: governance blocking mid-stream +- Context preservation: headers passed through request lifecycle + +## Utilities and Helpers + +### Test Fixtures +- `governance_client` - API client for governance endpoints +- `cleanup_tracker` - Automatic entity cleanup after tests +- `sample_customer` - Pre-created customer for testing +- `sample_team` - Pre-created team for testing +- `sample_virtual_key` - Pre-created virtual key for testing +- `field_update_tester` - Helper for comprehensive field update testing + +### Utility Functions +- `generate_unique_name()` - Generate unique test entity names +- `wait_for_condition()` - Wait for async conditions +- `assert_response_success()` - Assert HTTP response success +- `deep_compare_entities()` - Deep comparison of entity data +- `verify_unchanged_fields()` - Verify fields remain unchanged +- `create_complete_virtual_key_data()` - Generate complete VK data + +### Error Handling +- Comprehensive error assertion helpers +- Automatic retry for transient failures +- Detailed error logging and reporting +- Clean failure modes with proper cleanup + +## Performance and Concurrency + +### Performance Testing +- Response time benchmarks for all endpoints +- Memory usage monitoring during tests +- Database query optimization validation +- Cache performance verification + +### Concurrency Testing +- Race condition detection +- Concurrent entity creation/updates +- Simultaneous budget usage scenarios +- Rate limit burst testing +- Cache consistency under load + +## Debugging and Monitoring + +### Test Logging +- Comprehensive test execution logging +- API request/response logging +- Error details and stack traces +- Performance metrics and timing + +### Debug Endpoints +- Test coverage of debug/stats endpoint +- Usage counter validation +- Health check verification +- Database state inspection + +## Contributing + +When adding new tests: + +1. **Follow naming conventions**: `test__.py` +2. **Use appropriate markers**: Mark tests with relevant pytest markers +3. **Include cleanup**: Use `cleanup_tracker` fixture for entity cleanup +4. **Document edge cases**: Comment complex test scenarios +5. **Add field update tests**: For any new entity fields, add comprehensive update tests +6. **Test relationships**: Verify entity relationships and cascading effects +7. **Include negative tests**: Test validation and error scenarios + +### Test Development Guidelines + +1. **Comprehensive Coverage**: Test all CRUD operations, field updates, and edge cases +2. **Isolation**: Tests should be independent and not rely on other test state +3. **Cleanup**: Always clean up created entities to avoid test interference +4. **Documentation**: Comment complex test logic and expected behaviors +5. **Performance**: Mark slow tests appropriately and optimize where possible +6. **Error Scenarios**: Test both success and failure paths +7. **Relationships**: Verify entity relationships are properly maintained + +## Troubleshooting + +### Common Issues + +1. **Server Not Running**: Ensure Bifrost server is running on localhost:8080 +2. **Permission Errors**: Check that test has access to create/delete entities +3. **Cleanup Failures**: Manually clean up test entities if auto-cleanup fails +4. **Timeout Errors**: Increase timeout for slow-running tests +5. **Concurrency Issues**: Use appropriate locks for shared resource tests + +### Debug Commands + +```bash +# Run with maximum verbosity +pytest -vvv --tb=long + +# Run single test with debugging +pytest -s test_virtual_keys_crud.py::test_vk_create_basic + +# Run with profiling +pytest --profile-svg + +# Check test coverage +pytest --cov=. --cov-report=term-missing +``` \ No newline at end of file diff --git a/tests/governance/__init__.py b/tests/governance/__init__.py new file mode 100644 index 000000000..2936e67c9 --- /dev/null +++ b/tests/governance/__init__.py @@ -0,0 +1,31 @@ +""" +Bifrost Governance Plugin Test Suite + +Comprehensive test suite for the Bifrost governance system covering: +- Virtual Key CRUD operations with comprehensive field updates +- Team CRUD operations with hierarchical relationships +- Customer CRUD operations with budget management +- Usage tracking and monitoring +- Rate limiting and budget enforcement +- Integration testing with chat completions +- Edge cases and validation testing +- Concurrency and race condition testing + +Test Structure: +- test_virtual_keys_crud.py: Virtual Key CRUD and field update tests +- test_teams_crud.py: Team CRUD and field update tests +- test_customers_crud.py: Customer CRUD and field update tests +- test_usage_tracking.py: Usage tracking, monitoring, and integration tests +- conftest.py: Test fixtures and utilities + +Key Features: +- Comprehensive field update testing for all entities +- Mutual exclusivity validation (VK team_id vs customer_id) +- Hierarchical budget and rate limit testing +- Automatic test entity cleanup +- Concurrent testing support +- Edge case and boundary condition coverage +""" + +__version__ = "1.0.0" +__author__ = "Bifrost Team" diff --git a/tests/governance/conftest.py b/tests/governance/conftest.py new file mode 100644 index 000000000..84d77c2d0 --- /dev/null +++ b/tests/governance/conftest.py @@ -0,0 +1,668 @@ +""" +Pytest configuration for Bifrost Governance Plugin testing. + +Provides comprehensive setup, fixtures, and utilities for testing the +Bifrost governance system with hierarchical budgets, rate limiting, +usage tracking, and CRUD operations for Virtual Keys, Teams, and Customers. +""" + +import pytest +import requests +import json +import uuid +import time +import os +from datetime import datetime, timedelta +from typing import Dict, List, Optional, Any, Tuple +from concurrent.futures import ThreadPoolExecutor +import threading +from dataclasses import dataclass +import copy + + +# Test Configuration +BIFROST_BASE_URL = os.getenv("BIFROST_BASE_URL", "http://localhost:8080") +GOVERNANCE_API_BASE = f"{BIFROST_BASE_URL}/api/governance" +COMPLETION_API_BASE = f"{BIFROST_BASE_URL}/v1" + + +def pytest_configure(config): + """Configure pytest with custom markers for governance testing""" + markers = [ + "governance: mark test as governance-related", + "virtual_keys: mark test as virtual key test", + "teams: mark test as team test", + "customers: mark test as customer test", + "budget: mark test as budget-related", + "rate_limit: mark test as rate limit-related", + "usage_tracking: mark test as usage tracking test", + "crud: mark test as CRUD operation test", + "field_updates: mark test as comprehensive field update test", + "validation: mark test as validation test", + "integration: mark test as integration test", + "edge_cases: mark test as edge case test", + "concurrency: mark test as concurrency test", + "mutual_exclusivity: mark test as mutual exclusivity test", + "hierarchical: mark test as hierarchical governance test", + "slow: mark test as slow running (>5s)", + "smoke: mark test as smoke test", + ] + + for marker in markers: + config.addinivalue_line("markers", marker) + + +@dataclass +class TestEntity: + """Base class for test entities""" + + id: str + created_at: Optional[str] = None + updated_at: Optional[str] = None + + +@dataclass +class TestBudget(TestEntity): + """Test budget entity""" + + max_limit: int = 0 + reset_duration: str = "" + current_usage: int = 0 + last_reset: Optional[str] = None + + +@dataclass +class TestRateLimit(TestEntity): + """Test rate limit entity""" + + token_max_limit: Optional[int] = None + token_reset_duration: Optional[str] = None + request_max_limit: Optional[int] = None + request_reset_duration: Optional[str] = None + token_current_usage: int = 0 + request_current_usage: int = 0 + token_last_reset: Optional[str] = None + request_last_reset: Optional[str] = None + + +@dataclass +class TestCustomer(TestEntity): + """Test customer entity""" + + name: str = "" + budget_id: Optional[str] = None + budget: Optional[TestBudget] = None + teams: Optional[List["TestTeam"]] = None + + +@dataclass +class TestTeam(TestEntity): + """Test team entity""" + + name: str = "" + customer_id: Optional[str] = None + budget_id: Optional[str] = None + customer: Optional[TestCustomer] = None + budget: Optional[TestBudget] = None + + +@dataclass +class TestVirtualKey(TestEntity): + """Test virtual key entity""" + + name: str = "" + value: str = "" + description: str = "" + allowed_models: Optional[List[str]] = None + allowed_providers: Optional[List[str]] = None + team_id: Optional[str] = None + customer_id: Optional[str] = None + budget_id: Optional[str] = None + rate_limit_id: Optional[str] = None + is_active: bool = True + team: Optional[TestTeam] = None + customer: Optional[TestCustomer] = None + budget: Optional[TestBudget] = None + rate_limit: Optional[TestRateLimit] = None + + +class GovernanceTestClient: + """HTTP client for governance API testing with comprehensive error handling""" + + def __init__(self, base_url: str = GOVERNANCE_API_BASE): + self.base_url = base_url + self.session = requests.Session() + self.session.headers.update({"Content-Type": "application/json"}) + + def request(self, method: str, endpoint: str, **kwargs) -> requests.Response: + """Make HTTP request with comprehensive error handling""" + url = f"{self.base_url}/{endpoint.lstrip('/')}" + try: + response = self.session.request(method, url, **kwargs) + return response + except requests.exceptions.RequestException as e: + pytest.fail(f"Request failed: {method} {url} - {str(e)}") + + # Virtual Key operations + def list_virtual_keys(self, **params) -> requests.Response: + """List all virtual keys""" + return self.request("GET", "/virtual-keys", params=params) + + def create_virtual_key(self, data: Dict[str, Any]) -> requests.Response: + """Create a virtual key""" + return self.request("POST", "/virtual-keys", json=data) + + def get_virtual_key(self, vk_id: str) -> requests.Response: + """Get virtual key by ID""" + return self.request("GET", f"/virtual-keys/{vk_id}") + + def update_virtual_key(self, vk_id: str, data: Dict[str, Any]) -> requests.Response: + """Update virtual key""" + return self.request("PUT", f"/virtual-keys/{vk_id}", json=data) + + def delete_virtual_key(self, vk_id: str) -> requests.Response: + """Delete virtual key""" + return self.request("DELETE", f"/virtual-keys/{vk_id}") + + # Team operations + def list_teams(self, **params) -> requests.Response: + """List all teams""" + return self.request("GET", "/teams", params=params) + + def create_team(self, data: Dict[str, Any]) -> requests.Response: + """Create a team""" + return self.request("POST", "/teams", json=data) + + def get_team(self, team_id: str) -> requests.Response: + """Get team by ID""" + return self.request("GET", f"/teams/{team_id}") + + def update_team(self, team_id: str, data: Dict[str, Any]) -> requests.Response: + """Update team""" + return self.request("PUT", f"/teams/{team_id}", json=data) + + def delete_team(self, team_id: str) -> requests.Response: + """Delete team""" + return self.request("DELETE", f"/teams/{team_id}") + + # Customer operations + def list_customers(self, **params) -> requests.Response: + """List all customers""" + return self.request("GET", "/customers", params=params) + + def create_customer(self, data: Dict[str, Any]) -> requests.Response: + """Create a customer""" + return self.request("POST", "/customers", json=data) + + def get_customer(self, customer_id: str) -> requests.Response: + """Get customer by ID""" + return self.request("GET", f"/customers/{customer_id}") + + def update_customer( + self, customer_id: str, data: Dict[str, Any] + ) -> requests.Response: + """Update customer""" + return self.request("PUT", f"/customers/{customer_id}", json=data) + + def delete_customer(self, customer_id: str) -> requests.Response: + """Delete customer""" + return self.request("DELETE", f"/customers/{customer_id}") + + # Monitoring and usage operations + def get_usage_stats(self, **params) -> requests.Response: + """Get usage statistics""" + return self.request("GET", "/usage-stats", params=params) + + def reset_usage(self, data: Dict[str, Any]) -> requests.Response: + """Reset usage counters""" + return self.request("POST", "/usage-reset", json=data) + + def get_debug_stats(self) -> requests.Response: + """Get debug statistics""" + return self.request("GET", "/debug/stats") + + def get_debug_counters(self) -> requests.Response: + """Get debug counters""" + return self.request("GET", "/debug/counters") + + def get_health_check(self) -> requests.Response: + """Get health check""" + return self.request("GET", "/debug/health") + + # Chat completion for integration testing + def chat_completion( + self, + messages: List[Dict], + model: str = "gpt-3.5-turbo", + headers: Optional[Dict] = None, + **kwargs, + ) -> requests.Response: + """Make chat completion request""" + data = {"model": model, "messages": messages, **kwargs} + + session_headers = self.session.headers.copy() + if headers: + session_headers.update(headers) + + url = f"{COMPLETION_API_BASE}/chat/completions" + try: + response = requests.post(url, json=data, headers=session_headers) + return response + except requests.exceptions.RequestException as e: + pytest.fail(f"Chat completion request failed: {url} - {str(e)}") + + +class CleanupTracker: + """Tracks entities created during tests for cleanup""" + + def __init__(self): + self.virtual_keys = [] + self.teams = [] + self.customers = [] + self._lock = threading.Lock() + + def add_virtual_key(self, vk_id: str): + """Add virtual key for cleanup""" + with self._lock: + if vk_id not in self.virtual_keys: + self.virtual_keys.append(vk_id) + + def add_team(self, team_id: str): + """Add team for cleanup""" + with self._lock: + if team_id not in self.teams: + self.teams.append(team_id) + + def add_customer(self, customer_id: str): + """Add customer for cleanup""" + with self._lock: + if customer_id not in self.customers: + self.customers.append(customer_id) + + def cleanup(self, client: GovernanceTestClient): + """Cleanup all tracked entities""" + with self._lock: + # Delete in dependency order: VKs -> Teams -> Customers + for vk_id in self.virtual_keys: + try: + client.delete_virtual_key(vk_id) + except Exception: + pass # Ignore cleanup errors + + for team_id in self.teams: + try: + client.delete_team(team_id) + except Exception: + pass + + for customer_id in self.customers: + try: + client.delete_customer(customer_id) + except Exception: + pass + + # Clear lists + self.virtual_keys.clear() + self.teams.clear() + self.customers.clear() + + +# Fixtures + + +@pytest.fixture(scope="session") +def governance_client(): + """Governance API client for the session""" + return GovernanceTestClient() + + +@pytest.fixture +def cleanup_tracker(): + """Cleanup tracker for test entities""" + return CleanupTracker() + + +@pytest.fixture(autouse=True) +def auto_cleanup(cleanup_tracker, governance_client): + """Automatically cleanup test entities after each test""" + yield + cleanup_tracker.cleanup(governance_client) + + +@pytest.fixture +def sample_budget_data(): + """Sample budget data for testing""" + return {"max_limit": 10000, "reset_duration": "1h"} # $100.00 in cents + + +@pytest.fixture +def sample_rate_limit_data(): + """Sample rate limit data for testing""" + return { + "token_max_limit": 1000, + "token_reset_duration": "1m", + "request_max_limit": 100, + "request_reset_duration": "1h", + } + + +@pytest.fixture +def sample_customer(governance_client, cleanup_tracker): + """Create a sample customer for testing""" + data = {"name": f"Test Customer {uuid.uuid4().hex[:8]}"} + response = governance_client.create_customer(data) + assert response.status_code == 201 + customer_data = response.json()["customer"] + cleanup_tracker.add_customer(customer_data["id"]) + return customer_data + + +@pytest.fixture +def sample_team(governance_client, cleanup_tracker): + """Create a sample team for testing""" + data = {"name": f"Test Team {uuid.uuid4().hex[:8]}"} + response = governance_client.create_team(data) + assert response.status_code == 201 + team_data = response.json()["team"] + cleanup_tracker.add_team(team_data["id"]) + return team_data + + +@pytest.fixture +def sample_team_with_customer(governance_client, cleanup_tracker, sample_customer): + """Create a sample team associated with a customer""" + data = { + "name": f"Test Team with Customer {uuid.uuid4().hex[:8]}", + "customer_id": sample_customer["id"], + } + response = governance_client.create_team(data) + assert response.status_code == 201 + team_data = response.json()["team"] + cleanup_tracker.add_team(team_data["id"]) + return team_data + + +@pytest.fixture +def sample_virtual_key(governance_client, cleanup_tracker): + """Create a sample virtual key for testing""" + data = {"name": f"Test VK {uuid.uuid4().hex[:8]}"} + response = governance_client.create_virtual_key(data) + assert response.status_code == 201 + vk_data = response.json()["virtual_key"] + cleanup_tracker.add_virtual_key(vk_data["id"]) + return vk_data + + +@pytest.fixture +def sample_virtual_key_with_team(governance_client, cleanup_tracker, sample_team): + """Create a sample virtual key associated with a team""" + data = { + "name": f"Test VK with Team {uuid.uuid4().hex[:8]}", + "team_id": sample_team["id"], + } + response = governance_client.create_virtual_key(data) + assert response.status_code == 201 + vk_data = response.json()["virtual_key"] + cleanup_tracker.add_virtual_key(vk_data["id"]) + return vk_data + + +@pytest.fixture +def sample_virtual_key_with_customer( + governance_client, cleanup_tracker, sample_customer +): + """Create a sample virtual key associated with a customer""" + data = { + "name": f"Test VK with Customer {uuid.uuid4().hex[:8]}", + "customer_id": sample_customer["id"], + } + response = governance_client.create_virtual_key(data) + assert response.status_code == 201 + vk_data = response.json()["virtual_key"] + cleanup_tracker.add_virtual_key(vk_data["id"]) + return vk_data + + +# Utility functions + + +def generate_unique_name(prefix: str = "Test") -> str: + """Generate a unique name for testing""" + return f"{prefix} {uuid.uuid4().hex[:8]} {int(time.time())}" + + +def wait_for_condition( + condition_func, timeout: float = 5.0, interval: float = 0.1 +) -> bool: + """Wait for a condition to be true""" + start_time = time.time() + while time.time() - start_time < timeout: + if condition_func(): + return True + time.sleep(interval) + return False + + +def assert_response_success(response: requests.Response, expected_status: int = 200): + """Assert that response is successful with expected status""" + if response.status_code != expected_status: + try: + error_data = response.json() + pytest.fail( + f"Expected status {expected_status}, got {response.status_code}: {error_data}" + ) + except: + pytest.fail( + f"Expected status {expected_status}, got {response.status_code}: {response.text}" + ) + + +def assert_field_unchanged(actual_value, expected_value, field_name: str): + """Assert that a field value hasn't changed""" + if actual_value != expected_value: + pytest.fail( + f"Field '{field_name}' changed unexpectedly. Expected: {expected_value}, Got: {actual_value}" + ) + + +def deep_compare_entities( + entity1: Dict, entity2: Dict, ignore_fields: List[str] = None +) -> List[str]: + """Deep compare two entities and return list of differences""" + if ignore_fields is None: + ignore_fields = ["updated_at", "created_at"] + + differences = [] + + def compare_values(path: str, val1, val2): + if isinstance(val1, dict) and isinstance(val2, dict): + for key in set(val1.keys()) | set(val2.keys()): + if key in ignore_fields: + continue + new_path = f"{path}.{key}" if path else key + if key not in val1: + differences.append(f"{new_path}: missing in first entity") + elif key not in val2: + differences.append(f"{new_path}: missing in second entity") + else: + compare_values(new_path, val1[key], val2[key]) + elif isinstance(val1, list) and isinstance(val2, list): + if len(val1) != len(val2): + differences.append( + f"{path}: list length differs ({len(val1)} vs {len(val2)})" + ) + else: + for i, (item1, item2) in enumerate(zip(val1, val2)): + compare_values(f"{path}[{i}]", item1, item2) + elif val1 != val2: + differences.append(f"{path}: {val1} != {val2}") + + compare_values("", entity1, entity2) + return differences + + +def create_complete_virtual_key_data( + name: str = None, + team_id: str = None, + customer_id: str = None, + include_budget: bool = True, + include_rate_limit: bool = True, +) -> Dict[str, Any]: + """Create complete virtual key data for testing""" + data = { + "name": name or generate_unique_name("Complete VK"), + "description": "Complete test virtual key with all fields", + "allowed_models": ["gpt-4", "claude-3-5-sonnet-20240620"], + "allowed_providers": ["openai", "anthropic"], + "is_active": True, + } + + if team_id: + data["team_id"] = team_id + elif customer_id: + data["customer_id"] = customer_id + + if include_budget: + data["budget"] = { + "max_limit": 50000, # $500.00 in cents + "reset_duration": "1d", + } + + if include_rate_limit: + data["rate_limit"] = { + "token_max_limit": 5000, + "token_reset_duration": "1h", + "request_max_limit": 500, + "request_reset_duration": "1h", + } + + return data + + +def verify_entity_relationships( + entity: Dict[str, Any], expected_relationships: Dict[str, Any] +): + """Verify that entity has expected relationship data loaded""" + for rel_name, expected_data in expected_relationships.items(): + if expected_data is None: + assert entity.get(rel_name) is None, f"Expected {rel_name} to be None" + else: + assert entity.get(rel_name) is not None, f"Expected {rel_name} to be loaded" + if isinstance(expected_data, dict): + for key, value in expected_data.items(): + assert ( + entity[rel_name].get(key) == value + ), f"Expected {rel_name}.{key} to be {value}" + + +def verify_unchanged_fields( + updated_entity: Dict, original_entity: Dict, exclude_fields: List[str] +): + """Verify that all fields except specified ones remain unchanged""" + ignore_fields = ["updated_at", "created_at"] + exclude_fields + + def check_field(path: str, updated_val, original_val): + if path in ignore_fields: + return + + if isinstance(updated_val, dict) and isinstance(original_val, dict): + for key in original_val.keys(): + if key not in ignore_fields: + new_path = f"{path}.{key}" if path else key + if key in updated_val: + check_field(new_path, updated_val[key], original_val[key]) + elif updated_val != original_val: + pytest.fail( + f"Field '{path}' should not have changed. Expected: {original_val}, Got: {updated_val}" + ) + + for field in original_entity.keys(): + if field not in ignore_fields: + if field in updated_entity: + check_field(field, updated_entity[field], original_entity[field]) + + +class FieldUpdateTester: + """Helper class for comprehensive field update testing""" + + def __init__(self, client: GovernanceTestClient, cleanup_tracker: CleanupTracker): + self.client = client + self.cleanup_tracker = cleanup_tracker + + def test_individual_field_updates( + self, entity_type: str, entity_id: str, field_test_cases: List[Dict] + ): + """Test updating individual fields one by one""" + + # Get original entity state + if entity_type == "virtual_key": + original_response = self.client.get_virtual_key(entity_id) + update_func = self.client.update_virtual_key + elif entity_type == "team": + original_response = self.client.get_team(entity_id) + update_func = self.client.update_team + elif entity_type == "customer": + original_response = self.client.get_customer(entity_id) + update_func = self.client.update_customer + else: + raise ValueError(f"Unknown entity type: {entity_type}") + + assert original_response.status_code == 200 + original_entity = original_response.json()[entity_type] + + for test_case in field_test_cases: + # Reset entity to original state if needed + if test_case.get("reset_before", True): + self._reset_entity_state(entity_type, entity_id, original_entity) + + # Perform field update + update_data = test_case["update_data"] + response = update_func(entity_id, update_data) + + # Verify update succeeded + assert ( + response.status_code == 200 + ), f"Field update failed for {test_case['field']}: {response.json()}" + updated_entity = response.json()[entity_type] + + # Verify target field was updated + if test_case.get("custom_validation"): + test_case["custom_validation"](updated_entity) + else: + self._verify_field_updated( + updated_entity, test_case["field"], test_case["expected_value"] + ) + + # Verify other fields unchanged if specified + if test_case.get("verify_unchanged", True): + exclude_fields = test_case.get( + "exclude_from_unchanged_check", [test_case["field"]] + ) + verify_unchanged_fields(updated_entity, original_entity, exclude_fields) + + def _reset_entity_state(self, entity_type: str, entity_id: str, target_state: Dict): + """Reset entity to target state""" + # This would require implementing a reset mechanism + # For now, we'll rely on test isolation + pass + + def _verify_field_updated(self, entity: Dict, field_path: str, expected_value): + """Verify that a field was updated to expected value""" + field_parts = field_path.split(".") + current_value = entity + + for part in field_parts: + if isinstance(current_value, dict): + current_value = current_value.get(part) + else: + pytest.fail(f"Cannot access field '{field_path}' in entity") + + assert ( + current_value == expected_value + ), f"Field '{field_path}' not updated correctly. Expected: {expected_value}, Got: {current_value}" + + +@pytest.fixture +def field_update_tester(governance_client, cleanup_tracker): + """Field update testing helper""" + return FieldUpdateTester(governance_client, cleanup_tracker) diff --git a/tests/governance/pytest.ini b/tests/governance/pytest.ini new file mode 100644 index 000000000..2f6bde148 --- /dev/null +++ b/tests/governance/pytest.ini @@ -0,0 +1,88 @@ +[tool:pytest] +# Pytest configuration for Bifrost Governance Plugin Testing + +# Test discovery +testpaths = . +python_files = test_*.py +python_classes = Test* +python_functions = test_* + +# Minimum version +minversion = 7.0 + +# Add options +addopts = + -ra + --strict-markers + --strict-config + --color=yes + --tb=short + --maxfail=10 + --durations=10 + --verbose + +# Markers for test categorization +markers = + governance: Tests for governance functionality + virtual_keys: Virtual Key CRUD and management tests + teams: Team CRUD and management tests + customers: Customer CRUD and management tests + budget: Budget-related tests + rate_limit: Rate limiting tests + usage_tracking: Usage tracking and monitoring tests + crud: CRUD operation tests + field_updates: Comprehensive field update tests + validation: Validation and constraint tests + integration: Integration and end-to-end tests + edge_cases: Edge cases and boundary condition tests + concurrency: Concurrency and race condition tests + mutual_exclusivity: Mutual exclusivity constraint tests + hierarchical: Hierarchical governance tests + slow: Tests that run slowly (> 5 seconds) + smoke: Smoke tests for quick validation + regression: Regression tests + api: API endpoint tests + relationships: Entity relationship tests + cleanup: Tests that require special cleanup + security: Security-related tests + +# Test timeout (in seconds) +timeout = 300 + +# Warnings configuration +filterwarnings = + error + ignore::UserWarning + ignore::DeprecationWarning + ignore::PendingDeprecationWarning + ignore::requests.packages.urllib3.disable_warnings + +# Logging configuration +log_cli = true +log_cli_level = INFO +log_cli_format = %(asctime)s [%(levelname)8s] %(name)s: %(message)s +log_cli_date_format = %Y-%m-%d %H:%M:%S + +log_file = governance_tests.log +log_file_level = DEBUG +log_file_format = %(asctime)s [%(levelname)8s] %(filename)s:%(lineno)d %(funcName)s(): %(message)s +log_file_date_format = %Y-%m-%d %H:%M:%S + +# Coverage configuration (when using --cov) +[coverage:run] +source = . +omit = + */tests/* + */test_* + */__pycache__/* + */venv/* + */env/* + .tox/* + +[coverage:report] +precision = 2 +show_missing = true +skip_covered = false + +[coverage:html] +directory = htmlcov \ No newline at end of file diff --git a/tests/governance/requirements.txt b/tests/governance/requirements.txt new file mode 100644 index 000000000..c25a0301f --- /dev/null +++ b/tests/governance/requirements.txt @@ -0,0 +1,52 @@ +# Bifrost Governance Plugin Test Suite Dependencies + +# Core testing framework +pytest>=7.4.0 +pytest-asyncio>=0.21.0 +pytest-xdist>=3.3.0 # For parallel test execution +pytest-cov>=4.1.0 # For coverage reporting +pytest-html>=3.2.0 # For HTML reports +pytest-json-report>=1.5.0 # For JSON reports +pytest-timeout>=2.1.0 # For test timeouts + +# HTTP client and API testing +requests>=2.31.0 +urllib3>=2.0.0 + +# Concurrency and async support +aiohttp>=3.8.0 + +# Data handling and validation +pydantic>=2.0.0 +jsonschema>=4.18.0 + +# Performance monitoring +psutil>=5.9.0 # For system metrics +memory-profiler>=0.61.0 # For memory profiling + +# Date/time handling +python-dateutil>=2.8.0 + +# Utilities +faker>=19.0.0 # For generating test data +factory-boy>=3.3.0 # For test data factories + +# Development and debugging +ipdb>=0.13.0 # Debugger +rich>=13.0.0 # Rich console output + +# Configuration management +python-dotenv>=1.0.0 # For environment configuration +pyyaml>=6.0 # For YAML configuration files + +# Type checking (development) +mypy>=1.5.0 # Static type checking +types-requests>=2.31.0 # Type stubs for requests + +# Testing utilities +pytest-mock>=3.11.0 # For mocking +pytest-benchmark>=4.0.0 # For benchmarking +freezegun>=1.2.0 # For time mocking + +# Load testing +locust>=2.15.0 # For load testing scenarios \ No newline at end of file diff --git a/tests/governance/test_customers_crud.py b/tests/governance/test_customers_crud.py new file mode 100644 index 000000000..7040b7f1f --- /dev/null +++ b/tests/governance/test_customers_crud.py @@ -0,0 +1,981 @@ +""" +Comprehensive Customer CRUD Tests for Bifrost Governance Plugin + +This module provides exhaustive testing of Customer operations including: +- Complete CRUD lifecycle testing +- Comprehensive field update testing (individual and batch) +- Team relationship management +- Budget management and hierarchies +- Cascading operations +- Edge cases and validation scenarios +- Concurrency and race condition testing +""" + +import pytest +import time +import uuid +from typing import Dict, Any, List +from concurrent.futures import ThreadPoolExecutor +import copy + +from conftest import ( + assert_response_success, + verify_unchanged_fields, + generate_unique_name, + verify_entity_relationships, + deep_compare_entities, +) + + +class TestCustomerBasicCRUD: + """Test basic CRUD operations for Customers""" + + @pytest.mark.customers + @pytest.mark.crud + @pytest.mark.smoke + def test_customer_create_minimal(self, governance_client, cleanup_tracker): + """Test creating customer with minimal required data""" + data = {"name": generate_unique_name("Minimal Customer")} + + response = governance_client.create_customer(data) + assert_response_success(response, 201) + + customer_data = response.json()["customer"] + cleanup_tracker.add_customer(customer_data["id"]) + + # Verify required fields + assert customer_data["name"] == data["name"] + assert customer_data["id"] is not None + assert customer_data["created_at"] is not None + assert customer_data["updated_at"] is not None + + # Verify optional fields are None/empty + assert customer_data["teams"] == [] + assert customer_data["virtual_keys"] is None + + @pytest.mark.customers + @pytest.mark.crud + @pytest.mark.budget + def test_customer_create_with_budget(self, governance_client, cleanup_tracker): + """Test creating customer with budget""" + data = { + "name": generate_unique_name("Budget Customer"), + "budget": { + "max_limit": 500000, # $5000.00 in cents + "reset_duration": "1M", + }, + } + + response = governance_client.create_customer(data) + assert_response_success(response, 201) + + customer_data = response.json()["customer"] + cleanup_tracker.add_customer(customer_data["id"]) + + # Verify budget was created + assert customer_data["budget"] is not None + assert customer_data["budget"]["max_limit"] == 500000 + assert customer_data["budget"]["reset_duration"] == "1M" + assert customer_data["budget"]["current_usage"] == 0 + assert customer_data["budget_id"] is not None + + @pytest.mark.customers + @pytest.mark.crud + def test_customer_list_all(self, governance_client, sample_customer): + """Test listing all customers""" + response = governance_client.list_customers() + assert_response_success(response, 200) + + data = response.json() + assert "customers" in data + assert "count" in data + assert isinstance(data["customers"], list) + assert data["count"] >= 1 + + # Find our test customer + test_customer = next( + ( + customer + for customer in data["customers"] + if customer["id"] == sample_customer["id"] + ), + None, + ) + assert test_customer is not None + + @pytest.mark.customers + @pytest.mark.crud + def test_customer_get_by_id(self, governance_client, sample_customer): + """Test getting customer by ID with relationships loaded""" + response = governance_client.get_customer(sample_customer["id"]) + assert_response_success(response, 200) + + customer_data = response.json()["customer"] + assert customer_data["id"] == sample_customer["id"] + assert customer_data["name"] == sample_customer["name"] + + # Verify teams relationship is loaded (empty list if no teams) + assert "teams" in customer_data + assert ( + isinstance(customer_data["teams"], list) or customer_data["teams"] is None + ) + + @pytest.mark.customers + @pytest.mark.crud + def test_customer_get_nonexistent(self, governance_client): + """Test getting non-existent customer returns 404""" + fake_id = str(uuid.uuid4()) + response = governance_client.get_customer(fake_id) + assert response.status_code == 404 + + @pytest.mark.customers + @pytest.mark.crud + def test_customer_delete(self, governance_client, cleanup_tracker): + """Test deleting a customer""" + # Create customer to delete + data = {"name": generate_unique_name("Delete Test Customer")} + create_response = governance_client.create_customer(data) + assert_response_success(create_response, 201) + customer_id = create_response.json()["customer"]["id"] + + # Delete customer + delete_response = governance_client.delete_customer(customer_id) + assert_response_success(delete_response, 200) + + # Verify customer is gone + get_response = governance_client.get_customer(customer_id) + assert get_response.status_code == 404 + + @pytest.mark.customers + @pytest.mark.crud + def test_customer_delete_nonexistent(self, governance_client): + """Test deleting non-existent customer returns 404""" + fake_id = str(uuid.uuid4()) + response = governance_client.delete_customer(fake_id) + assert response.status_code == 404 + + +class TestCustomerValidation: + """Test validation rules for Customer operations""" + + @pytest.mark.customers + @pytest.mark.validation + def test_customer_create_missing_name(self, governance_client): + """Test creating customer without name fails""" + data = {"budget": {"max_limit": 1000, "reset_duration": "1h"}} + response = governance_client.create_customer(data) + assert response.status_code == 400 + + @pytest.mark.customers + @pytest.mark.validation + def test_customer_create_empty_name(self, governance_client): + """Test creating customer with empty name fails""" + data = {"name": ""} + response = governance_client.create_customer(data) + assert response.status_code == 400 + + @pytest.mark.customers + @pytest.mark.validation + def test_customer_create_invalid_budget(self, governance_client): + """Test creating customer with invalid budget data""" + # Test negative budget + data = { + "name": generate_unique_name("Negative Budget Customer"), + "budget": {"max_limit": -10000, "reset_duration": "1h"}, + } + response = governance_client.create_customer(data) + assert response.status_code == 400 + + # Test invalid reset duration + data = { + "name": generate_unique_name("Invalid Duration Customer"), + "budget": {"max_limit": 10000, "reset_duration": "invalid_duration"}, + } + response = governance_client.create_customer(data) + assert response.status_code == 400 + + @pytest.mark.customers + @pytest.mark.validation + def test_customer_create_invalid_json(self, governance_client): + """Test creating customer with invalid data types""" + data = { + "name": 12345, # Should be string + "budget": "not_an_object", # Should be object + } + response = governance_client.create_customer(data) + assert response.status_code == 400 + + +class TestCustomerFieldUpdates: + """Comprehensive tests for Customer field updates""" + + @pytest.mark.customers + @pytest.mark.field_updates + def test_customer_update_individual_fields( + self, governance_client, cleanup_tracker + ): + """Test updating each customer field individually""" + # Create customer with all fields for testing + original_data = { + "name": generate_unique_name("Complete Update Test Customer"), + "budget": {"max_limit": 250000, "reset_duration": "1w"}, + } + create_response = governance_client.create_customer(original_data) + assert_response_success(create_response, 201) + customer_id = create_response.json()["customer"]["id"] + cleanup_tracker.add_customer(customer_id) + + # Get original state + original_response = governance_client.get_customer(customer_id) + original_customer = original_response.json()["customer"] + + # Test individual field updates + field_test_cases = [ + { + "field": "name", + "update_data": {"name": "Updated Customer Name"}, + "expected_value": "Updated Customer Name", + } + ] + + for test_case in field_test_cases: + # Reset customer to original state + reset_data = {"name": original_customer["name"]} + governance_client.update_customer(customer_id, reset_data) + + # Perform field update + response = governance_client.update_customer( + customer_id, test_case["update_data"] + ) + assert_response_success(response, 200) + updated_customer = response.json()["customer"] + + # Verify target field was updated + if test_case.get("custom_validation"): + test_case["custom_validation"](updated_customer) + else: + field_parts = test_case["field"].split(".") + current_value = updated_customer + for part in field_parts: + current_value = current_value[part] + assert ( + current_value == test_case["expected_value"] + ), f"Field {test_case['field']} not updated correctly" + + # Verify other fields unchanged (if specified) + if test_case.get("verify_unchanged", True): + exclude_fields = test_case.get( + "exclude_from_unchanged_check", [test_case["field"]] + ) + verify_unchanged_fields( + updated_customer, original_customer, exclude_fields + ) + + @pytest.mark.customers + @pytest.mark.field_updates + @pytest.mark.budget + def test_customer_budget_updates(self, governance_client, cleanup_tracker): + """Test comprehensive budget creation, update, and modification""" + # Create customer without budget + data = {"name": generate_unique_name("Budget Update Test Customer")} + create_response = governance_client.create_customer(data) + assert_response_success(create_response, 201) + customer_id = create_response.json()["customer"]["id"] + cleanup_tracker.add_customer(customer_id) + + # Test 1: Add budget to customer without budget + budget_data = {"max_limit": 100000, "reset_duration": "1M"} + response = governance_client.update_customer( + customer_id, {"budget": budget_data} + ) + assert_response_success(response, 200) + updated_customer = response.json()["customer"] + assert updated_customer["budget"]["max_limit"] == 100000 + assert updated_customer["budget"]["reset_duration"] == "1M" + assert updated_customer["budget_id"] is not None + + # Test 2: Update existing budget completely + new_budget_data = {"max_limit": 200000, "reset_duration": "3M"} + response = governance_client.update_customer( + customer_id, {"budget": new_budget_data} + ) + assert_response_success(response, 200) + updated_customer = response.json()["customer"] + assert updated_customer["budget"]["max_limit"] == 200000 + assert updated_customer["budget"]["reset_duration"] == "3M" + + # Test 3: Partial budget update (only max_limit) + response = governance_client.update_customer( + customer_id, {"budget": {"max_limit": 300000}} + ) + assert_response_success(response, 200) + updated_customer = response.json()["customer"] + assert updated_customer["budget"]["max_limit"] == 300000 + assert ( + updated_customer["budget"]["reset_duration"] == "3M" + ) # Should remain unchanged + + # Test 4: Partial budget update (only reset_duration) + response = governance_client.update_customer( + customer_id, {"budget": {"reset_duration": "6M"}} + ) + assert_response_success(response, 200) + updated_customer = response.json()["customer"] + assert ( + updated_customer["budget"]["max_limit"] == 300000 + ) # Should remain unchanged + assert updated_customer["budget"]["reset_duration"] == "6M" + + @pytest.mark.customers + @pytest.mark.field_updates + def test_customer_multiple_field_updates(self, governance_client, cleanup_tracker): + """Test updating multiple fields simultaneously""" + # Create customer with initial data + initial_data = { + "name": generate_unique_name("Multi-Field Test Customer"), + } + create_response = governance_client.create_customer(initial_data) + assert_response_success(create_response, 201) + customer_id = create_response.json()["customer"]["id"] + cleanup_tracker.add_customer(customer_id) + + # Update multiple fields at once + update_data = { + "name": "Updated Multi-Field Customer Name", + "budget": {"max_limit": 500000, "reset_duration": "1Y"}, + } + + response = governance_client.update_customer(customer_id, update_data) + assert_response_success(response, 200) + + updated_customer = response.json()["customer"] + assert updated_customer["name"] == "Updated Multi-Field Customer Name" + assert updated_customer["budget"]["max_limit"] == 500000 + assert updated_customer["budget"]["reset_duration"] == "1Y" + + @pytest.mark.customers + @pytest.mark.field_updates + @pytest.mark.edge_cases + def test_customer_update_edge_cases(self, governance_client, cleanup_tracker): + """Test edge cases in customer updates""" + # Create test customer + data = {"name": generate_unique_name("Edge Case Customer")} + create_response = governance_client.create_customer(data) + assert_response_success(create_response, 201) + customer_id = create_response.json()["customer"]["id"] + cleanup_tracker.add_customer(customer_id) + + original_response = governance_client.get_customer(customer_id) + original_customer = original_response.json()["customer"] + + # Test 1: Empty update (should return unchanged customer) + response = governance_client.update_customer(customer_id, {}) + assert_response_success(response, 200) + updated_customer = response.json()["customer"] + + # Compare ignoring timestamps + differences = deep_compare_entities( + updated_customer, original_customer, ignore_fields=["updated_at"] + ) + assert len(differences) == 0, f"Empty update changed fields: {differences}" + + # Test 2: Update with same values + response = governance_client.update_customer( + customer_id, {"name": original_customer["name"]} + ) + assert_response_success(response, 200) + + # Test 3: Very long customer name (test field length limits) + long_name = "x" * 1000 # Adjust based on actual field limits + response = governance_client.update_customer(customer_id, {"name": long_name}) + # Expected behavior depends on API validation rules + + @pytest.mark.customers + @pytest.mark.field_updates + def test_customer_update_nonexistent(self, governance_client): + """Test updating non-existent customer returns 404""" + fake_id = str(uuid.uuid4()) + response = governance_client.update_customer(fake_id, {"name": "test"}) + assert response.status_code == 404 + + +class TestCustomerBudgetManagement: + """Test customer budget specific functionality""" + + @pytest.mark.customers + @pytest.mark.budget + def test_customer_budget_creation_and_validation( + self, governance_client, cleanup_tracker + ): + """Test budget creation with various configurations""" + # Test valid budget configurations + budget_test_cases = [ + {"max_limit": 50000, "reset_duration": "1d"}, + {"max_limit": 250000, "reset_duration": "1w"}, + {"max_limit": 1000000, "reset_duration": "1M"}, + {"max_limit": 5000000, "reset_duration": "3M"}, + {"max_limit": 10000000, "reset_duration": "1Y"}, + ] + + for budget_config in budget_test_cases: + data = { + "name": generate_unique_name( + f"Budget Customer {budget_config['reset_duration']}" + ), + "budget": budget_config, + } + + response = governance_client.create_customer(data) + assert_response_success(response, 201) + + customer_data = response.json()["customer"] + cleanup_tracker.add_customer(customer_data["id"]) + + assert customer_data["budget"]["max_limit"] == budget_config["max_limit"] + assert ( + customer_data["budget"]["reset_duration"] + == budget_config["reset_duration"] + ) + assert customer_data["budget"]["current_usage"] == 0 + assert customer_data["budget"]["last_reset"] is not None + + @pytest.mark.customers + @pytest.mark.budget + @pytest.mark.edge_cases + def test_customer_budget_edge_cases(self, governance_client, cleanup_tracker): + """Test budget edge cases and boundary conditions""" + # Test boundary values + edge_case_budgets = [ + {"max_limit": 0, "reset_duration": "1h"}, # Zero budget + {"max_limit": 1, "reset_duration": "1s"}, # Minimal values + {"max_limit": 9223372036854775807, "reset_duration": "1h"}, # Max int64 + ] + + for budget_config in edge_case_budgets: + data = { + "name": generate_unique_name( + f"Edge Budget Customer {budget_config['max_limit']}" + ), + "budget": budget_config, + } + + response = governance_client.create_customer(data) + # Adjust assertions based on API validation rules + if ( + budget_config["max_limit"] >= 0 + ): # Assuming non-negative budgets are valid + assert_response_success(response, 201) + cleanup_tracker.add_customer(response.json()["customer"]["id"]) + else: + assert response.status_code == 400 + + @pytest.mark.customers + @pytest.mark.budget + @pytest.mark.hierarchical + def test_customer_budget_hierarchy_foundation( + self, governance_client, cleanup_tracker + ): + """Test customer budget as foundation of hierarchical budget system""" + # Create customer with large budget (top of hierarchy) + customer_data = { + "name": generate_unique_name("Hierarchy Foundation Customer"), + "budget": {"max_limit": 1000000, "reset_duration": "1M"}, # $10,000 + } + customer_response = governance_client.create_customer(customer_data) + assert_response_success(customer_response, 201) + customer = customer_response.json()["customer"] + cleanup_tracker.add_customer(customer["id"]) + + # Create teams under this customer with smaller budgets + team1_data = { + "name": generate_unique_name("Sub-Team 1"), + "customer_id": customer["id"], + "budget": {"max_limit": 300000, "reset_duration": "1M"}, # $3,000 + } + team1_response = governance_client.create_team(team1_data) + assert_response_success(team1_response, 201) + team1 = team1_response.json()["team"] + cleanup_tracker.add_team(team1["id"]) + + team2_data = { + "name": generate_unique_name("Sub-Team 2"), + "customer_id": customer["id"], + "budget": {"max_limit": 200000, "reset_duration": "1M"}, # $2,000 + } + team2_response = governance_client.create_team(team2_data) + assert_response_success(team2_response, 201) + team2 = team2_response.json()["team"] + cleanup_tracker.add_team(team2["id"]) + + # Create VKs under teams with even smaller budgets + vk1_data = { + "name": generate_unique_name("Team1 VK"), + "team_id": team1["id"], + "budget": {"max_limit": 100000, "reset_duration": "1M"}, # $1,000 + } + vk1_response = governance_client.create_virtual_key(vk1_data) + assert_response_success(vk1_response, 201) + vk1 = vk1_response.json()["virtual_key"] + cleanup_tracker.add_virtual_key(vk1["id"]) + + # Verify hierarchy structure + assert customer["budget"]["max_limit"] == 1000000 + assert team1["budget"]["max_limit"] == 300000 + assert team2["budget"]["max_limit"] == 200000 + assert vk1["budget"]["max_limit"] == 100000 + + # Verify relationships + assert team1["customer_id"] == customer["id"] + assert team2["customer_id"] == customer["id"] + assert vk1["team_id"] == team1["id"] + + @pytest.mark.customers + @pytest.mark.budget + def test_customer_budget_large_scale(self, governance_client, cleanup_tracker): + """Test customer budgets for large enterprise scenarios""" + # Test very large budget for enterprise customer + enterprise_data = { + "name": generate_unique_name("Enterprise Customer"), + "budget": { + "max_limit": 100000000000, # $1 billion in cents + "reset_duration": "1Y", + }, + } + + response = governance_client.create_customer(enterprise_data) + assert_response_success(response, 201) + customer = response.json()["customer"] + cleanup_tracker.add_customer(customer["id"]) + + assert customer["budget"]["max_limit"] == 100000000000 + assert customer["budget"]["reset_duration"] == "1Y" + + +class TestCustomerTeamRelationships: + """Test customer relationships with teams""" + + @pytest.mark.customers + @pytest.mark.relationships + def test_customer_teams_relationship_loading( + self, governance_client, cleanup_tracker + ): + """Test that customer properly loads teams relationships""" + # Create customer + customer_data = {"name": generate_unique_name("Team Parent Customer")} + customer_response = governance_client.create_customer(customer_data) + assert_response_success(customer_response, 201) + customer = customer_response.json()["customer"] + cleanup_tracker.add_customer(customer["id"]) + + # Create teams under this customer + team_names = [] + for i in range(3): + team_name = generate_unique_name(f"Customer Team {i}") + team_names.append(team_name) + team_data = {"name": team_name, "customer_id": customer["id"]} + team_response = governance_client.create_team(team_data) + assert_response_success(team_response, 201) + cleanup_tracker.add_team(team_response.json()["team"]["id"]) + + # Fetch customer with teams loaded + customer_response = governance_client.get_customer(customer["id"]) + assert_response_success(customer_response, 200) + customer_with_teams = customer_response.json()["customer"] + + # Verify teams relationship loaded + assert "teams" in customer_with_teams + teams = customer_with_teams["teams"] + assert isinstance(teams, list) + assert len(teams) == 3 + + # Verify all team names are present + loaded_team_names = {team["name"] for team in teams} + for name in team_names: + assert name in loaded_team_names + + # Verify all teams have correct customer_id + for team in teams: + assert team["customer_id"] == customer["id"] + + @pytest.mark.customers + @pytest.mark.relationships + def test_customer_with_no_teams(self, governance_client, cleanup_tracker): + """Test customer with no teams has empty teams list""" + # Create customer without teams + customer_data = {"name": generate_unique_name("No Teams Customer")} + customer_response = governance_client.create_customer(customer_data) + assert_response_success(customer_response, 201) + customer = customer_response.json()["customer"] + cleanup_tracker.add_customer(customer["id"]) + + # Fetch customer with teams loaded + customer_response = governance_client.get_customer(customer["id"]) + assert_response_success(customer_response, 200) + customer_data = customer_response.json()["customer"] + + # Teams should be empty list or None + teams = customer_data.get("teams") + assert teams == [] or teams is None + + @pytest.mark.customers + @pytest.mark.relationships + def test_customer_teams_cascading_operations( + self, governance_client, cleanup_tracker + ): + """Test cascading operations between customers and teams""" + # Create customer + customer_data = {"name": generate_unique_name("Cascade Test Customer")} + customer_response = governance_client.create_customer(customer_data) + assert_response_success(customer_response, 201) + customer = customer_response.json()["customer"] + cleanup_tracker.add_customer(customer["id"]) + + # Create teams under customer + team_ids = [] + for i in range(2): + team_data = { + "name": generate_unique_name(f"Cascade Team {i}"), + "customer_id": customer["id"], + } + team_response = governance_client.create_team(team_data) + assert_response_success(team_response, 201) + team_id = team_response.json()["team"]["id"] + team_ids.append(team_id) + cleanup_tracker.add_team(team_id) + + # Create VKs under teams + vk_ids = [] + for team_id in team_ids: + vk_data = {"name": generate_unique_name("Cascade VK"), "team_id": team_id} + vk_response = governance_client.create_virtual_key(vk_data) + assert_response_success(vk_response, 201) + vk_id = vk_response.json()["virtual_key"]["id"] + vk_ids.append(vk_id) + cleanup_tracker.add_virtual_key(vk_id) + + # Verify all entities exist and are properly linked + customer_response = governance_client.get_customer(customer["id"]) + customer_with_teams = customer_response.json()["customer"] + assert len(customer_with_teams["teams"]) == 2 + + for vk_id in vk_ids: + vk_response = governance_client.get_virtual_key(vk_id) + vk = vk_response.json()["virtual_key"] + assert vk["team"] is not None + assert vk["team"]["customer_id"] == customer["id"] + + @pytest.mark.customers + @pytest.mark.relationships + @pytest.mark.edge_cases + def test_customer_orphaned_teams_handling(self, governance_client, cleanup_tracker): + """Test customer behavior when teams reference non-existent customer""" + # This test simulates data integrity issues + # In practice, this would be prevented by foreign key constraints + + # Create customer and team + customer_data = {"name": generate_unique_name("Temp Customer")} + customer_response = governance_client.create_customer(customer_data) + assert_response_success(customer_response, 201) + customer = customer_response.json()["customer"] + cleanup_tracker.add_customer(customer["id"]) + + team_data = { + "name": generate_unique_name("Orphan Test Team"), + "customer_id": customer["id"], + } + team_response = governance_client.create_team(team_data) + assert_response_success(team_response, 201) + team = team_response.json()["team"] + cleanup_tracker.add_team(team["id"]) + + # If we were to delete the customer, what happens to the team? + # This depends on database constraints and API implementation + # For now, we just verify the relationship exists correctly + assert team["customer_id"] == customer["id"] + assert team["customer"]["id"] == customer["id"] + + +class TestCustomerConcurrency: + """Test concurrent operations on Customers""" + + @pytest.mark.customers + @pytest.mark.concurrency + @pytest.mark.slow + def test_customer_concurrent_creation(self, governance_client, cleanup_tracker): + """Test creating multiple customers concurrently""" + + def create_customer(index): + data = {"name": generate_unique_name(f"Concurrent Customer {index}")} + response = governance_client.create_customer(data) + return response + + # Create 10 customers concurrently + with ThreadPoolExecutor(max_workers=10) as executor: + futures = [executor.submit(create_customer, i) for i in range(10)] + responses = [future.result() for future in futures] + + # Verify all succeeded + created_customers = [] + for response in responses: + assert_response_success(response, 201) + customer_data = response.json()["customer"] + created_customers.append(customer_data) + cleanup_tracker.add_customer(customer_data["id"]) + + # Verify all customers have unique IDs + customer_ids = [customer["id"] for customer in created_customers] + assert len(set(customer_ids)) == 10 # All unique IDs + + @pytest.mark.customers + @pytest.mark.concurrency + @pytest.mark.slow + def test_customer_concurrent_updates(self, governance_client, cleanup_tracker): + """Test updating same customer concurrently""" + # Create customer to update + data = {"name": generate_unique_name("Concurrent Update Customer")} + create_response = governance_client.create_customer(data) + assert_response_success(create_response, 201) + customer_id = create_response.json()["customer"]["id"] + cleanup_tracker.add_customer(customer_id) + + # Update concurrently with different names + def update_customer(index): + update_data = {"name": f"Updated by thread {index}"} + response = governance_client.update_customer(customer_id, update_data) + return response, index + + with ThreadPoolExecutor(max_workers=5) as executor: + futures = [executor.submit(update_customer, i) for i in range(5)] + results = [future.result() for future in futures] + + # All updates should succeed (last one wins) + for response, index in results: + assert_response_success(response, 200) + + # Verify final state + final_response = governance_client.get_customer(customer_id) + final_customer = final_response.json()["customer"] + assert final_customer["name"].startswith("Updated by thread") + + @pytest.mark.customers + @pytest.mark.concurrency + @pytest.mark.slow + def test_customer_concurrent_budget_updates( + self, governance_client, cleanup_tracker + ): + """Test concurrent budget updates on same customer""" + # Create customer with budget + data = { + "name": generate_unique_name("Concurrent Budget Customer"), + "budget": {"max_limit": 100000, "reset_duration": "1d"}, + } + create_response = governance_client.create_customer(data) + assert_response_success(create_response, 201) + customer_id = create_response.json()["customer"]["id"] + cleanup_tracker.add_customer(customer_id) + + # Update budget concurrently with different limits + def update_budget(index): + limit = 100000 + (index * 10000) # Different limits + update_data = {"budget": {"max_limit": limit}} + response = governance_client.update_customer(customer_id, update_data) + return response, limit + + with ThreadPoolExecutor(max_workers=5) as executor: + futures = [executor.submit(update_budget, i) for i in range(5)] + results = [future.result() for future in futures] + + # All updates should succeed + for response, limit in results: + assert_response_success(response, 200) + + # Verify final state has one of the updated limits + final_response = governance_client.get_customer(customer_id) + final_customer = final_response.json()["customer"] + final_limit = final_customer["budget"]["max_limit"] + expected_limits = [100000 + (i * 10000) for i in range(5)] + assert final_limit in expected_limits + + +class TestCustomerComplexScenarios: + """Test complex scenarios involving customers""" + + @pytest.mark.customers + @pytest.mark.hierarchical + @pytest.mark.slow + def test_customer_large_hierarchy_creation( + self, governance_client, cleanup_tracker + ): + """Test creating large hierarchical structure under customer""" + # Create customer + customer_data = { + "name": generate_unique_name("Large Hierarchy Customer"), + "budget": {"max_limit": 10000000, "reset_duration": "1M"}, # $100,000 + } + customer_response = governance_client.create_customer(customer_data) + assert_response_success(customer_response, 201) + customer = customer_response.json()["customer"] + cleanup_tracker.add_customer(customer["id"]) + + # Create multiple teams + team_ids = [] + for i in range(5): + team_data = { + "name": generate_unique_name(f"Large Hierarchy Team {i}"), + "customer_id": customer["id"], + "budget": { + "max_limit": 1000000, + "reset_duration": "1M", + }, # $10,000 each + } + team_response = governance_client.create_team(team_data) + assert_response_success(team_response, 201) + team_id = team_response.json()["team"]["id"] + team_ids.append(team_id) + cleanup_tracker.add_team(team_id) + + # Create multiple VKs per team + vk_count = 0 + for team_id in team_ids: + for j in range(3): # 3 VKs per team + vk_data = { + "name": generate_unique_name(f"Large Hierarchy VK {team_id}-{j}"), + "team_id": team_id, + "budget": { + "max_limit": 100000, + "reset_duration": "1M", + }, # $1,000 each + } + vk_response = governance_client.create_virtual_key(vk_data) + assert_response_success(vk_response, 201) + vk_id = vk_response.json()["virtual_key"]["id"] + cleanup_tracker.add_virtual_key(vk_id) + vk_count += 1 + + # Verify hierarchy structure + customer_response = governance_client.get_customer(customer["id"]) + customer_with_teams = customer_response.json()["customer"] + + assert len(customer_with_teams["teams"]) == 5 + assert vk_count == 15 # 5 teams * 3 VKs each + + # Verify budget hierarchy makes sense + total_team_budgets = sum( + team.get("budget", {}).get("max_limit", 0) + for team in customer_with_teams["teams"] + ) + assert ( + total_team_budgets <= customer["budget"]["max_limit"] + ) # Teams shouldn't exceed customer + + @pytest.mark.customers + @pytest.mark.performance + @pytest.mark.slow + def test_customer_performance_with_many_teams( + self, governance_client, cleanup_tracker + ): + """Test customer performance when loading many teams""" + # Create customer + customer_data = {"name": generate_unique_name("Performance Test Customer")} + customer_response = governance_client.create_customer(customer_data) + assert_response_success(customer_response, 201) + customer = customer_response.json()["customer"] + cleanup_tracker.add_customer(customer["id"]) + + # Create many teams + team_count = 50 # Adjust based on performance requirements + start_time = time.time() + + for i in range(team_count): + team_data = { + "name": generate_unique_name(f"Perf Team {i}"), + "customer_id": customer["id"], + } + team_response = governance_client.create_team(team_data) + assert_response_success(team_response, 201) + cleanup_tracker.add_team(team_response.json()["team"]["id"]) + + creation_time = time.time() - start_time + + # Test customer loading performance + start_time = time.time() + customer_response = governance_client.get_customer(customer["id"]) + assert_response_success(customer_response, 200) + load_time = time.time() - start_time + + customer_with_teams = customer_response.json()["customer"] + assert len(customer_with_teams["teams"]) == team_count + + # Log performance metrics (adjust thresholds based on requirements) + print(f"Created {team_count} teams in {creation_time:.2f}s") + print(f"Loaded customer with {team_count} teams in {load_time:.2f}s") + + # Performance assertions (adjust based on requirements) + assert ( + load_time < 5.0 + ), f"Loading customer with {team_count} teams took too long: {load_time}s" + + @pytest.mark.customers + @pytest.mark.integration + def test_customer_full_lifecycle_scenario(self, governance_client, cleanup_tracker): + """Test complete customer lifecycle scenario""" + # 1. Create customer with budget + customer_data = { + "name": generate_unique_name("Lifecycle Customer"), + "budget": {"max_limit": 1000000, "reset_duration": "1M"}, + } + customer_response = governance_client.create_customer(customer_data) + assert_response_success(customer_response, 201) + customer = customer_response.json()["customer"] + cleanup_tracker.add_customer(customer["id"]) + + # 2. Update customer name and budget + update_data = { + "name": "Updated Lifecycle Customer", + "budget": {"max_limit": 2000000, "reset_duration": "3M"}, + } + update_response = governance_client.update_customer(customer["id"], update_data) + assert_response_success(update_response, 200) + updated_customer = update_response.json()["customer"] + assert updated_customer["name"] == "Updated Lifecycle Customer" + assert updated_customer["budget"]["max_limit"] == 2000000 + + # 3. Create teams under customer + team_data = { + "name": generate_unique_name("Lifecycle Team"), + "customer_id": customer["id"], + "budget": {"max_limit": 500000, "reset_duration": "1M"}, + } + team_response = governance_client.create_team(team_data) + assert_response_success(team_response, 201) + team = team_response.json()["team"] + cleanup_tracker.add_team(team["id"]) + + # 4. Create VKs under team + vk_data = { + "name": generate_unique_name("Lifecycle VK"), + "team_id": team["id"], + "budget": {"max_limit": 100000, "reset_duration": "1d"}, + } + vk_response = governance_client.create_virtual_key(vk_data) + assert_response_success(vk_response, 201) + vk = vk_response.json()["virtual_key"] + cleanup_tracker.add_virtual_key(vk["id"]) + + # 5. Verify complete hierarchy + final_customer_response = governance_client.get_customer(customer["id"]) + final_customer = final_customer_response.json()["customer"] + + assert final_customer["name"] == "Updated Lifecycle Customer" + assert len(final_customer["teams"]) == 1 + assert final_customer["teams"][0]["id"] == team["id"] + + final_vk_response = governance_client.get_virtual_key(vk["id"]) + final_vk = final_vk_response.json()["virtual_key"] + + # Verify VK belongs to team (customer relationship not preloaded in VK->team) + assert final_vk["team"]["id"] == team["id"] + assert final_vk["team"].get("customer_id") == customer["id"] + + # 6. Clean up (automatic via cleanup_tracker) + # This tests the full CRUD lifecycle diff --git a/tests/governance/test_helpers.py b/tests/governance/test_helpers.py new file mode 100644 index 000000000..605f8f398 --- /dev/null +++ b/tests/governance/test_helpers.py @@ -0,0 +1,644 @@ +""" +Helper utilities and test data generators for Bifrost Governance Plugin tests. + +This module provides additional utilities for test data generation, validation, +and common test operations to support the comprehensive governance test suite. +""" + +import pytest +import uuid +import time +import json +import random +from typing import Dict, Any, List, Optional, Union +from datetime import datetime, timedelta +from faker import Faker + +from conftest import assert_response_success, generate_unique_name, GovernanceTestClient + +# Initialize Faker for generating test data +fake = Faker() + + +class TestDataFactory: + """Factory for generating realistic test data""" + + @staticmethod + def generate_budget_config( + min_limit: int = 1000, + max_limit: int = 1000000, + duration_options: List[str] = None, + ) -> Dict[str, Any]: + """Generate realistic budget configuration""" + if duration_options is None: + duration_options = ["1h", "1d", "1w", "1M", "3M", "6M", "1Y"] + + return { + "max_limit": random.randint(min_limit, max_limit), + "reset_duration": random.choice(duration_options), + } + + @staticmethod + def generate_rate_limit_config( + include_tokens: bool = True, include_requests: bool = True + ) -> Dict[str, Any]: + """Generate realistic rate limit configuration""" + config = {} + + if include_tokens: + config.update( + { + "token_max_limit": random.randint(100, 100000), + "token_reset_duration": random.choice(["1m", "5m", "1h", "1d"]), + } + ) + + if include_requests: + config.update( + { + "request_max_limit": random.randint(10, 10000), + "request_reset_duration": random.choice(["1m", "5m", "1h", "1d"]), + } + ) + + return config + + @staticmethod + def generate_customer_data(include_budget: bool = False) -> Dict[str, Any]: + """Generate realistic customer data""" + data = {"name": f"{fake.company()} ({generate_unique_name('Customer')})"} + + if include_budget: + data["budget"] = TestDataFactory.generate_budget_config( + min_limit=100000, max_limit=10000000 # Customers have larger budgets + ) + + return data + + @staticmethod + def generate_team_data( + customer_id: Optional[str] = None, include_budget: bool = False + ) -> Dict[str, Any]: + """Generate realistic team data""" + team_types = [ + "Engineering", + "Marketing", + "Sales", + "Research", + "Support", + "Operations", + ] + data = { + "name": f"{random.choice(team_types)} Team ({generate_unique_name('Team')})" + } + + if customer_id: + data["customer_id"] = customer_id + + if include_budget: + data["budget"] = TestDataFactory.generate_budget_config( + min_limit=10000, max_limit=1000000 # Teams have medium budgets + ) + + return data + + @staticmethod + def generate_virtual_key_data( + team_id: Optional[str] = None, + customer_id: Optional[str] = None, + include_budget: bool = False, + include_rate_limit: bool = False, + model_restrictions: bool = False, + ) -> Dict[str, Any]: + """Generate realistic virtual key data""" + purposes = [ + "Development", + "Production", + "Testing", + "Staging", + "Demo", + "Research", + ] + data = { + "name": f"{random.choice(purposes)} VK ({generate_unique_name('VK')})", + "description": fake.sentence(), + "is_active": random.choice([True, True, True, False]), # 75% active + } + + if team_id: + data["team_id"] = team_id + elif customer_id: + data["customer_id"] = customer_id + + if model_restrictions: + all_models = [ + "gpt-4", + "gpt-3.5-turbo", + "gpt-4-turbo", + "claude-3-5-sonnet-20240620", + "claude-3-7-sonnet-20250219", + ] + all_providers = ["openai", "anthropic"] + + data["allowed_models"] = random.sample( + all_models, random.randint(1, len(all_models)) + ) + data["allowed_providers"] = random.sample( + all_providers, random.randint(1, len(all_providers)) + ) + + if include_budget: + data["budget"] = TestDataFactory.generate_budget_config( + min_limit=1000, max_limit=100000 # VKs have smaller budgets + ) + + if include_rate_limit: + data["rate_limit"] = TestDataFactory.generate_rate_limit_config() + + return data + + +class ValidationHelper: + """Helper functions for validating test results""" + + @staticmethod + def validate_entity_structure( + entity: Dict[str, Any], entity_type: str + ) -> List[str]: + """Validate that entity has expected structure""" + errors = [] + + # Common fields all entities should have + required_fields = ["id", "created_at", "updated_at"] + for field in required_fields: + if field not in entity: + errors.append(f"Missing required field: {field}") + elif entity[field] is None: + errors.append(f"Required field is None: {field}") + + # Entity-specific validation + if entity_type == "virtual_key": + vk_fields = ["name", "value", "is_active"] + for field in vk_fields: + if field not in entity: + errors.append(f"VK missing field: {field}") + + elif entity_type == "team": + team_fields = ["name"] + for field in team_fields: + if field not in entity: + errors.append(f"Team missing field: {field}") + + elif entity_type == "customer": + customer_fields = ["name"] + for field in customer_fields: + if field not in entity: + errors.append(f"Customer missing field: {field}") + + return errors + + @staticmethod + def validate_budget_structure(budget: Dict[str, Any]) -> List[str]: + """Validate budget structure""" + errors = [] + required_fields = [ + "id", + "max_limit", + "reset_duration", + "current_usage", + "last_reset", + ] + + for field in required_fields: + if field not in budget: + errors.append(f"Budget missing field: {field}") + + if budget.get("max_limit") is not None and budget["max_limit"] < 0: + errors.append("Budget max_limit cannot be negative") + + if budget.get("current_usage") is not None and budget["current_usage"] < 0: + errors.append("Budget current_usage cannot be negative") + + return errors + + @staticmethod + def validate_rate_limit_structure(rate_limit: Dict[str, Any]) -> List[str]: + """Validate rate limit structure""" + errors = [] + required_fields = ["id"] + + for field in required_fields: + if field not in rate_limit: + errors.append(f"Rate limit missing field: {field}") + + # At least one limit should be specified + token_fields = ["token_max_limit", "token_reset_duration"] + request_fields = ["request_max_limit", "request_reset_duration"] + + has_token_limits = any( + rate_limit.get(field) is not None for field in token_fields + ) + has_request_limits = any( + rate_limit.get(field) is not None for field in request_fields + ) + + if not has_token_limits and not has_request_limits: + errors.append("Rate limit must have either token or request limits") + + return errors + + @staticmethod + def validate_hierarchy_consistency( + customer: Dict, teams: List[Dict], vks: List[Dict] + ) -> List[str]: + """Validate hierarchical consistency""" + errors = [] + + # Check team customer references + for team in teams: + if team.get("customer_id") != customer["id"]: + errors.append(f"Team {team['id']} has incorrect customer_id") + + # Check VK team references + team_ids = {team["id"] for team in teams} + for vk in vks: + if vk.get("team_id") and vk["team_id"] not in team_ids: + errors.append(f"VK {vk['id']} references non-existent team") + + return errors + + +class TestScenarioBuilder: + """Builder for complex test scenarios""" + + def __init__(self, client: GovernanceTestClient, cleanup_tracker): + self.client = client + self.cleanup_tracker = cleanup_tracker + self.created_entities = {"customers": [], "teams": [], "virtual_keys": []} + + def create_customer(self, **kwargs) -> Dict[str, Any]: + """Create a customer with automatic cleanup tracking""" + data = TestDataFactory.generate_customer_data(**kwargs) + response = self.client.create_customer(data) + assert_response_success(response, 201) + + customer = response.json()["customer"] + self.cleanup_tracker.add_customer(customer["id"]) + self.created_entities["customers"].append(customer) + return customer + + def create_team( + self, customer_id: Optional[str] = None, **kwargs + ) -> Dict[str, Any]: + """Create a team with automatic cleanup tracking""" + data = TestDataFactory.generate_team_data(customer_id=customer_id, **kwargs) + response = self.client.create_team(data) + assert_response_success(response, 201) + + team = response.json()["team"] + self.cleanup_tracker.add_team(team["id"]) + self.created_entities["teams"].append(team) + return team + + def create_virtual_key( + self, team_id: Optional[str] = None, customer_id: Optional[str] = None, **kwargs + ) -> Dict[str, Any]: + """Create a virtual key with automatic cleanup tracking""" + data = TestDataFactory.generate_virtual_key_data( + team_id=team_id, customer_id=customer_id, **kwargs + ) + response = self.client.create_virtual_key(data) + assert_response_success(response, 201) + + vk = response.json()["virtual_key"] + self.cleanup_tracker.add_virtual_key(vk["id"]) + self.created_entities["virtual_keys"].append(vk) + return vk + + def create_simple_hierarchy(self) -> Dict[str, Any]: + """Create a simple Customer -> Team -> VK hierarchy""" + customer = self.create_customer(include_budget=True) + team = self.create_team(customer_id=customer["id"], include_budget=True) + vk = self.create_virtual_key( + team_id=team["id"], include_budget=True, include_rate_limit=True + ) + + return {"customer": customer, "team": team, "virtual_key": vk} + + def create_complex_hierarchy( + self, team_count: int = 3, vk_per_team: int = 2 + ) -> Dict[str, Any]: + """Create a complex hierarchy with multiple teams and VKs""" + customer = self.create_customer(include_budget=True) + + teams = [] + for i in range(team_count): + team = self.create_team(customer_id=customer["id"], include_budget=True) + teams.append(team) + + vks = [] + for team in teams: + for j in range(vk_per_team): + vk = self.create_virtual_key( + team_id=team["id"], + include_budget=True, + include_rate_limit=True, + model_restrictions=random.choice([True, False]), + ) + vks.append(vk) + + return {"customer": customer, "teams": teams, "virtual_keys": vks} + + def create_mixed_vk_associations(self) -> Dict[str, Any]: + """Create VKs with mixed team/customer associations""" + customer = self.create_customer(include_budget=True) + team = self.create_team(customer_id=customer["id"], include_budget=True) + + # VK directly associated with customer + customer_vk = self.create_virtual_key( + customer_id=customer["id"], include_budget=True + ) + + # VK associated with team (indirect customer association) + team_vk = self.create_virtual_key(team_id=team["id"], include_budget=True) + + # Standalone VK + standalone_vk = self.create_virtual_key( + include_budget=True, include_rate_limit=True + ) + + return { + "customer": customer, + "team": team, + "customer_vk": customer_vk, + "team_vk": team_vk, + "standalone_vk": standalone_vk, + } + + +class PerformanceTracker: + """Track performance metrics during tests""" + + def __init__(self): + self.measurements = [] + + def time_operation(self, operation_name: str, operation_func, *args, **kwargs): + """Time an operation and record the measurement""" + start_time = time.time() + try: + result = operation_func(*args, **kwargs) + success = True + error = None + except Exception as e: + result = None + success = False + error = str(e) + + end_time = time.time() + duration = end_time - start_time + + measurement = { + "operation": operation_name, + "duration": duration, + "success": success, + "error": error, + "timestamp": datetime.now().isoformat(), + } + + self.measurements.append(measurement) + return result, measurement + + def get_stats(self) -> Dict[str, Any]: + """Get performance statistics""" + if not self.measurements: + return {"count": 0} + + durations = [m["duration"] for m in self.measurements] + successes = [m for m in self.measurements if m["success"]] + failures = [m for m in self.measurements if not m["success"]] + + return { + "count": len(self.measurements), + "success_count": len(successes), + "failure_count": len(failures), + "success_rate": len(successes) / len(self.measurements), + "avg_duration": sum(durations) / len(durations), + "min_duration": min(durations), + "max_duration": max(durations), + "total_duration": sum(durations), + } + + def print_report(self): + """Print performance report""" + stats = self.get_stats() + if stats["count"] == 0: + print("No measurements recorded") + return + + print(f"\nPerformance Report:") + print(f" Total operations: {stats['count']}") + print(f" Success rate: {stats['success_rate']:.2%}") + print(f" Average duration: {stats['avg_duration']:.3f}s") + print(f" Min duration: {stats['min_duration']:.3f}s") + print(f" Max duration: {stats['max_duration']:.3f}s") + print(f" Total duration: {stats['total_duration']:.3f}s") + + +class ChatCompletionHelper: + """Helper for chat completion testing""" + + @staticmethod + def generate_test_messages( + complexity: str = "simple", token_count_estimate: int = None + ) -> List[Dict[str, str]]: + """Generate test messages of varying complexity""" + if complexity == "simple": + return [{"role": "user", "content": "Hello, how are you?"}] + + elif complexity == "medium": + return [ + {"role": "user", "content": "Can you explain quantum computing?"}, + { + "role": "assistant", + "content": "Quantum computing is a type of computation that harnesses quantum mechanics...", + }, + { + "role": "user", + "content": "How does it differ from classical computing?", + }, + ] + + elif complexity == "complex": + content = fake.text(max_nb_chars=2000) + return [ + {"role": "system", "content": "You are a helpful AI assistant."}, + {"role": "user", "content": content}, + { + "role": "assistant", + "content": "I understand. Let me help you with that.", + }, + {"role": "user", "content": "Please provide a detailed analysis."}, + ] + + elif complexity == "custom" and token_count_estimate: + # Rough estimate: 4 characters per token + char_count = token_count_estimate * 4 + content = fake.text(max_nb_chars=char_count) + return [{"role": "user", "content": content}] + + else: + return [{"role": "user", "content": fake.sentence()}] + + @staticmethod + def make_test_request( + client: GovernanceTestClient, + vk_value: str, + model: str = "gpt-3.5-turbo", + max_tokens: int = 50, + **kwargs, + ) -> Dict[str, Any]: + """Make a standardized test chat completion request""" + messages = ( + kwargs.get("messages") or ChatCompletionHelper.generate_test_messages() + ) + headers = {"x-bf-vk": vk_value} + + response = client.chat_completion( + messages=messages, + model=model, + headers=headers, + max_tokens=max_tokens, + **{k: v for k, v in kwargs.items() if k != "messages"}, + ) + + return { + "response": response, + "status_code": response.status_code, + "success": response.status_code == 200, + "rate_limited": response.status_code == 429, + "budget_exceeded": response.status_code == 402, + "unauthorized": response.status_code in [401, 403], + "data": ( + response.json() + if response.headers.get("content-type", "").startswith( + "application/json" + ) + else response.text + ), + } + + +# Pytest fixtures for helpers + + +@pytest.fixture +def test_data_factory(): + """Test data factory fixture""" + return TestDataFactory() + + +@pytest.fixture +def validation_helper(): + """Validation helper fixture""" + return ValidationHelper() + + +@pytest.fixture +def scenario_builder(governance_client, cleanup_tracker): + """Test scenario builder fixture""" + return TestScenarioBuilder(governance_client, cleanup_tracker) + + +@pytest.fixture +def performance_tracker(): + """Performance tracker fixture""" + return PerformanceTracker() + + +@pytest.fixture +def chat_completion_helper(): + """Chat completion helper fixture""" + return ChatCompletionHelper() + + +# Test helper usage examples +class TestHelperExamples: + """Examples of how to use the test helpers""" + + @pytest.mark.helpers + def test_data_factory_usage( + self, test_data_factory, governance_client, cleanup_tracker + ): + """Example of using TestDataFactory""" + # Generate and create customer + customer_data = test_data_factory.generate_customer_data(include_budget=True) + customer_response = governance_client.create_customer(customer_data) + assert_response_success(customer_response, 201) + customer = customer_response.json()["customer"] + cleanup_tracker.add_customer(customer["id"]) + + # Verify data structure + assert customer["name"].endswith("Customer") + assert customer["budget"] is not None + + @pytest.mark.helpers + def test_scenario_builder_usage(self, scenario_builder): + """Example of using TestScenarioBuilder""" + # Create simple hierarchy + hierarchy = scenario_builder.create_simple_hierarchy() + + # Verify hierarchy structure + assert hierarchy["customer"]["id"] is not None + assert hierarchy["team"]["customer_id"] == hierarchy["customer"]["id"] + assert hierarchy["virtual_key"]["team_id"] == hierarchy["team"]["id"] + + @pytest.mark.helpers + def test_validation_helper_usage(self, validation_helper, sample_virtual_key): + """Example of using ValidationHelper""" + # Validate VK structure + errors = validation_helper.validate_entity_structure( + sample_virtual_key, "virtual_key" + ) + assert len(errors) == 0, f"VK validation errors: {errors}" + + # Validate budget if present + if sample_virtual_key.get("budget"): + budget_errors = validation_helper.validate_budget_structure( + sample_virtual_key["budget"] + ) + assert len(budget_errors) == 0, f"Budget validation errors: {budget_errors}" + + @pytest.mark.helpers + def test_performance_tracker_usage(self, performance_tracker, governance_client): + """Example of using PerformanceTracker""" + # Time an operation + result, measurement = performance_tracker.time_operation( + "list_customers", governance_client.list_customers + ) + + assert measurement["success"] is True + assert measurement["duration"] > 0 + + # Get performance stats + stats = performance_tracker.get_stats() + assert stats["count"] == 1 + assert stats["success_rate"] == 1.0 + + @pytest.mark.helpers + def test_chat_completion_helper_usage( + self, chat_completion_helper, governance_client, sample_virtual_key + ): + """Example of using ChatCompletionHelper""" + # Generate test messages + simple_messages = chat_completion_helper.generate_test_messages("simple") + assert len(simple_messages) == 1 + assert simple_messages[0]["role"] == "user" + + # Make test request + result = chat_completion_helper.make_test_request( + governance_client, sample_virtual_key["value"], max_tokens=10 + ) + + assert "status_code" in result + assert "success" in result + assert isinstance(result["success"], bool) diff --git a/tests/governance/test_teams_crud.py b/tests/governance/test_teams_crud.py new file mode 100644 index 000000000..169e6b63a --- /dev/null +++ b/tests/governance/test_teams_crud.py @@ -0,0 +1,897 @@ +""" +Comprehensive Team CRUD Tests for Bifrost Governance Plugin + +This module provides exhaustive testing of Team operations including: +- Complete CRUD lifecycle testing +- Comprehensive field update testing (individual and batch) +- Customer association testing +- Budget inheritance and management +- Filtering and query operations +- Edge cases and validation scenarios +- Concurrency and race condition testing +""" + +import pytest +import time +import uuid +from typing import Dict, Any, List +from concurrent.futures import ThreadPoolExecutor +import copy + +from conftest import ( + assert_response_success, + verify_unchanged_fields, + generate_unique_name, + verify_entity_relationships, + deep_compare_entities, +) + + +class TestTeamBasicCRUD: + """Test basic CRUD operations for Teams""" + + @pytest.mark.teams + @pytest.mark.crud + @pytest.mark.smoke + def test_team_create_minimal(self, governance_client, cleanup_tracker): + """Test creating team with minimal required data""" + data = {"name": generate_unique_name("Minimal Team")} + + response = governance_client.create_team(data) + assert_response_success(response, 201) + + team_data = response.json()["team"] + cleanup_tracker.add_team(team_data["id"]) + + # Verify required fields + assert team_data["name"] == data["name"] + assert team_data["id"] is not None + assert team_data["created_at"] is not None + assert team_data["updated_at"] is not None + + # Verify optional fields are None/empty + assert team_data["virtual_keys"] is None + + @pytest.mark.teams + @pytest.mark.crud + def test_team_create_with_customer( + self, governance_client, cleanup_tracker, sample_customer + ): + """Test creating team associated with a customer""" + data = { + "name": generate_unique_name("Customer Team"), + "customer_id": sample_customer["id"], + } + + response = governance_client.create_team(data) + assert_response_success(response, 201) + + team_data = response.json()["team"] + cleanup_tracker.add_team(team_data["id"]) + + # Verify customer association + assert team_data["customer_id"] == sample_customer["id"] + assert team_data["customer"] is not None + assert team_data["customer"]["id"] == sample_customer["id"] + assert team_data["customer"]["name"] == sample_customer["name"] + + @pytest.mark.teams + @pytest.mark.crud + @pytest.mark.budget + def test_team_create_with_budget(self, governance_client, cleanup_tracker): + """Test creating team with budget""" + data = { + "name": generate_unique_name("Budget Team"), + "budget": {"max_limit": 25000, "reset_duration": "1d"}, # $250.00 in cents + } + + response = governance_client.create_team(data) + assert_response_success(response, 201) + + team_data = response.json()["team"] + cleanup_tracker.add_team(team_data["id"]) + + # Verify budget was created + assert team_data["budget"] is not None + assert team_data["budget"]["max_limit"] == 25000 + assert team_data["budget"]["reset_duration"] == "1d" + assert team_data["budget"]["current_usage"] == 0 + assert team_data["budget_id"] is not None + + @pytest.mark.teams + @pytest.mark.crud + @pytest.mark.budget + def test_team_create_complete( + self, governance_client, cleanup_tracker, sample_customer + ): + """Test creating team with all possible fields""" + data = { + "name": generate_unique_name("Complete Team"), + "customer_id": sample_customer["id"], + "budget": { + "max_limit": 100000, # $1000.00 in cents + "reset_duration": "1w", + }, + } + + response = governance_client.create_team(data) + assert_response_success(response, 201) + + team_data = response.json()["team"] + cleanup_tracker.add_team(team_data["id"]) + + # Verify all fields + assert team_data["name"] == data["name"] + assert team_data["customer_id"] == sample_customer["id"] + assert team_data["customer"]["id"] == sample_customer["id"] + assert team_data["budget"]["max_limit"] == 100000 + assert team_data["budget"]["reset_duration"] == "1w" + + @pytest.mark.teams + @pytest.mark.crud + def test_team_list_all(self, governance_client, sample_team): + """Test listing all teams""" + response = governance_client.list_teams() + assert_response_success(response, 200) + + data = response.json() + assert "teams" in data + assert "count" in data + assert isinstance(data["teams"], list) + assert data["count"] >= 1 + + # Find our test team + test_team = next( + (team for team in data["teams"] if team["id"] == sample_team["id"]), None + ) + assert test_team is not None + + @pytest.mark.teams + @pytest.mark.crud + def test_team_list_filter_by_customer( + self, governance_client, sample_team_with_customer + ): + """Test listing teams filtered by customer""" + customer_id = sample_team_with_customer["customer_id"] + response = governance_client.list_teams(customer_id=customer_id) + assert_response_success(response, 200) + + data = response.json() + teams = data["teams"] + + # All returned teams should belong to the specified customer + for team in teams: + assert team["customer_id"] == customer_id + + # Our test team should be in the results + test_team = next( + (team for team in teams if team["id"] == sample_team_with_customer["id"]), + None, + ) + assert test_team is not None + + @pytest.mark.teams + @pytest.mark.crud + def test_team_get_by_id(self, governance_client, sample_team): + """Test getting team by ID with relationships loaded""" + response = governance_client.get_team(sample_team["id"]) + assert_response_success(response, 200) + + team_data = response.json()["team"] + assert team_data["id"] == sample_team["id"] + assert team_data["name"] == sample_team["name"] + + @pytest.mark.teams + @pytest.mark.crud + def test_team_get_nonexistent(self, governance_client): + """Test getting non-existent team returns 404""" + fake_id = str(uuid.uuid4()) + response = governance_client.get_team(fake_id) + assert response.status_code == 404 + + @pytest.mark.teams + @pytest.mark.crud + def test_team_delete(self, governance_client, cleanup_tracker): + """Test deleting a team""" + # Create team to delete + data = {"name": generate_unique_name("Delete Test Team")} + create_response = governance_client.create_team(data) + assert_response_success(create_response, 201) + team_id = create_response.json()["team"]["id"] + + # Delete team + delete_response = governance_client.delete_team(team_id) + assert_response_success(delete_response, 200) + + # Verify team is gone + get_response = governance_client.get_team(team_id) + assert get_response.status_code == 404 + + @pytest.mark.teams + @pytest.mark.crud + def test_team_delete_nonexistent(self, governance_client): + """Test deleting non-existent team returns 404""" + fake_id = str(uuid.uuid4()) + response = governance_client.delete_team(fake_id) + assert response.status_code == 404 + + +class TestTeamValidation: + """Test validation rules for Team operations""" + + @pytest.mark.teams + @pytest.mark.validation + def test_team_create_missing_name(self, governance_client): + """Test creating team without name fails""" + data = {"customer_id": str(uuid.uuid4())} + response = governance_client.create_team(data) + assert response.status_code == 400 + + @pytest.mark.teams + @pytest.mark.validation + def test_team_create_empty_name(self, governance_client): + """Test creating team with empty name fails""" + data = {"name": ""} + response = governance_client.create_team(data) + assert response.status_code == 400 + + @pytest.mark.teams + @pytest.mark.validation + def test_team_create_invalid_customer_id(self, governance_client): + """Test creating team with non-existent customer_id""" + data = { + "name": generate_unique_name("Invalid Customer Team"), + "customer_id": str(uuid.uuid4()), + } + response = governance_client.create_team(data) + # Note: Depending on implementation, this might succeed with warning or fail + # Adjust assertion based on actual API behavior + + @pytest.mark.teams + @pytest.mark.validation + def test_team_create_invalid_budget(self, governance_client): + """Test creating team with invalid budget data""" + # Test negative budget (should be rejected) + data = { + "name": generate_unique_name("Negative Budget Team"), + "budget": {"max_limit": -1000, "reset_duration": "1h"}, + } + response = governance_client.create_team(data) + assert response.status_code == 400 # API should reject negative budgets + + # Test invalid reset duration + data = { + "name": generate_unique_name("Invalid Duration Team"), + "budget": {"max_limit": 1000, "reset_duration": "invalid"}, + } + response = governance_client.create_team(data) + assert response.status_code == 400 + + +class TestTeamFieldUpdates: + """Comprehensive tests for Team field updates""" + + @pytest.mark.teams + @pytest.mark.field_updates + def test_team_update_individual_fields( + self, governance_client, cleanup_tracker, sample_customer + ): + """Test updating each team field individually""" + # Create team with all fields for testing + original_data = { + "name": generate_unique_name("Complete Update Test Team"), + "customer_id": sample_customer["id"], + "budget": {"max_limit": 50000, "reset_duration": "1d"}, + } + create_response = governance_client.create_team(original_data) + assert_response_success(create_response, 201) + team_id = create_response.json()["team"]["id"] + cleanup_tracker.add_team(team_id) + + # Get original state + original_response = governance_client.get_team(team_id) + original_team = original_response.json()["team"] + + # Create another customer for testing customer_id updates + other_customer_data = {"name": generate_unique_name("Other Customer")} + other_customer_response = governance_client.create_customer(other_customer_data) + assert_response_success(other_customer_response, 201) + other_customer = other_customer_response.json()["customer"] + cleanup_tracker.add_customer(other_customer["id"]) + + # Test individual field updates + field_test_cases = [ + { + "field": "name", + "update_data": {"name": "Updated Team Name"}, + "expected_value": "Updated Team Name", + }, + { + "field": "customer_id", + "update_data": {"customer_id": other_customer["id"]}, + "expected_value": other_customer["id"], + "exclude_from_unchanged_check": ["customer_id", "customer"], + }, + { + "field": "customer_id_clear", + "update_data": {"customer_id": None}, + "expected_value": None, + "exclude_from_unchanged_check": ["customer_id", "customer"], + "custom_validation": lambda team: team["customer_id"] is None + and team["customer"] is None, + }, + ] + + for test_case in field_test_cases: + # Reset team to original state + reset_data = { + "name": original_team["name"], + "customer_id": original_team["customer_id"], + } + governance_client.update_team(team_id, reset_data) + + # Perform field update + response = governance_client.update_team(team_id, test_case["update_data"]) + assert_response_success(response, 200) + updated_team = response.json()["team"] + + # Verify target field was updated + if test_case.get("custom_validation"): + test_case["custom_validation"](updated_team) + else: + field_parts = test_case["field"].split(".") + current_value = updated_team + for part in field_parts: + if part != "clear": # Skip suffix indicators + current_value = current_value[part] + assert ( + current_value == test_case["expected_value"] + ), f"Field {test_case['field']} not updated correctly" + + # Verify other fields unchanged (if specified) + if test_case.get("verify_unchanged", True): + exclude_fields = test_case.get( + "exclude_from_unchanged_check", [test_case["field"]] + ) + verify_unchanged_fields(updated_team, original_team, exclude_fields) + + @pytest.mark.teams + @pytest.mark.field_updates + @pytest.mark.budget + def test_team_budget_updates(self, governance_client, cleanup_tracker): + """Test comprehensive budget creation, update, and modification""" + # Create team without budget + data = {"name": generate_unique_name("Budget Update Test Team")} + create_response = governance_client.create_team(data) + assert_response_success(create_response, 201) + team_id = create_response.json()["team"]["id"] + cleanup_tracker.add_team(team_id) + + # Test 1: Add budget to team without budget + budget_data = {"max_limit": 15000, "reset_duration": "1h"} + response = governance_client.update_team(team_id, {"budget": budget_data}) + assert_response_success(response, 200) + updated_team = response.json()["team"] + assert updated_team["budget"]["max_limit"] == 15000 + assert updated_team["budget"]["reset_duration"] == "1h" + assert updated_team["budget_id"] is not None + + # Test 2: Update existing budget completely + new_budget_data = {"max_limit": 30000, "reset_duration": "2h"} + response = governance_client.update_team(team_id, {"budget": new_budget_data}) + assert_response_success(response, 200) + updated_team = response.json()["team"] + assert updated_team["budget"]["max_limit"] == 30000 + assert updated_team["budget"]["reset_duration"] == "2h" + + # Test 3: Partial budget update (only max_limit) + response = governance_client.update_team( + team_id, {"budget": {"max_limit": 45000}} + ) + assert_response_success(response, 200) + updated_team = response.json()["team"] + assert updated_team["budget"]["max_limit"] == 45000 + assert ( + updated_team["budget"]["reset_duration"] == "2h" + ) # Should remain unchanged + + # Test 4: Partial budget update (only reset_duration) + response = governance_client.update_team( + team_id, {"budget": {"reset_duration": "1d"}} + ) + assert_response_success(response, 200) + updated_team = response.json()["team"] + assert updated_team["budget"]["max_limit"] == 45000 # Should remain unchanged + assert updated_team["budget"]["reset_duration"] == "1d" + + @pytest.mark.teams + @pytest.mark.field_updates + def test_team_multiple_field_updates( + self, governance_client, cleanup_tracker, sample_customer + ): + """Test updating multiple fields simultaneously""" + # Create team with initial data + initial_data = { + "name": generate_unique_name("Multi-Field Test Team"), + } + create_response = governance_client.create_team(initial_data) + assert_response_success(create_response, 201) + team_id = create_response.json()["team"]["id"] + cleanup_tracker.add_team(team_id) + + # Update multiple fields at once + update_data = { + "name": "Updated Multi-Field Team Name", + "customer_id": sample_customer["id"], + "budget": {"max_limit": 75000, "reset_duration": "1w"}, + } + + response = governance_client.update_team(team_id, update_data) + assert_response_success(response, 200) + + updated_team = response.json()["team"] + assert updated_team["name"] == "Updated Multi-Field Team Name" + assert updated_team["customer_id"] == sample_customer["id"] + assert updated_team["customer"]["id"] == sample_customer["id"] + assert updated_team["budget"]["max_limit"] == 75000 + assert updated_team["budget"]["reset_duration"] == "1w" + + @pytest.mark.teams + @pytest.mark.field_updates + @pytest.mark.edge_cases + def test_team_update_edge_cases(self, governance_client, cleanup_tracker): + """Test edge cases in team updates""" + # Create test team + data = {"name": generate_unique_name("Edge Case Team")} + create_response = governance_client.create_team(data) + assert_response_success(create_response, 201) + team_id = create_response.json()["team"]["id"] + cleanup_tracker.add_team(team_id) + + original_response = governance_client.get_team(team_id) + original_team = original_response.json()["team"] + + # Test 1: Empty update (should return unchanged team) + response = governance_client.update_team(team_id, {}) + assert_response_success(response, 200) + updated_team = response.json()["team"] + + # Compare ignoring timestamps + differences = deep_compare_entities( + updated_team, original_team, ignore_fields=["updated_at"] + ) + assert len(differences) == 0, f"Empty update changed fields: {differences}" + + # Test 2: Update with same values + response = governance_client.update_team( + team_id, {"name": original_team["name"]} + ) + assert_response_success(response, 200) + + # Test 3: Very long team name (test field length limits) + long_name = "x" * 1000 # Adjust based on actual field limits + response = governance_client.update_team(team_id, {"name": long_name}) + # Expected behavior depends on API validation rules + + @pytest.mark.teams + @pytest.mark.field_updates + def test_team_update_nonexistent(self, governance_client): + """Test updating non-existent team returns 404""" + fake_id = str(uuid.uuid4()) + response = governance_client.update_team(fake_id, {"name": "test"}) + assert response.status_code == 404 + + +class TestTeamBudgetManagement: + """Test team budget specific functionality""" + + @pytest.mark.teams + @pytest.mark.budget + def test_team_budget_creation_and_validation( + self, governance_client, cleanup_tracker + ): + """Test budget creation with various configurations""" + # Test valid budget configurations + budget_test_cases = [ + {"max_limit": 5000, "reset_duration": "1h"}, + {"max_limit": 25000, "reset_duration": "1d"}, + {"max_limit": 100000, "reset_duration": "1w"}, + {"max_limit": 500000, "reset_duration": "1M"}, + ] + + for budget_config in budget_test_cases: + data = { + "name": generate_unique_name( + f"Budget Team {budget_config['reset_duration']}" + ), + "budget": budget_config, + } + + response = governance_client.create_team(data) + assert_response_success(response, 201) + + team_data = response.json()["team"] + cleanup_tracker.add_team(team_data["id"]) + + assert team_data["budget"]["max_limit"] == budget_config["max_limit"] + assert ( + team_data["budget"]["reset_duration"] == budget_config["reset_duration"] + ) + assert team_data["budget"]["current_usage"] == 0 + assert team_data["budget"]["last_reset"] is not None + + @pytest.mark.teams + @pytest.mark.budget + @pytest.mark.edge_cases + def test_team_budget_edge_cases(self, governance_client, cleanup_tracker): + """Test budget edge cases and boundary conditions""" + # Test boundary values + edge_case_budgets = [ + {"max_limit": 0, "reset_duration": "1h"}, # Zero budget + {"max_limit": 1, "reset_duration": "1s"}, # Minimal values + {"max_limit": 9223372036854775807, "reset_duration": "1h"}, # Max int64 + ] + + for budget_config in edge_case_budgets: + data = { + "name": generate_unique_name( + f"Edge Budget Team {budget_config['max_limit']}" + ), + "budget": budget_config, + } + + response = governance_client.create_team(data) + # Adjust assertions based on API validation rules + if ( + budget_config["max_limit"] >= 0 + ): # Assuming non-negative budgets are valid + assert_response_success(response, 201) + cleanup_tracker.add_team(response.json()["team"]["id"]) + else: + assert response.status_code == 400 + + @pytest.mark.teams + @pytest.mark.budget + def test_team_budget_inheritance_simulation( + self, governance_client, cleanup_tracker + ): + """Test team budget in context of hierarchical inheritance""" + # This test simulates budget inheritance behavior + # Actual inheritance testing would be in integration tests + + # Create customer with budget + customer_data = { + "name": generate_unique_name("Budget Customer"), + "budget": {"max_limit": 100000, "reset_duration": "1d"}, + } + customer_response = governance_client.create_customer(customer_data) + assert_response_success(customer_response, 201) + customer = customer_response.json()["customer"] + cleanup_tracker.add_customer(customer["id"]) + + # Create team with smaller budget under customer + team_data = { + "name": generate_unique_name("Sub-Budget Team"), + "customer_id": customer["id"], + "budget": { + "max_limit": 25000, + "reset_duration": "1d", + }, # Smaller than customer + } + team_response = governance_client.create_team(team_data) + assert_response_success(team_response, 201) + team = team_response.json()["team"] + cleanup_tracker.add_team(team["id"]) + + # Verify both budgets exist independently + assert team["budget"]["max_limit"] == 25000 + # Note: Customer budget not preloaded in team response (use customer endpoint to verify) + customer_response = governance_client.get_customer(customer["id"]) + customer_with_budget = customer_response.json()["customer"] + assert customer_with_budget["budget"]["max_limit"] == 100000 + + # Create team without budget under customer (should inherit) + no_budget_team_data = { + "name": generate_unique_name("Inherit Budget Team"), + "customer_id": customer["id"], + } + no_budget_response = governance_client.create_team(no_budget_team_data) + assert_response_success(no_budget_response, 201) + no_budget_team = no_budget_response.json()["team"] + cleanup_tracker.add_team(no_budget_team["id"]) + + # Team without explicit budget should not have budget field (omitempty) + assert no_budget_team.get("budget") is None + # Verify customer has budget (need to fetch customer directly due to preloading limits) + customer_check = governance_client.get_customer(customer["id"]) + assert customer_check.json()["customer"]["budget"]["max_limit"] == 100000 + + +class TestTeamRelationships: + """Test team relationships with customers""" + + @pytest.mark.teams + @pytest.mark.relationships + def test_team_customer_relationship_loading( + self, governance_client, cleanup_tracker, sample_customer + ): + """Test that team properly loads customer relationships""" + data = { + "name": generate_unique_name("Customer Relationship Team"), + "customer_id": sample_customer["id"], + } + + response = governance_client.create_team(data) + assert_response_success(response, 201) + team_data = response.json()["team"] + cleanup_tracker.add_team(team_data["id"]) + + # Verify customer relationship loaded + assert team_data["customer"] is not None + assert team_data["customer"]["id"] == sample_customer["id"] + assert team_data["customer"]["name"] == sample_customer["name"] + + # Verify customer budget relationship loaded if it exists + if sample_customer.get("budget"): + assert team_data["customer"]["budget"] is not None + + @pytest.mark.teams + @pytest.mark.relationships + def test_team_orphaned_customer_reference(self, governance_client, cleanup_tracker): + """Test team behavior with orphaned customer reference""" + # Create team with non-existent customer_id + fake_customer_id = str(uuid.uuid4()) + data = { + "name": generate_unique_name("Orphaned Team"), + "customer_id": fake_customer_id, + } + + response = governance_client.create_team(data) + # Behavior depends on API implementation: + # - Might succeed with warning + # - Might fail with validation error + # Adjust assertion based on actual behavior + + if response.status_code == 201: + cleanup_tracker.add_team(response.json()["team"]["id"]) + # Verify team was created but customer relationship is null/missing + team_data = response.json()["team"] + assert team_data.get("customer") is None + else: + assert response.status_code == 400 # Validation error expected + + @pytest.mark.teams + @pytest.mark.relationships + def test_team_customer_association_changes( + self, governance_client, cleanup_tracker, sample_customer + ): + """Test changing team customer associations""" + # Create standalone team + data = {"name": generate_unique_name("Association Test Team")} + create_response = governance_client.create_team(data) + assert_response_success(create_response, 201) + team_id = create_response.json()["team"]["id"] + cleanup_tracker.add_team(team_id) + + # Create another customer + other_customer_data = {"name": generate_unique_name("Other Customer")} + other_customer_response = governance_client.create_customer(other_customer_data) + assert_response_success(other_customer_response, 201) + other_customer = other_customer_response.json()["customer"] + cleanup_tracker.add_customer(other_customer["id"]) + + # Test 1: Associate with first customer + response = governance_client.update_team( + team_id, {"customer_id": sample_customer["id"]} + ) + assert_response_success(response, 200) + updated_team = response.json()["team"] + assert updated_team["customer_id"] == sample_customer["id"] + assert updated_team["customer"]["id"] == sample_customer["id"] + + # Test 2: Switch to other customer + response = governance_client.update_team( + team_id, {"customer_id": other_customer["id"]} + ) + assert_response_success(response, 200) + updated_team = response.json()["team"] + assert updated_team["customer_id"] == other_customer["id"] + assert updated_team["customer"]["id"] == other_customer["id"] + + # Test 3: Remove customer association + response = governance_client.update_team(team_id, {"customer_id": None}) + # Note: Behavior depends on API implementation + # Adjust assertion based on actual behavior + + +class TestTeamConcurrency: + """Test concurrent operations on Teams""" + + @pytest.mark.teams + @pytest.mark.concurrency + @pytest.mark.slow + def test_team_concurrent_creation(self, governance_client, cleanup_tracker): + """Test creating multiple teams concurrently""" + + def create_team(index): + data = {"name": generate_unique_name(f"Concurrent Team {index}")} + response = governance_client.create_team(data) + return response + + # Create 10 teams concurrently + with ThreadPoolExecutor(max_workers=10) as executor: + futures = [executor.submit(create_team, i) for i in range(10)] + responses = [future.result() for future in futures] + + # Verify all succeeded + created_teams = [] + for response in responses: + assert_response_success(response, 201) + team_data = response.json()["team"] + created_teams.append(team_data) + cleanup_tracker.add_team(team_data["id"]) + + # Verify all teams have unique IDs + team_ids = [team["id"] for team in created_teams] + assert len(set(team_ids)) == 10 # All unique IDs + + @pytest.mark.teams + @pytest.mark.concurrency + @pytest.mark.slow + def test_team_concurrent_updates(self, governance_client, cleanup_tracker): + """Test updating same team concurrently""" + # Create team to update + data = {"name": generate_unique_name("Concurrent Update Team")} + create_response = governance_client.create_team(data) + assert_response_success(create_response, 201) + team_id = create_response.json()["team"]["id"] + cleanup_tracker.add_team(team_id) + + # Update concurrently with different names + def update_team(index): + update_data = {"name": f"Updated by thread {index}"} + response = governance_client.update_team(team_id, update_data) + return response, index + + with ThreadPoolExecutor(max_workers=5) as executor: + futures = [executor.submit(update_team, i) for i in range(5)] + results = [future.result() for future in futures] + + # All updates should succeed (last one wins) + for response, index in results: + assert_response_success(response, 200) + + # Verify final state + final_response = governance_client.get_team(team_id) + final_team = final_response.json()["team"] + assert final_team["name"].startswith("Updated by thread") + + @pytest.mark.teams + @pytest.mark.concurrency + @pytest.mark.slow + def test_team_concurrent_customer_association( + self, governance_client, cleanup_tracker, sample_customer + ): + """Test concurrent customer association updates""" + # Create multiple teams to associate with same customer + teams = [] + for i in range(5): + data = {"name": generate_unique_name(f"Concurrent Association Team {i}")} + response = governance_client.create_team(data) + assert_response_success(response, 201) + team_data = response.json()["team"] + teams.append(team_data) + cleanup_tracker.add_team(team_data["id"]) + + # Associate all teams with customer concurrently + def associate_team(team): + update_data = {"customer_id": sample_customer["id"]} + response = governance_client.update_team(team["id"], update_data) + return response, team["id"] + + with ThreadPoolExecutor(max_workers=5) as executor: + futures = [executor.submit(associate_team, team) for team in teams] + results = [future.result() for future in futures] + + # All associations should succeed + for response, team_id in results: + assert_response_success(response, 200) + updated_team = response.json()["team"] + assert updated_team["customer_id"] == sample_customer["id"] + + +class TestTeamFiltering: + """Test team filtering and query operations""" + + @pytest.mark.teams + @pytest.mark.api + def test_team_filter_by_customer_comprehensive( + self, governance_client, cleanup_tracker + ): + """Test comprehensive customer filtering scenarios""" + # Create customers + customer1_data = {"name": generate_unique_name("Filter Customer 1")} + customer1_response = governance_client.create_customer(customer1_data) + assert_response_success(customer1_response, 201) + customer1 = customer1_response.json()["customer"] + cleanup_tracker.add_customer(customer1["id"]) + + customer2_data = {"name": generate_unique_name("Filter Customer 2")} + customer2_response = governance_client.create_customer(customer2_data) + assert_response_success(customer2_response, 201) + customer2 = customer2_response.json()["customer"] + cleanup_tracker.add_customer(customer2["id"]) + + # Create teams for customer1 + for i in range(3): + team_data = { + "name": generate_unique_name(f"Customer1 Team {i}"), + "customer_id": customer1["id"], + } + response = governance_client.create_team(team_data) + assert_response_success(response, 201) + cleanup_tracker.add_team(response.json()["team"]["id"]) + + # Create teams for customer2 + for i in range(2): + team_data = { + "name": generate_unique_name(f"Customer2 Team {i}"), + "customer_id": customer2["id"], + } + response = governance_client.create_team(team_data) + assert_response_success(response, 201) + cleanup_tracker.add_team(response.json()["team"]["id"]) + + # Create standalone team + standalone_data = {"name": generate_unique_name("Standalone Team")} + response = governance_client.create_team(standalone_data) + assert_response_success(response, 201) + cleanup_tracker.add_team(response.json()["team"]["id"]) + + # Test filtering by customer1 + response = governance_client.list_teams(customer_id=customer1["id"]) + assert_response_success(response, 200) + teams = response.json()["teams"] + assert len(teams) == 3 + for team in teams: + assert team["customer_id"] == customer1["id"] + + # Test filtering by customer2 + response = governance_client.list_teams(customer_id=customer2["id"]) + assert_response_success(response, 200) + teams = response.json()["teams"] + assert len(teams) == 2 + for team in teams: + assert team["customer_id"] == customer2["id"] + + # Test filtering by non-existent customer + fake_customer_id = str(uuid.uuid4()) + response = governance_client.list_teams(customer_id=fake_customer_id) + assert_response_success(response, 200) + teams = response.json()["teams"] + assert len(teams) == 0 + + @pytest.mark.teams + @pytest.mark.api + def test_team_list_pagination_and_sorting(self, governance_client, cleanup_tracker): + """Test team list with pagination and sorting (if supported by API)""" + # Create multiple teams for testing + team_names = [] + for i in range(10): + name = generate_unique_name(f"Sort Test Team {i:02d}") + team_names.append(name) + data = {"name": name} + response = governance_client.create_team(data) + assert_response_success(response, 201) + cleanup_tracker.add_team(response.json()["team"]["id"]) + + # Test basic list (should include our teams) + response = governance_client.list_teams() + assert_response_success(response, 200) + teams = response.json()["teams"] + assert len(teams) >= 10 + + # Verify our teams are in the response + response_team_names = {team["name"] for team in teams} + for name in team_names: + assert name in response_team_names diff --git a/tests/governance/test_usage_tracking.py b/tests/governance/test_usage_tracking.py new file mode 100644 index 000000000..aaa5724cc --- /dev/null +++ b/tests/governance/test_usage_tracking.py @@ -0,0 +1,1061 @@ +""" +Comprehensive Usage Tracking and Monitoring Tests for Bifrost Governance Plugin + +This module provides exhaustive testing of usage tracking, monitoring, and integration including: +- Chat completion integration with governance headers +- Usage tracking and budget enforcement +- Rate limiting enforcement during real requests +- Monitoring endpoints testing +- Reset functionality testing +- Debug and health endpoints +- Integration edge cases and error scenarios +- Performance and concurrency testing +""" + +import pytest +import time +import uuid +import json +from typing import Dict, Any, List +from concurrent.futures import ThreadPoolExecutor +import threading + +from conftest import ( + assert_response_success, + generate_unique_name, + wait_for_condition, + BIFROST_BASE_URL, +) + + +class TestUsageStatsEndpoints: + """Test usage statistics and monitoring endpoints""" + + @pytest.mark.usage_tracking + @pytest.mark.api + @pytest.mark.smoke + def test_get_usage_stats_general(self, governance_client): + """Test getting general usage statistics""" + response = governance_client.get_usage_stats() + assert_response_success(response, 200) + + stats = response.json() + # Stats structure depends on implementation, but should be valid JSON + assert isinstance(stats, dict) + + @pytest.mark.usage_tracking + @pytest.mark.api + def test_get_usage_stats_for_vk(self, governance_client, sample_virtual_key): + """Test getting usage statistics for specific VK""" + response = governance_client.get_usage_stats( + virtual_key_id=sample_virtual_key["id"] + ) + assert_response_success(response, 200) + + data = response.json() + assert "virtual_key_id" in data + assert data["virtual_key_id"] == sample_virtual_key["id"] + assert "usage_stats" in data + + @pytest.mark.usage_tracking + @pytest.mark.api + def test_get_usage_stats_nonexistent_vk(self, governance_client): + """Test getting usage stats for non-existent VK""" + fake_vk_id = str(uuid.uuid4()) + response = governance_client.get_usage_stats(virtual_key_id=fake_vk_id) + # Behavior depends on implementation - might return empty stats or 404 + assert response.status_code in [200, 404] + + @pytest.mark.usage_tracking + @pytest.mark.api + def test_reset_usage_basic(self, governance_client, sample_virtual_key): + """Test basic usage reset functionality""" + reset_data = {"virtual_key_id": sample_virtual_key["id"]} + + response = governance_client.reset_usage(reset_data) + assert_response_success(response, 200) + + result = response.json() + assert "message" in result + assert "successfully" in result["message"].lower() + + @pytest.mark.usage_tracking + @pytest.mark.api + def test_reset_usage_with_provider_and_model( + self, governance_client, sample_virtual_key + ): + """Test usage reset with specific provider and model""" + reset_data = { + "virtual_key_id": sample_virtual_key["id"], + "provider": "openai", + "model": "gpt-4", + } + + response = governance_client.reset_usage(reset_data) + assert_response_success(response, 200) + + @pytest.mark.usage_tracking + @pytest.mark.api + def test_reset_usage_invalid_vk(self, governance_client): + """Test usage reset with invalid VK ID""" + reset_data = {"virtual_key_id": str(uuid.uuid4())} + + response = governance_client.reset_usage(reset_data) + assert response.status_code in [400, 404, 500] # Expected error + + +class TestDebugEndpoints: + """Test debug and monitoring endpoints""" + + @pytest.mark.usage_tracking + @pytest.mark.api + @pytest.mark.smoke + def test_get_debug_stats(self, governance_client): + """Test debug statistics endpoint""" + response = governance_client.get_debug_stats() + assert_response_success(response, 200) + + data = response.json() + assert "plugin_stats" in data + assert "database_stats" in data + assert "timestamp" in data + + @pytest.mark.usage_tracking + @pytest.mark.api + def test_get_debug_counters(self, governance_client): + """Test debug counters endpoint""" + response = governance_client.get_debug_counters() + assert_response_success(response, 200) + + data = response.json() + assert "counters" in data + assert "count" in data + assert "timestamp" in data + assert isinstance(data["counters"], list) + + @pytest.mark.usage_tracking + @pytest.mark.api + @pytest.mark.smoke + def test_get_health_check(self, governance_client): + """Test health check endpoint""" + response = governance_client.get_health_check() + # Health check should return 200 for healthy or 503 for unhealthy + assert response.status_code in [200, 503] + + data = response.json() + assert "status" in data + assert "timestamp" in data + assert "checks" in data + assert data["status"] in ["healthy", "unhealthy"] + + +class TestChatCompletionIntegration: + """Test chat completion integration with governance headers""" + + @pytest.mark.integration + @pytest.mark.usage_tracking + @pytest.mark.smoke + def test_chat_completion_with_vk_header( + self, governance_client, sample_virtual_key + ): + """Test chat completion with valid VK header""" + messages = [{"role": "user", "content": "Hello, world!"}] + headers = {"x-bf-vk": sample_virtual_key["value"]} + + response = governance_client.chat_completion( + messages=messages, + model="openai/gpt-3.5-turbo", + headers=headers, + max_tokens=10, + ) + + # Response should be successful, rate limited, budget exceeded, or VK not found + assert response.status_code in [200, 429, 402, 403] + + if response.status_code == 200: + data = response.json() + assert "choices" in data + assert len(data["choices"]) > 0 + + @pytest.mark.integration + @pytest.mark.usage_tracking + def test_chat_completion_without_vk_header(self, governance_client): + """Test chat completion without VK header""" + messages = [{"role": "user", "content": "Hello, world!"}] + + response = governance_client.chat_completion( + messages=messages, model="openai/gpt-3.5-turbo", max_tokens=10 + ) + + # Should succeed without VK header (governance skipped) + assert response.status_code in [ + 200, + 400, + ] # 200 if no governance, 400 if provider issues + + @pytest.mark.integration + @pytest.mark.usage_tracking + def test_chat_completion_invalid_vk_header(self, governance_client): + """Test chat completion with invalid VK header""" + messages = [{"role": "user", "content": "Hello, world!"}] + headers = {"x-bf-vk": "invalid-vk-value"} + + response = governance_client.chat_completion( + messages=messages, + model="openai/gpt-3.5-turbo", + headers=headers, + max_tokens=10, + ) + + # Should fail with invalid VK (governance blocks) + assert response.status_code == 403 + + @pytest.mark.integration + @pytest.mark.usage_tracking + def test_chat_completion_inactive_vk(self, governance_client, cleanup_tracker): + """Test chat completion with inactive VK""" + # Create inactive VK + vk_data = {"name": generate_unique_name("Inactive VK"), "is_active": False} + create_response = governance_client.create_virtual_key(vk_data) + assert_response_success(create_response, 201) + inactive_vk = create_response.json()["virtual_key"] + cleanup_tracker.add_virtual_key(inactive_vk["id"]) + + messages = [{"role": "user", "content": "Hello, world!"}] + headers = {"x-bf-vk": inactive_vk["value"]} + + response = governance_client.chat_completion( + messages=messages, + model="openai/gpt-3.5-turbo", + headers=headers, + max_tokens=10, + ) + + # Should fail with inactive VK (governance blocks) + assert response.status_code == 403 + + @pytest.mark.integration + @pytest.mark.usage_tracking + def test_chat_completion_with_model_restrictions( + self, governance_client, cleanup_tracker + ): + """Test chat completion with model restrictions""" + # Create VK with model restrictions + vk_data = { + "name": generate_unique_name("Restricted VK"), + "allowed_models": ["gpt-4"], # Only allow GPT-4 + "allowed_providers": ["openai"], + } + create_response = governance_client.create_virtual_key(vk_data) + assert_response_success(create_response, 201) + restricted_vk = create_response.json()["virtual_key"] + cleanup_tracker.add_virtual_key(restricted_vk["id"]) + + # Test with allowed model + messages = [{"role": "user", "content": "Hello, world!"}] + headers = {"x-bf-vk": restricted_vk["value"]} + + response = governance_client.chat_completion( + messages=messages, model="gpt-4", headers=headers, max_tokens=10 + ) + + # Should work with allowed model + assert response.status_code in [200, 429, 402] # Success or limits + + # Test with disallowed model + response = governance_client.chat_completion( + messages=messages, + model="openai/gpt-3.5-turbo", # Not in allowed_models + headers=headers, + max_tokens=10, + ) + + # Should fail with disallowed model + assert response.status_code in [400, 403] + + +class TestBudgetEnforcement: + """Test budget enforcement during chat completions""" + + @pytest.mark.integration + @pytest.mark.budget + @pytest.mark.usage_tracking + def test_budget_enforcement_basic(self, governance_client, cleanup_tracker): + """Test basic budget enforcement""" + # Create VK with very small budget + vk_data = { + "name": generate_unique_name("Small Budget VK"), + "budget": { + "max_limit": 1, # 1 cent - very small budget + "reset_duration": "1h", + }, + } + create_response = governance_client.create_virtual_key(vk_data) + assert_response_success(create_response, 201) + small_budget_vk = create_response.json()["virtual_key"] + cleanup_tracker.add_virtual_key(small_budget_vk["id"]) + + messages = [ + { + "role": "user", + "content": "Write a very long story about artificial intelligence" * 10, + } + ] + headers = {"x-bf-vk": small_budget_vk["value"]} + + response = governance_client.chat_completion( + messages=messages, + model="openai/gpt-3.5-turbo", + headers=headers, + max_tokens=1000, # Request expensive completion + ) + + # Should fail due to budget exceeded + if response.status_code == 402: # Budget exceeded + error_data = response.json() + assert "budget" in error_data.get("error", "").lower() + elif response.status_code == 200: + # If it succeeded, check that budget was tracked + stats_response = governance_client.get_usage_stats( + virtual_key_id=small_budget_vk["id"] + ) + if stats_response.status_code == 200: + # Verify usage was tracked + pass + + @pytest.mark.integration + @pytest.mark.budget + @pytest.mark.usage_tracking + def test_hierarchical_budget_enforcement(self, governance_client, cleanup_tracker): + """Test hierarchical budget enforcement (Customer -> Team -> VK)""" + # Create customer with budget + customer_data = { + "name": generate_unique_name("Budget Test Customer"), + "budget": {"max_limit": 10000, "reset_duration": "1h"}, + } + customer_response = governance_client.create_customer(customer_data) + assert_response_success(customer_response, 201) + customer = customer_response.json()["customer"] + cleanup_tracker.add_customer(customer["id"]) + + # Create team under customer with smaller budget + team_data = { + "name": generate_unique_name("Budget Test Team"), + "customer_id": customer["id"], + "budget": {"max_limit": 5000, "reset_duration": "1h"}, + } + team_response = governance_client.create_team(team_data) + assert_response_success(team_response, 201) + team = team_response.json()["team"] + cleanup_tracker.add_team(team["id"]) + + # Create VK under team with even smaller budget + vk_data = { + "name": generate_unique_name("Budget Test VK"), + "team_id": team["id"], + "budget": {"max_limit": 1, "reset_duration": "1h"}, # Smallest budget + } + vk_response = governance_client.create_virtual_key(vk_data) + assert_response_success(vk_response, 201) + vk = vk_response.json()["virtual_key"] + cleanup_tracker.add_virtual_key(vk["id"]) + + # Test request that should hit VK budget first + messages = [{"role": "user", "content": "Expensive request" * 50}] + headers = {"x-bf-vk": vk["value"]} + + response = governance_client.chat_completion( + messages=messages, + model="gpt-4", # More expensive model + headers=headers, + max_tokens=1000, + ) + + # Should be limited by VK budget (smallest in hierarchy) + # Actual behavior depends on implementation + + @pytest.mark.integration + @pytest.mark.budget + @pytest.mark.usage_tracking + def test_budget_reset_functionality(self, governance_client, cleanup_tracker): + """Test budget reset functionality""" + # Create VK with small budget + vk_data = { + "name": generate_unique_name("Reset Budget VK"), + "budget": {"max_limit": 100, "reset_duration": "1h"}, # Small but not tiny + } + create_response = governance_client.create_virtual_key(vk_data) + assert_response_success(create_response, 201) + vk = create_response.json()["virtual_key"] + cleanup_tracker.add_virtual_key(vk["id"]) + + # Make a request to use some budget + messages = [{"role": "user", "content": "Hello"}] + headers = {"x-bf-vk": vk["value"]} + + response = governance_client.chat_completion( + messages=messages, + model="openai/gpt-3.5-turbo", + headers=headers, + max_tokens=5, + ) + + # Reset the usage + reset_data = {"virtual_key_id": vk["id"]} + reset_response = governance_client.reset_usage(reset_data) + assert_response_success(reset_response, 200) + + # Budget should be reset - could make another request + response2 = governance_client.chat_completion( + messages=messages, + model="openai/gpt-3.5-turbo", + headers=headers, + max_tokens=5, + ) + + # Should work after reset (unless other limits apply) + assert response2.status_code in [200, 429] # Success or rate limited + + +class TestRateLimitEnforcement: + """Test rate limiting enforcement during chat completions""" + + @pytest.mark.integration + @pytest.mark.rate_limit + @pytest.mark.usage_tracking + def test_request_rate_limiting(self, governance_client, cleanup_tracker): + """Test request rate limiting""" + # Create VK with very restrictive request rate limit + vk_data = { + "name": generate_unique_name("Rate Limited VK"), + "rate_limit": { + "request_max_limit": 2, # Only 2 requests allowed + "request_reset_duration": "1m", + }, + } + create_response = governance_client.create_virtual_key(vk_data) + assert_response_success(create_response, 201) + rate_limited_vk = create_response.json()["virtual_key"] + cleanup_tracker.add_virtual_key(rate_limited_vk["id"]) + + messages = [{"role": "user", "content": "Hello"}] + headers = {"x-bf-vk": rate_limited_vk["value"]} + + # Make requests up to the limit + responses = [] + for i in range(3): # Try 3 requests, limit is 2 + response = governance_client.chat_completion( + messages=messages, + model="openai/gpt-3.5-turbo", + headers=headers, + max_tokens=5, + ) + responses.append(response) + time.sleep(0.1) # Small delay + + # First 2 should succeed, 3rd should be rate limited + success_count = sum(1 for r in responses if r.status_code == 200) + rate_limited_count = sum(1 for r in responses if r.status_code == 429) + + # Depending on implementation, might be exactly enforced or allow some variance + assert rate_limited_count > 0 or success_count <= 2 + + @pytest.mark.integration + @pytest.mark.rate_limit + @pytest.mark.usage_tracking + def test_token_rate_limiting(self, governance_client, cleanup_tracker): + """Test token rate limiting""" + # Create VK with restrictive token rate limit + vk_data = { + "name": generate_unique_name("Token Rate Limited VK"), + "rate_limit": { + "token_max_limit": 100, # Only 100 tokens allowed + "token_reset_duration": "1m", + }, + } + create_response = governance_client.create_virtual_key(vk_data) + assert_response_success(create_response, 201) + token_limited_vk = create_response.json()["virtual_key"] + cleanup_tracker.add_virtual_key(token_limited_vk["id"]) + + # Make request that would exceed token limit + messages = [ + {"role": "user", "content": "Write a very long response about AI" * 10} + ] + headers = {"x-bf-vk": token_limited_vk["value"]} + + response = governance_client.chat_completion( + messages=messages, + model="openai/gpt-3.5-turbo", + headers=headers, + max_tokens=500, # Request more tokens than limit + ) + + # Should be limited by token rate limit + if response.status_code == 429: + error_data = response.json() + # Check if error mentions tokens or rate limit + error_text = error_data.get("error", "").lower() + assert "token" in error_text or "rate" in error_text + + @pytest.mark.integration + @pytest.mark.rate_limit + @pytest.mark.usage_tracking + def test_independent_rate_limits(self, governance_client, cleanup_tracker): + """Test that token and request rate limits are independent""" + # Create VK with different token and request limits + vk_data = { + "name": generate_unique_name("Independent Limits VK"), + "rate_limit": { + "token_max_limit": 1000, + "token_reset_duration": "1h", + "request_max_limit": 5, + "request_reset_duration": "1m", + }, + } + create_response = governance_client.create_virtual_key(vk_data) + assert_response_success(create_response, 201) + independent_vk = create_response.json()["virtual_key"] + cleanup_tracker.add_virtual_key(independent_vk["id"]) + + messages = [{"role": "user", "content": "Short"}] + headers = {"x-bf-vk": independent_vk["value"]} + + # Make multiple small requests (should hit request limit before token limit) + responses = [] + for i in range(10): # More than request limit + response = governance_client.chat_completion( + messages=messages, + model="openai/gpt-3.5-turbo", + headers=headers, + max_tokens=5, # Small token count + ) + responses.append(response) + time.sleep(0.1) + + # Should be limited by request count, not tokens + rate_limited_responses = [r for r in responses if r.status_code == 429] + assert len(rate_limited_responses) > 0 + + @pytest.mark.integration + @pytest.mark.rate_limit + @pytest.mark.usage_tracking + def test_rate_limit_reset(self, governance_client, cleanup_tracker): + """Test rate limit reset functionality""" + # Create VK with short reset duration for testing + vk_data = { + "name": generate_unique_name("Reset Test VK"), + "rate_limit": { + "request_max_limit": 1, + "request_reset_duration": "5s", # Short duration for testing + }, + } + create_response = governance_client.create_virtual_key(vk_data) + assert_response_success(create_response, 201) + reset_vk = create_response.json()["virtual_key"] + cleanup_tracker.add_virtual_key(reset_vk["id"]) + + messages = [{"role": "user", "content": "Hello"}] + headers = {"x-bf-vk": reset_vk["value"]} + + # Make first request (should succeed) + response1 = governance_client.chat_completion( + messages=messages, + model="openai/gpt-3.5-turbo", + headers=headers, + max_tokens=5, + ) + + # Make second request immediately (should be rate limited) + response2 = governance_client.chat_completion( + messages=messages, + model="openai/gpt-3.5-turbo", + headers=headers, + max_tokens=5, + ) + + # Reset the rate limit manually + reset_data = {"virtual_key_id": reset_vk["id"]} + reset_response = governance_client.reset_usage(reset_data) + assert_response_success(reset_response, 200) + + # Make third request after reset (should succeed) + response3 = governance_client.chat_completion( + messages=messages, + model="openai/gpt-3.5-turbo", + headers=headers, + max_tokens=5, + ) + + # Should work after reset + assert response3.status_code in [200, 429] # Success or different limit + + +class TestConcurrentUsageTracking: + """Test concurrent usage tracking and limits""" + + @pytest.mark.integration + @pytest.mark.concurrency + @pytest.mark.usage_tracking + @pytest.mark.slow + def test_concurrent_requests_same_vk(self, governance_client, cleanup_tracker): + """Test concurrent requests using same VK""" + # Create VK with moderate limits + vk_data = { + "name": generate_unique_name("Concurrent VK"), + "rate_limit": {"request_max_limit": 10, "request_reset_duration": "1m"}, + "budget": {"max_limit": 10000, "reset_duration": "1h"}, + } + create_response = governance_client.create_virtual_key(vk_data) + assert_response_success(create_response, 201) + concurrent_vk = create_response.json()["virtual_key"] + cleanup_tracker.add_virtual_key(concurrent_vk["id"]) + + messages = [{"role": "user", "content": "Hello"}] + headers = {"x-bf-vk": concurrent_vk["value"]} + + def make_request(index): + try: + response = governance_client.chat_completion( + messages=messages, + model="openai/gpt-3.5-turbo", + headers=headers, + max_tokens=5, + ) + return response.status_code, index + except Exception as e: + return str(e), index + + # Make 15 concurrent requests (more than rate limit) + with ThreadPoolExecutor(max_workers=15) as executor: + futures = [executor.submit(make_request, i) for i in range(15)] + results = [future.result() for future in futures] + + # Count success vs rate limited responses + success_codes = [r[0] for r in results if r[0] == 200] + rate_limited_codes = [r[0] for r in results if r[0] == 429] + + # Should have some successful and some rate limited + total_responses = len(success_codes) + len(rate_limited_codes) + assert total_responses > 0 + + # Rate limiting should have kicked in for some requests + assert len(success_codes) <= 10 # Shouldn't exceed rate limit + + @pytest.mark.integration + @pytest.mark.concurrency + @pytest.mark.usage_tracking + @pytest.mark.slow + def test_concurrent_budget_tracking(self, governance_client, cleanup_tracker): + """Test concurrent budget tracking accuracy""" + # Create VK with small budget for testing + vk_data = { + "name": generate_unique_name("Budget Tracking VK"), + "budget": {"max_limit": 1000, "reset_duration": "1h"}, # Small budget + } + create_response = governance_client.create_virtual_key(vk_data) + assert_response_success(create_response, 201) + budget_vk = create_response.json()["virtual_key"] + cleanup_tracker.add_virtual_key(budget_vk["id"]) + + messages = [{"role": "user", "content": "Count to 10"}] + headers = {"x-bf-vk": budget_vk["value"]} + + def make_budget_request(index): + try: + response = governance_client.chat_completion( + messages=messages, + model="openai/gpt-3.5-turbo", + headers=headers, + max_tokens=20, + ) + return ( + response.status_code, + index, + response.json() if response.status_code == 200 else None, + ) + except Exception as e: + return str(e), index, None + + # Make concurrent requests that should consume budget + with ThreadPoolExecutor(max_workers=5) as executor: + futures = [executor.submit(make_budget_request, i) for i in range(5)] + results = [future.result() for future in futures] + + # Check budget tracking consistency + success_count = sum(1 for r in results if r[0] == 200) + budget_exceeded_count = sum(1 for r in results if r[0] == 402) + + # Should have proper budget enforcement + assert success_count + budget_exceeded_count > 0 + + +class TestStreamingIntegration: + """Test streaming integration with governance""" + + @pytest.mark.integration + @pytest.mark.usage_tracking + def test_streaming_chat_completion_with_governance( + self, governance_client, sample_virtual_key + ): + """Test streaming chat completion with governance headers""" + messages = [{"role": "user", "content": "Count from 1 to 5"}] + headers = {"x-bf-vk": sample_virtual_key["value"]} + + response = governance_client.chat_completion( + messages=messages, + model="openai/gpt-3.5-turbo", + headers=headers, + max_tokens=50, + stream=True, + ) + + # Streaming should work with governance + if response.status_code == 200: + # For streaming, response should be text/event-stream + content_type = response.headers.get("content-type", "") + assert ( + "text/event-stream" in content_type + or "application/json" in content_type + ) + else: + # Should be properly governed (rate limited, budget exceeded, etc.) + assert response.status_code in [402, 403, 429] + + @pytest.mark.integration + @pytest.mark.usage_tracking + @pytest.mark.rate_limit + def test_streaming_rate_limit_enforcement(self, governance_client, cleanup_tracker): + """Test rate limiting during streaming requests""" + # Create VK with token rate limit + vk_data = { + "name": generate_unique_name("Streaming Rate Limit VK"), + "rate_limit": {"token_max_limit": 50, "token_reset_duration": "1m"}, + } + create_response = governance_client.create_virtual_key(vk_data) + assert_response_success(create_response, 201) + streaming_vk = create_response.json()["virtual_key"] + cleanup_tracker.add_virtual_key(streaming_vk["id"]) + + messages = [{"role": "user", "content": "Write a long story about AI"}] + headers = {"x-bf-vk": streaming_vk["value"]} + + response = governance_client.chat_completion( + messages=messages, + model="openai/gpt-3.5-turbo", + headers=headers, + max_tokens=200, # More than token limit + stream=True, + ) + + # Should be limited by token rate limit + if response.status_code == 429: + error_data = response.json() + assert "token" in error_data.get("error", "").lower() + + +class TestProviderModelValidation: + """Test provider and model validation during integration""" + + @pytest.mark.integration + @pytest.mark.validation + def test_anthropic_model_integration(self, governance_client, cleanup_tracker): + """Test integration with Anthropic models""" + # Create VK allowing Anthropic + vk_data = { + "name": generate_unique_name("Anthropic VK"), + "allowed_providers": ["anthropic"], + "allowed_models": ["claude-3-5-sonnet-20240620"], + } + create_response = governance_client.create_virtual_key(vk_data) + assert_response_success(create_response, 201) + anthropic_vk = create_response.json()["virtual_key"] + cleanup_tracker.add_virtual_key(anthropic_vk["id"]) + + messages = [{"role": "user", "content": "Hello Claude"}] + headers = {"x-bf-vk": anthropic_vk["value"]} + + response = governance_client.chat_completion( + messages=messages, + model="claude-3-5-sonnet-20240620", + headers=headers, + max_tokens=10, + ) + + # Should work if Anthropic is properly configured + assert response.status_code in [200, 400, 402, 429, 503] + + @pytest.mark.integration + @pytest.mark.validation + def test_openai_model_integration(self, governance_client, cleanup_tracker): + """Test integration with OpenAI models""" + # Create VK allowing OpenAI + vk_data = { + "name": generate_unique_name("OpenAI VK"), + "allowed_providers": ["openai"], + "allowed_models": ["gpt-4", "gpt-3.5-turbo"], + } + create_response = governance_client.create_virtual_key(vk_data) + assert_response_success(create_response, 201) + openai_vk = create_response.json()["virtual_key"] + cleanup_tracker.add_virtual_key(openai_vk["id"]) + + messages = [{"role": "user", "content": "Hello GPT"}] + headers = {"x-bf-vk": openai_vk["value"]} + + # Test GPT-4 + response = governance_client.chat_completion( + messages=messages, model="gpt-4", headers=headers, max_tokens=10 + ) + + # Should work if OpenAI is properly configured + assert response.status_code in [200, 400, 402, 429, 503] + + @pytest.mark.integration + @pytest.mark.validation + def test_disallowed_provider_model_combination( + self, governance_client, cleanup_tracker + ): + """Test disallowed provider/model combinations""" + # Create VK only allowing OpenAI + vk_data = { + "name": generate_unique_name("OpenAI Only VK"), + "allowed_providers": ["openai"], + "allowed_models": ["gpt-4"], + } + create_response = governance_client.create_virtual_key(vk_data) + assert_response_success(create_response, 201) + restricted_vk = create_response.json()["virtual_key"] + cleanup_tracker.add_virtual_key(restricted_vk["id"]) + + messages = [{"role": "user", "content": "Hello"}] + headers = {"x-bf-vk": restricted_vk["value"]} + + # Try to use Anthropic model (should fail) + response = governance_client.chat_completion( + messages=messages, + model="claude-3-5-sonnet-20240620", + headers=headers, + max_tokens=10, + ) + + # Should be rejected for disallowed model + assert response.status_code in [400, 403] + + +class TestErrorHandlingAndEdgeCases: + """Test error handling and edge cases in usage tracking""" + + @pytest.mark.integration + @pytest.mark.edge_cases + def test_malformed_vk_header(self, governance_client): + """Test malformed VK header handling""" + messages = [{"role": "user", "content": "Hello"}] + + malformed_headers = [ + {"x-bf-vk": ""}, # Empty + {"x-bf-vk": " "}, # Whitespace + {"x-bf-vk": "short"}, # Too short + {"x-bf-vk": "x" * 100}, # Too long + {"x-bf-vk": "invalid-characters-#@!"}, # Invalid chars + ] + + for headers in malformed_headers: + response = governance_client.chat_completion( + messages=messages, + model="openai/gpt-3.5-turbo", + headers=headers, + max_tokens=5, + ) + + # Should properly reject malformed headers + assert response.status_code in [400, 403] + + @pytest.mark.integration + @pytest.mark.edge_cases + def test_concurrent_vk_updates_during_requests( + self, governance_client, cleanup_tracker + ): + """Test VK updates during active requests""" + # Create VK + vk_data = {"name": generate_unique_name("Update Test VK")} + create_response = governance_client.create_virtual_key(vk_data) + assert_response_success(create_response, 201) + update_vk = create_response.json()["virtual_key"] + cleanup_tracker.add_virtual_key(update_vk["id"]) + + messages = [{"role": "user", "content": "Hello"}] + headers = {"x-bf-vk": update_vk["value"]} + + def make_request(): + return governance_client.chat_completion( + messages=messages, + model="openai/gpt-3.5-turbo", + headers=headers, + max_tokens=5, + ) + + def update_vk_config(): + update_data = {"description": "Updated during request"} + return governance_client.update_virtual_key(update_vk["id"], update_data) + + # Start request and update concurrently + with ThreadPoolExecutor(max_workers=2) as executor: + request_future = executor.submit(make_request) + update_future = executor.submit(update_vk_config) + + request_response = request_future.result() + update_response = update_future.result() + + # Both should handle gracefully + assert request_response.status_code in [200, 402, 403, 429] + assert_response_success(update_response, 200) + + @pytest.mark.integration + @pytest.mark.edge_cases + def test_extreme_token_counts(self, governance_client, sample_virtual_key): + """Test extreme token count scenarios""" + headers = {"x-bf-vk": sample_virtual_key["value"]} + + # Test with 0 max_tokens + response = governance_client.chat_completion( + messages=[{"role": "user", "content": "Hello"}], + model="openai/gpt-3.5-turbo", + headers=headers, + max_tokens=0, + ) + + # Should handle 0 tokens gracefully + assert response.status_code in [200, 400] + + # Test with very large max_tokens + response = governance_client.chat_completion( + messages=[{"role": "user", "content": "Hello"}], + model="openai/gpt-3.5-turbo", + headers=headers, + max_tokens=100000, # Very large + ) + + # Should handle large token requests + assert response.status_code in [200, 400, 402, 429] + + @pytest.mark.integration + @pytest.mark.edge_cases + def test_empty_and_large_messages(self, governance_client, sample_virtual_key): + """Test empty and very large message scenarios""" + headers = {"x-bf-vk": sample_virtual_key["value"]} + + # Test with empty message + response = governance_client.chat_completion( + messages=[{"role": "user", "content": ""}], + model="openai/gpt-3.5-turbo", + headers=headers, + max_tokens=5, + ) + + # Should handle empty messages + assert response.status_code in [200, 400] + + # Test with very large message + large_content = "This is a very long message. " * 1000 + response = governance_client.chat_completion( + messages=[{"role": "user", "content": large_content}], + model="openai/gpt-3.5-turbo", + headers=headers, + max_tokens=5, + ) + + # Should handle large messages + assert response.status_code in [200, 400, 402, 429] + + +class TestPerformanceAndScaling: + """Test performance and scaling of usage tracking""" + + @pytest.mark.integration + @pytest.mark.performance + @pytest.mark.slow + def test_high_frequency_requests(self, governance_client, cleanup_tracker): + """Test high frequency requests performance""" + # Create VK with high limits + vk_data = { + "name": generate_unique_name("High Frequency VK"), + "rate_limit": { + "request_max_limit": 1000, + "request_reset_duration": "1h", + "token_max_limit": 100000, + "token_reset_duration": "1h", + }, + "budget": {"max_limit": 1000000, "reset_duration": "1h"}, + } + create_response = governance_client.create_virtual_key(vk_data) + assert_response_success(create_response, 201) + high_freq_vk = create_response.json()["virtual_key"] + cleanup_tracker.add_virtual_key(high_freq_vk["id"]) + + messages = [{"role": "user", "content": "Hi"}] + headers = {"x-bf-vk": high_freq_vk["value"]} + + # Measure performance of rapid requests + start_time = time.time() + responses = [] + + for i in range(20): # Make 20 rapid requests + response = governance_client.chat_completion( + messages=messages, + model="openai/gpt-3.5-turbo", + headers=headers, + max_tokens=1, + ) + responses.append(response.status_code) + if i % 5 == 0: + time.sleep(0.1) # Brief pause every 5 requests + + total_time = time.time() - start_time + + # Performance assertions + assert total_time < 30.0, f"20 requests took too long: {total_time}s" + + # Most requests should succeed (unless rate limited) + success_count = sum(1 for code in responses if code == 200) + print( + f"High frequency test: {success_count}/20 requests succeeded in {total_time:.2f}s" + ) + + @pytest.mark.integration + @pytest.mark.performance + @pytest.mark.slow + def test_usage_stats_performance(self, governance_client, cleanup_tracker): + """Test usage statistics endpoint performance""" + # Create multiple VKs and make requests + vk_ids = [] + for i in range(10): + vk_data = {"name": generate_unique_name(f"Stats Perf VK {i}")} + response = governance_client.create_virtual_key(vk_data) + assert_response_success(response, 201) + vk_id = response.json()["virtual_key"]["id"] + vk_ids.append(vk_id) + cleanup_tracker.add_virtual_key(vk_id) + + # Test general stats performance + start_time = time.time() + response = governance_client.get_usage_stats() + stats_time = time.time() - start_time + + assert_response_success(response, 200) + assert stats_time < 2.0, f"Usage stats took too long: {stats_time}s" + + # Test individual VK stats performance + start_time = time.time() + for vk_id in vk_ids[:5]: # Test 5 VKs + response = governance_client.get_usage_stats(virtual_key_id=vk_id) + assert_response_success(response, 200) + + individual_stats_time = time.time() - start_time + assert ( + individual_stats_time < 5.0 + ), f"Individual VK stats took too long: {individual_stats_time}s" + + print( + f"Performance test: General stats: {stats_time:.2f}s, 5 individual stats: {individual_stats_time:.2f}s" + ) diff --git a/tests/governance/test_virtual_keys_crud.py b/tests/governance/test_virtual_keys_crud.py new file mode 100644 index 000000000..f2b025956 --- /dev/null +++ b/tests/governance/test_virtual_keys_crud.py @@ -0,0 +1,928 @@ +""" +Comprehensive Virtual Key CRUD Tests for Bifrost Governance Plugin + +This module provides exhaustive testing of Virtual Key operations including: +- Complete CRUD lifecycle testing +- Comprehensive field update testing (individual and batch) +- Mutual exclusivity validation (team_id vs customer_id) +- Budget and rate limit management +- Relationship testing with teams and customers +- Edge cases and validation scenarios +- Concurrency and race condition testing +""" + +import pytest +import time +import uuid +from typing import Dict, Any, List +from concurrent.futures import ThreadPoolExecutor +import copy + +from conftest import ( + assert_response_success, + verify_unchanged_fields, + generate_unique_name, + create_complete_virtual_key_data, + verify_entity_relationships, + deep_compare_entities, +) + + +class TestVirtualKeyBasicCRUD: + """Test basic CRUD operations for Virtual Keys""" + + @pytest.mark.virtual_keys + @pytest.mark.crud + @pytest.mark.smoke + def test_vk_create_minimal(self, governance_client, cleanup_tracker): + """Test creating VK with minimal required data""" + data = {"name": generate_unique_name("Minimal VK")} + + response = governance_client.create_virtual_key(data) + assert_response_success(response, 201) + + vk_data = response.json()["virtual_key"] + cleanup_tracker.add_virtual_key(vk_data["id"]) + + # Verify required fields + assert vk_data["name"] == data["name"] + assert vk_data["value"] is not None # Auto-generated + assert vk_data["is_active"] is True # Default value + assert vk_data["id"] is not None + assert vk_data["created_at"] is not None + assert vk_data["updated_at"] is not None + + # Verify optional fields are None/empty + assert vk_data["allowed_models"] is None + assert vk_data["allowed_providers"] is None + + @pytest.mark.virtual_keys + @pytest.mark.crud + def test_vk_create_complete(self, governance_client, cleanup_tracker): + """Test creating VK with all possible fields""" + data = create_complete_virtual_key_data() + + response = governance_client.create_virtual_key(data) + assert_response_success(response, 201) + + vk_data = response.json()["virtual_key"] + cleanup_tracker.add_virtual_key(vk_data["id"]) + + # Verify all fields are set correctly + assert vk_data["name"] == data["name"] + assert vk_data["description"] == data["description"] + assert vk_data["allowed_models"] == data["allowed_models"] + assert vk_data["allowed_providers"] == data["allowed_providers"] + assert vk_data["is_active"] == data["is_active"] + + # Verify budget was created + assert vk_data["budget"] is not None + assert vk_data["budget"]["max_limit"] == data["budget"]["max_limit"] + assert vk_data["budget"]["reset_duration"] == data["budget"]["reset_duration"] + + # Verify rate limit was created + assert vk_data["rate_limit"] is not None + assert ( + vk_data["rate_limit"]["token_max_limit"] + == data["rate_limit"]["token_max_limit"] + ) + assert ( + vk_data["rate_limit"]["request_max_limit"] + == data["rate_limit"]["request_max_limit"] + ) + + @pytest.mark.virtual_keys + @pytest.mark.crud + def test_vk_create_with_team(self, governance_client, cleanup_tracker, sample_team): + """Test creating VK associated with a team""" + data = {"name": generate_unique_name("Team VK"), "team_id": sample_team["id"]} + + response = governance_client.create_virtual_key(data) + assert_response_success(response, 201) + + vk_data = response.json()["virtual_key"] + cleanup_tracker.add_virtual_key(vk_data["id"]) + + # Verify team association + assert vk_data["team_id"] == sample_team["id"] + assert vk_data.get("customer_id") is None + assert vk_data["team"] is not None + assert vk_data["team"]["id"] == sample_team["id"] + + @pytest.mark.virtual_keys + @pytest.mark.crud + def test_vk_create_with_customer( + self, governance_client, cleanup_tracker, sample_customer + ): + """Test creating VK associated with a customer""" + data = { + "name": generate_unique_name("Customer VK"), + "customer_id": sample_customer["id"], + } + + response = governance_client.create_virtual_key(data) + assert_response_success(response, 201) + + vk_data = response.json()["virtual_key"] + cleanup_tracker.add_virtual_key(vk_data["id"]) + + # Verify customer association + assert vk_data["customer_id"] == sample_customer["id"] + assert vk_data.get("team_id") is None + assert vk_data["customer"] is not None + assert vk_data["customer"]["id"] == sample_customer["id"] + + @pytest.mark.virtual_keys + @pytest.mark.crud + @pytest.mark.mutual_exclusivity + def test_vk_create_mutual_exclusivity_violation( + self, governance_client, sample_team, sample_customer + ): + """Test that VK cannot be created with both team_id and customer_id""" + data = { + "name": generate_unique_name("Invalid VK"), + "team_id": sample_team["id"], + "customer_id": sample_customer["id"], + } + + response = governance_client.create_virtual_key(data) + assert response.status_code == 400 + error_data = response.json() + assert "cannot be attached to both" in error_data["error"].lower() + + @pytest.mark.virtual_keys + @pytest.mark.crud + def test_vk_list_all(self, governance_client, sample_virtual_key): + """Test listing all virtual keys""" + response = governance_client.list_virtual_keys() + assert_response_success(response, 200) + + data = response.json() + assert "virtual_keys" in data + assert "count" in data + assert isinstance(data["virtual_keys"], list) + assert data["count"] >= 1 + + # Find our test VK + test_vk = next( + (vk for vk in data["virtual_keys"] if vk["id"] == sample_virtual_key["id"]), + None, + ) + assert test_vk is not None + + @pytest.mark.virtual_keys + @pytest.mark.crud + def test_vk_get_by_id(self, governance_client, sample_virtual_key): + """Test getting VK by ID with relationships loaded""" + response = governance_client.get_virtual_key(sample_virtual_key["id"]) + assert_response_success(response, 200) + + vk_data = response.json()["virtual_key"] + assert vk_data["id"] == sample_virtual_key["id"] + assert vk_data["name"] == sample_virtual_key["name"] + + @pytest.mark.virtual_keys + @pytest.mark.crud + def test_vk_get_nonexistent(self, governance_client): + """Test getting non-existent VK returns 404""" + fake_id = str(uuid.uuid4()) + response = governance_client.get_virtual_key(fake_id) + assert response.status_code == 404 + + @pytest.mark.virtual_keys + @pytest.mark.crud + def test_vk_delete(self, governance_client, cleanup_tracker): + """Test deleting a virtual key""" + # Create VK to delete + data = {"name": generate_unique_name("Delete Test VK")} + create_response = governance_client.create_virtual_key(data) + assert_response_success(create_response, 201) + vk_id = create_response.json()["virtual_key"]["id"] + + # Delete VK + delete_response = governance_client.delete_virtual_key(vk_id) + assert_response_success(delete_response, 200) + + # Verify VK is gone + get_response = governance_client.get_virtual_key(vk_id) + assert get_response.status_code == 404 + + @pytest.mark.virtual_keys + @pytest.mark.crud + def test_vk_delete_nonexistent(self, governance_client): + """Test deleting non-existent VK returns 404""" + fake_id = str(uuid.uuid4()) + response = governance_client.delete_virtual_key(fake_id) + assert response.status_code == 404 + + +class TestVirtualKeyValidation: + """Test validation rules for Virtual Key operations""" + + @pytest.mark.virtual_keys + @pytest.mark.validation + def test_vk_create_missing_name(self, governance_client): + """Test creating VK without name fails""" + data = {"description": "VK without name"} + response = governance_client.create_virtual_key(data) + assert response.status_code == 400 + + @pytest.mark.virtual_keys + @pytest.mark.validation + def test_vk_create_empty_name(self, governance_client): + """Test creating VK with empty name fails""" + data = {"name": ""} + response = governance_client.create_virtual_key(data) + assert response.status_code == 400 + + @pytest.mark.virtual_keys + @pytest.mark.validation + def test_vk_create_invalid_team_id(self, governance_client): + """Test creating VK with non-existent team_id""" + data = { + "name": generate_unique_name("Invalid Team VK"), + "team_id": str(uuid.uuid4()), + } + response = governance_client.create_virtual_key(data) + # Note: Depending on implementation, this might succeed with warning or fail + # Adjust assertion based on actual API behavior + + @pytest.mark.virtual_keys + @pytest.mark.validation + def test_vk_create_invalid_customer_id(self, governance_client): + """Test creating VK with non-existent customer_id""" + data = { + "name": generate_unique_name("Invalid Customer VK"), + "customer_id": str(uuid.uuid4()), + } + response = governance_client.create_virtual_key(data) + # Note: Adjust assertion based on actual API behavior + + @pytest.mark.virtual_keys + @pytest.mark.validation + def test_vk_create_invalid_json(self, governance_client): + """Test creating VK with malformed JSON""" + # This would be tested at the HTTP level, but pytest requests handles JSON encoding + # So we test with invalid data types instead + data = { + "name": 123, # Should be string + "is_active": "not_boolean", # Should be boolean + } + response = governance_client.create_virtual_key(data) + assert response.status_code == 400 + + +class TestVirtualKeyFieldUpdates: + """Comprehensive tests for Virtual Key field updates""" + + @pytest.mark.virtual_keys + @pytest.mark.field_updates + def test_vk_update_individual_fields( + self, governance_client, cleanup_tracker, sample_team, sample_customer + ): + """Test updating each VK field individually""" + # Create complete VK for testing + original_data = create_complete_virtual_key_data() + create_response = governance_client.create_virtual_key(original_data) + assert_response_success(create_response, 201) + vk_id = create_response.json()["virtual_key"]["id"] + cleanup_tracker.add_virtual_key(vk_id) + + # Get original state + original_response = governance_client.get_virtual_key(vk_id) + original_vk = original_response.json()["virtual_key"] + + # Test individual field updates + field_test_cases = [ + { + "field": "description", + "update_data": {"description": "Updated description"}, + "expected_value": "Updated description", + }, + { + "field": "allowed_models", + "update_data": {"allowed_models": ["gpt-4", "claude-3-opus"]}, + "expected_value": ["gpt-4", "claude-3-opus"], + }, + { + "field": "allowed_providers", + "update_data": {"allowed_providers": ["openai"]}, + "expected_value": ["openai"], + }, + { + "field": "is_active", + "update_data": {"is_active": False}, + "expected_value": False, + }, + { + "field": "team_id", + "update_data": {"team_id": sample_team["id"]}, + "expected_value": sample_team["id"], + "exclude_from_unchanged_check": [ + "team_id", + "customer_id", + "team", + "customer", + ], + }, + { + "field": "customer_id", + "update_data": {"customer_id": sample_customer["id"]}, + "expected_value": sample_customer["id"], + "exclude_from_unchanged_check": [ + "team_id", + "customer_id", + "team", + "customer", + ], + }, + ] + + for test_case in field_test_cases: + # Reset VK to original state by updating all fields back + reset_data = { + "description": original_vk.get("description", ""), + "allowed_models": original_vk["allowed_models"], + "allowed_providers": original_vk["allowed_providers"], + "is_active": original_vk["is_active"], + "team_id": original_vk.get("team_id"), + "customer_id": original_vk.get("customer_id"), + } + governance_client.update_virtual_key(vk_id, reset_data) + + # Perform field update + response = governance_client.update_virtual_key( + vk_id, test_case["update_data"] + ) + assert_response_success(response, 200) + updated_vk = response.json()["virtual_key"] + + # Verify target field was updated + field_parts = test_case["field"].split(".") + current_value = updated_vk + for part in field_parts: + current_value = current_value[part] + assert ( + current_value == test_case["expected_value"] + ), f"Field {test_case['field']} not updated correctly" + + # Verify other fields unchanged (if specified) + if test_case.get("verify_unchanged", True): + exclude_fields = test_case.get( + "exclude_from_unchanged_check", [test_case["field"]] + ) + verify_unchanged_fields(updated_vk, original_vk, exclude_fields) + + @pytest.mark.virtual_keys + @pytest.mark.field_updates + def test_vk_budget_updates(self, governance_client, cleanup_tracker): + """Test comprehensive budget creation, update, and modification""" + # Create VK without budget + data = {"name": generate_unique_name("Budget Test VK")} + create_response = governance_client.create_virtual_key(data) + assert_response_success(create_response, 201) + vk_id = create_response.json()["virtual_key"]["id"] + cleanup_tracker.add_virtual_key(vk_id) + + # Test 1: Add budget to VK without budget + budget_data = {"max_limit": 10000, "reset_duration": "1h"} + response = governance_client.update_virtual_key(vk_id, {"budget": budget_data}) + assert_response_success(response, 200) + updated_vk = response.json()["virtual_key"] + assert updated_vk["budget"]["max_limit"] == 10000 + assert updated_vk["budget"]["reset_duration"] == "1h" + assert updated_vk["budget_id"] is not None + + # Test 2: Update existing budget completely + new_budget_data = {"max_limit": 20000, "reset_duration": "2h"} + response = governance_client.update_virtual_key( + vk_id, {"budget": new_budget_data} + ) + assert_response_success(response, 200) + updated_vk = response.json()["virtual_key"] + assert updated_vk["budget"]["max_limit"] == 20000 + assert updated_vk["budget"]["reset_duration"] == "2h" + + # Test 3: Partial budget update (only max_limit) + response = governance_client.update_virtual_key( + vk_id, {"budget": {"max_limit": 30000}} + ) + assert_response_success(response, 200) + updated_vk = response.json()["virtual_key"] + assert updated_vk["budget"]["max_limit"] == 30000 + assert updated_vk["budget"]["reset_duration"] == "2h" # Should remain unchanged + + # Test 4: Partial budget update (only reset_duration) + response = governance_client.update_virtual_key( + vk_id, {"budget": {"reset_duration": "24h"}} + ) + assert_response_success(response, 200) + updated_vk = response.json()["virtual_key"] + assert updated_vk["budget"]["max_limit"] == 30000 # Should remain unchanged + assert updated_vk["budget"]["reset_duration"] == "24h" + + @pytest.mark.virtual_keys + @pytest.mark.field_updates + def test_vk_rate_limit_updates(self, governance_client, cleanup_tracker): + """Test comprehensive rate limit creation, update, and field-level modifications""" + # Create VK without rate limit + data = {"name": generate_unique_name("Rate Limit Test VK")} + create_response = governance_client.create_virtual_key(data) + assert_response_success(create_response, 201) + vk_id = create_response.json()["virtual_key"]["id"] + cleanup_tracker.add_virtual_key(vk_id) + + # Test 1: Add rate limit to VK + rate_limit_data = { + "token_max_limit": 1000, + "token_reset_duration": "1m", + "request_max_limit": 100, + "request_reset_duration": "1h", + } + response = governance_client.update_virtual_key( + vk_id, {"rate_limit": rate_limit_data} + ) + assert_response_success(response, 200) + updated_vk = response.json()["virtual_key"] + assert updated_vk["rate_limit"]["token_max_limit"] == 1000 + assert updated_vk["rate_limit"]["request_max_limit"] == 100 + assert updated_vk["rate_limit_id"] is not None + + # Test 2: Update only token limits + response = governance_client.update_virtual_key( + vk_id, + {"rate_limit": {"token_max_limit": 2000, "token_reset_duration": "2m"}}, + ) + assert_response_success(response, 200) + updated_vk = response.json()["virtual_key"] + assert updated_vk["rate_limit"]["token_max_limit"] == 2000 + assert updated_vk["rate_limit"]["token_reset_duration"] == "2m" + assert updated_vk["rate_limit"]["request_max_limit"] == 100 # Unchanged + assert updated_vk["rate_limit"]["request_reset_duration"] == "1h" # Unchanged + + # Test 3: Update only request limits + response = governance_client.update_virtual_key( + vk_id, + {"rate_limit": {"request_max_limit": 200, "request_reset_duration": "2h"}}, + ) + assert_response_success(response, 200) + updated_vk = response.json()["virtual_key"] + assert updated_vk["rate_limit"]["token_max_limit"] == 2000 # Unchanged + assert updated_vk["rate_limit"]["request_max_limit"] == 200 + assert updated_vk["rate_limit"]["request_reset_duration"] == "2h" + + # Test 4: Partial rate limit update (single field) + response = governance_client.update_virtual_key( + vk_id, {"rate_limit": {"token_max_limit": 5000}} + ) + assert_response_success(response, 200) + updated_vk = response.json()["virtual_key"] + assert updated_vk["rate_limit"]["token_max_limit"] == 5000 + assert updated_vk["rate_limit"]["token_reset_duration"] == "2m" # Unchanged + assert updated_vk["rate_limit"]["request_max_limit"] == 200 # Unchanged + assert updated_vk["rate_limit"]["request_reset_duration"] == "2h" # Unchanged + + @pytest.mark.virtual_keys + @pytest.mark.field_updates + def test_vk_multiple_field_updates(self, governance_client, cleanup_tracker): + """Test updating multiple fields simultaneously""" + # Create VK with some initial data + initial_data = { + "name": generate_unique_name("Multi-Field Test VK"), + "description": "Initial description", + "allowed_models": ["gpt-3.5-turbo"], + "is_active": True, + } + create_response = governance_client.create_virtual_key(initial_data) + assert_response_success(create_response, 201) + vk_id = create_response.json()["virtual_key"]["id"] + cleanup_tracker.add_virtual_key(vk_id) + + # Update multiple fields at once + update_data = { + "description": "Updated description via multi-field", + "allowed_models": ["gpt-4", "claude-3-5-sonnet-20240620"], + "allowed_providers": ["openai", "anthropic"], + "is_active": False, + "budget": {"max_limit": 50000, "reset_duration": "1d"}, + "rate_limit": { + "token_max_limit": 5000, + "request_max_limit": 500, + "token_reset_duration": "1h", + "request_reset_duration": "1h", + }, + } + + response = governance_client.update_virtual_key(vk_id, update_data) + assert_response_success(response, 200) + + updated_vk = response.json()["virtual_key"] + assert updated_vk["description"] == "Updated description via multi-field" + assert updated_vk["allowed_models"] == ["gpt-4", "claude-3-5-sonnet-20240620"] + assert updated_vk["allowed_providers"] == ["openai", "anthropic"] + assert updated_vk["is_active"] is False + assert updated_vk["budget"]["max_limit"] == 50000 + assert updated_vk["rate_limit"]["token_max_limit"] == 5000 + + @pytest.mark.virtual_keys + @pytest.mark.field_updates + @pytest.mark.mutual_exclusivity + def test_vk_relationship_updates( + self, governance_client, cleanup_tracker, sample_team, sample_customer + ): + """Test updating VK relationships with mutual exclusivity validation""" + # Create standalone VK + data = {"name": generate_unique_name("Relationship Test VK")} + create_response = governance_client.create_virtual_key(data) + assert_response_success(create_response, 201) + vk_id = create_response.json()["virtual_key"]["id"] + cleanup_tracker.add_virtual_key(vk_id) + + # Test 1: Add team relationship + response = governance_client.update_virtual_key( + vk_id, {"team_id": sample_team["id"]} + ) + assert_response_success(response, 200) + updated_vk = response.json()["virtual_key"] + assert updated_vk["team_id"] == sample_team["id"] + assert updated_vk.get("customer_id") is None + assert updated_vk["team"]["id"] == sample_team["id"] + + # Test 2: Switch to customer (should clear team) + response = governance_client.update_virtual_key( + vk_id, {"customer_id": sample_customer["id"]} + ) + assert_response_success(response, 200) + updated_vk = response.json()["virtual_key"] + assert updated_vk["customer_id"] == sample_customer["id"] + assert updated_vk.get("team_id") is None + assert updated_vk["customer"]["id"] == sample_customer["id"] + assert updated_vk.get("team") is None + + # Test 3: Try to set both (should fail) + response = governance_client.update_virtual_key( + vk_id, {"team_id": sample_team["id"], "customer_id": sample_customer["id"]} + ) + assert response.status_code == 400 + error_data = response.json() + assert "cannot be attached to both" in error_data["error"].lower() + + # Test 4: Clear both relationships + response = governance_client.update_virtual_key( + vk_id, {"team_id": None, "customer_id": None} + ) + # Note: Behavior depends on API implementation - adjust based on actual behavior + # Some APIs might not support explicit null setting + + @pytest.mark.virtual_keys + @pytest.mark.field_updates + @pytest.mark.edge_cases + def test_vk_update_edge_cases(self, governance_client, cleanup_tracker): + """Test edge cases in VK updates""" + # Create test VK + data = {"name": generate_unique_name("Edge Case VK")} + create_response = governance_client.create_virtual_key(data) + assert_response_success(create_response, 201) + vk_id = create_response.json()["virtual_key"]["id"] + cleanup_tracker.add_virtual_key(vk_id) + + original_response = governance_client.get_virtual_key(vk_id) + original_vk = original_response.json()["virtual_key"] + + # Test 1: Empty update (should return unchanged VK) + response = governance_client.update_virtual_key(vk_id, {}) + assert_response_success(response, 200) + updated_vk = response.json()["virtual_key"] + + # Compare ignoring timestamps + differences = deep_compare_entities( + updated_vk, original_vk, ignore_fields=["updated_at"] + ) + assert len(differences) == 0, f"Empty update changed fields: {differences}" + + # Test 2: Invalid field values + response = governance_client.update_virtual_key(vk_id, {"is_active": "invalid"}) + assert response.status_code == 400 + + # Test 3: Update with same values (should succeed but might not change updated_at) + response = governance_client.update_virtual_key( + vk_id, + { + "description": original_vk.get("description", ""), + }, + ) + # Note: Adjust based on API behavior for no-op updates + + # Test 4: Very long values (test field length limits) + long_description = "x" * 10000 # Adjust based on actual field limits + response = governance_client.update_virtual_key( + vk_id, {"description": long_description} + ) + # Expected behavior depends on API validation rules + + @pytest.mark.virtual_keys + @pytest.mark.field_updates + def test_vk_update_nonexistent(self, governance_client): + """Test updating non-existent VK returns 404""" + fake_id = str(uuid.uuid4()) + response = governance_client.update_virtual_key( + fake_id, {"description": "test"} + ) + assert response.status_code == 404 + + +class TestVirtualKeyBudgetAndRateLimit: + """Test budget and rate limit specific functionality""" + + @pytest.mark.virtual_keys + @pytest.mark.budget + def test_vk_budget_creation_and_validation( + self, governance_client, cleanup_tracker + ): + """Test budget creation with various configurations""" + # Test valid budget configurations + budget_test_cases = [ + {"max_limit": 1000, "reset_duration": "1h"}, + {"max_limit": 50000, "reset_duration": "1d"}, + {"max_limit": 100000, "reset_duration": "1w"}, + {"max_limit": 1000000, "reset_duration": "1M"}, + ] + + for budget_config in budget_test_cases: + data = { + "name": generate_unique_name( + f"Budget VK {budget_config['reset_duration']}" + ), + "budget": budget_config, + } + + response = governance_client.create_virtual_key(data) + assert_response_success(response, 201) + + vk_data = response.json()["virtual_key"] + cleanup_tracker.add_virtual_key(vk_data["id"]) + + assert vk_data["budget"]["max_limit"] == budget_config["max_limit"] + assert ( + vk_data["budget"]["reset_duration"] == budget_config["reset_duration"] + ) + assert vk_data["budget"]["current_usage"] == 0 + assert vk_data["budget"]["last_reset"] is not None + + @pytest.mark.virtual_keys + @pytest.mark.budget + @pytest.mark.edge_cases + def test_vk_budget_edge_cases(self, governance_client, cleanup_tracker): + """Test budget edge cases and boundary conditions""" + # Test boundary values + edge_case_budgets = [ + {"max_limit": 0, "reset_duration": "1h"}, # Zero budget + {"max_limit": 1, "reset_duration": "1s"}, # Minimal values + {"max_limit": 9223372036854775807, "reset_duration": "1h"}, # Max int64 + ] + + for budget_config in edge_case_budgets: + data = { + "name": generate_unique_name( + f"Edge Budget VK {budget_config['max_limit']}" + ), + "budget": budget_config, + } + + response = governance_client.create_virtual_key(data) + # Adjust assertions based on API validation rules + if ( + budget_config["max_limit"] >= 0 + ): # Assuming non-negative budgets are valid + assert_response_success(response, 201) + cleanup_tracker.add_virtual_key(response.json()["virtual_key"]["id"]) + else: + assert response.status_code == 400 + + @pytest.mark.virtual_keys + @pytest.mark.rate_limit + def test_vk_rate_limit_creation_and_validation( + self, governance_client, cleanup_tracker + ): + """Test rate limit creation with various configurations""" + # Test different rate limit configurations + rate_limit_test_cases = [ + { + "token_max_limit": 1000, + "token_reset_duration": "1m", + "request_max_limit": 100, + "request_reset_duration": "1h", + }, + { + "token_max_limit": 10000, + "token_reset_duration": "1h", + # Only token limits + }, + { + "request_max_limit": 500, + "request_reset_duration": "1d", + # Only request limits + }, + { + "token_max_limit": 5000, + "token_reset_duration": "30s", + "request_max_limit": 1000, + "request_reset_duration": "5m", + }, + ] + + for rate_limit_config in rate_limit_test_cases: + data = { + "name": generate_unique_name("Rate Limit VK"), + "rate_limit": rate_limit_config, + } + + response = governance_client.create_virtual_key(data) + assert_response_success(response, 201) + + vk_data = response.json()["virtual_key"] + cleanup_tracker.add_virtual_key(vk_data["id"]) + + rate_limit = vk_data["rate_limit"] + for key, value in rate_limit_config.items(): + assert rate_limit[key] == value + + @pytest.mark.virtual_keys + @pytest.mark.rate_limit + @pytest.mark.edge_cases + def test_vk_rate_limit_edge_cases(self, governance_client, cleanup_tracker): + """Test rate limit edge cases and boundary conditions""" + # Test minimal rate limits + minimal_rate_limit = { + "token_max_limit": 1, + "token_reset_duration": "1s", + "request_max_limit": 1, + "request_reset_duration": "1s", + } + + data = { + "name": generate_unique_name("Minimal Rate Limit VK"), + "rate_limit": minimal_rate_limit, + } + + response = governance_client.create_virtual_key(data) + assert_response_success(response, 201) + cleanup_tracker.add_virtual_key(response.json()["virtual_key"]["id"]) + + # Test large rate limits + large_rate_limit = { + "token_max_limit": 1000000, + "token_reset_duration": "1h", + "request_max_limit": 100000, + "request_reset_duration": "1h", + } + + data = { + "name": generate_unique_name("Large Rate Limit VK"), + "rate_limit": large_rate_limit, + } + + response = governance_client.create_virtual_key(data) + assert_response_success(response, 201) + cleanup_tracker.add_virtual_key(response.json()["virtual_key"]["id"]) + + +class TestVirtualKeyConcurrency: + """Test concurrent operations on Virtual Keys""" + + @pytest.mark.virtual_keys + @pytest.mark.concurrency + @pytest.mark.slow + def test_vk_concurrent_creation(self, governance_client, cleanup_tracker): + """Test creating multiple VKs concurrently""" + + def create_vk(index): + data = {"name": generate_unique_name(f"Concurrent VK {index}")} + response = governance_client.create_virtual_key(data) + return response + + # Create 10 VKs concurrently + with ThreadPoolExecutor(max_workers=10) as executor: + futures = [executor.submit(create_vk, i) for i in range(10)] + responses = [future.result() for future in futures] + + # Verify all succeeded + created_vks = [] + for response in responses: + assert_response_success(response, 201) + vk_data = response.json()["virtual_key"] + created_vks.append(vk_data) + cleanup_tracker.add_virtual_key(vk_data["id"]) + + # Verify all VKs have unique IDs and values + vk_ids = [vk["id"] for vk in created_vks] + vk_values = [vk["value"] for vk in created_vks] + assert len(set(vk_ids)) == 10 # All unique IDs + assert len(set(vk_values)) == 10 # All unique values + + @pytest.mark.virtual_keys + @pytest.mark.concurrency + @pytest.mark.slow + def test_vk_concurrent_updates(self, governance_client, cleanup_tracker): + """Test updating same VK concurrently""" + # Create VK to update + data = {"name": generate_unique_name("Concurrent Update VK")} + create_response = governance_client.create_virtual_key(data) + assert_response_success(create_response, 201) + vk_id = create_response.json()["virtual_key"]["id"] + cleanup_tracker.add_virtual_key(vk_id) + + # Update concurrently with different descriptions + def update_vk(index): + update_data = {"description": f"Updated by thread {index}"} + response = governance_client.update_virtual_key(vk_id, update_data) + return response, index + + with ThreadPoolExecutor(max_workers=5) as executor: + futures = [executor.submit(update_vk, i) for i in range(5)] + results = [future.result() for future in futures] + + # All updates should succeed (last one wins) + for response, index in results: + assert_response_success(response, 200) + + # Verify final state + final_response = governance_client.get_virtual_key(vk_id) + final_vk = final_response.json()["virtual_key"] + assert final_vk["description"].startswith("Updated by thread") + + +class TestVirtualKeyRelationships: + """Test VK relationships with teams and customers""" + + @pytest.mark.virtual_keys + @pytest.mark.relationships + def test_vk_team_relationship_loading( + self, governance_client, cleanup_tracker, sample_team_with_customer + ): + """Test that VK properly loads team and customer relationships""" + data = { + "name": generate_unique_name("Relationship VK"), + "team_id": sample_team_with_customer["id"], + } + + response = governance_client.create_virtual_key(data) + assert_response_success(response, 201) + vk_data = response.json()["virtual_key"] + cleanup_tracker.add_virtual_key(vk_data["id"]) + + # Verify team relationship loaded + assert vk_data["team"] is not None + assert vk_data["team"]["id"] == sample_team_with_customer["id"] + assert vk_data["team"]["name"] == sample_team_with_customer["name"] + + # Verify team's customer_id is present (nested customer not preloaded) + if sample_team_with_customer.get("customer_id"): + # Note: API only preloads one level deep, so customer object isn't nested here + assert ( + vk_data["team"].get("customer_id") + == sample_team_with_customer["customer_id"] + ) + + @pytest.mark.virtual_keys + @pytest.mark.relationships + def test_vk_customer_relationship_loading( + self, governance_client, cleanup_tracker, sample_customer + ): + """Test that VK properly loads customer relationships""" + data = { + "name": generate_unique_name("Customer Relationship VK"), + "customer_id": sample_customer["id"], + } + + response = governance_client.create_virtual_key(data) + assert_response_success(response, 201) + vk_data = response.json()["virtual_key"] + cleanup_tracker.add_virtual_key(vk_data["id"]) + + # Verify customer relationship loaded + assert vk_data["customer"] is not None + assert vk_data["customer"]["id"] == sample_customer["id"] + assert vk_data["customer"]["name"] == sample_customer["name"] + + @pytest.mark.virtual_keys + @pytest.mark.relationships + def test_vk_orphaned_relationships(self, governance_client, cleanup_tracker): + """Test VK behavior with orphaned team/customer references""" + # Create VK with non-existent team_id + fake_team_id = str(uuid.uuid4()) + data = {"name": generate_unique_name("Orphaned VK"), "team_id": fake_team_id} + + response = governance_client.create_virtual_key(data) + # Behavior depends on API implementation: + # - Might succeed with warning + # - Might fail with validation error + # Adjust assertion based on actual behavior + + if response.status_code == 201: + cleanup_tracker.add_virtual_key(response.json()["virtual_key"]["id"]) + # Verify VK was created but team relationship is null/missing + vk_data = response.json()["virtual_key"] + assert vk_data.get("team") is None + else: + assert response.status_code == 400 # Validation error expected diff --git a/tests/integrations/Makefile b/tests/integrations/Makefile new file mode 100644 index 000000000..2f0b2dc61 --- /dev/null +++ b/tests/integrations/Makefile @@ -0,0 +1,120 @@ +# Bifrost Python E2E Test Makefile +# Provides convenient commands for running tests + +# Get the directory where this Makefile is located +SCRIPT_DIR := $(dir $(abspath $(lastword $(MAKEFILE_LIST)))) + +.PHONY: help install test test-all test-parallel test-verbose clean lint format check-env + +# Default target +help: + @echo "Bifrost Python E2E Test Commands:" + @echo "" + @echo "Setup:" + @echo " install Install Python dependencies" + @echo " check-env Check environment variables" + @echo "" + @echo "Testing:" + @echo " test Run all tests using master runner" + @echo " test-all Run all tests with pytest" + @echo " test-parallel Run tests in parallel" + @echo " test-verbose Run tests with verbose output" + @echo " test-openai Run OpenAI integration tests only" + @echo " test-anthropic Run Anthropic integration tests only" + @echo " test-litellm Run LiteLLM integration tests only" + @echo " test-langchain Run LangChain integration tests only" + @echo " test-langgraph Run LangGraph integration tests only" + @echo " test-mistral Run Mistral integration tests only" + @echo " test-genai Run Google GenAI integration tests only" + @echo "" + @echo "Development:" + @echo " lint Run code linting" + @echo " format Format code with black" + @echo " clean Clean up temporary files" + +# Setup commands +install: + pip install -r $(SCRIPT_DIR)requirements.txt + +check-env: + @echo "Checking environment variables..." + @python -c "import os; print('βœ“ BIFROST_BASE_URL:', os.getenv('BIFROST_BASE_URL', 'http://localhost:8080'))" + @python -c "import os; print('βœ“ OPENAI_API_KEY:', 'Set' if os.getenv('OPENAI_API_KEY') else 'Not set')" + @python -c "import os; print('βœ“ ANTHROPIC_API_KEY:', 'Set' if os.getenv('ANTHROPIC_API_KEY') else 'Not set')" + @python -c "import os; print('βœ“ MISTRAL_API_KEY:', 'Set' if os.getenv('MISTRAL_API_KEY') else 'Not set')" + @python -c "import os; print('βœ“ GOOGLE_API_KEY:', 'Set' if os.getenv('GOOGLE_API_KEY') else 'Not set')" + +# Testing commands using master runner +test: + python $(SCRIPT_DIR)run_all_tests.py + +test-parallel: + python $(SCRIPT_DIR)run_all_tests.py --parallel + +test-verbose: + python $(SCRIPT_DIR)run_all_tests.py --verbose + +test-list: + python $(SCRIPT_DIR)run_all_tests.py --list + +# Individual integration tests +test-openai: + python $(SCRIPT_DIR)run_all_tests.py --integration openai --verbose + +test-anthropic: + python $(SCRIPT_DIR)run_all_tests.py --integration anthropic --verbose + +test-litellm: + python $(SCRIPT_DIR)run_all_tests.py --integration litellm --verbose + +test-langchain: + python $(SCRIPT_DIR)run_all_tests.py --integration langchain --verbose + +test-langgraph: + python $(SCRIPT_DIR)run_all_tests.py --integration langgraph --verbose + +test-mistral: + python $(SCRIPT_DIR)run_all_tests.py --integration mistral --verbose + +test-genai: + python $(SCRIPT_DIR)run_all_tests.py --integration genai --verbose + +# Pytest commands +test-all: + pytest -v + +test-pytest-parallel: + pytest -v -n auto + +test-coverage: + pytest --cov=. --cov-report=html --cov-report=term + +# Development commands +lint: + @echo "Running flake8..." + cd $(SCRIPT_DIR) && flake8 *.py + @echo "Running mypy..." + cd $(SCRIPT_DIR) && mypy *.py + +format: + @echo "Formatting code with black..." + cd $(SCRIPT_DIR) && black *.py + +clean: + @echo "Cleaning up temporary files..." + cd $(SCRIPT_DIR) && rm -rf __pycache__/ + cd $(SCRIPT_DIR) && rm -rf .pytest_cache/ + cd $(SCRIPT_DIR) && rm -rf .coverage + cd $(SCRIPT_DIR) && rm -rf htmlcov/ + cd $(SCRIPT_DIR) && rm -rf .mypy_cache/ + cd $(SCRIPT_DIR) && find . -name "*.pyc" -delete + cd $(SCRIPT_DIR) && find . -name "*.pyo" -delete + +# Quick commands for common workflows +quick-test: check-env test + +all-tests: install check-env test-parallel + +dev-setup: install check-env + @echo "Development environment ready!" + @echo "Run 'make test' to execute all tests" \ No newline at end of file diff --git a/tests/integrations/README.md b/tests/integrations/README.md new file mode 100644 index 000000000..aa105e523 --- /dev/null +++ b/tests/integrations/README.md @@ -0,0 +1,1564 @@ +# Bifrost Integration Tests + +Production-ready end-to-end test suite for testing AI integrations through Bifrost proxy. This test suite provides uniform testing across multiple AI integrations with comprehensive coverage of chat, tool calling, image processing, embeddings, speech synthesis, and multimodal workflows. + +## πŸŒ‰ Architecture Overview + +The Bifrost integration tests use a centralized configuration system that routes all AI integration requests through Bifrost as a gateway/proxy: + +```text +β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β” β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β” β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β” +β”‚ Test Client │───▢│ Bifrost Gateway │───▢│ AI Integration β”‚ +β”‚ β”‚ β”‚ localhost:8080 β”‚ β”‚ (OpenAI, etc.) β”‚ +β””β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜ β””β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜ β””β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜ +``` + +### URL Structure + +- **Base URL**: `http://localhost:8080` (configurable via `BIFROST_BASE_URL`) +- **Integration Endpoints**: + - OpenAI: `http://localhost:8080/openai` + - Anthropic: `http://localhost:8080/anthropic` + - Google: `http://localhost:8080/genai` + - LiteLLM: `http://localhost:8080/litellm` + +## πŸš€ Features + +- **πŸŒ‰ Bifrost Gateway Integration**: All integrations route through Bifrost proxy +- **πŸ€– Centralized Configuration**: YAML-based configuration with environment variable support +- **πŸ”§ Integration-Specific Clients**: Type-safe, integration-optimized implementations +- **πŸ“‹ Comprehensive Test Coverage**: 14 categories covering all major AI functionality +- **βš™οΈ Flexible Execution**: Selective test running with command-line flags +- **πŸ›‘οΈ Robust Error Handling**: Graceful error handling and detailed error reporting +- **🎯 Production-Ready**: Async support, timeouts, retries, and logging +- **🎡 Speech & Audio Support**: Text-to-speech synthesis and speech-to-text transcription testing +- **πŸ”— Embeddings Support**: Text-to-vector conversion and similarity analysis testing + +## πŸ“‹ Test Categories + +Our test suite covers 30 comprehensive scenarios for each integration: + +### Core Chat & Conversation Tests +1. **Simple Chat** - Basic single-message conversations +2. **Multi-turn Conversation** - Conversation history and context retention +3. **Streaming** - Real-time streaming responses and tool calls + +### Tool Calling & Function Tests +4. **Single Tool Call** - Basic function calling capabilities +5. **Multiple Tool Calls** - Multiple tools in single request +6. **End-to-End Tool Calling** - Complete tool workflow with results +7. **Automatic Function Calling** - Integration-managed tool execution + +### Image & Vision Tests +8. **Image Analysis (URL)** - Image processing from URLs +9. **Image Analysis (Base64)** - Image processing from base64 data +10. **Multiple Images** - Multi-image analysis and comparison + +### Speech & Audio Tests (OpenAI) +11. **Speech Synthesis** - Text-to-speech conversion with different voices +12. **Audio Transcription** - Speech-to-text conversion with multiple formats +13. **Transcription Streaming** - Real-time transcription processing +14. **Speech Round-Trip** - Complete textβ†’speechβ†’text workflow validation +15. **Speech Error Handling** - Invalid voice, model, and input error handling +16. **Transcription Error Handling** - Invalid audio format and model error handling +17. **Voice & Format Testing** - Multiple voices and audio format validation + +### Embeddings Tests (OpenAI) +18. **Single Text Embedding** - Basic text-to-vector conversion +19. **Batch Text Embeddings** - Multiple text embeddings in single request +20. **Embedding Similarity Analysis** - Cosine similarity testing for similar texts +21. **Embedding Dissimilarity Analysis** - Validation of different topic embeddings +22. **Different Embedding Models** - Testing various embedding model capabilities +23. **Long Text Embedding** - Handling of longer text inputs and token usage +24. **Embedding Error Handling** - Invalid model and input error processing +25. **Dimensionality Reduction** - Custom embedding dimensions (if supported) +26. **Encoding Format Testing** - Different embedding output formats +27. **Usage Tracking** - Token consumption and batch processing validation + +### Integration & Error Tests +28. **Complex End-to-End** - Comprehensive multimodal workflows +29. **Integration-Specific Features** - Integration-unique capabilities +30. **Error Handling** - Invalid request error processing and propagation + +## πŸ“ Directory Structure + +```text +transports-integrations/ +β”œβ”€β”€ config.yml # Central configuration file +β”œβ”€β”€ requirements.txt # Python dependencies +β”œβ”€β”€ run_all_tests.py # Test runner script +β”œβ”€β”€ run_integration_tests.py # Integration-specific test runner +β”œβ”€β”€ test_audio.py # Speech & transcription test runner +β”œβ”€β”€ pytest.ini # Pytest configuration +β”œβ”€β”€ Makefile # Convenience commands +β”œβ”€β”€ tests/ +β”‚ β”œβ”€β”€ conftest.py # Pytest configuration and fixtures +β”‚ β”œβ”€β”€ utils/ +β”‚ β”‚ β”œβ”€β”€ common.py # Shared test utilities and fixtures +β”‚ β”‚ β”œβ”€β”€ config_loader.py # Configuration system +β”‚ β”‚ └── models.py # Model configurations (compatibility layer) +β”‚ └── integrations/ +β”‚ β”œβ”€β”€ test_openai.py # OpenAI integration tests +β”‚ β”œβ”€β”€ test_anthropic.py # Anthropic integration tests +β”‚ β”œβ”€β”€ test_google.py # Google AI integration tests +β”‚ └── test_litellm.py # LiteLLM integration tests +``` + +## ⚑ Quick Start + +### 1. Installation + +```bash +# Clone the repository +git clone +cd bifrost/tests/transports-integrations + +# Option 1: Using Makefile (recommended) +make install + +# Option 2: Direct pip install +pip install -r requirements.txt +``` + +### 2. Configuration + +The system uses `config.yml` for centralized configuration. Set up your environment variables: + +```bash +# Required: Bifrost gateway +export BIFROST_BASE_URL="http://localhost:8080" + +# Required: Integration API keys +export OPENAI_API_KEY="your-openai-key" +export ANTHROPIC_API_KEY="your-anthropic-key" +export GOOGLE_API_KEY="your-google-api-key" + +# Optional: Integration-specific settings +export OPENAI_ORG_ID="org-..." +export OPENAI_PROJECT_ID="proj_..." +export GOOGLE_PROJECT_ID="your-project" +export GOOGLE_LOCATION="us-central1" +export TEST_ENV="development" + +# Quick check using Makefile +make check-env +``` + +### 3. Verify Configuration + +```bash +# Test the configuration system +python tests/utils/config_loader.py +``` + +This will display: + +- πŸŒ‰ Bifrost gateway URLs +- πŸ€– Model configurations +- βš™οΈ API settings +- βœ… Validation status + +### 4. Pytest Configuration + +The project includes a `pytest.ini` file with optimized settings: + +```ini +[pytest] +# Test discovery +testpaths = . +python_files = test_*.py +python_classes = Test* +python_functions = test_* + +# Output formatting +addopts = + -v + --tb=short + --strict-markers + --disable-warnings + --color=yes + +# Timeout settings (3 minutes per test) +timeout = 180 + +# Markers for test categorization +markers = + integration: marks tests as integration tests + slow: marks tests as slow running + e2e: marks tests as end-to-end tests + tool_calling: marks tests as tool calling tests +``` + +### 5. Run Tests + +```bash +# Option 1: Using Makefile (recommended for convenience) +make test # Run all tests using master runner +make test-openai # Run OpenAI tests only +make test-anthropic # Run Anthropic tests only +make test-genai # Run Google GenAI tests only +make test-litellm # Run LiteLLM tests only +make test-verbose # Run all tests with verbose output +make test-parallel # Run tests in parallel + +# Option 2: Using test runner scripts directly +python run_all_tests.py + +# Run specific integration +python run_integration_tests.py openai +python run_integration_tests.py anthropic +python run_integration_tests.py google +python run_integration_tests.py litellm + +# Option 3: Using pytest directly +pytest tests/integrations/test_openai.py -v + +# Run specific test categories +pytest tests/integrations/ -k "error_handling" -v # Run only error handling tests +pytest tests/integrations/ -k "test_12" -v # Run all 12th test cases (error handling) +``` + +#### Makefile Commands + +The project includes a `Makefile` with convenient commands: + +```bash +# Setup +make install # Install Python dependencies +make check-env # Check environment variables + +# Testing +make test # Run all tests using master runner +make test-all # Run all tests with pytest +make test-parallel # Run tests in parallel +make test-verbose # Run tests with verbose output +make test-openai # Run OpenAI integration tests only +make test-anthropic # Run Anthropic integration tests only +make test-genai # Run Google GenAI integration tests only +make test-litellm # Run LiteLLM integration tests only +make test-coverage # Run tests with coverage report + +# Development +make lint # Run code linting +make format # Format code with black +make clean # Clean up temporary files + +# Quick workflows +make quick-test # Check environment + run tests +make all-tests # Full install + check + parallel tests +make dev-setup # Setup development environment +``` + +## πŸ”§ Configuration System + +### Configuration Files + +#### 1. `config.yml` - Main Configuration + +Central configuration file containing: + +- Bifrost gateway settings and endpoints +- Model configurations for all integrations +- API settings (timeouts, retries) +- Test parameters and limits +- Environment-specific overrides +- Integration-specific settings + +#### 2. `tests/utils/config_loader.py` - Configuration Loader + +Python module that: + +- Loads and parses `config.yml` +- Expands environment variables with `${VAR:-default}` syntax +- Provides convenience functions for URLs and models +- Validates configuration completeness +- Handles error scenarios + +#### 3. `tests/utils/models.py` - Compatibility Layer + +Maintains backward compatibility while delegating to the new config system. + +### Key Configuration Sections + +#### Bifrost Gateway + +```yaml +bifrost: + base_url: "${BIFROST_BASE_URL:-http://localhost:8080}" + endpoints: + openai: "openai" + anthropic: "anthropic" + google: "genai" + litellm: "litellm" +``` + +#### Model Configurations + +```yaml +models: + openai: + chat: "gpt-3.5-turbo" + vision: "gpt-4o" + tools: "gpt-3.5-turbo" + speech: "tts-1" + transcription: "whisper-1" + alternatives: ["gpt-4", "gpt-4-turbo-preview", "gpt-4o", "gpt-4o-mini"] + speech_alternatives: ["tts-1-hd"] + transcription_alternatives: ["whisper-1"] +``` + +#### API Settings + +```yaml +api: + timeout: 30 + max_retries: 3 + retry_delay: 1 +``` + +### Usage Examples + +#### Getting Integration URLs + +```python +from tests.utils.config_loader import get_integration_url + +# Get Bifrost URL for OpenAI +openai_url = get_integration_url("openai") +# Returns: http://localhost:8080/openai + +# Get integration URL through Bifrost +openai_url = get_integration_url("openai") +# Returns: http://localhost:8080/openai +``` + +#### Getting Model Names + +```python +from tests.utils.config_loader import get_model + +# Get different model types +chat_model = get_model("openai", "chat") # "gpt-3.5-turbo" +vision_model = get_model("openai", "vision") # "gpt-4o" +speech_model = get_model("openai", "speech") # "tts-1" +transcription_model = get_model("openai", "transcription") # "whisper-1" +``` + +## 🎡 Speech & Transcription Testing + +The test suite includes comprehensive speech synthesis and transcription testing for supported integrations (currently OpenAI). + +### Speech & Audio Test Categories + +#### 1. Speech Synthesis (Text-to-Speech) +- **Basic synthesis**: Convert text to audio with different voices +- **Format testing**: Multiple audio formats (MP3, WAV, Opus) +- **Voice validation**: Test all available voices (alloy, echo, fable, onyx, nova, shimmer) +- **Parameter testing**: Response format, voice settings, and quality options + +#### 2. Speech Streaming +- **Real-time generation**: Streaming audio synthesis for large texts +- **Chunk validation**: Verify audio chunk integrity and format +- **Performance testing**: Measure streaming latency and throughput + +#### 3. Audio Transcription (Speech-to-Text) +- **File format support**: WAV, MP3, and other audio formats +- **Language detection**: Multi-language transcription capabilities +- **Parameter testing**: Language hints, response formats, temperature settings +- **Quality validation**: Transcription accuracy and completeness + +#### 4. Transcription Streaming +- **Real-time processing**: Streaming transcription for long audio files +- **Progressive results**: Incremental text output validation +- **Error handling**: Network interruption and recovery testing + +#### 5. Round-Trip Testing +- **Complete workflow**: Text β†’ Speech β†’ Transcription β†’ Text validation +- **Accuracy measurement**: Compare original text with round-trip result +- **Quality assessment**: Measure transcription fidelity and word preservation + +### Running Speech & Transcription Tests + +#### Quick Start + +```bash +# Run all speech and transcription tests +python test_audio.py + +# Run with verbose output +python test_audio.py --verbose + +# Run specific test +python test_audio.py --test test_14_speech_synthesis + +# List available tests +python test_audio.py --list +``` + +#### Individual Test Examples + +```bash +# Test speech synthesis +pytest tests/integrations/test_openai.py::TestOpenAIIntegration::test_14_speech_synthesis -v + +# Test transcription +pytest tests/integrations/test_openai.py::TestOpenAIIntegration::test_16_transcription_audio -v + +# Test round-trip workflow +pytest tests/integrations/test_openai.py::TestOpenAIIntegration::test_18_speech_transcription_round_trip -v + +# Test error handling +pytest tests/integrations/test_openai.py::TestOpenAIIntegration::test_19_speech_error_handling -v +pytest tests/integrations/test_openai.py::TestOpenAIIntegration::test_20_transcription_error_handling -v +``` + +#### Available Test Audio Types + +1. **Sine Wave**: Pure tone audio for basic testing +2. **Chord**: Multi-frequency audio for complex signal testing +3. **Frequency Sweep**: Variable frequency audio for range testing +4. **White Noise**: Random audio for noise handling testing +5. **Silence**: Empty audio for edge case testing +6. **Various Durations**: Short (0.5s) to long (10s) audio files + +### Speech & Transcription Configuration + +#### Model Configuration + +```yaml +models: + openai: + speech: "tts-1" # Default speech synthesis model + transcription: "whisper-1" # Default transcription model + speech_alternatives: ["tts-1-hd"] # Higher quality speech model + transcription_alternatives: ["whisper-1"] # Alternative transcription models + +# Model capabilities +model_capabilities: + "tts-1": + speech: true + streaming: false # Streaming support varies + max_tokens: null + context_window: null + + "whisper-1": + transcription: true + streaming: false # Streaming support varies + max_tokens: null + context_window: null +``` + +#### Test Settings + +```yaml +test_settings: + max_tokens: + speech: null # Speech doesn't use token limits + transcription: null # Transcription doesn't use token limits + + timeouts: + speech: 60 # Speech generation timeout + transcription: 60 # Transcription processing timeout +``` + +### Speech Test Examples + +#### Basic Speech Synthesis + +```python +# Test basic speech synthesis +response = openai_client.audio.speech.create( + model="tts-1", + voice="alloy", + input="Hello, this is a test of speech synthesis.", +) +audio_content = response.content +assert len(audio_content) > 1000 # Ensure substantial audio data +``` + +#### Transcription Testing + +```python +# Test audio transcription +test_audio = generate_test_audio() # Generate test WAV file +response = openai_client.audio.transcriptions.create( + model="whisper-1", + file=("test.wav", test_audio, "audio/wav"), + language="en", +) +transcribed_text = response.text +assert len(transcribed_text) > 0 # Ensure transcription occurred +``` + +#### Round-Trip Validation + +```python +# Complete round-trip test +original_text = "The quick brown fox jumps over the lazy dog." + +# Step 1: Text to speech +speech_response = openai_client.audio.speech.create( + model="tts-1", + voice="alloy", + input=original_text, + response_format="wav", +) + +# Step 2: Speech to text +transcription_response = openai_client.audio.transcriptions.create( + model="whisper-1", + file=("speech.wav", speech_response.content, "audio/wav"), +) + +# Step 3: Validate similarity +transcribed_text = transcription_response.text +# Check for key word preservation (allowing for transcription variations) +``` + +### Error Handling Tests + +#### Speech Synthesis Errors + +```python +# Test invalid voice +with pytest.raises(Exception): + openai_client.audio.speech.create( + model="tts-1", + voice="invalid_voice", + input="This should fail", + ) + +# Test empty input +with pytest.raises(Exception): + openai_client.audio.speech.create( + model="tts-1", + voice="alloy", + input="", + ) +``` + +#### Transcription Errors + +```python +# Test invalid audio format +invalid_audio = b"This is not audio data" +with pytest.raises(Exception): + openai_client.audio.transcriptions.create( + model="whisper-1", + file=("invalid.wav", invalid_audio, "audio/wav"), + ) + +# Test unsupported file type +with pytest.raises(Exception): + openai_client.audio.transcriptions.create( + model="whisper-1", + file=("test.txt", b"text content", "text/plain"), + ) +``` + +### Integration Support Matrix + +| Integration | Speech Synthesis | Transcription | Streaming | Notes | +|------------|------------------|---------------|-----------|-------| +| OpenAI | βœ… Full Support | βœ… Full Support | πŸ”„ Varies | Complete implementation | +| Anthropic | ❌ Not Available | ❌ Not Available | ❌ No | No speech/audio APIs | +| Google | ❌ Not Available* | ❌ Not Available* | ❌ No | *Not through Gemini API | +| LiteLLM | βœ… Via OpenAI | βœ… Via OpenAI | πŸ”„ Varies | Proxies to OpenAI | + +*Note: Google offers speech services through separate APIs (Cloud Speech-to-Text, Cloud Text-to-Speech) that are not currently integrated.* + +### Performance Considerations + +#### Speech Synthesis +- **File Size**: Generated audio files range from 50KB to 5MB depending on length and quality +- **Generation Time**: Typically 2-10 seconds for short texts, longer for complex content +- **Format Impact**: WAV files are larger but offer better compatibility; MP3 is more compressed + +#### Transcription +- **Processing Time**: Usually 1-5 seconds for short audio files (under 30 seconds) +- **File Size Limits**: Most services support files up to 25MB +- **Accuracy Factors**: Audio quality, background noise, speaker clarity affect results + +### Best Practices + +#### For Speech Testing +1. **Use consistent test text** for reproducible results +2. **Test multiple voices** to ensure voice switching works +3. **Validate audio headers** to confirm proper format generation +4. **Check file sizes** to ensure reasonable audio generation + +#### For Transcription Testing +1. **Use high-quality test audio** for consistent transcription results +2. **Test various audio formats** (WAV, MP3, etc.) for compatibility +3. **Include silence and noise** tests for edge case handling +4. **Validate response formats** (JSON, text) as needed + +#### For Round-Trip Testing +1. **Use simple, clear phrases** to maximize transcription accuracy +2. **Allow for minor variations** in transcribed text +3. **Focus on key word preservation** rather than exact matches +4. **Test with different voices** to ensure consistency across voice models + +### Troubleshooting + +#### Common Issues + +1. **Audio Format Errors** + ```bash + # Check audio file headers + file test_audio.wav + # Should show: RIFF (little-endian) data, WAVE audio + ``` + +2. **API Key Issues** + ```bash + # Verify OpenAI API key + export OPENAI_API_KEY="your-key-here" + python test_audio.py --test test_14_speech_synthesis + ``` + +3. **Bifrost Configuration** + ```bash + # Ensure Bifrost is running and accessible + curl http://localhost:8080/openai/v1/audio/speech -I + ``` + +4. **Model Availability** + ```python + # Check if speech/transcription models are available + from tests.utils.config_loader import get_model + print("Speech model:", get_model("openai", "speech")) + print("Transcription model:", get_model("openai", "transcription")) + ``` + +#### Debug Commands + +```bash +# Test individual components +python test_audio.py --test test_14_speech_synthesis --verbose + +# Check Bifrost logs for audio endpoint requests +# (Check your Bifrost instance logs) +``` + +## Getting Model Names + +```python +from tests.utils.config_loader import get_model + +# Get chat model for OpenAI +chat_model = get_model("openai", "chat") +# Returns: gpt-3.5-turbo + +# Get vision model for Anthropic +vision_model = get_model("anthropic", "vision") +# Returns: claude-3-haiku-20240307 +``` + +## πŸ€– Integration Support + +### Currently Supported Integrations + +#### OpenAI + +- βœ… **Full Bifrost Integration**: Complete base URL support +- βœ… **Models**: gpt-3.5-turbo, gpt-4, gpt-4o, gpt-4o-mini, text-embedding-3-small, tts-1, whisper-1 +- βœ… **Features**: Chat, tools, vision, speech synthesis, transcription, embeddings +- βœ… **Settings**: Organization/project IDs, timeouts, retries +- βœ… **All Test Categories**: 30/30 scenarios supported (including speech & embeddings) + +#### Anthropic + +- βœ… **Full Bifrost Integration**: Complete base URL support +- βœ… **Models**: claude-3-haiku-20240307, claude-3-sonnet-20240229, claude-3-opus-20240229, claude-3-5-sonnet-20241022 +- βœ… **Features**: Chat, tools, vision +- βœ… **Settings**: API version headers, timeouts, retries +- βœ… **All Test Categories**: 11/11 scenarios supported + +#### Google AI + +- βœ… **Full Bifrost Integration**: Complete custom transport implementation +- βœ… **Models**: gemini-2.0-flash-001, gemini-1.5-pro, gemini-1.5-flash, gemini-1.0-pro +- βœ… **Features**: Chat, tools, vision, multimodal processing +- βœ… **Settings**: Project ID, location, API configuration +- βœ… **All Test Categories**: 11/11 scenarios supported +- βœ… **Custom Base64 Handling**: Resolved cross-language encoding compatibility + +#### LiteLLM + +- βœ… **Full Bifrost Integration**: Global base URL configuration +- βœ… **Models**: Supports all LiteLLM-compatible models +- βœ… **Features**: Chat, tools, vision (integration-dependent) +- βœ… **Settings**: Drop params, debug mode, integration-specific configs +- βœ… **All Test Categories**: 11/11 scenarios supported +- βœ… **Multi-Integration**: OpenAI, Anthropic, Google, Azure, Cohere, Mistral, etc. + +## πŸ§ͺ Running Tests + +### Test Execution Methods + +#### 1. Using Test Runner Scripts + +##### `run_integration_tests.py` - Advanced Integration Testing + +```bash +# Basic usage - run all available integrations +python run_integration_tests.py + +# Run specific integration +python run_integration_tests.py --integrations openai + +# Run multiple integrations +python run_integration_tests.py --integrations openai anthropic google + +# Run specific test across integrations +python run_integration_tests.py --integrations openai anthropic --test "test_03_single_tool_call" + +# Run test pattern (e.g., all tool calling tests) +python run_integration_tests.py --integrations google --test "tool_call" + +# Run with verbose output +python run_integration_tests.py --integrations openai --test "test_01_simple_chat" --verbose + +# Utility commands +python run_integration_tests.py --check-keys # Check API key availability +python run_integration_tests.py --show-models # Show model configuration +``` + +##### `run_all_tests.py` - Simple Sequential Testing + +```bash +# Run all integrations sequentially +python run_all_tests.py + +# Run with custom configuration +BIFROST_BASE_URL=https://your-bifrost.com python run_all_tests.py +``` + +#### 2. Using pytest Directly + +```bash +# Run all tests for a integration +pytest tests/integrations/test_openai.py -v + +# Run specific test categories +pytest tests/integrations/test_openai.py::TestOpenAIIntegration::test_01_simple_chat -v + +# Run with coverage +pytest tests/integrations/ --cov=tests --cov-report=html + +# Run with custom markers +pytest tests/integrations/ -m "not slow" -v +``` + +#### 3. Selective Test Execution + +```bash +# Skip tests that require API keys you don't have +pytest tests/integrations/test_openai.py -v # Will skip if OPENAI_API_KEY not set + +# Run only specific test methods +pytest tests/integrations/test_anthropic.py -k "tool_call" -v + +# Run with timeout +pytest tests/integrations/ --timeout=300 -v +``` + +### πŸ” Checking and Running Specific Tests + +#### πŸš€ Quick Commands (Most Common) + +```bash +# Run specific test for specific integration (your example!) +python run_integration_tests.py --integrations google --test "test_03_single_tool_call" + +# Run all tool calling tests across multiple integrations +python run_integration_tests.py --integrations openai anthropic --test "tool_call" + +# Run all tests for one integration +python run_integration_tests.py --integrations openai -v + +# Check what integrations are available +python run_integration_tests.py --check-keys + +# Run specific test with pytest directly +pytest tests/integrations/test_google.py::TestGoogleIntegration::test_03_single_tool_call -v +``` + +#### Quick Reference: Test Categories + +```text +Test 01: Simple Chat - Basic single-message conversations +Test 02: Multi-turn Conversation - Conversation history and context +Test 03: Single Tool Call - Basic function calling +Test 04: Multiple Tool Calls - Multiple tools in one request +Test 05: End-to-End Tool Calling - Complete tool workflow with results +Test 06: Automatic Function Call - Integration-managed tool execution +Test 07: Image Analysis (URL) - Image processing from URLs +Test 08: Image Analysis (Base64) - Image processing from base64 +Test 09: Multiple Images - Multi-image analysis and comparison +Test 10: Complex End-to-End - Comprehensive multimodal workflows +Test 11: Integration-Specific - Integration-unique features +``` + +#### Listing Available Tests + +```bash +# List all tests for a specific integration +pytest tests/integrations/test_openai.py --collect-only + +# List all test methods with descriptions +pytest tests/integrations/test_openai.py --collect-only -q + +# Show test structure for all integrations +pytest tests/integrations/ --collect-only +``` + +#### Running Individual Test Categories + +```bash +# Test 1: Simple Chat +pytest tests/integrations/test_openai.py::TestOpenAIIntegration::test_01_simple_chat -v + +# Test 3: Single Tool Call +pytest tests/integrations/test_anthropic.py::TestAnthropicIntegration::test_03_single_tool_call -v + +# Test 7: Image Analysis (URL) +pytest tests/integrations/test_google.py::TestGoogleIntegration::test_07_image_url -v + +# Test 9: Multiple Images +pytest tests/integrations/test_litellm.py::TestLiteLLMIntegration::test_09_multiple_images -v + +# Test 21: Single Text Embedding (OpenAI only) +pytest tests/integrations/test_openai.py::TestOpenAIIntegration::test_21_single_text_embedding -v + +# Test 23: Embedding Similarity Analysis (OpenAI only) +pytest tests/integrations/test_openai.py::TestOpenAIIntegration::test_23_embedding_similarity_analysis -v +``` + +#### Running Test Categories by Pattern + +```bash +# Run all simple chat tests across integrations +pytest tests/integrations/ -k "test_01_simple_chat" -v + +# Run all tool calling tests (single and multiple) +pytest tests/integrations/ -k "tool_call" -v + +# Run all image-related tests +pytest tests/integrations/ -k "image" -v + +# Run all embedding tests (OpenAI only) +pytest tests/integrations/test_openai.py -k "embedding" -v + +# Run all speech and audio tests (OpenAI only) +pytest tests/integrations/test_openai.py -k "speech or transcription" -v + +# Run all end-to-end tests +pytest tests/integrations/ -k "end2end" -v + +# Run integration-specific feature tests +pytest tests/integrations/ -k "integration_specific" -v +``` + +#### Running Tests by Integration + +```bash +# Run all OpenAI tests +pytest tests/integrations/test_openai.py -v + +# Run all Anthropic tests with detailed output +pytest tests/integrations/test_anthropic.py -v -s + +# Run Google tests with coverage +pytest tests/integrations/test_google.py --cov=tests --cov-report=term-missing -v + +# Run LiteLLM tests with timing +pytest tests/integrations/test_litellm.py --durations=10 -v +``` + +#### Advanced Test Selection + +```bash +# Run tests 1-5 (basic functionality) for OpenAI +pytest tests/integrations/test_openai.py -k "test_01 or test_02 or test_03 or test_04 or test_05" -v + +# Run only vision tests (tests 7, 8, 9, 10) +pytest tests/integrations/ -k "test_07 or test_08 or test_09 or test_10" -v + +# Run tests excluding images (skip tests 7, 8, 9, 10) +pytest tests/integrations/ -k "not (test_07 or test_08 or test_09 or test_10)" -v + +# Run only tool-related tests (tests 3, 4, 5, 6) +pytest tests/integrations/ -k "test_03 or test_04 or test_05 or test_06" -v +``` + +#### Test Status and Validation + +```bash +# Check which tests would run (dry run) +pytest tests/integrations/test_openai.py --collect-only --quiet + +# Validate test setup without running +pytest tests/integrations/test_openai.py --setup-only -v + +# Run tests with immediate failure reporting +pytest tests/integrations/ -x -v # Stop on first failure + +# Run tests with detailed failure information +pytest tests/integrations/ --tb=long -v +``` + +#### Integration-Specific Test Validation + +```bash +# Check if integration supports all test categories +python -c " +from tests.integrations.test_openai import TestOpenAIIntegration +import inspect +methods = [m for m in dir(TestOpenAIIntegration) if m.startswith('test_')] +print('OpenAI Test Methods:') +for i, method in enumerate(sorted(methods), 1): + print(f' {i:2d}. {method}') +print(f'Total: {len(methods)} tests') +" + +# Verify integration configuration +python -c " +from tests.utils.config_loader import get_config, get_model +config = get_config() +integration = 'openai' +print(f'{integration.upper()} Configuration:') +for model_type in ['chat', 'vision', 'tools']: + try: + model = get_model(integration, model_type) + print(f' {model_type}: {model}') + except Exception as e: + print(f' {model_type}: ERROR - {e}') +" +``` + +#### Test Results Analysis + +```bash +# Run tests with detailed reporting +pytest tests/integrations/test_openai.py -v --tb=short --report=term-missing + +# Generate HTML test report +pytest tests/integrations/ --html=test_report.html --self-contained-html + +# Run tests with JSON output for analysis +pytest tests/integrations/test_openai.py --json-report --json-report-file=openai_results.json + +# Compare test results across integrations +pytest tests/integrations/ -v | grep -E "(PASSED|FAILED|SKIPPED)" | sort +``` + +#### Debugging Specific Tests + +```bash +# Debug a failing test with full output +pytest tests/integrations/test_openai.py::TestOpenAIIntegration::test_03_single_tool_call -v -s --tb=long + +# Run test with Python debugger +pytest tests/integrations/test_openai.py::TestOpenAIIntegration::test_03_single_tool_call --pdb + +# Run test with custom logging +pytest tests/integrations/test_openai.py::TestOpenAIIntegration::test_03_single_tool_call --log-cli-level=DEBUG -s + +# Test with environment variable override +OPENAI_API_KEY=sk-test pytest tests/integrations/test_openai.py::TestOpenAIIntegration::test_01_simple_chat -v +``` + +#### Practical Testing Scenarios + +```bash +# Scenario 1: Test a new integration integration +# 1. Check configuration +python tests/utils/config_loader.py + +# 2. List available tests +pytest tests/integrations/test_your_integration.py --collect-only + +# 3. Run basic tests first (using test runner) +python run_integration_tests.py --integrations your_integration --test "test_01 or test_02" -v + +# 4. Test tool calling if supported (using test runner) +python run_integration_tests.py --integrations your_integration --test "tool_call" -v + +# Alternative: Direct pytest approach +pytest tests/integrations/test_your_integration.py -k "test_01 or test_02" -v +pytest tests/integrations/test_your_integration.py -k "tool_call" -v + +# Scenario 2: Debug a failing tool call test +# 1. Run with full debugging +pytest tests/integrations/test_openai.py::TestOpenAIIntegration::test_03_single_tool_call -v -s --tb=long + +# 2. Check tool extraction function +python -c " +from tests.integrations.test_openai import extract_openai_tool_calls +print('Tool extraction function available:', callable(extract_openai_tool_calls)) +" + +# 3. Test with different model +OPENAI_CHAT_MODEL=gpt-4 pytest tests/integrations/test_openai.py::TestOpenAIIntegration::test_03_single_tool_call -v + +# Scenario 3: Compare integration capabilities +# Run the same test across all integrations (using test runner) +python run_integration_tests.py --integrations openai anthropic google litellm --test "test_01_simple_chat" -v + +# Alternative: Direct pytest approach +pytest tests/integrations/ -k "test_01_simple_chat" -v --tb=short + +# Scenario 4: Test only supported features +# For a integration that doesn't support images +pytest tests/integrations/test_your_integration.py -k "not (test_07 or test_08 or test_09 or test_10)" -v + +# Scenario 5: Performance testing +# Run with timing to identify slow tests +pytest tests/integrations/test_openai.py --durations=0 -v + +# Scenario 6: Continuous integration testing +# Run all tests with coverage and reports +pytest tests/integrations/ --cov=tests --cov-report=xml --junit-xml=test_results.xml -v +``` + +#### Test Output Examples + +```bash +# Successful test run +$ pytest tests/integrations/test_openai.py::TestOpenAIIntegration::test_01_simple_chat -v +========================= test session starts ========================= +tests/integrations/test_openai.py::TestOpenAIIntegration::test_01_simple_chat PASSED [100%] +βœ“ OpenAI simple chat test passed +Response: "Hello! I'm Claude, an AI assistant. How can I help you today?" + +# Failed test with debugging info +$ pytest tests/integrations/test_openai.py::TestOpenAIIntegration::test_03_single_tool_call -v -s +========================= FAILURES ========================= +_____________ TestOpenAIIntegration.test_03_single_tool_call _____________ +AssertionError: Expected tool calls but got none +Response content: "I can help with weather information, but I need a specific location." +Tool calls found: [] + +# Test collection output +$ pytest tests/integrations/test_openai.py --collect-only -q +tests/integrations/test_openai.py::TestOpenAIIntegration::test_01_simple_chat +tests/integrations/test_openai.py::TestOpenAIIntegration::test_02_multi_turn_conversation +tests/integrations/test_openai.py::TestOpenAIIntegration::test_03_single_tool_call +tests/integrations/test_openai.py::TestOpenAIIntegration::test_04_multiple_tool_calls +tests/integrations/test_openai.py::TestOpenAIIntegration::test_05_end2end_tool_calling +tests/integrations/test_openai.py::TestOpenAIIntegration::test_06_automatic_function_calling +tests/integrations/test_openai.py::TestOpenAIIntegration::test_07_image_url +tests/integrations/test_openai.py::TestOpenAIIntegration::test_08_image_base64 +tests/integrations/test_openai.py::TestOpenAIIntegration::test_09_multiple_images +tests/integrations/test_openai.py::TestOpenAIIntegration::test_10_complex_end2end +tests/integrations/test_openai.py::TestOpenAIIntegration::test_11_integration_specific_features +11 tests collected + +# Test runner script output +$ python run_integration_tests.py --integrations google --test "test_03_single_tool_call" -v +πŸš€ Starting integration tests... +πŸ“‹ Testing integrations: google +============================================================ +πŸ§ͺ TESTING GOOGLE INTEGRATION +============================================================ +========================= test session starts ========================= +tests/integrations/test_google.py::TestGoogleIntegration::test_03_single_tool_call PASSED [100%] +βœ… GOOGLE tests PASSED + +================================================================================ +🎯 FINAL SUMMARY +================================================================================ + +πŸ”‘ API Key Status: + βœ… GOOGLE: Available + +πŸ“Š Test Results: + βœ… GOOGLE: All tests passed + +πŸ† Overall Results: + Integrations tested: 1 + Integrations passed: 1 + Success rate: 100.0% +``` + +### Environment Variables + +#### Required Variables + +```bash +# Bifrost gateway (required) +export BIFROST_BASE_URL="http://localhost:8080" + +# Integration API keys (at least one required) +export OPENAI_API_KEY="sk-..." +export ANTHROPIC_API_KEY="sk-ant-..." +export GOOGLE_API_KEY="AIza..." +``` + +#### Optional Variables + +```bash +# Integration-specific settings +export OPENAI_ORG_ID="org-..." +export OPENAI_PROJECT_ID="proj_..." +export GOOGLE_PROJECT_ID="your-project" +export GOOGLE_LOCATION="us-central1" + +# Environment configuration +export TEST_ENV="development" # or "production" +``` + +### Test Output and Debugging + +#### Understanding Test Results + +```bash +# Successful test output +βœ“ OpenAI Integration Tests + βœ“ test_01_simple_chat - Response: "Hello! How can I help you today?" + βœ“ test_03_single_tool_call - Tool called: get_weather(location="New York") + βœ“ test_07_image_url - Image analyzed successfully + +# Failed test output +βœ— test_03_single_tool_call - AssertionError: Expected tool calls but got none + Response content: "I can help with weather, but I need a specific location." +``` + +#### Debug Mode + +```bash +# Enable verbose output +pytest tests/integrations/test_openai.py -v -s + +# Show full tracebacks +pytest tests/integrations/test_openai.py --tb=long + +# Enable debug logging +pytest tests/integrations/test_openai.py --log-cli-level=DEBUG +``` + +## πŸ”¨ Adding New Integrations + +### Step-by-Step Guide + +#### 1. Update Configuration + +Add your integration to `config.yml`: + +```yaml +# Add to bifrost endpoints +bifrost: + endpoints: + your_integration: "/your_integration" + +# Add model configuration +models: + your_integration: + chat: "your-chat-model" + vision: "your-vision-model" + tools: "your-tools-model" + alternatives: ["alternative-model-1", "alternative-model-2"] + +# Add model capabilities +model_capabilities: + "your-chat-model": + chat: true + tools: true + vision: false + max_tokens: 4096 + context_window: 8192 + +# Add integration settings +integration_settings: + your_integration: + api_version: "v1" + custom_header: "value" +``` + +#### 2. Create Integration Test File + +Create `tests/integrations/test_your_integration.py`: + +```python +""" +Your Integration Tests + +Tests all 11 core scenarios using Your Integration SDK. +""" + +import pytest +from your_integration_sdk import YourIntegrationClient + +from ..utils.common import ( + Config, + SIMPLE_CHAT_MESSAGES, + MULTI_TURN_MESSAGES, + # ... import all test fixtures + get_api_key, + skip_if_no_api_key, + get_model, +) + + +@pytest.fixture +def your_integration_client(): + """Create Your Integration client for testing""" + from ..utils.config_loader import get_integration_url, get_config + + api_key = get_api_key("your_integration") + base_url = get_integration_url("your_integration") + + # Get additional integration settings + config = get_config() + integration_settings = config.get_integration_settings("your_integration") + api_config = config.get_api_config() + + client_kwargs = { + "api_key": api_key, + "base_url": base_url, + "timeout": api_config.get("timeout", 30), + "max_retries": api_config.get("max_retries", 3), + } + + # Add integration-specific settings + if integration_settings.get("api_version"): + client_kwargs["api_version"] = integration_settings["api_version"] + + return YourIntegrationClient(**client_kwargs) + + +@pytest.fixture +def test_config(): + """Test configuration""" + return Config() + + +class TestYourIntegrationIntegration: + """Test suite for Your Integration covering all 11 core scenarios""" + + @skip_if_no_api_key("your_integration") + def test_01_simple_chat(self, your_integration_client, test_config): + """Test Case 1: Simple chat interaction""" + response = your_integration_client.chat.create( + model=get_model("your_integration", "chat"), + messages=SIMPLE_CHAT_MESSAGES, + max_tokens=100, + ) + + assert_valid_chat_response(response) + assert response.content is not None + assert len(response.content) > 0 + + # ... implement all 11 test methods following the same pattern + # See existing integration test files for complete examples + + +def extract_your_integration_tool_calls(response) -> List[Dict[str, Any]]: + """Extract tool calls from Your Integration response format""" + tool_calls = [] + + # Implement based on your integration's response format + if hasattr(response, 'tool_calls') and response.tool_calls: + for tool_call in response.tool_calls: + tool_calls.append({ + "name": tool_call.function.name, + "arguments": json.loads(tool_call.function.arguments) + }) + + return tool_calls +``` + +#### 3. Update Common Utilities + +Add your integration to `tests/utils/common.py`: + +```python +def get_api_key(integration: str) -> str: + """Get API key for integration""" + key_map = { + "openai": "OPENAI_API_KEY", + "anthropic": "ANTHROPIC_API_KEY", + "google": "GOOGLE_API_KEY", + "litellm": "LITELLM_API_KEY", + "your_integration": "YOUR_INTEGRATION_API_KEY", # Add this line + } + + env_var = key_map.get(integration) + if not env_var: + raise ValueError(f"Unknown integration: {integration}") + + api_key = os.getenv(env_var) + if not api_key: + raise ValueError(f"{env_var} environment variable not set") + + return api_key +``` + +#### 4. Add Integration-Specific Tool Extraction + +Update the tool extraction functions in your test file: + +```python +def extract_your_integration_tool_calls(response: Any) -> List[Dict[str, Any]]: + """Extract tool calls from Your Integration response format""" + tool_calls = [] + + try: + # Implement based on your integration's response structure + # Example for a hypothetical integration: + if hasattr(response, 'function_calls'): + for fc in response.function_calls: + tool_calls.append({ + "name": fc.name, + "arguments": fc.parameters + }) + + return tool_calls + + except Exception as e: + print(f"Error extracting tool calls: {e}") + return [] +``` + +#### 5. Test Your Implementation + +```bash +# Set up environment +export YOUR_INTEGRATION_API_KEY="your-api-key" +export BIFROST_BASE_URL="http://localhost:8080" + +# Test configuration +python tests/utils/config_loader.py + +# Run your integration tests +pytest tests/integrations/test_your_integration.py -v + +# Run specific test +pytest tests/integrations/test_your_integration.py::TestYourIntegrationIntegration::test_01_simple_chat -v +``` + +### 🎯 Key Implementation Points + +#### 1. **Follow the Pattern** + +- Use existing integration test files as templates +- Implement all 11 test scenarios +- Follow the same naming conventions and structure + +#### 2. **Handle Integration Differences** + +```python +# Example: Different response formats +def assert_valid_chat_response(response): + """Validate chat response - adapt for your integration""" + if hasattr(response, 'choices'): # OpenAI-style + assert response.choices[0].message.content + elif hasattr(response, 'content'): # Anthropic-style + assert response.content[0].text + elif hasattr(response, 'text'): # Google-style + assert response.text + # Add your integration's format here +``` + +#### 3. **Implement Tool Calling** + +```python +def convert_to_your_integration_tools(tools: List[Dict[str, Any]]) -> List[Dict[str, Any]]: + """Convert common tool format to your integration's format""" + your_integration_tools = [] + + for tool in tools: + # Convert to your integration's tool schema + your_integration_tools.append({ + "name": tool["name"], + "description": tool["description"], + "parameters": tool["parameters"], + # Add integration-specific fields + }) + + return your_integration_tools +``` + +#### 4. **Handle Image Processing** + +```python +def convert_to_your_integration_messages(messages: List[Dict[str, Any]]) -> List[Dict[str, Any]]: + """Convert common message format to your integration's format""" + your_integration_messages = [] + + for msg in messages: + if isinstance(msg.get("content"), list): + # Handle multimodal content (text + images) + content = [] + for item in msg["content"]: + if item["type"] == "text": + content.append({"type": "text", "text": item["text"]}) + elif item["type"] == "image_url": + # Convert to your integration's image format + content.append({ + "type": "image", + "source": item["image_url"]["url"] + }) + your_integration_messages.append({"role": msg["role"], "content": content}) + else: + your_integration_messages.append(msg) + + return your_integration_messages +``` + +#### 5. **Error Handling** + +```python +@skip_if_no_api_key("your_integration") +def test_03_single_tool_call(self, your_integration_client, test_config): + """Test Case 3: Single tool call""" + try: + response = your_integration_client.chat.create( + model=get_model("your_integration", "tools"), + messages=SINGLE_TOOL_CALL_MESSAGES, + tools=convert_to_your_integration_tools([WEATHER_TOOL]), + max_tokens=100, + ) + + assert_has_tool_calls(response, expected_count=1) + tool_calls = extract_your_integration_tool_calls(response) + assert tool_calls[0]["name"] == "get_weather" + assert "location" in tool_calls[0]["arguments"] + + except Exception as e: + pytest.skip(f"Tool calling not supported or failed: {e}") +``` + +### πŸ” Testing Checklist + +Before submitting your integration implementation: + +- [ ] **Configuration**: Integration added to `config.yml` with all required sections +- [ ] **Environment**: API key environment variable documented and tested +- [ ] **All 11 Tests**: Every test scenario implemented and passing +- [ ] **Tool Extraction**: Integration-specific tool call extraction function +- [ ] **Message Conversion**: Proper handling of multimodal messages +- [ ] **Error Handling**: Graceful handling of unsupported features +- [ ] **Documentation**: Integration added to README with capabilities +- [ ] **Bifrost Integration**: Base URL properly configured and tested + +### 🚨 Common Pitfalls + +1. **Incorrect Response Parsing**: Each integration has different response formats +2. **Tool Schema Differences**: Tool calling schemas vary significantly +3. **Image Format Handling**: Base64 vs URL handling differs per integration +4. **Missing Error Handling**: Some integrations don't support all features +5. **Configuration Errors**: Forgetting to add integration to all config sections + +## πŸ”§ Troubleshooting + +### Common Issues + +#### 1. Configuration Problems + +```bash +# Error: Configuration file not found +FileNotFoundError: Configuration file not found: config.yml + +# Solution: Ensure config.yml exists in project root +ls -la config.yml +``` + +#### 2. Integration Connection Issues + +```bash +# Error: Connection refused to Bifrost +ConnectionError: Connection refused to localhost:8080 + +# Solutions: +# 1. Check if Bifrost is running +curl http://localhost:8080/health + +# 2. Ensure BIFROST_BASE_URL is set correctly +echo $BIFROST_BASE_URL +``` + +#### 3. API Key Issues + +```bash +# Error: API key not set +ValueError: OPENAI_API_KEY environment variable not set + +# Solution: Set required environment variables +export OPENAI_API_KEY="sk-..." +export ANTHROPIC_API_KEY="sk-ant-..." +export GOOGLE_API_KEY="AIza..." +``` + +#### 4. Model Configuration Errors + +```bash +# Error: Unknown model type +ValueError: Unknown model type 'vision' for integration 'your_integration' + +# Solution: Check config.yml has all model types defined +python tests/utils/config_loader.py +``` + +#### 5. Test Failures + +```bash +# Error: Tool calls not found +AssertionError: Response should contain tool calls + +# Debug steps: +# 1. Check if integration supports tool calling +# 2. Verify tool extraction function +# 3. Check integration-specific tool format +pytest tests/integrations/test_openai.py::TestOpenAIIntegration::test_03_single_tool_call -v -s +``` + +### Debug Mode + +Enable comprehensive debugging: + +```bash +# Full verbose output with debugging +pytest tests/integrations/test_openai.py -v -s --tb=long --log-cli-level=DEBUG + +# Test configuration system +python tests/utils/config_loader.py + +# Check specific integration URL +python -c " +from tests.utils.config_loader import get_integration_url, get_model +print('OpenAI URL:', get_integration_url('openai')) +print('OpenAI Chat Model:', get_model('openai', 'chat')) +" +``` + +## πŸ“š Additional Resources + +### Configuration Examples + +- See `config.yml` for complete configuration reference +- Check `tests/utils/config_loader.py` for usage examples +- Review integration test files for implementation patterns + +### Contributing + +1. Fork the repository +2. Create feature branch: `git checkout -b feature/new-integration` +3. Follow the integration implementation guide above +4. Add comprehensive tests and documentation +5. Submit pull request with test results + +## πŸ†˜ Support + +For issues and questions: + +- Create GitHub issues for bugs and feature requests +- Check existing issues for solutions +- Review integration-specific documentation +- Test configuration with `python tests/utils/config_loader.py` + +--- + +**Note**: This test suite is designed for testing AI integrations through Bifrost proxy. Ensure your Bifrost instance is properly configured and running before executing tests. The configuration system provides Bifrost routing for maximum flexibility. diff --git a/tests/integrations/config.yml b/tests/integrations/config.yml new file mode 100644 index 000000000..5ed543de3 --- /dev/null +++ b/tests/integrations/config.yml @@ -0,0 +1,342 @@ +# Bifrost Integration Tests Configuration +# This file centralizes all configuration for AI integration clients and test settings + +# Bifrost Gateway Configuration +# All integrations route through Bifrost as a proxy/gateway +bifrost: + base_url: "${BIFROST_BASE_URL:-http://localhost:8080}" + + # Integration-specific endpoints (suffixes appended to base_url) + endpoints: + openai: "openai" + anthropic: "anthropic" + google: "genai" + litellm: "litellm" + langchain: "langchain" + + # Full URLs constructed as: {base_url.rstrip('/')}/{endpoints[integration]} + # Examples: + # - OpenAI: http://localhost:8080/openai + # - Anthropic: http://localhost:8080/anthropic + # - Google: http://localhost:8080/genai + # - LiteLLM: http://localhost:8080/litellm + # - LangChain: http://localhost:8080/langchain + +# API Configuration +api: + timeout: 30 # seconds + max_retries: 3 + retry_delay: 1 # seconds + +# Model configurations for each integration +models: + openai: + chat: "gpt-3.5-turbo" + vision: "gpt-4o" + tools: "gpt-3.5-turbo" + speech: "tts-1" + transcription: "whisper-1" + embeddings: "text-embedding-3-small" + alternatives: + - "gpt-4" + - "gpt-4-turbo-preview" + - "gpt-4o" + - "gpt-4o-mini" + speech_alternatives: + - "tts-1-hd" + transcription_alternatives: + - "whisper-1" + embeddings_alternatives: + - "text-embedding-3-large" + - "text-embedding-ada-002" + + anthropic: + chat: "claude-3-haiku-20240307" + vision: "claude-3-haiku-20240307" + tools: "claude-3-haiku-20240307" + speech: null # Anthropic doesn't support speech synthesis + transcription: null # Anthropic doesn't support transcription + alternatives: + - "claude-3-sonnet-20240229" + - "claude-3-opus-20240229" + - "claude-3-5-sonnet-20241022" + + google: + chat: "gemini-2.0-flash-001" + vision: "gemini-2.0-flash-001" + tools: "gemini-2.0-flash-001" + speech: null # Google doesn't expose speech synthesis through Gemini API + transcription: null # Google doesn't expose transcription through Gemini API + alternatives: + - "gemini-1.5-pro" + - "gemini-1.5-flash" + - "gemini-1.0-pro" + + litellm: + chat: "gpt-3.5-turbo" # Uses OpenAI by default + vision: "gpt-4o" # Uses OpenAI vision + tools: "gpt-3.5-turbo" # Uses OpenAI for tools + speech: "tts-1" # Uses OpenAI TTS through LiteLLM + transcription: "whisper-1" # Uses OpenAI Whisper through LiteLLM + embeddings: "text-embedding-3-small" # Uses OpenAI embeddings through LiteLLM + alternatives: + - "claude-3-haiku-20240307" # Anthropic via LiteLLM + - "gemini-2.0-flash-001" # Google via LiteLLM + - "gpt-4" # OpenAI GPT-4 + - "mistral-7b-instruct" # Mistral via LiteLLM + - "command-r-plus" # Cohere via LiteLLM + + langchain: + chat: "gpt-3.5-turbo" # OpenAI models via LangChain + vision: "gpt-4o" # OpenAI vision via LangChain + tools: "gpt-3.5-turbo" # Function calling via LangChain + speech: "tts-1" # OpenAI TTS via LangChain + transcription: "whisper-1" # OpenAI Whisper via LangChain + embeddings: "text-embedding-3-small" # OpenAI embeddings via LangChain + alternatives: + - "claude-3-haiku-20240307" # Anthropic via LangChain + - "gemini-2.0-flash-001" # Google via LangChain + - "gpt-4" # OpenAI GPT-4 via LangChain + +# Model capabilities matrix +model_capabilities: + # OpenAI Models + "gpt-3.5-turbo": + chat: true + tools: true + vision: false + streaming: true + max_tokens: 4096 + context_window: 4096 + + "gpt-4": + chat: true + tools: true + vision: false + streaming: true + max_tokens: 8192 + context_window: 8192 + + "gpt-4o": + chat: true + tools: true + vision: true + streaming: true + max_tokens: 4096 + context_window: 128000 + + "gpt-4o-mini": + chat: true + tools: true + vision: true + streaming: true + speech: false + transcription: false + max_tokens: 4096 + context_window: 128000 + + # OpenAI Speech Models + "tts-1": + chat: false + tools: false + vision: false + streaming: false + speech: true + transcription: false + max_tokens: null + context_window: null + + "tts-1-hd": + chat: false + tools: false + vision: false + streaming: false + speech: true + transcription: false + max_tokens: null + context_window: null + + # OpenAI Transcription Models + "whisper-1": + chat: false + tools: false + vision: false + streaming: false + speech: false + transcription: true + embeddings: false + max_tokens: null + context_window: null + + # OpenAI Embedding Models + "text-embedding-3-small": + chat: false + tools: false + vision: false + streaming: false + speech: false + transcription: false + embeddings: true + max_tokens: null + context_window: 8191 + dimensions: 1536 + + "text-embedding-3-large": + chat: false + tools: false + vision: false + streaming: false + speech: false + transcription: false + embeddings: true + max_tokens: null + context_window: 8191 + dimensions: 3072 + + "text-embedding-ada-002": + chat: false + tools: false + vision: false + streaming: false + speech: false + transcription: false + embeddings: true + max_tokens: null + context_window: 8191 + dimensions: 1536 + + # Anthropic Models + "claude-3-haiku-20240307": + chat: true + tools: true + vision: true + streaming: true + max_tokens: 4096 + context_window: 200000 + + "claude-3-sonnet-20240229": + chat: true + tools: true + vision: true + streaming: true + max_tokens: 4096 + context_window: 200000 + + "claude-3-opus-20240229": + chat: true + tools: true + vision: true + streaming: true + max_tokens: 4096 + context_window: 200000 + + # Google Models + "gemini-pro": + chat: true + tools: true + vision: false + streaming: true + max_tokens: 8192 + context_window: 32768 + + "gemini-2.0-flash-001": + chat: true + tools: true + vision: true + streaming: true + max_tokens: 8192 + context_window: 32768 + + "gemini-1.5-pro": + chat: true + tools: true + vision: true + streaming: true + max_tokens: 8192 + context_window: 1000000 + + # Mistral Models + "mistral-7b-instruct": + chat: true + tools: false + vision: false + streaming: true + max_tokens: 4096 + context_window: 32768 + + "mistral-8x7b-instruct": + chat: true + tools: true + vision: false + streaming: true + max_tokens: 4096 + context_window: 32768 + +# Test configuration +test_settings: + # Maximum tokens for test responses + max_tokens: + chat: 100 + vision: 200 + tools: 100 + complex: 300 + speech: null # Speech doesn't use token limits + transcription: null # Transcription doesn't use token limits + embeddings: null # Embeddings don't use token limits (text is the input) + + # Timeout settings for tests + timeouts: + simple: 30 # seconds + complex: 60 # seconds + + # Retry settings for flaky tests + retries: + max_attempts: 3 + delay: 2 # seconds + +# Integration-specific settings +integration_settings: + openai: + organization: "${OPENAI_ORG_ID:-}" + project: "${OPENAI_PROJECT_ID:-}" + + anthropic: + version: "2023-06-01" + + google: + project_id: "${GOOGLE_PROJECT_ID:-}" + location: "${GOOGLE_LOCATION:-us-central1}" + + litellm: + drop_params: true + debug: false + + langchain: + debug: false + streaming: true + +# Environment-specific overrides +environments: + development: + api: + timeout: 60 + max_retries: 5 + test_settings: + timeouts: + simple: 60 + complex: 120 + + production: + api: + timeout: 15 + max_retries: 2 + test_settings: + timeouts: + simple: 20 + complex: 40 + +# Logging configuration +logging: + level: "INFO" + format: "%(asctime)s - %(name)s - %(levelname)s - %(message)s" + file: "tests.log" diff --git a/tests/integrations/pytest.ini b/tests/integrations/pytest.ini new file mode 100644 index 000000000..6c53a50ea --- /dev/null +++ b/tests/integrations/pytest.ini @@ -0,0 +1,27 @@ +[pytest] +# Test discovery +testpaths = . +python_files = test_*.py +python_classes = Test* +python_functions = test_* + +# Output formatting +addopts = + -v + --tb=short + --strict-markers + --disable-warnings + --color=yes + +# Timeout settings (3 minutes per test) +timeout = 180 + +# Markers for test categorization +markers = + integration: marks tests as integration tests + slow: marks tests as slow running + e2e: marks tests as end-to-end tests + tool_calling: marks tests as tool calling tests + +# Minimum version +minversion = 7.0 \ No newline at end of file diff --git a/tests/integrations/requirements.txt b/tests/integrations/requirements.txt new file mode 100644 index 000000000..32adb34b5 --- /dev/null +++ b/tests/integrations/requirements.txt @@ -0,0 +1,43 @@ +# Core testing framework +pytest>=7.0.0 +pytest-asyncio>=0.21.0 + +# Environment and configuration +python-dotenv>=1.0.0 +PyYAML>=6.0 + +# Image processing +Pillow>=9.0.0 + +# HTTP requests for debugging +requests>=2.28.0 + +# Type hints +typing-extensions>=4.0.0 + +# Optional: For better test reporting +pytest-html>=3.1.0 +pytest-cov>=4.0.0 + +# AI/ML SDK dependencies +openai>=1.30.0 +anthropic>=0.25.0 +litellm>=1.35.0 +langchain-openai>=0.1.0 +langchain-core>=0.2.0 +langchain-anthropic>=0.1.0 +langchain-google-genai>=1.0.0 +langchain-mistralai>=0.1.0 +langgraph>=0.1.0 +mistralai>=0.4.0 +google-genai>=1.0.0 + +# Optional testing utilities +httpx>=0.25.0 +pytest-timeout>=2.1.0 +pytest-mock>=3.11.0 + +# Development dependencies (optional) +black>=23.0.0 # Code formatting +flake8>=6.0.0 # Linting +mypy>=1.5.0 # Type checking \ No newline at end of file diff --git a/tests/integrations/run_all_tests.py b/tests/integrations/run_all_tests.py new file mode 100755 index 000000000..953fff318 --- /dev/null +++ b/tests/integrations/run_all_tests.py @@ -0,0 +1,338 @@ +#!/usr/bin/env python3 +""" +Bifrost Integration End-to-End Test Runner + +This script runs all integration end-to-end tests for Bifrost. +It can run tests individually or all together, providing comprehensive +reporting and flexible execution options. + +Usage: + python run_all_tests.py # Run all tests + python run_all_tests.py --integration openai # Run specific integration + python run_all_tests.py --list # List available integrations + python run_all_tests.py --parallel # Run tests in parallel + python run_all_tests.py --verbose # Verbose output +""" + +import argparse +import subprocess +import sys +import time +import os +from pathlib import Path +from typing import List, Dict, Optional +import concurrent.futures +from dotenv import load_dotenv + +# Load environment variables +load_dotenv() + + +class BifrostTestRunner: + """Main test runner for Bifrost integration tests""" + + def __init__(self): + self.test_dir = Path(__file__).parent + self.integrations = { + "openai": { + "file": "tests/integrations/test_openai.py", + "description": "OpenAI Python SDK integration tests", + "env_vars": ["OPENAI_API_KEY"], + }, + "anthropic": { + "file": "tests/integrations/test_anthropic.py", + "description": "Anthropic Python SDK integration tests", + "env_vars": ["ANTHROPIC_API_KEY"], + }, + "litellm": { + "file": "tests/integrations/test_litellm.py", + "description": "LiteLLM integration tests", + "env_vars": ["OPENAI_API_KEY"], # LiteLLM can use OpenAI key + }, + "langchain": { + "file": "tests/integrations/test_langchain.py", + "description": "LangChain integration tests", + "env_vars": [ + "OPENAI_API_KEY", + "ANTHROPIC_API_KEY", + ], # LangChain uses multiple providers + }, + "google": { + "file": "tests/integrations/test_google.py", + "description": "Google GenAI integration tests", + "env_vars": ["GOOGLE_API_KEY"], + }, + } + + self.results = {} + + def check_environment(self, integration: str) -> bool: + """Check if required environment variables are set for an integration""" + config = self.integrations[integration] + missing_vars = [] + + for var in config["env_vars"]: + if not os.getenv(var): + missing_vars.append(var) + + if missing_vars: + print( + f"⚠ Skipping {integration}: Missing environment variables: {', '.join(missing_vars)}" + ) + return False + + return True + + def run_integration_test(self, integration: str, verbose: bool = False) -> Dict: + """Run tests for a specific integration""" + if integration not in self.integrations: + return {"success": False, "error": f"Unknown integration: {integration}"} + + config = self.integrations[integration] + test_file = self.test_dir / config["file"] + + if not test_file.exists(): + return {"success": False, "error": f"Test file not found: {test_file}"} + + # Check environment variables + if not self.check_environment(integration): + return { + "success": False, + "error": "Missing required environment variables", + "skipped": True, + } + + print(f"\n{'='*60}") + print(f"Running {integration.upper()} Integration Tests") + print(f"{'='*60}") + print(f"Description: {config['description']}") + print(f"Test file: {config['file']}") + + start_time = time.time() + + try: + # Run the test with pytest + cmd = [sys.executable, "-m", "pytest", str(test_file)] + + # Add pytest flags for better output + if verbose: + cmd.extend(["-v", "-s"]) # verbose and don't capture output + else: + cmd.append("-q") # quiet mode + + if verbose: + result = subprocess.run( + cmd, cwd=self.test_dir, text=True, capture_output=False, timeout=300 + ) + else: + result = subprocess.run( + cmd, cwd=self.test_dir, text=True, capture_output=True, timeout=300 + ) + + elapsed_time = time.time() - start_time + + success = result.returncode == 0 + + return { + "success": success, + "return_code": result.returncode, + "stdout": result.stdout if not verbose else "", + "stderr": result.stderr if not verbose else "", + "elapsed_time": elapsed_time, + } + + except subprocess.TimeoutExpired: + return { + "success": False, + "error": "Test timed out (5 minutes)", + "elapsed_time": 300, + } + except Exception as e: + return { + "success": False, + "error": str(e), + "elapsed_time": time.time() - start_time, + } + + def run_all_tests(self, parallel: bool = False, verbose: bool = False) -> None: + """Run all integration tests""" + print("Bifrost Integration End-to-End Test Suite") + print("=" * 50) + print(f"Running tests for {len(self.integrations)} integrations") + print(f"Parallel execution: {'Enabled' if parallel else 'Disabled'}") + print(f"Verbose output: {'Enabled' if verbose else 'Disabled'}") + + # Check Bifrost availability + bifrost_url = os.getenv("BIFROST_BASE_URL", "http://localhost:8080") + print(f"Bifrost URL: {bifrost_url}") + + start_time = time.time() + + if parallel: + self._run_parallel(verbose) + else: + self._run_sequential(verbose) + + total_time = time.time() - start_time + self._print_summary(total_time) + + def _run_sequential(self, verbose: bool) -> None: + """Run tests sequentially""" + for integration in self.integrations: + self.results[integration] = self.run_integration_test(integration, verbose) + + def _run_parallel(self, verbose: bool) -> None: + """Run tests in parallel""" + print("\nRunning tests in parallel...") + + with concurrent.futures.ThreadPoolExecutor(max_workers=3) as executor: + # Submit all tests + future_to_integration = { + executor.submit( + self.run_integration_test, integration, verbose + ): integration + for integration in self.integrations + } + + # Collect results + for future in concurrent.futures.as_completed(future_to_integration): + integration = future_to_integration[future] + try: + self.results[integration] = future.result() + except Exception as e: + self.results[integration] = {"success": False, "error": str(e)} + + def _print_summary(self, total_time: float) -> None: + """Print test summary""" + print(f"\n{'='*60}") + print("TEST SUMMARY") + print(f"{'='*60}") + + passed = 0 + failed = 0 + skipped = 0 + + for integration, result in self.results.items(): + status = ( + "SKIPPED" + if result.get("skipped") + else ("PASSED" if result["success"] else "FAILED") + ) + elapsed = result.get("elapsed_time", 0) + + if result.get("skipped"): + skipped += 1 + print( + f"⚠ {integration:12} {status:8} - {result.get('error', 'Unknown error')}" + ) + elif result["success"]: + passed += 1 + print(f"βœ“ {integration:12} {status:8} - {elapsed:.2f}s") + else: + failed += 1 + error_msg = result.get("error", "Unknown error") + print(f"βœ— {integration:12} {status:8} - {error_msg}") + + # Print stderr if available + if "stderr" in result and result["stderr"]: + print(f" Error output: {result['stderr'][:200]}...") + + print(f"\n{'='*60}") + print( + f"Total: {len(self.integrations)} | Passed: {passed} | Failed: {failed} | Skipped: {skipped}" + ) + print(f"Total time: {total_time:.2f} seconds") + print(f"{'='*60}") + + # Exit with appropriate code + if failed > 0: + sys.exit(1) + else: + print("All tests completed successfully!") + + def list_integrations(self) -> None: + """List available integrations""" + print("Available Integrations:") + print("=" * 30) + + for integration, config in self.integrations.items(): + env_status = "βœ“" if self.check_environment(integration) else "βœ—" + print(f"{env_status} {integration:12} - {config['description']}") + print(f" Required env vars: {', '.join(config['env_vars'])}") + print() + + +def main(): + parser = argparse.ArgumentParser( + description="Run Bifrost integration end-to-end tests", + formatter_class=argparse.RawDescriptionHelpFormatter, + epilog=""" +Examples: + python run_all_tests.py # Run all tests + python run_all_tests.py --integration openai # Run OpenAI tests only + python run_all_tests.py --parallel --verbose # Run all tests in parallel with verbose output + python run_all_tests.py --list # List available integrations + """, + ) + + parser.add_argument( + "--integration", "-i", help="Run tests for specific integration only" + ) + + parser.add_argument( + "--list", + "-l", + action="store_true", + help="List available integrations and their status", + ) + + parser.add_argument( + "--parallel", + "-p", + action="store_true", + help="Run tests in parallel (faster but less readable output)", + ) + + parser.add_argument( + "--verbose", + "-v", + action="store_true", + help="Enable verbose output (shows test output in real-time)", + ) + + args = parser.parse_args() + + runner = BifrostTestRunner() + + if args.list: + runner.list_integrations() + return + + if args.integration: + if args.integration not in runner.integrations: + print(f"Error: Unknown integration '{args.integration}'") + print(f"Available integrations: {', '.join(runner.integrations.keys())}") + sys.exit(1) + + result = runner.run_integration_test(args.integration, args.verbose) + if result["success"]: + print(f"\nβœ“ {args.integration} tests passed!") + else: + error_msg = result.get("error", "Unknown error") + print(f"\nβœ— {args.integration} tests failed: {error_msg}") + + # Show stdout/stderr if available + if result.get("stdout"): + print("\n--- Test Output ---") + print(result["stdout"]) + if result.get("stderr"): + print("\n--- Error Output ---") + print(result["stderr"]) + + sys.exit(1) + else: + runner.run_all_tests(args.parallel, args.verbose) + + +if __name__ == "__main__": + main() diff --git a/tests/integrations/run_integration_tests.py b/tests/integrations/run_integration_tests.py new file mode 100755 index 000000000..169e7f0f2 --- /dev/null +++ b/tests/integrations/run_integration_tests.py @@ -0,0 +1,271 @@ +#!/usr/bin/env python3 +""" +Integration-specific test runner for Bifrost integration tests. + +This script runs tests for each integration independently using their native SDKs. +No more complex gateway conversions - just direct testing! +""" + +import os +import sys +import argparse +import subprocess +from pathlib import Path +from typing import List, Optional + + +def check_api_keys(): + """Check which API keys are available""" + keys = { + "openai": os.getenv("OPENAI_API_KEY"), + "anthropic": os.getenv("ANTHROPIC_API_KEY"), + "google": os.getenv("GOOGLE_API_KEY"), + "litellm": os.getenv("LITELLM_API_KEY"), + } + + available = [integration for integration, key in keys.items() if key] + missing = [integration for integration, key in keys.items() if not key] + + return available, missing + + +def run_integration_tests( + integrations: List[str], test_pattern: Optional[str] = None, verbose: bool = False +): + """Run tests for specified integrations""" + + results = {} + + for integration in integrations: + print(f"\n{'='*60}") + print(f"πŸ§ͺ TESTING {integration.upper()} INTEGRATION") + print(f"{'='*60}") + + # Build pytest command with absolute path relative to script location + script_dir = Path(__file__).parent + test_file = script_dir / "tests" / "integrations" / f"test_{integration}.py" + + # Check if test file exists + if not test_file.exists(): + print(f"❌ Test file not found: {test_file}") + results[integration] = {"error": f"Test file not found: {test_file}"} + continue + + cmd = ["python", "-m", "pytest", str(test_file)] + + if test_pattern: + cmd.extend(["-k", test_pattern]) + + if verbose: + cmd.append("-v") + else: + cmd.append("-q") + + # Remove integration-specific marker (not needed for file-based selection) + # cmd.extend(["-m", integration]) + + # Run the tests + try: + result = subprocess.run( + cmd, + stdout=subprocess.PIPE, + stderr=subprocess.STDOUT, + text=True, + check=True, + ) + results[integration] = { + "returncode": result.returncode, + "stdout": result.stdout, + "stderr": "", # stderr is now captured in stdout + } + + # Print results + print(f"βœ… {integration.upper()} tests PASSED") + + if verbose: + print(result.stdout) + + except subprocess.CalledProcessError as e: + print(f"❌ {integration.upper()} tests FAILED") + results[integration] = { + "returncode": e.returncode, + "stdout": e.stdout, + "stderr": "", # stderr is captured in stdout + } + + # Always print output on failure to show what went wrong + if e.stdout: + print(e.stdout) + + except Exception as e: + print(f"❌ Error running {integration} tests: {e}") + results[integration] = {"error": str(e)} + + return results + + +def print_summary( + results: dict, available_integrations: List[str], missing_integrations: List[str] +): + """Print final summary""" + print(f"\n{'='*80}") + print("🎯 FINAL SUMMARY") + print(f"{'='*80}") + + # API Key Status + print(f"\nπŸ”‘ API Key Status:") + for integration in available_integrations: + print(f" βœ… {integration.upper()}: Available") + + for integration in missing_integrations: + print(f" ❌ {integration.upper()}: Missing API key") + + # Test Results + print(f"\nπŸ“Š Test Results:") + passed_integrations = [] + failed_integrations = [] + + for integration, result in results.items(): + if "error" in result: + print(f" πŸ’₯ {integration.upper()}: Error - {result['error']}") + failed_integrations.append(integration) + elif result["returncode"] == 0: + print(f" βœ… {integration.upper()}: All tests passed") + passed_integrations.append(integration) + else: + print(f" ❌ {integration.upper()}: Some tests failed") + failed_integrations.append(integration) + + # Overall Status + total_tested = len(results) + total_passed = len(passed_integrations) + + print(f"\nπŸ† Overall Results:") + print(f" Integrations tested: {total_tested}") + print(f" Integrations passed: {total_passed}") + print( + f" Success rate: {(total_passed/total_tested)*100:.1f}%" + if total_tested > 0 + else " Success rate: N/A" + ) + + if failed_integrations: + print(f"\n⚠️ Failed integrations: {', '.join(failed_integrations)}") + print(" Check the detailed output above for specific test failures.") + + +def main(): + parser = argparse.ArgumentParser( + description="Run integration-specific integration tests" + ) + parser.add_argument( + "--integrations", + nargs="+", + choices=["openai", "anthropic", "google", "litellm", "all"], + default=["all"], + help="Integrations to test (default: all available)", + ) + parser.add_argument( + "--test", help="Run specific test pattern (e.g., 'test_01_simple_chat')" + ) + parser.add_argument("-v", "--verbose", action="store_true", help="Verbose output") + parser.add_argument( + "--check-keys", action="store_true", help="Only check API key availability" + ) + parser.add_argument( + "--show-models", + action="store_true", + help="Show model configuration for all integrations", + ) + + args = parser.parse_args() + + # Check API keys + available_integrations, missing_integrations = check_api_keys() + + if args.check_keys: + print("πŸ”‘ API Key Status:") + for integration in available_integrations: + print(f" βœ… {integration.upper()}: Available") + for integration in missing_integrations: + print(f" ❌ {integration.upper()}: Missing") + return + + if args.show_models: + # Import and show model configuration using absolute path + script_dir = Path(__file__).parent + models_path = script_dir / "tests" / "utils" / "models.py" + + if not models_path.exists(): + print(f"❌ Models file not found: {models_path}") + sys.exit(1) + + # Add the parent directory to sys.path to enable the import + models_parent_dir = str(script_dir) + if models_parent_dir not in sys.path: + sys.path.insert(0, models_parent_dir) + + try: + from tests.utils.models import print_model_summary + + print_model_summary() + except ImportError as e: + print(f"❌ Could not import print_model_summary: {e}") + print(f"Tried to import from: {models_path}") + sys.exit(1) + return + + # Determine which integrations to test + if "all" in args.integrations: + integrations_to_test = available_integrations + requested_integrations = [ + "openai", + "anthropic", + "google", + "litellm", + ] # all possible integrations + else: + integrations_to_test = [ + p for p in args.integrations if p in available_integrations + ] + requested_integrations = args.integrations + + if not integrations_to_test: + print("❌ No integrations available for testing. Please set API keys.") + print("\nRequired environment variables for requested integrations:") + for integration in requested_integrations: + if integration != "all": # Skip the "all" keyword + api_key_name = f"{integration.upper()}_API_KEY" + print(f" - {api_key_name}") + sys.exit(1) + + # Calculate which requested integrations are missing API keys + requested_missing_integrations = [ + integration + for integration in requested_integrations + if integration in missing_integrations + ] + + # Show what we're about to test + print("πŸš€ Starting integration tests...") + print(f"πŸ“‹ Testing integrations: {', '.join(integrations_to_test)}") + if requested_missing_integrations: + print( + f"⏭️ Skipping integrations (no API key): {', '.join(requested_missing_integrations)}" + ) + + # Run tests + results = run_integration_tests(integrations_to_test, args.test, args.verbose) + + # Print summary + print_summary(results, available_integrations, requested_missing_integrations) + + # Exit with appropriate code + failed_count = sum( + 1 for r in results.values() if r.get("returncode", 1) != 0 or "error" in r + ) + sys.exit(failed_count) + + +if __name__ == "__main__": + main() diff --git a/tests/integrations/test_audio.py b/tests/integrations/test_audio.py new file mode 100755 index 000000000..e52299897 --- /dev/null +++ b/tests/integrations/test_audio.py @@ -0,0 +1,155 @@ +#!/usr/bin/env python3 +""" +Dedicated test runner for Speech and Transcription functionality. +This script runs only the speech and transcription tests for easier development and debugging. + +Usage: + python test_audio.py + python test_audio.py --verbose + python test_audio.py --help +""" + +import sys +import os +import argparse +import subprocess +from pathlib import Path + +# Add the tests directory to Python path +tests_dir = Path(__file__).parent +sys.path.insert(0, str(tests_dir)) + + +def run_speech_transcription_tests(verbose=False, specific_test=None): + """Run speech and transcription tests""" + + # Change to the tests directory + os.chdir(tests_dir) + + # Build pytest command + cmd = ["python", "-m", "pytest"] + + if verbose: + cmd.append("-v") + else: + cmd.append("-q") + + # Add specific test pattern for speech/transcription tests + if specific_test: + test_pattern = f"tests/integrations/test_openai.py::{specific_test}" + else: + # Run all speech and transcription related tests + test_pattern = "tests/integrations/test_openai.py::TestOpenAIIntegration::test_14_speech_synthesis" + cmd.extend( + [ + "tests/integrations/test_openai.py::TestOpenAIIntegration::test_14_speech_synthesis", + "tests/integrations/test_openai.py::TestOpenAIIntegration::test_15_transcription_audio", + "tests/integrations/test_openai.py::TestOpenAIIntegration::test_16_transcription_streaming", + "tests/integrations/test_openai.py::TestOpenAIIntegration::test_17_speech_transcription_round_trip", + "tests/integrations/test_openai.py::TestOpenAIIntegration::test_18_speech_error_handling", + "tests/integrations/test_openai.py::TestOpenAIIntegration::test_19_transcription_error_handling", + "tests/integrations/test_openai.py::TestOpenAIIntegration::test_20_speech_different_voices_and_formats", + ] + ) + + if not specific_test: + # Add some useful pytest options + cmd.extend( + [ + "--tb=short", # Shorter traceback format + "--maxfail=3", # Stop after 3 failures + "-x", # Stop on first failure + ] + ) + else: + cmd.append(test_pattern) + + # Add environment info + print("🎡 SPEECH & TRANSCRIPTION INTEGRATION TESTS") + print("=" * 60) + print(f"πŸ”§ Running from: {tests_dir}") + print(f"πŸ“‹ Environment variables needed:") + print(" - OPENAI_API_KEY (required)") + print(" - BIFROST_BASE_URL (optional, defaults to http://localhost:8080)") + print() + + # Check for required environment variables + if not os.getenv("OPENAI_API_KEY"): + print("❌ ERROR: OPENAI_API_KEY environment variable is required") + print(" Set it with: export OPENAI_API_KEY=your_key_here") + return 1 + + bifrost_url = os.getenv("BIFROST_BASE_URL", "http://localhost:8080") + print(f"πŸŒ‰ Bifrost URL: {bifrost_url}") + print(f"πŸ€– Testing OpenAI integration through Bifrost proxy") + print() + + # Run the tests + print("πŸš€ Starting Speech & Transcription Tests...") + print("-" * 60) + + try: + result = subprocess.run(cmd, cwd=tests_dir) + return result.returncode + except KeyboardInterrupt: + print("\n❌ Tests interrupted by user") + return 1 + except Exception as e: + print(f"\n❌ Error running tests: {e}") + return 1 + + +def list_available_tests(): + """List all available speech and transcription tests""" + tests = [ + "test_14_speech_synthesis", + "test_15_transcription_audio", + "test_16_transcription_streaming", + "test_17_speech_transcription_round_trip", + "test_18_speech_error_handling", + "test_19_transcription_error_handling", + "test_20_speech_different_voices_and_formats", + ] + + print("🎡 Available Speech & Transcription Tests:") + print("=" * 50) + for i, test in enumerate(tests, 1): + print(f"{i:2d}. {test}") + print() + print("Run specific test with: python test_audio.py --test ") + + +def main(): + parser = argparse.ArgumentParser( + description="Run Speech and Transcription integration tests", + formatter_class=argparse.RawDescriptionHelpFormatter, + epilog=""" +Examples: + python test_audio.py # Run all speech/transcription tests + python test_audio.py --verbose # Run with verbose output + python test_audio.py --list # List available tests + python test_audio.py --test test_14_speech_synthesis # Run specific test + """, + ) + + parser.add_argument( + "--verbose", "-v", action="store_true", help="Enable verbose output" + ) + + parser.add_argument("--test", "-t", type=str, help="Run a specific test by name") + + parser.add_argument( + "--list", "-l", action="store_true", help="List available tests" + ) + + args = parser.parse_args() + + if args.list: + list_available_tests() + return 0 + + return run_speech_transcription_tests(verbose=args.verbose, specific_test=args.test) + + +if __name__ == "__main__": + sys.exit(main()) diff --git a/tests/integrations/tests/__init__.py b/tests/integrations/tests/__init__.py new file mode 100644 index 000000000..92e4c036e --- /dev/null +++ b/tests/integrations/tests/__init__.py @@ -0,0 +1,8 @@ +""" +Bifrost Integration Tests + +Production-ready test suite for testing various AI integrations through Bifrost proxy. +Supports multiple integrations with uniform test interface. +""" + +__version__ = "1.0.0" diff --git a/tests/integrations/tests/conftest.py b/tests/integrations/tests/conftest.py new file mode 100644 index 000000000..bf8dc16a0 --- /dev/null +++ b/tests/integrations/tests/conftest.py @@ -0,0 +1,162 @@ +""" +Pytest configuration for integration-specific tests. +""" + +import pytest +import os + + +def pytest_configure(config): + """Configure pytest with custom markers""" + config.addinivalue_line("markers", "openai: mark test as requiring OpenAI API key") + config.addinivalue_line( + "markers", "anthropic: mark test as requiring Anthropic API key" + ) + config.addinivalue_line("markers", "google: mark test as requiring Google API key") + config.addinivalue_line("markers", "litellm: mark test as requiring LiteLLM setup") + + +def pytest_collection_modifyitems(config, items): + """Modify test collection to add markers based on test file names""" + for item in items: + # Add markers based on test file location + if "test_openai" in item.nodeid: + item.add_marker(pytest.mark.openai) + elif "test_anthropic" in item.nodeid: + item.add_marker(pytest.mark.anthropic) + elif "test_google" in item.nodeid: + item.add_marker(pytest.mark.google) + elif "test_litellm" in item.nodeid: + item.add_marker(pytest.mark.litellm) + + +@pytest.fixture(scope="session") +def api_keys(): + """Collect all available API keys""" + return { + "openai": os.getenv("OPENAI_API_KEY"), + "anthropic": os.getenv("ANTHROPIC_API_KEY"), + "google": os.getenv("GOOGLE_API_KEY"), + "litellm": os.getenv("LITELLM_API_KEY"), + } + + +@pytest.fixture(scope="session") +def available_integrations(api_keys): + """Determine which integrations are available based on API keys""" + available = [] + + if api_keys["openai"]: + available.append("openai") + if api_keys["anthropic"]: + available.append("anthropic") + if api_keys["google"]: + available.append("google") + if api_keys["litellm"]: + available.append("litellm") + + return available + + +@pytest.fixture +def test_summary(): + """Fixture to collect test results for summary reporting""" + results = {"passed": [], "failed": [], "skipped": []} + return results + + +def pytest_runtest_makereport(item, call): + """Hook to capture test results""" + # Only record results during the "call" phase to avoid double counting + if call.when == "call": + # Extract integration and test info + integration = None + if "test_openai" in item.nodeid: + integration = "openai" + elif "test_anthropic" in item.nodeid: + integration = "anthropic" + elif "test_google" in item.nodeid: + integration = "google" + elif "test_litellm" in item.nodeid: + integration = "litellm" + + test_name = item.name + + # Store result info + result_info = { + "integration": integration, + "test": test_name, + "nodeid": item.nodeid, + } + + if hasattr(item.session, "test_results"): + if call.excinfo is None: + item.session.test_results["passed"].append(result_info) + else: + result_info["error"] = str(call.excinfo.value) + item.session.test_results["failed"].append(result_info) + + +def pytest_sessionstart(session): + """Initialize test results collection""" + session.test_results = {"passed": [], "failed": [], "skipped": []} + + +def pytest_sessionfinish(session, exitstatus): + """Print test summary at the end""" + results = session.test_results + + print("\n" + "=" * 80) + print("INTEGRATION TEST SUMMARY") + print("=" * 80) + + # Group results by integration + integration_results = {} + + for result in results["passed"] + results["failed"] + results["skipped"]: + integration = result.get("integration", "unknown") + if integration and integration not in integration_results: + integration_results[integration] = {"passed": 0, "failed": 0, "skipped": 0} + + for result in results["passed"]: + integration = result.get("integration", "unknown") + if integration and integration in integration_results: + integration_results[integration]["passed"] += 1 + + for result in results["failed"]: + integration = result.get("integration", "unknown") + if integration and integration in integration_results: + integration_results[integration]["failed"] += 1 + + for result in results["skipped"]: + integration = result.get("integration", "unknown") + if integration and integration in integration_results: + integration_results[integration]["skipped"] += 1 + + # Print summary by integration + for integration, counts in integration_results.items(): + total = counts["passed"] + counts["failed"] + counts["skipped"] + if total > 0: + print(f"\n{integration.upper()} Integration:") + print(f" βœ… Passed: {counts['passed']}") + print(f" ❌ Failed: {counts['failed']}") + print(f" ⏭️ Skipped: {counts['skipped']}") + print(f" πŸ“Š Total: {total}") + + if counts["passed"] > 0: + success_rate = ( + (counts["passed"] / (counts["passed"] + counts["failed"])) * 100 + if (counts["passed"] + counts["failed"]) > 0 + else 0 + ) + print(f" 🎯 Success Rate: {success_rate:.1f}%") + + # Print failed tests details + if results["failed"]: + print(f"\n❌ FAILED TESTS ({len(results['failed'])}):") + for result in results["failed"]: + print(f" β€’ {result['integration']}: {result['test']}") + if "error" in result: + print(f" Error: {result['error']}") + + print("\n" + "=" * 80) diff --git a/tests/integrations/tests/integrations/__init__.py b/tests/integrations/tests/integrations/__init__.py new file mode 100644 index 000000000..ec4135e3b --- /dev/null +++ b/tests/integrations/tests/integrations/__init__.py @@ -0,0 +1 @@ +# Integration-specific test packages diff --git a/tests/integrations/tests/integrations/test_anthropic.py b/tests/integrations/tests/integrations/test_anthropic.py new file mode 100644 index 000000000..83f351b99 --- /dev/null +++ b/tests/integrations/tests/integrations/test_anthropic.py @@ -0,0 +1,628 @@ +""" +Anthropic Integration Tests + +πŸ€– MODELS USED: +- Chat: claude-3-haiku-20240307 +- Vision: claude-3-haiku-20240307 +- Tools: claude-3-haiku-20240307 +- Alternatives: claude-3-sonnet-20240229, claude-3-opus-20240229, claude-3-5-sonnet-20241022 + +Tests all 11 core scenarios using Anthropic SDK directly: +1. Simple chat +2. Multi turn conversation +3. Tool calls +4. Multiple tool calls +5. End2End tool calling +6. Automatic function calling +7. Image (url) +8. Image (base64) +9. Multiple images +10. Complete end2end test with conversation history, tool calls, tool results and images +11. Integration specific tests +""" + +import pytest +import base64 +import requests +from anthropic import Anthropic +from typing import List, Dict, Any + +from ..utils.common import ( + Config, + SIMPLE_CHAT_MESSAGES, + MULTI_TURN_MESSAGES, + SINGLE_TOOL_CALL_MESSAGES, + MULTIPLE_TOOL_CALL_MESSAGES, + IMAGE_URL, + BASE64_IMAGE, + INVALID_ROLE_MESSAGES, + STREAMING_CHAT_MESSAGES, + STREAMING_TOOL_CALL_MESSAGES, + WEATHER_TOOL, + CALCULATOR_TOOL, + ALL_TOOLS, + mock_tool_response, + assert_valid_chat_response, + assert_has_tool_calls, + assert_valid_image_response, + assert_valid_error_response, + assert_error_propagation, + assert_valid_streaming_response, + collect_streaming_content, + extract_tool_calls, + get_api_key, + skip_if_no_api_key, + COMPARISON_KEYWORDS, + WEATHER_KEYWORDS, + LOCATION_KEYWORDS, +) +from ..utils.config_loader import get_model + + +@pytest.fixture +def anthropic_client(): + """Create Anthropic client for testing""" + from ..utils.config_loader import get_integration_url, get_config + + api_key = get_api_key("anthropic") + base_url = get_integration_url("anthropic") + + # Get additional integration settings + config = get_config() + integration_settings = config.get_integration_settings("anthropic") + api_config = config.get_api_config() + + client_kwargs = { + "api_key": api_key, + "base_url": base_url, + "timeout": api_config.get("timeout", 30), + "max_retries": api_config.get("max_retries", 3), + } + + # Add Anthropic-specific settings + if integration_settings.get("version"): + client_kwargs["default_headers"] = { + "anthropic-version": integration_settings["version"] + } + + return Anthropic(**client_kwargs) + + +@pytest.fixture +def test_config(): + """Test configuration""" + return Config() + + +def convert_to_anthropic_messages( + messages: List[Dict[str, Any]], +) -> List[Dict[str, Any]]: + """Convert common message format to Anthropic format""" + anthropic_messages = [] + + for msg in messages: + if msg["role"] == "system": + continue # System messages handled separately in Anthropic + + # Handle image messages + if isinstance(msg.get("content"), list): + content = [] + for item in msg["content"]: + if item["type"] == "text": + content.append({"type": "text", "text": item["text"]}) + elif item["type"] == "image_url": + url = item["image_url"]["url"] + if url.startswith("data:image"): + # Base64 image + media_type, data = url.split(",", 1) + content.append( + { + "type": "image", + "source": { + "type": "base64", + "media_type": media_type, + "data": data, + }, + } + ) + else: + # URL image - need to download and convert to base64 + response = requests.get(url) + img_data = base64.b64encode(response.content).decode() + content.append( + { + "type": "image", + "source": { + "type": "base64", + "media_type": "image/jpeg", + "data": img_data, + }, + } + ) + + anthropic_messages.append({"role": msg["role"], "content": content}) + else: + anthropic_messages.append({"role": msg["role"], "content": msg["content"]}) + + return anthropic_messages + + +def convert_to_anthropic_tools(tools: List[Dict[str, Any]]) -> List[Dict[str, Any]]: + """Convert common tool format to Anthropic format""" + anthropic_tools = [] + + for tool in tools: + anthropic_tools.append( + { + "name": tool["name"], + "description": tool["description"], + "input_schema": tool["parameters"], + } + ) + + return anthropic_tools + + +class TestAnthropicIntegration: + """Test suite for Anthropic integration covering all 11 core scenarios""" + + @skip_if_no_api_key("anthropic") + def test_01_simple_chat(self, anthropic_client, test_config): + """Test Case 1: Simple chat interaction""" + messages = convert_to_anthropic_messages(SIMPLE_CHAT_MESSAGES) + + response = anthropic_client.messages.create( + model=get_model("anthropic", "chat"), messages=messages, max_tokens=100 + ) + + assert_valid_chat_response(response) + assert len(response.content) > 0 + assert response.content[0].type == "text" + assert len(response.content[0].text) > 0 + + @skip_if_no_api_key("anthropic") + def test_02_multi_turn_conversation(self, anthropic_client, test_config): + """Test Case 2: Multi-turn conversation""" + messages = convert_to_anthropic_messages(MULTI_TURN_MESSAGES) + + response = anthropic_client.messages.create( + model=get_model("anthropic", "chat"), messages=messages, max_tokens=150 + ) + + assert_valid_chat_response(response) + content = response.content[0].text.lower() + # Should mention population or numbers since we asked about Paris population + assert any( + word in content + for word in ["population", "million", "people", "inhabitants"] + ) + + @skip_if_no_api_key("anthropic") + def test_03_single_tool_call(self, anthropic_client, test_config): + """Test Case 3: Single tool call""" + messages = convert_to_anthropic_messages(SINGLE_TOOL_CALL_MESSAGES) + tools = convert_to_anthropic_tools([WEATHER_TOOL]) + + response = anthropic_client.messages.create( + model=get_model("anthropic", "chat"), + messages=messages, + tools=tools, + max_tokens=100, + ) + + assert_has_tool_calls(response, expected_count=1) + tool_calls = extract_tool_calls(response) + assert tool_calls[0]["name"] == "get_weather" + assert "location" in tool_calls[0]["arguments"] + + @skip_if_no_api_key("anthropic") + def test_04_multiple_tool_calls(self, anthropic_client, test_config): + """Test Case 4: Multiple tool calls in one response""" + messages = convert_to_anthropic_messages(MULTIPLE_TOOL_CALL_MESSAGES) + tools = convert_to_anthropic_tools([WEATHER_TOOL, CALCULATOR_TOOL]) + + response = anthropic_client.messages.create( + model=get_model("anthropic", "chat"), + messages=messages, + tools=tools, + max_tokens=200, + ) + + # Anthropic might be more conservative with multiple tool calls + # Let's check if it made at least one tool call and prefer multiple if possible + assert_has_tool_calls(response) # At least 1 tool call + tool_calls = extract_anthropic_tool_calls(response) + tool_names = [tc["name"] for tc in tool_calls] + + # Should make relevant tool calls - either weather, calculate, or both + expected_tools = ["get_weather", "calculate"] + made_relevant_calls = any(name in expected_tools for name in tool_names) + assert ( + made_relevant_calls + ), f"Expected tool calls from {expected_tools}, got {tool_names}" + + @skip_if_no_api_key("anthropic") + def test_05_end2end_tool_calling(self, anthropic_client, test_config): + """Test Case 5: Complete tool calling flow with responses""" + messages = [{"role": "user", "content": "What's the weather in Boston?"}] + tools = convert_to_anthropic_tools([WEATHER_TOOL]) + + response = anthropic_client.messages.create( + model=get_model("anthropic", "chat"), + messages=messages, + tools=tools, + max_tokens=100, + ) + + assert_has_tool_calls(response, expected_count=1) + + # Add assistant's response to conversation + messages.append({"role": "assistant", "content": response.content}) + + # Add tool response + tool_calls = extract_anthropic_tool_calls(response) + tool_response = mock_tool_response( + tool_calls[0]["name"], tool_calls[0]["arguments"] + ) + + # Find the tool use block to get its ID + tool_use_id = None + for content in response.content: + if content.type == "tool_use": + tool_use_id = content.id + break + + messages.append( + { + "role": "user", + "content": [ + { + "type": "tool_result", + "tool_use_id": tool_use_id, + "content": tool_response, + } + ], + } + ) + + # Get final response + final_response = anthropic_client.messages.create( + model=get_model("anthropic", "chat"), messages=messages, max_tokens=150 + ) + + # Anthropic might return empty content if tool result is sufficient + assert final_response is not None + if len(final_response.content) > 0: + assert_valid_chat_response(final_response) + content = final_response.content[0].text.lower() + weather_location_keywords = WEATHER_KEYWORDS + LOCATION_KEYWORDS + assert any(word in content for word in weather_location_keywords) + else: + # If no content, that's ok - tool result was sufficient + print("Model returned empty content - tool result was sufficient") + + @skip_if_no_api_key("anthropic") + def test_06_automatic_function_calling(self, anthropic_client, test_config): + """Test Case 6: Automatic function calling""" + messages = [{"role": "user", "content": "Calculate 25 * 4 for me"}] + tools = convert_to_anthropic_tools([CALCULATOR_TOOL]) + + response = anthropic_client.messages.create( + model=get_model("anthropic", "chat"), + messages=messages, + tools=tools, + max_tokens=100, + ) + + # Should automatically choose to use the calculator + assert_has_tool_calls(response, expected_count=1) + tool_calls = extract_tool_calls(response) + assert tool_calls[0]["name"] == "calculate" + + @skip_if_no_api_key("anthropic") + def test_07_image_url(self, anthropic_client, test_config): + """Test Case 7: Image analysis from URL""" + # Download image and convert to base64 for Anthropic + response_img = requests.get(IMAGE_URL) + img_data = base64.b64encode(response_img.content).decode() + + messages = [ + { + "role": "user", + "content": [ + {"type": "text", "text": "What do you see in this image?"}, + { + "type": "image", + "source": { + "type": "base64", + "media_type": "image/jpeg", + "data": img_data, + }, + }, + ], + } + ] + + response = anthropic_client.messages.create( + model=get_model("anthropic", "chat"), messages=messages, max_tokens=200 + ) + + assert_valid_image_response(response) + + @skip_if_no_api_key("anthropic") + def test_08_image_base64(self, anthropic_client, test_config): + """Test Case 8: Image analysis from base64""" + messages = [ + { + "role": "user", + "content": [ + {"type": "text", "text": "Describe this image"}, + { + "type": "image", + "source": { + "type": "base64", + "media_type": "image/png", + "data": BASE64_IMAGE, + }, + }, + ], + } + ] + + response = anthropic_client.messages.create( + model=get_model("anthropic", "chat"), messages=messages, max_tokens=200 + ) + + assert_valid_image_response(response) + + @skip_if_no_api_key("anthropic") + def test_09_multiple_images(self, anthropic_client, test_config): + """Test Case 9: Multiple image analysis""" + # Download first image + response_img = requests.get(IMAGE_URL) + img_data = base64.b64encode(response_img.content).decode() + + messages = [ + { + "role": "user", + "content": [ + {"type": "text", "text": "Compare these two images"}, + { + "type": "image", + "source": { + "type": "base64", + "media_type": "image/jpeg", + "data": img_data, + }, + }, + { + "type": "image", + "source": { + "type": "base64", + "media_type": "image/png", + "data": BASE64_IMAGE, + }, + }, + ], + } + ] + + response = anthropic_client.messages.create( + model=get_model("anthropic", "chat"), messages=messages, max_tokens=300 + ) + + assert_valid_image_response(response) + content = response.content[0].text.lower() + # Should mention comparison or differences + assert any( + word in content for word in COMPARISON_KEYWORDS + ), f"Response should contain comparison keywords. Got content: {content}" + + @skip_if_no_api_key("anthropic") + def test_10_complex_end2end(self, anthropic_client, test_config): + """Test Case 10: Complex end-to-end with conversation, images, and tools""" + # Download image for Anthropic format + response_img = requests.get(IMAGE_URL) + img_data = base64.b64encode(response_img.content).decode() + + messages = [ + {"role": "user", "content": "Hello! I need help with some tasks."}, + { + "role": "assistant", + "content": "Hello! I'd be happy to help you with your tasks. What do you need assistance with?", + }, + { + "role": "user", + "content": [ + { + "type": "text", + "text": "First, can you tell me what's in this image and then get the weather for the location shown?", + }, + { + "type": "image", + "source": { + "type": "base64", + "media_type": "image/jpeg", + "data": img_data, + }, + }, + ], + }, + ] + + tools = convert_to_anthropic_tools([WEATHER_TOOL]) + + response1 = anthropic_client.messages.create( + model=get_model("anthropic", "chat"), + messages=messages, + tools=tools, + max_tokens=300, + ) + + # Should either describe image or call weather tool (or both) + assert len(response1.content) > 0 + + # Add response to conversation + messages.append({"role": "assistant", "content": response1.content}) + + # If there were tool calls, handle them + tool_calls = extract_anthropic_tool_calls(response1) + if tool_calls: + for i, tool_call in enumerate(tool_calls): + tool_response = mock_tool_response( + tool_call["name"], tool_call["arguments"] + ) + + # Find the corresponding tool use ID + tool_use_id = None + for content in response1.content: + if content.type == "tool_use" and content.name == tool_call["name"]: + tool_use_id = content.id + break + + messages.append( + { + "role": "user", + "content": [ + { + "type": "tool_result", + "tool_use_id": tool_use_id, + "content": tool_response, + } + ], + } + ) + + # Get final response after tool calls + final_response = anthropic_client.messages.create( + model=get_model("anthropic", "chat"), messages=messages, max_tokens=200 + ) + + # Anthropic might return empty content if tool result is sufficient + # This is valid behavior - just check that we got a response + assert final_response is not None + if len(final_response.content) > 0: + # If there is content, validate it + assert_valid_chat_response(final_response) + else: + # If no content, that's ok too - tool result was sufficient + print("Model returned empty content - tool result was sufficient") + + @skip_if_no_api_key("anthropic") + def test_11_integration_specific_features(self, anthropic_client, test_config): + """Test Case 11: Anthropic-specific features""" + + # Test 1: System message + response1 = anthropic_client.messages.create( + model=get_model("anthropic", "chat"), + system="You are a helpful assistant that always responds in exactly 5 words.", + messages=[{"role": "user", "content": "Hello, how are you?"}], + max_tokens=50, + ) + + assert_valid_chat_response(response1) + # Check if response is approximately 5 words (allow some flexibility) + word_count = len(response1.content[0].text.split()) + assert 3 <= word_count <= 7, f"Expected ~5 words, got {word_count}" + + # Test 2: Temperature parameter + response2 = anthropic_client.messages.create( + model=get_model("anthropic", "chat"), + messages=[ + {"role": "user", "content": "Tell me a creative story in one sentence."} + ], + temperature=0.9, + max_tokens=100, + ) + + assert_valid_chat_response(response2) + + # Test 3: Tool choice (any tool) + tools = convert_to_anthropic_tools([CALCULATOR_TOOL, WEATHER_TOOL]) + response3 = anthropic_client.messages.create( + model=get_model("anthropic", "chat"), + messages=[{"role": "user", "content": "What's 15 + 27?"}], + tools=tools, + tool_choice={"type": "any"}, # Force tool use + max_tokens=100, + ) + + assert_has_tool_calls(response3) + tool_calls = extract_anthropic_tool_calls(response3) + # Should prefer calculator for math question + assert tool_calls[0]["name"] == "calculate" + + @skip_if_no_api_key("anthropic") + def test_12_error_handling_invalid_roles(self, anthropic_client, test_config): + """Test Case 12: Error handling for invalid roles""" + with pytest.raises(Exception) as exc_info: + anthropic_client.messages.create( + model=get_model("anthropic", "chat"), + messages=INVALID_ROLE_MESSAGES, + max_tokens=100, + ) + + # Verify the error is properly caught and contains role-related information + error = exc_info.value + assert_valid_error_response(error, "tester") + assert_error_propagation(error, "anthropic") + + @skip_if_no_api_key("anthropic") + def test_13_streaming(self, anthropic_client, test_config): + """Test Case 13: Streaming chat completion""" + # Test basic streaming + stream = anthropic_client.messages.create( + model=get_model("anthropic", "chat"), + messages=STREAMING_CHAT_MESSAGES, + max_tokens=200, + stream=True, + ) + + content, chunk_count, tool_calls_detected = collect_streaming_content( + stream, "anthropic", timeout=30 + ) + + # Validate streaming results + assert chunk_count > 0, "Should receive at least one chunk" + assert len(content) > 10, "Should receive substantial content" + assert not tool_calls_detected, "Basic streaming shouldn't have tool calls" + + # Test streaming with tool calls + stream_with_tools = anthropic_client.messages.create( + model=get_model("anthropic", "tools"), + messages=STREAMING_TOOL_CALL_MESSAGES, + max_tokens=150, + tools=convert_to_anthropic_tools([WEATHER_TOOL]), + stream=True, + ) + + content_tools, chunk_count_tools, tool_calls_detected_tools = ( + collect_streaming_content(stream_with_tools, "anthropic", timeout=30) + ) + + # Validate tool streaming results + assert chunk_count_tools > 0, "Should receive at least one chunk with tools" + assert tool_calls_detected_tools, "Should receive at least one chunk with tools" + + +# Additional helper functions specific to Anthropic +def extract_anthropic_tool_calls(response: Any) -> List[Dict[str, Any]]: + """Extract tool calls from Anthropic response format with proper type checking""" + tool_calls = [] + + # Type check for Anthropic Message response + if not hasattr(response, "content") or not response.content: + return tool_calls + + for content in response.content: + if hasattr(content, "type") and content.type == "tool_use": + if hasattr(content, "name") and hasattr(content, "input"): + try: + tool_calls.append( + {"name": content.name, "arguments": content.input} + ) + except AttributeError as e: + print(f"Warning: Failed to extract tool call from content: {e}") + continue + + return tool_calls diff --git a/tests/integrations/tests/integrations/test_google.py b/tests/integrations/tests/integrations/test_google.py new file mode 100644 index 000000000..fea509222 --- /dev/null +++ b/tests/integrations/tests/integrations/test_google.py @@ -0,0 +1,528 @@ +""" +Google GenAI Integration Tests + +Tests all 11 core scenarios using Google GenAI SDK directly: +1. Simple chat +2. Multi turn conversation +3. Tool calls +4. Multiple tool calls +5. End2End tool calling +6. Automatic function calling +7. Image (url) +8. Image (base64) +9. Multiple images +10. Complete end2end test with conversation history, tool calls, tool results and images +11. Integration specific tests +""" + +import pytest +import base64 +import requests +from PIL import Image +import io +from google import genai +from google.genai.types import HttpOptions +from google.genai import types +from typing import List, Dict, Any + +from ..utils.common import ( + Config, + SIMPLE_CHAT_MESSAGES, + SINGLE_TOOL_CALL_MESSAGES, + MULTIPLE_TOOL_CALL_MESSAGES, + IMAGE_URL, + BASE64_IMAGE, + INVALID_ROLE_MESSAGES, + STREAMING_CHAT_MESSAGES, + STREAMING_TOOL_CALL_MESSAGES, + WEATHER_TOOL, + CALCULATOR_TOOL, + assert_valid_chat_response, + assert_valid_embedding_response, + assert_valid_image_response, + assert_valid_error_response, + assert_error_propagation, + assert_valid_streaming_response, + collect_streaming_content, + get_api_key, + skip_if_no_api_key, + COMPARISON_KEYWORDS, + WEATHER_KEYWORDS, + LOCATION_KEYWORDS, + GENAI_INVALID_ROLE_CONTENT, + EMBEDDINGS_SINGLE_TEXT, +) +from ..utils.config_loader import get_model + + +@pytest.fixture +def google_client(): + """Configure Google GenAI client for testing""" + from ..utils.config_loader import get_integration_url + + api_key = get_api_key("google") + base_url = get_integration_url("google") + + client_kwargs = { + "api_key": api_key, + } + + # Add base URL support and timeout through HttpOptions + http_options_kwargs = {} + if base_url: + http_options_kwargs["base_url"] = base_url + + if http_options_kwargs: + client_kwargs["http_options"] = HttpOptions(**http_options_kwargs) + + return genai.Client(**client_kwargs) + + +@pytest.fixture +def test_config(): + """Test configuration""" + return Config() + + +def convert_to_google_messages(messages: List[Dict[str, Any]]) -> str: + """Convert common message format to Google GenAI format""" + # Google GenAI uses a simpler format - just extract the first user message + for msg in messages: + if msg["role"] == "user": + if isinstance(msg["content"], str): + return msg["content"] + elif isinstance(msg["content"], list): + # Handle multimodal content + text_parts = [ + item["text"] for item in msg["content"] if item["type"] == "text" + ] + if text_parts: + return text_parts[0] + return "Hello" + + +def convert_to_google_tools(tools: List[Dict[str, Any]]) -> List[Any]: + """Convert common tool format to Google GenAI format using FunctionDeclaration""" + from google.genai import types + + google_tools = [] + + for tool in tools: + # Create a FunctionDeclaration for each tool + function_declaration = types.FunctionDeclaration( + name=tool["name"], + description=tool["description"], + parameters=types.Schema( + type=tool["parameters"]["type"].upper(), + properties={ + name: types.Schema( + type=prop["type"].upper(), + description=prop.get("description", ""), + ) + for name, prop in tool["parameters"]["properties"].items() + }, + required=tool["parameters"].get("required", []), + ), + ) + + # Create a Tool object containing the function declaration + google_tool = types.Tool(function_declarations=[function_declaration]) + google_tools.append(google_tool) + + return google_tools + + +def load_image_from_url(url: str): + """Load image from URL for Google GenAI""" + from google.genai import types + import io + import base64 + + if url.startswith("data:image"): + # Base64 image - extract the base64 data part + header, data = url.split(",", 1) + img_data = base64.b64decode(data) + image = Image.open(io.BytesIO(img_data)) + else: + # URL image + response = requests.get(url) + image = Image.open(io.BytesIO(response.content)) + + # Resize image to reduce payload size (max width/height of 512px) + max_size = 512 + if image.width > max_size or image.height > max_size: + image.thumbnail((max_size, max_size), Image.Resampling.LANCZOS) + + # Convert to RGB if necessary (for JPEG compatibility) + if image.mode in ("RGBA", "LA", "P"): + # Create a white background + background = Image.new("RGB", image.size, (255, 255, 255)) + if image.mode == "P": + image = image.convert("RGBA") + background.paste( + image, mask=image.split()[-1] if image.mode in ("RGBA", "LA") else None + ) + image = background + + # Convert PIL Image to compressed JPEG bytes + img_byte_arr = io.BytesIO() + image.save(img_byte_arr, format="JPEG", quality=85, optimize=True) + img_byte_arr = img_byte_arr.getvalue() + + # Use the correct Part.from_bytes method as per Google GenAI documentation + return types.Part.from_bytes(data=img_byte_arr, mime_type="image/jpeg") + + +class TestGoogleIntegration: + """Test suite for Google GenAI integration covering all 11 core scenarios""" + + @skip_if_no_api_key("google") + def test_01_simple_chat(self, google_client, test_config): + """Test Case 1: Simple chat interaction""" + message = convert_to_google_messages(SIMPLE_CHAT_MESSAGES) + + response = google_client.models.generate_content( + model=get_model("google", "chat"), contents=message + ) + + assert_valid_chat_response(response) + assert response.text is not None + assert len(response.text) > 0 + + @skip_if_no_api_key("google") + def test_02_multi_turn_conversation(self, google_client, test_config): + """Test Case 2: Multi-turn conversation""" + # Start a chat session for multi-turn + chat = google_client.chats.create(model=get_model("google", "chat")) + + # Send first message + response1 = chat.send_message("What's the capital of France?") + assert_valid_chat_response(response1) + + # Send follow-up message + response2 = chat.send_message("What's the population of that city?") + assert_valid_chat_response(response2) + + content = response2.text.lower() + # Should mention population or numbers since we asked about Paris population + assert any( + word in content + for word in ["population", "million", "people", "inhabitants"] + ) + + @skip_if_no_api_key("google") + def test_03_single_tool_call(self, google_client, test_config): + """Test Case 3: Single tool call""" + from google.genai import types + + tools = convert_to_google_tools([WEATHER_TOOL]) + message = convert_to_google_messages(SINGLE_TOOL_CALL_MESSAGES) + + response = google_client.models.generate_content( + model=get_model("google", "tools"), + contents=message, + config=types.GenerateContentConfig(tools=tools), + ) + + # Check for function calls in response + assert response.candidates is not None + assert len(response.candidates) > 0 + + # Check if function call was made (Google GenAI might return function calls) + if hasattr(response, "function_calls") and response.function_calls: + assert len(response.function_calls) >= 1 + assert response.function_calls[0].name == "get_weather" + + @skip_if_no_api_key("google") + def test_04_multiple_tool_calls(self, google_client, test_config): + """Test Case 4: Multiple tool calls in one response""" + from google.genai import types + + tools = convert_to_google_tools([WEATHER_TOOL, CALCULATOR_TOOL]) + message = convert_to_google_messages(MULTIPLE_TOOL_CALL_MESSAGES) + + response = google_client.models.generate_content( + model=get_model("google", "tools"), + contents=message, + config=types.GenerateContentConfig(tools=tools), + ) + + # Check for function calls + assert response.candidates is not None + + # Check if function calls were made + if hasattr(response, "function_calls") and response.function_calls: + # Should have multiple function calls + assert len(response.function_calls) >= 1 + function_names = [fc.name for fc in response.function_calls] + # At least one of the expected tools should be called + assert any(name in ["get_weather", "calculate"] for name in function_names) + + @skip_if_no_api_key("google") + def test_05_end2end_tool_calling(self, google_client, test_config): + """Test Case 5: Complete tool calling flow with responses""" + from google.genai import types + + tools = convert_to_google_tools([WEATHER_TOOL]) + + # Start chat for tool calling flow + chat = google_client.chats.create(model=get_model("google", "tools")) + + response1 = chat.send_message( + "What's the weather in Boston?", + config=types.GenerateContentConfig(tools=tools), + ) + + # Check if function call was made + if hasattr(response1, "function_calls") and response1.function_calls: + # Simulate function execution and send result back + for fc in response1.function_calls: + if fc.name == "get_weather": + # Mock function result and send back + response2 = chat.send_message( + types.Part.from_function_response( + name=fc.name, + response={ + "result": "The weather in Boston is 72Β°F and sunny." + }, + ) + ) + assert_valid_chat_response(response2) + + content = response2.text.lower() + weather_location_keywords = WEATHER_KEYWORDS + LOCATION_KEYWORDS + assert any(word in content for word in weather_location_keywords) + + @skip_if_no_api_key("google") + def test_06_automatic_function_calling(self, google_client, test_config): + """Test Case 6: Automatic function calling""" + from google.genai import types + + tools = convert_to_google_tools([CALCULATOR_TOOL]) + + response = google_client.models.generate_content( + model=get_model("google", "tools"), + contents="Calculate 25 * 4 for me", + config=types.GenerateContentConfig(tools=tools), + ) + + # Should automatically choose to use the calculator + assert response.candidates is not None + + # Check if function calls were made + if hasattr(response, "function_calls") and response.function_calls: + assert response.function_calls[0].name == "calculate" + + @skip_if_no_api_key("google") + def test_07_image_url(self, google_client, test_config): + """Test Case 7: Image analysis from URL""" + image = load_image_from_url(IMAGE_URL) + + response = google_client.models.generate_content( + model=get_model("google", "vision"), + contents=["What do you see in this image?", image], + ) + + assert_valid_image_response(response) + + @skip_if_no_api_key("google") + def test_08_image_base64(self, google_client, test_config): + """Test Case 8: Image analysis from base64""" + image = load_image_from_url(f"data:image/png;base64,{BASE64_IMAGE}") + + response = google_client.models.generate_content( + model=get_model("google", "vision"), contents=["Describe this image", image] + ) + + assert_valid_image_response(response) + + @skip_if_no_api_key("google") + def test_09_multiple_images(self, google_client, test_config): + """Test Case 9: Multiple image analysis""" + image1 = load_image_from_url(IMAGE_URL) + image2 = load_image_from_url(f"data:image/png;base64,{BASE64_IMAGE}") + + response = google_client.models.generate_content( + model=get_model("google", "vision"), + contents=["Compare these two images", image1, image2], + ) + + assert_valid_image_response(response) + content = response.text.lower() + # Should mention comparison or differences + assert any( + word in content for word in COMPARISON_KEYWORDS + ), f"Response should contain comparison keywords. Got content: {content}" + + @skip_if_no_api_key("google") + def test_10_complex_end2end(self, google_client, test_config): + """Test Case 10: Complex end-to-end with conversation, images, and tools""" + from google.genai import types + + tools = convert_to_google_tools([WEATHER_TOOL]) + + image = load_image_from_url(IMAGE_URL) + + # Start complex conversation + chat = google_client.chats.create(model=get_model("google", "vision")) + + response1 = chat.send_message( + [ + "First, can you tell me what's in this image and then get the weather for the location shown?", + image, + ], + config=types.GenerateContentConfig(tools=tools), + ) + + # Should either describe image or call weather tool (or both) + assert response1.candidates is not None + + # Check for function calls and handle them + if hasattr(response1, "function_calls") and response1.function_calls: + for fc in response1.function_calls: + if fc.name == "get_weather": + # Send function result back + final_response = chat.send_message( + types.Part.from_function_response( + name=fc.name, + response={"result": "The weather is 72Β°F and sunny."}, + ) + ) + assert_valid_chat_response(final_response) + + @skip_if_no_api_key("google") + def test_11_integration_specific_features(self, google_client, test_config): + """Test Case 11: Google GenAI-specific features""" + + # Test 1: Generation config with temperature + from google.genai import types + + response1 = google_client.models.generate_content( + model=get_model("google", "chat"), + contents="Tell me a creative story in one sentence.", + config=types.GenerateContentConfig(temperature=0.9, max_output_tokens=100), + ) + + assert_valid_chat_response(response1) + + # Test 2: Safety settings + response2 = google_client.models.generate_content( + model=get_model("google", "chat"), + contents="Hello, how are you?", + config=types.GenerateContentConfig( + safety_settings=[ + types.SafetySetting( + category="HARM_CATEGORY_HARASSMENT", + threshold="BLOCK_MEDIUM_AND_ABOVE", + ) + ] + ), + ) + + assert_valid_chat_response(response2) + + # Test 3: System instruction + response3 = google_client.models.generate_content( + model=get_model("google", "chat"), + contents="high", + config=types.GenerateContentConfig( + system_instruction="I say high, you say low", + max_output_tokens=10, + ), + ) + + assert_valid_chat_response(response3) + + @skip_if_no_api_key("google") + def test_12_error_handling_invalid_roles(self, google_client, test_config): + """Test Case 12: Error handling for invalid roles""" + with pytest.raises(Exception) as exc_info: + google_client.models.generate_content( + model=get_model("google", "chat"), contents=GENAI_INVALID_ROLE_CONTENT + ) + + # Verify the error is properly caught and contains role-related information + error = exc_info.value + assert_valid_error_response(error, "tester") + assert_error_propagation(error, "google") + + @skip_if_no_api_key("google") + def test_13_streaming(self, google_client, test_config): + """Test Case 13: Streaming chat completion using Google GenAI SDK""" + + # Use the correct Google GenAI SDK streaming method + stream = google_client.models.generate_content_stream( + model=get_model("google", "chat"), + contents="Tell me a short story about a robot", + ) + + content = "" + chunk_count = 0 + + # Collect streaming content + for chunk in stream: + chunk_count += 1 + if chunk.text: + content += chunk.text + + # Validate streaming results + assert chunk_count > 0, "Should receive at least one chunk" + assert len(content) > 10, "Should receive substantial content" + + # Check for robot-related terms (the story might not use the exact word "robot") + robot_terms = [ + "robot", + "metallic", + "programmed", + "unit", + "custodian", + "mechanical", + "android", + "machine", + ] + has_robot_content = any(term in content.lower() for term in robot_terms) + assert ( + has_robot_content + ), f"Content should relate to robots. Found content: {content[:200]}..." + + print( + f"βœ… Streaming test passed: {chunk_count} chunks, {len(content)} characters" + ) + + @skip_if_no_api_key("google") + def test_14_single_text_embedding(self, google_client, test_config): + """Test Case 21: Single text embedding generation""" + response = google_client.models.embed_content( + model="gemini-embedding-001", contents=EMBEDDINGS_SINGLE_TEXT, + config=types.EmbedContentConfig(output_dimensionality=1536) + ) + + assert_valid_embedding_response(response, expected_dimensions=1536) + + # Verify response structure + assert len(response.embeddings) == 1, "Should have exactly one embedding" + + +# Additional helper functions specific to Google GenAI +def extract_google_function_calls(response: Any) -> List[Dict[str, Any]]: + """Extract function calls from Google GenAI response format with proper type checking""" + function_calls = [] + + # Type check for Google GenAI response + if not hasattr(response, "function_calls") or not response.function_calls: + return function_calls + + for fc in response.function_calls: + if hasattr(fc, "name") and hasattr(fc, "args"): + try: + function_calls.append( + { + "name": fc.name, + "arguments": dict(fc.args) if fc.args else {}, + } + ) + except (AttributeError, TypeError) as e: + print(f"Warning: Failed to extract Google function call: {e}") + continue + + return function_calls diff --git a/tests/integrations/tests/integrations/test_langchain.py b/tests/integrations/tests/integrations/test_langchain.py new file mode 100644 index 000000000..dbbff9cc8 --- /dev/null +++ b/tests/integrations/tests/integrations/test_langchain.py @@ -0,0 +1,924 @@ +""" +LangChain Integration Tests + +🦜 LANGCHAIN COMPONENTS TESTED: +- Chat Models: OpenAI ChatOpenAI, Anthropic ChatAnthropic, Google ChatVertexAI +- Provider-Specific: Google ChatGoogleGenerativeAI, Mistral ChatMistralAI +- Embeddings: OpenAI OpenAIEmbeddings, Google VertexAIEmbeddings +- Tools: Function calling and tool integration +- Chains: LLMChain, ConversationChain, SequentialChain +- Memory: ConversationBufferMemory, ConversationSummaryMemory +- Agents: OpenAI Functions Agent, ReAct Agent +- Streaming: Real-time response streaming +- Vector Stores: Integration with embeddings and retrieval + +Tests LangChain standard interface compliance and Bifrost integration: +1. Chat model standard tests (via LangChain test suite) +2. Embeddings standard tests (via LangChain test suite) +3. Tool integration and function calling +4. Chain composition and execution +5. Memory management and conversation history +6. Agent reasoning and tool usage +7. Streaming responses and async operations +8. Vector store operations +9. Multi-provider compatibility +10. Error handling and fallbacks +11. LangChain Expression Language (LCEL) +12. Google Gemini integration via langchain-google-genai +13. Mistral AI integration via langchain-mistralai +14. Provider-specific streaming capabilities +15. Cross-provider response comparison +""" + +import pytest +import asyncio +import os +from typing import List, Dict, Any, Type, Optional +from unittest.mock import patch + +# LangChain core imports +from langchain_core.messages import HumanMessage, AIMessage, SystemMessage +from langchain_core.tools import BaseTool +from langchain_core.prompts import ChatPromptTemplate, HumanMessagePromptTemplate +from langchain_core.output_parsers import StrOutputParser +from langchain_core.runnables import RunnablePassthrough + +# LangChain provider imports +from langchain_openai import ChatOpenAI, OpenAIEmbeddings +from langchain_anthropic import ChatAnthropic + +# Optional imports for providers that may not be available +try: + from langchain_google_vertexai import ChatVertexAI, VertexAIEmbeddings + + GOOGLE_VERTEXAI_AVAILABLE = True +except ImportError: + GOOGLE_VERTEXAI_AVAILABLE = False + ChatVertexAI = None + VertexAIEmbeddings = None + +# Google Gemini specific imports +try: + from langchain_google_genai import ChatGoogleGenerativeAI + + GOOGLE_GENAI_AVAILABLE = True +except ImportError: + GOOGLE_GENAI_AVAILABLE = False + ChatGoogleGenerativeAI = None + +# Mistral specific imports +try: + from langchain_mistralai import ChatMistralAI + + MISTRAL_AI_AVAILABLE = True +except ImportError: + MISTRAL_AI_AVAILABLE = False + ChatMistralAI = None + +# Optional imports for legacy LangChain (chains, memory, agents) +try: + from langchain.chains import LLMChain, ConversationChain, SequentialChain + from langchain.memory import ConversationBufferMemory, ConversationSummaryMemory + from langchain.agents import ( + AgentExecutor, + create_openai_functions_agent, + create_react_agent, + ) + from langchain.agents.tools import Tool + + LEGACY_LANGCHAIN_AVAILABLE = True +except ImportError: + LEGACY_LANGCHAIN_AVAILABLE = False + LLMChain = ConversationChain = SequentialChain = None + ConversationBufferMemory = ConversationSummaryMemory = None + AgentExecutor = create_openai_functions_agent = create_react_agent = Tool = None + +# LangChain standard tests (if available) +try: + from langchain_tests.integration_tests import ChatModelIntegrationTests + from langchain_tests.integration_tests import EmbeddingsIntegrationTests + + LANGCHAIN_TESTS_AVAILABLE = True +except ImportError: + # Fallback for environments without langchain-tests + LANGCHAIN_TESTS_AVAILABLE = False + + class ChatModelIntegrationTests: + pass + + class EmbeddingsIntegrationTests: + pass + + +from ..utils.common import ( + Config, + SIMPLE_CHAT_MESSAGES, + MULTI_TURN_MESSAGES, + WEATHER_TOOL, + CALCULATOR_TOOL, + EMBEDDINGS_SINGLE_TEXT, + EMBEDDINGS_MULTIPLE_TEXTS, + EMBEDDINGS_SIMILAR_TEXTS, + mock_tool_response, + assert_valid_chat_response, + assert_valid_embedding_response, + assert_valid_embeddings_batch_response, + calculate_cosine_similarity, + get_api_key, + skip_if_no_api_key, + WEATHER_KEYWORDS, + LOCATION_KEYWORDS, +) +from ..utils.config_loader import get_model, get_integration_url, get_config + + +@pytest.fixture +def test_config(): + """Test configuration""" + return Config() + + +@pytest.fixture(autouse=True) +def setup_langchain(): + """Setup LangChain with Bifrost configuration and dummy credentials""" + # Set dummy credentials since Bifrost handles actual authentication + os.environ["OPENAI_API_KEY"] = "dummy-openai-key-bifrost-handles-auth" + os.environ["ANTHROPIC_API_KEY"] = "dummy-anthropic-key-bifrost-handles-auth" + os.environ["GOOGLE_API_KEY"] = "dummy-google-api-key-bifrost-handles-auth" + os.environ["VERTEX_PROJECT"] = "dummy-vertex-project" + os.environ["VERTEX_LOCATION"] = "us-central1" + + # Get Bifrost URL for LangChain + base_url = get_integration_url("langchain") + config = get_config() + integration_settings = config.get_integration_settings("langchain") + + # Store original base URLs and set Bifrost URLs + original_openai_base = os.environ.get("OPENAI_BASE_URL") + original_anthropic_base = os.environ.get("ANTHROPIC_BASE_URL") + + if base_url: + # Configure provider base URLs to route through Bifrost + os.environ["OPENAI_BASE_URL"] = f"{base_url}/v1" + os.environ["ANTHROPIC_BASE_URL"] = f"{base_url}/v1" + + yield + + # Cleanup: restore original URLs + if original_openai_base: + os.environ["OPENAI_BASE_URL"] = original_openai_base + else: + os.environ.pop("OPENAI_BASE_URL", None) + + if original_anthropic_base: + os.environ["ANTHROPIC_BASE_URL"] = original_anthropic_base + else: + os.environ.pop("ANTHROPIC_BASE_URL", None) + + +def create_langchain_tool_from_dict(tool_dict: Dict[str, Any]): + """Convert common tool format to LangChain Tool""" + if not LEGACY_LANGCHAIN_AVAILABLE: + return None + + def tool_func(**kwargs): + return mock_tool_response(tool_dict["name"], kwargs) + + return Tool( + name=tool_dict["name"], + description=tool_dict["description"], + func=tool_func, + ) + + +class TestLangChainChatOpenAI(ChatModelIntegrationTests): + """Standard LangChain tests for ChatOpenAI through Bifrost""" + + @property + def chat_model_class(self) -> Type[ChatOpenAI]: + return ChatOpenAI + + @property + def chat_model_params(self) -> dict: + return { + "model": get_model("langchain", "chat"), + "temperature": 0.7, + "max_tokens": 100, + "base_url": ( + get_integration_url("langchain") + if get_integration_url("langchain") + else None + ), + } + + +class TestLangChainOpenAIEmbeddings(EmbeddingsIntegrationTests): + """Standard LangChain tests for OpenAI Embeddings through Bifrost""" + + @property + def embeddings_class(self) -> Type[OpenAIEmbeddings]: + return OpenAIEmbeddings + + @property + def embeddings_params(self) -> dict: + return { + "model": get_model("langchain", "embeddings"), + "base_url": ( + get_integration_url("langchain") + if get_integration_url("langchain") + else None + ), + } + + +class TestLangChainIntegration: + """Comprehensive LangChain integration tests through Bifrost""" + + def test_01_chat_openai_basic(self, test_config): + """Test Case 1: Basic ChatOpenAI functionality""" + try: + chat = ChatOpenAI( + model=get_model("langchain", "chat"), + temperature=0.7, + max_tokens=100, + base_url=( + get_integration_url("langchain") + if get_integration_url("langchain") + else None + ), + ) + + messages = [HumanMessage(content="Hello! How are you today?")] + response = chat.invoke(messages) + + assert isinstance(response, AIMessage) + assert response.content is not None + assert len(response.content) > 0 + + except Exception as e: + pytest.skip(f"ChatOpenAI through LangChain not available: {e}") + + def test_02_chat_anthropic_basic(self, test_config): + """Test Case 2: Basic ChatAnthropic functionality""" + try: + chat = ChatAnthropic( + model="claude-3-haiku-20240307", + temperature=0.7, + max_tokens=100, + base_url=( + get_integration_url("langchain") + if get_integration_url("langchain") + else None + ), + ) + + messages = [ + HumanMessage(content="Explain machine learning in one sentence.") + ] + response = chat.invoke(messages) + + assert isinstance(response, AIMessage) + assert response.content is not None + assert any( + word in response.content.lower() + for word in ["machine", "learning", "data", "algorithm"] + ) + + except Exception as e: + pytest.skip(f"ChatAnthropic through LangChain not available: {e}") + + def test_03_openai_embeddings_basic(self, test_config): + """Test Case 3: Basic OpenAI embeddings functionality""" + try: + embeddings = OpenAIEmbeddings( + model=get_model("langchain", "embeddings"), + base_url=( + get_integration_url("langchain") + if get_integration_url("langchain") + else None + ), + ) + + # Test single embedding + result = embeddings.embed_query(EMBEDDINGS_SINGLE_TEXT) + + assert isinstance(result, list) + assert len(result) > 0 + assert all(isinstance(x, float) for x in result) + + # Test batch embeddings + batch_result = embeddings.embed_documents(EMBEDDINGS_MULTIPLE_TEXTS) + + assert isinstance(batch_result, list) + assert len(batch_result) == len(EMBEDDINGS_MULTIPLE_TEXTS) + assert all(isinstance(embedding, list) for embedding in batch_result) + + except Exception as e: + pytest.skip(f"OpenAI embeddings through LangChain not available: {e}") + + @pytest.mark.skipif( + not LEGACY_LANGCHAIN_AVAILABLE, reason="Legacy LangChain package not available" + ) + def test_04_function_calling_tools(self, test_config): + """Test Case 4: Function calling with tools""" + try: + chat = ChatOpenAI( + model=get_model("langchain", "tools"), + temperature=0, + base_url=( + get_integration_url("langchain") + if get_integration_url("langchain") + else None + ), + ) + + # Create tools + weather_tool = create_langchain_tool_from_dict(WEATHER_TOOL) + calculator_tool = create_langchain_tool_from_dict(CALCULATOR_TOOL) + tools = [weather_tool, calculator_tool] + + # Bind tools to the model + chat_with_tools = chat.bind_tools(tools) + + # Test tool calling + response = chat_with_tools.invoke( + [HumanMessage(content="What's the weather in Boston?")] + ) + + assert isinstance(response, AIMessage) + # Should either have tool calls or mention the location + has_tool_calls = hasattr(response, "tool_calls") and response.tool_calls + mentions_location = any( + word in response.content.lower() + for word in LOCATION_KEYWORDS + WEATHER_KEYWORDS + ) + + assert ( + has_tool_calls or mentions_location + ), "Should use tools or mention weather/location" + + except Exception as e: + pytest.skip(f"Function calling through LangChain not available: {e}") + + def test_05_llm_chain_basic(self, test_config): + """Test Case 5: Basic LLM Chain functionality""" + try: + llm = ChatOpenAI( + model=get_model("langchain", "chat"), + temperature=0.7, + max_tokens=100, + base_url=( + get_integration_url("langchain") + if get_integration_url("langchain") + else None + ), + ) + + prompt = ChatPromptTemplate.from_messages( + [ + ( + "system", + "You are a helpful assistant that explains concepts clearly.", + ), + ("human", "Explain {topic} in simple terms."), + ] + ) + + chain = prompt | llm | StrOutputParser() + + result = chain.invoke({"topic": "machine learning"}) + + assert isinstance(result, str) + assert len(result) > 0 + assert any( + word in result.lower() for word in ["machine", "learning", "data"] + ) + + except Exception as e: + pytest.skip(f"LLM Chain through LangChain not available: {e}") + + @pytest.mark.skipif( + not LEGACY_LANGCHAIN_AVAILABLE, reason="Legacy LangChain package not available" + ) + def test_06_conversation_memory(self, test_config): + """Test Case 6: Conversation memory functionality""" + try: + llm = ChatOpenAI( + model=get_model("langchain", "chat"), + temperature=0.7, + max_tokens=150, + base_url=( + get_integration_url("langchain") + if get_integration_url("langchain") + else None + ), + ) + + memory = ConversationBufferMemory() + conversation = ConversationChain(llm=llm, memory=memory, verbose=False) + + # First interaction + response1 = conversation.predict( + input="My name is Alice. What's the capital of France?" + ) + assert "Paris" in response1 or "paris" in response1.lower() + + # Second interaction - should remember the name + response2 = conversation.predict(input="What's my name?") + assert "Alice" in response2 or "alice" in response2.lower() + + except Exception as e: + pytest.skip(f"Conversation memory through LangChain not available: {e}") + + def test_07_streaming_responses(self, test_config): + """Test Case 7: Streaming response functionality""" + try: + chat = ChatOpenAI( + model=get_model("langchain", "chat"), + temperature=0.7, + max_tokens=100, + streaming=True, + base_url=( + get_integration_url("langchain") + if get_integration_url("langchain") + else None + ), + ) + + messages = [HumanMessage(content="Tell me a short story about a robot.")] + + # Collect streaming chunks + chunks = [] + for chunk in chat.stream(messages): + chunks.append(chunk) + + assert len(chunks) > 0, "Should receive streaming chunks" + + # Combine chunks to get full response + full_content = "".join(chunk.content for chunk in chunks if chunk.content) + assert len(full_content) > 0, "Should have content from streaming" + assert any(word in full_content.lower() for word in ["robot", "story"]) + + except Exception as e: + pytest.skip(f"Streaming through LangChain not available: {e}") + + def test_08_multi_provider_chain(self, test_config): + """Test Case 8: Chain with multiple provider models""" + try: + # Create different provider models + openai_chat = ChatOpenAI( + model="gpt-3.5-turbo", + temperature=0.5, + max_tokens=50, + base_url=( + get_integration_url("langchain") + if get_integration_url("langchain") + else None + ), + ) + + anthropic_chat = ChatAnthropic( + model="claude-3-haiku-20240307", + temperature=0.5, + max_tokens=50, + base_url=( + get_integration_url("langchain") + if get_integration_url("langchain") + else None + ), + ) + + # Test both models work + message = [HumanMessage(content="What is AI? Answer in one sentence.")] + + openai_response = openai_chat.invoke(message) + anthropic_response = anthropic_chat.invoke(message) + + assert isinstance(openai_response, AIMessage) + assert isinstance(anthropic_response, AIMessage) + assert ( + openai_response.content != anthropic_response.content + ) # Should be different responses + + except Exception as e: + pytest.skip(f"Multi-provider chains through LangChain not available: {e}") + + def test_09_embeddings_similarity(self, test_config): + """Test Case 9: Embeddings similarity analysis""" + try: + embeddings = OpenAIEmbeddings( + model=get_model("langchain", "embeddings"), + base_url=( + get_integration_url("langchain") + if get_integration_url("langchain") + else None + ), + ) + + # Get embeddings for similar texts + similar_embeddings = embeddings.embed_documents(EMBEDDINGS_SIMILAR_TEXTS) + + # Calculate similarities + similarity_1_2 = calculate_cosine_similarity( + similar_embeddings[0], similar_embeddings[1] + ) + similarity_1_3 = calculate_cosine_similarity( + similar_embeddings[0], similar_embeddings[2] + ) + + # Similar texts should have high similarity + assert ( + similarity_1_2 > 0.7 + ), f"Similar texts should have high similarity, got {similarity_1_2:.4f}" + assert ( + similarity_1_3 > 0.7 + ), f"Similar texts should have high similarity, got {similarity_1_3:.4f}" + + except Exception as e: + pytest.skip(f"Embeddings similarity through LangChain not available: {e}") + + def test_10_async_operations(self, test_config): + """Test Case 10: Async operation support""" + + async def async_test(): + try: + chat = ChatOpenAI( + model=get_model("langchain", "chat"), + temperature=0.7, + max_tokens=100, + base_url=( + get_integration_url("langchain") + if get_integration_url("langchain") + else None + ), + ) + + messages = [HumanMessage(content="Hello from async!")] + response = await chat.ainvoke(messages) + + assert isinstance(response, AIMessage) + assert response.content is not None + assert len(response.content) > 0 + + return True + + except Exception as e: + pytest.skip(f"Async operations through LangChain not available: {e}") + return False + + # Run async test + result = asyncio.run(async_test()) + if result is not False: # Skip if not explicitly skipped + assert result is True + + def test_11_error_handling(self, test_config): + """Test Case 11: Error handling and fallbacks""" + try: + # Test with invalid model name + chat = ChatOpenAI( + model="invalid-model-name-should-fail", + temperature=0.7, + max_tokens=100, + base_url=( + get_integration_url("langchain") + if get_integration_url("langchain") + else None + ), + ) + + messages = [HumanMessage(content="This should fail gracefully.")] + + with pytest.raises(Exception) as exc_info: + chat.invoke(messages) + + # Should get a meaningful error + error_message = str(exc_info.value).lower() + assert any( + word in error_message + for word in ["model", "error", "invalid", "not found"] + ) + + except Exception as e: + pytest.skip(f"Error handling test through LangChain not available: {e}") + + def test_12_langchain_expression_language(self, test_config): + """Test Case 12: LangChain Expression Language (LCEL)""" + try: + llm = ChatOpenAI( + model=get_model("langchain", "chat"), + temperature=0.7, + max_tokens=100, + base_url=( + get_integration_url("langchain") + if get_integration_url("langchain") + else None + ), + ) + + prompt = ChatPromptTemplate.from_template("Tell me a joke about {topic}") + output_parser = StrOutputParser() + + # Create chain using LCEL + chain = prompt | llm | output_parser + + result = chain.invoke({"topic": "programming"}) + + assert isinstance(result, str) + assert len(result) > 0 + assert any( + word in result.lower() for word in ["programming", "code", "joke"] + ) + + except Exception as e: + pytest.skip(f"LCEL through LangChain not available: {e}") + + @pytest.mark.skipif( + not GOOGLE_GENAI_AVAILABLE, + reason="langchain-google-genai package not available", + ) + def test_13_gemini_chat_integration(self, test_config): + """Test Case 13: Google Gemini chat via LangChain""" + try: + # Use ChatGoogleGenerativeAI with Bifrost routing + chat = ChatGoogleGenerativeAI( + model="gemini-1.5-flash", + google_api_key="dummy-google-api-key-bifrost-handles-auth", + temperature=0.7, + max_tokens=100, + ) + + # Patch the base URL to route through Bifrost + base_url = get_integration_url("langchain") + if base_url: + # For Gemini through Bifrost, we need to route to the genai endpoint + with patch.object(chat, "_client") as mock_client: + # Set up mock to route to Bifrost + mock_client.base_url = f"{base_url}/v1beta" + + messages = [HumanMessage(content="Write a haiku about technology.")] + response = chat.invoke(messages) + + assert isinstance(response, AIMessage) + assert response.content is not None + assert len(response.content) > 0 + assert any( + word in response.content.lower() + for word in ["tech", "digital", "future", "machine"] + ) + else: + pytest.skip("Bifrost URL not configured for LangChain integration") + + except Exception as e: + pytest.skip(f"Gemini through LangChain not available: {e}") + + @pytest.mark.skipif( + not MISTRAL_AI_AVAILABLE, reason="langchain-mistralai package not available" + ) + def test_14_mistral_chat_integration(self, test_config): + """Test Case 14: Mistral AI chat via LangChain""" + try: + # Mistral is OpenAI-compatible, so it can route through Bifrost easily + base_url = get_integration_url("langchain") + if base_url: + chat = ChatMistralAI( + model="mistral-7b-instruct", + mistral_api_key="dummy-mistral-api-key-bifrost-handles-auth", + endpoint=f"{base_url}/v1", # Route through Bifrost + temperature=0.7, + max_tokens=100, + ) + + messages = [ + HumanMessage(content="Explain quantum computing in simple terms.") + ] + response = chat.invoke(messages) + + assert isinstance(response, AIMessage) + assert response.content is not None + assert len(response.content) > 0 + assert any( + word in response.content.lower() + for word in ["quantum", "computing", "bit", "science"] + ) + else: + pytest.skip("Bifrost URL not configured for LangChain integration") + + except Exception as e: + pytest.skip(f"Mistral through LangChain not available: {e}") + + @pytest.mark.skipif( + not GOOGLE_GENAI_AVAILABLE, + reason="langchain-google-genai package not available", + ) + def test_15_gemini_streaming(self, test_config): + """Test Case 15: Gemini streaming responses via LangChain""" + try: + chat = ChatGoogleGenerativeAI( + model="gemini-1.5-flash", + google_api_key="dummy-google-api-key-bifrost-handles-auth", + temperature=0.7, + max_tokens=100, + streaming=True, + ) + + base_url = get_integration_url("langchain") + if base_url: + with patch.object(chat, "_client") as mock_client: + mock_client.base_url = f"{base_url}/v1beta" + + messages = [ + HumanMessage(content="Tell me about artificial intelligence.") + ] + + # Collect streaming chunks + chunks = [] + for chunk in chat.stream(messages): + chunks.append(chunk) + + assert len(chunks) > 0, "Should receive streaming chunks" + + # Combine chunks to get full response + full_content = "".join( + chunk.content for chunk in chunks if chunk.content + ) + assert len(full_content) > 0, "Should have content from streaming" + assert any( + word in full_content.lower() + for word in ["artificial", "intelligence", "ai"] + ) + else: + pytest.skip("Bifrost URL not configured for LangChain integration") + + except Exception as e: + pytest.skip(f"Gemini streaming through LangChain not available: {e}") + + @pytest.mark.skipif( + not MISTRAL_AI_AVAILABLE, reason="langchain-mistralai package not available" + ) + def test_16_mistral_streaming(self, test_config): + """Test Case 16: Mistral streaming responses via LangChain""" + try: + base_url = get_integration_url("langchain") + if base_url: + chat = ChatMistralAI( + model="mistral-7b-instruct", + mistral_api_key="dummy-mistral-api-key-bifrost-handles-auth", + endpoint=f"{base_url}/v1", + temperature=0.7, + max_tokens=100, + streaming=True, + ) + + messages = [ + HumanMessage(content="Describe machine learning algorithms.") + ] + + # Collect streaming chunks + chunks = [] + for chunk in chat.stream(messages): + chunks.append(chunk) + + assert len(chunks) > 0, "Should receive streaming chunks" + + # Combine chunks to get full response + full_content = "".join( + chunk.content for chunk in chunks if chunk.content + ) + assert len(full_content) > 0, "Should have content from streaming" + assert any( + word in full_content.lower() + for word in ["machine", "learning", "algorithm"] + ) + else: + pytest.skip("Bifrost URL not configured for LangChain integration") + + except Exception as e: + pytest.skip(f"Mistral streaming through LangChain not available: {e}") + + def test_17_multi_provider_langchain_comparison(self, test_config): + """Test Case 17: Compare responses across multiple LangChain providers""" + providers_tested = [] + responses = {} + + # Test OpenAI + try: + openai_chat = ChatOpenAI( + model="gpt-3.5-turbo", + temperature=0.5, + max_tokens=50, + base_url=( + get_integration_url("langchain") + if get_integration_url("langchain") + else None + ), + ) + + message = [ + HumanMessage( + content="What is the future of AI? Answer in one sentence." + ) + ] + responses["openai"] = openai_chat.invoke(message) + providers_tested.append("OpenAI") + + except Exception: + pass + + # Test Anthropic + try: + anthropic_chat = ChatAnthropic( + model="claude-3-haiku-20240307", + temperature=0.5, + max_tokens=50, + base_url=( + get_integration_url("langchain") + if get_integration_url("langchain") + else None + ), + ) + + responses["anthropic"] = anthropic_chat.invoke(message) + providers_tested.append("Anthropic") + + except Exception: + pass + + # Test Gemini (if available) + if GOOGLE_GENAI_AVAILABLE: + try: + gemini_chat = ChatGoogleGenerativeAI( + model="gemini-1.5-flash", + google_api_key="dummy-google-api-key-bifrost-handles-auth", + temperature=0.5, + max_tokens=50, + ) + + base_url = get_integration_url("langchain") + if base_url: + with patch.object(gemini_chat, "_client") as mock_client: + mock_client.base_url = f"{base_url}/v1beta" + responses["gemini"] = gemini_chat.invoke(message) + providers_tested.append("Gemini") + + except Exception: + pass + + # Test Mistral (if available) + if MISTRAL_AI_AVAILABLE: + try: + base_url = get_integration_url("langchain") + if base_url: + mistral_chat = ChatMistralAI( + model="mistral-7b-instruct", + mistral_api_key="dummy-mistral-api-key-bifrost-handles-auth", + endpoint=f"{base_url}/v1", + temperature=0.5, + max_tokens=50, + ) + + responses["mistral"] = mistral_chat.invoke(message) + providers_tested.append("Mistral") + + except Exception: + pass + + # Verify we tested at least 2 providers + assert ( + len(providers_tested) >= 2 + ), f"Should test at least 2 providers, got: {providers_tested}" + + # Verify all responses are valid + for provider, response in responses.items(): + assert isinstance( + response, AIMessage + ), f"{provider} should return AIMessage" + assert response.content is not None, f"{provider} should have content" + assert ( + len(response.content) > 0 + ), f"{provider} should have non-empty content" + + # Verify responses are different (providers should give unique answers) + response_contents = [resp.content for resp in responses.values()] + unique_responses = set(response_contents) + assert ( + len(unique_responses) > 1 + ), "Different providers should give different responses" + + +# Skip standard tests if langchain-tests is not available +@pytest.mark.skipif( + not LANGCHAIN_TESTS_AVAILABLE, reason="langchain-tests package not available" +) +class TestLangChainStandardChatModel(TestLangChainChatOpenAI): + """Run LangChain's standard chat model tests""" + + pass + + +@pytest.mark.skipif( + not LANGCHAIN_TESTS_AVAILABLE, reason="langchain-tests package not available" +) +class TestLangChainStandardEmbeddings(TestLangChainOpenAIEmbeddings): + """Run LangChain's standard embeddings tests""" + + pass diff --git a/tests/integrations/tests/integrations/test_litellm.py b/tests/integrations/tests/integrations/test_litellm.py new file mode 100644 index 000000000..a0cdfd9f3 --- /dev/null +++ b/tests/integrations/tests/integrations/test_litellm.py @@ -0,0 +1,705 @@ +""" +LiteLLM Integration Tests + +πŸ€– MODELS USED: +- Chat: gpt-3.5-turbo (OpenAI via LiteLLM) +- Vision: gpt-4o (OpenAI via LiteLLM) +- Tools: gpt-3.5-turbo (OpenAI via LiteLLM) +- Speech: tts-1 (OpenAI via LiteLLM) +- Transcription: whisper-1 (OpenAI via LiteLLM) +- Embeddings: text-embedding-3-small (OpenAI via LiteLLM) +- Alternatives: claude-3-haiku-20240307, gemini-pro, mistral-7b-instruct, gpt-4, command-r-plus + +Tests all 19 core scenarios using LiteLLM SDK directly: +1. Simple chat +2. Multi turn conversation +3. Tool calls +4. Multiple tool calls +5. End2End tool calling +6. Automatic function calling +7. Image (url) +8. Image (base64) +9. Multiple images +10. Complete end2end test with conversation history, tool calls, tool results and images +11. Integration specific tests +12. Error handling +13. Streaming +14. Google Gemini integration +15. Mistral integration +16. OpenAI embeddings via LiteLLM +17. OpenAI speech synthesis via LiteLLM +18. OpenAI transcription via LiteLLM +19. Multi-provider comparison +""" + +import pytest +import json +import litellm +from typing import List, Dict, Any + +from ..utils.common import ( + Config, + SIMPLE_CHAT_MESSAGES, + MULTI_TURN_MESSAGES, + SINGLE_TOOL_CALL_MESSAGES, + MULTIPLE_TOOL_CALL_MESSAGES, + IMAGE_URL_MESSAGES, + IMAGE_BASE64_MESSAGES, + MULTIPLE_IMAGES_MESSAGES, + COMPLEX_E2E_MESSAGES, + INVALID_ROLE_MESSAGES, + STREAMING_CHAT_MESSAGES, + STREAMING_TOOL_CALL_MESSAGES, + WEATHER_TOOL, + CALCULATOR_TOOL, + mock_tool_response, + assert_valid_chat_response, + assert_has_tool_calls, + assert_valid_image_response, + assert_valid_error_response, + assert_error_propagation, + assert_valid_streaming_response, + collect_streaming_content, + extract_tool_calls, + get_api_key, + skip_if_no_api_key, + COMPARISON_KEYWORDS, + WEATHER_KEYWORDS, + LOCATION_KEYWORDS, + # Audio and embeddings test data + EMBEDDINGS_SINGLE_TEXT, + EMBEDDINGS_MULTIPLE_TEXTS, + EMBEDDINGS_SIMILAR_TEXTS, + SPEECH_TEST_INPUT, + generate_test_audio, + assert_valid_speech_response, + assert_valid_transcription_response, + assert_valid_embedding_response, + assert_valid_embeddings_batch_response, + calculate_cosine_similarity, + collect_streaming_transcription_content, +) +from ..utils.config_loader import get_model + + +@pytest.fixture +def test_config(): + """Test configuration""" + return Config() + + +@pytest.fixture(autouse=True) +def setup_litellm(): + """Setup LiteLLM with Bifrost configuration and dummy credentials""" + import os + from ..utils.config_loader import get_integration_url, get_config + + # Set dummy credentials since Bifrost handles actual authentication + os.environ["OPENAI_API_KEY"] = "dummy-openai-key-bifrost-handles-auth" + os.environ["ANTHROPIC_API_KEY"] = "dummy-anthropic-key-bifrost-handles-auth" + os.environ["MISTRAL_API_KEY"] = "dummy-mistral-key-bifrost-handles-auth" + + # For Google, set all possible API key environment variables + os.environ["GOOGLE_API_KEY"] = "dummy-google-api-key-bifrost-handles-auth" + os.environ["GEMINI_API_KEY"] = "dummy-gemini-api-key-bifrost-handles-auth" + os.environ["VERTEX_PROJECT"] = "dummy-vertex-project" + os.environ["VERTEX_LOCATION"] = "us-central1" + + # Get Bifrost URL for LiteLLM + base_url = get_integration_url("litellm") + config = get_config() + integration_settings = config.get_integration_settings("litellm") + api_config = config.get_api_config() + + # Configure LiteLLM globally + if base_url: + litellm.api_base = base_url + + # Set timeout and other settings + litellm.request_timeout = api_config.get("timeout", 30) + + # Apply integration-specific settings + if integration_settings.get("drop_params"): + litellm.drop_params = integration_settings["drop_params"] + if integration_settings.get("debug"): + litellm.set_verbose = integration_settings["debug"] + + +def convert_to_litellm_tools(tools: List[Dict[str, Any]]) -> List[Dict[str, Any]]: + """Convert common tool format to LiteLLM format (OpenAI-compatible)""" + return [{"type": "function", "function": tool} for tool in tools] + + +class TestLiteLLMIntegration: + """Test suite for LiteLLM integration covering all 11 core scenarios""" + + def test_01_simple_chat(self, test_config): + """Test Case 1: Simple chat interaction""" + response = litellm.completion( + model=get_model("litellm", "chat"), + messages=SIMPLE_CHAT_MESSAGES, + max_tokens=100, + ) + + assert_valid_chat_response(response) + assert response.choices[0].message.content is not None + assert len(response.choices[0].message.content) > 0 + + def test_02_multi_turn_conversation(self, test_config): + """Test Case 2: Multi-turn conversation""" + response = litellm.completion( + model=get_model("litellm", "chat"), + messages=MULTI_TURN_MESSAGES, + max_tokens=150, + ) + + assert_valid_chat_response(response) + content = response.choices[0].message.content.lower() + # Should mention population or numbers since we asked about Paris population + assert any( + word in content + for word in ["population", "million", "people", "inhabitants"] + ) + + def test_03_single_tool_call(self, test_config): + """Test Case 3: Single tool call""" + tools = convert_to_litellm_tools([WEATHER_TOOL]) + + response = litellm.completion( + model=get_model("litellm", "chat"), + messages=SINGLE_TOOL_CALL_MESSAGES, + tools=tools, + max_tokens=100, + ) + + assert_has_tool_calls(response, expected_count=1) + tool_calls = extract_tool_calls(response) + assert tool_calls[0]["name"] == "get_weather" + assert "location" in tool_calls[0]["arguments"] + + def test_04_multiple_tool_calls(self, test_config): + """Test Case 4: Multiple tool calls in one response""" + tools = convert_to_litellm_tools([WEATHER_TOOL, CALCULATOR_TOOL]) + + response = litellm.completion( + model=get_model("litellm", "chat"), + messages=MULTIPLE_TOOL_CALL_MESSAGES, + tools=tools, + max_tokens=200, + ) + + assert_has_tool_calls(response, expected_count=2) + tool_calls = extract_tool_calls(response) + tool_names = [tc["name"] for tc in tool_calls] + assert "get_weather" in tool_names + assert "calculate" in tool_names + + def test_05_end2end_tool_calling(self, test_config): + """Test Case 5: Complete tool calling flow with responses""" + messages = [{"role": "user", "content": "What's the weather in Boston?"}] + tools = convert_to_litellm_tools([WEATHER_TOOL]) + + response = litellm.completion( + model=get_model("litellm", "chat"), + messages=messages, + tools=tools, + max_tokens=100, + ) + + assert_has_tool_calls(response, expected_count=1) + + # Add assistant's tool call to conversation + messages.append(response.choices[0].message) + + # Add tool response + tool_calls = extract_litellm_tool_calls(response) + tool_response = mock_tool_response( + tool_calls[0]["name"], tool_calls[0]["arguments"] + ) + + messages.append( + { + "role": "tool", + "tool_call_id": response.choices[0].message.tool_calls[0].id, + "content": tool_response, + } + ) + + # Get final response + final_response = litellm.completion( + model=get_model("litellm", "chat"), messages=messages, max_tokens=150 + ) + + assert_valid_chat_response(final_response) + content = final_response.choices[0].message.content.lower() + weather_location_keywords = WEATHER_KEYWORDS + LOCATION_KEYWORDS + assert any(word in content for word in weather_location_keywords) + + def test_06_automatic_function_calling(self, test_config): + """Test Case 6: Automatic function calling""" + tools = convert_to_litellm_tools([CALCULATOR_TOOL]) + + response = litellm.completion( + model=get_model("litellm", "chat"), + messages=[{"role": "user", "content": "Calculate 25 * 4 for me"}], + tools=tools, + tool_choice="auto", + max_tokens=100, + ) + + # Should automatically choose to use the calculator + assert_has_tool_calls(response, expected_count=1) + tool_calls = extract_litellm_tool_calls(response) + assert tool_calls[0]["name"] == "calculate" + + def test_07_image_url(self, test_config): + """Test Case 7: Image analysis from URL""" + response = litellm.completion( + model=get_model("litellm", "vision"), + messages=IMAGE_URL_MESSAGES, + max_tokens=200, + ) + + assert_valid_image_response(response) + + def test_08_image_base64(self, test_config): + """Test Case 8: Image analysis from base64""" + response = litellm.completion( + model=get_model("litellm", "vision"), + messages=IMAGE_BASE64_MESSAGES, + max_tokens=200, + ) + + assert_valid_image_response(response) + + def test_09_multiple_images(self, test_config): + """Test Case 9: Multiple image analysis""" + response = litellm.completion( + model=get_model("litellm", "vision"), + messages=MULTIPLE_IMAGES_MESSAGES, + max_tokens=300, + ) + + assert_valid_image_response(response) + content = response.choices[0].message.content.lower() + # Should mention comparison or differences + assert any( + word in content for word in COMPARISON_KEYWORDS + ), f"Response should contain comparison keywords. Got content: {content}" + + def test_10_complex_end2end(self, test_config): + """Test Case 10: Complex end-to-end with conversation, images, and tools""" + messages = COMPLEX_E2E_MESSAGES.copy() + tools = convert_to_litellm_tools([WEATHER_TOOL]) + + # First, analyze the image + response1 = litellm.completion( + model=get_model("litellm", "vision"), + messages=messages, + tools=tools, + max_tokens=300, + ) + + # Should either describe image or call weather tool (or both) + assert ( + response1.choices[0].message.content is not None + or response1.choices[0].message.tool_calls is not None + ) + + # Add response to conversation + messages.append(response1.choices[0].message) + + # If there were tool calls, handle them + if response1.choices[0].message.tool_calls: + for tool_call in response1.choices[0].message.tool_calls: + tool_name = tool_call.function.name + tool_args = json.loads(tool_call.function.arguments) + tool_response = mock_tool_response(tool_name, tool_args) + + messages.append( + { + "role": "tool", + "tool_call_id": tool_call.id, + "content": tool_response, + } + ) + + # Get final response after tool calls + final_response = litellm.completion( + model=get_model("litellm", "vision"), messages=messages, max_tokens=200 + ) + + assert_valid_chat_response(final_response) + + def test_11_integration_specific_features(self, test_config): + """Test Case 11: LiteLLM-specific features""" + + # Test 1: Multiple integrations through LiteLLM + integrations_to_test = [ + "gpt-3.5-turbo", # OpenAI + "claude-3-haiku-20240307", # Anthropic + "gemini-2.0-flash-001", # Google Gemini + "mistral-7b-instruct", # Mistral + ] + + for model in integrations_to_test: + try: + response = litellm.completion( + model=model, + messages=[{"role": "user", "content": "Hello, how are you?"}], + max_tokens=50, + ) + + assert_valid_chat_response(response) + + except Exception as e: + # Some integrations might not be available, skip gracefully + pytest.skip(f"Integration {model} not available: {e}") + + # Test 2: Function calling with specific tool choice + tools = convert_to_litellm_tools([CALCULATOR_TOOL, WEATHER_TOOL]) + + response2 = litellm.completion( + model=get_model("litellm", "chat"), + messages=[{"role": "user", "content": "What's 15 + 27?"}], + tools=tools, + tool_choice={"type": "function", "function": {"name": "calculate"}}, + max_tokens=100, + ) + + assert_has_tool_calls(response2, expected_count=1) + tool_calls = extract_litellm_tool_calls(response2) + assert tool_calls[0]["name"] == "calculate" + + # Test 3: Temperature and other parameters + response3 = litellm.completion( + model=get_model("litellm", "chat"), + messages=[ + {"role": "user", "content": "Tell me a creative story in one sentence."} + ], + temperature=0.9, + top_p=0.9, + max_tokens=100, + ) + + assert_valid_chat_response(response3) + + def test_12_error_handling_invalid_roles(self, test_config): + """Test Case 12: Error handling for invalid roles""" + with pytest.raises(Exception) as exc_info: + litellm.completion( + model=get_model("litellm", "chat"), + messages=INVALID_ROLE_MESSAGES, + max_tokens=100, + ) + + # Verify the error is properly caught and contains role-related information + error = exc_info.value + assert_valid_error_response(error, "tester") + assert_error_propagation(error, "litellm") + + def test_13_streaming(self, test_config): + """Test Case 13: Streaming chat completion""" + # Test basic streaming + stream = litellm.completion( + model=get_model("litellm", "chat"), + messages=STREAMING_CHAT_MESSAGES, + max_tokens=200, + stream=True, + ) + + content, chunk_count, tool_calls_detected = collect_streaming_content( + stream, "openai", timeout=30 # LiteLLM uses OpenAI format + ) + + # Validate streaming results + assert chunk_count > 0, "Should receive at least one chunk" + assert len(content) > 10, "Should receive substantial content" + assert not tool_calls_detected, "Basic streaming shouldn't have tool calls" + + # Test streaming with tool calls + stream_with_tools = litellm.completion( + model=get_model("litellm", "tools"), + messages=STREAMING_TOOL_CALL_MESSAGES, + max_tokens=150, + tools=convert_to_litellm_tools([WEATHER_TOOL]), + stream=True, + ) + + content_tools, chunk_count_tools, tool_calls_detected_tools = ( + collect_streaming_content( + stream_with_tools, "openai", timeout=30 # LiteLLM uses OpenAI format + ) + ) + + # Validate tool streaming results + assert chunk_count_tools > 0, "Should receive at least one chunk with tools" + assert ( + tool_calls_detected_tools + ), "Should detect tool calls in streaming response" + + def test_14_gemini_integration(self, test_config): + """Test Case 14: Google Gemini integration through LiteLLM""" + try: + # Test basic chat with Gemini + response = litellm.completion( + model="vertex_ai/gemini-2.0-flash-001", + messages=[ + { + "role": "user", + "content": "What is machine learning? Answer in one sentence.", + } + ], + max_tokens=100, + ) + + assert_valid_chat_response(response) + content = response.choices[0].message.content.lower() + assert any( + word in content for word in ["machine", "learning", "data", "algorithm"] + ), f"Response should mention ML concepts. Got: {content}" + + # Test with tool calling if supported + tools = convert_to_litellm_tools([CALCULATOR_TOOL]) + response_tools = litellm.completion( + model="vertex_ai/gemini-2.0-flash-001", + messages=[{"role": "user", "content": "Calculate 42 * 17"}], + tools=tools, + max_tokens=100, + ) + + # Gemini should either use tools or provide calculation + if response_tools.choices[0].message.tool_calls: + assert_has_tool_calls(response_tools, expected_count=1) + else: + # Should at least provide the calculation result + content = response_tools.choices[0].message.content + assert ( + "714" in content or "42" in content + ), "Should provide calculation result" + + except Exception as e: + pytest.skip(f"Gemini integration not available: {e}") + + def test_15_mistral_integration(self, test_config): + """Test Case 15: Mistral integration through LiteLLM""" + try: + # Test basic chat with Mistral + response = litellm.completion( + model="mistral/mistral-7b-instruct", + messages=[ + { + "role": "user", + "content": "Explain recursion in programming briefly.", + } + ], + max_tokens=150, + ) + + assert_valid_chat_response(response) + content = response.choices[0].message.content.lower() + assert any( + word in content for word in ["recursion", "function", "itself", "call"] + ), f"Response should explain recursion. Got: {content}" + + # Test with different temperature + response_creative = litellm.completion( + model="mistral/mistral-7b-instruct", + messages=[{"role": "user", "content": "Write a haiku about code."}], + temperature=0.8, + max_tokens=100, + ) + + assert_valid_chat_response(response_creative) + + except Exception as e: + pytest.skip(f"Mistral integration not available: {e}") + + def test_16_openai_embeddings_via_litellm(self, test_config): + """Test Case 16: OpenAI embeddings through LiteLLM""" + try: + # Test single text embedding + response = litellm.embedding( + model=get_model("litellm", "embeddings") or "text-embedding-3-small", + input=EMBEDDINGS_SINGLE_TEXT, + ) + + assert_valid_embedding_response(response, expected_dimensions=1536) + + # Test batch embeddings + batch_response = litellm.embedding( + model=get_model("litellm", "embeddings") or "text-embedding-3-small", + input=EMBEDDINGS_MULTIPLE_TEXTS, + ) + + assert_valid_embeddings_batch_response( + batch_response, len(EMBEDDINGS_MULTIPLE_TEXTS), expected_dimensions=1536 + ) + + # Test similarity analysis + similar_response = litellm.embedding( + model=get_model("litellm", "embeddings") or "text-embedding-3-small", + input=EMBEDDINGS_SIMILAR_TEXTS, + ) + + embeddings = [ + item["embedding"] if isinstance(item, dict) else item.embedding + for item in ( + similar_response["data"] + if isinstance(similar_response, dict) + else similar_response.data + ) + ] + + # Calculate similarity between similar texts + similarity = calculate_cosine_similarity(embeddings[0], embeddings[1]) + assert ( + similarity > 0.7 + ), f"Similar texts should have high similarity, got {similarity:.4f}" + + except Exception as e: + pytest.skip(f"OpenAI embeddings through LiteLLM not available: {e}") + + def test_17_openai_speech_via_litellm(self, test_config): + """Test Case 17: OpenAI speech synthesis through LiteLLM""" + try: + # Test basic speech synthesis + response = litellm.speech( + model=get_model("litellm", "speech") or "tts-1", + voice="alloy", + input=SPEECH_TEST_INPUT, + ) + + # LiteLLM might return different response format + if hasattr(response, "content"): + audio_content = response.content + elif isinstance(response, bytes): + audio_content = response + else: + audio_content = response + + assert_valid_speech_response(audio_content) + + # Test with different voice + response2 = litellm.speech( + model=get_model("litellm", "speech") or "tts-1", + voice="nova", + input="Short test message for voice comparison.", + response_format="mp3", + ) + + if hasattr(response2, "content"): + audio_content2 = response2.content + elif isinstance(response2, bytes): + audio_content2 = response2 + else: + audio_content2 = response2 + + assert_valid_speech_response(audio_content2, expected_audio_size_min=500) + + # Different voices should produce different audio + assert ( + audio_content != audio_content2 + ), "Different voices should produce different audio" + + except Exception as e: + pytest.skip(f"OpenAI speech through LiteLLM not available: {e}") + + def test_18_openai_transcription_via_litellm(self, test_config): + """Test Case 18: OpenAI transcription through LiteLLM""" + try: + # Generate test audio for transcription + test_audio = generate_test_audio() + + # Test basic transcription + response = litellm.transcription( + model=get_model("litellm", "transcription") or "whisper-1", + file=("test_audio.wav", test_audio, "audio/wav"), + ) + + assert_valid_transcription_response(response) + + # Test with additional parameters + response2 = litellm.transcription( + model=get_model("litellm", "transcription") or "whisper-1", + file=("test_audio.wav", test_audio, "audio/wav"), + language="en", + temperature=0.0, + ) + + assert_valid_transcription_response(response2) + + except Exception as e: + pytest.skip(f"OpenAI transcription through LiteLLM not available: {e}") + + def test_19_multi_provider_comparison(self, test_config): + """Test Case 19: Compare responses across different providers through LiteLLM""" + test_prompt = "What is the capital of Japan? Answer in one word." + models_to_test = [ + "gpt-3.5-turbo", # OpenAI + "claude-3-haiku-20240307", # Anthropic + "vertex_ai/gemini-2.0-flash-001", # Google + ] + + responses = {} + + for model in models_to_test: + try: + response = litellm.completion( + model=model, + messages=[{"role": "user", "content": test_prompt}], + max_tokens=50, + ) + + assert_valid_chat_response(response) + responses[model] = response.choices[0].message.content.lower() + + except Exception as e: + print(f"Model {model} not available: {e}") + continue + + # Verify that we got at least one response + assert len(responses) > 0, "Should get at least one successful response" + + # All responses should mention Tokyo or Japan + for model, content in responses.items(): + assert any( + word in content for word in ["tokyo", "japan"] + ), f"Model {model} should mention Tokyo. Got: {content}" + + +# Additional helper functions specific to LiteLLM +def extract_litellm_tool_calls(response: Any) -> List[Dict[str, Any]]: + """Extract tool calls from LiteLLM response format (OpenAI-compatible) with proper type checking""" + tool_calls = [] + + # Type check for LiteLLM response (OpenAI-compatible format) + if not hasattr(response, "choices") or not response.choices: + return tool_calls + + choice = response.choices[0] + if not hasattr(choice, "message") or not hasattr(choice.message, "tool_calls"): + return tool_calls + + if not choice.message.tool_calls: + return tool_calls + + for tool_call in choice.message.tool_calls: + if hasattr(tool_call, "function") and hasattr(tool_call.function, "name"): + try: + arguments = ( + json.loads(tool_call.function.arguments) + if isinstance(tool_call.function.arguments, str) + else tool_call.function.arguments + ) + tool_calls.append( + { + "name": tool_call.function.name, + "arguments": arguments, + } + ) + except (json.JSONDecodeError, AttributeError) as e: + print(f"Warning: Failed to parse LiteLLM tool call arguments: {e}") + continue + + return tool_calls diff --git a/tests/integrations/tests/integrations/test_openai.py b/tests/integrations/tests/integrations/test_openai.py new file mode 100644 index 000000000..4a9a61ea0 --- /dev/null +++ b/tests/integrations/tests/integrations/test_openai.py @@ -0,0 +1,1056 @@ +""" +OpenAI Integration Tests + +πŸ€– MODELS USED: +- Chat: gpt-3.5-turbo +- Vision: gpt-4o +- Tools: gpt-3.5-turbo +- Speech: tts-1 +- Transcription: whisper-1 +- Embeddings: text-embedding-3-small +- Alternatives: gpt-4, gpt-4-turbo-preview, gpt-4o, gpt-4o-mini + +Tests all core scenarios using OpenAI SDK directly: +1. Simple chat +2. Multi turn conversation +3. Tool calls +4. Multiple tool calls +5. End2End tool calling +6. Automatic function calling +7. Image (url) +8. Image (base64) +9. Multiple images +10. Complete end2end test with conversation history, tool calls, tool results and images +11. Integration specific tests +12. Error handling +13. Streaming chat +14. Speech synthesis +15. Audio transcription +16. Transcription streaming +17. Speech-transcription round trip +18. Speech error handling +19. Transcription error handling +20. Different voices and audio formats +21. Single text embedding +22. Batch text embeddings +23. Embedding similarity analysis +24. Embedding dissimilarity analysis +25. Different embedding models +26. Long text embedding +27. Embedding error handling +28. Embedding dimensionality reduction +29. Embedding encoding formats +30. Embedding usage tracking +""" + +import pytest +import json +from openai import OpenAI +from typing import List, Dict, Any + +from ..utils.common import ( + Config, + SIMPLE_CHAT_MESSAGES, + MULTI_TURN_MESSAGES, + SINGLE_TOOL_CALL_MESSAGES, + MULTIPLE_TOOL_CALL_MESSAGES, + IMAGE_URL_MESSAGES, + IMAGE_BASE64_MESSAGES, + MULTIPLE_IMAGES_MESSAGES, + COMPLEX_E2E_MESSAGES, + INVALID_ROLE_MESSAGES, + STREAMING_CHAT_MESSAGES, + STREAMING_TOOL_CALL_MESSAGES, + WEATHER_TOOL, + CALCULATOR_TOOL, + mock_tool_response, + assert_valid_chat_response, + assert_has_tool_calls, + assert_valid_image_response, + assert_valid_error_response, + assert_error_propagation, + assert_valid_streaming_response, + collect_streaming_content, + extract_tool_calls, + get_api_key, + skip_if_no_api_key, + COMPARISON_KEYWORDS, + WEATHER_KEYWORDS, + LOCATION_KEYWORDS, + # Speech and Transcription utilities + SPEECH_TEST_INPUT, + SPEECH_TEST_VOICES, + TRANSCRIPTION_TEST_INPUTS, + generate_test_audio, + TEST_AUDIO_DATA, + assert_valid_speech_response, + assert_valid_transcription_response, + assert_valid_streaming_speech_response, + assert_valid_streaming_transcription_response, + collect_streaming_speech_content, + collect_streaming_transcription_content, + # Embeddings utilities + EMBEDDINGS_SINGLE_TEXT, + EMBEDDINGS_MULTIPLE_TEXTS, + EMBEDDINGS_SIMILAR_TEXTS, + EMBEDDINGS_DIFFERENT_TEXTS, + EMBEDDINGS_EMPTY_TEXTS, + EMBEDDINGS_LONG_TEXT, + assert_valid_embedding_response, + assert_valid_embeddings_batch_response, + calculate_cosine_similarity, + assert_embeddings_similarity, + assert_embeddings_dissimilarity, +) +from ..utils.config_loader import get_model + + +# Helper functions (defined early for use in test methods) +def extract_openai_tool_calls(response: Any) -> List[Dict[str, Any]]: + """Extract tool calls from OpenAI response format with proper type checking""" + tool_calls = [] + + # Type check for OpenAI ChatCompletion response + if not hasattr(response, "choices") or not response.choices: + return tool_calls + + choice = response.choices[0] + if not hasattr(choice, "message") or not hasattr(choice.message, "tool_calls"): + return tool_calls + + if not choice.message.tool_calls: + return tool_calls + + for tool_call in choice.message.tool_calls: + if hasattr(tool_call, "function") and hasattr(tool_call.function, "name"): + try: + arguments = ( + json.loads(tool_call.function.arguments) + if isinstance(tool_call.function.arguments, str) + else tool_call.function.arguments + ) + tool_calls.append( + { + "name": tool_call.function.name, + "arguments": arguments, + } + ) + except (json.JSONDecodeError, AttributeError) as e: + print(f"Warning: Failed to parse tool call arguments: {e}") + continue + + return tool_calls + + +def convert_to_openai_tools(tools: List[Dict[str, Any]]) -> List[Dict[str, Any]]: + """Convert common tool format to OpenAI format""" + return [{"type": "function", "function": tool} for tool in tools] + + +@pytest.fixture +def openai_client(): + """Create OpenAI client for testing""" + from ..utils.config_loader import get_integration_url, get_config + + api_key = get_api_key("openai") + base_url = get_integration_url("openai") + + # Get additional integration settings + config = get_config() + integration_settings = config.get_integration_settings("openai") + api_config = config.get_api_config() + + client_kwargs = { + "api_key": api_key, + "base_url": base_url, + "timeout": api_config.get("timeout", 30), + "max_retries": api_config.get("max_retries", 3), + } + + # Add optional OpenAI-specific settings + if integration_settings.get("organization"): + client_kwargs["organization"] = integration_settings["organization"] + if integration_settings.get("project"): + client_kwargs["project"] = integration_settings["project"] + + return OpenAI(**client_kwargs) + + +@pytest.fixture +def test_config(): + """Test configuration""" + return Config() + + +class TestOpenAIIntegration: + """Test suite for OpenAI integration covering all 11 core scenarios""" + + @skip_if_no_api_key("openai") + def test_01_simple_chat(self, openai_client, test_config): + """Test Case 1: Simple chat interaction""" + response = openai_client.chat.completions.create( + model=get_model("openai", "chat"), + messages=SIMPLE_CHAT_MESSAGES, + max_tokens=100, + ) + + assert_valid_chat_response(response) + assert response.choices[0].message.content is not None + assert len(response.choices[0].message.content) > 0 + + @skip_if_no_api_key("openai") + def test_02_multi_turn_conversation(self, openai_client, test_config): + """Test Case 2: Multi-turn conversation""" + response = openai_client.chat.completions.create( + model=get_model("openai", "chat"), + messages=MULTI_TURN_MESSAGES, + max_tokens=150, + ) + + assert_valid_chat_response(response) + content = response.choices[0].message.content.lower() + # Should mention population or numbers since we asked about Paris population + assert any( + word in content + for word in ["population", "million", "people", "inhabitants"] + ) + + @skip_if_no_api_key("openai") + def test_03_single_tool_call(self, openai_client, test_config): + """Test Case 3: Single tool call""" + response = openai_client.chat.completions.create( + model=get_model("openai", "tools"), + messages=SINGLE_TOOL_CALL_MESSAGES, + tools=[{"type": "function", "function": WEATHER_TOOL}], + max_tokens=100, + ) + + assert_has_tool_calls(response, expected_count=1) + tool_calls = extract_tool_calls(response) + assert tool_calls[0]["name"] == "get_weather" + assert "location" in tool_calls[0]["arguments"] + + @skip_if_no_api_key("openai") + def test_04_multiple_tool_calls(self, openai_client, test_config): + """Test Case 4: Multiple tool calls in one response""" + response = openai_client.chat.completions.create( + model=get_model("openai", "tools"), + messages=MULTIPLE_TOOL_CALL_MESSAGES, + tools=[ + {"type": "function", "function": WEATHER_TOOL}, + {"type": "function", "function": CALCULATOR_TOOL}, + ], + max_tokens=200, + ) + + assert_has_tool_calls(response, expected_count=2) + tool_calls = extract_openai_tool_calls(response) + tool_names = [tc["name"] for tc in tool_calls] + assert "get_weather" in tool_names + assert "calculate" in tool_names + + @skip_if_no_api_key("openai") + def test_05_end2end_tool_calling(self, openai_client, test_config): + """Test Case 5: Complete tool calling flow with responses""" + # Initial request + messages = [{"role": "user", "content": "What's the weather in Boston?"}] + + response = openai_client.chat.completions.create( + model=get_model("openai", "tools"), + messages=messages, + tools=[{"type": "function", "function": WEATHER_TOOL}], + max_tokens=100, + ) + + assert_has_tool_calls(response, expected_count=1) + + # Add assistant's tool call to conversation + messages.append(response.choices[0].message) + + # Add tool response + tool_calls = extract_openai_tool_calls(response) + tool_response = mock_tool_response( + tool_calls[0]["name"], tool_calls[0]["arguments"] + ) + + messages.append( + { + "role": "tool", + "tool_call_id": response.choices[0].message.tool_calls[0].id, + "content": tool_response, + } + ) + + # Get final response + final_response = openai_client.chat.completions.create( + model=get_model("openai", "tools"), messages=messages, max_tokens=150 + ) + + assert_valid_chat_response(final_response) + content = final_response.choices[0].message.content.lower() + weather_location_keywords = WEATHER_KEYWORDS + LOCATION_KEYWORDS + assert any(word in content for word in weather_location_keywords) + + @skip_if_no_api_key("openai") + def test_06_automatic_function_calling(self, openai_client, test_config): + """Test Case 6: Automatic function calling (tool_choice='auto')""" + response = openai_client.chat.completions.create( + model=get_model("openai", "tools"), + messages=[{"role": "user", "content": "Calculate 25 * 4 for me"}], + tools=[{"type": "function", "function": CALCULATOR_TOOL}], + tool_choice="auto", # Let model decide + max_tokens=100, + ) + + # Should automatically choose to use the calculator + assert_has_tool_calls(response, expected_count=1) + tool_calls = extract_openai_tool_calls(response) + assert tool_calls[0]["name"] == "calculate" + + @skip_if_no_api_key("openai") + def test_07_image_url(self, openai_client, test_config): + """Test Case 7: Image analysis from URL""" + response = openai_client.chat.completions.create( + model=get_model("openai", "vision"), + messages=IMAGE_URL_MESSAGES, + max_tokens=200, + ) + + assert_valid_image_response(response) + + @skip_if_no_api_key("openai") + def test_08_image_base64(self, openai_client, test_config): + """Test Case 8: Image analysis from base64""" + response = openai_client.chat.completions.create( + model=get_model("openai", "vision"), + messages=IMAGE_BASE64_MESSAGES, + max_tokens=200, + ) + + assert_valid_image_response(response) + + @skip_if_no_api_key("openai") + def test_09_multiple_images(self, openai_client, test_config): + """Test Case 9: Multiple image analysis""" + response = openai_client.chat.completions.create( + model=get_model("openai", "vision"), + messages=MULTIPLE_IMAGES_MESSAGES, + max_tokens=300, + ) + + assert_valid_image_response(response) + content = response.choices[0].message.content.lower() + # Should mention comparison or differences (flexible matching) + assert any( + word in content for word in COMPARISON_KEYWORDS + ), f"Response should contain comparison keywords. Got content: {content}" + + @skip_if_no_api_key("openai") + def test_10_complex_end2end(self, openai_client, test_config): + """Test Case 10: Complex end-to-end with conversation, images, and tools""" + messages = COMPLEX_E2E_MESSAGES.copy() + + # First, analyze the image + response1 = openai_client.chat.completions.create( + model=get_model("openai", "vision"), + messages=messages, + tools=[{"type": "function", "function": WEATHER_TOOL}], + max_tokens=300, + ) + + # Should either describe image or call weather tool (or both) + assert ( + response1.choices[0].message.content is not None + or response1.choices[0].message.tool_calls is not None + ) + + # Add response to conversation + messages.append(response1.choices[0].message) + + # If there were tool calls, handle them + if response1.choices[0].message.tool_calls: + for tool_call in response1.choices[0].message.tool_calls: + tool_name = tool_call.function.name + tool_args = json.loads(tool_call.function.arguments) + tool_response = mock_tool_response(tool_name, tool_args) + + messages.append( + { + "role": "tool", + "tool_call_id": tool_call.id, + "content": tool_response, + } + ) + + # Get final response after tool calls + final_response = openai_client.chat.completions.create( + model=get_model("openai", "vision"), messages=messages, max_tokens=200 + ) + + assert_valid_chat_response(final_response) + + @skip_if_no_api_key("openai") + def test_11_integration_specific_features(self, openai_client, test_config): + """Test Case 11: OpenAI-specific features""" + + # Test 1: Function calling with specific tool choice + response1 = openai_client.chat.completions.create( + model=get_model("openai", "tools"), + messages=[{"role": "user", "content": "What's 15 + 27?"}], + tools=[ + {"type": "function", "function": CALCULATOR_TOOL}, + {"type": "function", "function": WEATHER_TOOL}, + ], + tool_choice={ + "type": "function", + "function": {"name": "calculate"}, + }, # Force specific tool + max_tokens=100, + ) + + assert_has_tool_calls(response1, expected_count=1) + tool_calls = extract_openai_tool_calls(response1) + assert tool_calls[0]["name"] == "calculate" + + # Test 2: System message + response2 = openai_client.chat.completions.create( + model=get_model("openai", "chat"), + messages=[ + { + "role": "system", + "content": "You are a helpful assistant that always responds in exactly 5 words.", + }, + {"role": "user", "content": "Hello, how are you?"}, + ], + max_tokens=50, + ) + + assert_valid_chat_response(response2) + # Check if response is approximately 5 words (allow some flexibility) + word_count = len(response2.choices[0].message.content.split()) + assert 3 <= word_count <= 7, f"Expected ~5 words, got {word_count}" + + # Test 3: Temperature and top_p parameters + response3 = openai_client.chat.completions.create( + model=get_model("openai", "chat"), + messages=[ + {"role": "user", "content": "Tell me a creative story in one sentence."} + ], + temperature=0.9, + top_p=0.9, + max_tokens=100, + ) + + assert_valid_chat_response(response3) + + @skip_if_no_api_key("openai") + def test_12_error_handling_invalid_roles(self, openai_client, test_config): + """Test Case 12: Error handling for invalid roles""" + with pytest.raises(Exception) as exc_info: + openai_client.chat.completions.create( + model=get_model("openai", "chat"), + messages=INVALID_ROLE_MESSAGES, + max_tokens=100, + ) + + # Verify the error is properly caught and contains role-related information + error = exc_info.value + assert_valid_error_response(error, "tester") + assert_error_propagation(error, "openai") + + @skip_if_no_api_key("openai") + def test_13_streaming(self, openai_client, test_config): + """Test Case 13: Streaming chat completion""" + # Test basic streaming + stream = openai_client.chat.completions.create( + model=get_model("openai", "chat"), + messages=STREAMING_CHAT_MESSAGES, + max_tokens=200, + stream=True, + ) + + content, chunk_count, tool_calls_detected = collect_streaming_content( + stream, "openai", timeout=30 + ) + + # Validate streaming results + assert chunk_count > 0, "Should receive at least one chunk" + assert len(content) > 10, "Should receive substantial content" + assert not tool_calls_detected, "Basic streaming shouldn't have tool calls" + + # Test streaming with tool calls + stream_with_tools = openai_client.chat.completions.create( + model=get_model("openai", "tools"), + messages=STREAMING_TOOL_CALL_MESSAGES, + max_tokens=150, + tools=convert_to_openai_tools([WEATHER_TOOL]), + stream=True, + ) + + content_tools, chunk_count_tools, tool_calls_detected_tools = ( + collect_streaming_content(stream_with_tools, "openai", timeout=30) + ) + + # Validate tool streaming results + assert chunk_count_tools > 0, "Should receive at least one chunk with tools" + assert ( + tool_calls_detected_tools + ), "Should detect tool calls in streaming response" + + @skip_if_no_api_key("openai") + def test_14_speech_synthesis(self, openai_client, test_config): + """Test Case 14: Speech synthesis (text-to-speech)""" + # Basic speech synthesis test + response = openai_client.audio.speech.create( + model=get_model("openai", "speech"), + voice="alloy", + input=SPEECH_TEST_INPUT, + ) + + # Read the audio content + audio_content = response.content + assert_valid_speech_response(audio_content) + + # Test with different voice + response2 = openai_client.audio.speech.create( + model=get_model("openai", "speech"), + voice="nova", + input="Short test message.", + response_format="mp3", + ) + + audio_content2 = response2.content + assert_valid_speech_response(audio_content2, expected_audio_size_min=500) + + # Verify that different voices produce different audio + assert ( + audio_content != audio_content2 + ), "Different voices should produce different audio" + + @skip_if_no_api_key("openai") + def test_15_transcription_audio(self, openai_client, test_config): + """Test Case 16: Audio transcription (speech-to-text)""" + # Generate test audio for transcription + test_audio = generate_test_audio() + + # Basic transcription test + response = openai_client.audio.transcriptions.create( + model=get_model("openai", "transcription"), + file=("test_audio.wav", test_audio, "audio/wav"), + ) + + assert_valid_transcription_response(response) + # Since we're using a generated sine wave, we don't expect specific text, + # but the API should return some transcription attempt + + # Test with additional parameters + response2 = openai_client.audio.transcriptions.create( + model=get_model("openai", "transcription"), + file=("test_audio.wav", test_audio, "audio/wav"), + language="en", + temperature=0.0, + ) + + assert_valid_transcription_response(response2) + + @skip_if_no_api_key("openai") + def test_16_transcription_streaming(self, openai_client, test_config): + """Test Case 17: Audio transcription streaming""" + # Generate test audio for streaming transcription + test_audio = generate_test_audio() + + try: + # Try to create streaming transcription + response = openai_client.audio.transcriptions.create( + model=get_model("openai", "transcription"), + file=("test_audio.wav", test_audio, "audio/wav"), + stream=True, + ) + + # If streaming is supported, collect the text chunks + if hasattr(response, "__iter__"): + text_content, chunk_count = collect_streaming_transcription_content( + response, "openai", timeout=60 + ) + assert chunk_count > 0, "Should receive at least one text chunk" + assert_valid_transcription_response( + text_content, min_text_length=0 + ) # Sine wave might not produce much text + else: + # If not streaming, should still be valid transcription + assert_valid_transcription_response(response) + + except Exception as e: + # If streaming is not supported, ensure it's a proper error message + error_message = str(e).lower() + streaming_not_supported = any( + phrase in error_message + for phrase in ["streaming", "not supported", "invalid", "stream"] + ) + if not streaming_not_supported: + # Re-raise if it's not a streaming support issue + raise + + @skip_if_no_api_key("openai") + def test_17_speech_transcription_round_trip(self, openai_client, test_config): + """Test Case 18: Complete round-trip - text to speech to text""" + original_text = "The quick brown fox jumps over the lazy dog." + + # Step 1: Convert text to speech + speech_response = openai_client.audio.speech.create( + model=get_model("openai", "speech"), + voice="alloy", + input=original_text, + response_format="wav", # Use WAV for better transcription compatibility + ) + + audio_content = speech_response.content + assert_valid_speech_response(audio_content) + + # Step 2: Convert speech back to text + transcription_response = openai_client.audio.transcriptions.create( + model=get_model("openai", "transcription"), + file=("generated_speech.wav", audio_content, "audio/wav"), + ) + + assert_valid_transcription_response(transcription_response) + transcribed_text = transcription_response.text + + # Step 3: Verify similarity (allowing for some variation in transcription) + # Check for key words from the original text + original_words = original_text.lower().split() + transcribed_words = transcribed_text.lower().split() + + # At least 50% of the original words should be present in the transcription + matching_words = sum(1 for word in original_words if word in transcribed_words) + match_percentage = matching_words / len(original_words) + + assert match_percentage >= 0.3, ( + f"Round-trip transcription should preserve at least 30% of original words. " + f"Original: '{original_text}', Transcribed: '{transcribed_text}', " + f"Match percentage: {match_percentage:.2%}" + ) + + @skip_if_no_api_key("openai") + def test_18_speech_error_handling(self, openai_client, test_config): + """Test Case 19: Speech synthesis error handling""" + # Test with invalid voice + with pytest.raises(Exception) as exc_info: + openai_client.audio.speech.create( + model=get_model("openai", "speech"), + voice="invalid_voice_name", + input="This should fail.", + ) + + error = exc_info.value + assert_valid_error_response(error, "invalid_voice_name") + + # Test with empty input + with pytest.raises(Exception) as exc_info: + openai_client.audio.speech.create( + model=get_model("openai", "speech"), + voice="alloy", + input="", + ) + + error = exc_info.value + # Should get an error for empty input + + # Test with invalid model + with pytest.raises(Exception) as exc_info: + openai_client.audio.speech.create( + model="invalid-speech-model", + voice="alloy", + input="This should fail due to invalid model.", + ) + + error = exc_info.value + # Should get an error for invalid model + + @skip_if_no_api_key("openai") + def test_19_transcription_error_handling(self, openai_client, test_config): + """Test Case 20: Transcription error handling""" + # Test with invalid audio data + invalid_audio = b"This is not audio data" + + with pytest.raises(Exception) as exc_info: + openai_client.audio.transcriptions.create( + model=get_model("openai", "transcription"), + file=("invalid.wav", invalid_audio, "audio/wav"), + ) + + error = exc_info.value + # Should get an error for invalid audio format + + # Test with invalid model + valid_audio = generate_test_audio() + + with pytest.raises(Exception) as exc_info: + openai_client.audio.transcriptions.create( + model="invalid-transcription-model", + file=("test.wav", valid_audio, "audio/wav"), + ) + + error = exc_info.value + # Should get an error for invalid model + + # Test with unsupported file format (if applicable) + with pytest.raises(Exception) as exc_info: + openai_client.audio.transcriptions.create( + model=get_model("openai", "transcription"), + file=("test.txt", b"text file content", "text/plain"), + ) + + error = exc_info.value + # Should get an error for unsupported file type + + @skip_if_no_api_key("openai") + def test_20_speech_different_voices_and_formats(self, openai_client, test_config): + """Test Case 21: Test different voices and response formats""" + test_text = "Testing different voices and audio formats." + + # Test multiple voices + voices_tested = [] + for voice in SPEECH_TEST_VOICES[ + :3 + ]: # Test first 3 voices to avoid too many API calls + response = openai_client.audio.speech.create( + model=get_model("openai", "speech"), + voice=voice, + input=test_text, + response_format="mp3", + ) + + audio_content = response.content + assert_valid_speech_response(audio_content) + voices_tested.append((voice, len(audio_content))) + + # Verify that different voices produce different sized outputs (generally) + sizes = [size for _, size in voices_tested] + assert len(set(sizes)) > 1 or all( + s > 1000 for s in sizes + ), "Different voices should produce varying audio outputs" + + # Test different response formats + formats_to_test = ["mp3", "wav", "opus"] + format_results = [] + + for format_type in formats_to_test: + try: + response = openai_client.audio.speech.create( + model=get_model("openai", "speech"), + voice="alloy", + input="Testing audio format: " + format_type, + response_format=format_type, + ) + + audio_content = response.content + assert_valid_speech_response(audio_content, expected_audio_size_min=500) + format_results.append(format_type) + + except Exception as e: + # Some formats might not be supported + print(f"Format {format_type} not supported or failed: {e}") + + # At least MP3 should be supported + assert "mp3" in format_results, "MP3 format should be supported" + + @skip_if_no_api_key("openai") + def test_21_single_text_embedding(self, openai_client, test_config): + """Test Case 21: Single text embedding generation""" + response = openai_client.embeddings.create( + model=get_model("openai", "embeddings"), input=EMBEDDINGS_SINGLE_TEXT + ) + + assert_valid_embedding_response(response, expected_dimensions=1536) + + # Verify response structure + assert len(response.data) == 1, "Should have exactly one embedding" + assert response.data[0].index == 0, "First embedding should have index 0" + assert ( + response.data[0].object == "embedding" + ), "Object type should be 'embedding'" + + # Verify model in response + assert response.model is not None, "Response should include model name" + assert "text-embedding" in response.model, "Model should be an embedding model" + + @skip_if_no_api_key("openai") + def test_22_batch_text_embeddings(self, openai_client, test_config): + """Test Case 22: Batch text embedding generation""" + response = openai_client.embeddings.create( + model=get_model("openai", "embeddings"), input=EMBEDDINGS_MULTIPLE_TEXTS + ) + + expected_count = len(EMBEDDINGS_MULTIPLE_TEXTS) + assert_valid_embeddings_batch_response( + response, expected_count, expected_dimensions=1536 + ) + + # Verify each embedding has correct index + for i, embedding_obj in enumerate(response.data): + assert embedding_obj.index == i, f"Embedding {i} should have index {i}" + assert ( + embedding_obj.object == "embedding" + ), f"Embedding {i} should have object type 'embedding'" + + @skip_if_no_api_key("openai") + def test_23_embedding_similarity_analysis(self, openai_client, test_config): + """Test Case 23: Embedding similarity analysis with similar texts""" + response = openai_client.embeddings.create( + model=get_model("openai", "embeddings"), input=EMBEDDINGS_SIMILAR_TEXTS + ) + + assert_valid_embeddings_batch_response( + response, len(EMBEDDINGS_SIMILAR_TEXTS), expected_dimensions=1536 + ) + + embeddings = [item.embedding for item in response.data] + + # Test similarity between the first two embeddings (similar weather texts) + similarity_1_2 = calculate_cosine_similarity(embeddings[0], embeddings[1]) + similarity_1_3 = calculate_cosine_similarity(embeddings[0], embeddings[2]) + similarity_2_3 = calculate_cosine_similarity(embeddings[1], embeddings[2]) + + # Similar texts should have high similarity (> 0.7) + assert ( + similarity_1_2 > 0.7 + ), f"Similar texts should have high similarity, got {similarity_1_2:.4f}" + assert ( + similarity_1_3 > 0.7 + ), f"Similar texts should have high similarity, got {similarity_1_3:.4f}" + assert ( + similarity_2_3 > 0.7 + ), f"Similar texts should have high similarity, got {similarity_2_3:.4f}" + + @skip_if_no_api_key("openai") + def test_24_embedding_dissimilarity_analysis(self, openai_client, test_config): + """Test Case 24: Embedding dissimilarity analysis with different texts""" + response = openai_client.embeddings.create( + model=get_model("openai", "embeddings"), input=EMBEDDINGS_DIFFERENT_TEXTS + ) + + assert_valid_embeddings_batch_response( + response, len(EMBEDDINGS_DIFFERENT_TEXTS), expected_dimensions=1536 + ) + + embeddings = [item.embedding for item in response.data] + + # Test dissimilarity between different topic embeddings + # Weather vs Programming + weather_prog_similarity = calculate_cosine_similarity( + embeddings[0], embeddings[1] + ) + # Weather vs Stock Market + weather_stock_similarity = calculate_cosine_similarity( + embeddings[0], embeddings[2] + ) + # Programming vs Machine Learning (should be more similar) + prog_ml_similarity = calculate_cosine_similarity(embeddings[1], embeddings[3]) + + # Different topics should have lower similarity + assert ( + weather_prog_similarity < 0.8 + ), f"Different topics should have lower similarity, got {weather_prog_similarity:.4f}" + assert ( + weather_stock_similarity < 0.8 + ), f"Different topics should have lower similarity, got {weather_stock_similarity:.4f}" + + # Programming and ML should be more similar than completely different topics + assert ( + prog_ml_similarity > weather_prog_similarity + ), "Related tech topics should be more similar than unrelated topics" + + @skip_if_no_api_key("openai") + def test_25_embedding_different_models(self, openai_client, test_config): + """Test Case 25: Test different embedding models""" + test_text = EMBEDDINGS_SINGLE_TEXT + + # Test with text-embedding-3-small (default) + response_small = openai_client.embeddings.create( + model="text-embedding-3-small", input=test_text + ) + assert_valid_embedding_response(response_small, expected_dimensions=1536) + + # Test with text-embedding-3-large if available + try: + response_large = openai_client.embeddings.create( + model="text-embedding-3-large", input=test_text + ) + assert_valid_embedding_response(response_large, expected_dimensions=3072) + + # Verify different models produce different embeddings + embedding_small = response_small.data[0].embedding + embedding_large = response_large.data[0].embedding + + # They should have different dimensions + assert len(embedding_small) != len( + embedding_large + ), "Different models should produce different dimension embeddings" + + except Exception as e: + # If text-embedding-3-large is not available, just log it + print(f"text-embedding-3-large not available: {e}") + + @skip_if_no_api_key("openai") + def test_26_embedding_long_text(self, openai_client, test_config): + """Test Case 26: Embedding generation with longer text""" + response = openai_client.embeddings.create( + model=get_model("openai", "embeddings"), input=EMBEDDINGS_LONG_TEXT + ) + + assert_valid_embedding_response(response, expected_dimensions=1536) + + # Verify token usage is reported for longer text + assert response.usage is not None, "Usage should be reported for longer text" + assert ( + response.usage.total_tokens > 20 + ), "Longer text should consume more tokens" + + @skip_if_no_api_key("openai") + def test_27_embedding_error_handling(self, openai_client, test_config): + """Test Case 27: Embedding error handling""" + + # Test with invalid model + with pytest.raises(Exception) as exc_info: + openai_client.embeddings.create( + model="invalid-embedding-model", input=EMBEDDINGS_SINGLE_TEXT + ) + + error = exc_info.value + assert_valid_error_response(error, "invalid-embedding-model") + + # Test with empty text (depending on implementation, might be handled) + try: + response = openai_client.embeddings.create( + model=get_model("openai", "embeddings"), input="" + ) + # If it doesn't throw an error, check that response is still valid + if response: + assert_valid_embedding_response(response) + + except Exception as e: + # Empty input might be rejected, which is acceptable + assert ( + "empty" in str(e).lower() or "invalid" in str(e).lower() + ), "Error should mention empty or invalid input" + + @skip_if_no_api_key("openai") + def test_28_embedding_dimensionality_reduction(self, openai_client, test_config): + """Test Case 28: Embedding with custom dimensions (if supported)""" + try: + # Test custom dimensions with text-embedding-3-small + custom_dimensions = 512 + response = openai_client.embeddings.create( + model="text-embedding-3-small", + input=EMBEDDINGS_SINGLE_TEXT, + dimensions=custom_dimensions, + ) + + assert_valid_embedding_response( + response, expected_dimensions=custom_dimensions + ) + + # Compare with default dimensions + response_default = openai_client.embeddings.create( + model="text-embedding-3-small", input=EMBEDDINGS_SINGLE_TEXT + ) + + embedding_custom = response.data[0].embedding + embedding_default = response_default.data[0].embedding + + assert ( + len(embedding_custom) == custom_dimensions + ), f"Custom dimensions should be {custom_dimensions}" + assert len(embedding_default) == 1536, "Default dimensions should be 1536" + assert len(embedding_custom) != len( + embedding_default + ), "Custom and default dimensions should be different" + + except Exception as e: + # Custom dimensions might not be supported by all models + print(f"Custom dimensions not supported: {e}") + + @skip_if_no_api_key("openai") + def test_29_embedding_encoding_format(self, openai_client, test_config): + """Test Case 29: Different encoding formats (if supported)""" + try: + # Test with float encoding (default) + response_float = openai_client.embeddings.create( + model=get_model("openai", "embeddings"), + input=EMBEDDINGS_SINGLE_TEXT, + encoding_format="float", + ) + + assert_valid_embedding_response(response_float, expected_dimensions=1536) + embedding_float = response_float.data[0].embedding + assert all( + isinstance(x, float) for x in embedding_float + ), "Float encoding should return float values" + + # Test with base64 encoding if supported + try: + response_base64 = openai_client.embeddings.create( + model=get_model("openai", "embeddings"), + input=EMBEDDINGS_SINGLE_TEXT, + encoding_format="base64", + ) + + # Base64 encoding returns string data + assert ( + response_base64.data[0].embedding is not None + ), "Base64 encoding should return data" + + except Exception as base64_error: + print(f"Base64 encoding not supported: {base64_error}") + + except Exception as e: + # Encoding format parameter might not be supported + print(f"Encoding format parameter not supported: {e}") + + @skip_if_no_api_key("openai") + def test_30_embedding_usage_tracking(self, openai_client, test_config): + """Test Case 30: Embedding usage tracking and token counting""" + # Single text embedding + response_single = openai_client.embeddings.create( + model=get_model("openai", "embeddings"), input=EMBEDDINGS_SINGLE_TEXT + ) + + assert_valid_embedding_response(response_single) + assert ( + response_single.usage is not None + ), "Single embedding should have usage data" + assert ( + response_single.usage.total_tokens > 0 + ), "Single embedding should consume tokens" + single_tokens = response_single.usage.total_tokens + + # Batch embedding + response_batch = openai_client.embeddings.create( + model=get_model("openai", "embeddings"), input=EMBEDDINGS_MULTIPLE_TEXTS + ) + + assert_valid_embeddings_batch_response( + response_batch, len(EMBEDDINGS_MULTIPLE_TEXTS) + ) + assert ( + response_batch.usage is not None + ), "Batch embedding should have usage data" + assert ( + response_batch.usage.total_tokens > 0 + ), "Batch embedding should consume tokens" + batch_tokens = response_batch.usage.total_tokens + + # Batch should consume more tokens than single + assert ( + batch_tokens > single_tokens + ), f"Batch embedding ({batch_tokens} tokens) should consume more than single ({single_tokens} tokens)" + + # Verify proportional token usage + texts_ratio = len(EMBEDDINGS_MULTIPLE_TEXTS) + token_ratio = batch_tokens / single_tokens + + # Token ratio should be roughly proportional to text count (allowing for some variance) + assert ( + 0.5 * texts_ratio <= token_ratio <= 2.0 * texts_ratio + ), f"Token usage ratio ({token_ratio:.2f}) should be roughly proportional to text count ({texts_ratio})" diff --git a/tests/integrations/tests/utils/__init__.py b/tests/integrations/tests/utils/__init__.py new file mode 100644 index 000000000..d0ba24ae9 --- /dev/null +++ b/tests/integrations/tests/utils/__init__.py @@ -0,0 +1 @@ +# Utils package for shared test utilities diff --git a/tests/integrations/tests/utils/common.py b/tests/integrations/tests/utils/common.py new file mode 100644 index 000000000..a79e86f00 --- /dev/null +++ b/tests/integrations/tests/utils/common.py @@ -0,0 +1,1397 @@ +""" +Common utilities and test data for all integration tests. +This module contains shared functions, test data, and assertions +that can be used across all integration-specific test files. +""" + +import ast +import base64 +import json +import operator +import os +from typing import Dict, List, Any, Optional +from dataclasses import dataclass + + +# Test Configuration +@dataclass +class Config: + """Configuration for test execution""" + + timeout: int = 30 + max_retries: int = 3 + debug: bool = False + + +# Common Test Data +SIMPLE_CHAT_MESSAGES = [{"role": "user", "content": "Hello! How are you today?"}] + +MULTI_TURN_MESSAGES = [ + {"role": "user", "content": "What's the capital of France?"}, + {"role": "assistant", "content": "The capital of France is Paris."}, + {"role": "user", "content": "What's the population of that city?"}, +] + +# Tool Definitions +WEATHER_TOOL = { + "name": "get_weather", + "description": "Get the current weather for a location", + "parameters": { + "type": "object", + "properties": { + "location": { + "type": "string", + "description": "The city and state, e.g. San Francisco, CA", + }, + "unit": { + "type": "string", + "enum": ["celsius", "fahrenheit"], + "description": "The temperature unit", + }, + }, + "required": ["location"], + }, +} + +CALCULATOR_TOOL = { + "name": "calculate", + "description": "Perform basic mathematical calculations", + "parameters": { + "type": "object", + "properties": { + "expression": { + "type": "string", + "description": "Mathematical expression to evaluate, e.g. '2 + 2'", + } + }, + "required": ["expression"], + }, +} + +SEARCH_TOOL = { + "name": "search_web", + "description": "Search the web for information", + "parameters": { + "type": "object", + "properties": {"query": {"type": "string", "description": "Search query"}}, + "required": ["query"], + }, +} + +ALL_TOOLS = [WEATHER_TOOL, CALCULATOR_TOOL, SEARCH_TOOL] + +# Embeddings Test Data +EMBEDDINGS_SINGLE_TEXT = "The quick brown fox jumps over the lazy dog." + +EMBEDDINGS_MULTIPLE_TEXTS = [ + "Artificial intelligence is transforming our world.", + "Machine learning algorithms learn from data to make predictions.", + "Natural language processing helps computers understand human language.", + "Computer vision enables machines to interpret and analyze visual information.", + "Robotics combines AI with mechanical engineering to create autonomous systems.", +] + +EMBEDDINGS_SIMILAR_TEXTS = [ + "The weather is sunny and warm today.", + "Today has bright sunshine and pleasant temperatures.", + "It's a beautiful day with clear skies and warmth.", +] + +EMBEDDINGS_DIFFERENT_TEXTS = [ + "The weather is sunny and warm today.", + "Python is a popular programming language.", + "The stock market closed higher yesterday.", + "Machine learning requires large datasets.", +] + +EMBEDDINGS_EMPTY_TEXTS = ["", " ", "\n\t", ""] + +EMBEDDINGS_LONG_TEXT = """ +This is a longer text sample designed to test how embedding models handle +larger inputs. It contains multiple sentences with various topics including +technology, science, literature, and general knowledge. The purpose is to +ensure that the embedding generation works correctly with substantial text +inputs that might be closer to real-world usage scenarios where users +embed entire paragraphs or documents rather than just short phrases. +""".strip() + +# Tool Call Test Messages +SINGLE_TOOL_CALL_MESSAGES = [ + {"role": "user", "content": "What's the weather like in San Francisco?"} +] + +MULTIPLE_TOOL_CALL_MESSAGES = [ + {"role": "user", "content": "What's the weather in New York and calculate 15 * 23?"} +] + +# Streaming Test Messages +STREAMING_CHAT_MESSAGES = [ + { + "role": "user", + "content": "Tell me a short story about a robot learning to paint. Keep it under 200 words.", + } +] + +STREAMING_TOOL_CALL_MESSAGES = [ + { + "role": "user", + "content": "What's the weather like in San Francisco? Please use the get_weather function.", + } +] + +# Image Test Data +IMAGE_URL = "https://upload.wikimedia.org/wikipedia/commons/thumb/d/dd/Gfp-wisconsin-madison-the-nature-boardwalk.jpg/2560px-Gfp-wisconsin-madison-the-nature-boardwalk.jpg" + +# Small test image as base64 (1x1 pixel red PNG) +BASE64_IMAGE = "iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAYAAAAfFcSJAAAADUlEQVR42mP8/5+hHgAHggJ/PchI7wAAAABJRU5ErkJggg==" + +IMAGE_URL_MESSAGES = [ + { + "role": "user", + "content": [ + {"type": "text", "text": "What do you see in this image?"}, + {"type": "image_url", "image_url": {"url": IMAGE_URL}}, + ], + } +] + +IMAGE_BASE64_MESSAGES = [ + { + "role": "user", + "content": [ + {"type": "text", "text": "Describe this image"}, + { + "type": "image_url", + "image_url": {"url": f"data:image/png;base64,{BASE64_IMAGE}"}, + }, + ], + } +] + +MULTIPLE_IMAGES_MESSAGES = [ + { + "role": "user", + "content": [ + {"type": "text", "text": "Compare these two images"}, + {"type": "image_url", "image_url": {"url": IMAGE_URL}}, + { + "type": "image_url", + "image_url": {"url": f"data:image/png;base64,{BASE64_IMAGE}"}, + }, + ], + } +] + +# Complex End-to-End Test Data +COMPLEX_E2E_MESSAGES = [ + {"role": "user", "content": "Hello! I need help with some tasks."}, + { + "role": "assistant", + "content": "Hello! I'd be happy to help you with your tasks. What do you need assistance with?", + }, + { + "role": "user", + "content": [ + { + "type": "text", + "text": "First, can you tell me what's in this image and then get the weather for the location shown?", + }, + {"type": "image_url", "image_url": {"url": IMAGE_URL}}, + ], + }, +] + +# Common keyword arrays for flexible assertions +COMPARISON_KEYWORDS = [ + "compare", + "comparison", + "different", + "difference", + "differences", + "both", + "two", + "first", + "second", + "images", + "image", + "versus", + "vs", + "contrast", + "unlike", + "while", + "whereas", +] + +WEATHER_KEYWORDS = [ + "weather", + "temperature", + "sunny", + "cloudy", + "rain", + "snow", + "celsius", + "fahrenheit", + "degrees", + "hot", + "cold", + "warm", + "cool", +] + +LOCATION_KEYWORDS = ["boston", "san francisco", "new york", "city", "location", "place"] + +# Error test data for invalid role testing +INVALID_ROLE_MESSAGES = [ + {"role": "tester", "content": "Hello! This should fail due to invalid role."} +] + +# GenAI-specific invalid role content that passes SDK validation but fails at Bifrost +GENAI_INVALID_ROLE_CONTENT = [ + { + "role": "tester", # Invalid role that should be caught by Bifrost + "parts": [ + {"text": "Hello! This should fail due to invalid role in GenAI format."} + ], + } +] + +# Error keywords for validating error messages +ERROR_KEYWORDS = [ + "invalid", + "error", + "role", + "tester", + "unsupported", + "unknown", + "bad", + "incorrect", + "not allowed", + "not supported", + "forbidden", +] + + +# Helper Functions +def safe_eval_arithmetic(expression: str) -> float: + """ + Safely evaluate arithmetic expressions using AST parsing. + Only allows basic arithmetic operations: +, -, *, /, **, (), and numbers. + + Args: + expression: String containing arithmetic expression + + Returns: + Evaluated result as float + + Raises: + ValueError: If expression contains unsupported operations + SyntaxError: If expression has invalid syntax + ZeroDivisionError: If division by zero occurs + """ + # Allowed operations mapping + ALLOWED_OPS = { + ast.Add: operator.add, + ast.Sub: operator.sub, + ast.Mult: operator.mul, + ast.Div: operator.truediv, + ast.Pow: operator.pow, + ast.USub: operator.neg, + ast.UAdd: operator.pos, + } + + def eval_node(node): + """Recursively evaluate AST nodes""" + if isinstance(node, ast.Constant): # Numbers + return node.value + elif isinstance(node, ast.Num): # Numbers (Python < 3.8 compatibility) + return node.n + elif isinstance(node, ast.UnaryOp): + if type(node.op) in ALLOWED_OPS: + return ALLOWED_OPS[type(node.op)](eval_node(node.operand)) + else: + raise ValueError( + f"Unsupported unary operation: {type(node.op).__name__}" + ) + elif isinstance(node, ast.BinOp): + if type(node.op) in ALLOWED_OPS: + left = eval_node(node.left) + right = eval_node(node.right) + return ALLOWED_OPS[type(node.op)](left, right) + else: + raise ValueError( + f"Unsupported binary operation: {type(node.op).__name__}" + ) + else: + raise ValueError(f"Unsupported expression type: {type(node).__name__}") + + try: + # Parse the expression into an AST + tree = ast.parse(expression, mode="eval") + # Evaluate the AST + return eval_node(tree.body) + except SyntaxError as e: + raise SyntaxError(f"Invalid syntax in expression '{expression}': {e}") + except ZeroDivisionError: + raise ZeroDivisionError(f"Division by zero in expression '{expression}'") + except Exception as e: + raise ValueError(f"Error evaluating expression '{expression}': {e}") + + +def mock_tool_response(tool_name: str, args: Dict[str, Any]) -> str: + """Generate mock responses for tool calls""" + if tool_name == "get_weather": + location = args.get("location", "Unknown") + unit = args.get("unit", "fahrenheit") + return f"The weather in {location} is 72Β°{'F' if unit == 'fahrenheit' else 'C'} and sunny." + + elif tool_name == "calculate": + expression = args.get("expression", "") + try: + # Clean the expression and safely evaluate it + cleaned_expression = expression.replace("x", "*").replace("Γ—", "*") + result = safe_eval_arithmetic(cleaned_expression) + return f"The result of {expression} is {result}" + except (ValueError, SyntaxError, ZeroDivisionError) as e: + return f"Could not calculate {expression}: {e}" + + elif tool_name == "search_web": + query = args.get("query", "") + return f"Here are the search results for '{query}': [Mock search results]" + + return f"Tool {tool_name} executed with args: {args}" + + +def validate_response_structure(response: Any, expected_fields: List[str]) -> bool: + """Validate that a response has the expected structure""" + if not hasattr(response, "__dict__") and not isinstance(response, dict): + return False + + response_dict = response.__dict__ if hasattr(response, "__dict__") else response + + for field in expected_fields: + if field not in response_dict: + return False + + return True + + +def extract_tool_calls(response: Any) -> List[Dict[str, Any]]: + """Extract tool calls from various response formats""" + tool_calls = [] + + # Handle OpenAI format: response.choices[0].message.tool_calls + if hasattr(response, "choices") and len(response.choices) > 0: + choice = response.choices[0] + if ( + hasattr(choice, "message") + and hasattr(choice.message, "tool_calls") + and choice.message.tool_calls + ): + for tool_call in choice.message.tool_calls: + if hasattr(tool_call, "function"): + tool_calls.append( + { + "name": tool_call.function.name, + "arguments": ( + json.loads(tool_call.function.arguments) + if isinstance(tool_call.function.arguments, str) + else tool_call.function.arguments + ), + } + ) + + # Handle direct tool_calls attribute (other formats) + elif hasattr(response, "tool_calls") and response.tool_calls: + for tool_call in response.tool_calls: + if hasattr(tool_call, "function"): + tool_calls.append( + { + "name": tool_call.function.name, + "arguments": ( + json.loads(tool_call.function.arguments) + if isinstance(tool_call.function.arguments, str) + else tool_call.function.arguments + ), + } + ) + + # Handle Anthropic format: response.content with tool_use blocks + elif hasattr(response, "content") and isinstance(response.content, list): + for content in response.content: + if hasattr(content, "type") and content.type == "tool_use": + tool_calls.append({"name": content.name, "arguments": content.input}) + + return tool_calls + + +def assert_valid_chat_response(response: Any, min_length: int = 1): + """Assert that a chat response is valid""" + assert response is not None, "Response should not be None" + + # Extract content from various response formats + content = "" + if hasattr(response, "text"): # Google GenAI + content = response.text + elif hasattr(response, "content"): # Anthropic + if isinstance(response.content, str): + content = response.content + elif isinstance(response.content, list) and len(response.content) > 0: + # Handle list content (like Anthropic) + text_content = [ + c for c in response.content if hasattr(c, "type") and c.type == "text" + ] + if text_content: + content = text_content[0].text + elif hasattr(response, "choices") and len(response.choices) > 0: # OpenAI + # Handle OpenAI format + choice = response.choices[0] + if hasattr(choice, "message") and hasattr(choice.message, "content"): + content = choice.message.content or "" + + assert ( + len(content) >= min_length + ), f"Response content should be at least {min_length} characters, got: {content}" + + +def assert_has_tool_calls(response: Any, expected_count: Optional[int] = None): + """Assert that a response contains tool calls""" + tool_calls = extract_tool_calls(response) + + assert len(tool_calls) > 0, "Response should contain tool calls" + + if expected_count is not None: + assert ( + len(tool_calls) == expected_count + ), f"Expected {expected_count} tool calls, got {len(tool_calls)}" + + # Validate tool call structure + for tool_call in tool_calls: + assert "name" in tool_call, "Tool call should have a name" + assert "arguments" in tool_call, "Tool call should have arguments" + + +def assert_valid_image_response(response: Any): + """Assert that an image analysis response is valid""" + assert_valid_chat_response(response, min_length=10) + + # Extract content for image-specific validation + content = "" + if hasattr(response, "text"): # Google GenAI + content = response.text.lower() + elif hasattr(response, "content"): # Anthropic + if isinstance(response.content, str): + content = response.content.lower() + elif isinstance(response.content, list): + text_content = [ + c for c in response.content if hasattr(c, "type") and c.type == "text" + ] + if text_content: + content = text_content[0].text.lower() + elif hasattr(response, "choices") and len(response.choices) > 0: # OpenAI + choice = response.choices[0] + if hasattr(choice, "message") and hasattr(choice.message, "content"): + content = (choice.message.content or "").lower() + + # Check for image-related keywords + image_keywords = [ + "image", + "picture", + "photo", + "see", + "visual", + "show", + "appear", + "color", + "scene", + ] + has_image_reference = any(keyword in content for keyword in image_keywords) + + assert ( + has_image_reference + ), f"Response should reference the image content. Got: {content}" + + +def assert_valid_error_response( + response_or_exception: Any, expected_invalid_role: str = "tester" +): + """ + Assert that an error response or exception properly indicates an invalid role error. + + Args: + response_or_exception: Either an HTTP error response or a raised exception + expected_invalid_role: The invalid role that should be mentioned in the error + """ + error_message = "" + error_type = "" + status_code = None + + # Handle different error response formats + if hasattr(response_or_exception, "response"): + # This is likely a requests.HTTPError or similar + try: + error_data = response_or_exception.response.json() + status_code = response_or_exception.response.status_code + + # Extract error message from various formats + if isinstance(error_data, dict): + if "error" in error_data: + if isinstance(error_data["error"], dict): + error_message = error_data["error"].get( + "message", str(error_data["error"]) + ) + error_type = error_data["error"].get("type", "") + else: + error_message = str(error_data["error"]) + else: + error_message = error_data.get("message", str(error_data)) + else: + error_message = str(error_data) + except: + error_message = str(response_or_exception) + + elif hasattr(response_or_exception, "message"): + # Direct error object + error_message = response_or_exception.message + + elif hasattr(response_or_exception, "args") and response_or_exception.args: + # Exception with args + error_message = str(response_or_exception.args[0]) + + else: + # Fallback to string representation + error_message = str(response_or_exception) + + # Convert to lowercase for case-insensitive matching + error_message_lower = error_message.lower() + error_type_lower = error_type.lower() + + # Validate that error message indicates role-related issue + role_error_indicators = [ + expected_invalid_role.lower(), + "role", + "invalid", + "unsupported", + "unknown", + "not allowed", + "not supported", + "bad request", + "invalid_request", + ] + + has_role_error = any( + indicator in error_message_lower or indicator in error_type_lower + for indicator in role_error_indicators + ) + + assert has_role_error, ( + f"Error message should indicate invalid role '{expected_invalid_role}'. " + f"Got error message: '{error_message}', error type: '{error_type}'" + ) + + # Validate status code if available (should be 4xx for client errors) + if status_code is not None: + assert ( + 400 <= status_code < 500 + ), f"Expected 4xx status code for invalid role error, got {status_code}" + + return True + + +def assert_error_propagation(error_response: Any, integration: str): + """ + Assert that error is properly propagated through Bifrost to the integration. + + Args: + error_response: The error response from the integration + integration: The integration name (openai, anthropic, etc.) + """ + # Check that we got an error response (not a success) + assert error_response is not None, "Should have received an error response" + + # Integration-specific error format validation + if integration.lower() == "openai": + # OpenAI format: should have top-level 'type', 'event_id' and 'error' field with nested structure + if hasattr(error_response, "response"): + error_data = error_response.response.json() + assert "error" in error_data, "OpenAI error should have 'error' field" + assert ( + "type" in error_data + ), "OpenAI error should have top-level 'type' field" + assert ( + "event_id" in error_data + ), "OpenAI error should have top-level 'event_id' field" + assert isinstance( + error_data["type"], str + ), "OpenAI error type should be a string" + assert isinstance( + error_data["event_id"], str + ), "OpenAI error event_id should be a string" + + # Check nested error structure + error_obj = error_data["error"] + assert ( + "message" in error_obj + ), "OpenAI error.error should have 'message' field" + assert "type" in error_obj, "OpenAI error.error should have 'type' field" + assert "code" in error_obj, "OpenAI error.error should have 'code' field" + assert ( + "event_id" in error_obj + ), "OpenAI error.error should have 'event_id' field" + + elif integration.lower() == "anthropic": + # Anthropic format: should have 'type' and 'error' with 'type' and 'message' + if hasattr(error_response, "response"): + error_data = error_response.response.json() + assert "type" in error_data, "Anthropic error should have 'type' field" + # Type field can be empty string if not set in original error + assert isinstance( + error_data["type"], str + ), "Anthropic error type should be a string" + assert "error" in error_data, "Anthropic error should have 'error' field" + assert ( + "type" in error_data["error"] + ), "Anthropic error.error should have 'type' field" + assert ( + "message" in error_data["error"] + ), "Anthropic error.error should have 'message' field" + + elif integration.lower() in ["google", "gemini", "genai"]: + # Gemini format: follows Google API design guidelines with error.code, error.message, error.status + if hasattr(error_response, "response"): + error_data = error_response.response.json() + assert "error" in error_data, "Gemini error should have 'error' field" + + # Check Google API standard error structure + error_obj = error_data["error"] + assert ( + "code" in error_obj + ), "Gemini error.error should have 'code' field (HTTP status code)" + assert isinstance( + error_obj["code"], int + ), "Gemini error.error.code should be an integer" + assert ( + "message" in error_obj + ), "Gemini error.error should have 'message' field" + assert isinstance( + error_obj["message"], str + ), "Gemini error.error.message should be a string" + assert ( + "status" in error_obj + ), "Gemini error.error should have 'status' field" + assert isinstance( + error_obj["status"], str + ), "Gemini error.error.status should be a string" + + return True + + +def assert_valid_streaming_response( + chunk: Any, integration: str, is_final: bool = False +): + """ + Assert that a streaming response chunk is valid for the given integration. + + Args: + chunk: Individual streaming response chunk + integration: The integration name (openai, anthropic, etc.) + is_final: Whether this is expected to be the final chunk + """ + assert chunk is not None, "Streaming chunk should not be None" + + if integration.lower() == "openai": + # OpenAI streaming format + assert hasattr(chunk, "choices"), "OpenAI streaming chunk should have choices" + assert ( + len(chunk.choices) > 0 + ), "OpenAI streaming chunk should have at least one choice" + + choice = chunk.choices[0] + assert hasattr(choice, "delta"), "OpenAI streaming choice should have delta" + + # Check for content or tool calls in delta + has_content = ( + hasattr(choice.delta, "content") and choice.delta.content is not None + ) + has_tool_calls = ( + hasattr(choice.delta, "tool_calls") and choice.delta.tool_calls is not None + ) + has_role = hasattr(choice.delta, "role") and choice.delta.role is not None + + # Allow empty deltas for final chunks (they just signal completion) + if not is_final: + assert ( + has_content or has_tool_calls or has_role + ), "OpenAI delta should have content, tool_calls, or role (except for final chunks)" + + if is_final: + assert hasattr( + choice, "finish_reason" + ), "Final chunk should have finish_reason" + assert ( + choice.finish_reason is not None + ), "Final chunk finish_reason should not be None" + + elif integration.lower() == "anthropic": + # Anthropic streaming format + assert hasattr(chunk, "type"), "Anthropic streaming chunk should have type" + + if chunk.type == "content_block_delta": + assert hasattr( + chunk, "delta" + ), "Content block delta should have delta field" + + # Validate based on delta type + if hasattr(chunk.delta, "type"): + if chunk.delta.type == "text_delta": + assert hasattr( + chunk.delta, "text" + ), "Text delta should have text field" + elif chunk.delta.type == "thinking_delta": + assert hasattr( + chunk.delta, "thinking" + ), "Thinking delta should have thinking field" + elif chunk.delta.type == "input_json_delta": + assert hasattr( + chunk.delta, "partial_json" + ), "Input JSON delta should have partial_json field" + else: + # Fallback: if no type specified, assume text_delta for backward compatibility + assert hasattr( + chunk.delta, "text" + ), "Content delta should have text field" + elif chunk.type == "message_delta" and is_final: + assert hasattr(chunk, "usage"), "Final message delta should have usage" + + elif integration.lower() in ["google", "gemini", "genai"]: + # Google streaming format + assert hasattr( + chunk, "candidates" + ), "Google streaming chunk should have candidates" + assert ( + len(chunk.candidates) > 0 + ), "Google streaming chunk should have at least one candidate" + + candidate = chunk.candidates[0] + assert hasattr(candidate, "content"), "Google candidate should have content" + + if is_final: + assert hasattr( + candidate, "finish_reason" + ), "Final chunk should have finish_reason" + + +def collect_streaming_content( + stream, integration: str, timeout: int = 30 +) -> tuple[str, int, bool]: + """ + Collect content from a streaming response and validate the stream. + + Args: + stream: The streaming response iterator + integration: The integration name (openai, anthropic, etc.) + timeout: Maximum time to wait for stream completion + + Returns: + tuple: (collected_content, chunk_count, tool_calls_detected) + """ + import time + + content_parts = [] + chunk_count = 0 + tool_calls_detected = False + start_time = time.time() + + for chunk in stream: + chunk_count += 1 + + # Check timeout + if time.time() - start_time > timeout: + raise TimeoutError(f"Streaming took longer than {timeout} seconds") + + # Validate chunk + is_final = False + if integration.lower() == "openai": + is_final = ( + hasattr(chunk, "choices") + and len(chunk.choices) > 0 + and hasattr(chunk.choices[0], "finish_reason") + and chunk.choices[0].finish_reason is not None + ) + + assert_valid_streaming_response(chunk, integration, is_final) + + # Extract content based on integration + if integration.lower() == "openai": + choice = chunk.choices[0] + if hasattr(choice.delta, "content") and choice.delta.content: + content_parts.append(choice.delta.content) + if hasattr(choice.delta, "tool_calls") and choice.delta.tool_calls: + tool_calls_detected = True + + elif integration.lower() == "anthropic": + if chunk.type == "content_block_delta": + if hasattr(chunk.delta, "text") and chunk.delta.text: + content_parts.append(chunk.delta.text) + elif hasattr(chunk.delta, "thinking") and chunk.delta.thinking: + content_parts.append(chunk.delta.thinking) + # Note: partial_json from input_json_delta is not user-visible content + elif chunk.type == "content_block_start": + # Check for tool use content blocks + if ( + hasattr(chunk, "content_block") + and hasattr(chunk.content_block, "type") + and chunk.content_block.type == "tool_use" + ): + tool_calls_detected = True + + elif integration.lower() in ["google", "gemini", "genai"]: + if hasattr(chunk, "candidates") and len(chunk.candidates) > 0: + candidate = chunk.candidates[0] + if ( + hasattr(candidate.content, "parts") + and len(candidate.content.parts) > 0 + ): + for part in candidate.content.parts: + if hasattr(part, "text") and part.text: + content_parts.append(part.text) + + # Safety check + if chunk_count > 500: + raise ValueError( + "Received too many streaming chunks, something might be wrong" + ) + + content = "".join(content_parts) + return content, chunk_count, tool_calls_detected + + +# Test Categories +class TestCategories: + """Constants for test categories""" + + SIMPLE_CHAT = "simple_chat" + MULTI_TURN = "multi_turn" + SINGLE_TOOL = "single_tool" + MULTIPLE_TOOLS = "multiple_tools" + E2E_TOOLS = "e2e_tools" + AUTO_FUNCTION = "auto_function" + IMAGE_URL = "image_url" + IMAGE_BASE64 = "image_base64" + STREAMING = "streaming" + MULTIPLE_IMAGES = "multiple_images" + COMPLEX_E2E = "complex_e2e" + INTEGRATION_SPECIFIC = "integration_specific" + ERROR_HANDLING = "error_handling" + + +# Speech and Transcription Test Data +SPEECH_TEST_INPUT = "Hello, this is a test of the speech synthesis functionality. The quick brown fox jumps over the lazy dog." + +SPEECH_TEST_VOICES = ["alloy", "echo", "fable", "onyx", "nova", "shimmer"] + + +# Generate a simple test audio file (sine wave) for transcription testing +def generate_test_audio() -> bytes: + """Generate a simple sine wave audio file for testing transcription""" + import wave + import math + import struct + + # Audio parameters + sample_rate = 16000 # 16kHz sample rate + duration = 2 # 2 seconds + frequency = 440 # A4 note (440 Hz) + + # Generate sine wave samples + samples = [] + for i in range(int(sample_rate * duration)): + t = i / sample_rate + sample = int(32767 * math.sin(2 * math.pi * frequency * t)) + samples.append(struct.pack("= expected_audio_size_min + ), f"Audio data should be at least {expected_audio_size_min} bytes, got {len(audio_data)}" + + # Check for common audio file headers + # MP3 files start with 0xFF followed by 0xFB, 0xF3, 0xF2, or 0xF0 (MPEG frame sync) + # or with an ID3 tag + is_mp3 = ( + audio_data.startswith(b"\xff\xfb") # MPEG-1 Layer III + or audio_data.startswith(b"\xff\xf3") # MPEG-2 Layer III + or audio_data.startswith(b"\xff\xf2") # MPEG-2.5 Layer III + or audio_data.startswith(b"\xff\xf0") # MPEG-2 Layer I/II + or audio_data.startswith(b"ID3") # ID3 tag + ) + is_wav = audio_data.startswith(b"RIFF") and b"WAVE" in audio_data[:20] + is_opus = audio_data.startswith(b"OggS") + is_aac = audio_data.startswith(b"\xff\xf1") or audio_data.startswith(b"\xff\xf9") + is_flac = audio_data.startswith(b"fLaC") + + assert ( + is_mp3 or is_wav or is_opus or is_aac or is_flac + ), f"Audio data should be in a recognized format (MP3, WAV, Opus, AAC, or FLAC) but got {audio_data[:100]}" + + +def assert_valid_transcription_response(response: Any, min_text_length: int = 1): + """Assert that a transcription response is valid""" + assert response is not None, "Transcription response should not be None" + + # Extract transcribed text from various response formats + text_content = "" + + if hasattr(response, "text"): + # Direct text attribute + text_content = response.text + elif hasattr(response, "content"): + # JSON response with content + if isinstance(response.content, str): + text_content = response.content + elif isinstance(response.content, dict) and "text" in response.content: + text_content = response.content["text"] + elif isinstance(response, dict): + # Direct dictionary response + text_content = response.get("text", "") + elif isinstance(response, str): + # Direct string response + text_content = response + + assert text_content is not None, "Transcription response should contain text" + assert isinstance( + text_content, str + ), f"Transcribed text should be string, got {type(text_content)}" + assert ( + len(text_content.strip()) >= min_text_length + ), f"Transcribed text should be at least {min_text_length} characters, got: '{text_content}'" + + +def assert_valid_embedding_response( + response: Any, expected_dimensions: Optional[int] = None +) -> None: + """Assert that an embedding response is valid""" + assert response is not None, "Embedding response should not be None" + + # Check if it's an OpenAI-style response object + if hasattr(response, "data"): + assert ( + len(response.data) > 0 + ), "Embedding response should contain at least one embedding" + + embedding = response.data[0].embedding + assert isinstance( + embedding, list + ), f"Embedding should be a list, got {type(embedding)}" + assert len(embedding) > 0, "Embedding should not be empty" + assert all( + isinstance(x, (int, float)) for x in embedding + ), "All embedding values should be numeric" + + if expected_dimensions: + assert ( + len(embedding) == expected_dimensions + ), f"Expected {expected_dimensions} dimensions, got {len(embedding)}" + + # Check if usage information is present + if hasattr(response, "usage") and response.usage: + assert hasattr( + response.usage, "total_tokens" + ), "Usage should include total_tokens" + assert ( + response.usage.total_tokens > 0 + ), "Token usage should be greater than 0" + + elif hasattr(response, "embeddings"): + assert len(response.embeddings) > 0, "Embedding should not be empty" + embedding = response.embeddings[0].values + assert isinstance(embedding, list), "Embedding should be a list" + assert len(embedding) > 0, "Embedding should not be empty" + assert all( + isinstance(x, (int, float)) for x in embedding + ), "All embedding values should be numeric" + if expected_dimensions: + assert ( + len(embedding) == expected_dimensions + ), f"Expected {expected_dimensions} dimensions, got {len(embedding)}" + + # Check if it's a direct list (embedding vector) + elif isinstance(response, list): + assert len(response) > 0, "Embedding should not be empty" + assert all( + isinstance(x, (int, float)) for x in response + ), "All embedding values should be numeric" + + if expected_dimensions: + assert ( + len(response) == expected_dimensions + ), f"Expected {expected_dimensions} dimensions, got {len(response)}" + + else: + raise AssertionError(f"Invalid embedding response format: {type(response)}") + + +def assert_valid_embeddings_batch_response( + response: Any, expected_count: int, expected_dimensions: Optional[int] = None +) -> None: + """Assert that a batch embeddings response is valid""" + assert response is not None, "Embeddings batch response should not be None" + + # Check if it's an OpenAI-style response object + if hasattr(response, "data"): + assert ( + len(response.data) == expected_count + ), f"Expected {expected_count} embeddings, got {len(response.data)}" + + for i, embedding_obj in enumerate(response.data): + assert hasattr( + embedding_obj, "embedding" + ), f"Embedding object {i} should have 'embedding' attribute" + embedding = embedding_obj.embedding + + assert isinstance( + embedding, list + ), f"Embedding {i} should be a list, got {type(embedding)}" + assert len(embedding) > 0, f"Embedding {i} should not be empty" + assert all( + isinstance(x, (int, float)) for x in embedding + ), f"All values in embedding {i} should be numeric" + + if expected_dimensions: + assert ( + len(embedding) == expected_dimensions + ), f"Embedding {i}: expected {expected_dimensions} dimensions, got {len(embedding)}" + + # Check usage information + if hasattr(response, "usage") and response.usage: + assert hasattr( + response.usage, "total_tokens" + ), "Usage should include total_tokens" + assert ( + response.usage.total_tokens > 0 + ), "Token usage should be greater than 0" + + # Check if it's a direct list of embeddings + elif isinstance(response, list): + assert ( + len(response) == expected_count + ), f"Expected {expected_count} embeddings, got {len(response)}" + + for i, embedding in enumerate(response): + assert isinstance( + embedding, list + ), f"Embedding {i} should be a list, got {type(embedding)}" + assert len(embedding) > 0, f"Embedding {i} should not be empty" + assert all( + isinstance(x, (int, float)) for x in embedding + ), f"All values in embedding {i} should be numeric" + + if expected_dimensions: + assert ( + len(embedding) == expected_dimensions + ), f"Embedding {i}: expected {expected_dimensions} dimensions, got {len(embedding)}" + + else: + raise AssertionError( + f"Invalid embeddings batch response format: {type(response)}" + ) + + +def calculate_cosine_similarity( + embedding1: List[float], embedding2: List[float] +) -> float: + """Calculate cosine similarity between two embedding vectors""" + import math + + assert len(embedding1) == len(embedding2), "Embeddings must have the same dimension" + + # Calculate dot product + dot_product = sum(a * b for a, b in zip(embedding1, embedding2)) + + # Calculate magnitudes + magnitude1 = math.sqrt(sum(a * a for a in embedding1)) + magnitude2 = math.sqrt(sum(b * b for b in embedding2)) + + # Avoid division by zero + if magnitude1 == 0 or magnitude2 == 0: + return 0.0 + + return dot_product / (magnitude1 * magnitude2) + + +def assert_embeddings_similarity( + embedding1: List[float], + embedding2: List[float], + min_similarity: float = 0.8, + max_similarity: float = 1.0, +) -> None: + """Assert that two embeddings have expected similarity""" + similarity = calculate_cosine_similarity(embedding1, embedding2) + assert ( + min_similarity <= similarity <= max_similarity + ), f"Embedding similarity {similarity:.4f} should be between {min_similarity} and {max_similarity}" + + +def assert_embeddings_dissimilarity( + embedding1: List[float], embedding2: List[float], max_similarity: float = 0.5 +) -> None: + """Assert that two embeddings are sufficiently different""" + similarity = calculate_cosine_similarity(embedding1, embedding2) + assert ( + similarity <= max_similarity + ), f"Embedding similarity {similarity:.4f} should be at most {max_similarity} for dissimilar texts" + + +def assert_valid_streaming_speech_response(chunk: Any, integration: str): + """Assert that a streaming speech response chunk is valid""" + assert chunk is not None, "Streaming speech chunk should not be None" + + if integration.lower() == "openai": + # For OpenAI, speech streaming returns audio chunks + # The chunk might be direct bytes or wrapped in an object + if hasattr(chunk, "audio"): + audio_data = chunk.audio + elif hasattr(chunk, "data"): + audio_data = chunk.data + elif isinstance(chunk, bytes): + audio_data = chunk + else: + # Try to find audio data in the chunk + audio_data = None + for attr in ["content", "chunk", "audio_chunk"]: + if hasattr(chunk, attr): + audio_data = getattr(chunk, attr) + break + + if audio_data: + assert isinstance( + audio_data, bytes + ), f"Audio chunk should be bytes, got {type(audio_data)}" + assert len(audio_data) > 0, "Audio chunk should not be empty" + + +def assert_valid_streaming_transcription_response(chunk: Any, integration: str): + """Assert that a streaming transcription response chunk is valid""" + assert chunk is not None, "Streaming transcription chunk should not be None" + + if integration.lower() == "openai": + # For OpenAI, transcription streaming returns text chunks + if hasattr(chunk, "text"): + text_chunk = chunk.text + elif hasattr(chunk, "content"): + text_chunk = chunk.content + elif isinstance(chunk, str): + text_chunk = chunk + elif isinstance(chunk, dict) and "text" in chunk: + text_chunk = chunk["text"] + else: + # Try to find text data in the chunk + text_chunk = None + for attr in ["data", "chunk", "text_chunk"]: + if hasattr(chunk, attr): + text_chunk = getattr(chunk, attr) + break + + if text_chunk: + assert isinstance( + text_chunk, str + ), f"Text chunk should be string, got {type(text_chunk)}" + # Note: text chunks can be empty in streaming (e.g., just punctuation updates) + + +def collect_streaming_speech_content( + stream, integration: str, timeout: int = 60 +) -> tuple[bytes, int]: + """ + Collect audio content from a streaming speech response. + + Args: + stream: The streaming response iterator + integration: The integration name (openai, etc.) + timeout: Maximum time to wait for stream completion + + Returns: + tuple: (collected_audio_bytes, chunk_count) + """ + import time + + audio_chunks = [] + chunk_count = 0 + start_time = time.time() + + for chunk in stream: + chunk_count += 1 + + # Check timeout + if time.time() - start_time > timeout: + raise TimeoutError(f"Speech streaming took longer than {timeout} seconds") + + # Validate chunk + assert_valid_streaming_speech_response(chunk, integration) + + # Extract audio data + if integration.lower() == "openai": + if hasattr(chunk, "audio") and chunk.audio: + audio_chunks.append(chunk.audio) + elif hasattr(chunk, "data") and chunk.data: + audio_chunks.append(chunk.data) + elif isinstance(chunk, bytes): + audio_chunks.append(chunk) + + # Safety check + if chunk_count > 1000: + raise ValueError( + "Received too many speech streaming chunks, something might be wrong" + ) + + # Combine all audio chunks + complete_audio = b"".join(audio_chunks) + return complete_audio, chunk_count + + +def collect_streaming_transcription_content( + stream, integration: str, timeout: int = 60 +) -> tuple[str, int]: + """ + Collect text content from a streaming transcription response. + + Args: + stream: The streaming response iterator + integration: The integration name (openai, etc.) + timeout: Maximum time to wait for stream completion + + Returns: + tuple: (collected_text, chunk_count) + """ + import time + + text_chunks = [] + chunk_count = 0 + start_time = time.time() + + for chunk in stream: + chunk_count += 1 + + # Check timeout + if time.time() - start_time > timeout: + raise TimeoutError( + f"Transcription streaming took longer than {timeout} seconds" + ) + + # Validate chunk + assert_valid_streaming_transcription_response(chunk, integration) + + # Extract text data + if integration.lower() == "openai": + if hasattr(chunk, "text") and chunk.text: + text_chunks.append(chunk.text) + elif hasattr(chunk, "content") and chunk.content: + text_chunks.append(chunk.content) + elif isinstance(chunk, str): + text_chunks.append(chunk) + + # Safety check + if chunk_count > 1000: + raise ValueError( + "Received too many transcription streaming chunks, something might be wrong" + ) + + # Combine all text chunks + complete_text = "".join(text_chunks) + return complete_text, chunk_count + + +# Environment helpers +def get_api_key(integration: str) -> str: + """Get API key for a integration from environment variables""" + key_map = { + "openai": "OPENAI_API_KEY", + "anthropic": "ANTHROPIC_API_KEY", + "google": "GOOGLE_API_KEY", + "litellm": "LITELLM_API_KEY", + } + + env_var = key_map.get(integration.lower()) + if not env_var: + raise ValueError(f"Unknown integration: {integration}") + + api_key = os.getenv(env_var) + if not api_key: + raise ValueError(f"Missing environment variable: {env_var}") + + return api_key + + +def skip_if_no_api_key(integration: str): + """Decorator to skip tests if API key is not available""" + import pytest + + def decorator(func): + try: + get_api_key(integration) + return func + except ValueError: + return pytest.mark.skip(f"No API key available for {integration}")(func) + + return decorator diff --git a/tests/integrations/tests/utils/config_loader.py b/tests/integrations/tests/utils/config_loader.py new file mode 100644 index 000000000..ae683d6b0 --- /dev/null +++ b/tests/integrations/tests/utils/config_loader.py @@ -0,0 +1,299 @@ +""" +Configuration loader for Bifrost integration tests. + +This module loads configuration from config.yml and provides utilities +for constructing integration URLs through the Bifrost gateway. +""" + +import os +import yaml +from typing import Dict, Any, Optional +from dataclasses import dataclass +from pathlib import Path + + +@dataclass +class BifrostConfig: + """Bifrost gateway configuration""" + + base_url: str + endpoints: Dict[str, str] + + +@dataclass +class IntegrationModels: + """Model configuration for a integration""" + + chat: str + vision: str + tools: str + alternatives: list + + +@dataclass +class TestConfig: + """Complete test configuration""" + + bifrost: BifrostConfig + api: Dict[str, Any] + models: Dict[str, IntegrationModels] + model_capabilities: Dict[str, Dict[str, Any]] + test_settings: Dict[str, Any] + integration_settings: Dict[str, Any] + environments: Dict[str, Any] + logging: Dict[str, Any] + + +class ConfigLoader: + """Configuration loader for Bifrost integration tests""" + + def __init__(self, config_path: Optional[str] = None): + """Initialize configuration loader + + Args: + config_path: Path to config.yml file. If None, looks for config.yml in project root. + """ + if config_path is None: + # Look for config.yml in project root + project_root = Path(__file__).parent.parent.parent + config_path = project_root / "config.yml" + + self.config_path = Path(config_path) + self._config = None + self._load_config() + + def _load_config(self): + """Load configuration from YAML file""" + if not self.config_path.exists(): + raise FileNotFoundError(f"Configuration file not found: {self.config_path}") + + with open(self.config_path, "r") as f: + raw_config = yaml.safe_load(f) + + # Expand environment variables + self._config = self._expand_env_vars(raw_config) + + def _expand_env_vars(self, obj): + """Recursively expand environment variables in configuration""" + if isinstance(obj, dict): + return {k: self._expand_env_vars(v) for k, v in obj.items()} + elif isinstance(obj, list): + return [self._expand_env_vars(item) for item in obj] + elif isinstance(obj, str): + # Handle ${VAR:-default} syntax + import re + + pattern = r"\$\{([^}]+)\}" + + def replace_var(match): + var_expr = match.group(1) + if ":-" in var_expr: + var_name, default_value = var_expr.split(":-", 1) + return os.getenv(var_name, default_value) + else: + return os.getenv(var_expr, "") + + return re.sub(pattern, replace_var, obj) + else: + return obj + + def get_integration_url(self, integration: str) -> str: + """Get the complete URL for a integration + + Args: + integration: Integration name (openai, anthropic, google, litellm) + + Returns: + Complete URL for the integration + + Examples: + get_integration_url("openai") -> "http://localhost:8080/openai" + """ + bifrost_config = self._config["bifrost"] + base_url = bifrost_config["base_url"] + endpoint = bifrost_config["endpoints"].get(integration, "") + + if not endpoint: + raise ValueError(f"No endpoint configured for integration: {integration}") + + return f"{base_url.rstrip('/')}/{endpoint}" + + def get_bifrost_config(self) -> BifrostConfig: + """Get Bifrost configuration""" + bifrost_data = self._config["bifrost"] + return BifrostConfig( + base_url=bifrost_data["base_url"], endpoints=bifrost_data["endpoints"] + ) + + def get_model(self, integration: str, model_type: str = "chat") -> str: + """Get model name for a integration and type""" + if integration not in self._config["models"]: + raise ValueError(f"Unknown integration: {integration}") + + integration_models = self._config["models"][integration] + + if model_type not in integration_models: + raise ValueError( + f"Unknown model type '{model_type}' for integration '{integration}'" + ) + + return integration_models[model_type] + + def get_model_alternatives(self, integration: str) -> list: + """Get alternative models for a integration""" + if integration not in self._config["models"]: + raise ValueError(f"Unknown integration: {integration}") + + return self._config["models"][integration].get("alternatives", []) + + def get_model_capabilities(self, model: str) -> Dict[str, Any]: + """Get capabilities for a specific model""" + return self._config["model_capabilities"].get( + model, + { + "chat": True, + "tools": False, + "vision": False, + "max_tokens": 4096, + "context_window": 4096, + }, + ) + + def supports_capability(self, model: str, capability: str) -> bool: + """Check if a model supports a specific capability""" + caps = self.get_model_capabilities(model) + return caps.get(capability, False) + + def get_api_config(self) -> Dict[str, Any]: + """Get API configuration (timeout, retries, etc.)""" + return self._config["api"] + + def get_test_settings(self) -> Dict[str, Any]: + """Get test configuration settings""" + return self._config["test_settings"] + + def get_integration_settings(self, integration: str) -> Dict[str, Any]: + """Get integration-specific settings""" + return self._config["integration_settings"].get(integration, {}) + + def get_environment_config(self, environment: str = None) -> Dict[str, Any]: + """Get environment-specific configuration + + Args: + environment: Environment name (development, production, etc.) + If None, uses TEST_ENV environment variable or 'development' + """ + if environment is None: + environment = os.getenv("TEST_ENV", "development") + + return self._config["environments"].get(environment, {}) + + def get_logging_config(self) -> Dict[str, Any]: + """Get logging configuration""" + return self._config["logging"] + + def list_integrations(self) -> list: + """List all configured integrations""" + return list(self._config["bifrost"]["endpoints"].keys()) + + def list_models(self, integration: str = None) -> Dict[str, Any]: + """List all models for a integration or all integrations""" + if integration: + if integration not in self._config["models"]: + raise ValueError(f"Unknown integration: {integration}") + return {integration: self._config["models"][integration]} + + return self._config["models"] + + def validate_config(self) -> bool: + """Validate configuration completeness""" + required_sections = ["bifrost", "models", "api", "test_settings"] + + for section in required_sections: + if section not in self._config: + raise ValueError(f"Missing required configuration section: {section}") + + # Validate Bifrost configuration + bifrost = self._config["bifrost"] + if "base_url" not in bifrost or "endpoints" not in bifrost: + raise ValueError("Bifrost configuration missing base_url or endpoints") + + # Validate that all integrations have model configurations + integrations = list(bifrost["endpoints"].keys()) + for integration in integrations: + if integration not in self._config["models"]: + raise ValueError( + f"No model configuration for integration: {integration}" + ) + + return True + + def print_config_summary(self): + """Print a summary of the configuration""" + print("πŸ”§ BIFROST INTEGRATION TEST CONFIGURATION") + print("=" * 80) + + # Bifrost configuration + bifrost = self.get_bifrost_config() + print(f"\nπŸŒ‰ BIFROST GATEWAY:") + print(f" Base URL: {bifrost.base_url}") + print(f" Endpoints:") + for integration, endpoint in bifrost.endpoints.items(): + full_url = f"{bifrost.base_url.rstrip('/')}/{endpoint}" + print(f" {integration}: {full_url}") + + # Model configurations + print(f"\nπŸ€– MODEL CONFIGURATIONS:") + for integration, models in self._config["models"].items(): + print(f" {integration.upper()}:") + print(f" Chat: {models['chat']}") + print(f" Vision: {models['vision']}") + print(f" Tools: {models['tools']}") + print(f" Alternatives: {len(models['alternatives'])} models") + + # API settings + api_config = self.get_api_config() + print(f"\nβš™οΈ API SETTINGS:") + print(f" Timeout: {api_config['timeout']}s") + print(f" Max Retries: {api_config['max_retries']}") + print(f" Retry Delay: {api_config['retry_delay']}s") + + print(f"\nβœ… Configuration loaded successfully from: {self.config_path}") + + +# Global configuration instance +_config_loader = None + + +def get_config() -> ConfigLoader: + """Get global configuration instance""" + global _config_loader + if _config_loader is None: + _config_loader = ConfigLoader() + return _config_loader + + +def get_integration_url(integration: str) -> str: + return get_config().get_integration_url(integration) + + +def get_model(integration: str, model_type: str = "chat") -> str: + """Convenience function to get model name""" + return get_config().get_model(integration, model_type) + + +def get_model_capabilities(model: str) -> Dict[str, Any]: + """Convenience function to get model capabilities""" + return get_config().get_model_capabilities(model) + + +def supports_capability(model: str, capability: str) -> bool: + """Convenience function to check model capability""" + return get_config().supports_capability(model, capability) + + +if __name__ == "__main__": + # Print configuration summary when run directly + config = get_config() + config.validate_config() + config.print_config_summary() diff --git a/tests/integrations/tests/utils/models.py b/tests/integrations/tests/utils/models.py new file mode 100644 index 000000000..315e5410c --- /dev/null +++ b/tests/integrations/tests/utils/models.py @@ -0,0 +1,66 @@ +""" +Model configurations for each integration. + +This file now acts as a compatibility layer and convenience wrapper +around the new configuration system in config.yml and config_loader.py. + +All model data is now centralized in config.yml for easier maintenance. +""" + +from typing import Dict, List +from dataclasses import dataclass +from .config_loader import get_config + + +@dataclass +class IntegrationModels: + """Model configuration for a integration""" + + chat: str # Primary chat model + vision: str # Vision/multimodal model + tools: str # Function calling model + alternatives: List[str] # Alternative models for testing + + +def get_integration_models() -> Dict[str, IntegrationModels]: + """Get all integration model configurations from config.yml""" + config = get_config() + integration_models = {} + + for integration in config.list_integrations(): + models_config = config.list_models(integration) + integration_models[integration] = IntegrationModels( + chat=models_config["chat"], + vision=models_config["vision"], + tools=models_config["tools"], + alternatives=models_config["alternatives"], + ) + + return integration_models + + +# Backward compatibility - load from config +INTEGRATION_MODELS = get_integration_models() + + +def get_alternatives(integration: str) -> List[str]: + """Get alternative models for a integration""" + config = get_config() + return config.get_model_alternatives(integration) + + +def list_all_models() -> Dict[str, Dict[str, str]]: + """List all models by integration and type""" + config = get_config() + return config.list_models() + + +# Print model summary for documentation +def print_model_summary(): + """Print a summary of all models and their capabilities""" + config = get_config() + config.print_config_summary() + + +if __name__ == "__main__": + print_model_summary() diff --git a/transports/.env.sample b/transports/.env.sample deleted file mode 100644 index 30e582a35..000000000 --- a/transports/.env.sample +++ /dev/null @@ -1,10 +0,0 @@ -OPENAI_API_KEY = YOUR_OPENAI_API_KEY -ANTHROPIC_API_KEY = YOUR_ANTHROPIC_API_KEY -BEDROCK_API_KEY = YOUR_BEDROCK_API_KEY -BEDROCK_ACCESS_KEY = YOUR_BEDROCK_ACCESS_KEY -COHERE_API_KEY = YOUR_COHERE_API_KEY -AZURE_API_KEY = YOUR_AZURE_API_KEY -AZURE_ENDPOINT = YOUR_AZURE_ENDPOINT - -MAXIM_API_KEY = YOUR_MAXIM_API_KEY -MAXIM_LOGGER_ID = YOUR_MAXIM_LOGGER_ID \ No newline at end of file diff --git a/transports/Dockerfile b/transports/Dockerfile index df9ac9901..4fe1d847e 100644 --- a/transports/Dockerfile +++ b/transports/Dockerfile @@ -1,61 +1,93 @@ -# --- First Stage: Builder image --- -FROM golang:1.24 AS builder +# --- UI Build Stage: Build the Next.js frontend --- +FROM node:24-alpine3.22 AS ui-builder WORKDIR /app -# Set environment for static build -ENV CGO_ENABLED=0 -ENV GOOS=linux -ENV GOARCH=amd64 +# Copy UI package files and install dependencies +COPY ui/package*.json ./ +RUN npm ci -# Define build-time variable for transport type -ARG TRANSPORT_TYPE=http +# Copy UI source code +COPY ui/ ./ -# Initialize Go module and fetch the bifrost transport package -RUN go mod init bifrost-transports && \ - go get github.com/maximhq/bifrost/transports/${TRANSPORT_TYPE}@latest +# Build UI (skip the copy-build step) +RUN npx next build +RUN node scripts/fix-paths.js +# Skip the copy-build step since we'll copy the files in the Go build stage -# Build the binary from the fetched package with static linking -RUN go build -ldflags="-w -s" -o /app/main github.com/maximhq/bifrost/transports/${TRANSPORT_TYPE} && \ - test -f /app/main || (echo "Build failed: /app/main not found" && exit 1) && \ - ls -lh /app/main +# --- Go Build Stage: Compile the Go binary --- +FROM golang:1.24-alpine3.22 AS builder +WORKDIR /app + +# Install dependencies including gcc for CGO and sqlite +RUN apk add --no-cache upx gcc musl-dev sqlite-dev + +# Set environment for CGO-enabled build (required for go-sqlite3) +ENV CGO_ENABLED=1 GOOS=linux + +COPY transports/go.mod transports/go.sum ./ +RUN ls +RUN cat go.mod +RUN go mod download + +# Copy source code and dependencies +COPY transports/ ./ + +COPY --from=ui-builder /app/out ./bifrost-http/ui + +# Build the binary with CGO enabled and static SQLite linking +ENV GOWORK=off +ARG VERSION=unknown +RUN go build \ + -ldflags="-w -s -extldflags '-static' -X main.Version=v${VERSION}" \ + -a -trimpath \ + -tags "sqlite_static" \ + -o /app/main \ + ./bifrost-http + +# Compress binary with upx +RUN upx --best --lzma /app/main + +# Verify build succeeded +RUN test -f /app/main || (echo "Build failed" && exit 1) -# --- Second Stage: Runtime image --- -FROM alpine:latest +# --- Runtime Stage: Minimal runtime image --- +FROM alpine:3.22 WORKDIR /app -# Copy the compiled binary from the builder stage +# Create data directory and set up user COPY --from=builder /app/main . -# Ensure the binary is executable -RUN chmod +x /app/main -# Create a directory to store configuration files -RUN mkdir -p /app/config - -# Define build-time variables for config file paths -ARG CONFIG_PATH -ARG ENV_PATH -ARG PORT -ARG POOL_SIZE -ARG DROP_EXCESS_REQUESTS - -# Set default values if args are not provided -ENV APP_PORT=${PORT:-8080} -ENV APP_POOL_SIZE=${POOL_SIZE:-300} -ENV APP_DROP_EXCESS_REQUESTS=${DROP_EXCESS_REQUESTS:-false} - -# Copy the config and environment files into the image -COPY ${CONFIG_PATH} /app/config/config.json -COPY ${ENV_PATH} /app/config/.env - -# Write a small script to validate config presence and run the app -RUN echo '#!/bin/sh' > /app/entrypoint.sh && \ - echo 'if [ ! -f /app/config/config.json ]; then echo "Missing config.json"; exit 1; fi' >> /app/entrypoint.sh && \ - echo 'if [ ! -f /app/config/.env ]; then echo "Missing .env"; exit 1; fi' >> /app/entrypoint.sh && \ - echo 'if [ ! -f /app/main ]; then echo "Missing main binary"; exit 1; fi' >> /app/entrypoint.sh && \ - echo 'exec /app/main -config /app/config/config.json -env /app/config/.env -port "$APP_PORT" -pool-size "$APP_POOL_SIZE" -drop-excess-requests "$APP_DROP_EXCESS_REQUESTS"' >> /app/entrypoint.sh && \ - chmod +x /app/entrypoint.sh - -# Expose the port defined by argument -EXPOSE ${PORT:-8080} - -# Use the script as the entry point -ENTRYPOINT ["/app/entrypoint.sh"] \ No newline at end of file +COPY --from=builder /app/docker-entrypoint.sh . + +# Getting arguments +ARG ARG_APP_PORT=8080 +ARG ARG_APP_HOST=0.0.0.0 +ARG ARG_LOG_LEVEL=info +ARG ARG_LOG_STYLE=json +ARG ARG_APP_DIR=/app/data + +# Environment variables with defaults (can be overridden at runtime) +ENV APP_PORT=$ARG_APP_PORT \ + APP_HOST=$ARG_APP_HOST \ + LOG_LEVEL=$ARG_LOG_LEVEL \ + LOG_STYLE=$ARG_LOG_STYLE \ + APP_DIR=$ARG_APP_DIR + + +RUN mkdir -p $APP_DIR/logs && \ + adduser -D -s /bin/sh appuser && \ + chown -R appuser:appuser /app && \ + chmod +x /app/docker-entrypoint.sh +USER appuser + + +# Declare volume for data persistence +VOLUME ["/app/data"] +EXPOSE $APP_PORT + +# Health check for container status monitoring +HEALTHCHECK --interval=30s --timeout=10s --start-period=5s --retries=3 \ + CMD wget --no-verbose --tries=1 --spider http://127.0.0.1:${APP_PORT}/metrics || exit 1 + +# Use entrypoint script that handles volume permissions and argument processing +ENTRYPOINT ["/app/docker-entrypoint.sh"] +CMD ["/app/main"] \ No newline at end of file diff --git a/transports/README.md b/transports/README.md index 0f0670a03..65418db01 100644 --- a/transports/README.md +++ b/transports/README.md @@ -1,178 +1,166 @@ -# Bifrost Transports +# Bifrost Gateway -This package contains clients for various transports that can be used to spin up your Bifrost client with just a single line of code. +Bifrost Gateway is a blazing-fast HTTP API that unifies access to 12+ AI providers (OpenAI, Anthropic, AWS Bedrock, Google Vertex, and more) through a single OpenAI-compatible interface. Deploy in seconds with zero configuration and get automatic fallbacks, semantic caching, tool calling, and enterprise-grade features. -## πŸ“‘ Table of Contents - -- [Bifrost Transports](#bifrost-transports) - - [πŸ“‘ Table of Contents](#-table-of-contents) - - [πŸš€ Setting Up Transports](#-setting-up-transports) - - [Prerequisites](#prerequisites) - - [Configuration](#configuration) - - [Docker Setup](#docker-setup) - - [Go Setup](#go-setup) - - [🧰 Usage](#-usage) - - [Text Completions](#text-completions) - - [Chat Completions](#chat-completions) - - [πŸ”§ Advanced Features](#-advanced-features) - - [Fallbacks](#fallbacks) +**Complete Documentation**: [https://docs.getbifrost.ai](https://docs.getbifrost.ai) --- -## πŸš€ Setting Up Transports +## Quick Start -### Prerequisites -- Go 1.23 or higher (if not using Docker) -- Access to at least one AI model provider (OpenAI, Anthropic, etc.) -- API keys for the providers you wish to use +### Installation -### Configuration +Choose your preferred method: -Bifrost uses a combination of a JSON configuration file and environment variables: +#### NPX (Recommended) -1. **JSON Configuration File**: Bifrost requires a configuration file to set up the gateway. This includes all your provider-level settings, keys, and meta configs for each of your providers. - -2. **Environment Variables**: If you don't want to include your keys in your config file, you can provide a `.env` file and add a prefix of `env.` followed by its key in your `.env` file. +```bash +# Install and run locally +npx -y @maximhq/bifrost -```json -{ - "keys": [{ - "value": "env.OPENAI_API_KEY", - "models": ["gpt-4o-mini", "gpt-4-turbo"], - "weight": 1.0 - }] -} +# Open web interface at http://localhost:8080 ``` -In this example, `OPENAI_API_KEY` refers to a key in the `.env` file. At runtime, its value will be used to replace the placeholder. +#### Docker + +```bash +# Pull and run Bifrost Gateway +docker pull maximhq/bifrost +docker run -p 8080:8080 maximhq/bifrost + +# For persistent configuration +docker run -p 8080:8080 -v $(pwd)/data:/app/data maximhq/bifrost +``` + +### Configuration + +Bifrost starts with zero configuration needed. Configure providers through the **built-in web UI** at `http://localhost:8080` or via API: + +```bash +# Add OpenAI provider via API +curl -X POST http://localhost:8080/api/providers \ + -H "Content-Type: application/json" \ + -d '{ + "provider": "openai", + "keys": [{"value": "sk-your-openai-key", "models": ["gpt-4o-mini"], "weight": 1.0}] + }' +``` -The same setup applies to keys in meta configs of all providers: +For file-based configuration, create `config.json` in your app directory: ```json { - "meta_config": { - "secret_access_key": "env.BEDROCK_ACCESS_KEY", - "region": "env.BEDROCK_REGION" + "providers": { + "openai": { + "keys": [{"value": "env.OPENAI_API_KEY", "models": ["gpt-4o-mini"], "weight": 1.0}] + } } } ``` -In this example, `BEDROCK_ACCESS_KEY` and `BEDROCK_REGION` refer to keys in the `.env` file. - -Please refer to `config.example.json` and `.env.sample` for examples. - -### Docker Setup - -You can run Bifrost using our **independent Dockerfile**. Just copy our Dockerfile and run these commands to get your Bifrost instance up and running: +### Your First API Call ```bash -docker build \ - --build-arg CONFIG_PATH=./config.example.json \ - --build-arg ENV_PATH=./.env.sample \ - --build-arg PORT=8080 \ - --build-arg POOL_SIZE=300 \ - -t bifrost-transports . - -docker run -p 8080:8080 bifrost-transports +curl -X POST http://localhost:8080/v1/chat/completions \ + -H "Content-Type: application/json" \ + -d '{ + "model": "openai/gpt-4o-mini", + "messages": [{"role": "user", "content": "Hello, Bifrost!"}] + }' ``` -You can also add a flag for `DROP_EXCESS_REQUESTS=false` in your Docker build command to drop excess requests when the buffer is full. Read more about `DROP_EXCESS_REQUESTS` and `POOL_SIZE` [here](https://github.com/maximhq/bifrost/tree/main?tab=README-ov-file#additional-configurations). +**That's it!** You now have a unified AI gateway running locally. --- -### Go Setup +## Key Features -If you wish to run Bifrost in your Go environment, follow these steps: +Bifrost Gateway provides enterprise-grade AI infrastructure with these core capabilities: -1. Install your binary: +### Core Features -```bash -go install github.com/maximhq/bifrost/transports/http@latest -``` +- **[Unified Interface](https://docs.getbifrost.ai/features/unified-interface)** - Single OpenAI-compatible API for all providers +- **[Multi-Provider Support](https://docs.getbifrost.ai/quickstart/gateway/provider-configuration)** - OpenAI, Anthropic, AWS Bedrock, Google Vertex, Azure, Cohere, Mistral, Ollama, Groq, and more +- **[Drop-in Replacement](https://docs.getbifrost.ai/features/drop-in-replacement)** - Replace OpenAI/Anthropic/GenAI SDKs with zero code changes +- **[Automatic Fallbacks](https://docs.getbifrost.ai/features/fallbacks)** - Seamless failover between providers and models +- **[Streaming Support](https://docs.getbifrost.ai/quickstart/gateway/streaming)** - Real-time response streaming for all providers -2. Run your binary: +### Advanced Features -- If it's in your PATH: -```bash -http -config config.json -env .env -port 8080 -pool-size 300 -``` +- **[Model Context Protocol (MCP)](https://docs.getbifrost.ai/features/mcp)** - Enable AI models to use external tools (filesystem, web search, databases) +- **[Semantic Caching](https://docs.getbifrost.ai/features/semantic-caching)** - Intelligent response caching based on semantic similarity +- **[Load Balancing](https://docs.getbifrost.ai/features/fallbacks)** - Distribute requests across multiple API keys and providers +- **[Governance & Budget Management](https://docs.getbifrost.ai/features/governance)** - Usage tracking, rate limiting, and cost control +- **[Custom Plugins](https://docs.getbifrost.ai/enterprise/custom-plugins)** - Extensible middleware for analytics, monitoring, and custom logic -- Otherwise: -```bash -./http -config config.json -env .env -port 8080 -pool-size 300 -``` +### Enterprise Features -You can also add a flag for `-drop-excess-requests=false` in your command to drop excess requests when the buffer is full. Read more about `DROP_EXCESS_REQUESTS` and `POOL_SIZE` [here](https://github.com/maximhq/bifrost/tree/main?tab=README-ov-file#additional-configurations). +- **[Clustering](https://docs.getbifrost.ai/enterprise/clustering)** - Multi-node deployment with shared state +- **[SSO Integration](https://docs.getbifrost.ai/features/sso-with-google-github)** - Google, GitHub authentication +- **[Vault Support](https://docs.getbifrost.ai/enterprise/vault-support)** - Secure API key management +- **[Custom Analytics](https://docs.getbifrost.ai/features/observability)** - Detailed usage insights and monitoring +- **[In-VPC Deployments](https://docs.getbifrost.ai/enterprise/invpc-deployments)** - Private cloud deployment options -## 🧰 Usage +**Learn More**: [Complete Feature Documentation](https://docs.getbifrost.ai/features/unified-interface) -Ensure that: -- Bifrost's HTTP server is running -- The providers/models you use are configured in your JSON config file +--- -### Text Completions +## SDK Integrations -```bash -curl -X POST http://localhost:8080/v1/text/completions \ - -H "Content-Type: application/json" \ - -d '{ - "provider": "openai", - "model": "gpt-4o-mini", - "text": "Once upon a time in the land of AI,", - "params": { - "temperature": 0.7, - "max_tokens": 100 - } - }' +Replace your existing SDK base URLs to unlock Bifrost's features instantly: + +### OpenAI SDK + +```python +import openai +client = openai.OpenAI( + base_url="http://localhost:8080/openai", + api_key="dummy" # Handled by Bifrost +) ``` -### Chat Completions +### Anthropic SDK -```bash -curl -X POST http://localhost:8080/v1/chat/completions \ - -H "Content-Type: application/json" \ - -d '{ - "provider": "openai", - "model": "gpt-4o-mini", - "messages": [ - {"role": "system", "content": "You are a helpful assistant."}, - {"role": "user", "content": "Tell me about Bifrost in Norse mythology."} - ], - "params": { - "temperature": 0.8, - "max_tokens": 500 - } - }' +```python +import anthropic +client = anthropic.Anthropic( + base_url="http://localhost:8080/anthropic", + api_key="dummy" # Handled by Bifrost +) +``` + +### Google GenAI SDK + +```python +import google.generativeai as genai +genai.configure( + transport="rest", + api_endpoint="http://localhost:8080/genai", + api_key="dummy" # Handled by Bifrost +) ``` +**Complete Integration Guides**: [SDK Integrations](https://docs.getbifrost.ai/integrations/what-is-an-integration) + --- -## πŸ”§ Advanced Features +## Documentation -### Fallbacks +### Getting Started -Configure fallback options in your requests: +- [Quick Setup Guide](https://docs.getbifrost.ai/quickstart/gateway/setting-up) - Detailed installation and configuration +- [Provider Configuration](https://docs.getbifrost.ai/quickstart/gateway/provider-configuration) - Connect multiple AI providers +- [Integration Guide](https://docs.getbifrost.ai/quickstart/gateway/integrations) - SDK replacements -```json -{ - "provider": "openai", - "model": "gpt-4", - "messages": [...], - "fallbacks": [ - { - "provider": "anthropic", - "model": "claude-3-opus-20240229" - }, - { - "provider": "bedrock", - "model": "anthropic.claude-3-sonnet-20240229-v1:0" - } - ] -} -``` +### Advanced Topics + +- [MCP Tool Calling](https://docs.getbifrost.ai/features/mcp) - External tool integration +- [Semantic Caching](https://docs.getbifrost.ai/features/semantic-caching) - Intelligent response caching +- [Fallbacks & Load Balancing](https://docs.getbifrost.ai/features/fallbacks) - Reliability and scaling +- [Budget Management](https://docs.getbifrost.ai/features/governance) - Cost control and governance -Read more about fallbacks and other additional configurations [here](https://github.com/maximhq/bifrost/tree/main?tab=README-ov-file#additional-configurations). +**Browse All Documentation**: [https://docs.getbifrost.ai](https://docs.getbifrost.ai) --- -Built with ❀️ by [Maxim](https://github.com/maximhq) \ No newline at end of file +*Built with ❀️ by [Maxim](https://getmaxim.ai)* diff --git a/transports/bifrost-http/.air.toml b/transports/bifrost-http/.air.toml new file mode 100644 index 000000000..d18ee38e0 --- /dev/null +++ b/transports/bifrost-http/.air.toml @@ -0,0 +1,63 @@ +root = "../.." +testdata_dir = "testdata" +tmp_dir = "transports/bifrost-http/tmp" + +[build] +args_bin = [] +bin = "tmp/main" +cmd = "go build -o ./tmp/main ." +delay = 1000 +exclude_dir = [ + "assets", + "tmp", + "vendor", + "testdata", + "ui", + "node_modules", + "transports/bifrost-http/ui", + "core/tests", + "tests", + "docs", + "npx", +] +exclude_file = [] +exclude_regex = ["_test.go"] +exclude_unchanged = false +follow_symlink = false +full_bin = "" +watch_dirs = ["."] +include_dir = [] +include_ext = ["go", "tpl", "tmpl", "html"] +include_file = [] +kill_delay = "1s" +log = "tmp/build-errors.log" +poll = false +stop_on_error = true +poll_interval = 0 +rerun = false +rerun_delay = 500 +send_interrupt = true +stop_on_root = false + +[color] +app = "" +build = "yellow" +main = "magenta" +runner = "green" +watcher = "cyan" + +[log] +main_only = false +time = false + +[misc] +clean_on_exit = false + +[proxy] +enabled = false +proxy_port = 8090 +app_port = 8080 + +[screen] +clear_on_rebuild = false +keep_scroll = true diff --git a/transports/bifrost-http/handlers/cache.go b/transports/bifrost-http/handlers/cache.go new file mode 100644 index 000000000..931c6b068 --- /dev/null +++ b/transports/bifrost-http/handlers/cache.go @@ -0,0 +1,62 @@ +package handlers + +import ( + "github.com/fasthttp/router" + "github.com/maximhq/bifrost/core/schemas" + "github.com/maximhq/bifrost/plugins/semanticcache" + "github.com/valyala/fasthttp" +) + +type CacheHandler struct { + logger schemas.Logger + plugin *semanticcache.Plugin +} + +func NewCacheHandler(plugin schemas.Plugin, logger schemas.Logger) *CacheHandler { + semanticCachePlugin, ok := plugin.(*semanticcache.Plugin) + if !ok { + logger.Fatal("Cache handler requires a semantic cache plugin") + } + + return &CacheHandler{ + plugin: semanticCachePlugin, + logger: logger, + } +} + +func (h *CacheHandler) RegisterRoutes(r *router.Router) { + r.DELETE("/api/cache/clear/{requestId}", h.clearCache) + r.DELETE("/api/cache/clear-by-key/{cacheKey}", h.clearCacheByKey) +} + +func (h *CacheHandler) clearCache(ctx *fasthttp.RequestCtx) { + requestID, ok := ctx.UserValue("requestId").(string) + if !ok { + SendError(ctx, fasthttp.StatusBadRequest, "Invalid request ID", h.logger) + return + } + if err := h.plugin.ClearCacheForRequestID(requestID); err != nil { + SendError(ctx, fasthttp.StatusInternalServerError, "Failed to clear cache", h.logger) + return + } + + SendJSON(ctx, map[string]any{ + "message": "Cache cleared successfully", + }, h.logger) +} + +func (h *CacheHandler) clearCacheByKey(ctx *fasthttp.RequestCtx) { + cacheKey, ok := ctx.UserValue("cacheKey").(string) + if !ok { + SendError(ctx, fasthttp.StatusBadRequest, "Invalid cache key", h.logger) + return + } + if err := h.plugin.ClearCacheForKey(cacheKey); err != nil { + SendError(ctx, fasthttp.StatusInternalServerError, "Failed to clear cache", h.logger) + return + } + + SendJSON(ctx, map[string]any{ + "message": "Cache cleared successfully", + }, h.logger) +} diff --git a/transports/bifrost-http/handlers/completions.go b/transports/bifrost-http/handlers/completions.go new file mode 100644 index 000000000..7333c9074 --- /dev/null +++ b/transports/bifrost-http/handlers/completions.go @@ -0,0 +1,650 @@ +// Package handlers provides HTTP request handlers for the Bifrost HTTP transport. +// This file contains completion request handlers for text and chat completions. +package handlers + +import ( + "bufio" + "context" + "encoding/json" + "fmt" + "io" + "mime/multipart" + "net/http" + "path/filepath" + "strconv" + "strings" + + "github.com/bytedance/sonic" + "github.com/fasthttp/router" + bifrost "github.com/maximhq/bifrost/core" + "github.com/maximhq/bifrost/core/schemas" + "github.com/maximhq/bifrost/transports/bifrost-http/lib" + "github.com/valyala/fasthttp" +) + +// CompletionHandler manages HTTP requests for completion operations +type CompletionHandler struct { + client *bifrost.Bifrost + handlerStore lib.HandlerStore + logger schemas.Logger +} + +// NewCompletionHandler creates a new completion handler instance +func NewCompletionHandler(client *bifrost.Bifrost, handlerStore lib.HandlerStore, logger schemas.Logger) *CompletionHandler { + return &CompletionHandler{ + client: client, + handlerStore: handlerStore, + logger: logger, + } +} + +// Known fields for CompletionRequest +var completionRequestKnownFields = map[string]bool{ + "model": true, + "messages": true, + "text": true, + "fallbacks": true, + "stream": true, + "input": true, + "voice": true, + "instructions": true, + "response_format": true, + "stream_format": true, + "tool_choice": true, + "tools": true, + "temperature": true, + "top_p": true, + "top_k": true, + "max_tokens": true, + "stop_sequences": true, + "presence_penalty": true, + "frequency_penalty": true, + "parallel_tool_calls": true, + "encoding_format": true, + "dimensions": true, + "user": true, +} + +// CompletionRequest represents a request for either text or chat completion +type CompletionRequest struct { + Model string `json:"model"` // Model to use in "provider/model" format + Messages []schemas.BifrostMessage `json:"messages"` // Chat messages (for chat completion) + Text string `json:"text"` // Text input (for text completion) + Fallbacks []string `json:"fallbacks"` // Fallback providers and models in "provider/model" format + Stream *bool `json:"stream"` // Whether to stream the response + + // Speech inputs + Input schemas.EmbeddingInput `json:"input"` // string can be used for voice input as well + Voice schemas.SpeechVoiceInput `json:"voice"` + Instructions string `json:"instructions"` + ResponseFormat string `json:"response_format"` + StreamFormat *string `json:"stream_format,omitempty"` + + ToolChoice *schemas.ToolChoice `json:"tool_choice,omitempty"` // Whether to call a tool + Tools *[]schemas.Tool `json:"tools,omitempty"` // Tools to use + Temperature *float64 `json:"temperature,omitempty"` // Controls randomness in the output + TopP *float64 `json:"top_p,omitempty"` // Controls diversity via nucleus sampling + TopK *int `json:"top_k,omitempty"` // Controls diversity via top-k sampling + MaxTokens *int `json:"max_tokens,omitempty"` // Maximum number of tokens to generate + StopSequences *[]string `json:"stop_sequences,omitempty"` // Sequences that stop generation + PresencePenalty *float64 `json:"presence_penalty,omitempty"` // Penalizes repeated tokens + FrequencyPenalty *float64 `json:"frequency_penalty,omitempty"` // Penalizes frequent tokens + ParallelToolCalls *bool `json:"parallel_tool_calls,omitempty"` // Enables parallel tool calls + EncodingFormat *string `json:"encoding_format,omitempty"` // Format for embedding output (e.g., "float", "base64") + Dimensions *int `json:"dimensions,omitempty"` // Number of dimensions for embedding output + User *string `json:"user,omitempty"` // User identifier for tracking + + // Dynamic parameters that can be provider-specific, they are directly + // added to the request as is. + ExtraParams map[string]interface{} `json:"-"` +} + +func (cr *CompletionRequest) UnmarshalJSON(data []byte) error { + // Use type alias to avoid infinite recursion + type Alias CompletionRequest + aux := (*Alias)(cr) + + // First unmarshal known fields + if err := sonic.Unmarshal(data, aux); err != nil { + return err + } + + // Then unmarshal to map for unknown fields + var rawData map[string]json.RawMessage + if err := sonic.Unmarshal(data, &rawData); err != nil { + return err + } + + // Initialize ExtraParams + if cr.ExtraParams == nil { + cr.ExtraParams = make(map[string]interface{}) + } + + // Extract unknown fields + for key, value := range rawData { + if !completionRequestKnownFields[key] { + var v interface{} + if err := sonic.Unmarshal(value, &v); err != nil { + continue // Skip fields that can't be unmarshaled + } + cr.ExtraParams[key] = v + } + } + + return nil +} + +func (cr *CompletionRequest) GetModelParameters() *schemas.ModelParameters { + params := &schemas.ModelParameters{ + ExtraParams: make(map[string]interface{}), + ToolChoice: cr.ToolChoice, + Tools: cr.Tools, + Temperature: cr.Temperature, + TopP: cr.TopP, + TopK: cr.TopK, + MaxTokens: cr.MaxTokens, + StopSequences: cr.StopSequences, + PresencePenalty: cr.PresencePenalty, + FrequencyPenalty: cr.FrequencyPenalty, + ParallelToolCalls: cr.ParallelToolCalls, + EncodingFormat: cr.EncodingFormat, + Dimensions: cr.Dimensions, + User: cr.User, + } + + if cr.ExtraParams != nil { + for k, v := range cr.ExtraParams { + params.ExtraParams[k] = v + } + } + + return params +} + +type CompletionType string + +const ( + CompletionTypeText CompletionType = "text" + CompletionTypeChat CompletionType = "chat" + CompletionTypeEmbeddings CompletionType = "embeddings" + CompletionTypeSpeech CompletionType = "speech" + CompletionTypeTranscription CompletionType = "transcription" +) + +const ( + // Maximum file size (25MB) + MaxFileSize = 25 * 1024 * 1024 + + // Primary MIME types for audio formats + AudioMimeMP3 = "audio/mpeg" // Covers MP3, MPEG, MPGA + AudioMimeMP4 = "audio/mp4" // MP4 audio + AudioMimeM4A = "audio/x-m4a" // M4A specific + AudioMimeOGG = "audio/ogg" // OGG audio + AudioMimeWAV = "audio/wav" // WAV audio + AudioMimeWEBM = "audio/webm" // WEBM audio + AudioMimeFLAC = "audio/flac" // FLAC audio + AudioMimeFLAC2 = "audio/x-flac" // Alternative FLAC +) + +// validateAudioFile checks if the file size and format are valid +func (h *CompletionHandler) validateAudioFile(fileHeader *multipart.FileHeader) error { + // Check file size + if fileHeader.Size > MaxFileSize { + return fmt.Errorf("file size exceeds maximum limit of %d MB", MaxFileSize/1024/1024) + } + + // Get file extension + ext := strings.ToLower(filepath.Ext(fileHeader.Filename)) + + // Check file extension + validExtensions := map[string]bool{ + ".flac": true, + ".mp3": true, + ".mp4": true, + ".mpeg": true, + ".mpga": true, + ".m4a": true, + ".ogg": true, + ".wav": true, + ".webm": true, + } + + if !validExtensions[ext] { + return fmt.Errorf("unsupported file format: %s. Supported formats: flac, mp3, mp4, mpeg, mpga, m4a, ogg, wav, webm", ext) + } + + // Open file to check MIME type + file, err := fileHeader.Open() + if err != nil { + return fmt.Errorf("failed to open file: %v", err) + } + defer file.Close() + + // Read first 512 bytes for MIME type detection + buffer := make([]byte, 512) + _, err = file.Read(buffer) + if err != nil && err != io.EOF { + return fmt.Errorf("failed to read file header: %v", err) + } + + // Check MIME type + mimeType := http.DetectContentType(buffer) + validMimeTypes := map[string]bool{ + // Primary MIME types + AudioMimeMP3: true, // Covers MP3, MPEG, MPGA + AudioMimeMP4: true, + AudioMimeM4A: true, + AudioMimeOGG: true, + AudioMimeWAV: true, + AudioMimeWEBM: true, + AudioMimeFLAC: true, + AudioMimeFLAC2: true, + + // Alternative MIME types + "audio/mpeg3": true, + "audio/x-wav": true, + "audio/vnd.wave": true, + "audio/x-mpeg": true, + "audio/x-mpeg3": true, + "audio/x-mpg": true, + "audio/x-mpegaudio": true, + } + + if !validMimeTypes[mimeType] { + return fmt.Errorf("invalid file type: %s. Supported audio formats: flac, mp3, mp4, mpeg, mpga, m4a, ogg, wav, webm", mimeType) + } + + // Reset file pointer for subsequent reads + _, err = file.Seek(0, 0) + if err != nil { + return fmt.Errorf("failed to reset file pointer: %v", err) + } + + return nil +} + +// RegisterRoutes registers all completion-related routes +func (h *CompletionHandler) RegisterRoutes(r *router.Router) { + // Completion endpoints + r.POST("/v1/text/completions", h.textCompletion) + r.POST("/v1/chat/completions", h.chatCompletion) + r.POST("/v1/embeddings", h.embeddings) + r.POST("/v1/audio/speech", h.speechCompletion) + r.POST("/v1/audio/transcriptions", h.transcriptionCompletion) +} + +// textCompletion handles POST /v1/text/completions - Process text completion requests +func (h *CompletionHandler) textCompletion(ctx *fasthttp.RequestCtx) { + h.handleRequest(ctx, CompletionTypeText) +} + +// chatCompletion handles POST /v1/chat/completions - Process chat completion requests +func (h *CompletionHandler) chatCompletion(ctx *fasthttp.RequestCtx) { + h.handleRequest(ctx, CompletionTypeChat) +} + +// embeddings handles POST /v1/embeddings - Process embeddings requests +func (h *CompletionHandler) embeddings(ctx *fasthttp.RequestCtx) { + h.handleRequest(ctx, CompletionTypeEmbeddings) +} + +// speechCompletion handles POST /v1/audio/speech - Process speech completion requests +func (h *CompletionHandler) speechCompletion(ctx *fasthttp.RequestCtx) { + h.handleRequest(ctx, CompletionTypeSpeech) +} + +// transcriptionCompletion handles POST /v1/audio/transcriptions - Process transcription requests +func (h *CompletionHandler) transcriptionCompletion(ctx *fasthttp.RequestCtx) { + // Parse multipart form + form, err := ctx.MultipartForm() + if err != nil { + SendError(ctx, fasthttp.StatusBadRequest, fmt.Sprintf("Failed to parse multipart form: %v", err), h.logger) + return + } + + // Extract model (required) + modelValues := form.Value["model"] + if len(modelValues) == 0 || modelValues[0] == "" { + SendError(ctx, fasthttp.StatusBadRequest, "Model is required", h.logger) + return + } + + provider, modelName, err := ParseModel(modelValues[0]) + if err != nil { + SendError(ctx, fasthttp.StatusBadRequest, fmt.Sprintf("Model must be in the format of 'provider/model': %v", err), h.logger) + return + } + + // Extract file (required) + fileHeaders := form.File["file"] + if len(fileHeaders) == 0 { + SendError(ctx, fasthttp.StatusBadRequest, "File is required", h.logger) + return + } + + fileHeader := fileHeaders[0] + + // // Validate file size and format + // if err := h.validateAudioFile(fileHeader); err != nil { + // SendError(ctx, fasthttp.StatusBadRequest, err.Error(), h.logger) + // return + // } + + file, err := fileHeader.Open() + if err != nil { + SendError(ctx, fasthttp.StatusBadRequest, fmt.Sprintf("Failed to open uploaded file: %v", err), h.logger) + return + } + defer file.Close() + + // Read file data + fileData := make([]byte, fileHeader.Size) + if _, err := file.Read(fileData); err != nil { + SendError(ctx, fasthttp.StatusInternalServerError, fmt.Sprintf("Failed to read uploaded file: %v", err), h.logger) + return + } + + // Create transcription input + transcriptionInput := &schemas.TranscriptionInput{ + File: fileData, + } + + // Extract optional parameters + if languageValues := form.Value["language"]; len(languageValues) > 0 && languageValues[0] != "" { + transcriptionInput.Language = &languageValues[0] + } + + if promptValues := form.Value["prompt"]; len(promptValues) > 0 && promptValues[0] != "" { + transcriptionInput.Prompt = &promptValues[0] + } + + if responseFormatValues := form.Value["response_format"]; len(responseFormatValues) > 0 && responseFormatValues[0] != "" { + transcriptionInput.ResponseFormat = &responseFormatValues[0] + } + + // Create BifrostRequest + bifrostReq := &schemas.BifrostRequest{ + Model: modelName, + Provider: schemas.ModelProvider(provider), + Input: schemas.RequestInput{ + TranscriptionInput: transcriptionInput, + }, + } + + // Convert context + bifrostCtx := lib.ConvertToBifrostContext(ctx, h.handlerStore.ShouldAllowDirectKeys()) + if bifrostCtx == nil { + SendError(ctx, fasthttp.StatusInternalServerError, "Failed to convert context", h.logger) + return + } + + if streamValues := form.Value["stream"]; len(streamValues) > 0 && streamValues[0] != "" { + stream := streamValues[0] + if stream == "true" { + h.handleStreamingTranscriptionRequest(ctx, bifrostReq, bifrostCtx) + return + } + } + + // Make transcription request + resp, bifrostErr := h.client.TranscriptionRequest(*bifrostCtx, bifrostReq) + + // Handle response + if bifrostErr != nil { + SendBifrostError(ctx, bifrostErr, h.logger) + return + } + + // Send successful response + SendJSON(ctx, resp, h.logger) +} + +// handleCompletion processes both text and chat completion requests +// It handles request parsing, validation, and response formatting +func (h *CompletionHandler) handleRequest(ctx *fasthttp.RequestCtx, completionType CompletionType) { + var req CompletionRequest + if err := sonic.Unmarshal(ctx.PostBody(), &req); err != nil { + SendError(ctx, fasthttp.StatusBadRequest, fmt.Sprintf("Invalid request format: %v", err), h.logger) + return + } + + if req.Model == "" { + SendError(ctx, fasthttp.StatusBadRequest, "Model is required", h.logger) + return + } + + provider, modelName, err := ParseModel(req.Model) + if err != nil { + SendError(ctx, fasthttp.StatusBadRequest, fmt.Sprintf("Model must be in the format of 'provider/model': %v", err), h.logger) + return + } + + fallbacks := make([]schemas.Fallback, len(req.Fallbacks)) + for i, fallback := range req.Fallbacks { + fallbackProvider, fallbackModelName, err := ParseModel(fallback) + if err != nil { + SendError(ctx, fasthttp.StatusBadRequest, fmt.Sprintf("Fallback must be in the format of 'provider/model': %v", err), h.logger) + return + } + if fallbackProvider == "" || fallbackModelName == "" { + SendError(ctx, fasthttp.StatusBadRequest, "Fallback must be in the format of 'provider/model'", h.logger) + return + } + fallbacks[i] = schemas.Fallback{ + Provider: schemas.ModelProvider(fallbackProvider), + Model: fallbackModelName, + } + } + + // Create BifrostRequest + bifrostReq := &schemas.BifrostRequest{ + Model: modelName, + Provider: schemas.ModelProvider(provider), + Params: req.GetModelParameters(), + Fallbacks: fallbacks, + } + + // Validate and set input based on completion type + switch completionType { + case CompletionTypeText: + if req.Text == "" { + SendError(ctx, fasthttp.StatusBadRequest, "Text is required for text completion", h.logger) + return + } + bifrostReq.Input = schemas.RequestInput{ + TextCompletionInput: &req.Text, + } + case CompletionTypeChat: + if len(req.Messages) == 0 { + SendError(ctx, fasthttp.StatusBadRequest, "Messages array is required for chat completion", h.logger) + return + } + bifrostReq.Input = schemas.RequestInput{ + ChatCompletionInput: &req.Messages, + } + case CompletionTypeEmbeddings: + bifrostReq.Input = schemas.RequestInput{ + EmbeddingInput: &req.Input, + } + case CompletionTypeSpeech: + if req.Input.Text == nil { + SendError(ctx, fasthttp.StatusBadRequest, "Input is required for speech completion", h.logger) + return + } + if req.Voice.Voice == nil && len(req.Voice.MultiVoiceConfig) == 0 { + SendError(ctx, fasthttp.StatusBadRequest, "Voice is required for speech completion", h.logger) + return + } + bifrostReq.Input = schemas.RequestInput{ + SpeechInput: &schemas.SpeechInput{ + Input: *req.Input.Text, + VoiceConfig: req.Voice, + Instructions: req.Instructions, + ResponseFormat: req.ResponseFormat, + }, + } + } + + // Convert context + bifrostCtx := lib.ConvertToBifrostContext(ctx, h.handlerStore.ShouldAllowDirectKeys()) + if bifrostCtx == nil { + SendError(ctx, fasthttp.StatusInternalServerError, "Failed to convert context", h.logger) + return + } + + // Check if streaming is requested + isStreaming := req.Stream != nil && *req.Stream || req.StreamFormat != nil && *req.StreamFormat == "sse" + + // Handle streaming for chat completions only + if isStreaming { + switch completionType { + case CompletionTypeChat: + h.handleStreamingChatCompletion(ctx, bifrostReq, bifrostCtx) + return + case CompletionTypeSpeech: + h.handleStreamingSpeech(ctx, bifrostReq, bifrostCtx) + return + } + } + + // Handle non-streaming requests + var resp *schemas.BifrostResponse + var bifrostErr *schemas.BifrostError + + switch completionType { + case CompletionTypeText: + resp, bifrostErr = h.client.TextCompletionRequest(*bifrostCtx, bifrostReq) + case CompletionTypeChat: + resp, bifrostErr = h.client.ChatCompletionRequest(*bifrostCtx, bifrostReq) + case CompletionTypeEmbeddings: + resp, bifrostErr = h.client.EmbeddingRequest(*bifrostCtx, bifrostReq) + case CompletionTypeSpeech: + resp, bifrostErr = h.client.SpeechRequest(*bifrostCtx, bifrostReq) + } + + // Handle response + if bifrostErr != nil { + SendBifrostError(ctx, bifrostErr, h.logger) + return + } + + if completionType == CompletionTypeSpeech { + if resp.Speech.Audio == nil { + SendError(ctx, fasthttp.StatusInternalServerError, "Speech response is missing audio data", h.logger) + return + } + + ctx.Response.Header.Set("Content-Type", "audio/mpeg") + ctx.Response.Header.Set("Content-Disposition", "attachment; filename=speech.mp3") + ctx.Response.Header.Set("Content-Length", strconv.Itoa(len(resp.Speech.Audio))) + ctx.Response.SetBody(resp.Speech.Audio) + return + } + + // Send successful response + SendJSON(ctx, resp, h.logger) +} + +// handleStreamingResponse is a generic function to handle streaming responses using Server-Sent Events (SSE) +func (h *CompletionHandler) handleStreamingResponse(ctx *fasthttp.RequestCtx, getStream func() (chan *schemas.BifrostStream, *schemas.BifrostError), extractResponse func(*schemas.BifrostStream) (interface{}, bool)) { + // Set SSE headers + ctx.SetContentType("text/event-stream") + ctx.Response.Header.Set("Cache-Control", "no-cache") + ctx.Response.Header.Set("Connection", "keep-alive") + ctx.Response.Header.Set("Access-Control-Allow-Origin", "*") + + // Get the streaming channel + stream, bifrostErr := getStream() + if bifrostErr != nil { + // Send error in SSE format + SendSSEError(ctx, bifrostErr, h.logger) + return + } + + // Use streaming response writer + ctx.Response.SetBodyStreamWriter(func(w *bufio.Writer) { + defer w.Flush() + + // Process streaming responses + for response := range stream { + if response == nil { + continue + } + + // Extract and validate the response data + data, valid := extractResponse(response) + if !valid { + continue + } + + // Convert response to JSON + responseJSON, err := sonic.Marshal(data) + if err != nil { + h.logger.Warn(fmt.Sprintf("Failed to marshal streaming response: %v", err)) + continue + } + + // Send as SSE data + if _, err := fmt.Fprintf(w, "data: %s\n\n", responseJSON); err != nil { + h.logger.Warn(fmt.Sprintf("Failed to write SSE data: %v", err)) + break + } + + // Flush immediately to send the chunk + if err := w.Flush(); err != nil { + h.logger.Warn(fmt.Sprintf("Failed to flush SSE data: %v", err)) + break + } + } + + // Send the [DONE] marker to indicate the end of the stream + if _, err := fmt.Fprint(w, "data: [DONE]\n\n"); err != nil { + h.logger.Warn(fmt.Sprintf("Failed to write SSE done marker: %v", err)) + } + }) +} + +// handleStreamingChatCompletion handles streaming chat completion requests using Server-Sent Events (SSE) +func (h *CompletionHandler) handleStreamingChatCompletion(ctx *fasthttp.RequestCtx, req *schemas.BifrostRequest, bifrostCtx *context.Context) { + getStream := func() (chan *schemas.BifrostStream, *schemas.BifrostError) { + return h.client.ChatCompletionStreamRequest(*bifrostCtx, req) + } + + extractResponse := func(response *schemas.BifrostStream) (interface{}, bool) { + return response, true + } + + h.handleStreamingResponse(ctx, getStream, extractResponse) +} + +// handleStreamingSpeech handles streaming speech requests using Server-Sent Events (SSE) +func (h *CompletionHandler) handleStreamingSpeech(ctx *fasthttp.RequestCtx, req *schemas.BifrostRequest, bifrostCtx *context.Context) { + getStream := func() (chan *schemas.BifrostStream, *schemas.BifrostError) { + return h.client.SpeechStreamRequest(*bifrostCtx, req) + } + + extractResponse := func(response *schemas.BifrostStream) (interface{}, bool) { + if response.Speech == nil || response.Speech.BifrostSpeechStreamResponse == nil { + return nil, false + } + return response.Speech, true + } + + h.handleStreamingResponse(ctx, getStream, extractResponse) +} + +// handleStreamingTranscriptionRequest handles streaming transcription requests using Server-Sent Events (SSE) +func (h *CompletionHandler) handleStreamingTranscriptionRequest(ctx *fasthttp.RequestCtx, req *schemas.BifrostRequest, bifrostCtx *context.Context) { + getStream := func() (chan *schemas.BifrostStream, *schemas.BifrostError) { + return h.client.TranscriptionStreamRequest(*bifrostCtx, req) + } + + extractResponse := func(response *schemas.BifrostStream) (interface{}, bool) { + if response.Transcribe == nil || response.Transcribe.BifrostTranscribeStreamResponse == nil { + return nil, false + } + return response.Transcribe, true + } + + h.handleStreamingResponse(ctx, getStream, extractResponse) +} diff --git a/transports/bifrost-http/handlers/config.go b/transports/bifrost-http/handlers/config.go new file mode 100644 index 000000000..6773a1391 --- /dev/null +++ b/transports/bifrost-http/handlers/config.go @@ -0,0 +1,131 @@ +package handlers + +import ( + "encoding/json" + "fmt" + "slices" + + "github.com/fasthttp/router" + bifrost "github.com/maximhq/bifrost/core" + "github.com/maximhq/bifrost/core/schemas" + "github.com/maximhq/bifrost/framework/configstore" + "github.com/maximhq/bifrost/transports/bifrost-http/lib" + "github.com/valyala/fasthttp" +) + +// ConfigHandler manages runtime configuration updates for Bifrost. +// It provides endpoints to update and retrieve settings persisted via the ConfigStore backed by sql database. +type ConfigHandler struct { + client *bifrost.Bifrost + logger schemas.Logger + store *lib.Config +} + +// NewConfigHandler creates a new handler for configuration management. +// It requires the Bifrost client, a logger, and the config store. +func NewConfigHandler(client *bifrost.Bifrost, logger schemas.Logger, store *lib.Config) *ConfigHandler { + return &ConfigHandler{ + client: client, + logger: logger, + store: store, + } +} + +// RegisterRoutes registers the configuration-related routes. +// It adds the `PUT /api/config` endpoint. +func (h *ConfigHandler) RegisterRoutes(r *router.Router) { + r.GET("/api/config", h.getConfig) + r.PUT("/api/config", h.updateConfig) + r.GET("/api/version", h.getVersion) +} + +// getVersion handles GET /api/version - Get the current version +func (h *ConfigHandler) getVersion(ctx *fasthttp.RequestCtx) { + SendJSON(ctx, version, h.logger) +} + +// getConfig handles GET /config - Get the current configuration +func (h *ConfigHandler) getConfig(ctx *fasthttp.RequestCtx) { + + var mapConfig = make(map[string]any) + + if query := string(ctx.QueryArgs().Peek("from_db")); query == "true" { + if h.store.ConfigStore == nil { + SendError(ctx, fasthttp.StatusServiceUnavailable, "config store not available", h.logger) + return + } + cc, err := h.store.ConfigStore.GetClientConfig() + if err != nil { + SendError(ctx, fasthttp.StatusInternalServerError, + fmt.Sprintf("failed to fetch config from db: %v", err), h.logger) + return + } + if cc != nil { + mapConfig["client_config"] = *cc + } + } else { + mapConfig["client_config"] = h.store.ClientConfig + } + + mapConfig["is_db_connected"] = h.store.ConfigStore != nil + mapConfig["is_cache_connected"] = h.store.VectorStore != nil + mapConfig["is_logs_connected"] = h.store.LogsStore != nil + + SendJSON(ctx, mapConfig, h.logger) +} + +// updateConfig updates the core configuration settings. +// Currently, it supports hot-reloading of the `drop_excess_requests` setting. +// Note that settings like `prometheus_labels` cannot be changed at runtime. +func (h *ConfigHandler) updateConfig(ctx *fasthttp.RequestCtx) { + if h.store.ConfigStore == nil { + SendError(ctx, fasthttp.StatusInternalServerError, "Config store not initialized", h.logger) + return + } + + var req configstore.ClientConfig + + if err := json.Unmarshal(ctx.PostBody(), &req); err != nil { + SendError(ctx, fasthttp.StatusBadRequest, fmt.Sprintf("Invalid request format: %v", err), h.logger) + return + } + + // Get current config with proper locking + currentConfig := h.store.ClientConfig + updatedConfig := currentConfig + + if req.DropExcessRequests != currentConfig.DropExcessRequests { + h.client.UpdateDropExcessRequests(req.DropExcessRequests) + updatedConfig.DropExcessRequests = req.DropExcessRequests + } + + if !slices.Equal(req.PrometheusLabels, currentConfig.PrometheusLabels) { + updatedConfig.PrometheusLabels = req.PrometheusLabels + } + + if !slices.Equal(req.AllowedOrigins, currentConfig.AllowedOrigins) { + updatedConfig.AllowedOrigins = req.AllowedOrigins + } + + updatedConfig.InitialPoolSize = req.InitialPoolSize + updatedConfig.EnableLogging = req.EnableLogging + updatedConfig.EnableGovernance = req.EnableGovernance + updatedConfig.EnforceGovernanceHeader = req.EnforceGovernanceHeader + updatedConfig.AllowDirectKeys = req.AllowDirectKeys + updatedConfig.MaxRequestBodySizeMB = req.MaxRequestBodySizeMB + + // Update the store with the new config + h.store.ClientConfig = updatedConfig + + if err := h.store.ConfigStore.UpdateClientConfig(&updatedConfig); err != nil { + h.logger.Warn(fmt.Sprintf("failed to save configuration: %v", err)) + SendError(ctx, fasthttp.StatusInternalServerError, fmt.Sprintf("failed to save configuration: %v", err), h.logger) + return + } + + ctx.SetStatusCode(fasthttp.StatusOK) + SendJSON(ctx, map[string]any{ + "status": "success", + "message": "configuration updated successfully", + }, h.logger) +} diff --git a/transports/bifrost-http/handlers/governance.go b/transports/bifrost-http/handlers/governance.go new file mode 100644 index 000000000..cc98e1222 --- /dev/null +++ b/transports/bifrost-http/handlers/governance.go @@ -0,0 +1,1045 @@ +// Package handlers provides HTTP request handlers for the Bifrost HTTP transport. +// This file contains all governance management functionality including CRUD operations for VKs, Rules, and configs. +package handlers + +import ( + "encoding/json" + "fmt" + "time" + + "github.com/fasthttp/router" + "github.com/google/uuid" + "github.com/maximhq/bifrost/core/schemas" + "github.com/maximhq/bifrost/framework/configstore" + "github.com/maximhq/bifrost/plugins/governance" + "github.com/valyala/fasthttp" + "gorm.io/gorm" +) + +// GovernanceHandler manages HTTP requests for governance operations +type GovernanceHandler struct { + plugin *governance.GovernancePlugin + pluginStore *governance.GovernanceStore + configStore configstore.ConfigStore + logger schemas.Logger +} + +// NewGovernanceHandler creates a new governance handler instance +func NewGovernanceHandler(plugin *governance.GovernancePlugin, configStore configstore.ConfigStore, logger schemas.Logger) (*GovernanceHandler, error) { + if configStore == nil { + return nil, fmt.Errorf("config store is required") + } + + return &GovernanceHandler{ + plugin: plugin, + pluginStore: plugin.GetGovernanceStore(), + configStore: configStore, + logger: logger, + }, nil +} + +// CreateVirtualKeyRequest represents the request body for creating a virtual key +type CreateVirtualKeyRequest struct { + Name string `json:"name" validate:"required"` + Description string `json:"description,omitempty"` + AllowedModels []string `json:"allowed_models,omitempty"` // Empty means all models allowed + AllowedProviders []string `json:"allowed_providers,omitempty"` // Empty means all providers allowed + TeamID *string `json:"team_id,omitempty"` // Mutually exclusive with CustomerID + CustomerID *string `json:"customer_id,omitempty"` // Mutually exclusive with TeamID + Budget *CreateBudgetRequest `json:"budget,omitempty"` + RateLimit *CreateRateLimitRequest `json:"rate_limit,omitempty"` + KeyIDs []string `json:"key_ids,omitempty"` // List of DBKey UUIDs to associate with this VirtualKey + IsActive *bool `json:"is_active,omitempty"` +} + +// UpdateVirtualKeyRequest represents the request body for updating a virtual key +type UpdateVirtualKeyRequest struct { + Description *string `json:"description,omitempty"` + AllowedModels *[]string `json:"allowed_models,omitempty"` + AllowedProviders *[]string `json:"allowed_providers,omitempty"` + TeamID *string `json:"team_id,omitempty"` + CustomerID *string `json:"customer_id,omitempty"` + Budget *UpdateBudgetRequest `json:"budget,omitempty"` + RateLimit *UpdateRateLimitRequest `json:"rate_limit,omitempty"` + KeyIDs *[]string `json:"key_ids,omitempty"` // List of DBKey UUIDs to associate with this VirtualKey + IsActive *bool `json:"is_active,omitempty"` +} + +// CreateBudgetRequest represents the request body for creating a budget +type CreateBudgetRequest struct { + MaxLimit float64 `json:"max_limit" validate:"required"` // Maximum budget in dollars + ResetDuration string `json:"reset_duration" validate:"required"` // e.g., "30s", "5m", "1h", "1d", "1w", "1M" +} + +// UpdateBudgetRequest represents the request body for updating a budget +type UpdateBudgetRequest struct { + MaxLimit *float64 `json:"max_limit,omitempty"` + ResetDuration *string `json:"reset_duration,omitempty"` +} + +// CreateRateLimitRequest represents the request body for creating a rate limit using flexible approach +type CreateRateLimitRequest struct { + TokenMaxLimit *int64 `json:"token_max_limit,omitempty"` // Maximum tokens allowed + TokenResetDuration *string `json:"token_reset_duration,omitempty"` // e.g., "30s", "5m", "1h", "1d", "1w", "1M" + RequestMaxLimit *int64 `json:"request_max_limit,omitempty"` // Maximum requests allowed + RequestResetDuration *string `json:"request_reset_duration,omitempty"` // e.g., "30s", "5m", "1h", "1d", "1w", "1M" +} + +// UpdateRateLimitRequest represents the request body for updating a rate limit using flexible approach +type UpdateRateLimitRequest struct { + TokenMaxLimit *int64 `json:"token_max_limit,omitempty"` // Maximum tokens allowed + TokenResetDuration *string `json:"token_reset_duration,omitempty"` // e.g., "30s", "5m", "1h", "1d", "1w", "1M" + RequestMaxLimit *int64 `json:"request_max_limit,omitempty"` // Maximum requests allowed + RequestResetDuration *string `json:"request_reset_duration,omitempty"` // e.g., "30s", "5m", "1h", "1d", "1w", "1M" +} + +// CreateTeamRequest represents the request body for creating a team +type CreateTeamRequest struct { + Name string `json:"name" validate:"required"` + CustomerID *string `json:"customer_id,omitempty"` // Team can belong to a customer + Budget *CreateBudgetRequest `json:"budget,omitempty"` // Team can have its own budget +} + +// UpdateTeamRequest represents the request body for updating a team +type UpdateTeamRequest struct { + Name *string `json:"name,omitempty"` + CustomerID *string `json:"customer_id,omitempty"` + Budget *UpdateBudgetRequest `json:"budget,omitempty"` +} + +// CreateCustomerRequest represents the request body for creating a customer +type CreateCustomerRequest struct { + Name string `json:"name" validate:"required"` + Budget *CreateBudgetRequest `json:"budget,omitempty"` +} + +// UpdateCustomerRequest represents the request body for updating a customer +type UpdateCustomerRequest struct { + Name *string `json:"name,omitempty"` + Budget *UpdateBudgetRequest `json:"budget,omitempty"` +} + +// RegisterRoutes registers all governance-related routes for the new hierarchical system +func (h *GovernanceHandler) RegisterRoutes(r *router.Router) { + // Virtual Key CRUD operations + r.GET("/api/governance/virtual-keys", h.getVirtualKeys) + r.POST("/api/governance/virtual-keys", h.createVirtualKey) + r.GET("/api/governance/virtual-keys/{vk_id}", h.getVirtualKey) + r.PUT("/api/governance/virtual-keys/{vk_id}", h.updateVirtualKey) + r.DELETE("/api/governance/virtual-keys/{vk_id}", h.deleteVirtualKey) + + // Team CRUD operations + r.GET("/api/governance/teams", h.getTeams) + r.POST("/api/governance/teams", h.createTeam) + r.GET("/api/governance/teams/{team_id}", h.getTeam) + r.PUT("/api/governance/teams/{team_id}", h.updateTeam) + r.DELETE("/api/governance/teams/{team_id}", h.deleteTeam) + + // Customer CRUD operations + r.GET("/api/governance/customers", h.getCustomers) + r.POST("/api/governance/customers", h.createCustomer) + r.GET("/api/governance/customers/{customer_id}", h.getCustomer) + r.PUT("/api/governance/customers/{customer_id}", h.updateCustomer) + r.DELETE("/api/governance/customers/{customer_id}", h.deleteCustomer) +} + +// Virtual Key CRUD Operations + +// getVirtualKeys handles GET /api/governance/virtual-keys - Get all virtual keys with relationships +func (h *GovernanceHandler) getVirtualKeys(ctx *fasthttp.RequestCtx) { + // Preload all relationships for complete information + virtualKeys, err := h.configStore.GetVirtualKeys() + if err != nil { + h.logger.Error("failed to retrieve virtual keys: %v", err) + SendError(ctx, 500, "Failed to retrieve virtual keys", h.logger) + return + } + + SendJSON(ctx, map[string]interface{}{ + "virtual_keys": virtualKeys, + "count": len(virtualKeys), + }, h.logger) +} + +// createVirtualKey handles POST /api/governance/virtual-keys - Create a new virtual key +func (h *GovernanceHandler) createVirtualKey(ctx *fasthttp.RequestCtx) { + var req CreateVirtualKeyRequest + if err := json.Unmarshal(ctx.PostBody(), &req); err != nil { + SendError(ctx, 400, "Invalid JSON", h.logger) + return + } + + // Validate required fields + if req.Name == "" { + SendError(ctx, 400, "Virtual key name is required", h.logger) + return + } + + // Validate mutually exclusive TeamID and CustomerID + if req.TeamID != nil && req.CustomerID != nil { + SendError(ctx, 400, "VirtualKey cannot be attached to both Team and Customer", h.logger) + return + } + + // Validate budget if provided + if req.Budget != nil { + if req.Budget.MaxLimit < 0 { + SendError(ctx, 400, fmt.Sprintf("Budget max_limit cannot be negative: %.2f", req.Budget.MaxLimit), h.logger) + return + } + // Validate reset duration format + if _, err := configstore.ParseDuration(req.Budget.ResetDuration); err != nil { + SendError(ctx, 400, fmt.Sprintf("Invalid reset duration format: %s", req.Budget.ResetDuration), h.logger) + return + } + } + + // Set defaults + isActive := true + if req.IsActive != nil { + isActive = *req.IsActive + } + + var vk configstore.TableVirtualKey + if err := h.configStore.ExecuteTransaction(func(tx *gorm.DB) error { + // Get the keys if DBKeyIDs are provided + var keys []configstore.TableKey + if len(req.KeyIDs) > 0 { + var err error + keys, err = h.configStore.GetKeysByIDs(req.KeyIDs) + if err != nil { + return fmt.Errorf("failed to get keys by IDs: %w", err) + } + if len(keys) != len(req.KeyIDs) { + return fmt.Errorf("some keys not found: expected %d, found %d", len(req.KeyIDs), len(keys)) + } + } + + vk = configstore.TableVirtualKey{ + ID: uuid.NewString(), + Name: req.Name, + Value: uuid.NewString(), + Description: req.Description, + AllowedModels: req.AllowedModels, + AllowedProviders: req.AllowedProviders, + TeamID: req.TeamID, + CustomerID: req.CustomerID, + IsActive: isActive, + Keys: keys, // Set the keys for the many-to-many relationship + } + + if req.Budget != nil { + budget := configstore.TableBudget{ + ID: uuid.NewString(), + MaxLimit: req.Budget.MaxLimit, + ResetDuration: req.Budget.ResetDuration, + LastReset: time.Now(), + CurrentUsage: 0, + } + if err := h.configStore.CreateBudget(&budget, tx); err != nil { + return err + } + vk.BudgetID = &budget.ID + } + + if req.RateLimit != nil { + rateLimit := configstore.TableRateLimit{ + ID: uuid.NewString(), + TokenMaxLimit: req.RateLimit.TokenMaxLimit, + TokenResetDuration: req.RateLimit.TokenResetDuration, + RequestMaxLimit: req.RateLimit.RequestMaxLimit, + RequestResetDuration: req.RateLimit.RequestResetDuration, + TokenLastReset: time.Now(), + RequestLastReset: time.Now(), + } + if err := h.configStore.CreateRateLimit(&rateLimit, tx); err != nil { + return err + } + vk.RateLimitID = &rateLimit.ID + } + + if err := h.configStore.CreateVirtualKey(&vk, tx); err != nil { + return err + } + + return nil + }); err != nil { + SendError(ctx, 500, err.Error(), h.logger) + return + } + + // Load relationships for response + preloadedVk, err := h.configStore.GetVirtualKey(vk.ID) + if err != nil { + h.logger.Error("failed to load relationships for created VK: %v", err) + // If we can't load the full VK, use the basic one we just created + preloadedVk = &vk + } + + // Add to in-memory store + h.pluginStore.CreateVirtualKeyInMemory(preloadedVk) + + // If budget was created, add it to in-memory store + if vk.BudgetID != nil && preloadedVk.Budget != nil { + h.pluginStore.CreateBudgetInMemory(preloadedVk.Budget) + } + + SendJSON(ctx, map[string]interface{}{ + "message": "Virtual key created successfully", + "virtual_key": preloadedVk, + }, h.logger) +} + +// getVirtualKey handles GET /api/governance/virtual-keys/{vk_id} - Get a specific virtual key +func (h *GovernanceHandler) getVirtualKey(ctx *fasthttp.RequestCtx) { + vkID := ctx.UserValue("vk_id").(string) + + vk, err := h.configStore.GetVirtualKey(vkID) + if err != nil { + if err == gorm.ErrRecordNotFound { + SendError(ctx, 404, "Virtual key not found", h.logger) + return + } + SendError(ctx, 500, "Failed to retrieve virtual key", h.logger) + return + } + + SendJSON(ctx, map[string]interface{}{ + "virtual_key": vk, + }, h.logger) +} + +// updateVirtualKey handles PUT /api/governance/virtual-keys/{vk_id} - Update a virtual key +func (h *GovernanceHandler) updateVirtualKey(ctx *fasthttp.RequestCtx) { + vkID := ctx.UserValue("vk_id").(string) + + var req UpdateVirtualKeyRequest + if err := json.Unmarshal(ctx.PostBody(), &req); err != nil { + SendError(ctx, 400, "Invalid JSON", h.logger) + return + } + + // Validate mutually exclusive TeamID and CustomerID + if req.TeamID != nil && req.CustomerID != nil { + SendError(ctx, 400, "VirtualKey cannot be attached to both Team and Customer", h.logger) + return + } + + vk, err := h.configStore.GetVirtualKey(vkID) + if err != nil { + if err == gorm.ErrRecordNotFound { + SendError(ctx, 404, "Virtual key not found", h.logger) + return + } + SendError(ctx, 500, "Failed to retrieve virtual key", h.logger) + return + } + + if err := h.configStore.ExecuteTransaction(func(tx *gorm.DB) error { + // Update fields if provided + if req.Description != nil { + vk.Description = *req.Description + } + if req.AllowedModels != nil { + vk.AllowedModels = *req.AllowedModels + } + if req.AllowedProviders != nil { + vk.AllowedProviders = *req.AllowedProviders + } + if req.TeamID != nil { + vk.TeamID = req.TeamID + vk.CustomerID = nil // Clear CustomerID if setting TeamID + } + if req.CustomerID != nil { + vk.CustomerID = req.CustomerID + vk.TeamID = nil // Clear TeamID if setting CustomerID + } + if req.IsActive != nil { + vk.IsActive = *req.IsActive + } + + // Handle budget updates + if req.Budget != nil { + if vk.BudgetID != nil { + // Update existing budget + budget := configstore.TableBudget{} + if err := tx.First(&budget, "id = ?", *vk.BudgetID).Error; err != nil { + return err + } + + if req.Budget.MaxLimit != nil { + budget.MaxLimit = *req.Budget.MaxLimit + } + if req.Budget.ResetDuration != nil { + budget.ResetDuration = *req.Budget.ResetDuration + } + + if err := h.configStore.UpdateBudget(&budget, tx); err != nil { + return err + } + vk.Budget = &budget + } else { + // Create new budget + if req.Budget.MaxLimit == nil || req.Budget.ResetDuration == nil { + return fmt.Errorf("both max_limit and reset_duration are required when creating a new budget") + } + if *req.Budget.MaxLimit < 0 { + return fmt.Errorf("budget max_limit cannot be negative: %.2f", *req.Budget.MaxLimit) + } + if _, err := configstore.ParseDuration(*req.Budget.ResetDuration); err != nil { + return fmt.Errorf("invalid reset duration format: %s", *req.Budget.ResetDuration) + } + // Storing now + budget := configstore.TableBudget{ + ID: uuid.NewString(), + MaxLimit: *req.Budget.MaxLimit, + ResetDuration: *req.Budget.ResetDuration, + LastReset: time.Now(), + CurrentUsage: 0, + } + if err := h.configStore.CreateBudget(&budget, tx); err != nil { + return err + } + vk.BudgetID = &budget.ID + vk.Budget = &budget + } + } + + // Handle rate limit updates + if req.RateLimit != nil { + if vk.RateLimitID != nil { + // Update existing rate limit + rateLimit := configstore.TableRateLimit{} + if err := tx.First(&rateLimit, "id = ?", *vk.RateLimitID).Error; err != nil { + return err + } + + if req.RateLimit.TokenMaxLimit != nil { + rateLimit.TokenMaxLimit = req.RateLimit.TokenMaxLimit + } + if req.RateLimit.TokenResetDuration != nil { + rateLimit.TokenResetDuration = req.RateLimit.TokenResetDuration + } + if req.RateLimit.RequestMaxLimit != nil { + rateLimit.RequestMaxLimit = req.RateLimit.RequestMaxLimit + } + if req.RateLimit.RequestResetDuration != nil { + rateLimit.RequestResetDuration = req.RateLimit.RequestResetDuration + } + + if err := h.configStore.UpdateRateLimit(&rateLimit, tx); err != nil { + return err + } + } else { + // Create new rate limit + rateLimit := configstore.TableRateLimit{ + ID: uuid.NewString(), + TokenMaxLimit: req.RateLimit.TokenMaxLimit, + TokenResetDuration: req.RateLimit.TokenResetDuration, + RequestMaxLimit: req.RateLimit.RequestMaxLimit, + RequestResetDuration: req.RateLimit.RequestResetDuration, + TokenLastReset: time.Now(), + RequestLastReset: time.Now(), + } + if err := h.configStore.CreateRateLimit(&rateLimit, tx); err != nil { + return err + } + vk.RateLimitID = &rateLimit.ID + } + } + + // Handle DBKey associations if provided + if req.KeyIDs != nil { + // Get the keys if DBKeyIDs are provided + var keys []configstore.TableKey + if len(*req.KeyIDs) > 0 { + var err error + keys, err = h.configStore.GetKeysByIDs(*req.KeyIDs) + if err != nil { + return fmt.Errorf("failed to get keys by IDs: %w", err) + } + if len(keys) != len(*req.KeyIDs) { + return fmt.Errorf("some keys not found: expected %d, found %d", len(*req.KeyIDs), len(keys)) + } + } + + // Set the keys for the many-to-many relationship + vk.Keys = keys + } + + if err := h.configStore.UpdateVirtualKey(vk, tx); err != nil { + return err + } + + return nil + }); err != nil { + h.logger.Error("failed to update virtual key: %v", err) + SendError(ctx, 500, "Failed to update virtual key", h.logger) + return + } + + // Load relationships for response + preloadedVk, err := h.configStore.GetVirtualKey(vk.ID) + if err != nil { + h.logger.Error("failed to load relationships for updated VK: %v", err) + preloadedVk = vk + } + + // Update in-memory cache for budget and rate limit changes + if req.Budget != nil && preloadedVk.BudgetID != nil { + if err := h.pluginStore.UpdateBudgetInMemory(preloadedVk.Budget); err != nil { + h.logger.Error("failed to update budget cache: %v", err) + } + } + + // Update in-memory store + h.pluginStore.UpdateVirtualKeyInMemory(preloadedVk) + + SendJSON(ctx, map[string]interface{}{ + "message": "Virtual key updated successfully", + "virtual_key": preloadedVk, + }, h.logger) +} + +// deleteVirtualKey handles DELETE /api/governance/virtual-keys/{vk_id} - Delete a virtual key +func (h *GovernanceHandler) deleteVirtualKey(ctx *fasthttp.RequestCtx) { + vkID := ctx.UserValue("vk_id").(string) + + // Fetch the virtual key from the database to get the budget and rate limit + vk, err := h.configStore.GetVirtualKey(vkID) + if err != nil { + if err == gorm.ErrRecordNotFound { + SendError(ctx, 404, "Virtual key not found", h.logger) + return + } + SendError(ctx, 500, "Failed to retrieve virtual key", h.logger) + return + } + + budgetID := vk.BudgetID + + if err := h.configStore.DeleteVirtualKey(vkID); err != nil { + if err == gorm.ErrRecordNotFound { + SendError(ctx, 404, "Virtual key not found", h.logger) + return + } + SendError(ctx, 500, "Failed to delete virtual key", h.logger) + return + } + + // Remove from in-memory store + h.pluginStore.DeleteVirtualKeyInMemory(vkID) + + // Remove Budget from in-memory store + if budgetID != nil { + h.pluginStore.DeleteBudgetInMemory(*budgetID) + } + + SendJSON(ctx, map[string]interface{}{ + "message": "Virtual key deleted successfully", + }, h.logger) +} + +// Team CRUD Operations + +// getTeams handles GET /api/governance/teams - Get all teams +func (h *GovernanceHandler) getTeams(ctx *fasthttp.RequestCtx) { + customerID := string(ctx.QueryArgs().Peek("customer_id")) + + // Preload relationships for complete information + teams, err := h.configStore.GetTeams(customerID) + if err != nil { + h.logger.Error("failed to retrieve teams: %v", err) + SendError(ctx, 500, fmt.Sprintf("Failed to retrieve teams: %v", err), h.logger) + return + } + + SendJSON(ctx, map[string]interface{}{ + "teams": teams, + "count": len(teams), + }, h.logger) +} + +// createTeam handles POST /api/governance/teams - Create a new team +func (h *GovernanceHandler) createTeam(ctx *fasthttp.RequestCtx) { + var req CreateTeamRequest + if err := json.Unmarshal(ctx.PostBody(), &req); err != nil { + SendError(ctx, 400, "Invalid JSON", h.logger) + return + } + + // Validate required fields + if req.Name == "" { + SendError(ctx, 400, "Team name is required", h.logger) + return + } + + // Validate budget if provided + if req.Budget != nil { + if req.Budget.MaxLimit < 0 { + SendError(ctx, 400, fmt.Sprintf("Budget max_limit cannot be negative: %.2f", req.Budget.MaxLimit), h.logger) + return + } + // Validate reset duration format + if _, err := configstore.ParseDuration(req.Budget.ResetDuration); err != nil { + SendError(ctx, 400, fmt.Sprintf("Invalid reset duration format: %s", req.Budget.ResetDuration), h.logger) + return + } + } + + var team configstore.TableTeam + if err := h.configStore.ExecuteTransaction(func(tx *gorm.DB) error { + team = configstore.TableTeam{ + ID: uuid.NewString(), + Name: req.Name, + CustomerID: req.CustomerID, + } + + if req.Budget != nil { + budget := configstore.TableBudget{ + ID: uuid.NewString(), + MaxLimit: req.Budget.MaxLimit, + ResetDuration: req.Budget.ResetDuration, + LastReset: time.Now(), + CurrentUsage: 0, + } + if err := h.configStore.CreateBudget(&budget, tx); err != nil { + return err + } + team.BudgetID = &budget.ID + } + + if err := h.configStore.CreateTeam(&team, tx); err != nil { + return err + } + return nil + }); err != nil { + h.logger.Error("failed to create team: %v", err) + SendError(ctx, 500, "failed to create team", h.logger) + return + } + + // Load relationships for response + preloadedTeam, err := h.configStore.GetTeam(team.ID) + if err != nil { + h.logger.Error("failed to load relationships for created team: %v", err) + preloadedTeam = &team + } + + // Add to in-memory store + h.pluginStore.CreateTeamInMemory(preloadedTeam) + + // If budget was created, add it to in-memory store + if preloadedTeam.BudgetID != nil { + h.pluginStore.CreateBudgetInMemory(preloadedTeam.Budget) + } + + SendJSON(ctx, map[string]interface{}{ + "message": "Team created successfully", + "team": preloadedTeam, + }, h.logger) +} + +// getTeam handles GET /api/governance/teams/{team_id} - Get a specific team +func (h *GovernanceHandler) getTeam(ctx *fasthttp.RequestCtx) { + teamID := ctx.UserValue("team_id").(string) + + team, err := h.configStore.GetTeam(teamID) + if err != nil { + if err == gorm.ErrRecordNotFound { + SendError(ctx, 404, "Team not found", h.logger) + return + } + SendError(ctx, 500, "Failed to retrieve team", h.logger) + return + } + + SendJSON(ctx, map[string]interface{}{ + "team": team, + }, h.logger) +} + +// updateTeam handles PUT /api/governance/teams/{team_id} - Update a team +func (h *GovernanceHandler) updateTeam(ctx *fasthttp.RequestCtx) { + teamID := ctx.UserValue("team_id").(string) + + var req UpdateTeamRequest + if err := json.Unmarshal(ctx.PostBody(), &req); err != nil { + SendError(ctx, 400, "Invalid JSON", h.logger) + return + } + + team, err := h.configStore.GetTeam(teamID) + if err != nil { + if err == gorm.ErrRecordNotFound { + SendError(ctx, 404, "Team not found", h.logger) + return + } + SendError(ctx, 500, "Failed to retrieve team", h.logger) + return + } + + if err := h.configStore.ExecuteTransaction(func(tx *gorm.DB) error { + // Update fields if provided + if req.Name != nil { + team.Name = *req.Name + } + if req.CustomerID != nil { + team.CustomerID = req.CustomerID + } + + // Handle budget updates + if req.Budget != nil { + if team.BudgetID != nil { + // Update existing budget + budget, err := h.configStore.GetBudget(*team.BudgetID, tx) + if err != nil { + return err + } + + if req.Budget.MaxLimit != nil { + budget.MaxLimit = *req.Budget.MaxLimit + } + if req.Budget.ResetDuration != nil { + budget.ResetDuration = *req.Budget.ResetDuration + } + + if err := h.configStore.UpdateBudget(budget, tx); err != nil { + return err + } + team.Budget = budget + } else { + // Create new budget + budget := configstore.TableBudget{ + ID: uuid.NewString(), + MaxLimit: *req.Budget.MaxLimit, + ResetDuration: *req.Budget.ResetDuration, + LastReset: time.Now(), + CurrentUsage: 0, + } + if err := h.configStore.CreateBudget(&budget, tx); err != nil { + return err + } + team.BudgetID = &budget.ID + team.Budget = &budget + } + } + + if err := h.configStore.UpdateTeam(team, tx); err != nil { + return err + } + + return nil + }); err != nil { + SendError(ctx, 500, "Failed to update team", h.logger) + return + } + + // Update in-memory cache for budget changes + if req.Budget != nil && team.BudgetID != nil { + if err := h.pluginStore.UpdateBudgetInMemory(team.Budget); err != nil { + h.logger.Error("failed to update budget cache: %v", err) + } + } + + // Load relationships for response + preloadedTeam, err := h.configStore.GetTeam(team.ID) + if err != nil { + h.logger.Error("failed to load relationships for updated team: %v", err) + preloadedTeam = team + } + + // Update in-memory store + h.pluginStore.UpdateTeamInMemory(preloadedTeam) + + SendJSON(ctx, map[string]interface{}{ + "message": "Team updated successfully", + "team": preloadedTeam, + }, h.logger) +} + +// deleteTeam handles DELETE /api/governance/teams/{team_id} - Delete a team +func (h *GovernanceHandler) deleteTeam(ctx *fasthttp.RequestCtx) { + teamID := ctx.UserValue("team_id").(string) + + team, err := h.configStore.GetTeam(teamID) + if err != nil { + if err == gorm.ErrRecordNotFound { + SendError(ctx, 404, "Team not found", h.logger) + return + } + SendError(ctx, 500, "Failed to retrieve team", h.logger) + return + } + + budgetID := team.BudgetID + + if err := h.configStore.DeleteTeam(teamID); err != nil { + if err == gorm.ErrRecordNotFound { + SendError(ctx, 404, "Team not found", h.logger) + return + } + SendError(ctx, 500, "Failed to delete team", h.logger) + return + } + + // Remove from in-memory store + h.pluginStore.DeleteTeamInMemory(teamID) + + // Remove Budget from in-memory store + if budgetID != nil { + h.pluginStore.DeleteBudgetInMemory(*budgetID) + } + + SendJSON(ctx, map[string]interface{}{ + "message": "Team deleted successfully", + }, h.logger) +} + +// Customer CRUD Operations + +// getCustomers handles GET /api/governance/customers - Get all customers +func (h *GovernanceHandler) getCustomers(ctx *fasthttp.RequestCtx) { + customers, err := h.configStore.GetCustomers() + if err != nil { + h.logger.Error("failed to retrieve customers: %v", err) + SendError(ctx, 500, "failed to retrieve customers", h.logger) + return + } + + SendJSON(ctx, map[string]interface{}{ + "customers": customers, + "count": len(customers), + }, h.logger) +} + +// createCustomer handles POST /api/governance/customers - Create a new customer +func (h *GovernanceHandler) createCustomer(ctx *fasthttp.RequestCtx) { + var req CreateCustomerRequest + if err := json.Unmarshal(ctx.PostBody(), &req); err != nil { + SendError(ctx, 400, "Invalid JSON", h.logger) + return + } + + // Validate required fields + if req.Name == "" { + SendError(ctx, 400, "Customer name is required", h.logger) + return + } + + // Validate budget if provided + if req.Budget != nil { + if req.Budget.MaxLimit < 0 { + SendError(ctx, 400, fmt.Sprintf("Budget max_limit cannot be negative: %.2f", req.Budget.MaxLimit), h.logger) + return + } + // Validate reset duration format + if _, err := configstore.ParseDuration(req.Budget.ResetDuration); err != nil { + SendError(ctx, 400, fmt.Sprintf("Invalid reset duration format: %s", req.Budget.ResetDuration), h.logger) + return + } + } + + var customer configstore.TableCustomer + if err := h.configStore.ExecuteTransaction(func(tx *gorm.DB) error { + customer = configstore.TableCustomer{ + ID: uuid.NewString(), + Name: req.Name, + } + + if req.Budget != nil { + budget := configstore.TableBudget{ + ID: uuid.NewString(), + MaxLimit: req.Budget.MaxLimit, + ResetDuration: req.Budget.ResetDuration, + LastReset: time.Now(), + CurrentUsage: 0, + } + if err := h.configStore.CreateBudget(&budget, tx); err != nil { + return err + } + customer.BudgetID = &budget.ID + } + + if err := h.configStore.CreateCustomer(&customer, tx); err != nil { + return err + } + return nil + }); err != nil { + SendError(ctx, 500, "failed to create customer", h.logger) + return + } + + // Load relationships for response + preloadedCustomer, err := h.configStore.GetCustomer(customer.ID) + if err != nil { + h.logger.Error("failed to load relationships for created customer: %v", err) + preloadedCustomer = &customer + } + + // Add to in-memory store + h.pluginStore.CreateCustomerInMemory(preloadedCustomer) + + // If budget was created, add it to in-memory store + if preloadedCustomer.BudgetID != nil { + h.pluginStore.CreateBudgetInMemory(preloadedCustomer.Budget) + } + + SendJSON(ctx, map[string]interface{}{ + "message": "Customer created successfully", + "customer": preloadedCustomer, + }, h.logger) +} + +// getCustomer handles GET /api/governance/customers/{customer_id} - Get a specific customer +func (h *GovernanceHandler) getCustomer(ctx *fasthttp.RequestCtx) { + customerID := ctx.UserValue("customer_id").(string) + + customer, err := h.configStore.GetCustomer(customerID) + if err != nil { + if err == gorm.ErrRecordNotFound { + SendError(ctx, 404, "Customer not found", h.logger) + return + } + SendError(ctx, 500, "Failed to retrieve customer", h.logger) + return + } + + SendJSON(ctx, map[string]interface{}{ + "customer": customer, + }, h.logger) +} + +// updateCustomer handles PUT /api/governance/customers/{customer_id} - Update a customer +func (h *GovernanceHandler) updateCustomer(ctx *fasthttp.RequestCtx) { + customerID := ctx.UserValue("customer_id").(string) + + var req UpdateCustomerRequest + if err := json.Unmarshal(ctx.PostBody(), &req); err != nil { + SendError(ctx, 400, "Invalid JSON", h.logger) + return + } + + customer, err := h.configStore.GetCustomer(customerID) + if err != nil { + if err == gorm.ErrRecordNotFound { + SendError(ctx, 404, "Customer not found", h.logger) + return + } + SendError(ctx, 500, "Failed to retrieve customer", h.logger) + return + } + + if err := h.configStore.ExecuteTransaction(func(tx *gorm.DB) error { + // Update fields if provided + if req.Name != nil { + customer.Name = *req.Name + } + + // Handle budget updates + if req.Budget != nil { + if customer.BudgetID != nil { + // Update existing budget + budget, err := h.configStore.GetBudget(*customer.BudgetID, tx) + if err != nil { + return err + } + + if req.Budget.MaxLimit != nil { + budget.MaxLimit = *req.Budget.MaxLimit + } + if req.Budget.ResetDuration != nil { + budget.ResetDuration = *req.Budget.ResetDuration + } + + if err := h.configStore.UpdateBudget(budget, tx); err != nil { + return err + } + customer.Budget = budget + } else { + // Create new budget + budget := configstore.TableBudget{ + ID: uuid.NewString(), + MaxLimit: *req.Budget.MaxLimit, + ResetDuration: *req.Budget.ResetDuration, + LastReset: time.Now(), + CurrentUsage: 0, + } + if err := h.configStore.CreateBudget(&budget, tx); err != nil { + return err + } + customer.BudgetID = &budget.ID + customer.Budget = &budget + } + } + + if err := h.configStore.UpdateCustomer(customer, tx); err != nil { + return err + } + + return nil + }); err != nil { + SendError(ctx, 500, "Failed to update customer", h.logger) + return + } + + // Update in-memory cache for budget changes + if req.Budget != nil && customer.BudgetID != nil { + if err := h.pluginStore.UpdateBudgetInMemory(customer.Budget); err != nil { + h.logger.Error("failed to update budget cache: %v", err) + } + } + + // Load relationships for response + preloadedCustomer, err := h.configStore.GetCustomer(customer.ID) + if err != nil { + h.logger.Error("failed to load relationships for updated customer: %v", err) + preloadedCustomer = customer + } + + // Update in-memory store + h.pluginStore.UpdateCustomerInMemory(preloadedCustomer) + + SendJSON(ctx, map[string]interface{}{ + "message": "Customer updated successfully", + "customer": preloadedCustomer, + }, h.logger) +} + +// deleteCustomer handles DELETE /api/governance/customers/{customer_id} - Delete a customer +func (h *GovernanceHandler) deleteCustomer(ctx *fasthttp.RequestCtx) { + customerID := ctx.UserValue("customer_id").(string) + + customer, err := h.configStore.GetCustomer(customerID) + if err != nil { + if err == gorm.ErrRecordNotFound { + SendError(ctx, 404, "Customer not found", h.logger) + return + } + SendError(ctx, 500, "Failed to retrieve customer", h.logger) + return + } + + budgetID := customer.BudgetID + + if err := h.configStore.DeleteCustomer(customerID); err != nil { + if err == gorm.ErrRecordNotFound { + SendError(ctx, 404, "Customer not found", h.logger) + return + } + SendError(ctx, 500, "Failed to delete customer", h.logger) + return + } + + // Remove from in-memory store + h.pluginStore.DeleteCustomerInMemory(customerID) + + // Remove Budget from in-memory store + if budgetID != nil { + h.pluginStore.DeleteBudgetInMemory(*budgetID) + } + + SendJSON(ctx, map[string]interface{}{ + "message": "Customer deleted successfully", + }, h.logger) +} diff --git a/transports/bifrost-http/handlers/handlers.go b/transports/bifrost-http/handlers/handlers.go new file mode 100644 index 000000000..7e2baffc7 --- /dev/null +++ b/transports/bifrost-http/handlers/handlers.go @@ -0,0 +1,8 @@ +package handlers + +var version string + +// SetVersion sets the version of the application. +func SetVersion(v string) { + version = v +} diff --git a/transports/bifrost-http/handlers/integrations.go b/transports/bifrost-http/handlers/integrations.go new file mode 100644 index 000000000..d3b639efd --- /dev/null +++ b/transports/bifrost-http/handlers/integrations.go @@ -0,0 +1,44 @@ +// Package handlers provides HTTP request handlers for the Bifrost HTTP transport. +// This file contains integration management handlers for AI provider integrations. +package handlers + +import ( + "github.com/fasthttp/router" + bifrost "github.com/maximhq/bifrost/core" + "github.com/maximhq/bifrost/transports/bifrost-http/integrations" + "github.com/maximhq/bifrost/transports/bifrost-http/integrations/anthropic" + "github.com/maximhq/bifrost/transports/bifrost-http/integrations/genai" + "github.com/maximhq/bifrost/transports/bifrost-http/integrations/langchain" + "github.com/maximhq/bifrost/transports/bifrost-http/integrations/litellm" + "github.com/maximhq/bifrost/transports/bifrost-http/integrations/openai" + "github.com/maximhq/bifrost/transports/bifrost-http/lib" +) + +// IntegrationHandler manages HTTP requests for AI provider integrations +type IntegrationHandler struct { + extensions []integrations.ExtensionRouter +} + +// NewIntegrationHandler creates a new integration handler instance +func NewIntegrationHandler(client *bifrost.Bifrost, handlerStore lib.HandlerStore) *IntegrationHandler { + // Initialize all available integration routers + extensions := []integrations.ExtensionRouter{ + openai.NewOpenAIRouter(client, handlerStore), + anthropic.NewAnthropicRouter(client, handlerStore), + genai.NewGenAIRouter(client, handlerStore), + litellm.NewLiteLLMRouter(client, handlerStore), + langchain.NewLangChainRouter(client, handlerStore), + } + + return &IntegrationHandler{ + extensions: extensions, + } +} + +// RegisterRoutes registers all integration routes for AI provider compatibility endpoints +func (h *IntegrationHandler) RegisterRoutes(r *router.Router) { + // Register routes for each integration extension + for _, extension := range h.extensions { + extension.RegisterRoutes(r) + } +} diff --git a/transports/bifrost-http/handlers/logging.go b/transports/bifrost-http/handlers/logging.go new file mode 100644 index 000000000..915a35497 --- /dev/null +++ b/transports/bifrost-http/handlers/logging.go @@ -0,0 +1,183 @@ +// Package handlers provides HTTP request handlers for the Bifrost HTTP transport. +// This file contains logging-related handlers for log search, stats, and management. +package handlers + +import ( + "fmt" + "strconv" + "strings" + "time" + + "github.com/fasthttp/router" + "github.com/maximhq/bifrost/core/schemas" + "github.com/maximhq/bifrost/framework/logstore" + "github.com/maximhq/bifrost/plugins/logging" + "github.com/valyala/fasthttp" +) + +// LoggingHandler manages HTTP requests for logging operations +type LoggingHandler struct { + logManager logging.LogManager + logger schemas.Logger +} + +// NewLoggingHandler creates a new logging handler instance +func NewLoggingHandler(logManager logging.LogManager, logger schemas.Logger) *LoggingHandler { + return &LoggingHandler{ + logManager: logManager, + logger: logger, + } +} + +// RegisterRoutes registers all logging-related routes +func (h *LoggingHandler) RegisterRoutes(r *router.Router) { + // Log retrieval with filtering, search, and pagination + r.GET("/api/logs", h.getLogs) + r.GET("/api/logs/dropped", h.getDroppedRequests) + r.GET("/api/logs/models", h.getAvailableModels) +} + +// getLogs handles GET /api/logs - Get logs with filtering, search, and pagination via query parameters +func (h *LoggingHandler) getLogs(ctx *fasthttp.RequestCtx) { + // Parse query parameters into filters + filters := &logstore.SearchFilters{} + pagination := &logstore.PaginationOptions{} + + // Extract filters from query parameters + if providers := string(ctx.QueryArgs().Peek("providers")); providers != "" { + filters.Providers = parseCommaSeparated(providers) + } + if models := string(ctx.QueryArgs().Peek("models")); models != "" { + filters.Models = parseCommaSeparated(models) + } + if statuses := string(ctx.QueryArgs().Peek("status")); statuses != "" { + filters.Status = parseCommaSeparated(statuses) + } + if objects := string(ctx.QueryArgs().Peek("objects")); objects != "" { + filters.Objects = parseCommaSeparated(objects) + } + if startTime := string(ctx.QueryArgs().Peek("start_time")); startTime != "" { + if t, err := time.Parse(time.RFC3339, startTime); err == nil { + filters.StartTime = &t + } + } + if endTime := string(ctx.QueryArgs().Peek("end_time")); endTime != "" { + if t, err := time.Parse(time.RFC3339, endTime); err == nil { + filters.EndTime = &t + } + } + if minLatency := string(ctx.QueryArgs().Peek("min_latency")); minLatency != "" { + if f, err := strconv.ParseFloat(minLatency, 64); err == nil { + filters.MinLatency = &f + } + } + if maxLatency := string(ctx.QueryArgs().Peek("max_latency")); maxLatency != "" { + if val, err := strconv.ParseFloat(maxLatency, 64); err == nil { + filters.MaxLatency = &val + } + } + if minTokens := string(ctx.QueryArgs().Peek("min_tokens")); minTokens != "" { + if val, err := strconv.Atoi(minTokens); err == nil { + filters.MinTokens = &val + } + } + if maxTokens := string(ctx.QueryArgs().Peek("max_tokens")); maxTokens != "" { + if val, err := strconv.Atoi(maxTokens); err == nil { + filters.MaxTokens = &val + } + } + if cost := string(ctx.QueryArgs().Peek("min_cost")); cost != "" { + if val, err := strconv.ParseFloat(cost, 64); err == nil { + filters.MinCost = &val + } + } + if maxCost := string(ctx.QueryArgs().Peek("max_cost")); maxCost != "" { + if val, err := strconv.ParseFloat(maxCost, 64); err == nil { + filters.MaxCost = &val + } + } + if contentSearch := string(ctx.QueryArgs().Peek("content_search")); contentSearch != "" { + filters.ContentSearch = contentSearch + } + + // Extract pagination parameters + pagination.Limit = 50 // Default limit + if limit := string(ctx.QueryArgs().Peek("limit")); limit != "" { + if i, err := strconv.Atoi(limit); err == nil { + if i <= 0 { + SendError(ctx, fasthttp.StatusBadRequest, "limit must be greater than 0", h.logger) + return + } + if i > 1000 { + SendError(ctx, fasthttp.StatusBadRequest, "limit cannot exceed 1000", h.logger) + return + } + pagination.Limit = i + } + } + + pagination.Offset = 0 // Default offset + if offset := string(ctx.QueryArgs().Peek("offset")); offset != "" { + if i, err := strconv.Atoi(offset); err == nil { + if i < 0 { + SendError(ctx, fasthttp.StatusBadRequest, "offset cannot be negative", h.logger) + return + } + pagination.Offset = i + } + } + + // Sort parameters + pagination.SortBy = "timestamp" // Default sort field + if sortBy := string(ctx.QueryArgs().Peek("sort_by")); sortBy != "" { + if sortBy == "timestamp" || sortBy == "latency" || sortBy == "tokens" || sortBy == "cost" { + pagination.SortBy = sortBy + } + } + + pagination.Order = "desc" // Default sort order + if order := string(ctx.QueryArgs().Peek("order")); order != "" { + if order == "asc" || order == "desc" { + pagination.Order = order + } + } + + result, err := h.logManager.Search(filters, pagination) + if err != nil { + h.logger.Error("failed to search logs: %v", err) + SendError(ctx, fasthttp.StatusInternalServerError, fmt.Sprintf("Search failed: %v", err), h.logger) + return + } + + SendJSON(ctx, result, h.logger) +} + +// getDroppedRequests handles GET /api/logs/dropped - Get the number of dropped requests +func (h *LoggingHandler) getDroppedRequests(ctx *fasthttp.RequestCtx) { + droppedRequests := h.logManager.GetDroppedRequests() + SendJSON(ctx, map[string]int64{"dropped_requests": droppedRequests}, h.logger) +} + +// getAvailableModels handles GET /api/logs/models - Get all unique models from logs +func (h *LoggingHandler) getAvailableModels(ctx *fasthttp.RequestCtx) { + models := h.logManager.GetAvailableModels() + SendJSON(ctx, map[string]interface{}{"models": models}, h.logger) +} + +// Helper functions + +// parseCommaSeparated splits a comma-separated string into a slice +func parseCommaSeparated(s string) []string { + if s == "" { + return nil + } + + var result []string + for _, item := range strings.Split(s, ",") { + if trimmed := strings.TrimSpace(item); trimmed != "" { + result = append(result, trimmed) + } + } + + return result +} diff --git a/transports/bifrost-http/handlers/mcp.go b/transports/bifrost-http/handlers/mcp.go new file mode 100644 index 000000000..37fedf8fc --- /dev/null +++ b/transports/bifrost-http/handlers/mcp.go @@ -0,0 +1,219 @@ +// Package handlers provides HTTP request handlers for the Bifrost HTTP transport. +// This file contains MCP (Model Context Protocol) tool execution handlers. +package handlers + +import ( + "encoding/json" + "fmt" + + "github.com/fasthttp/router" + bifrost "github.com/maximhq/bifrost/core" + "github.com/maximhq/bifrost/core/schemas" + "github.com/maximhq/bifrost/transports/bifrost-http/lib" + "github.com/valyala/fasthttp" +) + +// MCPHandler manages HTTP requests for MCP tool operations +type MCPHandler struct { + client *bifrost.Bifrost + logger schemas.Logger + store *lib.Config +} + +// NewMCPHandler creates a new MCP handler instance +func NewMCPHandler(client *bifrost.Bifrost, logger schemas.Logger, store *lib.Config) *MCPHandler { + return &MCPHandler{ + client: client, + logger: logger, + store: store, + } +} + +// RegisterRoutes registers all MCP-related routes +func (h *MCPHandler) RegisterRoutes(r *router.Router) { + // MCP tool execution endpoint + r.POST("/v1/mcp/tool/execute", h.executeTool) + r.GET("/api/mcp/clients", h.getMCPClients) + r.POST("/api/mcp/client", h.addMCPClient) + r.PUT("/api/mcp/client/{name}", h.editMCPClientTools) + r.DELETE("/api/mcp/client/{name}", h.removeMCPClient) + r.POST("/api/mcp/client/{name}/reconnect", h.reconnectMCPClient) +} + +// executeTool handles POST /v1/mcp/tool/execute - Execute MCP tool +func (h *MCPHandler) executeTool(ctx *fasthttp.RequestCtx) { + var req schemas.ToolCall + if err := json.Unmarshal(ctx.PostBody(), &req); err != nil { + SendError(ctx, fasthttp.StatusBadRequest, fmt.Sprintf("Invalid request format: %v", err), h.logger) + return + } + + // Validate required fields + if req.Function.Name == nil || *req.Function.Name == "" { + SendError(ctx, fasthttp.StatusBadRequest, "Tool function name is required", h.logger) + return + } + + // Convert context + bifrostCtx := lib.ConvertToBifrostContext(ctx, false) + if bifrostCtx == nil { + SendError(ctx, fasthttp.StatusInternalServerError, "Failed to convert context", h.logger) + return + } + + // Execute MCP tool + resp, bifrostErr := h.client.ExecuteMCPTool(*bifrostCtx, req) + if bifrostErr != nil { + SendBifrostError(ctx, bifrostErr, h.logger) + return + } + + // Send successful response + SendJSON(ctx, resp, h.logger) +} + +// getMCPClients handles GET /api/mcp/clients - Get all MCP clients +func (h *MCPHandler) getMCPClients(ctx *fasthttp.RequestCtx) { + // Get clients from store config + configsInStore := h.store.MCPConfig + if configsInStore == nil { + SendJSON(ctx, []schemas.MCPClient{}, h.logger) + return + } + + // Get actual connected clients from Bifrost + clientsInBifrost, err := h.client.GetMCPClients() + if err != nil { + SendError(ctx, fasthttp.StatusInternalServerError, fmt.Sprintf("Failed to get MCP clients from Bifrost: %v", err), h.logger) + return + } + + // Create a map of connected clients for quick lookup + connectedClientsMap := make(map[string]schemas.MCPClient) + for _, client := range clientsInBifrost { + connectedClientsMap[client.Name] = client + } + + // Build the final client list, including errored clients + clients := make([]schemas.MCPClient, 0, len(configsInStore.ClientConfigs)) + + for _, configClient := range configsInStore.ClientConfigs { + if connectedClient, exists := connectedClientsMap[configClient.Name]; exists { + // Client is connected, use the actual client data + clients = append(clients, schemas.MCPClient{ + Name: connectedClient.Name, + Config: h.store.RedactMCPClientConfig(connectedClient.Config), + Tools: connectedClient.Tools, + State: connectedClient.State, + }) + } else { + // Client is in config but not connected, mark as errored + clients = append(clients, schemas.MCPClient{ + Name: configClient.Name, + Config: h.store.RedactMCPClientConfig(configClient), + Tools: []string{}, // No tools available since connection failed + State: schemas.MCPConnectionStateError, + }) + } + } + + SendJSON(ctx, clients, h.logger) +} + +// reconnectMCPClient handles POST /api/mcp/client/{name}/reconnect - Reconnect an MCP client +func (h *MCPHandler) reconnectMCPClient(ctx *fasthttp.RequestCtx) { + name, err := getNameFromCtx(ctx) + if err != nil { + SendError(ctx, fasthttp.StatusBadRequest, fmt.Sprintf("Invalid name: %v", err), h.logger) + return + } + + if err := h.client.ReconnectMCPClient(name); err != nil { + SendError(ctx, fasthttp.StatusInternalServerError, fmt.Sprintf("Failed to reconnect MCP client: %v", err), h.logger) + return + } + + SendJSON(ctx, map[string]any{ + "status": "success", + "message": "MCP client reconnected successfully", + }, h.logger) +} + +// addMCPClient handles POST /api/mcp/client - Add a new MCP client +func (h *MCPHandler) addMCPClient(ctx *fasthttp.RequestCtx) { + var req schemas.MCPClientConfig + if err := json.Unmarshal(ctx.PostBody(), &req); err != nil { + SendError(ctx, fasthttp.StatusBadRequest, fmt.Sprintf("Invalid request format: %v", err), h.logger) + return + } + + if err := h.store.AddMCPClient(req); err != nil { + SendError(ctx, fasthttp.StatusInternalServerError, fmt.Sprintf("Failed to add MCP client: %v", err), h.logger) + return + } + + SendJSON(ctx, map[string]any{ + "status": "success", + "message": "MCP client added successfully", + }, h.logger) +} + +// editMCPClientTools handles PUT /api/mcp/client/{name} - Edit MCP client tools +func (h *MCPHandler) editMCPClientTools(ctx *fasthttp.RequestCtx) { + name, err := getNameFromCtx(ctx) + if err != nil { + SendError(ctx, fasthttp.StatusBadRequest, fmt.Sprintf("Invalid name: %v", err), h.logger) + return + } + + var req struct { + ToolsToExecute []string `json:"tools_to_execute,omitempty"` + ToolsToSkip []string `json:"tools_to_skip,omitempty"` + } + if err := json.Unmarshal(ctx.PostBody(), &req); err != nil { + SendError(ctx, fasthttp.StatusBadRequest, fmt.Sprintf("Invalid request format: %v", err), h.logger) + return + } + + if err := h.store.EditMCPClientTools(name, req.ToolsToExecute, req.ToolsToSkip); err != nil { + SendError(ctx, fasthttp.StatusInternalServerError, fmt.Sprintf("Failed to edit MCP client tools: %v", err), h.logger) + return + } + + SendJSON(ctx, map[string]any{ + "status": "success", + "message": "MCP client tools edited successfully", + }, h.logger) +} + +// removeMCPClient handles DELETE /api/mcp/client/{name} - Remove an MCP client +func (h *MCPHandler) removeMCPClient(ctx *fasthttp.RequestCtx) { + name, err := getNameFromCtx(ctx) + if err != nil { + SendError(ctx, fasthttp.StatusBadRequest, fmt.Sprintf("Invalid name: %v", err), h.logger) + return + } + + if err := h.store.RemoveMCPClient(name); err != nil { + SendError(ctx, fasthttp.StatusInternalServerError, fmt.Sprintf("Failed to remove MCP client: %v", err), h.logger) + return + } + + SendJSON(ctx, map[string]any{ + "status": "success", + "message": "MCP client removed successfully", + }, h.logger) +} + +func getNameFromCtx(ctx *fasthttp.RequestCtx) (string, error) { + nameValue := ctx.UserValue("name") + if nameValue == nil { + return "", fmt.Errorf("missing name parameter") + } + nameStr, ok := nameValue.(string) + if !ok { + return "", fmt.Errorf("invalid name parameter type") + } + + return nameStr, nil +} diff --git a/transports/bifrost-http/handlers/plugins.go b/transports/bifrost-http/handlers/plugins.go new file mode 100644 index 000000000..3842ada91 --- /dev/null +++ b/transports/bifrost-http/handlers/plugins.go @@ -0,0 +1,236 @@ +package handlers + +import ( + "encoding/json" + "errors" + + "github.com/fasthttp/router" + "github.com/maximhq/bifrost/core/schemas" + "github.com/maximhq/bifrost/framework/configstore" + "github.com/valyala/fasthttp" + "gorm.io/gorm" +) + +type PluginsHandler struct { + logger schemas.Logger + configStore configstore.ConfigStore +} + +func NewPluginsHandler(configStore configstore.ConfigStore, logger schemas.Logger) *PluginsHandler { + return &PluginsHandler{ + configStore: configStore, + logger: logger, + } +} + +type CreatePluginRequest struct { + Name string `json:"name"` + Enabled bool `json:"enabled"` + Config map[string]interface{} `json:"config"` +} + +type UpdatePluginRequest struct { + Enabled bool `json:"enabled"` + Config map[string]interface{} `json:"config"` +} + +func (h *PluginsHandler) RegisterRoutes(r *router.Router) { + r.GET("/api/plugins", h.getPlugins) + r.GET("/api/plugins/{name}", h.getPlugin) + r.POST("/api/plugins", h.createPlugin) + r.PUT("/api/plugins/{name}", h.updatePlugin) + r.DELETE("/api/plugins/{name}", h.deletePlugin) +} + +func (h *PluginsHandler) getPlugins(ctx *fasthttp.RequestCtx) { + plugins, err := h.configStore.GetPlugins() + if err != nil { + h.logger.Error("failed to get plugins: %v", err) + SendError(ctx, 500, "Failed to retrieve plugins", h.logger) + return + } + + SendJSON(ctx, map[string]interface{}{ + "plugins": plugins, + "count": len(plugins), + }, h.logger) +} + +func (h *PluginsHandler) getPlugin(ctx *fasthttp.RequestCtx) { + // Safely validate the "name" parameter + nameValue := ctx.UserValue("name") + if nameValue == nil { + h.logger.Warn("missing required 'name' parameter in request") + SendError(ctx, 400, "Missing required 'name' parameter", h.logger) + return + } + + name, ok := nameValue.(string) + if !ok { + h.logger.Warn("invalid 'name' parameter type, expected string but got %T", nameValue) + SendError(ctx, 400, "Invalid 'name' parameter type, expected string", h.logger) + return + } + + if name == "" { + h.logger.Warn("empty 'name' parameter provided") + SendError(ctx, 400, "Empty 'name' parameter not allowed", h.logger) + return + } + + plugin, err := h.configStore.GetPlugin(name) + if err != nil { + if errors.Is(err, gorm.ErrRecordNotFound) { + SendError(ctx, fasthttp.StatusNotFound, "Plugin not found", h.logger) + return + } + h.logger.Error("failed to get plugin: %v", err) + SendError(ctx, 500, "Failed to retrieve plugin", h.logger) + return + } + SendJSON(ctx, plugin, h.logger) +} + +func (h *PluginsHandler) createPlugin(ctx *fasthttp.RequestCtx) { + var request CreatePluginRequest + if err := json.Unmarshal(ctx.PostBody(), &request); err != nil { + h.logger.Error("failed to unmarshal create plugin request: %v", err) + SendError(ctx, 400, "Invalid request body", h.logger) + return + } + + // Validate required fields + if request.Name == "" { + SendError(ctx, fasthttp.StatusBadRequest, "Plugin name is required", h.logger) + return + } + + // Check if plugin already exists + existingPlugin, err := h.configStore.GetPlugin(request.Name) + if err == nil && existingPlugin != nil { + SendError(ctx, fasthttp.StatusConflict, "Plugin already exists", h.logger) + return + } + + if err := h.configStore.CreatePlugin(&configstore.TablePlugin{ + Name: request.Name, + Enabled: request.Enabled, + Config: request.Config, + }); err != nil { + h.logger.Error("failed to create plugin: %v", err) + SendError(ctx, 500, "Failed to create plugin", h.logger) + return + } + + plugin, err := h.configStore.GetPlugin(request.Name) + if err != nil { + h.logger.Error("failed to get plugin: %v", err) + SendError(ctx, 500, "Failed to retrieve plugin", h.logger) + return + } + + ctx.SetStatusCode(fasthttp.StatusCreated) + SendJSON(ctx, map[string]interface{}{ + "message": "Plugin created successfully", + "plugin": plugin, + }, h.logger) +} + +func (h *PluginsHandler) updatePlugin(ctx *fasthttp.RequestCtx) { + // Safely validate the "name" parameter + nameValue := ctx.UserValue("name") + if nameValue == nil { + h.logger.Warn("missing required 'name' parameter in update plugin request") + SendError(ctx, 400, "Missing required 'name' parameter", h.logger) + return + } + + name, ok := nameValue.(string) + if !ok { + h.logger.Warn("invalid 'name' parameter type in update plugin request, expected string but got %T", nameValue) + SendError(ctx, 400, "Invalid 'name' parameter type, expected string", h.logger) + return + } + + if name == "" { + h.logger.Warn("empty 'name' parameter provided in update plugin request") + SendError(ctx, 400, "Empty 'name' parameter not allowed", h.logger) + return + } + + // Check if plugin exists + if _, err := h.configStore.GetPlugin(name); err != nil { + h.logger.Warn("plugin not found for update: %s", name) + SendError(ctx, fasthttp.StatusNotFound, "Plugin not found", h.logger) + return + } + + var request UpdatePluginRequest + if err := json.Unmarshal(ctx.PostBody(), &request); err != nil { + h.logger.Error("failed to unmarshal update plugin request: %v", err) + SendError(ctx, 400, "Invalid request body", h.logger) + return + } + + if err := h.configStore.UpdatePlugin(&configstore.TablePlugin{ + Name: name, + Enabled: request.Enabled, + Config: request.Config, + }); err != nil { + h.logger.Error("failed to update plugin: %v", err) + SendError(ctx, 500, "Failed to update plugin", h.logger) + return + } + + plugin, err := h.configStore.GetPlugin(name) + if err != nil { + if errors.Is(err, gorm.ErrRecordNotFound) { + SendError(ctx, fasthttp.StatusNotFound, "Plugin not found", h.logger) + return + } + h.logger.Error("failed to get plugin: %v", err) + SendError(ctx, 500, "Failed to retrieve plugin", h.logger) + return + } + + SendJSON(ctx, map[string]interface{}{ + "message": "Plugin updated successfully", + "plugin": plugin, + }, h.logger) +} + +func (h *PluginsHandler) deletePlugin(ctx *fasthttp.RequestCtx) { + // Safely validate the "name" parameter + nameValue := ctx.UserValue("name") + if nameValue == nil { + h.logger.Warn("missing required 'name' parameter in delete plugin request") + SendError(ctx, 400, "Missing required 'name' parameter", h.logger) + return + } + + name, ok := nameValue.(string) + if !ok { + h.logger.Warn("invalid 'name' parameter type in delete plugin request, expected string but got %T", nameValue) + SendError(ctx, 400, "Invalid 'name' parameter type, expected string", h.logger) + return + } + + if name == "" { + h.logger.Warn("empty 'name' parameter provided in delete plugin request") + SendError(ctx, 400, "Empty 'name' parameter not allowed", h.logger) + return + } + + if err := h.configStore.DeletePlugin(name); err != nil { + if errors.Is(err, gorm.ErrRecordNotFound) { + SendError(ctx, fasthttp.StatusNotFound, "Plugin not found", h.logger) + return + } + h.logger.Error("failed to delete plugin: %v", err) + SendError(ctx, 500, "Failed to delete plugin", h.logger) + return + } + SendJSON(ctx, map[string]interface{}{ + "message": "Plugin deleted successfully", + }, h.logger) +} diff --git a/transports/bifrost-http/handlers/providers.go b/transports/bifrost-http/handlers/providers.go new file mode 100644 index 000000000..fea72c641 --- /dev/null +++ b/transports/bifrost-http/handlers/providers.go @@ -0,0 +1,602 @@ +// Package handlers provides HTTP request handlers for the Bifrost HTTP transport. +// This file contains all provider management functionality including CRUD operations. +package handlers + +import ( + "encoding/json" + "fmt" + "net/url" + "slices" + "sort" + "strings" + + "github.com/fasthttp/router" + bifrost "github.com/maximhq/bifrost/core" + "github.com/maximhq/bifrost/core/schemas" + "github.com/maximhq/bifrost/framework/configstore" + "github.com/maximhq/bifrost/transports/bifrost-http/lib" + "github.com/valyala/fasthttp" +) + +// ProviderHandler manages HTTP requests for provider operations +type ProviderHandler struct { + store *lib.Config + client *bifrost.Bifrost + logger schemas.Logger +} + +// NewProviderHandler creates a new provider handler instance +func NewProviderHandler(store *lib.Config, client *bifrost.Bifrost, logger schemas.Logger) *ProviderHandler { + return &ProviderHandler{ + store: store, + client: client, + logger: logger, + } +} + +// ProviderResponse represents the response for provider operations +type ProviderResponse struct { + Name schemas.ModelProvider `json:"name"` + Keys []schemas.Key `json:"keys"` // API keys for the provider + NetworkConfig schemas.NetworkConfig `json:"network_config"` // Network-related settings + ConcurrencyAndBufferSize schemas.ConcurrencyAndBufferSize `json:"concurrency_and_buffer_size"` // Concurrency settings + ProxyConfig *schemas.ProxyConfig `json:"proxy_config"` // Proxy configuration + SendBackRawResponse bool `json:"send_back_raw_response"` // Include raw response in BifrostResponse + CustomProviderConfig *schemas.CustomProviderConfig `json:"custom_provider_config,omitempty"` // Custom provider configuration +} + +// ListProvidersResponse represents the response for listing all providers +type ListProvidersResponse struct { + Providers []ProviderResponse `json:"providers"` + Total int `json:"total"` +} + +// ErrorResponse represents an error response +type ErrorResponse struct { + Error string `json:"error"` + Message string `json:"message,omitempty"` +} + +// RegisterRoutes registers all provider management routes +func (h *ProviderHandler) RegisterRoutes(r *router.Router) { + // Provider CRUD operations + r.GET("/api/providers", h.listProviders) + r.GET("/api/providers/{provider}", h.getProvider) + r.POST("/api/providers", h.addProvider) + r.PUT("/api/providers/{provider}", h.updateProvider) + r.DELETE("/api/providers/{provider}", h.deleteProvider) + r.GET("/api/keys", h.listKeys) +} + +// listProviders handles GET /api/providers - List all providers +func (h *ProviderHandler) listProviders(ctx *fasthttp.RequestCtx) { + providers, err := h.store.GetAllProviders() + if err != nil { + SendError(ctx, fasthttp.StatusInternalServerError, fmt.Sprintf("Failed to get providers: %v", err), h.logger) + return + } + + providerResponses := []ProviderResponse{} + + // Sort providers alphabetically + sort.Slice(providers, func(i, j int) bool { + return string(providers[i]) < string(providers[j]) + }) + + for _, provider := range providers { + config, err := h.store.GetProviderConfigRedacted(provider) + if err != nil { + h.logger.Warn(fmt.Sprintf("Failed to get config for provider %s: %v", provider, err)) + // Include provider even if config fetch fails + providerResponses = append(providerResponses, ProviderResponse{ + Name: provider, + }) + continue + } + + providerResponses = append(providerResponses, h.getProviderResponseFromConfig(provider, *config)) + } + + response := ListProvidersResponse{ + Providers: providerResponses, + Total: len(providerResponses), + } + + SendJSON(ctx, response, h.logger) +} + +// getProvider handles GET /api/providers/{provider} - Get specific provider +func (h *ProviderHandler) getProvider(ctx *fasthttp.RequestCtx) { + provider, err := getProviderFromCtx(ctx) + if err != nil { + SendError(ctx, fasthttp.StatusBadRequest, fmt.Sprintf("Invalid provider: %v", err), h.logger) + return + } + + config, err := h.store.GetProviderConfigRedacted(provider) + if err != nil { + SendError(ctx, fasthttp.StatusNotFound, fmt.Sprintf("Provider not found: %v", err), h.logger) + return + } + + response := h.getProviderResponseFromConfig(provider, *config) + + SendJSON(ctx, response, h.logger) +} + +// addProvider handles POST /api/providers - Add a new provider +func (h *ProviderHandler) addProvider(ctx *fasthttp.RequestCtx) { + // Payload structure + var payload = struct { + Provider schemas.ModelProvider `json:"provider"` + Keys []schemas.Key `json:"keys"` // API keys for the provider + NetworkConfig *schemas.NetworkConfig `json:"network_config,omitempty"` // Network-related settings + ConcurrencyAndBufferSize *schemas.ConcurrencyAndBufferSize `json:"concurrency_and_buffer_size,omitempty"` // Concurrency settings + ProxyConfig *schemas.ProxyConfig `json:"proxy_config,omitempty"` // Proxy configuration + SendBackRawResponse *bool `json:"send_back_raw_response,omitempty"` // Include raw response in BifrostResponse + CustomProviderConfig *schemas.CustomProviderConfig `json:"custom_provider_config,omitempty"` // Custom provider configuration + }{} + + if err := json.Unmarshal(ctx.PostBody(), &payload); err != nil { + SendError(ctx, fasthttp.StatusBadRequest, fmt.Sprintf("Invalid JSON: %v", err), h.logger) + return + } + + // Validate provider + if payload.Provider == "" { + SendError(ctx, fasthttp.StatusBadRequest, "Missing provider", h.logger) + return + } + + if payload.CustomProviderConfig != nil { + // custom provider key should not be same as standard provider names + if bifrost.IsStandardProvider(payload.Provider) { + SendError(ctx, fasthttp.StatusBadRequest, "Custom provider cannot be same as a standard provider", h.logger) + return + } + + if payload.CustomProviderConfig.BaseProviderType == "" { + SendError(ctx, fasthttp.StatusBadRequest, "BaseProviderType is required when CustomProviderConfig is provided", h.logger) + return + } + + // check if base provider is a supported base provider + if !bifrost.IsSupportedBaseProvider(payload.CustomProviderConfig.BaseProviderType) { + SendError(ctx, fasthttp.StatusBadRequest, "BaseProviderType must be a standard provider", h.logger) + return + } + } + + if payload.ConcurrencyAndBufferSize != nil { + if payload.ConcurrencyAndBufferSize.Concurrency == 0 { + SendError(ctx, fasthttp.StatusBadRequest, "Concurrency must be greater than 0", h.logger) + return + } + if payload.ConcurrencyAndBufferSize.BufferSize == 0 { + SendError(ctx, fasthttp.StatusBadRequest, "Buffer size must be greater than 0", h.logger) + return + } + + if payload.ConcurrencyAndBufferSize.Concurrency > payload.ConcurrencyAndBufferSize.BufferSize { + SendError(ctx, fasthttp.StatusBadRequest, "Concurrency must be less than or equal to buffer size", h.logger) + return + } + } + + // Check if provider already exists + if _, err := h.store.GetProviderConfigRedacted(payload.Provider); err == nil { + SendError(ctx, fasthttp.StatusConflict, fmt.Sprintf("Provider %s already exists", payload.Provider), h.logger) + return + } + + // Construct ProviderConfig from individual fields + config := configstore.ProviderConfig{ + Keys: payload.Keys, + NetworkConfig: payload.NetworkConfig, + ProxyConfig: payload.ProxyConfig, + ConcurrencyAndBufferSize: payload.ConcurrencyAndBufferSize, + SendBackRawResponse: payload.SendBackRawResponse != nil && *payload.SendBackRawResponse, + CustomProviderConfig: payload.CustomProviderConfig, + } + + // Add provider to store (env vars will be processed by store) + if err := h.store.AddProvider(payload.Provider, config); err != nil { + h.logger.Warn(fmt.Sprintf("Failed to add provider %s: %v", payload.Provider, err)) + SendError(ctx, fasthttp.StatusInternalServerError, fmt.Sprintf("Failed to add provider: %v", err), h.logger) + return + } + + h.logger.Info(fmt.Sprintf("Provider %s added successfully", payload.Provider)) + + // Get redacted config for response + redactedConfig, err := h.store.GetProviderConfigRedacted(payload.Provider) + if err != nil { + h.logger.Warn(fmt.Sprintf("Failed to get redacted config for provider %s: %v", payload.Provider, err)) + // Fall back to the raw config (no keys) + response := h.getProviderResponseFromConfig(payload.Provider, configstore.ProviderConfig{ + NetworkConfig: config.NetworkConfig, + ConcurrencyAndBufferSize: config.ConcurrencyAndBufferSize, + ProxyConfig: config.ProxyConfig, + SendBackRawResponse: config.SendBackRawResponse, + CustomProviderConfig: config.CustomProviderConfig, + }) + SendJSON(ctx, response, h.logger) + return + } + + response := h.getProviderResponseFromConfig(payload.Provider, *redactedConfig) + + SendJSON(ctx, response, h.logger) +} + +// updateProvider handles PUT /api/providers/{provider} - Update provider config +// NOTE: This endpoint expects ALL fields to be provided in the request body, +// including both edited and non-edited fields. Partial updates are not supported. +// The frontend should send the complete provider configuration. +// This flow upserts the config +func (h *ProviderHandler) updateProvider(ctx *fasthttp.RequestCtx) { + provider, err := getProviderFromCtx(ctx) + if err != nil { + SendError(ctx, fasthttp.StatusBadRequest, fmt.Sprintf("Invalid provider: %v", err), h.logger) + return + } + + var payload = struct { + Keys []schemas.Key `json:"keys"` // API keys for the provider + NetworkConfig schemas.NetworkConfig `json:"network_config"` // Network-related settings + ConcurrencyAndBufferSize schemas.ConcurrencyAndBufferSize `json:"concurrency_and_buffer_size"` // Concurrency settings + ProxyConfig *schemas.ProxyConfig `json:"proxy_config,omitempty"` // Proxy configuration + SendBackRawResponse *bool `json:"send_back_raw_response,omitempty"` // Include raw response in BifrostResponse + CustomProviderConfig *schemas.CustomProviderConfig `json:"custom_provider_config,omitempty"` // Custom provider configuration + }{} + + if err := json.Unmarshal(ctx.PostBody(), &payload); err != nil { + SendError(ctx, fasthttp.StatusBadRequest, fmt.Sprintf("Invalid JSON: %v", err), h.logger) + return + } + + // Get the raw config to access actual values for merging with redacted request values + oldConfigRaw, err := h.store.GetProviderConfigRaw(provider) + if err != nil { + SendError(ctx, fasthttp.StatusNotFound, err.Error(), h.logger) + return + } + + if oldConfigRaw == nil { + oldConfigRaw = &configstore.ProviderConfig{} + } + + oldConfigRedacted, err := h.store.GetProviderConfigRedacted(provider) + if err != nil { + SendError(ctx, fasthttp.StatusNotFound, err.Error(), h.logger) + return + } + + if oldConfigRedacted == nil { + oldConfigRedacted = &configstore.ProviderConfig{} + } + + // Construct ProviderConfig from individual fields + config := configstore.ProviderConfig{ + Keys: oldConfigRaw.Keys, + NetworkConfig: oldConfigRaw.NetworkConfig, + ConcurrencyAndBufferSize: oldConfigRaw.ConcurrencyAndBufferSize, + ProxyConfig: oldConfigRaw.ProxyConfig, + CustomProviderConfig: oldConfigRaw.CustomProviderConfig, + } + + // Environment variable cleanup is now handled automatically by mergeKeys function + + var keysToAdd []schemas.Key + var keysToUpdate []schemas.Key + + for _, key := range payload.Keys { + if !slices.ContainsFunc(oldConfigRaw.Keys, func(k schemas.Key) bool { + return k.ID == key.ID + }) { + keysToAdd = append(keysToAdd, key) + } else { + keysToUpdate = append(keysToUpdate, key) + } + } + + var keysToDelete []schemas.Key + for _, key := range oldConfigRaw.Keys { + if !slices.ContainsFunc(payload.Keys, func(k schemas.Key) bool { + return k.ID == key.ID + }) { + keysToDelete = append(keysToDelete, key) + } + } + + keys, err := h.mergeKeys(provider, oldConfigRaw.Keys, oldConfigRedacted.Keys, keysToAdd, keysToDelete, keysToUpdate) + if err != nil { + SendError(ctx, fasthttp.StatusBadRequest, fmt.Sprintf("Invalid keys: %v", err), h.logger) + return + } + config.Keys = keys + + if payload.ConcurrencyAndBufferSize.Concurrency == 0 { + SendError(ctx, fasthttp.StatusBadRequest, "Concurrency must be greater than 0", h.logger) + return + } + if payload.ConcurrencyAndBufferSize.BufferSize == 0 { + SendError(ctx, fasthttp.StatusBadRequest, "Buffer size must be greater than 0", h.logger) + return + } + + if payload.ConcurrencyAndBufferSize.Concurrency > payload.ConcurrencyAndBufferSize.BufferSize { + SendError(ctx, fasthttp.StatusBadRequest, "Concurrency must be less than or equal to buffer size", h.logger) + return + } + + // Build a prospective config with the requested CustomProviderConfig (including nil) + prospective := config + prospective.CustomProviderConfig = payload.CustomProviderConfig + if err := lib.ValidateCustomProviderUpdate(prospective, *oldConfigRaw, provider); err != nil { + SendError(ctx, fasthttp.StatusBadRequest, fmt.Sprintf("Invalid custom provider config: %v", err), h.logger) + return + } + + config.ConcurrencyAndBufferSize = &payload.ConcurrencyAndBufferSize + config.NetworkConfig = &payload.NetworkConfig + config.ProxyConfig = payload.ProxyConfig + config.CustomProviderConfig = payload.CustomProviderConfig + if payload.SendBackRawResponse != nil { + config.SendBackRawResponse = *payload.SendBackRawResponse + } + + // Update provider config in store (env vars will be processed by store) + if err := h.store.UpdateProviderConfig(provider, config); err != nil { + h.logger.Warn(fmt.Sprintf("Failed to update provider %s: %v", provider, err)) + SendError(ctx, fasthttp.StatusInternalServerError, fmt.Sprintf("Failed to update provider: %v", err), h.logger) + return + } + + oldConcurrencyAndBufferSize := &schemas.DefaultConcurrencyAndBufferSize + if oldConfigRaw.ConcurrencyAndBufferSize != nil { + oldConcurrencyAndBufferSize = oldConfigRaw.ConcurrencyAndBufferSize + } + + if config.ConcurrencyAndBufferSize.Concurrency != oldConcurrencyAndBufferSize.Concurrency || + config.ConcurrencyAndBufferSize.BufferSize != oldConcurrencyAndBufferSize.BufferSize { + // Update concurrency and queue configuration in Bifrost + if err := h.client.UpdateProviderConcurrency(provider); err != nil { + // Note: Store update succeeded, continue but log the concurrency update failure + h.logger.Warn(fmt.Sprintf("Failed to update concurrency for provider %s: %v", provider, err)) + } + } + + // Get redacted config for response + redactedConfig, err := h.store.GetProviderConfigRedacted(provider) + if err != nil { + h.logger.Warn(fmt.Sprintf("Failed to get redacted config for provider %s: %v", provider, err)) + // Fall back to sanitized config (no keys) + response := h.getProviderResponseFromConfig(provider, configstore.ProviderConfig{ + NetworkConfig: config.NetworkConfig, + ConcurrencyAndBufferSize: config.ConcurrencyAndBufferSize, + ProxyConfig: config.ProxyConfig, + SendBackRawResponse: config.SendBackRawResponse, + CustomProviderConfig: config.CustomProviderConfig, + }) + SendJSON(ctx, response, h.logger) + return + } + + response := h.getProviderResponseFromConfig(provider, *redactedConfig) + + SendJSON(ctx, response, h.logger) +} + +// deleteProvider handles DELETE /api/providers/{provider} - Remove provider +func (h *ProviderHandler) deleteProvider(ctx *fasthttp.RequestCtx) { + provider, err := getProviderFromCtx(ctx) + if err != nil { + SendError(ctx, fasthttp.StatusBadRequest, fmt.Sprintf("Invalid provider: %v", err), h.logger) + return + } + + // Check if provider exists + if _, err := h.store.GetProviderConfigRedacted(provider); err != nil { + SendError(ctx, fasthttp.StatusNotFound, fmt.Sprintf("Provider not found: %v", err), h.logger) + return + } + + // Remove provider from store + if err := h.store.RemoveProvider(provider); err != nil { + h.logger.Warn(fmt.Sprintf("Failed to remove provider %s: %v", provider, err)) + SendError(ctx, fasthttp.StatusInternalServerError, fmt.Sprintf("Failed to remove provider: %v", err), h.logger) + return + } + + h.logger.Info(fmt.Sprintf("Provider %s removed successfully", provider)) + + response := ProviderResponse{ + Name: provider, + } + + SendJSON(ctx, response, h.logger) +} + +// listKeys handles GET /api/keys - List all keys +func (h *ProviderHandler) listKeys(ctx *fasthttp.RequestCtx) { + keys, err := h.store.GetAllKeys() + if err != nil { + SendError(ctx, fasthttp.StatusInternalServerError, fmt.Sprintf("Failed to get keys: %v", err), h.logger) + return + } + + SendJSON(ctx, keys, h.logger) +} + +// mergeKeys merges new keys with old, preserving values that are redacted in the new config +func (h *ProviderHandler) mergeKeys(provider schemas.ModelProvider, oldRawKeys []schemas.Key, oldRedactedKeys []schemas.Key, keysToAdd []schemas.Key, keysToDelete []schemas.Key, keysToUpdate []schemas.Key) ([]schemas.Key, error) { + // Clean up environment variables for deleted keys only + // Updated keys will be cleaned up after merge to avoid premature cleanup + h.store.CleanupEnvKeysForKeys(provider, keysToDelete) + // Create a map of indices to delete + toDelete := make(map[int]bool) + for _, key := range keysToDelete { + for i, oldKey := range oldRawKeys { + if oldKey.ID == key.ID { + toDelete[i] = true + break + } + } + } + + // Create a map of updates by ID for quick lookup + updates := make(map[string]schemas.Key) + for _, key := range keysToUpdate { + updates[key.ID] = key + } + + // Map old redacted keys by ID for reliable lookup + redactedByID := make(map[string]schemas.Key) + for _, rk := range oldRedactedKeys { + redactedByID[rk.ID] = rk + } + + // Process existing keys (handle updates and deletions) + var resultKeys []schemas.Key + for i, oldRawKey := range oldRawKeys { + // Skip if this key should be deleted + if toDelete[i] { + continue + } + + // Check if this key should be updated + if updateKey, exists := updates[oldRawKey.ID]; exists { + oldRedactedKey, ok := redactedByID[oldRawKey.ID] + if !ok { + oldRedactedKey = schemas.Key{} + } + mergedKey := updateKey + + // Handle redacted values - preserve old value if new value is redacted/env var AND it's the same as old redacted value + if lib.IsRedacted(updateKey.Value) && + strings.EqualFold(updateKey.Value, oldRedactedKey.Value) { + mergedKey.Value = oldRawKey.Value + } + + // Handle Azure config redacted values + if updateKey.AzureKeyConfig != nil && oldRedactedKey.AzureKeyConfig != nil && oldRawKey.AzureKeyConfig != nil { + if lib.IsRedacted(updateKey.AzureKeyConfig.Endpoint) && + strings.EqualFold(updateKey.AzureKeyConfig.Endpoint, oldRedactedKey.AzureKeyConfig.Endpoint) { + mergedKey.AzureKeyConfig.Endpoint = oldRawKey.AzureKeyConfig.Endpoint + } + if updateKey.AzureKeyConfig.APIVersion != nil && + oldRedactedKey.AzureKeyConfig.APIVersion != nil && + oldRawKey.AzureKeyConfig != nil { + if lib.IsRedacted(*updateKey.AzureKeyConfig.APIVersion) && + strings.EqualFold(*updateKey.AzureKeyConfig.APIVersion, *oldRedactedKey.AzureKeyConfig.APIVersion) { + mergedKey.AzureKeyConfig.APIVersion = oldRawKey.AzureKeyConfig.APIVersion + } + } + } + + // Handle Vertex config redacted values + if updateKey.VertexKeyConfig != nil && oldRedactedKey.VertexKeyConfig != nil && oldRawKey.VertexKeyConfig != nil { + if lib.IsRedacted(updateKey.VertexKeyConfig.ProjectID) && + strings.EqualFold(updateKey.VertexKeyConfig.ProjectID, oldRedactedKey.VertexKeyConfig.ProjectID) { + mergedKey.VertexKeyConfig.ProjectID = oldRawKey.VertexKeyConfig.ProjectID + } + if lib.IsRedacted(updateKey.VertexKeyConfig.Region) && + strings.EqualFold(updateKey.VertexKeyConfig.Region, oldRedactedKey.VertexKeyConfig.Region) { + mergedKey.VertexKeyConfig.Region = oldRawKey.VertexKeyConfig.Region + } + if lib.IsRedacted(updateKey.VertexKeyConfig.AuthCredentials) && + strings.EqualFold(updateKey.VertexKeyConfig.AuthCredentials, oldRedactedKey.VertexKeyConfig.AuthCredentials) { + mergedKey.VertexKeyConfig.AuthCredentials = oldRawKey.VertexKeyConfig.AuthCredentials + } + } + + // Handle Bedrock config redacted values + if updateKey.BedrockKeyConfig != nil && oldRedactedKey.BedrockKeyConfig != nil && oldRawKey.BedrockKeyConfig != nil { + if lib.IsRedacted(updateKey.BedrockKeyConfig.AccessKey) && + strings.EqualFold(updateKey.BedrockKeyConfig.AccessKey, oldRedactedKey.BedrockKeyConfig.AccessKey) { + mergedKey.BedrockKeyConfig.AccessKey = oldRawKey.BedrockKeyConfig.AccessKey + } + if lib.IsRedacted(updateKey.BedrockKeyConfig.SecretKey) && + strings.EqualFold(updateKey.BedrockKeyConfig.SecretKey, oldRedactedKey.BedrockKeyConfig.SecretKey) { + mergedKey.BedrockKeyConfig.SecretKey = oldRawKey.BedrockKeyConfig.SecretKey + } + if updateKey.BedrockKeyConfig.SessionToken != nil && + oldRedactedKey.BedrockKeyConfig.SessionToken != nil && + oldRawKey.BedrockKeyConfig != nil { + if lib.IsRedacted(*updateKey.BedrockKeyConfig.SessionToken) && + strings.EqualFold(*updateKey.BedrockKeyConfig.SessionToken, *oldRedactedKey.BedrockKeyConfig.SessionToken) { + mergedKey.BedrockKeyConfig.SessionToken = oldRawKey.BedrockKeyConfig.SessionToken + } + } + if updateKey.BedrockKeyConfig.Region != nil { + if lib.IsRedacted(*updateKey.BedrockKeyConfig.Region) && + (!strings.HasPrefix(*updateKey.BedrockKeyConfig.Region, "env.") || + (oldRedactedKey.BedrockKeyConfig.Region != nil && + !strings.EqualFold(*updateKey.BedrockKeyConfig.Region, *oldRedactedKey.BedrockKeyConfig.Region))) { + mergedKey.BedrockKeyConfig.Region = oldRawKey.BedrockKeyConfig.Region + } + } + if updateKey.BedrockKeyConfig.ARN != nil { + if lib.IsRedacted(*updateKey.BedrockKeyConfig.ARN) && + (!strings.HasPrefix(*updateKey.BedrockKeyConfig.ARN, "env.") || + (oldRedactedKey.BedrockKeyConfig.ARN != nil && + !strings.EqualFold(*updateKey.BedrockKeyConfig.ARN, *oldRedactedKey.BedrockKeyConfig.ARN))) { + mergedKey.BedrockKeyConfig.ARN = oldRawKey.BedrockKeyConfig.ARN + } + } + } + + resultKeys = append(resultKeys, mergedKey) + } else { + // Keep unchanged key + resultKeys = append(resultKeys, oldRawKey) + } + } + + // Add new keys + resultKeys = append(resultKeys, keysToAdd...) + + // Clean up environment variables for updated keys after merge + // This allows us to compare the final merged values with the original values + h.store.CleanupEnvKeysForUpdatedKeys(provider, keysToUpdate, oldRawKeys, resultKeys) + + return resultKeys, nil +} + +func (h *ProviderHandler) getProviderResponseFromConfig(provider schemas.ModelProvider, config configstore.ProviderConfig) ProviderResponse { + if config.NetworkConfig == nil { + config.NetworkConfig = &schemas.DefaultNetworkConfig + } + if config.ConcurrencyAndBufferSize == nil { + config.ConcurrencyAndBufferSize = &schemas.DefaultConcurrencyAndBufferSize + } + + return ProviderResponse{ + Name: provider, + Keys: config.Keys, + NetworkConfig: *config.NetworkConfig, + ConcurrencyAndBufferSize: *config.ConcurrencyAndBufferSize, + ProxyConfig: config.ProxyConfig, + SendBackRawResponse: config.SendBackRawResponse, + CustomProviderConfig: config.CustomProviderConfig, + } +} + +func getProviderFromCtx(ctx *fasthttp.RequestCtx) (schemas.ModelProvider, error) { + providerValue := ctx.UserValue("provider") + if providerValue == nil { + return "", fmt.Errorf("missing provider parameter") + } + providerStr, ok := providerValue.(string) + if !ok { + return "", fmt.Errorf("invalid provider parameter type") + } + + decoded, err := url.PathUnescape(providerStr) + if err != nil { + return "", fmt.Errorf("invalid provider parameter encoding: %v", err) + } + + return schemas.ModelProvider(decoded), nil +} diff --git a/transports/bifrost-http/handlers/utils.go b/transports/bifrost-http/handlers/utils.go new file mode 100644 index 000000000..80ed171c9 --- /dev/null +++ b/transports/bifrost-http/handlers/utils.go @@ -0,0 +1,112 @@ +// Package handlers provides HTTP request handlers for the Bifrost HTTP transport. +// This file contains common utility functions used across all handlers. +package handlers + +import ( + "encoding/json" + "fmt" + "slices" + "strings" + + "github.com/maximhq/bifrost/core/schemas" + "github.com/valyala/fasthttp" +) + +// SendJSON sends a JSON response with 200 OK status +func SendJSON(ctx *fasthttp.RequestCtx, data interface{}, logger schemas.Logger) { + ctx.SetStatusCode(fasthttp.StatusOK) + ctx.SetContentType("application/json") + + if err := json.NewEncoder(ctx).Encode(data); err != nil { + logger.Warn(fmt.Sprintf("Failed to encode JSON response: %v", err)) + SendError(ctx, fasthttp.StatusInternalServerError, fmt.Sprintf("Failed to encode response: %v", err), logger) + } +} + +// SendError sends a BifrostError response +func SendError(ctx *fasthttp.RequestCtx, statusCode int, message string, logger schemas.Logger) { + bifrostErr := &schemas.BifrostError{ + IsBifrostError: false, + StatusCode: &statusCode, + Error: schemas.ErrorField{ + Message: message, + }, + } + SendBifrostError(ctx, bifrostErr, logger) +} + +// SendBifrostError sends a BifrostError response +func SendBifrostError(ctx *fasthttp.RequestCtx, bifrostErr *schemas.BifrostError, logger schemas.Logger) { + if bifrostErr.StatusCode != nil { + ctx.SetStatusCode(*bifrostErr.StatusCode) + } else if !bifrostErr.IsBifrostError { + ctx.SetStatusCode(fasthttp.StatusBadRequest) + } else { + ctx.SetStatusCode(fasthttp.StatusInternalServerError) + } + + ctx.SetContentType("application/json") + if encodeErr := json.NewEncoder(ctx).Encode(bifrostErr); encodeErr != nil { + logger.Warn(fmt.Sprintf("Failed to encode error response: %v", encodeErr)) + ctx.SetStatusCode(fasthttp.StatusInternalServerError) + ctx.SetBodyString(fmt.Sprintf("Failed to encode error response: %v", encodeErr)) + } +} + +// SendSSEError sends an error in Server-Sent Events format +func SendSSEError(ctx *fasthttp.RequestCtx, bifrostErr *schemas.BifrostError, logger schemas.Logger) { + errorJSON, err := json.Marshal(map[string]interface{}{ + "error": bifrostErr, + }) + if err != nil { + logger.Error("failed to marshal error for SSE: %v", err) + ctx.SetStatusCode(fasthttp.StatusInternalServerError) + return + } + + if _, err := fmt.Fprintf(ctx, "data: %s\n\n", errorJSON); err != nil { + logger.Warn(fmt.Sprintf("Failed to write SSE error: %v", err)) + } +} + +// IsOriginAllowed checks if the given origin is allowed based on localhost rules and configured allowed origins. +// Localhost origins are always allowed. Additional origins can be configured in allowedOrigins. +func IsOriginAllowed(origin string, allowedOrigins []string) bool { + // Always allow localhost origins + if isLocalhostOrigin(origin) { + return true + } + + // Check configured allowed origins + return slices.Contains(allowedOrigins, origin) +} + +// isLocalhostOrigin checks if the given origin is a localhost origin +func isLocalhostOrigin(origin string) bool { + return strings.HasPrefix(origin, "http://localhost:") || + strings.HasPrefix(origin, "https://localhost:") || + strings.HasPrefix(origin, "http://127.0.0.1:") || + strings.HasPrefix(origin, "http://0.0.0.0:") || + strings.HasPrefix(origin, "https://127.0.0.1:") +} + +// ParseModel parses a model string in the format "provider/model" or "provider/nested/model" +// Returns the provider and full model name after the first slash +func ParseModel(model string) (string, string, error) { + model = strings.TrimSpace(model) + if model == "" { + return "", "", fmt.Errorf("model cannot be empty") + } + + parts := strings.SplitN(model, "/", 2) + if len(parts) < 2 { + return "", "", fmt.Errorf("model must be in the format 'provider/model'") + } + + provider := strings.TrimSpace(parts[0]) + name := strings.TrimSpace(parts[1]) + if provider == "" || name == "" { + return "", "", fmt.Errorf("model must be in the format 'provider/model' with non-empty provider and model") + } + return provider, name, nil +} diff --git a/transports/bifrost-http/handlers/websocket.go b/transports/bifrost-http/handlers/websocket.go new file mode 100644 index 000000000..4e640574f --- /dev/null +++ b/transports/bifrost-http/handlers/websocket.go @@ -0,0 +1,255 @@ +// Package handlers provides HTTP request handlers for the Bifrost HTTP transport. +// This file contains WebSocket handlers for real-time log streaming. +package handlers + +import ( + "encoding/json" + "strings" + "sync" + "time" + + "github.com/fasthttp/router" + "github.com/fasthttp/websocket" + "github.com/maximhq/bifrost/core/schemas" + "github.com/maximhq/bifrost/framework/logstore" + "github.com/maximhq/bifrost/plugins/logging" + "github.com/valyala/fasthttp" +) + +// WebSocketClient represents a connected WebSocket client with its own mutex +type WebSocketClient struct { + conn *websocket.Conn + mu sync.Mutex // Per-connection mutex for thread-safe writes +} + +// WebSocketHandler manages WebSocket connections for real-time updates +type WebSocketHandler struct { + logManager logging.LogManager + logger schemas.Logger + allowedOrigins []string + clients map[*websocket.Conn]*WebSocketClient + mu sync.RWMutex + stopChan chan struct{} // Channel to signal heartbeat goroutine to stop + done chan struct{} // Channel to signal when heartbeat goroutine has stopped +} + +// NewWebSocketHandler creates a new WebSocket handler instance +func NewWebSocketHandler(logManager logging.LogManager, logger schemas.Logger, allowedOrigins []string) *WebSocketHandler { + return &WebSocketHandler{ + logManager: logManager, + logger: logger, + allowedOrigins: allowedOrigins, + clients: make(map[*websocket.Conn]*WebSocketClient), + stopChan: make(chan struct{}), + done: make(chan struct{}), + } +} + +// RegisterRoutes registers all WebSocket-related routes +func (h *WebSocketHandler) RegisterRoutes(r *router.Router) { + r.GET("/ws/logs", h.connectLogStream) +} + +// getUpgrader returns a WebSocket upgrader configured with the current allowed origins +func (h *WebSocketHandler) getUpgrader() websocket.FastHTTPUpgrader { + return websocket.FastHTTPUpgrader{ + ReadBufferSize: 1024, + WriteBufferSize: 1024, + CheckOrigin: func(ctx *fasthttp.RequestCtx) bool { + origin := string(ctx.Request.Header.Peek("Origin")) + if origin == "" { + // If no Origin header, check the Host header for direct connections + host := string(ctx.Request.Header.Peek("Host")) + return isLocalhost(host) + } + // Check if origin is allowed (localhost always allowed + configured origins) + return IsOriginAllowed(origin, h.allowedOrigins) + }, + } +} + +// isLocalhost checks if the given host is localhost +func isLocalhost(host string) bool { + // Remove port if present + if idx := strings.LastIndex(host, ":"); idx != -1 { + host = host[:idx] + } + + // Check for localhost variations + return host == "localhost" || + host == "127.0.0.1" || + host == "::1" || + host == "" +} + +// connectLogStream handles WebSocket connections for real-time log streaming +func (h *WebSocketHandler) connectLogStream(ctx *fasthttp.RequestCtx) { + upgrader := h.getUpgrader() + err := upgrader.Upgrade(ctx, func(ws *websocket.Conn) { + // Read safety & liveness + ws.SetReadLimit(50 << 20) // 50 MiB + ws.SetReadDeadline(time.Now().Add(60 * time.Second)) + ws.SetPongHandler(func(string) error { + ws.SetReadDeadline(time.Now().Add(60 * time.Second)) + return nil + }) + // Create a new client with its own mutex + client := &WebSocketClient{ + conn: ws, + } + + // Register new client + h.mu.Lock() + h.clients[ws] = client + h.mu.Unlock() + + // Clean up on disconnect + defer func() { + h.mu.Lock() + delete(h.clients, ws) + h.mu.Unlock() + ws.Close() + }() + + // Keep connection alive and handle client messages + // This loop continuously reads and discards incoming WebSocket messages to: + // 1. Keep the connection alive by processing client pings and control frames + // 2. Detect when the client disconnects by watching for close frames or errors + // 3. Maintain proper WebSocket protocol handling without accumulating messages + for { + _, _, err := ws.ReadMessage() + if err != nil { + // Only log unexpected close errors + if websocket.IsUnexpectedCloseError(err, + websocket.CloseNormalClosure, + websocket.CloseGoingAway, + websocket.CloseAbnormalClosure, + websocket.CloseNoStatusReceived) { + h.logger.Error("websocket read error: %v", err) + } + break + } + } + }) + + if err != nil { + h.logger.Error("websocket upgrade error: %v", err) + return + } +} + +// sendMessageSafely sends a message to a client with proper locking and error handling +func (h *WebSocketHandler) sendMessageSafely(client *WebSocketClient, messageType int, data []byte) error { + client.mu.Lock() + defer client.mu.Unlock() + + // Set a write deadline to prevent hanging connections + client.conn.SetWriteDeadline(time.Now().Add(10 * time.Second)) + defer client.conn.SetWriteDeadline(time.Time{}) // Clear the deadline + + err := client.conn.WriteMessage(messageType, data) + if err != nil { + // Remove the client from the map if write fails + go func() { + h.mu.Lock() + delete(h.clients, client.conn) + h.mu.Unlock() + client.conn.Close() + }() + } + return err +} + +// BroadcastLogUpdate sends a log update to all connected WebSocket clients +func (h *WebSocketHandler) BroadcastLogUpdate(logEntry *logstore.Log) { + // Add panic recovery to prevent server crashes + defer func() { + if r := recover(); r != nil { + h.logger.Error("panic in BroadcastLogUpdate: %v", r) + } + }() + + // Determine operation type based on log status and timestamp + operationType := "update" + if logEntry.Status == "processing" && logEntry.CreatedAt.Equal(logEntry.Timestamp) { + operationType = "create" + } + + message := struct { + Type string `json:"type"` + Operation string `json:"operation"` // "create" or "update" + Payload *logstore.Log `json:"payload"` + }{ + Type: "log", + Operation: operationType, + Payload: logEntry, + } + + data, err := json.Marshal(message) + if err != nil { + h.logger.Error("failed to marshal log entry: %v", err) + return + } + + // Get a snapshot of clients to avoid holding the lock during writes + h.mu.RLock() + clients := make([]*WebSocketClient, 0, len(h.clients)) + for _, client := range h.clients { + clients = append(clients, client) + } + h.mu.RUnlock() + + // Send message to each client safely + for _, client := range clients { + if err := h.sendMessageSafely(client, websocket.TextMessage, data); err != nil { + h.logger.Error("failed to send message to client: %v", err) + } + } +} + +// StartHeartbeat starts sending periodic heartbeat messages to keep connections alive +func (h *WebSocketHandler) StartHeartbeat() { + ticker := time.NewTicker(30 * time.Second) + go func() { + defer func() { + ticker.Stop() + close(h.done) + }() + + for { + select { + case <-ticker.C: + // Get a snapshot of clients to avoid holding the lock during writes + h.mu.RLock() + clients := make([]*WebSocketClient, 0, len(h.clients)) + for _, client := range h.clients { + clients = append(clients, client) + } + h.mu.RUnlock() + + // Send heartbeat to each client safely + for _, client := range clients { + if err := h.sendMessageSafely(client, websocket.PingMessage, nil); err != nil { + h.logger.Error("failed to send heartbeat: %v", err) + } + } + case <-h.stopChan: + return + } + } + }() +} + +// Stop gracefully shuts down the WebSocket handler +func (h *WebSocketHandler) Stop() { + close(h.stopChan) // Signal heartbeat goroutine to stop + <-h.done // Wait for heartbeat goroutine to finish + + // Close all client connections + h.mu.Lock() + for _, client := range h.clients { + client.conn.Close() + } + h.clients = make(map[*websocket.Conn]*WebSocketClient) + h.mu.Unlock() +} diff --git a/transports/bifrost-http/integrations/anthropic/router.go b/transports/bifrost-http/integrations/anthropic/router.go new file mode 100644 index 000000000..8c9e8a36b --- /dev/null +++ b/transports/bifrost-http/integrations/anthropic/router.go @@ -0,0 +1,55 @@ +package anthropic + +import ( + "errors" + + bifrost "github.com/maximhq/bifrost/core" + "github.com/maximhq/bifrost/core/schemas" + "github.com/maximhq/bifrost/transports/bifrost-http/integrations" + "github.com/maximhq/bifrost/transports/bifrost-http/lib" +) + +// AnthropicRouter handles Anthropic-compatible API endpoints +type AnthropicRouter struct { + *integrations.GenericRouter +} + +// CreateAnthropicRouteConfigs creates route configurations for Anthropic endpoints. +func CreateAnthropicRouteConfigs(pathPrefix string) []integrations.RouteConfig { + return []integrations.RouteConfig{ + { + Path: pathPrefix + "/v1/messages", + Method: "POST", + GetRequestTypeInstance: func() interface{} { + return &AnthropicMessageRequest{} + }, + RequestConverter: func(req interface{}) (*schemas.BifrostRequest, error) { + if anthropicReq, ok := req.(*AnthropicMessageRequest); ok { + return anthropicReq.ConvertToBifrostRequest(), nil + } + return nil, errors.New("invalid request type") + }, + ResponseConverter: func(resp *schemas.BifrostResponse) (interface{}, error) { + return DeriveAnthropicFromBifrostResponse(resp), nil + }, + ErrorConverter: func(err *schemas.BifrostError) interface{} { + return DeriveAnthropicErrorFromBifrostError(err) + }, + StreamConfig: &integrations.StreamConfig{ + ResponseConverter: func(resp *schemas.BifrostResponse) (interface{}, error) { + return DeriveAnthropicStreamFromBifrostResponse(resp), nil + }, + ErrorConverter: func(err *schemas.BifrostError) interface{} { + return DeriveAnthropicStreamFromBifrostError(err) + }, + }, + }, + } +} + +// NewAnthropicRouter creates a new AnthropicRouter with the given bifrost client. +func NewAnthropicRouter(client *bifrost.Bifrost, handlerStore lib.HandlerStore) *AnthropicRouter { + return &AnthropicRouter{ + GenericRouter: integrations.NewGenericRouter(client, handlerStore, CreateAnthropicRouteConfigs("/anthropic")), + } +} diff --git a/transports/bifrost-http/integrations/anthropic/types.go b/transports/bifrost-http/integrations/anthropic/types.go new file mode 100644 index 000000000..c907d0775 --- /dev/null +++ b/transports/bifrost-http/integrations/anthropic/types.go @@ -0,0 +1,698 @@ +package anthropic + +import ( + "encoding/json" + "fmt" + + bifrost "github.com/maximhq/bifrost/core" + "github.com/maximhq/bifrost/core/schemas" + "github.com/maximhq/bifrost/transports/bifrost-http/integrations" +) + +var fnTypePtr = bifrost.Ptr(string(schemas.ToolChoiceTypeFunction)) + +// AnthropicContentBlock represents content in Anthropic message format +type AnthropicContentBlock struct { + Type string `json:"type"` // "text", "image", "tool_use", "tool_result" + Text *string `json:"text,omitempty"` // For text content + ToolUseID *string `json:"tool_use_id,omitempty"` // For tool_result content + ID *string `json:"id,omitempty"` // For tool_use content + Name *string `json:"name,omitempty"` // For tool_use content + Input interface{} `json:"input,omitempty"` // For tool_use content + Content AnthropicContent `json:"content,omitempty"` // For tool_result content + Source *AnthropicImageSource `json:"source,omitempty"` // For image content +} + +// AnthropicImageSource represents image source in Anthropic format +type AnthropicImageSource struct { + Type string `json:"type"` // "base64" or "url" + MediaType *string `json:"media_type,omitempty"` // "image/jpeg", "image/png", etc. + Data *string `json:"data,omitempty"` // Base64-encoded image data + URL *string `json:"url,omitempty"` // URL of the image +} + +// AnthropicMessage represents a message in Anthropic format +type AnthropicMessage struct { + Role string `json:"role"` // "user", "assistant" + Content AnthropicContent `json:"content"` // Array of content blocks +} + +type AnthropicContent struct { + ContentStr *string + ContentBlocks *[]AnthropicContentBlock +} + +// AnthropicTool represents a tool in Anthropic format +type AnthropicTool struct { + Name string `json:"name"` + Type *string `json:"type,omitempty"` + Description string `json:"description"` + InputSchema *struct { + Type string `json:"type"` // "object" + Properties map[string]interface{} `json:"properties"` + Required []string `json:"required"` + } `json:"input_schema,omitempty"` +} + +// AnthropicToolChoice represents tool choice in Anthropic format +type AnthropicToolChoice struct { + Type string `json:"type"` // "auto", "any", "tool" + Name string `json:"name,omitempty"` // For type "tool" +} + +// AnthropicMessageRequest represents an Anthropic messages API request +type AnthropicMessageRequest struct { + Model string `json:"model"` + MaxTokens int `json:"max_tokens"` + Messages []AnthropicMessage `json:"messages"` + System *AnthropicContent `json:"system,omitempty"` + Temperature *float64 `json:"temperature,omitempty"` + TopP *float64 `json:"top_p,omitempty"` + TopK *int `json:"top_k,omitempty"` + StopSequences *[]string `json:"stop_sequences,omitempty"` + Stream *bool `json:"stream,omitempty"` + Tools *[]AnthropicTool `json:"tools,omitempty"` + ToolChoice *AnthropicToolChoice `json:"tool_choice,omitempty"` +} + +// IsStreamingRequested implements the StreamingRequest interface +func (r *AnthropicMessageRequest) IsStreamingRequested() bool { + return r.Stream != nil && *r.Stream +} + +// AnthropicMessageResponse represents an Anthropic messages API response +type AnthropicMessageResponse struct { + ID string `json:"id"` + Type string `json:"type"` + Role string `json:"role"` + Content []AnthropicContentBlock `json:"content"` + Model string `json:"model"` + StopReason *string `json:"stop_reason,omitempty"` + StopSequence *string `json:"stop_sequence,omitempty"` + Usage *AnthropicUsage `json:"usage,omitempty"` +} + +// AnthropicUsage represents usage information in Anthropic format +type AnthropicUsage struct { + InputTokens int `json:"input_tokens"` + OutputTokens int `json:"output_tokens"` +} + +// AnthropicMessageError represents an Anthropic messages API error response +type AnthropicMessageError struct { + Type string `json:"type"` // always "error" + Error AnthropicMessageErrorStruct `json:"error"` // Error details +} + +// AnthropicMessageErrorStruct represents the error structure of an Anthropic messages API error response +type AnthropicMessageErrorStruct struct { + Type string `json:"type"` // Error type + Message string `json:"message"` // Error message +} + +// AnthropicStreamResponse represents a single chunk in the Anthropic streaming response +// This matches the format expected by Anthropic's streaming API clients +type AnthropicStreamResponse struct { + Type string `json:"type"` + ID *string `json:"id,omitempty"` + Model *string `json:"model,omitempty"` + Index *int `json:"index,omitempty"` + Message *AnthropicStreamMessage `json:"message,omitempty"` + ContentBlock *AnthropicContentBlock `json:"content_block,omitempty"` + Delta *AnthropicStreamDelta `json:"delta,omitempty"` + Usage *AnthropicUsage `json:"usage,omitempty"` +} + +// AnthropicStreamMessage represents the message structure in streaming events +type AnthropicStreamMessage struct { + ID string `json:"id"` + Type string `json:"type"` + Role string `json:"role"` + Content []AnthropicContentBlock `json:"content"` + Model string `json:"model"` + StopReason *string `json:"stop_reason,omitempty"` + StopSequence *string `json:"stop_sequence,omitempty"` + Usage *AnthropicUsage `json:"usage,omitempty"` +} + +// AnthropicStreamDelta represents the incremental content in a streaming chunk +type AnthropicStreamDelta struct { + Type string `json:"type"` + Text *string `json:"text,omitempty"` + Thinking *string `json:"thinking,omitempty"` + PartialJSON *string `json:"partial_json,omitempty"` + StopReason *string `json:"stop_reason,omitempty"` + StopSequence *string `json:"stop_sequence,omitempty"` +} + +// MarshalJSON implements custom JSON marshalling for MessageContent. +// It marshals either ContentStr or ContentBlocks directly without wrapping. +func (mc AnthropicContent) MarshalJSON() ([]byte, error) { + // Validation: ensure only one field is set at a time + if mc.ContentStr != nil && mc.ContentBlocks != nil { + return nil, fmt.Errorf("both ContentStr and ContentBlocks are set; only one should be non-nil") + } + + if mc.ContentStr != nil { + return json.Marshal(*mc.ContentStr) + } + if mc.ContentBlocks != nil { + return json.Marshal(*mc.ContentBlocks) + } + // If both are nil, return null + return json.Marshal(nil) +} + +// UnmarshalJSON implements custom JSON unmarshalling for MessageContent. +// It determines whether "content" is a string or array and assigns to the appropriate field. +// It also handles direct string/array content without a wrapper object. +func (mc *AnthropicContent) UnmarshalJSON(data []byte) error { + // First, try to unmarshal as a direct string + var stringContent string + if err := json.Unmarshal(data, &stringContent); err == nil { + mc.ContentStr = &stringContent + return nil + } + + // Try to unmarshal as a direct array of ContentBlock + var arrayContent []AnthropicContentBlock + if err := json.Unmarshal(data, &arrayContent); err == nil { + mc.ContentBlocks = &arrayContent + return nil + } + + return fmt.Errorf("content field is neither a string nor an array of ContentBlock") +} + +// ConvertToBifrostRequest converts an Anthropic messages request to Bifrost format +func (r *AnthropicMessageRequest) ConvertToBifrostRequest() *schemas.BifrostRequest { + provider, model := integrations.ParseModelString(r.Model, schemas.Anthropic, false) + + bifrostReq := &schemas.BifrostRequest{ + Provider: provider, + Model: model, + } + + messages := []schemas.BifrostMessage{} + + // Add system message if present + if r.System != nil { + if r.System.ContentStr != nil && *r.System.ContentStr != "" { + messages = append(messages, schemas.BifrostMessage{ + Role: schemas.ModelChatMessageRoleSystem, + Content: schemas.MessageContent{ + ContentStr: r.System.ContentStr, + }, + }) + } else if r.System.ContentBlocks != nil { + contentBlocks := []schemas.ContentBlock{} + for _, block := range *r.System.ContentBlocks { + contentBlocks = append(contentBlocks, schemas.ContentBlock{ + Type: schemas.ContentBlockTypeText, + Text: block.Text, + }) + } + messages = append(messages, schemas.BifrostMessage{ + Role: schemas.ModelChatMessageRoleSystem, + Content: schemas.MessageContent{ + ContentBlocks: &contentBlocks, + }, + }) + } + } + + // Convert messages + for _, msg := range r.Messages { + var bifrostMsg schemas.BifrostMessage + bifrostMsg.Role = schemas.ModelChatMessageRole(msg.Role) + + if msg.Content.ContentStr != nil { + bifrostMsg.Content = schemas.MessageContent{ + ContentStr: msg.Content.ContentStr, + } + } else if msg.Content.ContentBlocks != nil { + // Handle different content types + var toolCalls []schemas.ToolCall + var contentBlocks []schemas.ContentBlock + + for _, content := range *msg.Content.ContentBlocks { + switch content.Type { + case "text": + if content.Text != nil { + contentBlocks = append(contentBlocks, schemas.ContentBlock{ + Type: schemas.ContentBlockTypeText, + Text: content.Text, + }) + } + case "image": + if content.Source != nil { + contentBlocks = append(contentBlocks, schemas.ContentBlock{ + Type: schemas.ContentBlockTypeImage, + ImageURL: &schemas.ImageURLStruct{ + URL: func() string { + if content.Source.Data != nil { + mime := "image/png" + if content.Source.MediaType != nil && *content.Source.MediaType != "" { + mime = *content.Source.MediaType + } + return "data:" + mime + ";base64," + *content.Source.Data + } + if content.Source.URL != nil { + return *content.Source.URL + } + return "" + }(), + }, + }) + } + case "tool_use": + if content.ID != nil && content.Name != nil { + tc := schemas.ToolCall{ + Type: fnTypePtr, + ID: content.ID, + Function: schemas.FunctionCall{ + Name: content.Name, + Arguments: jsonifyInput(content.Input), + }, + } + toolCalls = append(toolCalls, tc) + } + case "tool_result": + if content.ToolUseID != nil { + bifrostMsg.ToolMessage = &schemas.ToolMessage{ + ToolCallID: content.ToolUseID, + } + if content.Content.ContentStr != nil { + contentBlocks = append(contentBlocks, schemas.ContentBlock{ + Type: schemas.ContentBlockTypeText, + Text: content.Content.ContentStr, + }) + } else if content.Content.ContentBlocks != nil { + for _, block := range *content.Content.ContentBlocks { + if block.Text != nil { + contentBlocks = append(contentBlocks, schemas.ContentBlock{ + Type: schemas.ContentBlockTypeText, + Text: block.Text, + }) + } else if block.Source != nil { + contentBlocks = append(contentBlocks, schemas.ContentBlock{ + Type: schemas.ContentBlockTypeImage, + ImageURL: &schemas.ImageURLStruct{ + URL: func() string { + if block.Source.Data != nil { + mime := "image/png" + if block.Source.MediaType != nil && *block.Source.MediaType != "" { + mime = *block.Source.MediaType + } + return "data:" + mime + ";base64," + *block.Source.Data + } + if block.Source.URL != nil { + return *block.Source.URL + } + return "" + }()}, + }) + } + } + } + bifrostMsg.Role = schemas.ModelChatMessageRoleTool + } + } + } + + // Concatenate all text contents + if len(contentBlocks) > 0 { + bifrostMsg.Content = schemas.MessageContent{ + ContentBlocks: &contentBlocks, + } + } + + if len(toolCalls) > 0 && msg.Role == string(schemas.ModelChatMessageRoleAssistant) { + bifrostMsg.AssistantMessage = &schemas.AssistantMessage{ + ToolCalls: &toolCalls, + } + } + } + messages = append(messages, bifrostMsg) + } + + bifrostReq.Input.ChatCompletionInput = &messages + + // Convert parameters + if r.MaxTokens > 0 || r.Temperature != nil || r.TopP != nil || r.TopK != nil || r.StopSequences != nil { + params := &schemas.ModelParameters{} + + if r.MaxTokens > 0 { + params.MaxTokens = &r.MaxTokens + } + if r.Temperature != nil { + params.Temperature = r.Temperature + } + if r.TopP != nil { + params.TopP = r.TopP + } + if r.TopK != nil { + params.TopK = r.TopK + } + if r.StopSequences != nil { + params.StopSequences = r.StopSequences + } + + bifrostReq.Params = params + } + + // Convert tools + if r.Tools != nil { + tools := []schemas.Tool{} + for _, tool := range *r.Tools { + // Convert input_schema to FunctionParameters + params := schemas.FunctionParameters{ + Type: "object", + } + if tool.InputSchema != nil { + params.Type = tool.InputSchema.Type + params.Required = tool.InputSchema.Required + params.Properties = tool.InputSchema.Properties + } + + tools = append(tools, schemas.Tool{ + Type: "function", + Function: schemas.Function{ + Name: tool.Name, + Description: tool.Description, + Parameters: params, + }, + }) + } + if bifrostReq.Params == nil { + bifrostReq.Params = &schemas.ModelParameters{} + } + bifrostReq.Params.Tools = &tools + } + + // Convert tool choice + if r.ToolChoice != nil { + if bifrostReq.Params == nil { + bifrostReq.Params = &schemas.ModelParameters{} + } + toolChoice := &schemas.ToolChoice{ + ToolChoiceStruct: &schemas.ToolChoiceStruct{ + Type: func() schemas.ToolChoiceType { + if r.ToolChoice.Type == "tool" { + return schemas.ToolChoiceTypeFunction + } + return schemas.ToolChoiceType(r.ToolChoice.Type) + }(), + }, + } + if r.ToolChoice.Type == "tool" && r.ToolChoice.Name != "" { + toolChoice.ToolChoiceStruct.Function = schemas.ToolChoiceFunction{ + Name: r.ToolChoice.Name, + } + } + bifrostReq.Params.ToolChoice = toolChoice + } + + // Apply parameter validation + if bifrostReq.Params != nil { + bifrostReq.Params = integrations.ValidateAndFilterParamsForProvider(provider, bifrostReq.Params) + } + + return bifrostReq +} + +// Helper function to convert interface{} to JSON string +func jsonifyInput(input interface{}) string { + if input == nil { + return "{}" + } + jsonBytes, err := json.Marshal(input) + if err != nil { + return "{}" + } + return string(jsonBytes) +} + +// DeriveAnthropicFromBifrostResponse converts a Bifrost response to Anthropic format +func DeriveAnthropicFromBifrostResponse(bifrostResp *schemas.BifrostResponse) *AnthropicMessageResponse { + if bifrostResp == nil { + return nil + } + + anthropicResp := &AnthropicMessageResponse{ + ID: bifrostResp.ID, + Type: "message", + Role: string(schemas.ModelChatMessageRoleAssistant), + Model: bifrostResp.Model, + } + + // Convert usage information + if bifrostResp.Usage != nil { + anthropicResp.Usage = &AnthropicUsage{ + InputTokens: bifrostResp.Usage.PromptTokens, + OutputTokens: bifrostResp.Usage.CompletionTokens, + } + } + + // Convert choices to content + var content []AnthropicContentBlock + if len(bifrostResp.Choices) > 0 { + choice := bifrostResp.Choices[0] // Anthropic typically returns one choice + + if choice.FinishReason != nil { + mappedReason := integrations.MapFinishReasonToProvider(*choice.FinishReason, schemas.Anthropic) + anthropicResp.StopReason = &mappedReason + } + if choice.StopString != nil { + anthropicResp.StopSequence = choice.StopString + } + + // Add thinking content if present + if choice.Message.AssistantMessage != nil && choice.Message.AssistantMessage.Thought != nil && *choice.Message.AssistantMessage.Thought != "" { + content = append(content, AnthropicContentBlock{ + Type: "thinking", + Text: choice.Message.AssistantMessage.Thought, + }) + } + + // Add text content + if choice.Message.Content.ContentStr != nil && *choice.Message.Content.ContentStr != "" { + content = append(content, AnthropicContentBlock{ + Type: "text", + Text: choice.Message.Content.ContentStr, + }) + } else if choice.Message.Content.ContentBlocks != nil { + for _, block := range *choice.Message.Content.ContentBlocks { + if block.Text != nil { + content = append(content, AnthropicContentBlock{ + Type: "text", + Text: block.Text, + }) + } + } + } + + // Add tool calls as tool_use content + if choice.Message.AssistantMessage != nil && choice.Message.AssistantMessage.ToolCalls != nil { + for _, toolCall := range *choice.Message.AssistantMessage.ToolCalls { + // Parse arguments JSON string back to map + var input map[string]interface{} + if toolCall.Function.Arguments != "" { + if err := json.Unmarshal([]byte(toolCall.Function.Arguments), &input); err != nil { + input = map[string]interface{}{} + } + } else { + input = map[string]interface{}{} + } + + content = append(content, AnthropicContentBlock{ + Type: "tool_use", + ID: toolCall.ID, + Name: toolCall.Function.Name, + Input: input, + }) + } + } + } + + if content == nil { + content = []AnthropicContentBlock{} + } + + anthropicResp.Content = content + return anthropicResp +} + +// DeriveAnthropicStreamFromBifrostResponse converts a Bifrost streaming response to Anthropic SSE string format +func DeriveAnthropicStreamFromBifrostResponse(bifrostResp *schemas.BifrostResponse) string { + if bifrostResp == nil { + return "" + } + + streamResp := &AnthropicStreamResponse{} + + // Handle different streaming event types based on the response content + if len(bifrostResp.Choices) > 0 { + choice := bifrostResp.Choices[0] // Anthropic typically returns one choice + + // Handle streaming responses + if choice.BifrostStreamResponseChoice != nil { + delta := choice.BifrostStreamResponseChoice.Delta + + // Handle text content deltas + if delta.Content != nil { + streamResp.Type = "content_block_delta" + streamResp.Index = &choice.Index + streamResp.Delta = &AnthropicStreamDelta{ + Type: "text_delta", + Text: delta.Content, + } + } else if delta.Thought != nil { + // Handle thinking content deltas + streamResp.Type = "content_block_delta" + streamResp.Index = &choice.Index + streamResp.Delta = &AnthropicStreamDelta{ + Type: "thinking_delta", + Thinking: delta.Thought, + } + } else if len(delta.ToolCalls) > 0 { + // Handle tool call deltas + toolCall := delta.ToolCalls[0] // Take first tool call + + if toolCall.Function.Name != nil && *toolCall.Function.Name != "" { + // Tool use start event + streamResp.Type = "content_block_start" + streamResp.Index = &choice.Index + streamResp.ContentBlock = &AnthropicContentBlock{ + Type: "tool_use", + ID: toolCall.ID, + Name: toolCall.Function.Name, + } + } else if toolCall.Function.Arguments != "" { + // Tool input delta + streamResp.Type = "content_block_delta" + streamResp.Index = &choice.Index + streamResp.Delta = &AnthropicStreamDelta{ + Type: "input_json_delta", + PartialJSON: &toolCall.Function.Arguments, + } + } + } else if choice.FinishReason != nil && *choice.FinishReason != "" { + // Handle finish reason - map back to Anthropic format + stopReason := integrations.MapFinishReasonToProvider(*choice.FinishReason, schemas.Anthropic) + streamResp.Type = "message_delta" + streamResp.Delta = &AnthropicStreamDelta{ + Type: "message_delta", + StopReason: &stopReason, + } + } + + } else if choice.BifrostNonStreamResponseChoice != nil { + // Handle non-streaming response converted to streaming format + streamResp.Type = "message_start" + + // Create message start event + streamMessage := &AnthropicStreamMessage{ + ID: bifrostResp.ID, + Type: "message", + Role: string(choice.BifrostNonStreamResponseChoice.Message.Role), + Model: bifrostResp.Model, + } + + // Convert content + var content []AnthropicContentBlock + if choice.BifrostNonStreamResponseChoice.Message.Content.ContentStr != nil { + content = append(content, AnthropicContentBlock{ + Type: "text", + Text: choice.BifrostNonStreamResponseChoice.Message.Content.ContentStr, + }) + } + + streamMessage.Content = content + streamResp.Message = streamMessage + } + } + + // Handle usage information + if bifrostResp.Usage != nil { + if streamResp.Type == "" { + streamResp.Type = "message_delta" + } + streamResp.Usage = &AnthropicUsage{ + InputTokens: bifrostResp.Usage.PromptTokens, + OutputTokens: bifrostResp.Usage.CompletionTokens, + } + } + + // Set common fields + if bifrostResp.ID != "" { + streamResp.ID = &bifrostResp.ID + } + if bifrostResp.Model != "" { + streamResp.Model = &bifrostResp.Model + } + + // Default to empty content_block_delta if no specific type was set + if streamResp.Type == "" { + streamResp.Type = "content_block_delta" + streamResp.Index = bifrost.Ptr(0) + streamResp.Delta = &AnthropicStreamDelta{ + Type: "text_delta", + Text: bifrost.Ptr(""), + } + } + + // Marshal to JSON and format as SSE + jsonData, err := json.Marshal(streamResp) + if err != nil { + return "" + } + + // Format as Anthropic SSE + return fmt.Sprintf("event: %s\ndata: %s\n\n", streamResp.Type, jsonData) +} + +// DeriveAnthropicErrorFromBifrostError derives a AnthropicMessageError from a BifrostError +func DeriveAnthropicErrorFromBifrostError(bifrostErr *schemas.BifrostError) *AnthropicMessageError { + if bifrostErr == nil { + return nil + } + + // Provide blank strings for nil pointer fields + errorType := "" + if bifrostErr.Type != nil { + errorType = *bifrostErr.Type + } + + // Handle nested error fields with nil checks + errorStruct := AnthropicMessageErrorStruct{ + Type: "", + Message: bifrostErr.Error.Message, + } + + if bifrostErr.Error.Type != nil { + errorStruct.Type = *bifrostErr.Error.Type + } + + return &AnthropicMessageError{ + Type: errorType, + Error: errorStruct, + } +} + +// DeriveAnthropicStreamFromBifrostError derives an Anthropic streaming error from a BifrostError in SSE format +func DeriveAnthropicStreamFromBifrostError(bifrostErr *schemas.BifrostError) string { + errorResp := DeriveAnthropicErrorFromBifrostError(bifrostErr) + if errorResp == nil { + return "" + } + + // Marshal to JSON + jsonData, err := json.Marshal(errorResp) + if err != nil { + return "" + } + + // Format as Anthropic SSE error event + return fmt.Sprintf("event: error\ndata: %s\n\n", jsonData) +} diff --git a/transports/bifrost-http/integrations/genai/router.go b/transports/bifrost-http/integrations/genai/router.go new file mode 100644 index 000000000..5192cea80 --- /dev/null +++ b/transports/bifrost-http/integrations/genai/router.go @@ -0,0 +1,117 @@ +package genai + +import ( + "errors" + "fmt" + "strings" + + bifrost "github.com/maximhq/bifrost/core" + "github.com/maximhq/bifrost/core/schemas" + "github.com/maximhq/bifrost/transports/bifrost-http/integrations" + "github.com/maximhq/bifrost/transports/bifrost-http/lib" + "github.com/valyala/fasthttp" +) + +// GenAIRouter holds route registrations for genai endpoints. +type GenAIRouter struct { + *integrations.GenericRouter +} + +// CreateGenAIRouteConfigs creates a route configurations for GenAI endpoints. +func CreateGenAIRouteConfigs(pathPrefix string) []integrations.RouteConfig { + var routes []integrations.RouteConfig + + // Chat completions endpoint + routes = append(routes, integrations.RouteConfig{ + Path: pathPrefix + "/v1beta/models/{model:*}", + Method: "POST", + GetRequestTypeInstance: func() interface{} { + return &GeminiChatRequest{} + }, + RequestConverter: func(req interface{}) (*schemas.BifrostRequest, error) { + if geminiReq, ok := req.(*GeminiChatRequest); ok { + return geminiReq.ConvertToBifrostRequest(), nil + } + return nil, errors.New("invalid request type") + }, + ResponseConverter: func(resp *schemas.BifrostResponse) (interface{}, error) { + return DeriveGenAIFromBifrostResponse(resp), nil + }, + ErrorConverter: func(err *schemas.BifrostError) interface{} { + return DeriveGeminiErrorFromBifrostError(err) + }, + StreamConfig: &integrations.StreamConfig{ + ResponseConverter: func(resp *schemas.BifrostResponse) (interface{}, error) { + return DeriveGeminiStreamFromBifrostResponse(resp), nil + }, + ErrorConverter: func(err *schemas.BifrostError) interface{} { + return DeriveGeminiStreamFromBifrostError(err) + }, + }, + PreCallback: extractAndSetModelFromURL, + }) + + return routes +} + +// NewGenAIRouter creates a new GenAIRouter with the given bifrost client. +func NewGenAIRouter(client *bifrost.Bifrost, handlerStore lib.HandlerStore) *GenAIRouter { + return &GenAIRouter{ + GenericRouter: integrations.NewGenericRouter(client, handlerStore, CreateGenAIRouteConfigs("/genai")), + } +} + +var embeddingPaths = []string{ + ":embedContent", + ":batchEmbedContents", + ":predict", +} + +// extractAndSetModelFromURL extracts model from URL and sets it in the request +func extractAndSetModelFromURL(ctx *fasthttp.RequestCtx, req interface{}) error { + model := ctx.UserValue("model") + if model == nil { + return fmt.Errorf("model parameter is required") + } + + modelStr := model.(string) + + // Check if this is an embedding request + isEmbedding := false + for _, path := range embeddingPaths { + if strings.HasSuffix(modelStr, path) { + isEmbedding = true + break + } + } + + // Check if this is a streaming request + isStreaming := strings.HasSuffix(modelStr, ":streamGenerateContent") + + // Remove Google GenAI API endpoint suffixes if present + for _, sfx := range []string{ + ":streamGenerateContent", + ":generateContent", + ":countTokens", + ":embedContent", + ":batchEmbedContents", + ":predict", + } { + modelStr = strings.TrimSuffix(modelStr, sfx) + } + + // Remove trailing colon if present + if len(modelStr) > 0 && modelStr[len(modelStr)-1] == ':' { + modelStr = modelStr[:len(modelStr)-1] + } + + // Set the model and flags in the request + if geminiReq, ok := req.(*GeminiChatRequest); ok { + geminiReq.Model = modelStr + geminiReq.Stream = isStreaming + geminiReq.IsEmbedding = isEmbedding + return nil + } + + return fmt.Errorf("invalid request type for GenAI") +} diff --git a/transports/bifrost-http/integrations/genai/types.go b/transports/bifrost-http/integrations/genai/types.go new file mode 100644 index 000000000..3b8100eb0 --- /dev/null +++ b/transports/bifrost-http/integrations/genai/types.go @@ -0,0 +1,964 @@ +package genai + +import ( + "encoding/base64" + "encoding/json" + "fmt" + "strings" + + bifrost "github.com/maximhq/bifrost/core" + "github.com/maximhq/bifrost/core/schemas" + "github.com/maximhq/bifrost/transports/bifrost-http/integrations" + genai_sdk "google.golang.org/genai" +) + +var fnTypePtr = bifrost.Ptr(string(schemas.ToolChoiceTypeFunction)) + +// EmbeddingRequest represents a single embedding request in a batch +type EmbeddingRequest struct { + Content *CustomContent `json:"content,omitempty"` + TaskType *string `json:"taskType,omitempty"` + Title *string `json:"title,omitempty"` + OutputDimensionality *int `json:"outputDimensionality,omitempty"` + Model string `json:"model,omitempty"` +} + +// CustomBlob handles URL-safe base64 decoding for Google GenAI requests +type CustomBlob struct { + Data []byte `json:"data,omitempty"` + MIMEType string `json:"mimeType,omitempty"` +} + +// UnmarshalJSON custom unmarshalling to handle URL-safe base64 encoding +func (b *CustomBlob) UnmarshalJSON(data []byte) error { + // First unmarshal into a temporary struct with string data + var temp struct { + Data string `json:"data,omitempty"` + MIMEType string `json:"mimeType,omitempty"` + } + + if err := json.Unmarshal(data, &temp); err != nil { + return err + } + + b.MIMEType = temp.MIMEType + + if temp.Data != "" { + // Convert URL-safe base64 to standard base64 + standardBase64 := strings.ReplaceAll(strings.ReplaceAll(temp.Data, "_", "/"), "-", "+") + + // Add padding if necessary + switch len(standardBase64) % 4 { + case 2: + standardBase64 += "==" + case 3: + standardBase64 += "=" + } + + decoded, err := base64.StdEncoding.DecodeString(standardBase64) + if err != nil { + return fmt.Errorf("failed to decode base64 data: %v", err) + } + b.Data = decoded + } + + return nil +} + +// CustomPart handles Google GenAI Part with custom Blob unmarshalling +type CustomPart struct { + VideoMetadata *genai_sdk.VideoMetadata `json:"videoMetadata,omitempty"` + Thought bool `json:"thought,omitempty"` + CodeExecutionResult *genai_sdk.CodeExecutionResult `json:"codeExecutionResult,omitempty"` + ExecutableCode *genai_sdk.ExecutableCode `json:"executableCode,omitempty"` + FileData *genai_sdk.FileData `json:"fileData,omitempty"` + FunctionCall *genai_sdk.FunctionCall `json:"functionCall,omitempty"` + FunctionResponse *genai_sdk.FunctionResponse `json:"functionResponse,omitempty"` + InlineData *CustomBlob `json:"inlineData,omitempty"` + Text string `json:"text,omitempty"` +} + +// ToGenAIPart converts CustomPart to genai_sdk.Part +func (p *CustomPart) ToGenAIPart() *genai_sdk.Part { + part := &genai_sdk.Part{ + VideoMetadata: p.VideoMetadata, + Thought: p.Thought, + CodeExecutionResult: p.CodeExecutionResult, + ExecutableCode: p.ExecutableCode, + FileData: p.FileData, + FunctionCall: p.FunctionCall, + FunctionResponse: p.FunctionResponse, + Text: p.Text, + } + + if p.InlineData != nil { + part.InlineData = &genai_sdk.Blob{ + Data: p.InlineData.Data, + MIMEType: p.InlineData.MIMEType, + } + } + + return part +} + +// CustomContent handles Google GenAI Content with custom Part unmarshalling +type CustomContent struct { + Parts []*CustomPart `json:"parts,omitempty"` + Role string `json:"role,omitempty"` +} + +// ToGenAIContent converts CustomContent to genai_sdk.Content +func (c *CustomContent) ToGenAIContent() genai_sdk.Content { + parts := make([]*genai_sdk.Part, len(c.Parts)) + for i, part := range c.Parts { + parts[i] = part.ToGenAIPart() + } + + return genai_sdk.Content{ + Parts: parts, + Role: c.Role, + } +} + +// ensureExtraParams ensures that bifrostReq.Params and bifrostReq.Params.ExtraParams are initialized +func ensureExtraParams(bifrostReq *schemas.BifrostRequest) { + if bifrostReq.Params == nil { + bifrostReq.Params = &schemas.ModelParameters{ + ExtraParams: make(map[string]interface{}), + } + } + if bifrostReq.Params.ExtraParams == nil { + bifrostReq.Params.ExtraParams = make(map[string]interface{}) + } +} + +type GeminiChatRequest struct { + Model string `json:"model,omitempty"` // Model field for explicit model specification + Contents []CustomContent `json:"contents,omitempty"` // For chat completion requests + Requests []EmbeddingRequest `json:"requests,omitempty"` // For batch embedding requests + SystemInstruction *CustomContent `json:"systemInstruction,omitempty"` + GenerationConfig genai_sdk.GenerationConfig `json:"generationConfig,omitempty"` + SafetySettings []genai_sdk.SafetySetting `json:"safetySettings,omitempty"` + Tools []genai_sdk.Tool `json:"tools,omitempty"` + ToolConfig genai_sdk.ToolConfig `json:"toolConfig,omitempty"` + Labels map[string]string `json:"labels,omitempty"` + CachedContent string `json:"cachedContent,omitempty"` + ResponseModalities []string `json:"responseModalities,omitempty"` + Stream bool `json:"-"` // Internal field to track streaming requests + IsEmbedding bool `json:"-"` // Internal field to track if this is an embedding request + + // Embedding-specific parameters + TaskType *string `json:"taskType,omitempty"` + Title *string `json:"title,omitempty"` + OutputDimensionality *int `json:"outputDimensionality,omitempty"` +} + +// IsStreamingRequested implements the StreamingRequest interface +func (r *GeminiChatRequest) IsStreamingRequested() bool { + return r.Stream && !r.IsEmbedding +} + +// GeminiChatRequestError represents a Gemini chat completion error response +type GeminiChatRequestError struct { + Error GeminiChatRequestErrorStruct `json:"error"` // Error details following Google API format +} + +// GeminiChatRequestErrorStruct represents the error structure of a Gemini chat completion error response +type GeminiChatRequestErrorStruct struct { + Code int `json:"code"` // HTTP status code + Message string `json:"message"` // Error message + Status string `json:"status"` // Error status string (e.g., "INVALID_REQUEST") +} + +func (r *GeminiChatRequest) ConvertToBifrostRequest() *schemas.BifrostRequest { + provider, model := integrations.ParseModelString(r.Model, schemas.Gemini, false) + + if provider == schemas.Vertex && !r.IsEmbedding { + // Add google/ prefix for Bifrost if not already present + if !strings.HasPrefix(model, "google/") { + model = "google/" + model + } + } + + // Handle embedding requests + if r.IsEmbedding { + // Extract texts from content (embedding requests) or contents (chat completion requests) + var texts []string + + // Check for batch embedding requests first + if len(r.Requests) > 0 { + for _, req := range r.Requests { + if req.Content != nil { + for _, part := range req.Content.Parts { + if part.Text != "" { + texts = append(texts, part.Text) + } + } + } + } + } + + // Fallback to contents (plural) for backward compatibility + if len(texts) == 0 { + for _, content := range r.Contents { + for _, part := range content.Parts { + if part.Text != "" { + texts = append(texts, part.Text) + } + } + } + } + + // Create embedding input + embeddingInput := &schemas.EmbeddingInput{ + Texts: texts, + } + + bifrostReq := &schemas.BifrostRequest{ + Provider: provider, + Model: model, + Input: schemas.RequestInput{ + EmbeddingInput: embeddingInput, + }, + } + + // Convert embedding parameters + params := r.convertEmbeddingParameters() + if params != nil { + bifrostReq.Params = params + } + + return bifrostReq + } + + // Handle chat completion requests + bifrostReq := &schemas.BifrostRequest{ + Provider: provider, + Model: model, + Input: schemas.RequestInput{ + ChatCompletionInput: &[]schemas.BifrostMessage{}, + }, + } + + messages := []schemas.BifrostMessage{} + + allGenAiMessages := []genai_sdk.Content{} + if r.SystemInstruction != nil { + allGenAiMessages = append(allGenAiMessages, r.SystemInstruction.ToGenAIContent()) + } + for _, content := range r.Contents { + allGenAiMessages = append(allGenAiMessages, content.ToGenAIContent()) + } + + for _, content := range allGenAiMessages { + if len(content.Parts) == 0 { + continue + } + + // Handle multiple parts - collect all content and tool calls + var toolCalls []schemas.ToolCall + var contentBlocks []schemas.ContentBlock + var thoughtStr string // Track thought content for assistant/model + + for _, part := range content.Parts { + switch { + case part.Text != "": + // Handle thought content specially for assistant messages + if part.Thought && + (content.Role == string(schemas.ModelChatMessageRoleAssistant) || content.Role == string(genai_sdk.RoleModel)) { + thoughtStr = thoughtStr + part.Text + "\n" + } else { + contentBlocks = append(contentBlocks, schemas.ContentBlock{ + Type: schemas.ContentBlockTypeText, + Text: &part.Text, + }) + } + + case part.FunctionCall != nil: + // Only add function calls for assistant messages + if content.Role == string(schemas.ModelChatMessageRoleAssistant) || content.Role == string(genai_sdk.RoleModel) { + jsonArgs, err := json.Marshal(part.FunctionCall.Args) + if err != nil { + jsonArgs = []byte(fmt.Sprintf("%v", part.FunctionCall.Args)) + } + id := part.FunctionCall.ID // create local copy + name := part.FunctionCall.Name // create local copy + toolCall := schemas.ToolCall{ + ID: bifrost.Ptr(id), + Type: fnTypePtr, + Function: schemas.FunctionCall{ + Name: &name, + Arguments: string(jsonArgs), + }, + } + toolCalls = append(toolCalls, toolCall) + } + + case part.FunctionResponse != nil: + // Create a separate tool response message + responseContent, err := json.Marshal(part.FunctionResponse.Response) + if err != nil { + responseContent = []byte(fmt.Sprintf("%v", part.FunctionResponse.Response)) + } + + toolResponseMsg := schemas.BifrostMessage{ + Role: schemas.ModelChatMessageRoleTool, + Content: schemas.MessageContent{ + ContentStr: bifrost.Ptr(string(responseContent)), + }, + ToolMessage: &schemas.ToolMessage{ + ToolCallID: &part.FunctionResponse.Name, + }, + } + + messages = append(messages, toolResponseMsg) + + case part.InlineData != nil: + // Handle inline images/media - only append if it's actually an image + if isImageMimeType(part.InlineData.MIMEType) { + contentBlocks = append(contentBlocks, schemas.ContentBlock{ + Type: schemas.ContentBlockTypeImage, + ImageURL: &schemas.ImageURLStruct{ + URL: fmt.Sprintf("data:%s;base64,%s", part.InlineData.MIMEType, base64.StdEncoding.EncodeToString(part.InlineData.Data)), + }, + }) + } + + case part.FileData != nil: + // Handle file data - only append if it's actually an image + if isImageMimeType(part.FileData.MIMEType) { + contentBlocks = append(contentBlocks, schemas.ContentBlock{ + Type: schemas.ContentBlockTypeImage, + ImageURL: &schemas.ImageURLStruct{ + URL: part.FileData.FileURI, + }, + }) + } + + case part.ExecutableCode != nil: + // Handle executable code as text content + codeText := fmt.Sprintf("```%s\n%s\n```", part.ExecutableCode.Language, part.ExecutableCode.Code) + contentBlocks = append(contentBlocks, schemas.ContentBlock{ + Type: schemas.ContentBlockTypeText, + Text: &codeText, + }) + + case part.CodeExecutionResult != nil: + // Handle code execution results as text content + resultText := fmt.Sprintf("Code execution result (%s):\n%s", part.CodeExecutionResult.Outcome, part.CodeExecutionResult.Output) + contentBlocks = append(contentBlocks, schemas.ContentBlock{ + Type: schemas.ContentBlockTypeText, + Text: &resultText, + }) + } + } + + // Only create message if there's actual content, tool calls, or thought content + if len(contentBlocks) > 0 || len(toolCalls) > 0 || thoughtStr != "" { + // Create main message with content blocks + bifrostMsg := schemas.BifrostMessage{ + Role: func(r string) schemas.ModelChatMessageRole { + if r == string(genai_sdk.RoleModel) { // GenAI's internal alias + return schemas.ModelChatMessageRoleAssistant + } + return schemas.ModelChatMessageRole(r) + }(content.Role), + } + + // Set content only if there are content blocks + if len(contentBlocks) > 0 { + bifrostMsg.Content = schemas.MessageContent{ + ContentBlocks: &contentBlocks, + } + } + + // Set assistant-specific fields for assistant/model messages + if content.Role == string(schemas.ModelChatMessageRoleAssistant) || content.Role == string(genai_sdk.RoleModel) { + if len(toolCalls) > 0 || thoughtStr != "" { + bifrostMsg.AssistantMessage = &schemas.AssistantMessage{} + if len(toolCalls) > 0 { + bifrostMsg.AssistantMessage.ToolCalls = &toolCalls + } + if thoughtStr != "" { + bifrostMsg.AssistantMessage.Thought = &thoughtStr + } + } + } + + messages = append(messages, bifrostMsg) + } + } + + bifrostReq.Input.ChatCompletionInput = &messages + + // Convert generation config to parameters + if params := r.convertGenerationConfigToParams(); params != nil { + bifrostReq.Params = params + } + + // Convert safety settings + if len(r.SafetySettings) > 0 { + ensureExtraParams(bifrostReq) + bifrostReq.Params.ExtraParams["safety_settings"] = r.SafetySettings + } + + // Convert additional request fields + if r.CachedContent != "" { + ensureExtraParams(bifrostReq) + bifrostReq.Params.ExtraParams["cached_content"] = r.CachedContent + } + + // Convert response modalities + if len(r.ResponseModalities) > 0 { + ensureExtraParams(bifrostReq) + bifrostReq.Params.ExtraParams["response_modalities"] = r.ResponseModalities + } + + // Convert labels + if len(r.Labels) > 0 { + ensureExtraParams(bifrostReq) + bifrostReq.Params.ExtraParams["labels"] = r.Labels + } + + // Convert tools and tool config + if len(r.Tools) > 0 { + ensureExtraParams(bifrostReq) + + tools := make([]schemas.Tool, 0, len(r.Tools)) + for _, tool := range r.Tools { + if len(tool.FunctionDeclarations) > 0 { + for _, fn := range tool.FunctionDeclarations { + bifrostTool := schemas.Tool{ + Type: "function", + Function: schemas.Function{ + Name: fn.Name, + Description: fn.Description, + }, + } + // Convert parameters schema if present + if fn.Parameters != nil { + bifrostTool.Function.Parameters = r.convertSchemaToFunctionParameters(fn.Parameters) + } + tools = append(tools, bifrostTool) + } + } + // Handle other tool types (Retrieval, GoogleSearch, etc.) as ExtraParams + if tool.Retrieval != nil { + bifrostReq.Params.ExtraParams["retrieval"] = tool.Retrieval + } + if tool.GoogleSearch != nil { + bifrostReq.Params.ExtraParams["google_search"] = tool.GoogleSearch + } + if tool.CodeExecution != nil { + bifrostReq.Params.ExtraParams["code_execution"] = tool.CodeExecution + } + } + + if len(tools) > 0 { + bifrostReq.Params.Tools = &tools + } + } + + // Convert tool config + if r.ToolConfig.FunctionCallingConfig != nil || r.ToolConfig.RetrievalConfig != nil { + ensureExtraParams(bifrostReq) + bifrostReq.Params.ExtraParams["tool_config"] = r.ToolConfig + } + + return bifrostReq +} + +// convertEmbeddingParameters converts Gemini embedding request parameters to ModelParameters +func (r *GeminiChatRequest) convertEmbeddingParameters() *schemas.ModelParameters { + params := &schemas.ModelParameters{ + ExtraParams: make(map[string]interface{}), + } + + // Check for parameters from batch embedding requests first + if len(r.Requests) > 0 { + // Use parameters from the first request in the batch + firstReq := r.Requests[0] + if firstReq.TaskType != nil { + params.ExtraParams["taskType"] = *firstReq.TaskType + } + if firstReq.Title != nil { + params.ExtraParams["title"] = *firstReq.Title + } + if firstReq.OutputDimensionality != nil { + params.Dimensions = firstReq.OutputDimensionality + } + } else { + // Fallback to top-level embedding parameters for single requests + if r.TaskType != nil { + params.ExtraParams["taskType"] = *r.TaskType + } + if r.Title != nil { + params.ExtraParams["title"] = *r.Title + } + if r.OutputDimensionality != nil { + params.Dimensions = r.OutputDimensionality + } + } + + return params +} + +// convertGenerationConfigToParams converts Gemini GenerationConfig to ModelParameters +func (r *GeminiChatRequest) convertGenerationConfigToParams() *schemas.ModelParameters { + params := &schemas.ModelParameters{ + ExtraParams: make(map[string]interface{}), + } + + config := r.GenerationConfig + + // Map generation config fields to parameters + if config.Temperature != nil { + temp := float64(*config.Temperature) + params.Temperature = &temp + } + if config.TopP != nil { + params.TopP = bifrost.Ptr(float64(*config.TopP)) + } + if config.TopK != nil { + params.TopK = bifrost.Ptr(int(*config.TopK)) + } + if config.MaxOutputTokens > 0 { + maxTokens := int(config.MaxOutputTokens) + params.MaxTokens = &maxTokens + } + if config.CandidateCount > 0 { + params.ExtraParams["candidate_count"] = config.CandidateCount + } + if len(config.StopSequences) > 0 { + params.StopSequences = &config.StopSequences + } + if config.PresencePenalty != nil { + params.PresencePenalty = bifrost.Ptr(float64(*config.PresencePenalty)) + } + if config.FrequencyPenalty != nil { + params.FrequencyPenalty = bifrost.Ptr(float64(*config.FrequencyPenalty)) + } + if config.Seed != nil { + params.ExtraParams["seed"] = *config.Seed + } + if config.ResponseMIMEType != "" { + params.ExtraParams["response_mime_type"] = config.ResponseMIMEType + } + if config.ResponseLogprobs { + params.ExtraParams["response_logprobs"] = config.ResponseLogprobs + } + if config.Logprobs != nil { + params.ExtraParams["logprobs"] = *config.Logprobs + } + + return params +} + +// convertSchemaToFunctionParameters converts genai.Schema to schemas.FunctionParameters +func (r *GeminiChatRequest) convertSchemaToFunctionParameters(schema *genai_sdk.Schema) schemas.FunctionParameters { + params := schemas.FunctionParameters{ + Type: string(schema.Type), + } + + if schema.Description != "" { + params.Description = &schema.Description + } + + if len(schema.Required) > 0 { + params.Required = schema.Required + } + + if len(schema.Properties) > 0 { + params.Properties = make(map[string]interface{}) + for k, v := range schema.Properties { + params.Properties[k] = v + } + } + + if len(schema.Enum) > 0 { + params.Enum = &schema.Enum + } + + return params +} + +func DeriveGenAIFromBifrostResponse(bifrostResp *schemas.BifrostResponse) interface{} { + if bifrostResp == nil { + return nil + } + + // Check if this is an embedding response by looking for embedding data + if len(bifrostResp.Data) > 0 { + // This is an embedding response + return DeriveGeminiEmbeddingFromBifrostResponse(bifrostResp) + } + + // This is a chat completion response + genaiResp := &genai_sdk.GenerateContentResponse{ + Candidates: make([]*genai_sdk.Candidate, len(bifrostResp.Choices)), + } + + if bifrostResp.Usage != nil { + genaiResp.UsageMetadata = &genai_sdk.GenerateContentResponseUsageMetadata{ + PromptTokenCount: int32(bifrostResp.Usage.PromptTokens), + CandidatesTokenCount: int32(bifrostResp.Usage.CompletionTokens), + TotalTokenCount: int32(bifrostResp.Usage.TotalTokens), + } + } + + for i, choice := range bifrostResp.Choices { + candidate := &genai_sdk.Candidate{ + Index: int32(choice.Index), + } + if choice.FinishReason != nil { + candidate.FinishReason = genai_sdk.FinishReason(*choice.FinishReason) + } + + if bifrostResp.Usage != nil { + candidate.TokenCount = int32(bifrostResp.Usage.CompletionTokens) + } + + parts := []*genai_sdk.Part{} + if choice.Message.Content.ContentStr != nil && *choice.Message.Content.ContentStr != "" { + parts = append(parts, &genai_sdk.Part{Text: *choice.Message.Content.ContentStr}) + } else if choice.Message.Content.ContentBlocks != nil { + for _, block := range *choice.Message.Content.ContentBlocks { + if block.Text != nil { + parts = append(parts, &genai_sdk.Part{Text: *block.Text}) + } + } + } + + // Handle tool calls + if choice.Message.AssistantMessage != nil && choice.Message.AssistantMessage.ToolCalls != nil { + for _, toolCall := range *choice.Message.AssistantMessage.ToolCalls { + argsMap := make(map[string]interface{}) + if toolCall.Function.Arguments != "" { + // Attempt to unmarshal arguments, but don't fail if it's not valid JSON, + // as BifrostResponse.FunctionCall.Arguments is a string. + // genai.FunctionCall.Args expects map[string]any. + json.Unmarshal([]byte(toolCall.Function.Arguments), &argsMap) + } + if toolCall.Function.Name != nil { + fc := &genai_sdk.FunctionCall{ + Name: *toolCall.Function.Name, + Args: argsMap, + } + if toolCall.ID != nil { + fc.ID = *toolCall.ID + } + parts = append(parts, &genai_sdk.Part{FunctionCall: fc}) + } + } + } + + // Handle thinking content if present + if choice.Message.AssistantMessage != nil && choice.Message.AssistantMessage.Thought != nil && *choice.Message.AssistantMessage.Thought != "" { + parts = append(parts, &genai_sdk.Part{ + Text: *choice.Message.AssistantMessage.Thought, + Thought: true, + }) + } + + if len(parts) > 0 { + candidate.Content = &genai_sdk.Content{ + Parts: parts, + Role: string(choice.Message.Role), + } + } + + // Handle safety ratings if available (from ExtraFields) + if bifrostResp.ExtraFields.RawResponse != nil { + if rawMap, ok := bifrostResp.ExtraFields.RawResponse.(map[string]interface{}); ok { + if candidates, ok := rawMap["candidates"].([]interface{}); ok && len(candidates) > i { + if candidateMap, ok := candidates[i].(map[string]interface{}); ok { + if safetyRatings, ok := candidateMap["safetyRatings"].([]interface{}); ok { + var ratings []*genai_sdk.SafetyRating + for _, rating := range safetyRatings { + if ratingMap, ok := rating.(map[string]interface{}); ok { + sr := &genai_sdk.SafetyRating{} + if category, ok := ratingMap["category"].(string); ok { + sr.Category = genai_sdk.HarmCategory(category) + } + if probability, ok := ratingMap["probability"].(string); ok { + sr.Probability = genai_sdk.HarmProbability(probability) + } + if blocked, ok := ratingMap["blocked"].(bool); ok { + sr.Blocked = blocked + } + ratings = append(ratings, sr) + } + } + candidate.SafetyRatings = ratings + } + } + } + } + } + + genaiResp.Candidates[i] = candidate + } + + return genaiResp +} + +// DeriveGeminiStreamFromBifrostResponse converts a Bifrost streaming response to Google GenAI streaming format +func DeriveGeminiStreamFromBifrostResponse(bifrostResp *schemas.BifrostResponse) *genai_sdk.GenerateContentResponse { + if bifrostResp == nil { + return nil + } + + genaiResp := &genai_sdk.GenerateContentResponse{ + Candidates: make([]*genai_sdk.Candidate, len(bifrostResp.Choices)), + } + + // Set usage metadata if available + if bifrostResp.Usage != nil { + genaiResp.UsageMetadata = &genai_sdk.GenerateContentResponseUsageMetadata{ + PromptTokenCount: int32(bifrostResp.Usage.PromptTokens), + CandidatesTokenCount: int32(bifrostResp.Usage.CompletionTokens), + TotalTokenCount: int32(bifrostResp.Usage.TotalTokens), + } + } + + // Convert choices to streaming format + for i, choice := range bifrostResp.Choices { + candidate := &genai_sdk.Candidate{ + Index: int32(choice.Index), + } + + // Set finish reason if present + if choice.FinishReason != nil { + candidate.FinishReason = genai_sdk.FinishReason(*choice.FinishReason) + } + + // Set token count if available + if bifrostResp.Usage != nil { + candidate.TokenCount = int32(bifrostResp.Usage.CompletionTokens) + } + + // Handle streaming response delta + var parts []*genai_sdk.Part + + if choice.BifrostStreamResponseChoice != nil { + // Convert streaming delta to parts + delta := choice.BifrostStreamResponseChoice.Delta + + // Handle text content delta + if delta.Content != nil && *delta.Content != "" { + parts = append(parts, &genai_sdk.Part{ + Text: *delta.Content, + }) + } + + // Handle thinking content delta + if delta.Thought != nil && *delta.Thought != "" { + parts = append(parts, &genai_sdk.Part{ + Text: *delta.Thought, + Thought: true, + }) + } + + // Handle tool call deltas + if len(delta.ToolCalls) > 0 { + for _, toolCall := range delta.ToolCalls { + if toolCall.Function.Name != nil && *toolCall.Function.Name != "" { + // Convert tool call arguments from JSON string to map + argsMap := make(map[string]interface{}) + if toolCall.Function.Arguments != "" { + json.Unmarshal([]byte(toolCall.Function.Arguments), &argsMap) + } + + fc := &genai_sdk.FunctionCall{ + Name: *toolCall.Function.Name, + Args: argsMap, + } + if toolCall.ID != nil { + fc.ID = *toolCall.ID + } + + parts = append(parts, &genai_sdk.Part{ + FunctionCall: fc, + }) + } + } + } + + } + + // Set content if we have parts + if len(parts) > 0 { + candidate.Content = &genai_sdk.Content{ + Parts: parts, + Role: string(schemas.ModelChatMessageRoleAssistant), // Streaming responses are typically from assistant + } + } + + genaiResp.Candidates[i] = candidate + } + + // Set response metadata + if bifrostResp.ID != "" { + genaiResp.ResponseID = bifrostResp.ID + } + if bifrostResp.Model != "" { + genaiResp.ModelVersion = bifrostResp.Model + } + + return genaiResp +} + +// GeminiEmbeddingResponse represents a Google GenAI embedding response +type GeminiEmbeddingResponse struct { + Embeddings []GeminiEmbedding `json:"embeddings"` + Metadata *EmbedContentMetadata `json:"metadata,omitempty"` +} + +// GeminiEmbedding represents a single embedding in the response +type GeminiEmbedding struct { + Values []float32 `json:"values"` + Statistics *ContentEmbeddingStatistics `json:"statistics,omitempty"` +} + +// EmbedContentMetadata represents request-level metadata for Vertex API +type EmbedContentMetadata struct { + BillableCharacterCount int32 `json:"billableCharacterCount,omitempty"` +} + +// ContentEmbeddingStatistics represents statistics of the input text +type ContentEmbeddingStatistics struct { + TokenCount int32 `json:"tokenCount,omitempty"` +} + +// DeriveGeminiEmbeddingFromBifrostResponse converts a Bifrost embedding response to Google GenAI format +func DeriveGeminiEmbeddingFromBifrostResponse(bifrostResp *schemas.BifrostResponse) *GeminiEmbeddingResponse { + if bifrostResp == nil || len(bifrostResp.Data) == 0 { + return nil + } + + genaiResp := &GeminiEmbeddingResponse{ + Embeddings: make([]GeminiEmbedding, len(bifrostResp.Data)), + } + + // Convert embeddings + for i, embedding := range bifrostResp.Data { + var values []float32 + if embedding.Embedding.EmbeddingArray != nil { + values = *embedding.Embedding.EmbeddingArray + } + + geminiEmbedding := GeminiEmbedding{ + Values: values, + } + + // Check for Vertex-specific statistics in response extra fields + if bifrostResp.ExtraFields.RawResponse != nil { + if rawMap, ok := bifrostResp.ExtraFields.RawResponse.(map[string]interface{}); ok { + // Check if this is an array of embeddings with individual statistics + if embeddings, ok := rawMap["embeddings"].([]interface{}); ok && len(embeddings) > i { + if embeddingMap, ok := embeddings[i].(map[string]interface{}); ok { + if statistics, ok := embeddingMap["statistics"].(map[string]interface{}); ok { + if tokenCount, ok := statistics["tokenCount"].(float64); ok { + geminiEmbedding.Statistics = &ContentEmbeddingStatistics{ + TokenCount: int32(tokenCount), + } + } + } + } + } + } + } + + genaiResp.Embeddings[i] = geminiEmbedding + } + + // Check for Vertex-specific metadata in response extra fields + if bifrostResp.ExtraFields.RawResponse != nil { + if rawMap, ok := bifrostResp.ExtraFields.RawResponse.(map[string]interface{}); ok { + if metadata, ok := rawMap["metadata"].(map[string]interface{}); ok { + if billableCharCount, ok := metadata["billableCharacterCount"].(float64); ok { + genaiResp.Metadata = &EmbedContentMetadata{ + BillableCharacterCount: int32(billableCharCount), + } + } + } + } + } + + return genaiResp +} + +// DeriveGeminiErrorFromBifrostError derives a GeminiChatRequestError from a BifrostError +func DeriveGeminiErrorFromBifrostError(bifrostErr *schemas.BifrostError) *GeminiChatRequestError { + if bifrostErr == nil { + return nil + } + + code := 500 + status := "" + + if bifrostErr.Error.Type != nil { + status = *bifrostErr.Error.Type + } + + if bifrostErr.StatusCode != nil { + code = *bifrostErr.StatusCode + } + + return &GeminiChatRequestError{ + Error: GeminiChatRequestErrorStruct{ + Code: code, + Message: bifrostErr.Error.Message, + Status: status, + }, + } +} + +// DeriveGeminiStreamFromBifrostError derives a Gemini streaming error from a BifrostError +func DeriveGeminiStreamFromBifrostError(bifrostErr *schemas.BifrostError) *GeminiChatRequestError { + // For streaming, we use the same error format as regular Gemini errors + return DeriveGeminiErrorFromBifrostError(bifrostErr) +} + +// isImageMimeType checks if a MIME type represents an image format +func isImageMimeType(mimeType string) bool { + if mimeType == "" { + return false + } + + // Convert to lowercase for case-insensitive comparison + mimeType = strings.ToLower(mimeType) + + // Remove any parameters (e.g., "image/jpeg; charset=utf-8" -> "image/jpeg") + if idx := strings.Index(mimeType, ";"); idx != -1 { + mimeType = strings.TrimSpace(mimeType[:idx]) + } + + // If it starts with "image/", it's an image + if strings.HasPrefix(mimeType, "image/") { + return true + } + + // Check for common image formats that might not have the "image/" prefix + commonImageTypes := []string{ + "jpeg", + "jpg", + "png", + "gif", + "webp", + "bmp", + "svg", + "tiff", + "ico", + "avif", + } + + // Check if the mimeType contains any of the common image type strings + for _, imageType := range commonImageTypes { + if strings.Contains(mimeType, imageType) { + return true + } + } + + return false +} diff --git a/transports/bifrost-http/integrations/langchain/router.go b/transports/bifrost-http/integrations/langchain/router.go new file mode 100644 index 000000000..c62cebe8b --- /dev/null +++ b/transports/bifrost-http/integrations/langchain/router.go @@ -0,0 +1,36 @@ +package langchain + +import ( + bifrost "github.com/maximhq/bifrost/core" + "github.com/maximhq/bifrost/transports/bifrost-http/integrations" + "github.com/maximhq/bifrost/transports/bifrost-http/integrations/anthropic" + "github.com/maximhq/bifrost/transports/bifrost-http/integrations/genai" + "github.com/maximhq/bifrost/transports/bifrost-http/integrations/openai" + "github.com/maximhq/bifrost/transports/bifrost-http/lib" +) + +// LangChainRouter holds route registrations for LangChain endpoints. +// It supports standard chat completions and image-enabled vision capabilities. +// LangChain is fully OpenAI-compatible, so we reuse OpenAI types +// with aliases for clarity and minimal LangChain-specific extensions +type LangChainRouter struct { + *integrations.GenericRouter +} + +// NewLangChainRouter creates a new LangChainRouter with the given bifrost client. +func NewLangChainRouter(client *bifrost.Bifrost, handlerStore lib.HandlerStore) *LangChainRouter { + routes := []integrations.RouteConfig{} + + // Add OpenAI routes to LangChain for OpenAI API compatibility + routes = append(routes, openai.CreateOpenAIRouteConfigs("/langchain", handlerStore)...) + + // Add Anthropic routes to LangChain for Anthropic API compatibility + routes = append(routes, anthropic.CreateAnthropicRouteConfigs("/langchain")...) + + // Add GenAI routes to LangChain for Vertex AI compatibility + routes = append(routes, genai.CreateGenAIRouteConfigs("/langchain")...) + + return &LangChainRouter{ + GenericRouter: integrations.NewGenericRouter(client, handlerStore, routes), + } +} diff --git a/transports/bifrost-http/integrations/litellm/router.go b/transports/bifrost-http/integrations/litellm/router.go new file mode 100644 index 000000000..b588a7ec8 --- /dev/null +++ b/transports/bifrost-http/integrations/litellm/router.go @@ -0,0 +1,36 @@ +package litellm + +import ( + bifrost "github.com/maximhq/bifrost/core" + "github.com/maximhq/bifrost/transports/bifrost-http/integrations" + "github.com/maximhq/bifrost/transports/bifrost-http/integrations/anthropic" + "github.com/maximhq/bifrost/transports/bifrost-http/integrations/genai" + "github.com/maximhq/bifrost/transports/bifrost-http/integrations/openai" + "github.com/maximhq/bifrost/transports/bifrost-http/lib" +) + +// LiteLLMRouter holds route registrations for LiteLLM endpoints. +// It supports standard chat completions and image-enabled vision capabilities. +// LiteLLM is fully OpenAI-compatible, so we reuse OpenAI types +// with aliases for clarity and minimal LiteLLM-specific extensions +type LiteLLMRouter struct { + *integrations.GenericRouter +} + +// NewLiteLLMRouter creates a new LiteLLMRouter with the given bifrost client. +func NewLiteLLMRouter(client *bifrost.Bifrost, handlerStore lib.HandlerStore) *LiteLLMRouter { + routes := []integrations.RouteConfig{} + + // Add OpenAI routes to LiteLLM for OpenAI API compatibility + routes = append(routes, openai.CreateOpenAIRouteConfigs("/litellm", handlerStore)...) + + // Add Anthropic routes to LiteLLM for Anthropic API compatibility + routes = append(routes, anthropic.CreateAnthropicRouteConfigs("/litellm")...) + + // Add GenAI routes to LiteLLM for Vertex AI compatibility + routes = append(routes, genai.CreateGenAIRouteConfigs("/litellm")...) + + return &LiteLLMRouter{ + GenericRouter: integrations.NewGenericRouter(client, handlerStore, routes), + } +} diff --git a/transports/bifrost-http/integrations/openai/router.go b/transports/bifrost-http/integrations/openai/router.go new file mode 100644 index 000000000..03f470bf2 --- /dev/null +++ b/transports/bifrost-http/integrations/openai/router.go @@ -0,0 +1,346 @@ +package openai + +import ( + "errors" + "strconv" + "strings" + + "github.com/google/uuid" + bifrost "github.com/maximhq/bifrost/core" + "github.com/maximhq/bifrost/core/schemas" + "github.com/maximhq/bifrost/transports/bifrost-http/integrations" + "github.com/maximhq/bifrost/transports/bifrost-http/lib" + "github.com/valyala/fasthttp" +) + +// setAzureModelName sets the model name for Azure requests with proper prefix handling +// When deploymentID is present, it always takes precedence over the request body model +// to avoid deployment/model mismatches. +func setAzureModelName(currentModel, deploymentID string) string { + if deploymentID != "" { + return "azure/" + deploymentID + } else if currentModel != "" && !strings.HasPrefix(currentModel, "azure/") { + return "azure/" + currentModel + } + return currentModel +} + +// OpenAIRouter holds route registrations for OpenAI endpoints. +// It supports standard chat completions, speech synthesis, audio transcription, and streaming capabilities with OpenAI-specific formatting. +type OpenAIRouter struct { + *integrations.GenericRouter +} + +func AzureEndpointPreHook(handlerStore lib.HandlerStore) func(ctx *fasthttp.RequestCtx, req interface{}) error { + return func(ctx *fasthttp.RequestCtx, req interface{}) error { + azureKey := ctx.Request.Header.Peek("authorization") + deploymentEndpoint := ctx.Request.Header.Peek("x-bf-azure-endpoint") + deploymentID := ctx.UserValue("deployment-id") + apiVersion := ctx.QueryArgs().Peek("api-version") + + if deploymentID != nil { + deploymentIDStr, ok := deploymentID.(string) + if !ok { + return errors.New("deployment-id is required in path") + } + + switch r := req.(type) { + case *OpenAIChatRequest: + r.Model = setAzureModelName(r.Model, deploymentIDStr) + case *OpenAISpeechRequest: + r.Model = setAzureModelName(r.Model, deploymentIDStr) + case *OpenAITranscriptionRequest: + r.Model = setAzureModelName(r.Model, deploymentIDStr) + case *OpenAIEmbeddingRequest: + r.Model = setAzureModelName(r.Model, deploymentIDStr) + } + + if deploymentEndpoint == nil || azureKey == nil || !handlerStore.ShouldAllowDirectKeys() { + return nil + } + + azureKeyStr := string(azureKey) + deploymentEndpointStr := string(deploymentEndpoint) + apiVersionStr := string(apiVersion) + + key := schemas.Key{ + ID: uuid.New().String(), + Models: []string{}, + AzureKeyConfig: &schemas.AzureKeyConfig{}, + } + + if deploymentEndpointStr != "" && deploymentIDStr != "" && azureKeyStr != "" { + key.Value = strings.TrimPrefix(azureKeyStr, "Bearer ") + key.AzureKeyConfig.Endpoint = deploymentEndpointStr + key.AzureKeyConfig.Deployments = map[string]string{deploymentIDStr: deploymentIDStr} + } + + if apiVersionStr != "" { + key.AzureKeyConfig.APIVersion = &apiVersionStr + } + + ctx.SetUserValue(string(schemas.BifrostContextKeyDirectKey), key) + + return nil + } + + return nil + } +} + +// CreateOpenAIRouteConfigs creates route configurations for OpenAI endpoints. +func CreateOpenAIRouteConfigs(pathPrefix string, handlerStore lib.HandlerStore) []integrations.RouteConfig { + var routes []integrations.RouteConfig + + // Chat completions endpoint + for _, path := range []string{ + "/v1/chat/completions", + "/chat/completions", + "/openai/deployments/{deployment-id}/chat/completions", + } { + routes = append(routes, integrations.RouteConfig{ + Path: pathPrefix + path, + Method: "POST", + GetRequestTypeInstance: func() interface{} { + return &OpenAIChatRequest{} + }, + RequestConverter: func(req interface{}) (*schemas.BifrostRequest, error) { + if openaiReq, ok := req.(*OpenAIChatRequest); ok { + return openaiReq.ConvertToBifrostRequest(pathPrefix != "/openai"), nil + } + return nil, errors.New("invalid request type") + }, + ResponseConverter: func(resp *schemas.BifrostResponse) (interface{}, error) { + return DeriveOpenAIFromBifrostResponse(resp), nil + }, + ErrorConverter: func(err *schemas.BifrostError) interface{} { + return DeriveOpenAIErrorFromBifrostError(err) + }, + StreamConfig: &integrations.StreamConfig{ + ResponseConverter: func(resp *schemas.BifrostResponse) (interface{}, error) { + return DeriveOpenAIStreamFromBifrostResponse(resp), nil + }, + ErrorConverter: func(err *schemas.BifrostError) interface{} { + return DeriveOpenAIStreamFromBifrostError(err) + }, + }, + PreCallback: AzureEndpointPreHook(handlerStore), + }) + } + + // Embeddings endpoint + for _, path := range []string{ + "/v1/embeddings", + "/embeddings", + "/openai/deployments/{deployment-id}/embeddings", + } { + routes = append(routes, integrations.RouteConfig{ + Path: pathPrefix + path, + Method: "POST", + GetRequestTypeInstance: func() interface{} { + return &OpenAIEmbeddingRequest{} + }, + RequestConverter: func(req interface{}) (*schemas.BifrostRequest, error) { + if embeddingReq, ok := req.(*OpenAIEmbeddingRequest); ok { + return embeddingReq.ConvertToBifrostRequest(pathPrefix != "/openai"), nil + } + return nil, errors.New("invalid embedding request type") + }, + ResponseConverter: func(resp *schemas.BifrostResponse) (interface{}, error) { + return DeriveOpenAIEmbeddingFromBifrostResponse(resp), nil + }, + ErrorConverter: func(err *schemas.BifrostError) interface{} { + return DeriveOpenAIErrorFromBifrostError(err) + }, + PreCallback: AzureEndpointPreHook(handlerStore), + }) + } + + // Speech synthesis endpoint + for _, path := range []string{ + "/v1/audio/speech", + "/audio/speech", + "/openai/deployments/{deployment-id}/audio/speech", + } { + routes = append(routes, integrations.RouteConfig{ + Path: pathPrefix + path, + Method: "POST", + GetRequestTypeInstance: func() interface{} { + return &OpenAISpeechRequest{} + }, + RequestConverter: func(req interface{}) (*schemas.BifrostRequest, error) { + if speechReq, ok := req.(*OpenAISpeechRequest); ok { + return speechReq.ConvertToBifrostRequest(pathPrefix != "/openai"), nil + } + return nil, errors.New("invalid speech request type") + }, + ResponseConverter: func(resp *schemas.BifrostResponse) (interface{}, error) { + speechResp := DeriveOpenAISpeechFromBifrostResponse(resp) + if speechResp == nil { + return nil, errors.New("failed to convert speech response") + } + // For speech, we return the raw audio data directly + return speechResp.Audio, nil + }, + ErrorConverter: func(err *schemas.BifrostError) interface{} { + return DeriveOpenAIErrorFromBifrostError(err) + }, + StreamConfig: &integrations.StreamConfig{ + ResponseConverter: func(resp *schemas.BifrostResponse) (interface{}, error) { + return DeriveOpenAISpeechFromBifrostResponse(resp), nil + }, + ErrorConverter: func(err *schemas.BifrostError) interface{} { + return DeriveOpenAIErrorFromBifrostError(err) + }, + }, + PreCallback: AzureEndpointPreHook(handlerStore), + }) + } + + // Audio transcription endpoint + for _, path := range []string{ + "/v1/audio/transcriptions", + "/audio/transcriptions", + "/openai/deployments/{deployment-id}/audio/transcriptions", + } { + routes = append(routes, integrations.RouteConfig{ + Path: pathPrefix + path, + Method: "POST", + GetRequestTypeInstance: func() interface{} { + return &OpenAITranscriptionRequest{} + }, + RequestParser: parseTranscriptionMultipartRequest, // Handle multipart form parsing + RequestConverter: func(req interface{}) (*schemas.BifrostRequest, error) { + if transcriptionReq, ok := req.(*OpenAITranscriptionRequest); ok { + return transcriptionReq.ConvertToBifrostRequest(pathPrefix != "/openai"), nil + } + return nil, errors.New("invalid transcription request type") + }, + ResponseConverter: func(resp *schemas.BifrostResponse) (interface{}, error) { + return DeriveOpenAITranscriptionFromBifrostResponse(resp), nil + }, + ErrorConverter: func(err *schemas.BifrostError) interface{} { + return DeriveOpenAIErrorFromBifrostError(err) + }, + StreamConfig: &integrations.StreamConfig{ + ResponseConverter: func(resp *schemas.BifrostResponse) (interface{}, error) { + return DeriveOpenAITranscriptionFromBifrostResponse(resp), nil + }, + ErrorConverter: func(err *schemas.BifrostError) interface{} { + return DeriveOpenAIErrorFromBifrostError(err) + }, + }, + PreCallback: AzureEndpointPreHook(handlerStore), + }) + } + + return routes +} + +// NewOpenAIRouter creates a new OpenAIRouter with the given bifrost client. +func NewOpenAIRouter(client *bifrost.Bifrost, handlerStore lib.HandlerStore) *OpenAIRouter { + return &OpenAIRouter{ + GenericRouter: integrations.NewGenericRouter(client, handlerStore, CreateOpenAIRouteConfigs("/openai", handlerStore)), + } +} + +// parseTranscriptionMultipartRequest is a RequestParser that handles multipart/form-data for transcription requests +func parseTranscriptionMultipartRequest(ctx *fasthttp.RequestCtx, req interface{}) error { + transcriptionReq, ok := req.(*OpenAITranscriptionRequest) + if !ok { + return errors.New("invalid request type for transcription") + } + + // Parse multipart form + form, err := ctx.MultipartForm() + if err != nil { + return err + } + + // Extract model (required) + modelValues := form.Value["model"] + if len(modelValues) == 0 || modelValues[0] == "" { + return errors.New("model field is required") + } + transcriptionReq.Model = modelValues[0] + + // Extract file (required) + fileHeaders := form.File["file"] + if len(fileHeaders) == 0 { + return errors.New("file field is required") + } + + fileHeader := fileHeaders[0] + file, err := fileHeader.Open() + if err != nil { + return err + } + defer file.Close() + + // Read file data + fileData := make([]byte, fileHeader.Size) + if _, err := file.Read(fileData); err != nil { + return err + } + transcriptionReq.File = fileData + + // Extract optional parameters + if languageValues := form.Value["language"]; len(languageValues) > 0 && languageValues[0] != "" { + language := languageValues[0] + transcriptionReq.Language = &language + } + + if promptValues := form.Value["prompt"]; len(promptValues) > 0 && promptValues[0] != "" { + prompt := promptValues[0] + transcriptionReq.Prompt = &prompt + } + + if responseFormatValues := form.Value["response_format"]; len(responseFormatValues) > 0 && responseFormatValues[0] != "" { + responseFormat := responseFormatValues[0] + transcriptionReq.ResponseFormat = &responseFormat + } + + if temperatureValues := form.Value["temperature"]; len(temperatureValues) > 0 && temperatureValues[0] != "" { + temp, err := strconv.ParseFloat(temperatureValues[0], 64) + if err != nil { + return errors.New("invalid temperature value") + } + transcriptionReq.Temperature = &temp + } + + // Handle include[] array format used by OpenAI + if includeValues := form.Value["include[]"]; len(includeValues) > 0 { + transcriptionReq.Include = includeValues + } else if includeValues := form.Value["include"]; len(includeValues) > 0 && includeValues[0] != "" { + // Fallback: Handle comma-separated values for backwards compatibility + includes := strings.Split(includeValues[0], ",") + // Trim whitespace from each value + for i, v := range includes { + includes[i] = strings.TrimSpace(v) + } + transcriptionReq.Include = includes + } + + // Handle timestamp_granularities[] array format used by OpenAI + if timestampValues := form.Value["timestamp_granularities[]"]; len(timestampValues) > 0 { + transcriptionReq.TimestampGranularities = timestampValues + } else if timestampValues := form.Value["timestamp_granularities"]; len(timestampValues) > 0 && timestampValues[0] != "" { + // Fallback: Handle comma-separated values for backwards compatibility + granularities := strings.Split(timestampValues[0], ",") + // Trim whitespace from each value + for i, v := range granularities { + granularities[i] = strings.TrimSpace(v) + } + transcriptionReq.TimestampGranularities = granularities + } + + if streamValues := form.Value["stream"]; len(streamValues) > 0 && streamValues[0] != "" { + stream, err := strconv.ParseBool(streamValues[0]) + if err != nil { + return errors.New("invalid stream value") + } + transcriptionReq.Stream = &stream + } + + return nil +} diff --git a/transports/bifrost-http/integrations/openai/types.go b/transports/bifrost-http/integrations/openai/types.go new file mode 100644 index 000000000..571ad5fb4 --- /dev/null +++ b/transports/bifrost-http/integrations/openai/types.go @@ -0,0 +1,616 @@ +package openai + +import ( + "github.com/bytedance/sonic" + "github.com/maximhq/bifrost/core/schemas" + "github.com/maximhq/bifrost/transports/bifrost-http/integrations" +) + +// OpenAIChatRequest represents an OpenAI chat completion request +type OpenAIChatRequest struct { + Model string `json:"model"` + Messages []schemas.BifrostMessage `json:"messages"` + MaxTokens *int `json:"max_tokens,omitempty"` + Temperature *float64 `json:"temperature,omitempty"` + TopP *float64 `json:"top_p,omitempty"` + N *int `json:"n,omitempty"` + Stop interface{} `json:"stop,omitempty"` + PresencePenalty *float64 `json:"presence_penalty,omitempty"` + FrequencyPenalty *float64 `json:"frequency_penalty,omitempty"` + LogitBias map[string]float64 `json:"logit_bias,omitempty"` + User *string `json:"user,omitempty"` + Tools *[]schemas.Tool `json:"tools,omitempty"` // Reuse schema type + ToolChoice *schemas.ToolChoice `json:"tool_choice,omitempty"` + Stream *bool `json:"stream,omitempty"` + LogProbs *bool `json:"logprobs,omitempty"` + TopLogProbs *int `json:"top_logprobs,omitempty"` + ResponseFormat interface{} `json:"response_format,omitempty"` + Seed *int `json:"seed,omitempty"` + MaxCompletionTokens *int `json:"max_completion_tokens,omitempty"` + ReasoningEffort *string `json:"reasoning_effort,omitempty"` + StreamOptions *map[string]interface{} `json:"stream_options,omitempty"` +} + +// OpenAISpeechRequest represents an OpenAI speech synthesis request +type OpenAISpeechRequest struct { + Model string `json:"model"` + Input string `json:"input"` + Voice string `json:"voice"` + ResponseFormat *string `json:"response_format,omitempty"` + Speed *float64 `json:"speed,omitempty"` + Instructions *string `json:"instructions,omitempty"` + StreamFormat *string `json:"stream_format,omitempty"` +} + +// OpenAITranscriptionRequest represents an OpenAI transcription request +// Note: This is used for JSON body parsing, actual form parsing is handled in the router +type OpenAITranscriptionRequest struct { + Model string `json:"model"` + File []byte `json:"file"` // Binary audio data + Language *string `json:"language,omitempty"` + Prompt *string `json:"prompt,omitempty"` + ResponseFormat *string `json:"response_format,omitempty"` + Temperature *float64 `json:"temperature,omitempty"` + Include []string `json:"include,omitempty"` + TimestampGranularities []string `json:"timestamp_granularities,omitempty"` + Stream *bool `json:"stream,omitempty"` +} + +// OpenAIEmbeddingRequest represents an OpenAI embedding request +type OpenAIEmbeddingRequest struct { + Model string `json:"model"` + Input any `json:"input"` // Can be string, []string, []int, [][]int + EncodingFormat *string `json:"encoding_format,omitempty"` + Dimensions *int `json:"dimensions,omitempty"` + User *string `json:"user,omitempty"` +} + +// IsStreamingRequested implements the StreamingRequest interface +func (r *OpenAIChatRequest) IsStreamingRequested() bool { + return r.Stream != nil && *r.Stream +} + +// IsStreamingRequested implements the StreamingRequest interface for speech +func (r *OpenAISpeechRequest) IsStreamingRequested() bool { + return r.StreamFormat != nil && *r.StreamFormat == "sse" +} + +// IsStreamingRequested implements the StreamingRequest interface for transcription +func (r *OpenAITranscriptionRequest) IsStreamingRequested() bool { + return r.Stream != nil && *r.Stream +} + +// IsStreamingRequested implements the StreamingRequest interface for embeddings +// Note: Embeddings don't support streaming in OpenAI API +func (r *OpenAIEmbeddingRequest) IsStreamingRequested() bool { + return false +} + +// OpenAIChatResponse represents an OpenAI chat completion response +type OpenAIChatResponse struct { + ID string `json:"id"` + Object string `json:"object"` + Created int `json:"created"` + Model string `json:"model"` + Choices []schemas.BifrostResponseChoice `json:"choices"` + Usage *schemas.LLMUsage `json:"usage,omitempty"` // Reuse schema type + ServiceTier *string `json:"service_tier,omitempty"` + SystemFingerprint *string `json:"system_fingerprint,omitempty"` +} + +// OpenAIEmbeddingResponse represents an OpenAI embedding response +type OpenAIEmbeddingResponse struct { + Object string `json:"object"` + Data []schemas.BifrostEmbedding `json:"data"` + Model string `json:"model"` + Usage *schemas.LLMUsage `json:"usage,omitempty"` + ServiceTier *string `json:"service_tier,omitempty"` + SystemFingerprint *string `json:"system_fingerprint,omitempty"` +} + +// OpenAIChatError represents an OpenAI chat completion error response +type OpenAIChatError struct { + EventID string `json:"event_id"` // Unique identifier for the error event + Type string `json:"type"` // Type of error + Error struct { + Type string `json:"type"` // Error type + Code string `json:"code"` // Error code + Message string `json:"message"` // Error message + Param interface{} `json:"param"` // Parameter that caused the error + EventID string `json:"event_id"` // Event ID for tracking + } `json:"error"` +} + +// OpenAIChatErrorStruct represents the error structure of an OpenAI chat completion error response +type OpenAIChatErrorStruct struct { + Type string `json:"type"` // Error type + Code string `json:"code"` // Error code + Message string `json:"message"` // Error message + Param interface{} `json:"param"` // Parameter that caused the error + EventID string `json:"event_id"` // Event ID for tracking +} + +// OpenAIStreamChoice represents a choice in a streaming response chunk +type OpenAIStreamChoice struct { + Index int `json:"index"` + Delta *OpenAIStreamDelta `json:"delta,omitempty"` + FinishReason *string `json:"finish_reason,omitempty"` + LogProbs *schemas.LogProbs `json:"logprobs,omitempty"` +} + +// OpenAIStreamDelta represents the incremental content in a streaming chunk +type OpenAIStreamDelta struct { + Role *string `json:"role,omitempty"` + Content *string `json:"content,omitempty"` + ToolCalls *[]schemas.ToolCall `json:"tool_calls,omitempty"` +} + +// OpenAIStreamResponse represents a single chunk in the OpenAI streaming response +type OpenAIStreamResponse struct { + ID string `json:"id"` + Object string `json:"object"` + Created int `json:"created"` + Model string `json:"model"` + SystemFingerprint *string `json:"system_fingerprint,omitempty"` + Choices []OpenAIStreamChoice `json:"choices"` + Usage *schemas.LLMUsage `json:"usage,omitempty"` +} + +// ConvertToBifrostRequest converts an OpenAI chat request to Bifrost format +func (r *OpenAIChatRequest) ConvertToBifrostRequest(checkProviderFromModel bool) *schemas.BifrostRequest { + provider, model := integrations.ParseModelString(r.Model, schemas.OpenAI, checkProviderFromModel) + + // Convert parameters first + params := r.convertParameters() + + bifrostReq := &schemas.BifrostRequest{ + Provider: provider, + Model: model, + Input: schemas.RequestInput{ + ChatCompletionInput: &r.Messages, + }, + Params: filterParams(provider, params), + } + + return bifrostReq +} + +// ConvertToBifrostRequest converts an OpenAI speech request to Bifrost format +func (r *OpenAISpeechRequest) ConvertToBifrostRequest(checkProviderFromModel bool) *schemas.BifrostRequest { + provider, model := integrations.ParseModelString(r.Model, schemas.OpenAI, checkProviderFromModel) + + // Create speech input + speechInput := &schemas.SpeechInput{ + Input: r.Input, + VoiceConfig: schemas.SpeechVoiceInput{ + Voice: &r.Voice, + }, + } + + // Set response format if provided + if r.ResponseFormat != nil { + speechInput.ResponseFormat = *r.ResponseFormat + } + + // Set instructions if provided + if r.Instructions != nil { + speechInput.Instructions = *r.Instructions + } + + bifrostReq := &schemas.BifrostRequest{ + Provider: provider, + Model: model, + Input: schemas.RequestInput{ + SpeechInput: speechInput, + }, + } + + // Convert parameters first + params := r.convertSpeechParameters() + + // Map parameters + bifrostReq.Params = filterParams(provider, params) + + return bifrostReq +} + +// ConvertToBifrostRequest converts an OpenAI transcription request to Bifrost format +func (r *OpenAITranscriptionRequest) ConvertToBifrostRequest(checkProviderFromModel bool) *schemas.BifrostRequest { + provider, model := integrations.ParseModelString(r.Model, schemas.OpenAI, checkProviderFromModel) + + // Create transcription input + transcriptionInput := &schemas.TranscriptionInput{ + File: r.File, + } + + // Set optional fields + if r.Language != nil { + transcriptionInput.Language = r.Language + } + if r.Prompt != nil { + transcriptionInput.Prompt = r.Prompt + } + if r.ResponseFormat != nil { + transcriptionInput.ResponseFormat = r.ResponseFormat + } + + bifrostReq := &schemas.BifrostRequest{ + Provider: provider, + Model: model, + Input: schemas.RequestInput{ + TranscriptionInput: transcriptionInput, + }, + } + + // Convert parameters first + params := r.convertTranscriptionParameters() + + // Map parameters + bifrostReq.Params = filterParams(provider, params) + + return bifrostReq +} + +// ConvertToBifrostRequest converts an OpenAI embedding request to Bifrost format +func (r *OpenAIEmbeddingRequest) ConvertToBifrostRequest(checkProviderFromModel bool) *schemas.BifrostRequest { + provider, model := integrations.ParseModelString(r.Model, schemas.OpenAI, checkProviderFromModel) + + // Create embedding input + embeddingInput := &schemas.EmbeddingInput{} + + // Cleaner coercion: marshal input and try to unmarshal into supported shapes + if raw, err := sonic.Marshal(r.Input); err == nil { + // 1) string + var s string + if err := sonic.Unmarshal(raw, &s); err == nil { + embeddingInput.Text = &s + } else { + // 2) []string + var ss []string + if err := sonic.Unmarshal(raw, &ss); err == nil { + embeddingInput.Texts = ss + } else { + // 3) []int + var i []int + if err := sonic.Unmarshal(raw, &i); err == nil { + embeddingInput.Embedding = i + } else { + // 4) [][]int + var ii [][]int + if err := sonic.Unmarshal(raw, &ii); err == nil { + embeddingInput.Embeddings = ii + } + } + } + } + } + + bifrostReq := &schemas.BifrostRequest{ + Provider: provider, + Model: model, + Input: schemas.RequestInput{ + EmbeddingInput: embeddingInput, + }, + } + + // Convert parameters first + params := r.convertEmbeddingParameters() + + // Map parameters + bifrostReq.Params = filterParams(provider, params) + + return bifrostReq +} + +// convertParameters converts OpenAI request parameters to Bifrost ModelParameters +// using direct field access for better performance and type safety. +func (r *OpenAIChatRequest) convertParameters() *schemas.ModelParameters { + params := &schemas.ModelParameters{ + ExtraParams: make(map[string]interface{}), + } + + params.Tools = r.Tools + params.ToolChoice = r.ToolChoice + + // Direct field mapping + if r.MaxTokens != nil { + params.MaxTokens = r.MaxTokens + } + if r.Temperature != nil { + params.Temperature = r.Temperature + } + if r.TopP != nil { + params.TopP = r.TopP + } + if r.PresencePenalty != nil { + params.PresencePenalty = r.PresencePenalty + } + if r.FrequencyPenalty != nil { + params.FrequencyPenalty = r.FrequencyPenalty + } + if r.N != nil { + params.ExtraParams["n"] = *r.N + } + if r.LogProbs != nil { + params.ExtraParams["logprobs"] = *r.LogProbs + } + if r.TopLogProbs != nil { + params.ExtraParams["top_logprobs"] = *r.TopLogProbs + } + if r.Stop != nil { + params.ExtraParams["stop"] = r.Stop + } + if r.LogitBias != nil { + params.ExtraParams["logit_bias"] = r.LogitBias + } + if r.User != nil { + params.ExtraParams["user"] = *r.User + } + if r.Stream != nil { + params.ExtraParams["stream"] = *r.Stream + } + if r.Seed != nil { + params.ExtraParams["seed"] = *r.Seed + } + if r.StreamOptions != nil { + params.ExtraParams["stream_options"] = r.StreamOptions + } + if r.ResponseFormat != nil { + params.ExtraParams["response_format"] = r.ResponseFormat + } + if r.MaxCompletionTokens != nil { + params.ExtraParams["max_completion_tokens"] = *r.MaxCompletionTokens + } + if r.ReasoningEffort != nil { + params.ExtraParams["reasoning_effort"] = *r.ReasoningEffort + } + + return params +} + +// convertSpeechParameters converts OpenAI speech request parameters to Bifrost ModelParameters +func (r *OpenAISpeechRequest) convertSpeechParameters() *schemas.ModelParameters { + params := &schemas.ModelParameters{ + ExtraParams: make(map[string]interface{}), + } + + // Add speech-specific parameters + if r.Speed != nil { + params.ExtraParams["speed"] = *r.Speed + } + + return params +} + +// convertTranscriptionParameters converts OpenAI transcription request parameters to Bifrost ModelParameters +func (r *OpenAITranscriptionRequest) convertTranscriptionParameters() *schemas.ModelParameters { + params := &schemas.ModelParameters{ + ExtraParams: make(map[string]interface{}), + } + + // Add transcription-specific parameters + if r.Temperature != nil { + params.ExtraParams["temperature"] = *r.Temperature + } + if len(r.TimestampGranularities) > 0 { + params.ExtraParams["timestamp_granularities"] = r.TimestampGranularities + } + if len(r.Include) > 0 { + params.ExtraParams["include"] = r.Include + } + + return params +} + +// convertEmbeddingParameters converts OpenAI embedding request parameters to Bifrost ModelParameters +func (r *OpenAIEmbeddingRequest) convertEmbeddingParameters() *schemas.ModelParameters { + params := &schemas.ModelParameters{ + ExtraParams: make(map[string]interface{}), + } + + // Add embedding-specific parameters + if r.EncodingFormat != nil { + params.EncodingFormat = r.EncodingFormat + } + if r.Dimensions != nil { + params.Dimensions = r.Dimensions + } + if r.User != nil { + params.User = r.User + } + + return params +} + +// DeriveOpenAIFromBifrostResponse converts a Bifrost response to OpenAI format +func DeriveOpenAIFromBifrostResponse(bifrostResp *schemas.BifrostResponse) *OpenAIChatResponse { + if bifrostResp == nil { + return nil + } + + openaiResp := &OpenAIChatResponse{ + ID: bifrostResp.ID, + Object: bifrostResp.Object, + Created: bifrostResp.Created, + Model: bifrostResp.Model, + Choices: bifrostResp.Choices, + Usage: bifrostResp.Usage, + ServiceTier: bifrostResp.ServiceTier, + SystemFingerprint: bifrostResp.SystemFingerprint, + } + + return openaiResp +} + +// DeriveOpenAISpeechFromBifrostResponse converts a Bifrost speech response to OpenAI format +func DeriveOpenAISpeechFromBifrostResponse(bifrostResp *schemas.BifrostResponse) *schemas.BifrostSpeech { + if bifrostResp == nil || bifrostResp.Speech == nil { + return nil + } + + return bifrostResp.Speech +} + +// DeriveOpenAITranscriptionFromBifrostResponse converts a Bifrost transcription response to OpenAI format +func DeriveOpenAITranscriptionFromBifrostResponse(bifrostResp *schemas.BifrostResponse) *schemas.BifrostTranscribe { + if bifrostResp == nil || bifrostResp.Transcribe == nil { + return nil + } + return bifrostResp.Transcribe +} + +// DeriveOpenAIEmbeddingFromBifrostResponse converts a Bifrost embedding response to OpenAI format +func DeriveOpenAIEmbeddingFromBifrostResponse(bifrostResp *schemas.BifrostResponse) *OpenAIEmbeddingResponse { + if bifrostResp == nil || bifrostResp.Data == nil { + return nil + } + + return &OpenAIEmbeddingResponse{ + Object: "list", + Data: bifrostResp.Data, + Model: bifrostResp.Model, + Usage: bifrostResp.Usage, + ServiceTier: bifrostResp.ServiceTier, + SystemFingerprint: bifrostResp.SystemFingerprint, + } +} + +// DeriveOpenAIErrorFromBifrostError derives a OpenAIChatError from a BifrostError +func DeriveOpenAIErrorFromBifrostError(bifrostErr *schemas.BifrostError) *OpenAIChatError { + if bifrostErr == nil { + return nil + } + + // Provide blank strings for nil pointer fields + eventID := "" + if bifrostErr.EventID != nil { + eventID = *bifrostErr.EventID + } + + errorType := "" + if bifrostErr.Type != nil { + errorType = *bifrostErr.Type + } + + // Handle nested error fields with nil checks + errorStruct := OpenAIChatErrorStruct{ + Type: "", + Code: "", + Message: bifrostErr.Error.Message, + Param: bifrostErr.Error.Param, + EventID: eventID, + } + + if bifrostErr.Error.Type != nil { + errorStruct.Type = *bifrostErr.Error.Type + } + + if bifrostErr.Error.Code != nil { + errorStruct.Code = *bifrostErr.Error.Code + } + + if bifrostErr.Error.EventID != nil { + errorStruct.EventID = *bifrostErr.Error.EventID + } + + return &OpenAIChatError{ + EventID: eventID, + Type: errorType, + Error: errorStruct, + } +} + +// DeriveOpenAIStreamFromBifrostError derives an OpenAI streaming error from a BifrostError +func DeriveOpenAIStreamFromBifrostError(bifrostErr *schemas.BifrostError) *OpenAIChatError { + // For streaming, we use the same error format as regular OpenAI errors + return DeriveOpenAIErrorFromBifrostError(bifrostErr) +} + +// DeriveOpenAIStreamFromBifrostResponse converts a Bifrost response to OpenAI streaming format +func DeriveOpenAIStreamFromBifrostResponse(bifrostResp *schemas.BifrostResponse) *OpenAIStreamResponse { + if bifrostResp == nil { + return nil + } + + streamResp := &OpenAIStreamResponse{ + ID: bifrostResp.ID, + Object: "chat.completion.chunk", + Created: bifrostResp.Created, + Model: bifrostResp.Model, + SystemFingerprint: bifrostResp.SystemFingerprint, + Usage: bifrostResp.Usage, + } + + // Convert choices to streaming format + for _, choice := range bifrostResp.Choices { + streamChoice := OpenAIStreamChoice{ + Index: choice.Index, + FinishReason: choice.FinishReason, + } + + var delta *OpenAIStreamDelta + + // Handle streaming vs non-streaming choices + if choice.BifrostStreamResponseChoice != nil { + // This is a streaming response - use the delta directly + delta = &OpenAIStreamDelta{} + + // Only set fields that are not nil + if choice.BifrostStreamResponseChoice.Delta.Role != nil { + delta.Role = choice.BifrostStreamResponseChoice.Delta.Role + } + if choice.BifrostStreamResponseChoice.Delta.Content != nil { + delta.Content = choice.BifrostStreamResponseChoice.Delta.Content + } + if len(choice.BifrostStreamResponseChoice.Delta.ToolCalls) > 0 { + delta.ToolCalls = &choice.BifrostStreamResponseChoice.Delta.ToolCalls + } + } else if choice.BifrostNonStreamResponseChoice != nil { + // This is a non-streaming response - convert message to delta format + delta = &OpenAIStreamDelta{} + + // Convert role + role := string(choice.BifrostNonStreamResponseChoice.Message.Role) + delta.Role = &role + + // Convert content + if choice.BifrostNonStreamResponseChoice.Message.Content.ContentStr != nil { + delta.Content = choice.BifrostNonStreamResponseChoice.Message.Content.ContentStr + } + + // Convert tool calls if present (from AssistantMessage) + if choice.BifrostNonStreamResponseChoice.Message.AssistantMessage != nil && + choice.BifrostNonStreamResponseChoice.Message.AssistantMessage.ToolCalls != nil { + delta.ToolCalls = choice.BifrostNonStreamResponseChoice.Message.AssistantMessage.ToolCalls + } + + // Set LogProbs from non-streaming choice + if choice.BifrostNonStreamResponseChoice.LogProbs != nil { + streamChoice.LogProbs = choice.BifrostNonStreamResponseChoice.LogProbs + } + } + + // Ensure we have a valid delta with at least one field set + // If all fields are nil, we should skip this chunk or set an empty content + if delta != nil { + hasValidField := (delta.Role != nil) || (delta.Content != nil) || (delta.ToolCalls != nil) + if !hasValidField { + // Set empty content to ensure we have at least one field + emptyContent := "" + delta.Content = &emptyContent + } + streamChoice.Delta = delta + } + + streamResp.Choices = append(streamResp.Choices, streamChoice) + } + + return streamResp +} + +func filterParams(provider schemas.ModelProvider, p *schemas.ModelParameters) *schemas.ModelParameters { + if p == nil { + return nil + } + return integrations.ValidateAndFilterParamsForProvider(provider, p) +} diff --git a/transports/bifrost-http/integrations/utils.go b/transports/bifrost-http/integrations/utils.go new file mode 100644 index 000000000..322bd8942 --- /dev/null +++ b/transports/bifrost-http/integrations/utils.go @@ -0,0 +1,1205 @@ +// Package integrations provides a generic router framework for handling different LLM provider APIs. +// +// CENTRALIZED STREAMING ARCHITECTURE: +// +// This package implements a centralized streaming approach where all stream handling logic +// is consolidated in the GenericRouter, eliminating the need for provider-specific StreamHandler +// implementations. The key components are: +// +// 1. StreamConfig: Defines streaming configuration for each route, including: +// - ResponseConverter: Converts BifrostResponse to provider-specific streaming format +// - ErrorConverter: Converts BifrostError to provider-specific streaming error format +// +// 2. Centralized Stream Processing: The GenericRouter handles all streaming logic: +// - SSE header management +// - Stream channel processing +// - Error handling and conversion +// - Response formatting and flushing +// - Stream closure (handled automatically by provider implementation) +// +// 3. Provider-Specific Type Conversion: Integration types.go files only handle type conversion: +// - Derive{Provider}StreamFromBifrostResponse: Convert responses to streaming format +// - Derive{Provider}StreamFromBifrostError: Convert errors to streaming error format +// +// BENEFITS: +// - Eliminates code duplication across provider-specific stream handlers +// - Centralizes streaming logic for consistency and maintainability +// - Separates concerns: routing logic vs type conversion +// - Automatic stream closure management by provider implementations +// - Consistent error handling across all providers +// +// USAGE EXAMPLE: +// +// routes := []RouteConfig{ +// { +// Path: "/openai/chat/completions", +// Method: "POST", +// // ... other configs ... +// StreamConfig: &StreamConfig{ +// ResponseConverter: func(resp *schemas.BifrostResponse) (interface{}, error) { +// return DeriveOpenAIStreamFromBifrostResponse(resp), nil +// }, +// ErrorConverter: func(err *schemas.BifrostError) interface{} { +// return DeriveOpenAIStreamFromBifrostError(err) +// }, +// }, +// }, +// } +package integrations + +import ( + "context" + "encoding/json" + "fmt" + "log" + "regexp" + "strconv" + "strings" + + "bufio" + + "github.com/fasthttp/router" + bifrost "github.com/maximhq/bifrost/core" + "github.com/maximhq/bifrost/core/schemas" + "github.com/maximhq/bifrost/transports/bifrost-http/lib" + "github.com/valyala/fasthttp" +) + +// ExtensionRouter defines the interface that all integration routers must implement +// to register their routes with the main HTTP router. +type ExtensionRouter interface { + RegisterRoutes(r *router.Router) +} + +// StreamingRequest interface for requests that support streaming +type StreamingRequest interface { + IsStreamingRequested() bool +} + +// RequestConverter is a function that converts integration-specific requests to Bifrost format. +// It takes the parsed request object and returns a BifrostRequest ready for processing. +type RequestConverter func(req interface{}) (*schemas.BifrostRequest, error) + +// ResponseConverter is a function that converts Bifrost responses to integration-specific format. +// It takes a BifrostResponse and returns the format expected by the specific integration. +type ResponseConverter func(*schemas.BifrostResponse) (interface{}, error) + +// StreamResponseConverter is a function that converts Bifrost responses to integration-specific streaming format. +// It takes a BifrostResponse and returns the streaming format expected by the specific integration. +type StreamResponseConverter func(*schemas.BifrostResponse) (interface{}, error) + +// ErrorConverter is a function that converts BifrostError to integration-specific format. +// It takes a BifrostError and returns the format expected by the specific integration. +type ErrorConverter func(*schemas.BifrostError) interface{} + +// StreamErrorConverter is a function that converts BifrostError to integration-specific streaming error format. +// It takes a BifrostError and returns the streaming error format expected by the specific integration. +type StreamErrorConverter func(*schemas.BifrostError) interface{} + +// RequestParser is a function that handles custom request body parsing. +// It replaces the default JSON parsing when configured (e.g., for multipart/form-data). +// The parser should populate the provided request object from the fasthttp context. +// If it returns an error, the request processing stops. +type RequestParser func(ctx *fasthttp.RequestCtx, req interface{}) error + +// PreRequestCallback is called after parsing the request but before processing through Bifrost. +// It can be used to modify the request object (e.g., extract model from URL parameters) +// or perform validation. If it returns an error, the request processing stops. +type PreRequestCallback func(ctx *fasthttp.RequestCtx, req interface{}) error + +// PostRequestCallback is called after processing the request but before sending the response. +// It can be used to modify the response or perform additional logging/metrics. +// If it returns an error, an error response is sent instead of the success response. +type PostRequestCallback func(ctx *fasthttp.RequestCtx, req interface{}, resp *schemas.BifrostResponse) error + +// StreamConfig defines streaming-specific configuration for an integration +// +// SSE FORMAT BEHAVIOR: +// +// The ResponseConverter and ErrorConverter functions in StreamConfig can return either: +// +// 1. OBJECTS (interface{} that's not a string): +// - Will be JSON marshaled and sent as standard SSE: data: {json}\n\n +// - Use this for most providers (OpenAI, Google, etc.) +// - Example: return map[string]interface{}{"delta": {"content": "hello"}} +// - Result: data: {"delta":{"content":"hello"}}\n\n +// +// 2. STRINGS: +// - Will be sent directly as-is without any modification +// - Use this for providers requiring custom SSE event types (Anthropic, etc.) +// - Example: return "event: content_block_delta\ndata: {\"type\":\"text\"}\n\n" +// - Result: event: content_block_delta +// data: {"type":"text"} +// +// Choose the appropriate return type based on your provider's SSE specification. +type StreamConfig struct { + ResponseConverter StreamResponseConverter // Function to convert BifrostResponse to streaming format + ErrorConverter StreamErrorConverter // Function to convert BifrostError to streaming error format +} + +// RouteConfig defines the configuration for a single route in an integration. +// It specifies the path, method, and handlers for request/response conversion. +type RouteConfig struct { + Path string // HTTP path pattern (e.g., "/openai/v1/chat/completions") + Method string // HTTP method (POST, GET, PUT, DELETE) + GetRequestTypeInstance func() interface{} // Factory function to create request instance (SHOULD NOT BE NIL) + RequestParser RequestParser // Optional: custom request parsing (e.g., multipart/form-data) + RequestConverter RequestConverter // Function to convert request to BifrostRequest (SHOULD NOT BE NIL) + ResponseConverter ResponseConverter // Function to convert BifrostResponse to integration format (SHOULD NOT BE NIL) + ErrorConverter ErrorConverter // Function to convert BifrostError to integration format (SHOULD NOT BE NIL) + StreamConfig *StreamConfig // Optional: Streaming configuration (if nil, streaming not supported) + PreCallback PreRequestCallback // Optional: called after parsing but before Bifrost processing + PostCallback PostRequestCallback // Optional: called after request processing +} + +// DefaultParameters defines the common parameters that most providers support +var DefaultParameters = map[string]bool{ + "max_tokens": true, + "temperature": true, + "top_p": true, + "stream": true, + "tools": true, + "tool_choice": true, +} + +// ProviderParameterSchema defines which parameters are valid for each provider +type ProviderParameterSchema struct { + ValidParams map[string]bool // Parameters that are supported by this provider +} + +// ParameterValidator validates and filters parameters for specific providers +type ParameterValidator struct { + schemas map[schemas.ModelProvider]ProviderParameterSchema +} + +// NewParameterValidator creates a new validator with provider schemas +func NewParameterValidator() *ParameterValidator { + return &ParameterValidator{ + schemas: buildProviderSchemas(), + } +} + +// ValidateAndFilterParams filters out invalid parameters for the target provider +func (v *ParameterValidator) ValidateAndFilterParams( + provider schemas.ModelProvider, + params *schemas.ModelParameters, +) *schemas.ModelParameters { + if params == nil { + return nil + } + + schema, exists := v.schemas[provider] + if !exists { + // Unknown provider, return all params (fallback behavior) + return params + } + + filteredParams := &schemas.ModelParameters{ + ExtraParams: make(map[string]interface{}), + } + + // Filter standard parameters + if params.MaxTokens != nil && schema.ValidParams["max_tokens"] { + filteredParams.MaxTokens = params.MaxTokens + } + + if params.Temperature != nil && schema.ValidParams["temperature"] { + filteredParams.Temperature = params.Temperature + } + + if params.TopP != nil && schema.ValidParams["top_p"] { + filteredParams.TopP = params.TopP + } + + if params.TopK != nil && schema.ValidParams["top_k"] { + filteredParams.TopK = params.TopK + } + + if params.PresencePenalty != nil && schema.ValidParams["presence_penalty"] { + filteredParams.PresencePenalty = params.PresencePenalty + } + + if params.FrequencyPenalty != nil && schema.ValidParams["frequency_penalty"] { + filteredParams.FrequencyPenalty = params.FrequencyPenalty + } + + if params.StopSequences != nil && schema.ValidParams["stop_sequences"] { + filteredParams.StopSequences = params.StopSequences + } + + if params.Tools != nil && schema.ValidParams["tools"] { + filteredParams.Tools = params.Tools + } + + if params.ToolChoice != nil && schema.ValidParams["tool_choice"] { + filteredParams.ToolChoice = params.ToolChoice + } + + if params.User != nil && schema.ValidParams["user"] { + filteredParams.User = params.User + } + + if params.EncodingFormat != nil && schema.ValidParams["encoding_format"] { + filteredParams.EncodingFormat = params.EncodingFormat + } + + if params.Dimensions != nil && schema.ValidParams["dimensions"] { + filteredParams.Dimensions = params.Dimensions + } + + // Parallel tool calls + if params.ParallelToolCalls != nil && schema.ValidParams["parallel_tool_calls"] { + filteredParams.ParallelToolCalls = params.ParallelToolCalls + } + + // Filter extra parameters + for key, value := range params.ExtraParams { + if schema.ValidParams[key] { + filteredParams.ExtraParams[key] = value + } + } + + // Check if all standard pointer fields are nil and ExtraParams is empty + if hasNoValidFields(filteredParams) && len(filteredParams.ExtraParams) == 0 { + return nil + } + + return filteredParams +} + +// hasNoValidFields checks if all standard pointer fields in ModelParameters are nil +func hasNoValidFields(params *schemas.ModelParameters) bool { + return params.ToolChoice == nil && + params.Tools == nil && + params.Temperature == nil && + params.TopP == nil && + params.TopK == nil && + params.MaxTokens == nil && + params.StopSequences == nil && + params.PresencePenalty == nil && + params.FrequencyPenalty == nil && + params.ParallelToolCalls == nil && + params.EncodingFormat == nil && + params.Dimensions == nil && + params.User == nil +} + +// buildProviderSchemas defines which parameters are valid for each provider +func buildProviderSchemas() map[schemas.ModelProvider]ProviderParameterSchema { + // Define parameter groups to avoid repetition + openAIParams := map[string]bool{ + "frequency_penalty": true, + "presence_penalty": true, + "n": true, + "stop": true, + "logprobs": true, + "top_logprobs": true, + "logit_bias": true, + "seed": true, + "user": true, + "response_format": true, + "parallel_tool_calls": true, + "max_completion_tokens": true, + "metadata": true, + "modalities": true, + "prediction": true, + "reasoning_effort": true, + "service_tier": true, + "store": true, + "speed": true, + "language": true, + "prompt": true, + "include": true, + "timestamp_granularities": true, + "encoding_format": true, + "dimensions": true, + "stream_options": true, + } + + anthropicParams := map[string]bool{ + "stop_sequences": true, + "system": true, + "metadata": true, + "mcp_servers": true, + "service_tier": true, + "thinking": true, + "top_k": true, + } + + cohereParams := map[string]bool{ + "frequency_penalty": true, + "presence_penalty": true, + "k": true, + "p": true, + "truncate": true, + "return_likelihoods": true, + "logit_bias": true, + "stop_sequences": true, + } + + mistralParams := map[string]bool{ + "frequency_penalty": true, + "presence_penalty": true, + "safe_mode": true, + "n": true, + "parallel_tool_calls": true, + "prediction": true, + "prompt_mode": true, + "random_seed": true, + "response_format": true, + "safe_prompt": true, + "top_k": true, + } + + groqParams := map[string]bool{ + "n": true, + "reasoning_effort": true, + "reasoning_format": true, + "service_tier": true, + "stop": true, + } + + ollamaParams := map[string]bool{ + "num_ctx": true, + "num_gpu": true, + "num_thread": true, + "repeat_penalty": true, + "repeat_last_n": true, + "seed": true, + "tfs_z": true, + "mirostat": true, + "mirostat_tau": true, + "mirostat_eta": true, + "format": true, + "keep_alive": true, + "low_vram": true, + "main_gpu": true, + "min_p": true, + "num_batch": true, + "num_keep": true, + "num_predict": true, + "numa": true, + "penalize_newline": true, + "raw": true, + "typical_p": true, + "use_mlock": true, + "use_mmap": true, + "vocab_only": true, + } + + // Vertex supports both OpenAI and Anthropic models, plus its own specific parameters + vertexParams := mergeWithDefaults(openAIParams) + // Add Anthropic-specific parameters for Claude models on Vertex + for k, v := range anthropicParams { + vertexParams[k] = v + } + // Add Vertex-specific parameters + vertexSpecificParams := map[string]bool{ + "task_type": true, // For embeddings + "title": true, // For embeddings + "autoTruncate": true, // For embeddings + "outputDimensionality": true, // For embeddings (maps to dimensions) + } + for k, v := range vertexSpecificParams { + vertexParams[k] = v + } + + // Bedrock supports both Anthropic and Mistral models, plus its own specific parameters + bedrockParams := mergeWithDefaults(anthropicParams) + // Add Mistral-specific parameters for Mistral models on Bedrock + for k, v := range mistralParams { + bedrockParams[k] = v + } + // Add Bedrock-specific parameters + bedrockSpecificParams := map[string]bool{ + "max_tokens_to_sample": true, // Anthropic models use this instead of max_tokens + "toolConfig": true, // Bedrock-specific tool configuration + "input_type": true, // For Cohere embeddings + } + for k, v := range bedrockSpecificParams { + bedrockParams[k] = v + } + + geminiParams := mergeWithDefaults(openAIParams) + geminiParams["top_k"] = true + geminiParams["stop_sequences"] = true + + openRouterSpecificParams := map[string]bool{ + "transforms": true, + "models": true, + "route": true, + "provider": true, + "prediction": true, // Reduce latency by providing the model with a predicted output + "top_a": true, // Range: [0, 1] + "min_p": true, // Range: [0, 1] + } + openRouterParams := mergeWithDefaults(openAIParams) + for k, v := range openRouterSpecificParams { + openRouterParams[k] = v + } + + return map[schemas.ModelProvider]ProviderParameterSchema{ + schemas.OpenAI: {ValidParams: mergeWithDefaults(openAIParams)}, + schemas.Azure: {ValidParams: mergeWithDefaults(openAIParams)}, + schemas.Anthropic: {ValidParams: mergeWithDefaults(anthropicParams)}, + schemas.Cohere: {ValidParams: mergeWithDefaults(cohereParams)}, + schemas.Mistral: {ValidParams: mergeWithDefaults(mistralParams)}, + schemas.Groq: {ValidParams: mergeWithDefaults(groqParams)}, + schemas.Bedrock: {ValidParams: bedrockParams}, + schemas.Vertex: {ValidParams: vertexParams}, + schemas.Ollama: {ValidParams: mergeWithDefaults(ollamaParams)}, + schemas.Cerebras: {ValidParams: mergeWithDefaults(openAIParams)}, + schemas.SGL: {ValidParams: mergeWithDefaults(openAIParams)}, + schemas.Parasail: {ValidParams: mergeWithDefaults(openAIParams)}, + schemas.Gemini: {ValidParams: geminiParams}, + schemas.OpenRouter: {ValidParams: openRouterParams}, + } +} + +// mergeWithDefaults merges provider-specific parameters with default parameters +func mergeWithDefaults(providerParams map[string]bool) map[string]bool { + result := make(map[string]bool, len(DefaultParameters)+len(providerParams)) + + // Copy default parameters + for k, v := range DefaultParameters { + result[k] = v + } + + // Add provider-specific parameters + for k, v := range providerParams { + result[k] = v + } + + return result +} + +// Global parameter validator instance +var globalParamValidator = NewParameterValidator() + +// SetGlobalParameterValidator sets the shared ParameterValidator instance. +// It’s primarily intended for test setup or one-time overrides. +// Note: calling this at runtime from multiple goroutines is not safe for concurrent use. +func SetGlobalParameterValidator(v *ParameterValidator) { + if v != nil { + globalParamValidator = v + } +} + +// ValidateAndFilterParamsForProvider is a convenience function that uses the global validator +// to filter parameters for a specific provider. This is the main function integrations should use. +func ValidateAndFilterParamsForProvider( + provider schemas.ModelProvider, + params *schemas.ModelParameters, +) *schemas.ModelParameters { + return globalParamValidator.ValidateAndFilterParams(provider, params) +} + +// GenericRouter provides a reusable router implementation for all integrations. +// It handles the common flow of: parse request β†’ convert to Bifrost β†’ execute β†’ convert response. +// Integration-specific logic is handled through the RouteConfig callbacks and converters. +type GenericRouter struct { + client *bifrost.Bifrost // Bifrost client for executing requests + handlerStore lib.HandlerStore // Config provider for the router + routes []RouteConfig // List of route configurations +} + +// NewGenericRouter creates a new generic router with the given bifrost client and route configurations. +// Each integration should create their own routes and pass them to this constructor. +func NewGenericRouter(client *bifrost.Bifrost, handlerStore lib.HandlerStore, routes []RouteConfig) *GenericRouter { + return &GenericRouter{ + client: client, + handlerStore: handlerStore, + routes: routes, + } +} + +// RegisterRoutes registers all configured routes on the given fasthttp router. +// This method implements the ExtensionRouter interface. +func (g *GenericRouter) RegisterRoutes(r *router.Router) { + for _, route := range g.routes { + // Validate route configuration at startup to fail fast + if route.GetRequestTypeInstance == nil { + log.Println("[WARN] route configuration is invalid: GetRequestTypeInstance cannot be nil for route " + route.Path) + continue + } + if route.RequestConverter == nil { + log.Println("[WARN] route configuration is invalid: RequestConverter cannot be nil for route " + route.Path) + continue + } + if route.ResponseConverter == nil { + log.Println("[WARN] route configuration is invalid: ResponseConverter cannot be nil for route " + route.Path) + continue + } + if route.ErrorConverter == nil { + log.Println("[WARN] route configuration is invalid: ErrorConverter cannot be nil for route " + route.Path) + continue + } + + // Test that GetRequestTypeInstance returns a valid instance + if testInstance := route.GetRequestTypeInstance(); testInstance == nil { + log.Println("[WARN] route configuration is invalid: GetRequestTypeInstance returned nil for route " + route.Path) + continue + } + + handler := g.createHandler(route) + switch strings.ToUpper(route.Method) { + case fasthttp.MethodPost: + r.POST(route.Path, handler) + case fasthttp.MethodGet: + r.GET(route.Path, handler) + case fasthttp.MethodPut: + r.PUT(route.Path, handler) + case fasthttp.MethodDelete: + r.DELETE(route.Path, handler) + default: + r.POST(route.Path, handler) // Default to POST + } + } +} + +// createHandler creates a fasthttp handler for the given route configuration. +// The handler follows this flow: +// 1. Parse JSON request body into the configured request type (for methods that expect bodies) +// 2. Execute pre-callback (if configured) for request modification/validation +// 3. Convert request to BifrostRequest using the configured converter +// 4. Execute the request through Bifrost (streaming or non-streaming) +// 5. Execute post-callback (if configured) for response modification +// 6. Convert and send the response using the configured response converter +func (g *GenericRouter) createHandler(config RouteConfig) fasthttp.RequestHandler { + return func(ctx *fasthttp.RequestCtx) { + // Parse request body into the integration-specific request type + // Note: config validation is performed at startup in RegisterRoutes + req := config.GetRequestTypeInstance() + + method := string(ctx.Method()) + + // Parse request body based on configuration + if method != fasthttp.MethodGet && method != fasthttp.MethodDelete { + if config.RequestParser != nil { + // Use custom parser (e.g., for multipart/form-data) + if err := config.RequestParser(ctx, req); err != nil { + g.sendError(ctx, config.ErrorConverter, newBifrostError(err, "failed to parse request")) + return + } + } else { + // Use default JSON parsing + body := ctx.Request.Body() + if len(body) > 0 { + if err := json.Unmarshal(body, req); err != nil { + g.sendError(ctx, config.ErrorConverter, newBifrostError(err, "Invalid JSON")) + return + } + } + } + } + + // Execute pre-request callback if configured + // This is typically used for extracting data from URL parameters + // or performing request validation after parsing + if config.PreCallback != nil { + if err := config.PreCallback(ctx, req); err != nil { + g.sendError(ctx, config.ErrorConverter, newBifrostError(err, "failed to execute pre-request callback: "+err.Error())) + return + } + } + + // Convert the integration-specific request to Bifrost format + bifrostReq, err := config.RequestConverter(req) + if err != nil { + g.sendError(ctx, config.ErrorConverter, newBifrostError(err, "failed to convert request to Bifrost format")) + return + } + if bifrostReq == nil { + g.sendError(ctx, config.ErrorConverter, newBifrostError(nil, "Invalid request")) + return + } + if bifrostReq.Model == "" { + g.sendError(ctx, config.ErrorConverter, newBifrostError(nil, "Model parameter is required")) + return + } + + // Check if streaming is requested + isStreaming := false + if streamingReq, ok := req.(StreamingRequest); ok { + isStreaming = streamingReq.IsStreamingRequested() + } + + // Execute the request through Bifrost + bifrostCtx := lib.ConvertToBifrostContext(ctx, g.handlerStore.ShouldAllowDirectKeys()) + + if ctx.UserValue(string(schemas.BifrostContextKeyDirectKey)) != nil { + key, ok := ctx.UserValue(string(schemas.BifrostContextKeyDirectKey)).(schemas.Key) + if ok { + *bifrostCtx = context.WithValue(*bifrostCtx, schemas.BifrostContextKeyDirectKey, key) + } + } + + if isStreaming { + g.handleStreamingRequest(ctx, config, req, bifrostReq, bifrostCtx) + } else { + g.handleNonStreamingRequest(ctx, config, req, bifrostReq, bifrostCtx) + } + } +} + +// handleNonStreamingRequest handles regular (non-streaming) requests +func (g *GenericRouter) handleNonStreamingRequest(ctx *fasthttp.RequestCtx, config RouteConfig, req interface{}, bifrostReq *schemas.BifrostRequest, bifrostCtx *context.Context) { + var result *schemas.BifrostResponse + var bifrostErr *schemas.BifrostError + + // Handle different request types + if bifrostReq.Input.TextCompletionInput != nil { + result, bifrostErr = g.client.TextCompletionRequest(*bifrostCtx, bifrostReq) + } else if bifrostReq.Input.ChatCompletionInput != nil { + result, bifrostErr = g.client.ChatCompletionRequest(*bifrostCtx, bifrostReq) + } else if bifrostReq.Input.EmbeddingInput != nil { + result, bifrostErr = g.client.EmbeddingRequest(*bifrostCtx, bifrostReq) + } else if bifrostReq.Input.SpeechInput != nil { + result, bifrostErr = g.client.SpeechRequest(*bifrostCtx, bifrostReq) + } else if bifrostReq.Input.TranscriptionInput != nil { + result, bifrostErr = g.client.TranscriptionRequest(*bifrostCtx, bifrostReq) + } + + // Handle errors + if bifrostErr != nil { + g.sendError(ctx, config.ErrorConverter, bifrostErr) + return + } + + // Execute post-request callback if configured + // This is typically used for response modification or additional processing + if config.PostCallback != nil { + if err := config.PostCallback(ctx, req, result); err != nil { + g.sendError(ctx, config.ErrorConverter, newBifrostError(err, "failed to execute post-request callback")) + return + } + } + + if result == nil { + g.sendError(ctx, config.ErrorConverter, newBifrostError(nil, "Bifrost response is nil after post-request callback")) + return + } + + // Convert Bifrost response to integration-specific format and send + response, err := config.ResponseConverter(result) + if err != nil { + g.sendError(ctx, config.ErrorConverter, newBifrostError(err, "failed to encode response")) + return + } + + if result.Speech != nil { + responseBytes, ok := response.([]byte) + if ok { + ctx.Response.Header.Set("Content-Type", "audio/mpeg") + ctx.Response.Header.Set("Content-Disposition", "attachment; filename=speech.mp3") + ctx.Response.Header.Set("Content-Length", strconv.Itoa(len(responseBytes))) + ctx.Response.SetBody(responseBytes) + return + } + } + + g.sendSuccess(ctx, config.ErrorConverter, response) +} + +// handleStreamingRequest handles streaming requests using Server-Sent Events (SSE) +func (g *GenericRouter) handleStreamingRequest(ctx *fasthttp.RequestCtx, config RouteConfig, req interface{}, bifrostReq *schemas.BifrostRequest, bifrostCtx *context.Context) { + // Set common SSE headers + ctx.SetContentType("text/event-stream") + ctx.Response.Header.Set("Cache-Control", "no-cache") + ctx.Response.Header.Set("Connection", "keep-alive") + ctx.Response.Header.Set("Access-Control-Allow-Origin", "*") + + var stream chan *schemas.BifrostStream + var bifrostErr *schemas.BifrostError + + // Handle different request types + if bifrostReq.Input.ChatCompletionInput != nil { + stream, bifrostErr = g.client.ChatCompletionStreamRequest(*bifrostCtx, bifrostReq) + } else if bifrostReq.Input.SpeechInput != nil { + stream, bifrostErr = g.client.SpeechStreamRequest(*bifrostCtx, bifrostReq) + } else if bifrostReq.Input.TranscriptionInput != nil { + stream, bifrostErr = g.client.TranscriptionStreamRequest(*bifrostCtx, bifrostReq) + } + + // Get the streaming channel from Bifrost + if bifrostErr != nil { + // Send error in SSE format + g.sendStreamError(ctx, config, bifrostErr) + return + } + + // Check if streaming is configured for this route + if config.StreamConfig == nil { + g.sendStreamError(ctx, config, newBifrostError(nil, "streaming is not supported for this integration")) + return + } + + // Handle streaming using the centralized approach + g.handleStreaming(ctx, config, stream) +} + +// handleStreaming processes a stream of BifrostResponse objects and sends them as Server-Sent Events (SSE). +// It handles both successful responses and errors in the streaming format. +// +// SSE FORMAT HANDLING: +// +// By default, all responses and errors are sent in the standard SSE format: +// +// data: {"response": "content"}\n\n +// +// However, some providers (like Anthropic) require custom SSE event formats with explicit event types: +// +// event: content_block_delta +// data: {"type": "content_block_delta", "delta": {...}} +// +// event: message_stop +// data: {"type": "message_stop"} +// +// STREAMCONFIG CONVERTER BEHAVIOR: +// +// The StreamConfig.ResponseConverter and StreamConfig.ErrorConverter functions can return: +// +// 1. OBJECTS (default behavior): +// - Return any Go struct/map/interface{} +// - Will be JSON marshaled and wrapped as: data: {json}\n\n +// - Example: return map[string]interface{}{"content": "hello"} +// - Result: data: {"content":"hello"}\n\n +// +// 2. STRINGS (custom SSE format): +// - Return a complete SSE string with custom event types and formatting +// - Will be sent directly without any wrapping or modification +// - Example: return "event: content_block_delta\ndata: {\"type\":\"text\"}\n\n" +// - Result: event: content_block_delta +// data: {"type":"text"} +// +// IMPLEMENTATION GUIDELINES: +// +// For standard providers (OpenAI, etc.): Return objects from converters +// For custom SSE providers (Anthropic, etc.): Return pre-formatted SSE strings +// +// When returning strings, ensure they: +// - Include proper event: lines (if needed) +// - Include data: lines with JSON content +// - End with \n\n for proper SSE formatting +// - Follow the provider's specific SSE event specification +func (g *GenericRouter) handleStreaming(ctx *fasthttp.RequestCtx, config RouteConfig, streamChan chan *schemas.BifrostStream) { + // Use streaming response writer + ctx.Response.SetBodyStreamWriter(func(w *bufio.Writer) { + defer w.Flush() + + // Process streaming responses + for response := range streamChan { + if response == nil { + continue + } + + // Check for context cancellation + select { + case <-ctx.Done(): + return + default: + } + + // Handle errors + if response.BifrostError != nil { + var errorResponse interface{} + var errorJSON []byte + var err error + + // Use stream error converter if available, otherwise fallback to regular error converter + if config.StreamConfig != nil && config.StreamConfig.ErrorConverter != nil { + errorResponse = config.StreamConfig.ErrorConverter(response.BifrostError) + } else if config.ErrorConverter != nil { + errorResponse = config.ErrorConverter(response.BifrostError) + } else { + // Default error response + errorResponse = map[string]interface{}{ + "error": map[string]interface{}{ + "type": "internal_error", + "message": "An error occurred while processing your request", + }, + } + } + + // Check if the error converter returned a raw SSE string or JSON object + if sseErrorString, ok := errorResponse.(string); ok { + // CUSTOM SSE FORMAT: The converter returned a complete SSE string + // This is used by providers like Anthropic that need custom event types + // Example: "event: error\ndata: {...}\n\n" + if _, err := fmt.Fprint(w, sseErrorString); err != nil { + return + } + } else { + // STANDARD SSE FORMAT: The converter returned an object + // This will be JSON marshaled and wrapped as "data: {json}\n\n" + // Used by most providers (OpenAI, Google, etc.) + errorJSON, err = json.Marshal(errorResponse) + if err != nil { + // Fallback to basic error if marshaling fails + basicError := map[string]interface{}{ + "error": map[string]interface{}{ + "type": "internal_error", + "message": "An error occurred while processing your request", + }, + } + if errorJSON, err = json.Marshal(basicError); err != nil { + return // Can't even send basic error + } + } + + // Send error as SSE data + if _, err := fmt.Fprintf(w, "data: %s\n\n", errorJSON); err != nil { + return + } + } + + // Flush and return on error + if err := w.Flush(); err != nil { + return + } + return // End stream on error + } + + // Handle successful responses + if response.BifrostResponse != nil { + // Convert response to integration-specific streaming format + var convertedResponse interface{} + var err error + + if config.StreamConfig.ResponseConverter != nil { + convertedResponse, err = config.StreamConfig.ResponseConverter(response.BifrostResponse) + } else { + // Fallback to regular response converter + convertedResponse, err = config.ResponseConverter(response.BifrostResponse) + } + + if err != nil { + // Log conversion error but continue processing + log.Printf("Failed to convert streaming response: %v", err) + continue + } + + // Check if the converter returned a raw SSE string or JSON object + if sseString, ok := convertedResponse.(string); ok { + // CUSTOM SSE FORMAT: The converter returned a complete SSE string + // This is used by providers like Anthropic that need custom event types + // Example: "event: content_block_delta\ndata: {...}\n\n" + if _, err := fmt.Fprint(w, sseString); err != nil { + return // Network error, stop streaming + } + } else { + // STANDARD SSE FORMAT: The converter returned an object + // This will be JSON marshaled and wrapped as "data: {json}\n\n" + // Used by most providers (OpenAI, Google, etc.) + responseJSON, err := json.Marshal(convertedResponse) + if err != nil { + // Log JSON marshaling error but continue processing + log.Printf("Failed to marshal streaming response: %v", err) + continue + } + + // Send as SSE data + if _, err := fmt.Fprintf(w, "data: %s\n\n", responseJSON); err != nil { + return // Network error, stop streaming + } + } + + // Flush immediately to send the chunk + if err := w.Flush(); err != nil { + return // Network error, stop streaming + } + } + } + }) +} + +// sendStreamError sends an error in streaming format using the stream error converter if available +func (g *GenericRouter) sendStreamError(ctx *fasthttp.RequestCtx, config RouteConfig, bifrostErr *schemas.BifrostError) { + var errorResponse interface{} + + // Use stream error converter if available, otherwise fallback to regular error converter + if config.StreamConfig != nil && config.StreamConfig.ErrorConverter != nil { + errorResponse = config.StreamConfig.ErrorConverter(bifrostErr) + } else { + errorResponse = config.ErrorConverter(bifrostErr) + } + + errorJSON, err := json.Marshal(map[string]interface{}{ + "error": errorResponse, + }) + if err != nil { + log.Printf("Failed to marshal error for SSE: %v", err) + ctx.SetStatusCode(fasthttp.StatusInternalServerError) + return + } + + if _, err := fmt.Fprintf(ctx, "data: %s\n\n", errorJSON); err != nil { + log.Printf("Failed to write SSE error: %v", err) + } +} + +// sendError sends an error response with the appropriate status code and JSON body. +// It handles different error types (string, error interface, or arbitrary objects). +func (g *GenericRouter) sendError(ctx *fasthttp.RequestCtx, errorConverter ErrorConverter, bifrostErr *schemas.BifrostError) { + if bifrostErr.StatusCode != nil { + ctx.SetStatusCode(*bifrostErr.StatusCode) + } else { + ctx.SetStatusCode(fasthttp.StatusInternalServerError) + } + ctx.SetContentType("application/json") + + errorBody, err := json.Marshal(errorConverter(bifrostErr)) + if err != nil { + ctx.SetStatusCode(fasthttp.StatusInternalServerError) + ctx.SetBodyString(fmt.Sprintf("failed to encode error response: %v", err)) + return + } + + ctx.SetBody(errorBody) +} + +// sendSuccess sends a successful response with HTTP 200 status and JSON body. +func (g *GenericRouter) sendSuccess(ctx *fasthttp.RequestCtx, errorConverter ErrorConverter, response interface{}) { + ctx.SetStatusCode(fasthttp.StatusOK) + ctx.SetContentType("application/json") + + responseBody, err := json.Marshal(response) + if err != nil { + g.sendError(ctx, errorConverter, newBifrostError(err, "failed to encode response")) + return + } + + ctx.SetBody(responseBody) +} + +// ValidProviders is a pre-computed map for efficient O(1) provider validation. +var ValidProviders = map[schemas.ModelProvider]bool{ + schemas.OpenAI: true, + schemas.Azure: true, + schemas.Anthropic: true, + schemas.Bedrock: true, + schemas.Cohere: true, + schemas.Vertex: true, + schemas.Mistral: true, + schemas.Ollama: true, + schemas.Groq: true, + schemas.SGL: true, + schemas.Parasail: true, + schemas.Cerebras: true, + schemas.Gemini: true, + schemas.OpenRouter: true, +} + +// ParseModelString extracts provider and model from a model string. +// For model strings like "anthropic/claude", it returns ("anthropic", "claude"). +// For model strings like "claude", it returns ("", "claude"). +func ParseModelString(model string, defaultProvider schemas.ModelProvider, checkProviderFromModel bool) (schemas.ModelProvider, string) { + // Check if model contains a provider prefix (only split on first "/" to preserve model names with "/") + if strings.Contains(model, "/") { + parts := strings.SplitN(model, "/", 2) + if len(parts) == 2 { + extractedProvider := parts[0] + extractedModel := parts[1] + + return schemas.ModelProvider(extractedProvider), extractedModel + } + } + + //TODO add model wise check for provider + + // No provider prefix found, return empty provider and the original model + return defaultProvider, model +} + +// GetProviderFromModel determines the appropriate provider based on model name patterns +// This function uses comprehensive pattern matching to identify the correct provider +// for various model naming conventions used across different AI providers. +func GetProviderFromModel(model string) schemas.ModelProvider { + // Check if model contains a provider prefix (only split on first "/" to preserve model names with "/") + if strings.Contains(model, "/") { + parts := strings.SplitN(model, "/", 2) + if len(parts) > 1 { + extractedProvider := parts[0] + + if ValidProviders[schemas.ModelProvider(extractedProvider)] { + return schemas.ModelProvider(extractedProvider) + } + } + } + + // Normalize model name for case-insensitive matching + modelLower := strings.ToLower(strings.TrimSpace(model)) + + // Azure OpenAI Models - check first to prevent false positives from OpenAI "gpt" patterns + if isAzureModel(modelLower) { + return schemas.Azure + } + + // OpenAI Models - comprehensive pattern matching + if isOpenAIModel(modelLower) { + return schemas.OpenAI + } + + // Anthropic Models - Claude family + if isAnthropicModel(modelLower) { + return schemas.Anthropic + } + + // Google Vertex AI Models - Gemini and Palm family + if isVertexModel(modelLower) { + return schemas.Vertex + } + + // AWS Bedrock Models - various model providers through Bedrock + if isBedrockModel(modelLower) { + return schemas.Bedrock + } + + // Cohere Models - Command and Embed family + if isCohereModel(modelLower) { + return schemas.Cohere + } + + // Google GenAI Models - Gemini and Palm family + if isGeminiModel(modelLower) { + return schemas.Gemini + } + + // Default to OpenAI for unknown models (most LiteLLM compatible) + return schemas.OpenAI +} + +// isOpenAIModel checks for OpenAI model patterns +func isOpenAIModel(model string) bool { + // Exclude Azure models to prevent overlap + if strings.Contains(model, "azure/") { + return false + } + + openaiPatterns := []string{ + "gpt", "davinci", "curie", "babbage", "ada", "o1", "o3", "o4", + "text-embedding", "dall-e", "whisper", "tts", "chatgpt", + } + + return matchesAnyPattern(model, openaiPatterns) +} + +// isAzureModel checks for Azure OpenAI specific patterns +func isAzureModel(model string) bool { + azurePatterns := []string{ + "azure", "model-router", "computer-use-preview", + } + + return matchesAnyPattern(model, azurePatterns) +} + +// isAnthropicModel checks for Anthropic Claude model patterns +func isAnthropicModel(model string) bool { + anthropicPatterns := []string{ + "claude", "anthropic/", + } + + return matchesAnyPattern(model, anthropicPatterns) +} + +var geminiRegexp = regexp.MustCompile(`\b(gemini|gemini-embedding|palm|bison|gecko)\b`) + +// isGeminiModel checks for Google Gemini model patterns using strict regex matching +func isGeminiModel(model string) bool { + return geminiRegexp.MatchString(model) +} + +// isVertexModel checks for Google Vertex AI model patterns +func isVertexModel(model string) bool { + vertexPatterns := []string{ + "gemini", "palm", "bison", "gecko", "vertex/", "google/", + } + + return matchesAnyPattern(model, vertexPatterns) +} + +// isBedrockModel checks for AWS Bedrock model patterns +func isBedrockModel(model string) bool { + bedrockPatterns := []string{ + "bedrock", "bedrock.amazonaws.com/", "bedrock/", + "amazon.titan", "amazon.nova", "aws/amazon.", + "ai21.jamba", "ai21.j2", "aws/ai21.", + "meta.llama", "aws/meta.", + "stability.stable-diffusion", "stability.sd3", "aws/stability.", + "anthropic.claude", "aws/anthropic.", + "cohere.command", "cohere.embed", "aws/cohere.", + "mistral.mistral", "mistral.mixtral", "aws/mistral.", + "titan-text", "titan-embed", "nova-micro", "nova-lite", "nova-pro", + "jamba-instruct", "j2-ultra", "j2-mid", + "llama-2", "llama-3", "llama-3.1", "llama-3.2", + "stable-diffusion-xl", "sd3-large", + } + + return matchesAnyPattern(model, bedrockPatterns) +} + +// isCohereModel checks for Cohere model patterns +func isCohereModel(model string) bool { + coherePatterns := []string{ + "command-", "embed-", "cohere", + } + + return matchesAnyPattern(model, coherePatterns) +} + +// matchesAnyPattern checks if the model matches any of the given patterns +func matchesAnyPattern(model string, patterns []string) bool { + for _, pattern := range patterns { + if strings.Contains(model, pattern) { + return true + } + } + return false +} + +// newBifrostError wraps a standard error into a BifrostError with IsBifrostError set to false. +// This helper function reduces code duplication when handling non-Bifrost errors. +func newBifrostError(err error, message string) *schemas.BifrostError { + if err == nil { + return &schemas.BifrostError{ + IsBifrostError: false, + Error: schemas.ErrorField{ + Message: message, + }, + } + } + + return &schemas.BifrostError{ + IsBifrostError: false, + Error: schemas.ErrorField{ + Message: message, + Error: err, + }, + } +} + +// MapFinishReasonToProvider maps OpenAI-compatible finish reasons to provider-specific format +func MapFinishReasonToProvider(finishReason string, targetProvider schemas.ModelProvider) string { + switch targetProvider { + case schemas.Anthropic: + return mapFinishReasonToAnthropic(finishReason) + default: + // For OpenAI, Azure, and other providers, pass through as-is + return finishReason + } +} + +// mapFinishReasonToAnthropic maps OpenAI finish reasons to Anthropic format +func mapFinishReasonToAnthropic(finishReason string) string { + switch finishReason { + case "stop": + return "end_turn" + case "length": + return "max_tokens" + case "tool_calls": + return "tool_use" + default: + // Pass through other reasons like "pause_turn", "refusal", "stop_sequence", etc. + return finishReason + } +} diff --git a/transports/bifrost-http/lib/account.go b/transports/bifrost-http/lib/account.go new file mode 100644 index 000000000..59c6f579d --- /dev/null +++ b/transports/bifrost-http/lib/account.go @@ -0,0 +1,115 @@ +// Package lib provides core functionality for the Bifrost HTTP service, +// including context propagation, header management, and integration with monitoring systems. +package lib + +import ( + "context" + "fmt" + + "github.com/maximhq/bifrost/core/schemas" +) + +// BaseAccount implements the Account interface for Bifrost. +// It manages provider configurations using a in-memory store for persistent storage. +// All data processing (environment variables, key configs) is done upfront in the store. +type BaseAccount struct { + store *Config // store for in-memory configuration +} + +// NewBaseAccount creates a new BaseAccount with the given store +func NewBaseAccount(store *Config) *BaseAccount { + return &BaseAccount{ + store: store, + } +} + +// GetConfiguredProviders returns a list of all configured providers. +// Implements the Account interface. +func (baseAccount *BaseAccount) GetConfiguredProviders() ([]schemas.ModelProvider, error) { + if baseAccount.store == nil { + return nil, fmt.Errorf("store not initialized") + } + + return baseAccount.store.GetAllProviders() +} + +// GetKeysForProvider returns the API keys configured for a specific provider. +// Keys are already processed (environment variables resolved) by the store. +// Implements the Account interface. +func (baseAccount *BaseAccount) GetKeysForProvider(ctx *context.Context, providerKey schemas.ModelProvider) ([]schemas.Key, error) { + if baseAccount.store == nil { + return nil, fmt.Errorf("store not initialized") + } + + config, err := baseAccount.store.GetProviderConfigRaw(providerKey) + if err != nil { + return nil, err + } + + keys := config.Keys + + if baseAccount.store.ClientConfig.EnableGovernance { + if v := (*ctx).Value(schemas.BifrostContextKey("bf-governance-include-only-keys")); v != nil { + if includeOnlyKeys, ok := v.([]string); ok { + if len(includeOnlyKeys) == 0 { + // header present but empty means "no keys allowed" + keys = nil + } else { + set := make(map[string]struct{}, len(includeOnlyKeys)) + for _, id := range includeOnlyKeys { + set[id] = struct{}{} + } + filtered := make([]schemas.Key, 0, len(keys)) + for _, key := range keys { + if _, ok := set[key.ID]; ok { + filtered = append(filtered, key) + } + } + keys = filtered + } + } + } + } + + return keys, nil +} + +// GetConfigForProvider returns the complete configuration for a specific provider. +// Configuration is already fully processed (environment variables, key configs) by the store. +// Implements the Account interface. +func (baseAccount *BaseAccount) GetConfigForProvider(providerKey schemas.ModelProvider) (*schemas.ProviderConfig, error) { + if baseAccount.store == nil { + return nil, fmt.Errorf("store not initialized") + } + + config, err := baseAccount.store.GetProviderConfigRaw(providerKey) + if err != nil { + return nil, err + } + + providerConfig := &schemas.ProviderConfig{} + + if config.ProxyConfig != nil { + providerConfig.ProxyConfig = config.ProxyConfig + } + + if config.NetworkConfig != nil { + providerConfig.NetworkConfig = *config.NetworkConfig + } else { + providerConfig.NetworkConfig = schemas.DefaultNetworkConfig + } + + if config.ConcurrencyAndBufferSize != nil { + providerConfig.ConcurrencyAndBufferSize = *config.ConcurrencyAndBufferSize + } else { + providerConfig.ConcurrencyAndBufferSize = schemas.DefaultConcurrencyAndBufferSize + } + + providerConfig.SendBackRawResponse = config.SendBackRawResponse + + if config.CustomProviderConfig != nil { + providerConfig.CustomProviderConfig = config.CustomProviderConfig + } + + return providerConfig, nil +} diff --git a/transports/bifrost-http/lib/config.go b/transports/bifrost-http/lib/config.go new file mode 100644 index 000000000..932b46ab1 --- /dev/null +++ b/transports/bifrost-http/lib/config.go @@ -0,0 +1,2190 @@ +// Package lib provides core functionality for the Bifrost HTTP service, +// including context propagation, header management, and integration with monitoring systems. +package lib + +import ( + "context" + "encoding/json" + "fmt" + "os" + "path/filepath" + "strings" + "sync" + + "github.com/google/uuid" + bifrost "github.com/maximhq/bifrost/core" + "github.com/maximhq/bifrost/core/schemas" + "github.com/maximhq/bifrost/framework/configstore" + "github.com/maximhq/bifrost/framework/logstore" + "github.com/maximhq/bifrost/framework/vectorstore" + "github.com/maximhq/bifrost/plugins/semanticcache" + "gorm.io/gorm" +) + +// HandlerStore provides access to runtime configuration values for handlers. +// This interface allows handlers to access only the configuration they need +// without depending on the entire ConfigStore, improving testability and decoupling. +type HandlerStore interface { + // ShouldAllowDirectKeys returns whether direct API keys in headers are allowed + ShouldAllowDirectKeys() bool +} + +// ConfigData represents the configuration data for the Bifrost HTTP transport. +// It contains the client configuration, provider configurations, MCP configuration, +// vector store configuration, config store configuration, and logs store configuration. +type ConfigData struct { + Client *configstore.ClientConfig `json:"client"` + Providers map[string]configstore.ProviderConfig `json:"providers"` + MCP *schemas.MCPConfig `json:"mcp,omitempty"` + Governance *configstore.GovernanceConfig `json:"governance,omitempty"` + VectorStoreConfig *vectorstore.Config `json:"vector_store,omitempty"` + ConfigStoreConfig *configstore.Config `json:"config_store,omitempty"` + LogsStoreConfig *logstore.Config `json:"logs_store,omitempty"` + Plugins []*schemas.PluginConfig `json:"plugins,omitempty"` +} + +// UnmarshalJSON unmarshals the ConfigData from JSON using internal unmarshallers +// for VectorStoreConfig, ConfigStoreConfig, and LogsStoreConfig to ensure proper +// type safety and configuration parsing. +func (cd *ConfigData) UnmarshalJSON(data []byte) error { + // First, unmarshal into a temporary struct to get all fields except the complex configs + type TempConfigData struct { + Client *configstore.ClientConfig `json:"client"` + Providers map[string]configstore.ProviderConfig `json:"providers"` + MCP *schemas.MCPConfig `json:"mcp,omitempty"` + Governance *configstore.GovernanceConfig `json:"governance,omitempty"` + VectorStoreConfig json.RawMessage `json:"vector_store,omitempty"` + ConfigStoreConfig json.RawMessage `json:"config_store,omitempty"` + LogsStoreConfig json.RawMessage `json:"logs_store,omitempty"` + Plugins []*schemas.PluginConfig `json:"plugins,omitempty"` + } + + var temp TempConfigData + if err := json.Unmarshal(data, &temp); err != nil { + return fmt.Errorf("failed to unmarshal config data: %w", err) + } + + // Set simple fields + cd.Client = temp.Client + cd.Providers = temp.Providers + cd.MCP = temp.MCP + cd.Governance = temp.Governance + cd.Plugins = temp.Plugins + + // Parse VectorStoreConfig using its internal unmarshaler + if len(temp.VectorStoreConfig) > 0 { + var vectorStoreConfig vectorstore.Config + if err := json.Unmarshal(temp.VectorStoreConfig, &vectorStoreConfig); err != nil { + return fmt.Errorf("failed to unmarshal vector store config: %w", err) + } + cd.VectorStoreConfig = &vectorStoreConfig + } + + // Parse ConfigStoreConfig using its internal unmarshaler + if len(temp.ConfigStoreConfig) > 0 { + var configStoreConfig configstore.Config + if err := json.Unmarshal(temp.ConfigStoreConfig, &configStoreConfig); err != nil { + return fmt.Errorf("failed to unmarshal config store config: %w", err) + } + cd.ConfigStoreConfig = &configStoreConfig + } + + // Parse LogsStoreConfig using its internal unmarshaler + if len(temp.LogsStoreConfig) > 0 { + var logsStoreConfig logstore.Config + if err := json.Unmarshal(temp.LogsStoreConfig, &logsStoreConfig); err != nil { + return fmt.Errorf("failed to unmarshal logs store config: %w", err) + } + cd.LogsStoreConfig = &logsStoreConfig + } + return nil +} + +// Config represents a high-performance in-memory configuration store for Bifrost. +// It provides thread-safe access to provider configurations with database persistence. +// +// Features: +// - Pure in-memory storage for ultra-fast access +// - Environment variable processing for API keys and key-level configurations +// - Thread-safe operations with read-write mutexes +// - Real-time configuration updates via HTTP API +// - Automatic database persistence for all changes +// - Support for provider-specific key configurations (Azure, Vertex, Bedrock) +type Config struct { + mu sync.RWMutex + muMCP sync.RWMutex + client *bifrost.Bifrost + + configPath string + + // Stores + ConfigStore configstore.ConfigStore + VectorStore vectorstore.VectorStore + LogsStore logstore.LogStore + + // In-memory storage + ClientConfig configstore.ClientConfig + Providers map[schemas.ModelProvider]configstore.ProviderConfig + MCPConfig *schemas.MCPConfig + GovernanceConfig *configstore.GovernanceConfig + + // Track which keys come from environment variables + EnvKeys map[string][]configstore.EnvKeyInfo + + // Plugin configs + Plugins []*schemas.PluginConfig +} + +var DefaultClientConfig = configstore.ClientConfig{ + DropExcessRequests: false, + PrometheusLabels: []string{}, + InitialPoolSize: schemas.DefaultInitialPoolSize, + EnableLogging: true, + EnableGovernance: true, + EnforceGovernanceHeader: false, + AllowDirectKeys: false, + AllowedOrigins: []string{}, + MaxRequestBodySizeMB: 100, +} + +// LoadConfig loads initial configuration from a JSON config file into memory +// with full preprocessing including environment variable resolution and key config parsing. +// All processing is done upfront to ensure zero latency when retrieving data. +// +// If the config file doesn't exist, the system starts with default configuration +// and users can add providers dynamically via the HTTP API. +// +// This method handles: +// - JSON config file parsing +// - Environment variable substitution for API keys (env.VARIABLE_NAME) +// - Key-level config processing for Azure, Vertex, and Bedrock (Endpoint, APIVersion, ProjectID, Region, AuthCredentials) +// - Case conversion for provider names (e.g., "OpenAI" -> "openai") +// - In-memory storage for ultra-fast access during request processing +// - Graceful handling of missing config files +func LoadConfig(ctx context.Context, configDirPath string) (*Config, error) { + // Initialize separate database connections for optimal performance at scale + configFilePath := filepath.Join(configDirPath, "config.json") + configDBPath := filepath.Join(configDirPath, "config.db") + logsDBPath := filepath.Join(configDirPath, "logs.db") + + config := &Config{ + configPath: configFilePath, + EnvKeys: make(map[string][]configstore.EnvKeyInfo), + Providers: make(map[schemas.ModelProvider]configstore.ProviderConfig), + } + + absConfigFilePath, err := filepath.Abs(configFilePath) + if err != nil { + return nil, fmt.Errorf("failed to get absolute path for config file: %w", err) + } + + // Check if config file exists + data, err := os.ReadFile(configFilePath) + if err != nil { + // If config file doesn't exist, we will directly use the config store (create one if it doesn't exist) + if os.IsNotExist(err) { + logger.Info("config file not found at path: %s, initializing with default values", absConfigFilePath) + // Initializing with default values + config.ConfigStore, err = configstore.NewConfigStore(&configstore.Config{ + Enabled: true, + Type: configstore.ConfigStoreTypeSQLite, + Config: &configstore.SQLiteConfig{ + Path: configDBPath, + }, + }, logger) + if err != nil { + return nil, fmt.Errorf("failed to initialize config store: %w", err) + } + // Checking if client config already exist + clientConfig, err := config.ConfigStore.GetClientConfig() + if err != nil { + return nil, fmt.Errorf("failed to get client config: %w", err) + } + if clientConfig == nil { + clientConfig = &DefaultClientConfig + } else { + // For backward compatibility, we need to handle cases where config is already present but max request body size is not set + if clientConfig.MaxRequestBodySizeMB == 0 { + clientConfig.MaxRequestBodySizeMB = DefaultClientConfig.MaxRequestBodySizeMB + } + } + err = config.ConfigStore.UpdateClientConfig(clientConfig) + if err != nil { + return nil, fmt.Errorf("failed to update client config: %w", err) + } + config.ClientConfig = *clientConfig + // Checking if log store config already exist + logStoreConfig, err := config.ConfigStore.GetLogsStoreConfig() + if err != nil { + return nil, fmt.Errorf("failed to get logs store config: %w", err) + } + logger.Debug("log store config from DB: %v", logStoreConfig) + if logStoreConfig == nil { + logStoreConfig = &logstore.Config{ + Enabled: true, + Type: logstore.LogStoreTypeSQLite, + Config: &logstore.SQLiteConfig{ + Path: logsDBPath, + }, + } + } + logger.Info("config store initialized; initializing logs store.") + config.LogsStore, err = logstore.NewLogStore(logStoreConfig, logger) + if err != nil { + return nil, fmt.Errorf("failed to initialize logs store: %v", err) + } + err = config.ConfigStore.UpdateLogsStoreConfig(logStoreConfig) + if err != nil { + return nil, fmt.Errorf("failed to update logs store config: %w", err) + } + // No providers in database, auto-detect from environment + providers, err := config.ConfigStore.GetProvidersConfig() + if err != nil { + return nil, fmt.Errorf("failed to get providers config: %w", err) + } + if providers == nil { + config.autoDetectProviders() + providers = config.Providers + // Store providers config in database + err = config.ConfigStore.UpdateProvidersConfig(providers) + if err != nil { + return nil, fmt.Errorf("failed to update providers config: %w", err) + } + } else { + processedProviders := make(map[schemas.ModelProvider]configstore.ProviderConfig) + for providerKey, dbProvider := range providers { + provider := schemas.ModelProvider(providerKey) + // Convert database keys to schemas.Key + keys := make([]schemas.Key, len(dbProvider.Keys)) + for i, dbKey := range dbProvider.Keys { + keys[i] = schemas.Key{ + ID: dbKey.ID, // Key ID is passed in dbKey, not ID + Value: dbKey.Value, + Models: dbKey.Models, + Weight: dbKey.Weight, + AzureKeyConfig: dbKey.AzureKeyConfig, + VertexKeyConfig: dbKey.VertexKeyConfig, + BedrockKeyConfig: dbKey.BedrockKeyConfig, + } + + } + providerConfig := configstore.ProviderConfig{ + Keys: keys, + NetworkConfig: dbProvider.NetworkConfig, + ConcurrencyAndBufferSize: dbProvider.ConcurrencyAndBufferSize, + ProxyConfig: dbProvider.ProxyConfig, + SendBackRawResponse: dbProvider.SendBackRawResponse, + CustomProviderConfig: dbProvider.CustomProviderConfig, + } + if err := ValidateCustomProvider(providerConfig, provider); err != nil { + logger.Warn("invalid custom provider config for %s: %v", provider, err) + continue + } + processedProviders[provider] = providerConfig + } + config.Providers = processedProviders + } + // Checking if MCP config already exists + mcpConfig, err := config.ConfigStore.GetMCPConfig() + if err != nil { + return nil, fmt.Errorf("failed to get MCP config: %w", err) + } + if mcpConfig == nil { + if err := config.processMCPEnvVars(); err != nil { + logger.Warn("failed to process MCP env vars: %v", err) + } + if err := config.ConfigStore.UpdateMCPConfig(config.MCPConfig, config.EnvKeys); err != nil { + return nil, fmt.Errorf("failed to update MCP config: %w", err) + } + // Refresh from store to ensure parity with persisted state + if mcpConfig, err = config.ConfigStore.GetMCPConfig(); err != nil { + return nil, fmt.Errorf("failed to get MCP config after update: %w", err) + } + config.MCPConfig = mcpConfig + } else { + // Use the saved config from the store + config.MCPConfig = mcpConfig + } + // Checking if plugins already exist + plugins, err := config.ConfigStore.GetPlugins() + if err != nil { + return nil, fmt.Errorf("failed to get plugins: %w", err) + } + if plugins == nil { + config.Plugins = []*schemas.PluginConfig{} + } else { + config.Plugins = make([]*schemas.PluginConfig, len(plugins)) + for i, plugin := range plugins { + pluginConfig := &schemas.PluginConfig{ + Name: plugin.Name, + Enabled: plugin.Enabled, + Config: plugin.Config, + } + if plugin.Name == semanticcache.PluginName { + if err := config.AddProviderKeysToSemanticCacheConfig(pluginConfig); err != nil { + logger.Warn("failed to add provider keys to semantic cache config: %v", err) + } + } + config.Plugins[i] = pluginConfig + } + } + // Load environment variable tracking + var dbEnvKeys map[string][]configstore.EnvKeyInfo + if dbEnvKeys, err = config.ConfigStore.GetEnvKeys(); err != nil { + return nil, err + } + config.EnvKeys = make(map[string][]configstore.EnvKeyInfo) + for envVar, dbEnvKey := range dbEnvKeys { + for _, dbEnvKey := range dbEnvKey { + config.EnvKeys[envVar] = append(config.EnvKeys[envVar], configstore.EnvKeyInfo{ + EnvVar: dbEnvKey.EnvVar, + Provider: dbEnvKey.Provider, + KeyType: dbEnvKey.KeyType, + ConfigPath: dbEnvKey.ConfigPath, + KeyID: dbEnvKey.KeyID, + }) + } + } + err = config.ConfigStore.UpdateEnvKeys(config.EnvKeys) + if err != nil { + return nil, fmt.Errorf("failed to update env keys: %w", err) + } + return config, nil + } + return nil, fmt.Errorf("failed to read config file: %w", err) + } + + // If config file exists, we will use it to only bootstrap config tables. + + logger.Info("loading configuration from: %s", absConfigFilePath) + + var configData ConfigData + if err := json.Unmarshal(data, &configData); err != nil { + return nil, fmt.Errorf("failed to unmarshal config: %w", err) + } + + // Initializing config store + if configData.ConfigStoreConfig != nil && configData.ConfigStoreConfig.Enabled { + config.ConfigStore, err = configstore.NewConfigStore(configData.ConfigStoreConfig, logger) + if err != nil { + return nil, err + } + logger.Info("config store initialized") + } + + // Initializing log store + if configData.LogsStoreConfig != nil && configData.LogsStoreConfig.Enabled { + config.LogsStore, err = logstore.NewLogStore(configData.LogsStoreConfig, logger) + if err != nil { + return nil, err + } + logger.Info("logs store initialized") + } + + // Initializing vector store + if configData.VectorStoreConfig != nil && configData.VectorStoreConfig.Enabled { + logger.Info("connecting to vectorstore") + // Checking type of the store + config.VectorStore, err = vectorstore.NewVectorStore(ctx, configData.VectorStoreConfig, logger) + if err != nil { + logger.Fatal("failed to connect to vector store: %v", err) + } + if config.ConfigStore != nil { + err = config.ConfigStore.UpdateVectorStoreConfig(configData.VectorStoreConfig) + if err != nil { + logger.Warn("failed to update vector store config: %v", err) + } + } + } + + // From now on, config store gets the priority if enabled and we find data + // if we don't find any data in the store, then we resort to config file + + //NOTE: We follow a standard practice here to first look in store -> not present then use config file -> if present in config file then update store. + + // 1. Check for Client Config + + var clientConfig *configstore.ClientConfig + if config.ConfigStore != nil { + clientConfig, err = config.ConfigStore.GetClientConfig() + if err != nil { + logger.Warn("failed to get client config from store: %v", err) + } + } + + if clientConfig != nil { + config.ClientConfig = *clientConfig + + // For backward compatibility, we need to handle cases where config is already present but max request body size is not set + if config.ClientConfig.MaxRequestBodySizeMB == 0 { + config.ClientConfig.MaxRequestBodySizeMB = DefaultClientConfig.MaxRequestBodySizeMB + } + } else { + logger.Debug("client config not found in store, using config file") + // Process core configuration if present, otherwise use defaults + if configData.Client != nil { + config.ClientConfig = *configData.Client + + // For backward compatibility, we need to handle cases where config is already present but max request body size is not set + if config.ClientConfig.MaxRequestBodySizeMB == 0 { + config.ClientConfig.MaxRequestBodySizeMB = DefaultClientConfig.MaxRequestBodySizeMB + } + } else { + config.ClientConfig = DefaultClientConfig + } + + if config.ConfigStore != nil { + logger.Debug("updating client config in store") + err = config.ConfigStore.UpdateClientConfig(&config.ClientConfig) + if err != nil { + logger.Warn("failed to update client config: %v", err) + } + } + } + + // 2. Check for Providers + + var processedProviders map[schemas.ModelProvider]configstore.ProviderConfig + if config.ConfigStore != nil { + logger.Debug("getting providers config from store") + processedProviders, err = config.ConfigStore.GetProvidersConfig() + if err != nil { + logger.Warn("failed to get providers config from store: %v", err) + } + } + + if processedProviders != nil { + config.Providers = processedProviders + } else { + // If we don't have any data in the store, we will process the data from the config file + logger.Debug("no providers config found in store, processing from config file") + processedProviders = make(map[schemas.ModelProvider]configstore.ProviderConfig) + // Process provider configurations + if configData.Providers != nil { + // Process each provider configuration + for providerName, cfg := range configData.Providers { + newEnvKeys := make(map[string]struct{}) + provider := schemas.ModelProvider(strings.ToLower(providerName)) + + // Process environment variables in keys (including key-level configs) + for i, key := range cfg.Keys { + if key.ID == "" { + cfg.Keys[i].ID = uuid.NewString() + } + + // Process API key value + processedValue, envVar, err := config.processEnvValue(key.Value) + if err != nil { + config.cleanupEnvKeys(provider, "", newEnvKeys) + if strings.Contains(err.Error(), "not found") { + logger.Info("%s: %v", provider, err) + } else { + logger.Warn("failed to process env vars in keys for %s: %v", provider, err) + } + continue + } + cfg.Keys[i].Value = processedValue + + // Track environment key if it came from env + if envVar != "" { + newEnvKeys[envVar] = struct{}{} + config.EnvKeys[envVar] = append(config.EnvKeys[envVar], configstore.EnvKeyInfo{ + EnvVar: envVar, + Provider: provider, + KeyType: "api_key", + ConfigPath: fmt.Sprintf("providers.%s.keys[%s]", provider, key.ID), + KeyID: key.ID, + }) + } + + // Process Azure key config if present + if key.AzureKeyConfig != nil { + if err := config.processAzureKeyConfigEnvVars(&cfg.Keys[i], provider, i, newEnvKeys); err != nil { + config.cleanupEnvKeys(provider, "", newEnvKeys) + logger.Warn("failed to process Azure key config env vars for %s: %v", provider, err) + continue + } + } + + // Process Vertex key config if present + if key.VertexKeyConfig != nil { + if err := config.processVertexKeyConfigEnvVars(&cfg.Keys[i], provider, i, newEnvKeys); err != nil { + config.cleanupEnvKeys(provider, "", newEnvKeys) + logger.Warn("failed to process Vertex key config env vars for %s: %v", provider, err) + continue + } + } + + // Process Bedrock key config if present + if key.BedrockKeyConfig != nil { + if err := config.processBedrockKeyConfigEnvVars(&cfg.Keys[i], provider, i, newEnvKeys); err != nil { + config.cleanupEnvKeys(provider, "", newEnvKeys) + logger.Warn("failed to process Bedrock key config env vars for %s: %v", provider, err) + continue + } + } + } + processedProviders[provider] = cfg + } + // Store processed configurations in memory + config.Providers = processedProviders + } else { + config.autoDetectProviders() + } + if config.ConfigStore != nil { + logger.Debug("updating providers config in store") + err = config.ConfigStore.UpdateProvidersConfig(processedProviders) + if err != nil { + logger.Warn("failed to update providers config: %v", err) + } + if err := config.ConfigStore.UpdateEnvKeys(config.EnvKeys); err != nil { + logger.Warn("failed to update env keys: %v", err) + } + } + } + + // 3. Check for MCP Config + + var mcpConfig *schemas.MCPConfig + if config.ConfigStore != nil { + logger.Debug("getting MCP config from store") + mcpConfig, err = config.ConfigStore.GetMCPConfig() + if err != nil { + logger.Warn("failed to get MCP config from store: %v", err) + } + } + + if mcpConfig != nil { + config.MCPConfig = mcpConfig + } else if configData.MCP != nil { + // If MCP config is not present in the store, we will use the config file + logger.Debug("no MCP config found in store, processing from config file") + config.MCPConfig = configData.MCP + if err := config.processMCPEnvVars(); err != nil { + logger.Warn("failed to process MCP env vars: %v", err) + } + if config.ConfigStore != nil { + logger.Debug("updating MCP config in store") + err = config.ConfigStore.UpdateMCPConfig(config.MCPConfig, config.EnvKeys) + if err != nil { + logger.Warn("failed to update MCP config: %v", err) + } + } + } + + // 4. Check for Governance Config + + var governanceConfig *configstore.GovernanceConfig + if config.ConfigStore != nil { + logger.Debug("getting governance config from store") + governanceConfig, err = config.ConfigStore.GetGovernanceConfig() + if err != nil { + logger.Warn("failed to get governance config from store: %v", err) + } + } + + if governanceConfig != nil { + config.GovernanceConfig = governanceConfig + } else if configData.Governance != nil { + logger.Debug("no governance config found in store, processing from config file") + config.GovernanceConfig = configData.Governance + + if config.ConfigStore != nil { + logger.Debug("updating governance config in store") + if err := config.ConfigStore.ExecuteTransaction(func(tx *gorm.DB) error { + // Create budgets + for _, budget := range config.GovernanceConfig.Budgets { + if err := config.ConfigStore.CreateBudget(&budget, tx); err != nil { + return fmt.Errorf("failed to create budget %s: %w", budget.ID, err) + } + } + + // Create rate limits + for _, rateLimit := range config.GovernanceConfig.RateLimits { + if err := config.ConfigStore.CreateRateLimit(&rateLimit, tx); err != nil { + return fmt.Errorf("failed to create rate limit %s: %w", rateLimit.ID, err) + } + } + + // Create customers + for _, customer := range config.GovernanceConfig.Customers { + if err := config.ConfigStore.CreateCustomer(&customer, tx); err != nil { + return fmt.Errorf("failed to create customer %s: %w", customer.ID, err) + } + } + + // Create teams + for _, team := range config.GovernanceConfig.Teams { + if err := config.ConfigStore.CreateTeam(&team, tx); err != nil { + return fmt.Errorf("failed to create team %s: %w", team.ID, err) + } + } + + // Create virtual keys + for _, virtualKey := range config.GovernanceConfig.VirtualKeys { + // Look up existing provider keys by key_id and populate the Keys field + var existingKeys []configstore.TableKey + for _, keyRef := range virtualKey.Keys { + if keyRef.KeyID != "" { + var existingKey configstore.TableKey + if err := tx.Where("key_id = ?", keyRef.KeyID).First(&existingKey).Error; err != nil { + if err == gorm.ErrRecordNotFound { + logger.Warn("referenced key %s not found for virtual key %s", keyRef.KeyID, virtualKey.ID) + continue + } + return fmt.Errorf("failed to lookup key %s for virtual key %s: %w", keyRef.KeyID, virtualKey.ID, err) + } + existingKeys = append(existingKeys, existingKey) + } + } + virtualKey.Keys = existingKeys + + if err := config.ConfigStore.CreateVirtualKey(&virtualKey, tx); err != nil { + return fmt.Errorf("failed to create virtual key %s: %w", virtualKey.ID, err) + } + } + + return nil + }); err != nil { + logger.Warn("failed to update governance config: %v", err) + } + } + } + + // 5. Check for Plugins + + if config.ConfigStore != nil { + logger.Debug("getting plugins from store") + plugins, err := config.ConfigStore.GetPlugins() + if err != nil { + logger.Warn("failed to get plugins from store: %v", err) + } + if plugins != nil { + config.Plugins = make([]*schemas.PluginConfig, len(plugins)) + for i, plugin := range plugins { + pluginConfig := &schemas.PluginConfig{ + Name: plugin.Name, + Enabled: plugin.Enabled, + Config: plugin.Config, + } + if plugin.Name == semanticcache.PluginName { + if err := config.AddProviderKeysToSemanticCacheConfig(pluginConfig); err != nil { + logger.Warn("failed to add provider keys to semantic cache config: %v", err) + } + } + config.Plugins[i] = pluginConfig + } + } + } + + // If plugins are not present in the store, we will use the config file + if len(config.Plugins) == 0 && len(configData.Plugins) > 0 { + logger.Debug("no plugins found in store, processing from config file") + config.Plugins = configData.Plugins + + for i, plugin := range config.Plugins { + if plugin.Name == semanticcache.PluginName { + if err := config.AddProviderKeysToSemanticCacheConfig(plugin); err != nil { + logger.Warn("failed to add provider keys to semantic cache config: %v", err) + } + config.Plugins[i] = plugin + } + } + + if config.ConfigStore != nil { + logger.Debug("updating plugins in store") + for _, plugin := range config.Plugins { + pluginConfigCopy, err := DeepCopy(plugin.Config) + if err != nil { + logger.Warn("failed to deep copy plugin config, skipping database update: %v", err) + continue + } + + pluginConfig := &configstore.TablePlugin{ + Name: plugin.Name, + Enabled: plugin.Enabled, + Config: pluginConfigCopy, + } + if plugin.Name == semanticcache.PluginName { + if err := config.RemoveProviderKeysFromSemanticCacheConfig(pluginConfig); err != nil { + logger.Warn("failed to remove provider keys from semantic cache config: %v", err) + } + } + if err := config.ConfigStore.CreatePlugin(pluginConfig); err != nil { + logger.Warn("failed to update plugin: %v", err) + } + } + } + } + + // 6. Check for Env Keys in config store + + // Initialize env keys + if config.ConfigStore != nil { + envKeys, err := config.ConfigStore.GetEnvKeys() + if err != nil { + logger.Warn("failed to get env keys from store: %v", err) + } + config.EnvKeys = envKeys + } + + if config.EnvKeys == nil { + config.EnvKeys = make(map[string][]configstore.EnvKeyInfo) + } + + logger.Info("successfully loaded configuration") + return config, nil +} + +// GetRawConfigString returns the raw configuration string. +func (s *Config) GetRawConfigString() string { + data, err := os.ReadFile(s.configPath) + if err != nil { + return "{}" + } + return string(data) +} + +// processEnvValue checks and replaces environment variable references in configuration values. +// Returns the processed value and the environment variable name if it was an env reference. +// Supports the "env.VARIABLE_NAME" syntax for referencing environment variables. +// This enables secure configuration management without hardcoding sensitive values. +// +// Examples: +// - "env.OPENAI_API_KEY" -> actual value from OPENAI_API_KEY environment variable +// - "sk-1234567890" -> returned as-is (no env prefix) +func (s *Config) processEnvValue(value string) (string, string, error) { + v := strings.TrimSpace(value) + if !strings.HasPrefix(v, "env.") { + return value, "", nil // do not trim non-env values + } + envKey := strings.TrimSpace(strings.TrimPrefix(v, "env.")) + if envKey == "" { + return "", "", fmt.Errorf("environment variable name missing in %q", value) + } + if envValue, ok := os.LookupEnv(envKey); ok { + return envValue, envKey, nil + } + return "", envKey, fmt.Errorf("environment variable %s not found", envKey) +} + +// getRestoredMCPConfig creates a copy of MCP config with env variable references restored +func (s *Config) getRestoredMCPConfig(envVarsByPath map[string]string) *schemas.MCPConfig { + if s.MCPConfig == nil { + return nil + } + + // Create a copy of the MCP config + mcpConfigCopy := &schemas.MCPConfig{ + ClientConfigs: make([]schemas.MCPClientConfig, len(s.MCPConfig.ClientConfigs)), + } + + // Process each client config + for i, clientConfig := range s.MCPConfig.ClientConfigs { + configCopy := schemas.MCPClientConfig{ + Name: clientConfig.Name, + ConnectionType: clientConfig.ConnectionType, + StdioConfig: clientConfig.StdioConfig, + ToolsToExecute: append([]string{}, clientConfig.ToolsToExecute...), + ToolsToSkip: append([]string{}, clientConfig.ToolsToSkip...), + } + + // Handle connection string with env variable restoration + if clientConfig.ConnectionString != nil { + connStr := *clientConfig.ConnectionString + path := fmt.Sprintf("mcp.client_configs[%d].connection_string", i) + if envVar, ok := envVarsByPath[path]; ok { + connStr = "env." + envVar + } + // If not from env var, keep actual value (no asterisk redaction) + configCopy.ConnectionString = &connStr + } + + mcpConfigCopy.ClientConfigs[i] = configCopy + } + + return mcpConfigCopy +} + +// GetProviderConfigRaw retrieves the raw, unredacted provider configuration from memory. +// This method is for internal use only, particularly by the account implementation. +// +// Performance characteristics: +// - Memory access: ultra-fast direct memory access +// - No database I/O or JSON parsing overhead +// - Thread-safe with read locks for concurrent access +// +// Returns a copy of the configuration to prevent external modifications. +func (s *Config) GetProviderConfigRaw(provider schemas.ModelProvider) (*configstore.ProviderConfig, error) { + s.mu.RLock() + defer s.mu.RUnlock() + + config, exists := s.Providers[provider] + if !exists { + return nil, ErrNotFound + } + + // Return direct reference for maximum performance - this is used by Bifrost core + // CRITICAL: Never modify the returned data as it's shared + return &config, nil +} + +// HandlerStore interface implementation + +// ShouldAllowDirectKeys returns whether direct API keys in headers are allowed +// Note: This method doesn't use locking for performance. In rare cases during +// config updates, it may return stale data, but this is acceptable since bool +// reads are atomic and won't cause panics. +func (s *Config) ShouldAllowDirectKeys() bool { + return s.ClientConfig.AllowDirectKeys +} + +// GetProviderConfigRedacted retrieves a provider configuration with sensitive values redacted. +// This method is intended for external API responses and logging. +// +// The returned configuration has sensitive values redacted: +// - API keys are redacted using RedactKey() +// - Values from environment variables show the original env var name (env.VAR_NAME) +// +// Returns a new copy with redacted values that is safe to expose externally. +func (s *Config) GetProviderConfigRedacted(provider schemas.ModelProvider) (*configstore.ProviderConfig, error) { + s.mu.RLock() + defer s.mu.RUnlock() + + config, exists := s.Providers[provider] + if !exists { + return nil, ErrNotFound + } + + // Create a map for quick lookup of env vars for this provider + envVarsByPath := make(map[string]string) + for envVar, infos := range s.EnvKeys { + for _, info := range infos { + if info.Provider == provider { + envVarsByPath[info.ConfigPath] = envVar + } + } + } + + // Create redacted config with same structure but redacted values + redactedConfig := configstore.ProviderConfig{ + NetworkConfig: config.NetworkConfig, + ConcurrencyAndBufferSize: config.ConcurrencyAndBufferSize, + ProxyConfig: config.ProxyConfig, + SendBackRawResponse: config.SendBackRawResponse, + CustomProviderConfig: config.CustomProviderConfig, + } + + // Create redacted keys + redactedConfig.Keys = make([]schemas.Key, len(config.Keys)) + for i, key := range config.Keys { + redactedConfig.Keys[i] = schemas.Key{ + ID: key.ID, + Models: key.Models, // Copy slice reference - read-only so safe + Weight: key.Weight, + } + + // Redact API key value + path := fmt.Sprintf("providers.%s.keys[%s]", provider, key.ID) + if envVar, ok := envVarsByPath[path]; ok { + redactedConfig.Keys[i].Value = "env." + envVar + } else if !strings.HasPrefix(key.Value, "env.") { + redactedConfig.Keys[i].Value = RedactKey(key.Value) + } + + // Redact Azure key config if present + if key.AzureKeyConfig != nil { + azureConfig := &schemas.AzureKeyConfig{ + Deployments: key.AzureKeyConfig.Deployments, + } + + // Redact Endpoint + path = fmt.Sprintf("providers.%s.keys[%s].azure_key_config.endpoint", provider, key.ID) + if envVar, ok := envVarsByPath[path]; ok { + azureConfig.Endpoint = "env." + envVar + } else if !strings.HasPrefix(key.AzureKeyConfig.Endpoint, "env.") { + azureConfig.Endpoint = RedactKey(key.AzureKeyConfig.Endpoint) + } + + // Redact APIVersion if present + if key.AzureKeyConfig.APIVersion != nil { + path = fmt.Sprintf("providers.%s.keys[%s].azure_key_config.api_version", provider, key.ID) + if envVar, ok := envVarsByPath[path]; ok { + azureConfig.APIVersion = bifrost.Ptr("env." + envVar) + } else { + // APIVersion is not sensitive, keep as-is + azureConfig.APIVersion = key.AzureKeyConfig.APIVersion + } + } + + redactedConfig.Keys[i].AzureKeyConfig = azureConfig + } + + // Redact Vertex key config if present + if key.VertexKeyConfig != nil { + vertexConfig := &schemas.VertexKeyConfig{} + + // Redact ProjectID + path = fmt.Sprintf("providers.%s.keys[%s].vertex_key_config.project_id", provider, key.ID) + if envVar, ok := envVarsByPath[path]; ok { + vertexConfig.ProjectID = "env." + envVar + } else if !strings.HasPrefix(key.VertexKeyConfig.ProjectID, "env.") { + vertexConfig.ProjectID = RedactKey(key.VertexKeyConfig.ProjectID) + } + + // Region is not sensitive, handle env vars only + path = fmt.Sprintf("providers.%s.keys[%s].vertex_key_config.region", provider, key.ID) + if envVar, ok := envVarsByPath[path]; ok { + vertexConfig.Region = "env." + envVar + } else { + vertexConfig.Region = key.VertexKeyConfig.Region + } + + // Redact AuthCredentials + path = fmt.Sprintf("providers.%s.keys[%s].vertex_key_config.auth_credentials", provider, key.ID) + if envVar, ok := envVarsByPath[path]; ok { + vertexConfig.AuthCredentials = "env." + envVar + } else if !strings.HasPrefix(key.VertexKeyConfig.AuthCredentials, "env.") { + vertexConfig.AuthCredentials = RedactKey(key.VertexKeyConfig.AuthCredentials) + } + + redactedConfig.Keys[i].VertexKeyConfig = vertexConfig + } + + // Redact Bedrock key config if present + if key.BedrockKeyConfig != nil { + bedrockConfig := &schemas.BedrockKeyConfig{ + Deployments: key.BedrockKeyConfig.Deployments, + } + + // Redact AccessKey + path = fmt.Sprintf("providers.%s.keys[%s].bedrock_key_config.access_key", provider, key.ID) + if envVar, ok := envVarsByPath[path]; ok { + bedrockConfig.AccessKey = "env." + envVar + } else if !strings.HasPrefix(key.BedrockKeyConfig.AccessKey, "env.") { + bedrockConfig.AccessKey = RedactKey(key.BedrockKeyConfig.AccessKey) + } + + // Redact SecretKey + path = fmt.Sprintf("providers.%s.keys[%s].bedrock_key_config.secret_key", provider, key.ID) + if envVar, ok := envVarsByPath[path]; ok { + bedrockConfig.SecretKey = "env." + envVar + } else if !strings.HasPrefix(key.BedrockKeyConfig.SecretKey, "env.") { + bedrockConfig.SecretKey = RedactKey(key.BedrockKeyConfig.SecretKey) + } + + // Redact SessionToken + path = fmt.Sprintf("providers.%s.keys[%s].bedrock_key_config.session_token", provider, key.ID) + if envVar, ok := envVarsByPath[path]; ok { + bedrockConfig.SessionToken = bifrost.Ptr("env." + envVar) + } else { + bedrockConfig.SessionToken = key.BedrockKeyConfig.SessionToken + } + + // Redact Region + path = fmt.Sprintf("providers.%s.keys[%s].bedrock_key_config.region", provider, key.ID) + if envVar, ok := envVarsByPath[path]; ok { + bedrockConfig.Region = bifrost.Ptr("env." + envVar) + } else { + bedrockConfig.Region = key.BedrockKeyConfig.Region + } + + // Redact ARN + path = fmt.Sprintf("providers.%s.keys[%s].bedrock_key_config.arn", provider, key.ID) + if envVar, ok := envVarsByPath[path]; ok { + bedrockConfig.ARN = bifrost.Ptr("env." + envVar) + } else { + bedrockConfig.ARN = key.BedrockKeyConfig.ARN + } + + redactedConfig.Keys[i].BedrockKeyConfig = bedrockConfig + } + } + + return &redactedConfig, nil +} + +// GetAllProviders returns all configured provider names. +func (s *Config) GetAllProviders() ([]schemas.ModelProvider, error) { + s.mu.RLock() + defer s.mu.RUnlock() + + providers := make([]schemas.ModelProvider, 0, len(s.Providers)) + for provider := range s.Providers { + providers = append(providers, provider) + } + + return providers, nil +} + +// AddProvider adds a new provider configuration to memory with full environment variable +// processing. This method is called when new providers are added via the HTTP API. +// +// The method: +// - Validates that the provider doesn't already exist +// - Processes environment variables in API keys, and key-level configs +// - Stores the processed configuration in memory +// - Updates metadata and timestamps +func (s *Config) AddProvider(provider schemas.ModelProvider, config configstore.ProviderConfig) error { + s.mu.Lock() + defer s.mu.Unlock() + + // Check if provider already exists + if _, exists := s.Providers[provider]; exists { + return fmt.Errorf("provider %s already exists", provider) + } + + // Validate CustomProviderConfig if present + if err := ValidateCustomProvider(config, provider); err != nil { + return err + } + newEnvKeys := make(map[string]struct{}) + + // Process environment variables in keys (including key-level configs) + for i, key := range config.Keys { + if key.ID == "" { + config.Keys[i].ID = uuid.NewString() + } + + // Process API key value + processedValue, envVar, err := s.processEnvValue(key.Value) + if err != nil { + s.cleanupEnvKeys(provider, "", newEnvKeys) + return fmt.Errorf("failed to process env var in key: %w", err) + } + config.Keys[i].Value = processedValue + + // Track environment key if it came from env + if envVar != "" { + newEnvKeys[envVar] = struct{}{} + s.EnvKeys[envVar] = append(s.EnvKeys[envVar], configstore.EnvKeyInfo{ + EnvVar: envVar, + Provider: provider, + KeyType: "api_key", + ConfigPath: fmt.Sprintf("providers.%s.keys[%s]", provider, config.Keys[i].ID), + KeyID: config.Keys[i].ID, + }) + } + + // Process Azure key config if present + if key.AzureKeyConfig != nil { + if err := s.processAzureKeyConfigEnvVars(&config.Keys[i], provider, i, newEnvKeys); err != nil { + s.cleanupEnvKeys(provider, "", newEnvKeys) + return fmt.Errorf("failed to process Azure key config env vars: %w", err) + } + } + + // Process Vertex key config if present + if key.VertexKeyConfig != nil { + if err := s.processVertexKeyConfigEnvVars(&config.Keys[i], provider, i, newEnvKeys); err != nil { + s.cleanupEnvKeys(provider, "", newEnvKeys) + return fmt.Errorf("failed to process Vertex key config env vars: %w", err) + } + } + + // Process Bedrock key config if present + if key.BedrockKeyConfig != nil { + if err := s.processBedrockKeyConfigEnvVars(&config.Keys[i], provider, i, newEnvKeys); err != nil { + s.cleanupEnvKeys(provider, "", newEnvKeys) + return fmt.Errorf("failed to process Bedrock key config env vars: %w", err) + } + } + } + + s.Providers[provider] = config + + if s.ConfigStore != nil { + if err := s.ConfigStore.AddProvider(provider, config, s.EnvKeys); err != nil { + return fmt.Errorf("failed to update provider config in store: %w", err) + } + if err := s.ConfigStore.UpdateEnvKeys(s.EnvKeys); err != nil { + logger.Warn("failed to update env keys: %v", err) + } + } + + logger.Info("added provider: %s", provider) + return nil +} + +// UpdateProviderConfig updates a provider configuration in memory with full environment +// variable processing. This method is called when provider configurations are modified +// via the HTTP API and ensures all data processing is done upfront. +// +// The method: +// - Processes environment variables in API keys, and key-level configs +// - Stores the processed configuration in memory +// - Updates metadata and timestamps +// - Thread-safe operation with write locks +// +// Note: Environment variable cleanup for deleted/updated keys is now handled automatically +// by the mergeKeys function before this method is called. +// +// Parameters: +// - provider: The provider to update +// - config: The new configuration +func (s *Config) UpdateProviderConfig(provider schemas.ModelProvider, config configstore.ProviderConfig) error { + s.mu.Lock() + defer s.mu.Unlock() + + // Get existing configuration for validation + existingConfig, exists := s.Providers[provider] + if !exists { + return fmt.Errorf("provider %s not found", provider) + } + + // Validate CustomProviderConfig if present, ensuring immutable fields are not changed + if err := ValidateCustomProviderUpdate(config, existingConfig, provider); err != nil { + return err + } + // Track new environment variables being added + newEnvKeys := make(map[string]struct{}) + + // Process environment variables in keys (including key-level configs) + for i, key := range config.Keys { + if key.ID == "" { + config.Keys[i].ID = uuid.NewString() + } + + // Process API key value + processedValue, envVar, err := s.processEnvValue(key.Value) + if err != nil { + s.cleanupEnvKeys(provider, "", newEnvKeys) // Clean up only new vars on failure + return fmt.Errorf("failed to process env var in key: %w", err) + } + config.Keys[i].Value = processedValue + + // Track environment key if it came from env + if envVar != "" { + newEnvKeys[envVar] = struct{}{} + s.EnvKeys[envVar] = append(s.EnvKeys[envVar], configstore.EnvKeyInfo{ + EnvVar: envVar, + Provider: provider, + KeyType: "api_key", + ConfigPath: fmt.Sprintf("providers.%s.keys[%s]", provider, config.Keys[i].ID), + KeyID: config.Keys[i].ID, + }) + } + + // Process Azure key config if present + if key.AzureKeyConfig != nil { + if err := s.processAzureKeyConfigEnvVars(&config.Keys[i], provider, i, newEnvKeys); err != nil { + s.cleanupEnvKeys(provider, "", newEnvKeys) + return fmt.Errorf("failed to process Azure key config env vars: %w", err) + } + } + + // Process Vertex key config if present + if key.VertexKeyConfig != nil { + if err := s.processVertexKeyConfigEnvVars(&config.Keys[i], provider, i, newEnvKeys); err != nil { + s.cleanupEnvKeys(provider, "", newEnvKeys) + return fmt.Errorf("failed to process Vertex key config env vars: %w", err) + } + } + + // Process Bedrock key config if present + if key.BedrockKeyConfig != nil { + if err := s.processBedrockKeyConfigEnvVars(&config.Keys[i], provider, i, newEnvKeys); err != nil { + s.cleanupEnvKeys(provider, "", newEnvKeys) + return fmt.Errorf("failed to process Bedrock key config env vars: %w", err) + } + } + } + + s.Providers[provider] = config + + if s.ConfigStore != nil { + if err := s.ConfigStore.UpdateProvider(provider, config, s.EnvKeys); err != nil { + return fmt.Errorf("failed to update provider config in store: %w", err) + } + if err := s.ConfigStore.UpdateEnvKeys(s.EnvKeys); err != nil { + logger.Warn("failed to update env keys: %v", err) + } + } + + logger.Info("Updated configuration for provider: %s", provider) + return nil +} + +// RemoveProvider removes a provider configuration from memory. +func (s *Config) RemoveProvider(provider schemas.ModelProvider) error { + s.mu.Lock() + defer s.mu.Unlock() + + if _, exists := s.Providers[provider]; !exists { + return fmt.Errorf("provider %s not found", provider) + } + + delete(s.Providers, provider) + s.cleanupEnvKeys(provider, "", nil) + + if s.ConfigStore != nil { + if err := s.ConfigStore.DeleteProvider(provider); err != nil { + return fmt.Errorf("failed to update provider config in store: %w", err) + } + if err := s.ConfigStore.UpdateEnvKeys(s.EnvKeys); err != nil { + logger.Warn("failed to update env keys: %v", err) + } + } + + logger.Info("Removed provider: %s", provider) + return nil +} + +// GetAllKeys returns the redacted keys +func (s *Config) GetAllKeys() ([]configstore.TableKey, error) { + s.mu.RLock() + defer s.mu.RUnlock() + + keys := make([]configstore.TableKey, 0) + for providerKey, provider := range s.Providers { + for _, key := range provider.Keys { + keys = append(keys, configstore.TableKey{ + KeyID: key.ID, + Value: "", + Models: key.Models, + Weight: key.Weight, + Provider: string(providerKey), + }) + } + } + + return keys, nil +} + +// processMCPEnvVars processes environment variables in the MCP configuration. +// This method handles the MCP config structures and processes environment +// variables in their fields, ensuring type safety and proper field handling. +// +// Supported fields that are processed: +// - ConnectionString in each MCP ClientConfig +// +// Returns an error if any required environment variable is missing. +// This approach ensures type safety while supporting environment variable substitution. +func (s *Config) processMCPEnvVars() error { + if s.MCPConfig == nil { + return nil + } + + var missingEnvVars []string + + // Process each client config + for i, clientConfig := range s.MCPConfig.ClientConfigs { + // Process ConnectionString if present + if clientConfig.ConnectionString != nil { + newValue, envVar, err := s.processEnvValue(*clientConfig.ConnectionString) + if err != nil { + logger.Warn("failed to process env vars in MCP client %s: %v", clientConfig.Name, err) + missingEnvVars = append(missingEnvVars, envVar) + continue + } + if envVar != "" { + s.EnvKeys[envVar] = append(s.EnvKeys[envVar], configstore.EnvKeyInfo{ + EnvVar: envVar, + Provider: "", + KeyType: "connection_string", + ConfigPath: fmt.Sprintf("mcp.client_configs.%s.connection_string", clientConfig.Name), + KeyID: "", // Empty for MCP connection strings + }) + } + s.MCPConfig.ClientConfigs[i].ConnectionString = &newValue + } + } + + if len(missingEnvVars) > 0 { + return fmt.Errorf("missing environment variables: %v", missingEnvVars) + } + + return nil +} + +// SetBifrostClient sets the Bifrost client in the store. +// This is used to allow the store to access the Bifrost client. +// This is useful for the MCP handler to access the Bifrost client. +func (s *Config) SetBifrostClient(client *bifrost.Bifrost) { + s.muMCP.Lock() + defer s.muMCP.Unlock() + + s.client = client +} + +// AddMCPClient adds a new MCP client to the configuration. +// This method is called when a new MCP client is added via the HTTP API. +// +// The method: +// - Validates that the MCP client doesn't already exist +// - Processes environment variables in the MCP client configuration +// - Stores the processed configuration in memory +func (s *Config) AddMCPClient(clientConfig schemas.MCPClientConfig) error { + if s.client == nil { + return fmt.Errorf("bifrost client not set") + } + + s.muMCP.Lock() + defer s.muMCP.Unlock() + + if s.MCPConfig == nil { + s.MCPConfig = &schemas.MCPConfig{} + } + + // Track new environment variables + newEnvKeys := make(map[string]struct{}) + + s.MCPConfig.ClientConfigs = append(s.MCPConfig.ClientConfigs, clientConfig) + + // Process environment variables in the new client config + if clientConfig.ConnectionString != nil { + processedValue, envVar, err := s.processEnvValue(*clientConfig.ConnectionString) + if err != nil { + s.MCPConfig.ClientConfigs = s.MCPConfig.ClientConfigs[:len(s.MCPConfig.ClientConfigs)-1] + return fmt.Errorf("failed to process env var in connection string: %w", err) + } + if envVar != "" { + newEnvKeys[envVar] = struct{}{} + s.EnvKeys[envVar] = append(s.EnvKeys[envVar], configstore.EnvKeyInfo{ + EnvVar: envVar, + Provider: "", + KeyType: "connection_string", + ConfigPath: fmt.Sprintf("mcp.client_configs.%s.connection_string", clientConfig.Name), + KeyID: "", // Empty for MCP connection strings + }) + } + s.MCPConfig.ClientConfigs[len(s.MCPConfig.ClientConfigs)-1].ConnectionString = &processedValue + } + + // Config with processed env vars + if err := s.client.AddMCPClient(s.MCPConfig.ClientConfigs[len(s.MCPConfig.ClientConfigs)-1]); err != nil { + s.MCPConfig.ClientConfigs = s.MCPConfig.ClientConfigs[:len(s.MCPConfig.ClientConfigs)-1] + s.cleanupEnvKeys("", clientConfig.Name, newEnvKeys) + return fmt.Errorf("failed to add MCP client: %w", err) + } + + if s.ConfigStore != nil { + if err := s.ConfigStore.UpdateMCPConfig(s.MCPConfig, s.EnvKeys); err != nil { + return fmt.Errorf("failed to update MCP config in store: %w", err) + } + if err := s.ConfigStore.UpdateEnvKeys(s.EnvKeys); err != nil { + logger.Warn("failed to update env keys: %v", err) + } + } + + return nil +} + +// RemoveMCPClient removes an MCP client from the configuration. +// This method is called when an MCP client is removed via the HTTP API. +// +// The method: +// - Validates that the MCP client exists +// - Removes the MCP client from the configuration +// - Removes the MCP client from the Bifrost client +func (s *Config) RemoveMCPClient(name string) error { + if s.client == nil { + return fmt.Errorf("bifrost client not set") + } + + s.muMCP.Lock() + defer s.muMCP.Unlock() + + if s.MCPConfig == nil { + return fmt.Errorf("no MCP config found") + } + + if err := s.client.RemoveMCPClient(name); err != nil { + return fmt.Errorf("failed to remove MCP client: %w", err) + } + + for i, clientConfig := range s.MCPConfig.ClientConfigs { + if clientConfig.Name == name { + s.MCPConfig.ClientConfigs = append(s.MCPConfig.ClientConfigs[:i], s.MCPConfig.ClientConfigs[i+1:]...) + break + } + } + + s.cleanupEnvKeys("", name, nil) + + if s.ConfigStore != nil { + if err := s.ConfigStore.UpdateMCPConfig(s.MCPConfig, s.EnvKeys); err != nil { + return fmt.Errorf("failed to update MCP config in store: %w", err) + } + if err := s.ConfigStore.UpdateEnvKeys(s.EnvKeys); err != nil { + logger.Warn("failed to update env keys: %v", err) + } + } + + return nil +} + +// EditMCPClientTools edits the tools of an MCP client. +// This allows for dynamic MCP client tool management at runtime. +// +// Parameters: +// - name: Name of the client to edit +// - toolsToAdd: Tools to add to the client +// - toolsToRemove: Tools to remove from the client +func (s *Config) EditMCPClientTools(name string, toolsToAdd []string, toolsToRemove []string) error { + if s.client == nil { + return fmt.Errorf("bifrost client not set") + } + + s.muMCP.Lock() + defer s.muMCP.Unlock() + + if s.MCPConfig == nil { + return fmt.Errorf("no MCP config found") + } + + if err := s.client.EditMCPClientTools(name, toolsToAdd, toolsToRemove); err != nil { + return fmt.Errorf("failed to edit MCP client tools: %w", err) + } + + for i, clientConfig := range s.MCPConfig.ClientConfigs { + if clientConfig.Name == name { + s.MCPConfig.ClientConfigs[i].ToolsToExecute = toolsToAdd + s.MCPConfig.ClientConfigs[i].ToolsToSkip = toolsToRemove + break + } + } + + if s.ConfigStore != nil { + if err := s.ConfigStore.UpdateMCPConfig(s.MCPConfig, s.EnvKeys); err != nil { + return fmt.Errorf("failed to update MCP config in store: %w", err) + } + if err := s.ConfigStore.UpdateEnvKeys(s.EnvKeys); err != nil { + logger.Warn("failed to update env keys: %v", err) + } + } + + return nil +} + +// RedactMCPClientConfig creates a redacted copy of an MCP client configuration. +// Connection strings are either redacted or replaced with their environment variable names. +func (s *Config) RedactMCPClientConfig(config schemas.MCPClientConfig) schemas.MCPClientConfig { + // Create a copy with basic fields + configCopy := schemas.MCPClientConfig{ + Name: config.Name, + ConnectionType: config.ConnectionType, + ConnectionString: config.ConnectionString, + StdioConfig: config.StdioConfig, + ToolsToExecute: append([]string{}, config.ToolsToExecute...), + ToolsToSkip: append([]string{}, config.ToolsToSkip...), + } + + // Handle connection string if present + if config.ConnectionString != nil { + connStr := *config.ConnectionString + + // Check if this value came from an env var + for envVar, infos := range s.EnvKeys { + for _, info := range infos { + if info.Provider == "" && info.KeyType == "connection_string" && info.ConfigPath == fmt.Sprintf("mcp.client_configs.%s.connection_string", config.Name) { + connStr = "env." + envVar + break + } + } + } + + // If not from env var, redact it + if !strings.HasPrefix(connStr, "env.") { + connStr = RedactKey(connStr) + } + configCopy.ConnectionString = &connStr + } + + return configCopy +} + +// RedactKey redacts sensitive key values by showing only the first and last 4 characters +func RedactKey(key string) string { + if key == "" { + return "" + } + + // If key is 8 characters or less, just return all asterisks + if len(key) <= 8 { + return strings.Repeat("*", len(key)) + } + + // Show first 4 and last 4 characters, replace middle with asterisks + prefix := key[:4] + suffix := key[len(key)-4:] + middle := strings.Repeat("*", 24) + + return prefix + middle + suffix +} + +// IsRedacted checks if a key value is redacted, either by being an environment variable +// reference (env.VAR_NAME) or containing the exact redaction pattern from RedactKey. +func IsRedacted(key string) bool { + if key == "" { + return false + } + + // Check if it's an environment variable reference + if strings.HasPrefix(key, "env.") { + return true + } + + if len(key) <= 8 { + return strings.Count(key, "*") == len(key) + } + + // Check for exact redaction pattern: 4 chars + 24 asterisks + 4 chars + if len(key) == 32 { + middle := key[4:28] + if middle == strings.Repeat("*", 24) { + return true + } + } + + return false +} + +// cleanupEnvKeys removes environment variable entries from the store based on the given criteria. +// If envVarsToRemove is nil, it removes all env vars for the specified provider/client. +// If envVarsToRemove is provided, it only removes those specific env vars. +// +// Parameters: +// - provider: Provider name to clean up (empty string for MCP clients) +// - mcpClientName: MCP client name to clean up (empty string for providers) +// - envVarsToRemove: Optional map of specific env vars to remove (nil to remove all) +func (s *Config) cleanupEnvKeys(provider schemas.ModelProvider, mcpClientName string, envVarsToRemove map[string]struct{}) { + // If envVarsToRemove is provided, only clean those specific vars + if envVarsToRemove != nil { + for envVar := range envVarsToRemove { + s.cleanupEnvVar(envVar, provider, mcpClientName) + } + return + } + + // If envVarsToRemove is nil, clean all vars for the provider/client + for envVar := range s.EnvKeys { + s.cleanupEnvVar(envVar, provider, mcpClientName) + } +} + +// cleanupEnvVar removes entries for a specific environment variable based on provider/client. +// This is a helper function to avoid duplicating the filtering logic. +func (s *Config) cleanupEnvVar(envVar string, provider schemas.ModelProvider, mcpClientName string) { + infos := s.EnvKeys[envVar] + if len(infos) == 0 { + return + } + + // Keep entries that don't match the provider/client we're cleaning up + filteredInfos := make([]configstore.EnvKeyInfo, 0, len(infos)) + for _, info := range infos { + shouldKeep := false + if provider != "" { + shouldKeep = info.Provider != provider + } else if mcpClientName != "" { + shouldKeep = info.Provider != "" || !strings.HasPrefix(info.ConfigPath, fmt.Sprintf("mcp.client_configs.%s", mcpClientName)) + } + if shouldKeep { + filteredInfos = append(filteredInfos, info) + } + } + + if len(filteredInfos) == 0 { + delete(s.EnvKeys, envVar) + } else { + s.EnvKeys[envVar] = filteredInfos + } +} + +// CleanupEnvKeysForKeys removes environment variable entries for specific keys that are being deleted. +// This function targets key-specific environment variables based on key IDs. +// +// Parameters: +// - provider: Provider name the keys belong to +// - keysToDelete: List of keys being deleted (uses their IDs to identify env vars to clean up) +func (s *Config) CleanupEnvKeysForKeys(provider schemas.ModelProvider, keysToDelete []schemas.Key) { + // Create a set of key IDs to delete for efficient lookup + keyIDsToDelete := make(map[string]bool) + for _, key := range keysToDelete { + keyIDsToDelete[key.ID] = true + } + + // Iterate through all environment variables and remove entries for deleted keys + for envVar, infos := range s.EnvKeys { + filteredInfos := make([]configstore.EnvKeyInfo, 0, len(infos)) + + for _, info := range infos { + // Keep entries that either: + // 1. Don't belong to this provider, OR + // 2. Don't have a KeyID (MCP), OR + // 3. Have a KeyID that's not being deleted + shouldKeep := info.Provider != provider || + info.KeyID == "" || + !keyIDsToDelete[info.KeyID] + + if shouldKeep { + filteredInfos = append(filteredInfos, info) + } + } + + // Update or delete the environment variable entry + if len(filteredInfos) == 0 { + delete(s.EnvKeys, envVar) + } else { + s.EnvKeys[envVar] = filteredInfos + } + } +} + +// CleanupEnvKeysForUpdatedKeys removes environment variable entries for keys that are being updated +// but only for fields where the environment variable reference has actually changed. +// This function is called after the merge to compare final values with original values. +// +// Parameters: +// - provider: Provider name the keys belong to +// - keysToUpdate: List of keys being updated +// - oldKeys: List of original keys before update +// - mergedKeys: List of final merged keys after update +func (s *Config) CleanupEnvKeysForUpdatedKeys(provider schemas.ModelProvider, keysToUpdate []schemas.Key, oldKeys []schemas.Key, mergedKeys []schemas.Key) { + // Create maps for efficient lookup + keysToUpdateMap := make(map[string]schemas.Key) + for _, key := range keysToUpdate { + keysToUpdateMap[key.ID] = key + } + + oldKeysMap := make(map[string]schemas.Key) + for _, key := range oldKeys { + oldKeysMap[key.ID] = key + } + + mergedKeysMap := make(map[string]schemas.Key) + for _, key := range mergedKeys { + mergedKeysMap[key.ID] = key + } + + // Iterate through all environment variables and remove entries only for fields that are changing + for envVar, infos := range s.EnvKeys { + filteredInfos := make([]configstore.EnvKeyInfo, 0, len(infos)) + + for _, info := range infos { + // Keep entries that either: + // 1. Don't belong to this provider, OR + // 2. Don't have a KeyID (MCP), OR + // 3. Have a KeyID that's not being updated, OR + // 4. Have a KeyID that's being updated but the env var reference hasn't changed + shouldKeep := info.Provider != provider || + info.KeyID == "" || + keysToUpdateMap[info.KeyID].ID == "" || + !s.isEnvVarReferenceChanging(mergedKeysMap[info.KeyID], oldKeysMap[info.KeyID], info.ConfigPath) + + if shouldKeep { + filteredInfos = append(filteredInfos, info) + } + } + + // Update or delete the environment variable entry + if len(filteredInfos) == 0 { + delete(s.EnvKeys, envVar) + } else { + s.EnvKeys[envVar] = filteredInfos + } + } +} + +// isEnvVarReferenceChanging checks if an environment variable reference is changing between old and merged key +func (s *Config) isEnvVarReferenceChanging(mergedKey, oldKey schemas.Key, configPath string) bool { + // Extract the field name from the config path + // e.g., "providers.vertex.keys[123].vertex_key_config.project_id" -> "project_id" + pathParts := strings.Split(configPath, ".") + if len(pathParts) < 2 { + return false + } + fieldName := pathParts[len(pathParts)-1] + + // Get the old and merged values for this field + oldValue := s.getFieldValue(oldKey, fieldName) + mergedValue := s.getFieldValue(mergedKey, fieldName) + + // If either value is an env var reference, check if they're different + oldIsEnvVar := strings.HasPrefix(oldValue, "env.") + mergedIsEnvVar := strings.HasPrefix(mergedValue, "env.") + + // If both are env vars, check if they reference the same variable + if oldIsEnvVar && mergedIsEnvVar { + return oldValue != mergedValue + } + + // If one is env var and other isn't, or both are different types, it's changing + return oldIsEnvVar != mergedIsEnvVar || oldValue != mergedValue +} + +// getFieldValue extracts the value of a specific field from a key based on the field name +func (s *Config) getFieldValue(key schemas.Key, fieldName string) string { + switch fieldName { + case "project_id": + if key.VertexKeyConfig != nil { + return key.VertexKeyConfig.ProjectID + } + case "region": + if key.VertexKeyConfig != nil { + return key.VertexKeyConfig.Region + } + case "auth_credentials": + if key.VertexKeyConfig != nil { + return key.VertexKeyConfig.AuthCredentials + } + case "endpoint": + if key.AzureKeyConfig != nil { + return key.AzureKeyConfig.Endpoint + } + case "api_version": + if key.AzureKeyConfig != nil && key.AzureKeyConfig.APIVersion != nil { + return *key.AzureKeyConfig.APIVersion + } + case "access_key": + if key.BedrockKeyConfig != nil { + return key.BedrockKeyConfig.AccessKey + } + case "secret_key": + if key.BedrockKeyConfig != nil { + return key.BedrockKeyConfig.SecretKey + } + case "session_token": + if key.BedrockKeyConfig != nil && key.BedrockKeyConfig.SessionToken != nil { + return *key.BedrockKeyConfig.SessionToken + } + default: + // For the main API key value + if fieldName == "value" || strings.Contains(fieldName, "key") { + return key.Value + } + } + return "" +} + +// autoDetectProviders automatically detects common environment variables and sets up providers +// when no configuration file exists. This enables zero-config startup when users have set +// standard environment variables like OPENAI_API_KEY, ANTHROPIC_API_KEY, etc. +// +// Supported environment variables: +// - OpenAI: OPENAI_API_KEY, OPENAI_KEY +// - Anthropic: ANTHROPIC_API_KEY, ANTHROPIC_KEY +// - Mistral: MISTRAL_API_KEY, MISTRAL_KEY +// +// For each detected provider, it creates a default configuration with: +// - The detected API key with weight 1.0 +// - Empty models list (provider will use default models) +// - Default concurrency and buffer size settings +func (s *Config) autoDetectProviders() { + // Define common environment variable patterns for each provider + providerEnvVars := map[schemas.ModelProvider][]string{ + schemas.OpenAI: {"OPENAI_API_KEY", "OPENAI_KEY"}, + schemas.Anthropic: {"ANTHROPIC_API_KEY", "ANTHROPIC_KEY"}, + schemas.Mistral: {"MISTRAL_API_KEY", "MISTRAL_KEY"}, + } + + detectedCount := 0 + + for provider, envVars := range providerEnvVars { + for _, envVar := range envVars { + if apiKey := os.Getenv(envVar); apiKey != "" { + // Generate a unique ID for the auto-detected key + keyID := uuid.NewString() + + // Create default provider configuration + providerConfig := configstore.ProviderConfig{ + Keys: []schemas.Key{ + { + ID: keyID, + Value: apiKey, + Models: []string{}, // Empty means all supported models + Weight: 1.0, + }, + }, + ConcurrencyAndBufferSize: &schemas.DefaultConcurrencyAndBufferSize, + } + + // Add to providers map + s.Providers[provider] = providerConfig + + // Track the environment variable + s.EnvKeys[envVar] = append(s.EnvKeys[envVar], configstore.EnvKeyInfo{ + EnvVar: envVar, + Provider: provider, + KeyType: "api_key", + ConfigPath: fmt.Sprintf("providers.%s.keys[%s]", provider, keyID), + KeyID: keyID, + }) + + logger.Info("auto-detected %s provider from environment variable %s", provider, envVar) + detectedCount++ + break // Only use the first found env var for each provider + } + } + } + + if detectedCount > 0 { + logger.Info("auto-configured %d provider(s) from environment variables", detectedCount) + if s.ConfigStore != nil { + if err := s.ConfigStore.UpdateProvidersConfig(s.Providers); err != nil { + logger.Error("failed to update providers in store: %v", err) + } + } + } +} + +// processAzureKeyConfigEnvVars processes environment variables in Azure key configuration +func (s *Config) processAzureKeyConfigEnvVars(key *schemas.Key, provider schemas.ModelProvider, keyIndex int, newEnvKeys map[string]struct{}) error { + azureConfig := key.AzureKeyConfig + + // Process Endpoint + processedEndpoint, envVar, err := s.processEnvValue(azureConfig.Endpoint) + if err != nil { + return err + } + if envVar != "" { + newEnvKeys[envVar] = struct{}{} + s.EnvKeys[envVar] = append(s.EnvKeys[envVar], configstore.EnvKeyInfo{ + EnvVar: envVar, + Provider: provider, + KeyType: "azure_config", + ConfigPath: fmt.Sprintf("providers.%s.keys[%s].azure_key_config.endpoint", provider, key.ID), + KeyID: key.ID, + }) + } + azureConfig.Endpoint = processedEndpoint + + // Process APIVersion if present + if azureConfig.APIVersion != nil { + processedAPIVersion, envVar, err := s.processEnvValue(*azureConfig.APIVersion) + if err != nil { + return err + } + if envVar != "" { + newEnvKeys[envVar] = struct{}{} + s.EnvKeys[envVar] = append(s.EnvKeys[envVar], configstore.EnvKeyInfo{ + EnvVar: envVar, + Provider: provider, + KeyType: "azure_config", + ConfigPath: fmt.Sprintf("providers.%s.keys[%s].azure_key_config.api_version", provider, key.ID), + KeyID: key.ID, + }) + } + azureConfig.APIVersion = &processedAPIVersion + } + + return nil +} + +// processVertexKeyConfigEnvVars processes environment variables in Vertex key configuration +func (s *Config) processVertexKeyConfigEnvVars(key *schemas.Key, provider schemas.ModelProvider, keyIndex int, newEnvKeys map[string]struct{}) error { + vertexConfig := key.VertexKeyConfig + + // Process ProjectID + processedProjectID, envVar, err := s.processEnvValue(vertexConfig.ProjectID) + if err != nil { + return err + } + if envVar != "" { + newEnvKeys[envVar] = struct{}{} + s.EnvKeys[envVar] = append(s.EnvKeys[envVar], configstore.EnvKeyInfo{ + EnvVar: envVar, + Provider: provider, + KeyType: "vertex_config", + ConfigPath: fmt.Sprintf("providers.%s.keys[%s].vertex_key_config.project_id", provider, key.ID), + KeyID: key.ID, + }) + } + vertexConfig.ProjectID = processedProjectID + + // Process Region + processedRegion, envVar, err := s.processEnvValue(vertexConfig.Region) + if err != nil { + return err + } + if envVar != "" { + newEnvKeys[envVar] = struct{}{} + s.EnvKeys[envVar] = append(s.EnvKeys[envVar], configstore.EnvKeyInfo{ + EnvVar: envVar, + Provider: provider, + KeyType: "vertex_config", + ConfigPath: fmt.Sprintf("providers.%s.keys[%s].vertex_key_config.region", provider, key.ID), + KeyID: key.ID, + }) + } + vertexConfig.Region = processedRegion + + // Process AuthCredentials + processedAuthCredentials, envVar, err := s.processEnvValue(vertexConfig.AuthCredentials) + if err != nil { + return err + } + if envVar != "" { + newEnvKeys[envVar] = struct{}{} + s.EnvKeys[envVar] = append(s.EnvKeys[envVar], configstore.EnvKeyInfo{ + EnvVar: envVar, + Provider: provider, + KeyType: "vertex_config", + ConfigPath: fmt.Sprintf("providers.%s.keys[%s].vertex_key_config.auth_credentials", provider, key.ID), + KeyID: key.ID, + }) + } + vertexConfig.AuthCredentials = processedAuthCredentials + + return nil +} + +// processBedrockKeyConfigEnvVars processes environment variables in Bedrock key configuration +func (s *Config) processBedrockKeyConfigEnvVars(key *schemas.Key, provider schemas.ModelProvider, keyIndex int, newEnvKeys map[string]struct{}) error { + bedrockConfig := key.BedrockKeyConfig + + // Process AccessKey + processedAccessKey, envVar, err := s.processEnvValue(bedrockConfig.AccessKey) + if err != nil { + return err + } + if envVar != "" { + newEnvKeys[envVar] = struct{}{} + s.EnvKeys[envVar] = append(s.EnvKeys[envVar], configstore.EnvKeyInfo{ + EnvVar: envVar, + Provider: provider, + KeyType: "bedrock_config", + ConfigPath: fmt.Sprintf("providers.%s.keys[%s].bedrock_key_config.access_key", provider, key.ID), + KeyID: key.ID, + }) + } + bedrockConfig.AccessKey = processedAccessKey + + // Process SecretKey + processedSecretKey, envVar, err := s.processEnvValue(bedrockConfig.SecretKey) + if err != nil { + return err + } + if envVar != "" { + newEnvKeys[envVar] = struct{}{} + s.EnvKeys[envVar] = append(s.EnvKeys[envVar], configstore.EnvKeyInfo{ + EnvVar: envVar, + Provider: provider, + KeyType: "bedrock_config", + ConfigPath: fmt.Sprintf("providers.%s.keys[%s].bedrock_key_config.secret_key", provider, key.ID), + KeyID: key.ID, + }) + } + bedrockConfig.SecretKey = processedSecretKey + + // Process SessionToken if present + if bedrockConfig.SessionToken != nil { + processedSessionToken, envVar, err := s.processEnvValue(*bedrockConfig.SessionToken) + if err != nil { + return err + } + if envVar != "" { + newEnvKeys[envVar] = struct{}{} + s.EnvKeys[envVar] = append(s.EnvKeys[envVar], configstore.EnvKeyInfo{ + EnvVar: envVar, + Provider: provider, + KeyType: "bedrock_config", + ConfigPath: fmt.Sprintf("providers.%s.keys[%s].bedrock_key_config.session_token", provider, key.ID), + KeyID: key.ID, + }) + } + bedrockConfig.SessionToken = &processedSessionToken + } + + // Process Region if present + if bedrockConfig.Region != nil { + processedRegion, envVar, err := s.processEnvValue(*bedrockConfig.Region) + if err != nil { + return err + } + if envVar != "" { + newEnvKeys[envVar] = struct{}{} + s.EnvKeys[envVar] = append(s.EnvKeys[envVar], configstore.EnvKeyInfo{ + EnvVar: envVar, + Provider: provider, + KeyType: "bedrock_config", + ConfigPath: fmt.Sprintf("providers.%s.keys[%s].bedrock_key_config.region", provider, key.ID), + KeyID: key.ID, + }) + } + bedrockConfig.Region = &processedRegion + } + + // Process ARN if present + if bedrockConfig.ARN != nil { + processedARN, envVar, err := s.processEnvValue(*bedrockConfig.ARN) + if err != nil { + return err + } + if envVar != "" { + newEnvKeys[envVar] = struct{}{} + s.EnvKeys[envVar] = append(s.EnvKeys[envVar], configstore.EnvKeyInfo{ + EnvVar: envVar, + Provider: provider, + KeyType: "bedrock_config", + ConfigPath: fmt.Sprintf("providers.%s.keys[%s].bedrock_key_config.arn", provider, key.ID), + KeyID: key.ID, + }) + } + bedrockConfig.ARN = &processedARN + } + + return nil +} + +// GetVectorStoreConfigRedacted retrieves the vector store configuration with password redacted for safe external exposure +func (s *Config) GetVectorStoreConfigRedacted() (*vectorstore.Config, error) { + var err error + var vectorStoreConfig *vectorstore.Config + if s.ConfigStore != nil { + vectorStoreConfig, err = s.ConfigStore.GetVectorStoreConfig() + if err != nil { + return nil, fmt.Errorf("failed to get vector store config: %w", err) + } + } + if vectorStoreConfig == nil { + return nil, nil + } + if vectorStoreConfig.Type == vectorstore.VectorStoreTypeWeaviate { + weaviateConfig, ok := vectorStoreConfig.Config.(*vectorstore.WeaviateConfig) + if !ok { + return nil, fmt.Errorf("failed to cast vector store config to weaviate config") + } + // Create a copy to avoid modifying the original + redactedWeaviateConfig := *weaviateConfig + // Redact password if it exists + if redactedWeaviateConfig.ApiKey != "" { + redactedWeaviateConfig.ApiKey = RedactKey(redactedWeaviateConfig.ApiKey) + } + redactedVectorStoreConfig := *vectorStoreConfig + redactedVectorStoreConfig.Config = &redactedWeaviateConfig + return &redactedVectorStoreConfig, nil + } + return nil, nil +} + +// ValidateCustomProvider validates the custom provider configuration +func ValidateCustomProvider(config configstore.ProviderConfig, provider schemas.ModelProvider) error { + if config.CustomProviderConfig == nil { + return nil + } + + if bifrost.IsStandardProvider(provider) { + return fmt.Errorf("custom provider validation failed: cannot be created on standard providers: %s", provider) + } + + cpc := config.CustomProviderConfig + + // Validate base provider type + if cpc.BaseProviderType == "" { + return fmt.Errorf("custom provider validation failed: base_provider_type is required") + } + + // Check if base provider is a supported base provider + if !bifrost.IsSupportedBaseProvider(cpc.BaseProviderType) { + return fmt.Errorf("custom provider validation failed: unsupported base_provider_type: %s", cpc.BaseProviderType) + } + return nil +} + +// ValidateCustomProviderUpdate validates that immutable fields in CustomProviderConfig are not changed during updates +func ValidateCustomProviderUpdate(newConfig, existingConfig configstore.ProviderConfig, provider schemas.ModelProvider) error { + // If neither config has CustomProviderConfig, no validation needed + if newConfig.CustomProviderConfig == nil && existingConfig.CustomProviderConfig == nil { + return nil + } + + // If new config doesn't have CustomProviderConfig but existing does, return an error + if newConfig.CustomProviderConfig == nil { + return fmt.Errorf("custom_provider_config cannot be removed after creation for provider %s", provider) + } + + // If existing config doesn't have CustomProviderConfig but new one does, that's fine (adding it) + if existingConfig.CustomProviderConfig == nil { + return ValidateCustomProvider(newConfig, provider) + } + + // Both configs have CustomProviderConfig, validate immutable fields + newCPC := newConfig.CustomProviderConfig + existingCPC := existingConfig.CustomProviderConfig + + // CustomProviderKey is internally set and immutable, no validation needed + + // Check if BaseProviderType is being changed + if newCPC.BaseProviderType != existingCPC.BaseProviderType { + return fmt.Errorf("provider %s: base_provider_type cannot be changed from %s to %s after creation", + provider, existingCPC.BaseProviderType, newCPC.BaseProviderType) + } + + return nil +} + +func (s *Config) AddProviderKeysToSemanticCacheConfig(config *schemas.PluginConfig) error { + if config.Name != semanticcache.PluginName { + return nil + } + + // Check if config.Config exists + if config.Config == nil { + return fmt.Errorf("semantic_cache plugin config is nil") + } + + // Type assert config.Config to map[string]interface{} + configMap, ok := config.Config.(map[string]interface{}) + if !ok { + return fmt.Errorf("semantic_cache plugin config must be a map, got %T", config.Config) + } + + // Check if provider key exists and is a string + providerVal, exists := configMap["provider"] + if !exists { + return fmt.Errorf("semantic_cache plugin missing required 'provider' field") + } + + provider, ok := providerVal.(string) + if !ok { + return fmt.Errorf("semantic_cache plugin 'provider' field must be a string, got %T", providerVal) + } + + if provider == "" { + return fmt.Errorf("semantic_cache plugin 'provider' field cannot be empty") + } + + keys, err := s.GetProviderConfigRaw(schemas.ModelProvider(provider)) + if err != nil { + return fmt.Errorf("failed to get provider config for %s: %w", provider, err) + } + + configMap["keys"] = keys.Keys + + return nil +} + +func (s *Config) RemoveProviderKeysFromSemanticCacheConfig(config *configstore.TablePlugin) error { + if config.Name != semanticcache.PluginName { + return nil + } + + // Check if config.Config exists + if config.Config == nil { + return fmt.Errorf("semantic_cache plugin config is nil") + } + + // Type assert config.Config to map[string]interface{} + configMap, ok := config.Config.(map[string]interface{}) + if !ok { + return fmt.Errorf("semantic_cache plugin config must be a map, got %T", config.Config) + } + + configMap["keys"] = []schemas.Key{} + + config.Config = configMap + + return nil +} + +func DeepCopy[T any](in T) (T, error) { + var out T + b, err := json.Marshal(in) + if err != nil { + return out, err + } + err = json.Unmarshal(b, &out) + return out, err +} diff --git a/transports/bifrost-http/lib/ctx.go b/transports/bifrost-http/lib/ctx.go new file mode 100644 index 000000000..4a908e81e --- /dev/null +++ b/transports/bifrost-http/lib/ctx.go @@ -0,0 +1,243 @@ +// Package lib provides core functionality for the Bifrost HTTP service, +// including context propagation, header management, and integration with monitoring systems. +// +// This package handles the conversion of FastHTTP request contexts to Bifrost contexts, +// ensuring that important metadata and tracking information is preserved across the system. +// It supports propagation of both Prometheus metrics and Maxim tracing data through HTTP headers. +package lib + +import ( + "context" + "strconv" + "strings" + "time" + + "github.com/google/uuid" + "github.com/maximhq/bifrost/core/schemas" + "github.com/maximhq/bifrost/plugins/governance" + "github.com/maximhq/bifrost/plugins/maxim" + "github.com/maximhq/bifrost/plugins/semanticcache" + "github.com/maximhq/bifrost/plugins/telemetry" + "github.com/valyala/fasthttp" +) + +// ConvertToBifrostContext converts a FastHTTP RequestCtx to a Bifrost context, +// preserving important header values for monitoring and tracing purposes. +// +// The function processes several types of special headers: +// 1. Prometheus Headers (x-bf-prom-*): +// - All headers prefixed with 'x-bf-prom-' are copied to the context +// - The prefix is stripped and the remainder becomes the context key +// - Example: 'x-bf-prom-latency' becomes 'latency' in the context +// +// 2. Maxim Tracing Headers (x-bf-maxim-*): +// - Specifically handles 'x-bf-maxim-traceID' and 'x-bf-maxim-generationID' +// - These headers enable trace correlation across service boundaries +// - Values are stored using Maxim's context keys for consistency +// +// 3. MCP Headers (x-bf-mcp-*): +// - Specifically handles 'x-bf-mcp-include-clients', 'x-bf-mcp-exclude-clients', 'x-bf-mcp-include-tools', and 'x-bf-mcp-exclude-tools' +// - These headers enable MCP client and tool filtering +// - Values are stored using MCP context keys for consistency +// +// 4. Governance Headers: +// - x-bf-vk: Virtual key for governance (required for governance to work) +// - x-bf-team: Team identifier for team-based governance rules +// - x-bf-user: User identifier for user-based governance rules +// - x-bf-customer: Customer identifier for customer-based governance rules +// +// 5. API Key Headers: +// - Authorization: Bearer token format only (e.g., "Bearer sk-...") - OpenAI style +// - x-api-key: Direct API key value - Anthropic style +// - Keys are extracted and stored in the context using schemas.BifrostContextKey +// - This enables explicit key usage for requests via headers +// + +// Parameters: +// - ctx: The FastHTTP request context containing the original headers +// +// Returns: +// - *context.Context: A new context.Context containing the propagated values +// +// Example Usage: +// +// fastCtx := &fasthttp.RequestCtx{...} +// bifrostCtx := ConvertToBifrostContext(fastCtx) +// // bifrostCtx now contains any prometheus and maxim header values + +type ContextKey string + +func ConvertToBifrostContext(ctx *fasthttp.RequestCtx, allowDirectKeys bool) *context.Context { + bifrostCtx := context.Background() + + // First, check if x-request-id header exists + requestID := string(ctx.Request.Header.Peek("x-request-id")) + if requestID == "" { + requestID = uuid.New().String() + } + bifrostCtx = context.WithValue(bifrostCtx, schemas.BifrostContextKey("request-id"), requestID) + + // Initialize tags map for collecting maxim tags + maximTags := make(map[string]string) + + // Then process other headers + ctx.Request.Header.All()(func(key, value []byte) bool { + keyStr := strings.ToLower(string(key)) + + if strings.HasPrefix(keyStr, "x-bf-prom-") { + labelName := strings.TrimPrefix(keyStr, "x-bf-prom-") + bifrostCtx = context.WithValue(bifrostCtx, telemetry.ContextKey(labelName), string(value)) + } + + if strings.HasPrefix(keyStr, "x-bf-maxim-") { + labelName := strings.TrimPrefix(keyStr, "x-bf-maxim-") + + if labelName == string(maxim.GenerationIDKey) { + bifrostCtx = context.WithValue(bifrostCtx, maxim.ContextKey(labelName), string(value)) + } + + if labelName == string(maxim.TraceIDKey) { + bifrostCtx = context.WithValue(bifrostCtx, maxim.ContextKey(labelName), string(value)) + } + + if labelName == string(maxim.SessionIDKey) { + bifrostCtx = context.WithValue(bifrostCtx, maxim.ContextKey(labelName), string(value)) + } + + if labelName == string(maxim.TraceNameKey) { + bifrostCtx = context.WithValue(bifrostCtx, maxim.ContextKey(labelName), string(value)) + } + + if labelName == string(maxim.GenerationNameKey) { + bifrostCtx = context.WithValue(bifrostCtx, maxim.ContextKey(labelName), string(value)) + } + + if labelName == string(maxim.LogRepoIDKey) { + bifrostCtx = context.WithValue(bifrostCtx, maxim.ContextKey(labelName), string(value)) + } + + // apart from these all headers starting with x-bf-maxim- are keys for tags + // collect them in the maximTags map + if labelName != string(maxim.GenerationIDKey) && labelName != string(maxim.TraceIDKey) && labelName != string(maxim.SessionIDKey) && labelName != string(maxim.TraceNameKey) && labelName != string(maxim.GenerationNameKey) && labelName != string(maxim.LogRepoIDKey) { + maximTags[labelName] = string(value) + } + } + + if strings.HasPrefix(keyStr, "x-bf-mcp-") { + labelName := strings.TrimPrefix(keyStr, "x-bf-mcp-") + + if labelName == "include-clients" || labelName == "exclude-clients" || labelName == "include-tools" || labelName == "exclude-tools" { + bifrostCtx = context.WithValue(bifrostCtx, ContextKey("mcp-"+labelName), string(value)) + return true + } + } + + // Handle governance headers (x-bf-team, x-bf-user, x-bf-customer) + if keyStr == "x-bf-team" || keyStr == "x-bf-user" || keyStr == "x-bf-customer" { + bifrostCtx = context.WithValue(bifrostCtx, governance.ContextKey(keyStr), string(value)) + } + + // Handle virtual key header (x-bf-vk) + if keyStr == "x-bf-vk" { + bifrostCtx = context.WithValue(bifrostCtx, governance.ContextKey(keyStr), string(value)) + } + + // Handle cache key header (x-bf-cache-key) + if keyStr == "x-bf-cache-key" { + bifrostCtx = context.WithValue(bifrostCtx, semanticcache.CacheKey, string(value)) + } + + // Handle cache TTL header (x-bf-cache-ttl) + if keyStr == "x-bf-cache-ttl" { + valueStr := string(value) + var ttlDuration time.Duration + var err error + + // First try to parse as duration (e.g., "30s", "5m", "1h") + if ttlDuration, err = time.ParseDuration(valueStr); err != nil { + // If that fails, try to parse as plain number and treat as seconds + if seconds, parseErr := strconv.Atoi(valueStr); parseErr == nil && seconds > 0 { + ttlDuration = time.Duration(seconds) * time.Second + err = nil // Reset error since we successfully parsed as seconds + } + } + + if err == nil { + bifrostCtx = context.WithValue(bifrostCtx, semanticcache.CacheTTLKey, ttlDuration) + } + // If both parsing attempts fail, we silently ignore the header and use default TTL + } + + if keyStr == "x-bf-cache-threshold" { + threshold, err := strconv.ParseFloat(string(value), 64) + if err == nil { + // Clamp threshold to the inclusive range [0.0, 1.0] + if threshold < 0.0 { + threshold = 0.0 + } else if threshold > 1.0 { + threshold = 1.0 + } + bifrostCtx = context.WithValue(bifrostCtx, semanticcache.CacheThresholdKey, threshold) + } + // If parsing fails, silently ignore the header (no context value set) + } + + if keyStr == "x-bf-cache-type" { + bifrostCtx = context.WithValue(bifrostCtx, semanticcache.CacheTypeKey, semanticcache.CacheType(string(value))) + } + + if keyStr == "x-bf-cache-no-store" { + if valueStr := string(value); valueStr == "true" { + bifrostCtx = context.WithValue(bifrostCtx, semanticcache.CacheNoStoreKey, true) + } + } + + return true + }) + + // Store the collected maxim tags in the context + if len(maximTags) > 0 { + bifrostCtx = context.WithValue(bifrostCtx, maxim.ContextKey(maxim.TagsKey), maximTags) + } + + if allowDirectKeys { + // Extract API key from Authorization header (Bearer format) or x-api-key header + var apiKey string + + // TODO: fix plugin data leak + // Check Authorization header (Bearer format only - OpenAI style) + authHeader := string(ctx.Request.Header.Peek("Authorization")) + if authHeader != "" { + // Only accept Bearer token format: "Bearer ..." + if strings.HasPrefix(strings.ToLower(authHeader), "bearer ") { + authHeaderValue := strings.TrimSpace(authHeader[7:]) // Remove "Bearer " prefix + if authHeaderValue != "" { + apiKey = authHeaderValue + } + } else { + apiKey = authHeader + } + } + + // Check x-api-key header if no valid Authorization header found (Anthropic style) + if apiKey == "" { + xAPIKey := string(ctx.Request.Header.Peek("x-api-key")) + if xAPIKey != "" { + apiKey = strings.TrimSpace(xAPIKey) + } + } + + // If we found an API key, create a Key object and store it in context + if apiKey != "" { + key := schemas.Key{ + ID: "header-provided", // Identifier for header-provided keys + Value: apiKey, + Models: []string{}, // Empty models list - will be validated by provider + Weight: 1.0, // Default weight + } + bifrostCtx = context.WithValue(bifrostCtx, schemas.BifrostContextKeyDirectKey, key) + } + } + + return &bifrostCtx +} diff --git a/transports/bifrost-http/lib/errors.go b/transports/bifrost-http/lib/errors.go new file mode 100644 index 000000000..e2e37d0b3 --- /dev/null +++ b/transports/bifrost-http/lib/errors.go @@ -0,0 +1,5 @@ +package lib + +import "errors" + +var ErrNotFound = errors.New("not found") diff --git a/transports/bifrost-http/lib/lib.go b/transports/bifrost-http/lib/lib.go new file mode 100644 index 000000000..230ad4b97 --- /dev/null +++ b/transports/bifrost-http/lib/lib.go @@ -0,0 +1,8 @@ +package lib + +import ( + bifrost "github.com/maximhq/bifrost/core" + "github.com/maximhq/bifrost/core/schemas" +) + +var logger = bifrost.NewDefaultLogger(schemas.LogLevelInfo) diff --git a/transports/bifrost-http/main.go b/transports/bifrost-http/main.go new file mode 100644 index 000000000..a7947640e --- /dev/null +++ b/transports/bifrost-http/main.go @@ -0,0 +1,593 @@ +// Package http provides an HTTP service using FastHTTP that exposes endpoints +// for text and chat completions using various AI model providers (OpenAI, Anthropic, Bedrock, Mistral, Ollama, etc.). +// +// The HTTP service provides the following main endpoints: +// - /v1/text/completions: For text completion requests +// - /v1/chat/completions: For chat completion requests +// - /v1/mcp/tool/execute: For MCP tool execution requests +// - /providers/*: For provider configuration management +// +// Configuration is handled through a JSON config file, high-performance ConfigStore, and environment variables: +// - Use -app-dir flag to specify the application data directory (contains config.json and logs) +// - Use -port flag to specify the server port (default: 8080) +// - When no config file exists, common environment variables are auto-detected (OPENAI_API_KEY, ANTHROPIC_API_KEY, MISTRAL_API_KEY) +// +// ConfigStore Features: +// - Pure in-memory storage for ultra-fast config access +// - Environment variable processing for secure configuration management +// - Real-time configuration updates via HTTP API +// - Explicit persistence control via POST /config/save endpoint +// - Provider-specific key config support (Azure, Bedrock, Vertex) +// - Thread-safe operations with concurrent request handling +// - Statistics and monitoring endpoints for operational insights +// +// Performance Optimizations: +// - Configuration data is processed once during startup and stored in memory +// - Ultra-fast memory access eliminates I/O overhead on every request +// - All environment variable processing done upfront during configuration loading +// - Thread-safe concurrent access with read-write mutex protection +// +// Example usage: +// +// go run main.go -app-dir ./data -port 8080 -host 0.0.0.0 +// after setting provider API keys like OPENAI_API_KEY in the environment. +// +// To bind to all interfaces for container usage, set BIFROST_HOST=0.0.0.0 or use -host 0.0.0.0 +// +// Integration Support: +// Bifrost supports multiple AI provider integrations through dedicated HTTP endpoints. +// Each integration exposes API-compatible endpoints that accept the provider's native request format, +// automatically convert it to Bifrost's unified format, process it, and return the expected response format. +// +// Integration endpoints follow the pattern: /{provider}/{provider_api_path} +// Examples: +// - OpenAI: POST /openai/v1/chat/completions (accepts OpenAI ChatCompletion requests) +// - GenAI: POST /genai/v1beta/models/{model} (accepts Google GenAI requests) +// - Anthropic: POST /anthropic/v1/messages (accepts Anthropic Messages requests) +// +// This allows clients to use their existing integration code without modification while benefiting +// from Bifrost's unified model routing, fallbacks, monitoring capabilities, and high-performance configuration management. +// +// NOTE: Streaming is supported for chat completions via Server-Sent Events (SSE) +package main + +import ( + "context" + "embed" + "encoding/json" + "flag" + "fmt" + "mime" + "net" + "os" + "os/signal" + "path" + "path/filepath" + "runtime" + "strings" + "syscall" + "time" + + "github.com/fasthttp/router" + bifrost "github.com/maximhq/bifrost/core" + schemas "github.com/maximhq/bifrost/core/schemas" + "github.com/maximhq/bifrost/framework/pricing" + "github.com/maximhq/bifrost/plugins/governance" + "github.com/maximhq/bifrost/plugins/logging" + "github.com/maximhq/bifrost/plugins/maxim" + "github.com/maximhq/bifrost/plugins/semanticcache" + "github.com/maximhq/bifrost/plugins/telemetry" + "github.com/maximhq/bifrost/transports/bifrost-http/handlers" + "github.com/maximhq/bifrost/transports/bifrost-http/lib" + "github.com/prometheus/client_golang/prometheus" + "github.com/prometheus/client_golang/prometheus/collectors" + "github.com/prometheus/client_golang/prometheus/promhttp" + "github.com/valyala/fasthttp" + "github.com/valyala/fasthttp/fasthttpadaptor" +) + +//go:embed all:ui +var uiContent embed.FS + +var Version string + +var logger = bifrost.NewDefaultLogger(schemas.LogLevelInfo) + +// Command line flags +var ( + port string // Port to run the server on + host string // Host to bind the server to + appDir string // Application data directory + + logLevel string // Logger level: debug, info, warn, error + logOutputStyle string // Logger output style: json, pretty +) + +const ( + DefaultHost = "localhost" + DefaultPort = "8080" + DefaultAppDir = "./bifrost-data" + DefaultLogLevel = string(schemas.LogLevelInfo) + DefaultLogOutputStyle = string(schemas.LoggerOutputTypeJSON) +) + +// init initializes command line flags and validates required configuration. +// It sets up the following flags: +// - host: Host to bind the server to (default: localhost, can be overridden with BIFROST_HOST env var) +// - port: Server port (default: 8080) +// - app-dir: Application data directory (default: current directory) +// - log-level: Logger level (debug, info, warn, error). Default is info. +// - log-style: Logger output type (json or pretty). Default is JSON. + +func init() { + if Version == "" { + Version = "v1.0.0" + } + versionLine := fmt.Sprintf("β•‘%s%s%sβ•‘", strings.Repeat(" ", (61-2-len(Version))/2), Version, strings.Repeat(" ", (61-2-len(Version)+1)/2)) + // Welcome to bifrost! + fmt.Printf(` +╔═══════════════════════════════════════════════════════════╗ +β•‘ β•‘ +β•‘ β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ•— β–ˆβ–ˆβ•—β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ•—β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ•— β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ•— β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ•—β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ•— β•‘ +β•‘ β–ˆβ–ˆβ•”β•β•β–ˆβ–ˆβ•—β–ˆβ–ˆβ•‘β–ˆβ–ˆβ•”β•β•β•β•β•β–ˆβ–ˆβ•”β•β•β–ˆβ–ˆβ•—β–ˆβ–ˆβ•”β•β•β•β–ˆβ–ˆβ•—β–ˆβ–ˆβ•”β•β•β•β•β•β•šβ•β•β–ˆβ–ˆβ•”β•β•β• β•‘ +β•‘ β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ•”β•β–ˆβ–ˆβ•‘β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ•— β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ•”β•β–ˆβ–ˆβ•‘ β–ˆβ–ˆβ•‘β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ•— β–ˆβ–ˆβ•‘ β•‘ +β•‘ β–ˆβ–ˆβ•”β•β•β–ˆβ–ˆβ•—β–ˆβ–ˆβ•‘β–ˆβ–ˆβ•”β•β•β• β–ˆβ–ˆβ•”β•β•β–ˆβ–ˆβ•—β–ˆβ–ˆβ•‘ β–ˆβ–ˆβ•‘β•šβ•β•β•β•β–ˆβ–ˆβ•‘ β–ˆβ–ˆβ•‘ β•‘ +β•‘ β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ•”β•β–ˆβ–ˆβ•‘β–ˆβ–ˆβ•‘ β–ˆβ–ˆβ•‘ β–ˆβ–ˆβ•‘β•šβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ•”β•β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ•‘ β–ˆβ–ˆβ•‘ β•‘ +β•‘ β•šβ•β•β•β•β•β• β•šβ•β•β•šβ•β• β•šβ•β• β•šβ•β• β•šβ•β•β•β•β•β• β•šβ•β•β•β•β•β•β• β•šβ•β• β•‘ +β•‘ β•‘ +║═══════════════════════════════════════════════════════════║ +%s +║═══════════════════════════════════════════════════════════║ +β•‘ The Fastest LLM Gateway β•‘ +║═══════════════════════════════════════════════════════════║ +β•‘ https://github.com/maximhq/bifrost β•‘ +β•šβ•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β• + +`, versionLine) + handlers.SetVersion(Version) + // Set default host from environment variable or use localhost + defaultHost := os.Getenv("BIFROST_HOST") + if defaultHost == "" { + defaultHost = DefaultHost + } + + flag.StringVar(&port, "port", DefaultPort, "Port to run the server on") + flag.StringVar(&host, "host", defaultHost, "Host to bind the server to (default: localhost, override with BIFROST_HOST env var)") + flag.StringVar(&appDir, "app-dir", DefaultAppDir, "Application data directory (contains config.json and logs)") + flag.StringVar(&logLevel, "log-level", DefaultLogLevel, "Logger level (debug, info, warn, error). Default is info.") + flag.StringVar(&logOutputStyle, "log-style", DefaultLogOutputStyle, "Logger output type (json or pretty). Default is JSON.") + flag.Parse() + + // Configure logger from flags + logger.SetOutputType(schemas.LoggerOutputType(logOutputStyle)) + logger.SetLevel(schemas.LogLevel(logLevel)) +} + +// registerCollectorSafely attempts to register a Prometheus collector, +// handling the case where it may already be registered. +// It logs any errors that occur during registration, except for AlreadyRegisteredError. +func registerCollectorSafely(collector prometheus.Collector) { + if err := prometheus.Register(collector); err != nil { + if _, ok := err.(prometheus.AlreadyRegisteredError); !ok { + logger.Error("failed to register prometheus collector: %v", err) + } + } +} + +// corsMiddleware handles CORS headers for localhost and configured allowed origins +func corsMiddleware(config *lib.Config, next fasthttp.RequestHandler) fasthttp.RequestHandler { + return func(ctx *fasthttp.RequestCtx) { + origin := string(ctx.Request.Header.Peek("Origin")) + + // Check if origin is allowed (localhost always allowed + configured origins) + if handlers.IsOriginAllowed(origin, config.ClientConfig.AllowedOrigins) { + ctx.Response.Header.Set("Access-Control-Allow-Origin", origin) + } + + ctx.Response.Header.Set("Access-Control-Allow-Methods", "GET, POST, PUT, DELETE, OPTIONS") + ctx.Response.Header.Set("Access-Control-Allow-Headers", "Content-Type, Authorization, X-Requested-With") + ctx.Response.Header.Set("Access-Control-Allow-Credentials", "true") + ctx.Response.Header.Set("Access-Control-Max-Age", "86400") + + // Handle preflight OPTIONS requests + if string(ctx.Method()) == "OPTIONS" { + ctx.SetStatusCode(fasthttp.StatusOK) + return + } + + next(ctx) + } +} + +// uiHandler serves the embedded Next.js UI files +func uiHandler(ctx *fasthttp.RequestCtx) { + // Get the request path + requestPath := string(ctx.Path()) + + // Clean the path to prevent directory traversal + cleanPath := path.Clean(requestPath) + + // Handle .txt files (Next.js RSC payload files) - map from /{page}.txt to /{page}/index.txt + if strings.HasSuffix(cleanPath, ".txt") { + // Remove .txt extension and add /index.txt + basePath := strings.TrimSuffix(cleanPath, ".txt") + if basePath == "/" || basePath == "" { + basePath = "/index" + } + cleanPath = basePath + "/index.txt" + } + + // Remove leading slash and add ui prefix + if cleanPath == "/" { + cleanPath = "ui/index.html" + } else { + cleanPath = "ui" + cleanPath + } + + // Check if this is a static asset request (has file extension) + hasExtension := strings.Contains(filepath.Base(cleanPath), ".") + + // Try to read the file from embedded filesystem + data, err := uiContent.ReadFile(cleanPath) + if err != nil { + + // If it's a static asset (has extension) and not found, return 404 + if hasExtension { + ctx.SetStatusCode(fasthttp.StatusNotFound) + ctx.SetBodyString("404 - Static asset not found: " + requestPath) + return + } + + // For routes without extensions (SPA routing), try {path}/index.html first + if !hasExtension { + indexPath := cleanPath + "/index.html" + data, err = uiContent.ReadFile(indexPath) + if err == nil { + cleanPath = indexPath + } else { + // If that fails, serve root index.html as fallback + data, err = uiContent.ReadFile("ui/index.html") + if err != nil { + ctx.SetStatusCode(fasthttp.StatusNotFound) + ctx.SetBodyString("404 - File not found") + return + } + cleanPath = "ui/index.html" + } + } else { + ctx.SetStatusCode(fasthttp.StatusNotFound) + ctx.SetBodyString("404 - File not found") + return + } + } + + // Set content type based on file extension + ext := filepath.Ext(cleanPath) + contentType := mime.TypeByExtension(ext) + if contentType == "" { + contentType = "application/octet-stream" + } + ctx.SetContentType(contentType) + + // Set cache headers for static assets + if strings.HasPrefix(cleanPath, "ui/_next/static/") { + ctx.Response.Header.Set("Cache-Control", "public, max-age=31536000, immutable") + } else if ext == ".html" { + ctx.Response.Header.Set("Cache-Control", "no-cache") + } else { + ctx.Response.Header.Set("Cache-Control", "public, max-age=3600") + } + + // Send the file content + ctx.SetBody(data) +} + +// GetDefaultConfigDir returns the OS-specific default configuration directory for Bifrost. +// This follows standard conventions: +// - Linux/macOS: ~/.config/bifrost +// - Windows: %APPDATA%\bifrost +// - If appDir is provided (non-empty), it returns that instead +func getDefaultConfigDir(appDir string) string { + // If appDir is provided, use it directly + if appDir != "" && appDir != "./bifrost-data" { + return appDir + } + + // Get OS-specific config directory + var configDir string + switch runtime.GOOS { + case "windows": + // Windows: %APPDATA%\bifrost + if appData := os.Getenv("APPDATA"); appData != "" { + configDir = filepath.Join(appData, "bifrost") + } else { + // Fallback to user home directory + if homeDir, err := os.UserHomeDir(); err == nil { + configDir = filepath.Join(homeDir, "AppData", "Roaming", "bifrost") + } + } + default: + // Linux, macOS and other Unix-like systems: ~/.config/bifrost + if homeDir, err := os.UserHomeDir(); err == nil { + configDir = filepath.Join(homeDir, ".config", "bifrost") + } + } + + // If we couldn't determine the config directory, fall back to current directory + if configDir == "" { + configDir = "./bifrost-data" + } + + return configDir +} + +// main is the entry point of the application. +// It: +// 1. Initializes Prometheus collectors for monitoring +// 2. Reads and parses configuration from the specified config file +// 3. Initializes the Bifrost client with the configuration +// 4. Sets up HTTP routes for text and chat completions +// 5. Starts the HTTP server on the specified host and port +// +// The server exposes the following endpoints: +// - POST /v1/text/completions: For text completion requests +// - POST /v1/chat/completions: For chat completion requests +// - GET /metrics: For Prometheus metrics +func main() { + ctx := context.Background() + configDir := getDefaultConfigDir(appDir) + // Ensure app directory exists + if err := os.MkdirAll(configDir, 0755); err != nil { + logger.Fatal("failed to create app directory %s: %v", configDir, err) + } + + // Register Prometheus collectors + registerCollectorSafely(collectors.NewGoCollector()) + registerCollectorSafely(collectors.NewProcessCollector(collectors.ProcessCollectorOpts{})) + + // Initialize high-performance configuration store with dedicated database + config, err := lib.LoadConfig(ctx, configDir) + if err != nil { + logger.Fatal("failed to load config %v", err) + } + + // Initialize pricing manager + pricingManager, err := pricing.Init(config.ConfigStore, logger) + if err != nil { + logger.Error("failed to initialize pricing manager: %v", err) + } + + // Create account backed by the high-performance store (all processing is done in LoadFromDatabase) + // The account interface now benefits from ultra-fast config access times via in-memory storage + account := lib.NewBaseAccount(config) + + // Initialize plugins + loadedPlugins := []schemas.Plugin{} + + telemetry.InitPrometheusMetrics(config.ClientConfig.PrometheusLabels) + logger.Debug("prometheus Go/Process collectors registered.") + + promPlugin := telemetry.Init(pricingManager, logger) + + loadedPlugins = append(loadedPlugins, promPlugin) + + var loggingPlugin *logging.LoggerPlugin + var loggingHandler *handlers.LoggingHandler + var wsHandler *handlers.WebSocketHandler + + if config.ClientConfig.EnableLogging && config.LogsStore != nil { + // Use dedicated logs database with high-scale optimizations + loggingPlugin, err = logging.Init(logger, config.LogsStore, pricingManager) + if err != nil { + logger.Fatal("failed to initialize logging plugin: %v", err) + } + + loadedPlugins = append(loadedPlugins, loggingPlugin) + loggingHandler = handlers.NewLoggingHandler(loggingPlugin.GetPluginLogManager(), logger) + wsHandler = handlers.NewWebSocketHandler(loggingPlugin.GetPluginLogManager(), logger, config.ClientConfig.AllowedOrigins) + } + + var governancePlugin *governance.GovernancePlugin + var governanceHandler *handlers.GovernanceHandler + + if config.ClientConfig.EnableGovernance { + // Initialize governance plugin + governancePlugin, err = governance.Init(ctx, &governance.Config{ + IsVkMandatory: &config.ClientConfig.EnforceGovernanceHeader, + }, logger, config.ConfigStore, config.GovernanceConfig, pricingManager) + if err != nil { + logger.Error("failed to initialize governance plugin: %s", err.Error()) + } else { + loadedPlugins = append(loadedPlugins, governancePlugin) + + governanceHandler, err = handlers.NewGovernanceHandler(governancePlugin, config.ConfigStore, logger) + if err != nil { + logger.Error("failed to initialize governance handler: %s", err.Error()) + } + } + } + + // Currently we support first party plugins only + // Eventually same flow will be used for third party plugins + for _, plugin := range config.Plugins { + if !plugin.Enabled { + logger.Debug("plugin %s is disabled, skipping initialization", plugin.Name) + continue + } + switch strings.ToLower(plugin.Name) { + case maxim.PluginName: + + var maximConfig maxim.Config + if plugin.Config != nil { + configBytes, err := json.Marshal(plugin.Config) + if err != nil { + logger.Fatal("failed to marshal maxim config: %v", err) + } + if err := json.Unmarshal(configBytes, &maximConfig); err != nil { + logger.Fatal("failed to unmarshal maxim config: %v", err) + } + } + + maximPlugin, err := maxim.Init(maximConfig) + if err != nil { + logger.Warn("failed to initialize maxim plugin: %v", err) + } else { + loadedPlugins = append(loadedPlugins, maximPlugin) + } + case semanticcache.PluginName: + if config.VectorStore == nil { + logger.Error("vector store is required to initialize semantic cache plugin, skipping initialization") + continue + } + + // Convert config map to semanticcache.Config struct + var semCacheConfig semanticcache.Config + if plugin.Config != nil { + configBytes, err := json.Marshal(plugin.Config) + if err != nil { + logger.Fatal("failed to marshal semantic cache config: %v", err) + } + if err := json.Unmarshal(configBytes, &semCacheConfig); err != nil { + logger.Fatal("failed to unmarshal semantic cache config: %v", err) + } + } + + semanticCachePlugin, err := semanticcache.Init(ctx, semCacheConfig, logger, config.VectorStore) + if err != nil { + logger.Error("failed to initialize semantic cache: %v", err) + } else { + loadedPlugins = append(loadedPlugins, semanticCachePlugin) + logger.Info("successfully initialized semantic cache") + } + } + } + + client, err := bifrost.Init(ctx, schemas.BifrostConfig{ + Account: account, + InitialPoolSize: config.ClientConfig.InitialPoolSize, + DropExcessRequests: config.ClientConfig.DropExcessRequests, + Plugins: loadedPlugins, + MCPConfig: config.MCPConfig, + Logger: logger, + }) + if err != nil { + logger.Fatal("failed to initialize bifrost: %v", err) + } + + config.SetBifrostClient(client) + + // Initialize handlers + providerHandler := handlers.NewProviderHandler(config, client, logger) + completionHandler := handlers.NewCompletionHandler(client, config, logger) + mcpHandler := handlers.NewMCPHandler(client, logger, config) + integrationHandler := handlers.NewIntegrationHandler(client, config) + configHandler := handlers.NewConfigHandler(client, logger, config) + pluginsHandler := handlers.NewPluginsHandler(config.ConfigStore, logger) + + var cacheHandler *handlers.CacheHandler + for _, plugin := range loadedPlugins { + if plugin.GetName() == semanticcache.PluginName { + cacheHandler = handlers.NewCacheHandler(plugin, logger) + } + } + + // Set up WebSocket callback for real-time log updates + if wsHandler != nil && loggingPlugin != nil { + loggingPlugin.SetLogCallback(wsHandler.BroadcastLogUpdate) + // Start WebSocket heartbeat + wsHandler.StartHeartbeat() + } + + r := router.New() + + // Register all handler routes + providerHandler.RegisterRoutes(r) + completionHandler.RegisterRoutes(r) + mcpHandler.RegisterRoutes(r) + integrationHandler.RegisterRoutes(r) + configHandler.RegisterRoutes(r) + pluginsHandler.RegisterRoutes(r) + if cacheHandler != nil { + cacheHandler.RegisterRoutes(r) + } + if governanceHandler != nil { + governanceHandler.RegisterRoutes(r) + } + if loggingHandler != nil { + loggingHandler.RegisterRoutes(r) + } + if wsHandler != nil { + wsHandler.RegisterRoutes(r) + } + + // Add Prometheus /metrics endpoint + r.GET("/metrics", fasthttpadaptor.NewFastHTTPHandler(promhttp.Handler())) + + // Add UI routes - serve the embedded Next.js build + r.GET("/", uiHandler) + r.GET("/{filepath:*}", uiHandler) + + r.NotFound = func(ctx *fasthttp.RequestCtx) { + handlers.SendError(ctx, fasthttp.StatusNotFound, "Route not found: "+string(ctx.Path()), logger) + } + + // Apply CORS middleware to all routes + corsHandler := corsMiddleware(config, r.Handler) + + // Create fasthttp server instance + server := &fasthttp.Server{ + Handler: corsHandler, + MaxRequestBodySize: config.ClientConfig.MaxRequestBodySizeMB * 1024 * 1024, + } + + // Create channels for signal and error handling + sigChan := make(chan os.Signal, 1) + errChan := make(chan error, 1) + signal.Notify(sigChan, syscall.SIGINT, syscall.SIGTERM) + + // Start server in a goroutine + serverAddr := net.JoinHostPort(host, port) + go func() { + logger.Info("successfully started bifrost, serving UI on http://%s:%s", host, port) + if err := server.ListenAndServe(serverAddr); err != nil { + errChan <- err + } + }() + + // Wait for either termination signal or server error + select { + case sig := <-sigChan: + logger.Info("received signal %v, initiating graceful shutdown...", sig) + // Create shutdown context with timeout + shutdownCtx, cancel := context.WithTimeout(ctx, 30*time.Second) + defer cancel() + // Perform graceful shutdown + if err := server.Shutdown(); err != nil { + logger.Error("error during graceful shutdown: %v", err) + } else { + logger.Info("server gracefully shutdown") + } + + // Wait for shutdown to complete or timeout + done := make(chan struct{}) + go func() { + defer close(done) + // Cleanup resources + if wsHandler != nil { + wsHandler.Stop() + } + client.Shutdown() + }() + + select { + case <-done: + logger.Info("cleanup completed") + case <-shutdownCtx.Done(): + logger.Warn("cleanup timed out after 30 seconds") + } + + case err := <-errChan: + logger.Error("server failed to start: %v", err) + os.Exit(1) + } +} diff --git a/transports/changelog.md b/transports/changelog.md new file mode 100644 index 000000000..9c26654f0 --- /dev/null +++ b/transports/changelog.md @@ -0,0 +1,4 @@ + + + +- Fixes pricing computation for nested model names i.e. groq/openai/gpt-oss-20b. \ No newline at end of file diff --git a/transports/config.example.json b/transports/config.example.json deleted file mode 100644 index 159aecac6..000000000 --- a/transports/config.example.json +++ /dev/null @@ -1,117 +0,0 @@ -{ - "OpenAI": { - "keys": [ - { - "value": "env.OPENAI_API_KEY", - "models": ["gpt-4o-mini", "gpt-4-turbo"], - "weight": 1.0 - } - ], - "network_config": { - "default_request_timeout_in_seconds": 30, - "max_retries": 1, - "retry_backoff_initial_ms": 100, - "retry_backoff_max_ms": 2000 - }, - "concurrency_and_buffer_size": { - "concurrency": 3, - "buffer_size": 10 - } - }, - "Anthropic": { - "keys": [ - { - "value": "env.ANTHROPIC_API_KEY", - "models": [ - "claude-3-7-sonnet-20250219", - "claude-3-5-sonnet-20240620", - "claude-2.1" - ], - "weight": 1.0 - } - ], - "network_config": { - "default_request_timeout_in_seconds": 30, - "max_retries": 1, - "retry_backoff_initial_ms": 100, - "retry_backoff_max_ms": 2000 - }, - "concurrency_and_buffer_size": { - "concurrency": 3, - "buffer_size": 10 - } - }, - "Bedrock": { - "keys": [ - { - "value": "env.BEDROCK_API_KEY", - "models": [ - "anthropic.claude-v2:1", - "mistral.mixtral-8x7b-instruct-v0:1", - "mistral.mistral-large-2402-v1:0", - "anthropic.claude-3-sonnet-20240229-v1:0" - ], - "weight": 1.0 - } - ], - "network_config": { - "default_request_timeout_in_seconds": 30, - "max_retries": 1, - "retry_backoff_initial_ms": 100, - "retry_backoff_max_ms": 2000 - }, - "meta_config": { - "secret_access_key": "env.BEDROCK_ACCESS_KEY", - "region": "us-east-1" - }, - "concurrency_and_buffer_size": { - "concurrency": 3, - "buffer_size": 10 - } - }, - "Cohere": { - "keys": [ - { - "value": "env.COHERE_API_KEY", - "models": ["command-a-03-2025"], - "weight": 1.0 - } - ], - "network_config": { - "default_request_timeout_in_seconds": 30, - "max_retries": 1, - "retry_backoff_initial_ms": 100, - "retry_backoff_max_ms": 2000 - }, - "concurrency_and_buffer_size": { - "concurrency": 3, - "buffer_size": 10 - } - }, - "Azure": { - "keys": [ - { - "value": "env.AZURE_API_KEY", - "models": ["gpt-4o"], - "weight": 1.0 - } - ], - "network_config": { - "default_request_timeout_in_seconds": 30, - "max_retries": 1, - "retry_backoff_initial_ms": 100, - "retry_backoff_max_ms": 2000 - }, - "meta_config": { - "endpoint": "env.AZURE_ENDPOINT", - "deployments": { - "gpt-4o": "gpt-4o-aug" - }, - "api_version": "2024-08-01-preview" - }, - "concurrency_and_buffer_size": { - "concurrency": 3, - "buffer_size": 10 - } - } -} diff --git a/transports/config.schema.json b/transports/config.schema.json new file mode 100644 index 000000000..536f52fe4 --- /dev/null +++ b/transports/config.schema.json @@ -0,0 +1,853 @@ +{ + "$schema": "https://json-schema.org/draft/2019-09/schema", + "$id": "https://www.getbifrost.ai/schema", + "title": "Bifrost Configuration Schema", + "description": "Schema for Bifrost HTTP transport configuration", + "type": "object", + "properties": { + "client": { + "type": "object", + "description": "Client configuration settings", + "properties": { + "drop_excess_requests": { + "type": "boolean", + "description": "Whether to drop excess requests when pool is full" + }, + "initial_pool_size": { + "type": "integer", + "minimum": 1, + "description": "Initial size of the connection pool", + "default": 300 + }, + "prometheus_labels": { + "type": "array", + "items": { + "type": "string" + }, + "description": "Labels to use for Prometheus metrics" + }, + "allowed_origins": { + "type": "array", + "items": { + "oneOf": [ + { + "type": "string", + "const": "*" + }, + { + "type": "string", + "format": "uri" + } + ] + }, + "description": "CORS allowed origins (supports \"*\" or URI strings)" + }, + "enable_logging": { + "type": "boolean", + "description": "Enable request/response logging" + }, + "enable_governance": { + "type": "boolean", + "description": "Enable governance features" + }, + "enforce_governance_header": { + "type": "boolean", + "description": "Enforce governance header. This will require every incoming request to include x-bf-vk header." + }, + "allow_direct_keys": { + "type": "boolean", + "description": "Allow provider keys" + }, + "max_request_body_size_mb": { + "type": "integer", + "minimum": 1, + "description": "Maximum request body size in MB" + } + }, + "additionalProperties": false + }, + "providers": { + "type": "object", + "description": "AI provider configurations", + "properties": { + "openai": { + "$ref": "#/$defs/provider" + }, + "anthropic": { + "$ref": "#/$defs/provider" + }, + "bedrock": { + "$ref": "#/$defs/provider_with_bedrock_config" + }, + "cohere": { + "$ref": "#/$defs/provider" + }, + "azure": { + "$ref": "#/$defs/provider_with_azure_config" + }, + "vertex": { + "$ref": "#/$defs/provider_with_vertex_config" + }, + "mistral": { + "$ref": "#/$defs/provider" + }, + "ollama": { + "$ref": "#/$defs/provider" + }, + "groq": { + "$ref": "#/$defs/provider" + }, + "gemini": { + "$ref": "#/$defs/provider" + }, + "openrouter": { + "$ref": "#/$defs/provider" + }, + "sgl": { + "$ref": "#/$defs/provider" + }, + "parasail": { + "$ref": "#/$defs/provider" + }, + "cerebras": { + "$ref": "#/$defs/provider" + } + }, + "additionalProperties": true + }, + "mcp": { + "type": "object", + "description": "Model Context Protocol configuration", + "properties": { + "client_configs": { + "type": "array", + "items": { + "$ref": "#/$defs/mcp_client_config" + }, + "description": "MCP client configurations" + } + }, + "additionalProperties": false + }, + "vector_store": { + "type": "object", + "description": "Vector store configuration for caching", + "properties": { + "enabled": { + "type": "boolean", + "description": "Enable vector store" + }, + "type": { + "type": "string", + "enum": [ + "weaviate" + ], + "description": "Vector store type" + }, + "config": { + "anyOf": [ + { + "if": { + "properties": { + "type": { + "const": "weaviate" + } + } + }, + "then": { + "$ref": "#/$defs/weaviate_config" + } + } + ] + } + }, + "additionalProperties": false + }, + "config_store": { + "type": "object", + "description": "Configuration store settings", + "properties": { + "enabled": { + "type": "boolean", + "description": "Enable configuration store" + }, + "type": { + "type": "string", + "enum": [ + "sqlite" + ], + "description": "Configuration store type" + }, + "config": { + "type": "object", + "properties": { + "path": { + "type": "string", + "description": "Database file path" + } + }, + "required": [ + "path" + ], + "additionalProperties": false + } + }, + "additionalProperties": false + }, + "logs_store": { + "type": "object", + "description": "Logs store settings", + "properties": { + "enabled": { + "type": "boolean", + "description": "Enable logs store" + }, + "type": { + "type": "string", + "enum": [ + "sqlite" + ], + "description": "Logs store type" + }, + "config": { + "type": "object", + "properties": { + "path": { + "type": "string", + "description": "Database file path" + } + }, + "required": [ + "path" + ], + "additionalProperties": false + } + }, + "additionalProperties": false + }, + "plugins": { + "type": "object", + "description": "Plugins configuration", + "properties": { + "enabled": { + "type": "boolean", + "description": "Enable plugins" + }, + "name":{ + "type": "string", + "description": "Name of the plugin" + }, + "config":{ + "type": "object", + "description": "Configuration for the plugin" + } + }, + "additionalProperties": false + } + }, + "additionalProperties": false, + "$defs": { + "network_config": { + "type": "object", + "properties": { + "base_url": { + "type": "string", + "format": "uri", + "description": "Base URL for the provider (optional, required for Ollama)" + }, + "extra_headers": { + "type": "object", + "additionalProperties": { + "type": "string" + }, + "description": "Additional headers to send with requests" + }, + "default_request_timeout_in_seconds": { + "type": "integer", + "minimum": 1, + "description": "Default request timeout in seconds" + }, + "max_retries": { + "type": "integer", + "minimum": 0, + "description": "Maximum number of retries" + }, + "retry_backoff_initial_ms": { + "type": "integer", + "minimum": 0, + "description": "Initial retry backoff in milliseconds" + }, + "retry_backoff_max_ms": { + "type": "integer", + "minimum": 0, + "description": "Maximum retry backoff in milliseconds" + } + }, + "additionalProperties": false + }, + "concurrency_config": { + "type": "object", + "properties": { + "concurrency": { + "type": "integer", + "minimum": 1, + "description": "Number of concurrent requests" + }, + "buffer_size": { + "type": "integer", + "minimum": 1, + "description": "Buffer size for requests" + } + }, + "required": [ + "concurrency", + "buffer_size" + ], + "additionalProperties": false + }, + "base_key": { + "type": "object", + "properties": { + "value": { + "type": "string", + "description": "API key value (can use env. prefix)" + }, + "models": { + "type": "array", + "items": { + "type": "string" + }, + "description": "Supported models for this key" + }, + "weight": { + "type": "number", + "minimum": 0, + "description": "Weight for load balancing" + } + }, + "required": [ + "weight" + ], + "additionalProperties": false + }, + "bedrock_key": { + "allOf": [ + { + "$ref": "#/$defs/base_key" + }, + { + "type": "object", + "properties": { + "bedrock_key_config": { + "type": "object", + "properties": { + "access_key": { + "type": "string", + "description": "AWS access key (can use env. prefix)" + }, + "secret_key": { + "type": "string", + "description": "AWS secret key (can use env. prefix)" + }, + "session_token": { + "type": "string", + "description": "AWS session token (can use env. prefix)" + }, + "deployments": { + "type": "object", + "additionalProperties": { + "type": "string" + }, + "description": "Model to deployment mappings" + }, + "arn": { + "type": "string", + "description": "AWS ARN" + }, + "region": { + "type": "string", + "description": "AWS region" + } + }, + "required": [ + "region" + ], + "additionalProperties": false + } + }, + "required": [ + "bedrock_key_config" + ] + } + ] + }, + "azure_key": { + "allOf": [ + { + "$ref": "#/$defs/base_key" + }, + { + "type": "object", + "properties": { + "azure_key_config": { + "type": "object", + "properties": { + "endpoint": { + "type": "string", + "description": "Azure endpoint (can use env. prefix)" + }, + "deployments": { + "type": "object", + "additionalProperties": { + "type": "string" + }, + "description": "Model to deployment mappings" + }, + "api_version": { + "type": "string", + "description": "Azure API version" + } + }, + "required": [ + "endpoint", + "api_version" + ], + "additionalProperties": false + } + }, + "required": [ + "azure_key_config" + ] + } + ] + }, + "vertex_key": { + "allOf": [ + { + "$ref": "#/$defs/base_key" + }, + { + "type": "object", + "properties": { + "vertex_key_config": { + "type": "object", + "properties": { + "project_id": { + "type": "string", + "description": "Google Cloud project ID (can use env. prefix)" + }, + "region": { + "type": "string", + "description": "Google Cloud region" + }, + "auth_credentials": { + "type": "string", + "description": "Authentication credentials (can use env. prefix)" + } + }, + "required": [ + "project_id", + "region" + ], + "additionalProperties": false + } + }, + "required": [ + "vertex_key_config" + ] + } + ] + }, + "provider": { + "type": "object", + "properties": { + "keys": { + "type": "array", + "items": { + "$ref": "#/$defs/base_key" + }, + "minItems": 1, + "description": "API keys for this provider" + }, + "network_config": { + "$ref": "#/$defs/network_config" + }, + "concurrency_and_buffer_size": { + "$ref": "#/$defs/concurrency_config" + }, + "proxy_config": { + "$ref": "#/$defs/proxy_config" + }, + "send_back_raw_response": { + "type": "boolean", + "description": "Include raw response in BifrostResponse (default: false)" + } + }, + "required": [ + "keys" + ], + "additionalProperties": false + }, + "provider_with_bedrock_config": { + "type": "object", + "properties": { + "keys": { + "type": "array", + "items": { + "$ref": "#/$defs/bedrock_key" + }, + "minItems": 1, + "description": "API keys for this provider" + }, + "network_config": { + "$ref": "#/$defs/network_config" + }, + "concurrency_and_buffer_size": { + "$ref": "#/$defs/concurrency_config" + }, + "proxy_config": { + "$ref": "#/$defs/proxy_config" + }, + "send_back_raw_response": { + "type": "boolean", + "description": "Include raw response in BifrostResponse (default: false)" + } + }, + "required": [ + "keys" + ], + "additionalProperties": false + }, + "provider_with_azure_config": { + "type": "object", + "properties": { + "keys": { + "type": "array", + "items": { + "$ref": "#/$defs/azure_key" + }, + "minItems": 1, + "description": "API keys for this provider" + }, + "network_config": { + "$ref": "#/$defs/network_config" + }, + "concurrency_and_buffer_size": { + "$ref": "#/$defs/concurrency_config" + }, + "proxy_config": { + "$ref": "#/$defs/proxy_config" + }, + "send_back_raw_response": { + "type": "boolean", + "description": "Include raw response in BifrostResponse (default: false)" + } + }, + "required": [ + "keys" + ], + "additionalProperties": false + }, + "provider_with_vertex_config": { + "type": "object", + "properties": { + "keys": { + "type": "array", + "items": { + "$ref": "#/$defs/vertex_key" + }, + "minItems": 1, + "description": "API keys for this provider" + }, + "network_config": { + "$ref": "#/$defs/network_config" + }, + "concurrency_and_buffer_size": { + "$ref": "#/$defs/concurrency_config" + }, + "proxy_config": { + "$ref": "#/$defs/proxy_config" + }, + "send_back_raw_response": { + "type": "boolean", + "description": "Include raw response in BifrostResponse (default: false)" + } + }, + "required": [ + "keys" + ], + "additionalProperties": false + }, + "mcp_client_config": { + "type": "object", + "properties": { + "name": { + "type": "string", + "description": "Name of the MCP client" + }, + "connection_type": { + "type": "string", + "enum": [ + "stdio", + "websocket", + "http" + ], + "description": "Connection type for MCP client" + }, + "stdio_config": { + "type": "object", + "properties": { + "command": { + "type": "string", + "description": "Command to execute" + }, + "args": { + "type": "array", + "items": { + "type": "string" + }, + "description": "Command arguments" + }, + "envs": { + "type": "array", + "items": { + "type": "string" + }, + "description": "Environment variables" + } + }, + "required": [ + "command" + ], + "additionalProperties": false + }, + "websocket_config": { + "type": "object", + "properties": { + "url": { + "type": "string", + "format": "uri", + "description": "WebSocket URL" + } + }, + "required": [ + "url" + ], + "additionalProperties": false + }, + "http_config": { + "type": "object", + "properties": { + "url": { + "type": "string", + "format": "uri", + "description": "HTTP URL" + } + }, + "required": [ + "url" + ], + "additionalProperties": false + } + }, + "required": [ + "name", + "connection_type" + ], + "additionalProperties": false, + "oneOf": [ + { + "properties": { + "connection_type": { + "const": "stdio" + } + }, + "required": [ + "stdio_config" + ] + }, + { + "properties": { + "connection_type": { + "const": "websocket" + } + }, + "required": [ + "websocket_config" + ] + }, + { + "properties": { + "connection_type": { + "const": "http" + } + }, + "required": [ + "http_config" + ] + } + ] + }, + "weaviate_config": { + "type": "object", + "description": "Weaviate configuration for vector store", + "properties": { + "scheme": { + "type": "string", + "description": "Weaviate server scheme (http or https) - REQUIRED" + }, + "host": { + "type": "string", + "description": "Weaviate server host (host:port) - REQUIRED" + }, + "api_key": { + "type": "string", + "description": "API key for Weaviate authentication (optional)" + }, + "grpc_config": { + "type": "object", + "properties": { + "host": { + "type": "string", + "description": "Weaviate server host (host:port). If host is without a port number then the 80 port for insecured and 443 port for secured connections will be used." + }, + "secured": { + "type": "boolean", + "description": "Secured set it to true if it's a secured connection" + } + } + }, + "headers": { + "type": "object", + "description": "Additional headers to send with requests" + }, + "timeout": { + "type": "string", + "pattern": "^[0-9]+(ns|us|Β΅s|ms|s|m|h)$", + "description": "Timeout for Weaviate operations (e.g., '5s')" + }, + "class_name": { + "type": "string", + "description": "Class name for Weaviate vector store" + }, + "properties": { + "type": "array", + "items": { + "type": "object" + }, + "description": "Properties for Weaviate vector store" + } + }, + "required": [ + "scheme", + "host" + ], + "additionalProperties": false + }, + "proxy_config": { + "type": "object", + "description": "Proxy configuration for provider connections", + "properties": { + "type": { + "type": "string", + "enum": [ + "none", + "http", + "socks5", + "environment" + ], + "description": "Type of proxy to use" + }, + "url": { + "type": "string", + "format": "uri", + "description": "URL of the proxy server" + }, + "username": { + "type": "string", + "description": "Username for proxy authentication" + }, + "password": { + "type": "string", + "description": "Password for proxy authentication" + } + }, + "required": [ + "type" + ], + "additionalProperties": false + }, + "clusterConfig": { + "type": "object", + "description": "Cluster mode configuration", + "properties": { + "enabled": { + "type": "boolean", + "description": "Whether cluster mode is enabled" + }, + "peers": { + "type": "array", + "description": "List of peer addresses", + "items": { + "type": "string", + "description": "Peer address in host:port format" + } + }, + "gossip": { + "type": "object", + "description": "Gossip protocol configuration", + "properties": { + "port": { + "type": "integer", + "minimum": 1, + "maximum": 65535, + "description": "Port for gossip communication" + }, + "config": { + "type": "object", + "description": "Gossip protocol settings", + "properties": { + "livenessProbeEverySeconds": { + "type": "integer", + "minimum": 1, + "description": "Interval between liveness probes in seconds" + }, + "timeoutSeconds": { + "type": "integer", + "minimum": 1, + "description": "Timeout for operations in seconds" + }, + "successThreshold": { + "type": "integer", + "minimum": 1, + "description": "Number of successful probes required" + }, + "failureThreshold": { + "type": "integer", + "minimum": 1, + "description": "Number of failed probes before marking as failed" + } + }, + "required": [ + "livenessProbeEverySeconds", + "timeoutSeconds", + "successThreshold", + "failureThreshold" + ], + "additionalProperties": false + } + }, + "required": [ + "port", + "config" + ], + "additionalProperties": false + } + }, + "required": [ + "enabled" + ], + "additionalProperties": false + } + } +} \ No newline at end of file diff --git a/transports/docker-entrypoint.sh b/transports/docker-entrypoint.sh new file mode 100644 index 000000000..b8e1580a6 --- /dev/null +++ b/transports/docker-entrypoint.sh @@ -0,0 +1,76 @@ +#!/bin/sh +set -e + +# Function to fix permissions on mounted volumes +fix_permissions() { + # Check if /app/data exists and fix ownership if needed + if [ -d "/app/data" ]; then + # Get current user info + CURRENT_UID=$(id -u) + CURRENT_GID=$(id -g) + + # Get directory ownership + DATA_UID=$(stat -c %u /app/data 2>/dev/null || echo "0") + DATA_GID=$(stat -c %g /app/data 2>/dev/null || echo "0") + + # If ownership doesn't match current user, try to fix it + if [ "$DATA_UID" != "$CURRENT_UID" ] || [ "$DATA_GID" != "$CURRENT_GID" ]; then + echo "Fixing permissions on /app/data (was $DATA_UID:$DATA_GID, setting to $CURRENT_UID:$CURRENT_GID)" + + # Try to change ownership (will work if running as root or if user has permission) + if chown -R "$CURRENT_UID:$CURRENT_GID" /app/data 2>/dev/null; then + echo "Successfully updated permissions on /app/data" + else + echo "Warning: Could not change ownership of /app/data. You may need to run:" + echo " docker run --user \$(id -u):\$(id -g) ..." + echo " or ensure the host directory is owned by UID:GID $CURRENT_UID:$CURRENT_GID" + fi + fi + + # Ensure logs subdirectory exists with correct permissions + mkdir -p /app/data/logs + chmod 755 /app/data/logs 2>/dev/null || true + fi +} + +# Fix permissions before starting the application +fix_permissions + +# Parse command line arguments and set environment variables +parse_args() { + while [ $# -gt 0 ]; do + case $1 in + --port|-port) + if [ -n "$2" ]; then + export APP_PORT="$2" + shift 2 + else + echo "Error: --port requires a value" + exit 1 + fi + ;; + --host|-host) + if [ -n "$2" ]; then + export APP_HOST="$2" + shift 2 + else + echo "Error: --host requires a value" + exit 1 + fi + ;; + *) + # Keep other arguments for the main application + set -- "$@" "$1" + shift + ;; + esac + done +} + +# Parse arguments if any are provided +if [ $# -gt 1 ]; then + parse_args "$@" +fi + +# Build the command with environment variables and standard arguments +exec /app/main -app-dir "$APP_DIR" -port "$APP_PORT" -host "$APP_HOST" -log-level "$LOG_LEVEL" -log-style "$LOG_STYLE" \ No newline at end of file diff --git a/transports/go.mod b/transports/go.mod index c92d309e3..62db49783 100644 --- a/transports/go.mod +++ b/transports/go.mod @@ -2,32 +2,116 @@ module github.com/maximhq/bifrost/transports go 1.24.1 +toolchain go1.24.3 + require ( + github.com/bytedance/sonic v1.14.0 github.com/fasthttp/router v1.5.4 - github.com/joho/godotenv v1.5.1 - github.com/maximhq/bifrost/core v1.0.2 - github.com/valyala/fasthttp v1.60.0 + github.com/fasthttp/websocket v1.5.12 + github.com/google/uuid v1.6.0 + github.com/maximhq/bifrost/core v1.1.37 + github.com/maximhq/bifrost/framework v1.0.23 + github.com/maximhq/bifrost/plugins/governance v1.2.16 + github.com/maximhq/bifrost/plugins/logging v1.2.16 + github.com/maximhq/bifrost/plugins/maxim v1.3.6 + github.com/maximhq/bifrost/plugins/semanticcache v1.2.18 + github.com/maximhq/bifrost/plugins/telemetry v1.2.15 + github.com/prometheus/client_golang v1.23.0 + github.com/valyala/fasthttp v1.65.0 + google.golang.org/genai v1.22.0 + gorm.io/gorm v1.30.1 ) require ( - github.com/andybalholm/brotli v1.1.1 // indirect - github.com/aws/aws-sdk-go-v2 v1.36.3 // indirect - github.com/aws/aws-sdk-go-v2/config v1.29.14 // indirect - github.com/aws/aws-sdk-go-v2/credentials v1.17.67 // indirect - github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.16.30 // indirect - github.com/aws/aws-sdk-go-v2/internal/configsources v1.3.34 // indirect - github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.6.34 // indirect + cloud.google.com/go v0.121.6 // indirect + cloud.google.com/go/auth v0.16.5 // indirect + cloud.google.com/go/compute/metadata v0.8.0 // indirect + github.com/andybalholm/brotli v1.2.0 // indirect + github.com/asaskevich/govalidator v0.0.0-20230301143203-a9d515a09cc2 // indirect + github.com/aws/aws-sdk-go-v2 v1.38.0 // indirect + github.com/aws/aws-sdk-go-v2/config v1.31.0 // indirect + github.com/aws/aws-sdk-go-v2/credentials v1.18.4 // indirect + github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.18.3 // indirect + github.com/aws/aws-sdk-go-v2/internal/configsources v1.4.3 // indirect + github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.7.3 // indirect github.com/aws/aws-sdk-go-v2/internal/ini v1.8.3 // indirect - github.com/aws/aws-sdk-go-v2/service/internal/accept-encoding v1.12.3 // indirect - github.com/aws/aws-sdk-go-v2/service/internal/presigned-url v1.12.15 // indirect - github.com/aws/aws-sdk-go-v2/service/sso v1.25.3 // indirect - github.com/aws/aws-sdk-go-v2/service/ssooidc v1.30.1 // indirect - github.com/aws/aws-sdk-go-v2/service/sts v1.33.19 // indirect - github.com/aws/smithy-go v1.22.3 // indirect - github.com/goccy/go-json v0.10.5 // indirect + github.com/aws/aws-sdk-go-v2/service/internal/accept-encoding v1.13.0 // indirect + github.com/aws/aws-sdk-go-v2/service/internal/presigned-url v1.13.3 // indirect + github.com/aws/aws-sdk-go-v2/service/sso v1.28.0 // indirect + github.com/aws/aws-sdk-go-v2/service/ssooidc v1.33.0 // indirect + github.com/aws/aws-sdk-go-v2/service/sts v1.37.0 // indirect + github.com/aws/smithy-go v1.22.5 // indirect + github.com/bahlo/generic-list-go v0.2.0 // indirect + github.com/beorn7/perks v1.0.1 // indirect + github.com/buger/jsonparser v1.1.1 // indirect + github.com/bytedance/sonic/loader v0.3.0 // indirect + github.com/cespare/xxhash/v2 v2.3.0 // indirect + github.com/cloudwego/base64x v0.1.6 // indirect + github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f // indirect + github.com/felixge/httpsnoop v1.0.4 // indirect + github.com/go-logr/logr v1.4.3 // indirect + github.com/go-logr/stdr v1.2.2 // indirect + github.com/go-openapi/analysis v0.23.0 // indirect + github.com/go-openapi/errors v0.22.0 // indirect + github.com/go-openapi/jsonpointer v0.21.0 // indirect + github.com/go-openapi/jsonreference v0.21.0 // indirect + github.com/go-openapi/loads v0.22.0 // indirect + github.com/go-openapi/runtime v0.24.2 // indirect + github.com/go-openapi/spec v0.21.0 // indirect + github.com/go-openapi/strfmt v0.23.0 // indirect + github.com/go-openapi/swag v0.23.0 // indirect + github.com/go-openapi/validate v0.24.0 // indirect + github.com/google/go-cmp v0.7.0 // indirect + github.com/google/s2a-go v0.1.9 // indirect + github.com/googleapis/enterprise-certificate-proxy v0.3.6 // indirect + github.com/googleapis/gax-go/v2 v2.15.0 // indirect + github.com/gorilla/websocket v1.5.3 // indirect + github.com/invopop/jsonschema v0.13.0 // indirect + github.com/jinzhu/inflection v1.0.0 // indirect + github.com/jinzhu/now v1.1.5 // indirect + github.com/josharian/intern v1.0.0 // indirect github.com/klauspost/compress v1.18.0 // indirect - github.com/savsgio/gotils v0.0.0-20240704082632-aef3928b8a38 // indirect + github.com/klauspost/cpuid/v2 v2.3.0 // indirect + github.com/mailru/easyjson v0.9.0 // indirect + github.com/mark3labs/mcp-go v0.37.0 // indirect + github.com/mattn/go-colorable v0.1.14 // indirect + github.com/mattn/go-isatty v0.0.20 // indirect + github.com/mattn/go-sqlite3 v1.14.32 // indirect + github.com/maximhq/maxim-go v0.1.10 // indirect + github.com/mitchellh/mapstructure v1.5.0 // indirect + github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822 // indirect + github.com/oklog/ulid v1.3.1 // indirect + github.com/opentracing/opentracing-go v1.2.0 // indirect + github.com/pkg/errors v0.9.1 // indirect + github.com/prometheus/client_model v0.6.2 // indirect + github.com/prometheus/common v0.65.0 // indirect + github.com/prometheus/procfs v0.17.0 // indirect + github.com/redis/go-redis/v9 v9.12.1 // indirect + github.com/rs/zerolog v1.34.0 // indirect + github.com/savsgio/gotils v0.0.0-20250408102913-196191ec6287 // indirect + github.com/spf13/cast v1.9.2 // indirect + github.com/twitchyliquid64/golang-asm v0.15.1 // indirect github.com/valyala/bytebufferpool v1.0.0 // indirect - golang.org/x/net v0.39.0 // indirect - golang.org/x/text v0.24.0 // indirect + github.com/weaviate/weaviate v1.31.5 // indirect + github.com/weaviate/weaviate-go-client/v5 v5.2.0 // indirect + github.com/wk8/go-ordered-map/v2 v2.1.8 // indirect + github.com/yosida95/uritemplate/v3 v3.0.2 // indirect + go.mongodb.org/mongo-driver v1.14.0 // indirect + go.opentelemetry.io/auto/sdk v1.1.0 // indirect + go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.62.0 // indirect + go.opentelemetry.io/otel v1.37.0 // indirect + go.opentelemetry.io/otel/metric v1.37.0 // indirect + go.opentelemetry.io/otel/trace v1.37.0 // indirect + golang.org/x/arch v0.20.0 // indirect + golang.org/x/crypto v0.41.0 // indirect + golang.org/x/net v0.43.0 // indirect + golang.org/x/oauth2 v0.30.0 // indirect + golang.org/x/sys v0.35.0 // indirect + golang.org/x/text v0.28.0 // indirect + google.golang.org/genproto/googleapis/rpc v0.0.0-20250811230008-5f3141c8851a // indirect + google.golang.org/grpc v1.74.2 // indirect + google.golang.org/protobuf v1.36.7 // indirect + gopkg.in/yaml.v2 v2.4.0 // indirect + gopkg.in/yaml.v3 v3.0.1 // indirect + gorm.io/driver/sqlite v1.6.0 // indirect ) diff --git a/transports/go.sum b/transports/go.sum index bab9764a1..7a0ea2fbf 100644 --- a/transports/go.sum +++ b/transports/go.sum @@ -1,52 +1,412 @@ -github.com/andybalholm/brotli v1.1.1 h1:PR2pgnyFznKEugtsUo0xLdDop5SKXd5Qf5ysW+7XdTA= -github.com/andybalholm/brotli v1.1.1/go.mod h1:05ib4cKhjx3OQYUY22hTVd34Bc8upXjOLL2rKwwZBoA= -github.com/aws/aws-sdk-go-v2 v1.36.3 h1:mJoei2CxPutQVxaATCzDUjcZEjVRdpsiiXi2o38yqWM= -github.com/aws/aws-sdk-go-v2 v1.36.3/go.mod h1:LLXuLpgzEbD766Z5ECcRmi8AzSwfZItDtmABVkRLGzg= -github.com/aws/aws-sdk-go-v2/config v1.29.14 h1:f+eEi/2cKCg9pqKBoAIwRGzVb70MRKqWX4dg1BDcSJM= -github.com/aws/aws-sdk-go-v2/config v1.29.14/go.mod h1:wVPHWcIFv3WO89w0rE10gzf17ZYy+UVS1Geq8Iei34g= -github.com/aws/aws-sdk-go-v2/credentials v1.17.67 h1:9KxtdcIA/5xPNQyZRgUSpYOE6j9Bc4+D7nZua0KGYOM= -github.com/aws/aws-sdk-go-v2/credentials v1.17.67/go.mod h1:p3C44m+cfnbv763s52gCqrjaqyPikj9Sg47kUVaNZQQ= -github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.16.30 h1:x793wxmUWVDhshP8WW2mlnXuFrO4cOd3HLBroh1paFw= -github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.16.30/go.mod h1:Jpne2tDnYiFascUEs2AWHJL9Yp7A5ZVy3TNyxaAjD6M= -github.com/aws/aws-sdk-go-v2/internal/configsources v1.3.34 h1:ZK5jHhnrioRkUNOc+hOgQKlUL5JeC3S6JgLxtQ+Rm0Q= -github.com/aws/aws-sdk-go-v2/internal/configsources v1.3.34/go.mod h1:p4VfIceZokChbA9FzMbRGz5OV+lekcVtHlPKEO0gSZY= -github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.6.34 h1:SZwFm17ZUNNg5Np0ioo/gq8Mn6u9w19Mri8DnJ15Jf0= -github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.6.34/go.mod h1:dFZsC0BLo346mvKQLWmoJxT+Sjp+qcVR1tRVHQGOH9Q= +cloud.google.com/go v0.121.6 h1:waZiuajrI28iAf40cWgycWNgaXPO06dupuS+sgibK6c= +cloud.google.com/go v0.121.6/go.mod h1:coChdst4Ea5vUpiALcYKXEpR1S9ZgXbhEzzMcMR66vI= +cloud.google.com/go/auth v0.16.5 h1:mFWNQ2FEVWAliEQWpAdH80omXFokmrnbDhUS9cBywsI= +cloud.google.com/go/auth v0.16.5/go.mod h1:utzRfHMP+Vv0mpOkTRQoWD2q3BatTOoWbA7gCc2dUhQ= +cloud.google.com/go/compute/metadata v0.8.0 h1:HxMRIbao8w17ZX6wBnjhcDkW6lTFpgcaobyVfZWqRLA= +cloud.google.com/go/compute/metadata v0.8.0/go.mod h1:sYOGTp851OV9bOFJ9CH7elVvyzopvWQFNNghtDQ/Biw= +github.com/BurntSushi/toml v0.3.1/go.mod h1:xHWCNGjB5oqiDr8zfno3MHue2Ht5sIBksp03qcyfWMU= +github.com/PuerkitoBio/purell v1.1.1/go.mod h1:c11w/QuzBsJSee3cPx9rAFu61PvFxuPbtSwDGJws/X0= +github.com/PuerkitoBio/urlesc v0.0.0-20170810143723-de5bf2ad4578/go.mod h1:uGdkoq3SwY9Y+13GIhn11/XLaGBb4BfwItxLd5jeuXE= +github.com/andybalholm/brotli v1.2.0 h1:ukwgCxwYrmACq68yiUqwIWnGY0cTPox/M94sVwToPjQ= +github.com/andybalholm/brotli v1.2.0/go.mod h1:rzTDkvFWvIrjDXZHkuS16NPggd91W3kUSvPlQ1pLaKY= +github.com/asaskevich/govalidator v0.0.0-20200907205600-7a23bdc65eef/go.mod h1:WaHUgvxTVq04UNunO+XhnAqY/wQc+bxr74GqbsZ/Jqw= +github.com/asaskevich/govalidator v0.0.0-20210307081110-f21760c49a8d/go.mod h1:WaHUgvxTVq04UNunO+XhnAqY/wQc+bxr74GqbsZ/Jqw= +github.com/asaskevich/govalidator v0.0.0-20230301143203-a9d515a09cc2 h1:DklsrG3dyBCFEj5IhUbnKptjxatkF07cF2ak3yi77so= +github.com/asaskevich/govalidator v0.0.0-20230301143203-a9d515a09cc2/go.mod h1:WaHUgvxTVq04UNunO+XhnAqY/wQc+bxr74GqbsZ/Jqw= +github.com/aws/aws-sdk-go-v2 v1.38.0 h1:UCRQ5mlqcFk9HJDIqENSLR3wiG1VTWlyUfLDEvY7RxU= +github.com/aws/aws-sdk-go-v2 v1.38.0/go.mod h1:9Q0OoGQoboYIAJyslFyF1f5K1Ryddop8gqMhWx/n4Wg= +github.com/aws/aws-sdk-go-v2/config v1.31.0 h1:9yH0xiY5fUnVNLRWO0AtayqwU1ndriZdN78LlhruJR4= +github.com/aws/aws-sdk-go-v2/config v1.31.0/go.mod h1:VeV3K72nXnhbe4EuxxhzsDc/ByrCSlZwUnWH52Nde/I= +github.com/aws/aws-sdk-go-v2/credentials v1.18.4 h1:IPd0Algf1b+Qy9BcDp0sCUcIWdCQPSzDoMK3a8pcbUM= +github.com/aws/aws-sdk-go-v2/credentials v1.18.4/go.mod h1:nwg78FjH2qvsRM1EVZlX9WuGUJOL5od+0qvm0adEzHk= +github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.18.3 h1:GicIdnekoJsjq9wqnvyi2elW6CGMSYKhdozE7/Svh78= +github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.18.3/go.mod h1:R7BIi6WNC5mc1kfRM7XM/VHC3uRWkjc396sfabq4iOo= +github.com/aws/aws-sdk-go-v2/internal/configsources v1.4.3 h1:o9RnO+YZ4X+kt5Z7Nvcishlz0nksIt2PIzDglLMP0vA= +github.com/aws/aws-sdk-go-v2/internal/configsources v1.4.3/go.mod h1:+6aLJzOG1fvMOyzIySYjOFjcguGvVRL68R+uoRencN4= +github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.7.3 h1:joyyUFhiTQQmVK6ImzNU9TQSNRNeD9kOklqTzyk5v6s= +github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.7.3/go.mod h1:+vNIyZQP3b3B1tSLI0lxvrU9cfM7gpdRXMFfm67ZcPc= github.com/aws/aws-sdk-go-v2/internal/ini v1.8.3 h1:bIqFDwgGXXN1Kpp99pDOdKMTTb5d2KyU5X/BZxjOkRo= github.com/aws/aws-sdk-go-v2/internal/ini v1.8.3/go.mod h1:H5O/EsxDWyU+LP/V8i5sm8cxoZgc2fdNR9bxlOFrQTo= -github.com/aws/aws-sdk-go-v2/service/internal/accept-encoding v1.12.3 h1:eAh2A4b5IzM/lum78bZ590jy36+d/aFLgKF/4Vd1xPE= -github.com/aws/aws-sdk-go-v2/service/internal/accept-encoding v1.12.3/go.mod h1:0yKJC/kb8sAnmlYa6Zs3QVYqaC8ug2AbnNChv5Ox3uA= -github.com/aws/aws-sdk-go-v2/service/internal/presigned-url v1.12.15 h1:dM9/92u2F1JbDaGooxTq18wmmFzbJRfXfVfy96/1CXM= -github.com/aws/aws-sdk-go-v2/service/internal/presigned-url v1.12.15/go.mod h1:SwFBy2vjtA0vZbjjaFtfN045boopadnoVPhu4Fv66vY= -github.com/aws/aws-sdk-go-v2/service/sso v1.25.3 h1:1Gw+9ajCV1jogloEv1RRnvfRFia2cL6c9cuKV2Ps+G8= -github.com/aws/aws-sdk-go-v2/service/sso v1.25.3/go.mod h1:qs4a9T5EMLl/Cajiw2TcbNt2UNo/Hqlyp+GiuG4CFDI= -github.com/aws/aws-sdk-go-v2/service/ssooidc v1.30.1 h1:hXmVKytPfTy5axZ+fYbR5d0cFmC3JvwLm5kM83luako= -github.com/aws/aws-sdk-go-v2/service/ssooidc v1.30.1/go.mod h1:MlYRNmYu/fGPoxBQVvBYr9nyr948aY/WLUvwBMBJubs= -github.com/aws/aws-sdk-go-v2/service/sts v1.33.19 h1:1XuUZ8mYJw9B6lzAkXhqHlJd/XvaX32evhproijJEZY= -github.com/aws/aws-sdk-go-v2/service/sts v1.33.19/go.mod h1:cQnB8CUnxbMU82JvlqjKR2HBOm3fe9pWorWBza6MBJ4= -github.com/aws/smithy-go v1.22.3 h1:Z//5NuZCSW6R4PhQ93hShNbyBbn8BWCmCVCt+Q8Io5k= -github.com/aws/smithy-go v1.22.3/go.mod h1:t1ufH5HMublsJYulve2RKmHDC15xu1f26kHCp/HgceI= +github.com/aws/aws-sdk-go-v2/service/internal/accept-encoding v1.13.0 h1:6+lZi2JeGKtCraAj1rpoZfKqnQ9SptseRZioejfUOLM= +github.com/aws/aws-sdk-go-v2/service/internal/accept-encoding v1.13.0/go.mod h1:eb3gfbVIxIoGgJsi9pGne19dhCBpK6opTYpQqAmdy44= +github.com/aws/aws-sdk-go-v2/service/internal/presigned-url v1.13.3 h1:ieRzyHXypu5ByllM7Sp4hC5f/1Fy5wqxqY0yB85hC7s= +github.com/aws/aws-sdk-go-v2/service/internal/presigned-url v1.13.3/go.mod h1:O5ROz8jHiOAKAwx179v+7sHMhfobFVi6nZt8DEyiYoM= +github.com/aws/aws-sdk-go-v2/service/sso v1.28.0 h1:Mc/MKBf2m4VynyJkABoVEN+QzkfLqGj0aiJuEe7cMeM= +github.com/aws/aws-sdk-go-v2/service/sso v1.28.0/go.mod h1:iS5OmxEcN4QIPXARGhavH7S8kETNL11kym6jhoS7IUQ= +github.com/aws/aws-sdk-go-v2/service/ssooidc v1.33.0 h1:6csaS/aJmqZQbKhi1EyEMM7yBW653Wy/B9hnBofW+sw= +github.com/aws/aws-sdk-go-v2/service/ssooidc v1.33.0/go.mod h1:59qHWaY5B+Rs7HGTuVGaC32m0rdpQ68N8QCN3khYiqs= +github.com/aws/aws-sdk-go-v2/service/sts v1.37.0 h1:MG9VFW43M4A8BYeAfaJJZWrroinxeTi2r3+SnmLQfSA= +github.com/aws/aws-sdk-go-v2/service/sts v1.37.0/go.mod h1:JdeBDPgpJfuS6rU/hNglmOigKhyEZtBmbraLE4GK1J8= +github.com/aws/smithy-go v1.22.5 h1:P9ATCXPMb2mPjYBgueqJNCA5S9UfktsW0tTxi+a7eqw= +github.com/aws/smithy-go v1.22.5/go.mod h1:t1ufH5HMublsJYulve2RKmHDC15xu1f26kHCp/HgceI= +github.com/bahlo/generic-list-go v0.2.0 h1:5sz/EEAK+ls5wF+NeqDpk5+iNdMDXrh3z3nPnH1Wvgk= +github.com/bahlo/generic-list-go v0.2.0/go.mod h1:2KvAjgMlE5NNynlg/5iLrrCCZ2+5xWbdbCW3pNTGyYg= +github.com/beorn7/perks v1.0.1 h1:VlbKKnNfV8bJzeqoa4cOKqO6bYr3WgKZxO8Z16+hsOM= +github.com/beorn7/perks v1.0.1/go.mod h1:G2ZrVWU2WbWT9wwq4/hrbKbnv/1ERSJQ0ibhJ6rlkpw= +github.com/bsm/ginkgo/v2 v2.12.0 h1:Ny8MWAHyOepLGlLKYmXG4IEkioBysk6GpaRTLC8zwWs= +github.com/bsm/ginkgo/v2 v2.12.0/go.mod h1:SwYbGRRDovPVboqFv0tPTcG1sN61LM1Z4ARdbAV9g4c= +github.com/bsm/gomega v1.27.10 h1:yeMWxP2pV2fG3FgAODIY8EiRE3dy0aeFYt4l7wh6yKA= +github.com/bsm/gomega v1.27.10/go.mod h1:JyEr/xRbxbtgWNi8tIEVPUYZ5Dzef52k01W3YH0H+O0= +github.com/buger/jsonparser v1.1.1 h1:2PnMjfWD7wBILjqQbt530v576A/cAbQvEW9gGIpYMUs= +github.com/buger/jsonparser v1.1.1/go.mod h1:6RYKKt7H4d4+iWqouImQ9R2FZql3VbhNgx27UK13J/0= +github.com/bytedance/sonic v1.14.0 h1:/OfKt8HFw0kh2rj8N0F6C/qPGRESq0BbaNZgcNXXzQQ= +github.com/bytedance/sonic v1.14.0/go.mod h1:WoEbx8WTcFJfzCe0hbmyTGrfjt8PzNEBdxlNUO24NhA= +github.com/bytedance/sonic/loader v0.3.0 h1:dskwH8edlzNMctoruo8FPTJDF3vLtDT0sXZwvZJyqeA= +github.com/bytedance/sonic/loader v0.3.0/go.mod h1:N8A3vUdtUebEY2/VQC0MyhYeKUFosQU6FxH2JmUe6VI= +github.com/cespare/xxhash/v2 v2.3.0 h1:UL815xU9SqsFlibzuggzjXhog7bL6oX9BbNZnL2UFvs= +github.com/cespare/xxhash/v2 v2.3.0/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs= +github.com/cloudwego/base64x v0.1.6 h1:t11wG9AECkCDk5fMSoxmufanudBtJ+/HemLstXDLI2M= +github.com/cloudwego/base64x v0.1.6/go.mod h1:OFcloc187FXDaYHvrNIjxSe8ncn0OOM8gEHfghB2IPU= +github.com/coreos/go-systemd/v22 v22.5.0/go.mod h1:Y58oyj3AT4RCenI/lSvhwexgC+NSVTIJ3seZv2GcEnc= +github.com/creack/pty v1.1.9/go.mod h1:oKZEueFk5CKHvIhNR5MUki03XCEU+Q6VDXinZuGJ33E= +github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc h1:U9qPSI2PIWSS1VwoXQT9A3Wy9MM3WgvqSxFWenqJduM= +github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f h1:lO4WD4F/rVNCu3HqELle0jiPLLBs70cWOduZpkS1E78= +github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f/go.mod h1:cuUVRXasLTGF7a8hSLbxyZXjz+1KgoB3wDUb6vlszIc= +github.com/docker/go-units v0.4.0/go.mod h1:fgPhTUdO+D/Jk86RDLlptpiXQzgHJF7gydDDbaIK4Dk= github.com/fasthttp/router v1.5.4 h1:oxdThbBwQgsDIYZ3wR1IavsNl6ZS9WdjKukeMikOnC8= github.com/fasthttp/router v1.5.4/go.mod h1:3/hysWq6cky7dTfzaaEPZGdptwjwx0qzTgFCKEWRjgc= -github.com/goccy/go-json v0.10.5 h1:Fq85nIqj+gXn/S5ahsiTlK3TmC85qgirsdTP/+DeaC4= -github.com/goccy/go-json v0.10.5/go.mod h1:oq7eo15ShAhp70Anwd5lgX2pLfOS3QCiwU/PULtXL6M= -github.com/joho/godotenv v1.5.1 h1:7eLL/+HRGLY0ldzfGMeQkb7vMd0as4CfYvUVzLqw0N0= -github.com/joho/godotenv v1.5.1/go.mod h1:f4LDr5Voq0i2e/R5DDNOoa2zzDfwtkZa6DnEwAbqwq4= +github.com/fasthttp/websocket v1.5.12 h1:e4RGPpWW2HTbL3zV0Y/t7g0ub294LkiuXXUuTOUInlE= +github.com/fasthttp/websocket v1.5.12/go.mod h1:I+liyL7/4moHojiOgUOIKEWm9EIxHqxZChS+aMFltyg= +github.com/felixge/httpsnoop v1.0.4 h1:NFTV2Zj1bL4mc9sqWACXbQFVBBg2W3GPvqp8/ESS2Wg= +github.com/felixge/httpsnoop v1.0.4/go.mod h1:m8KPJKqk1gH5J9DgRY2ASl2lWCfGKXixSwevea8zH2U= +github.com/frankban/quicktest v1.14.6 h1:7Xjx+VpznH+oBnejlPUj8oUpdxnVs4f8XU8WnHkI4W8= +github.com/frankban/quicktest v1.14.6/go.mod h1:4ptaffx2x8+WTWXmUCuVU6aPUX1/Mz7zb5vbUoiM6w0= +github.com/go-logr/logr v1.2.2/go.mod h1:jdQByPbusPIv2/zmleS9BjJVeZ6kBagPoEUsqbVz/1A= +github.com/go-logr/logr v1.4.3 h1:CjnDlHq8ikf6E492q6eKboGOC0T8CDaOvkHCIg8idEI= +github.com/go-logr/logr v1.4.3/go.mod h1:9T104GzyrTigFIr8wt5mBrctHMim0Nb2HLGrmQ40KvY= +github.com/go-logr/stdr v1.2.2 h1:hSWxHoqTgW2S2qGc0LTAI563KZ5YKYRhT3MFKZMbjag= +github.com/go-logr/stdr v1.2.2/go.mod h1:mMo/vtBO5dYbehREoey6XUKy/eSumjCCveDpRre4VKE= +github.com/go-openapi/analysis v0.21.2/go.mod h1:HZwRk4RRisyG8vx2Oe6aqeSQcoxRp47Xkp3+K6q+LdY= +github.com/go-openapi/analysis v0.23.0 h1:aGday7OWupfMs+LbmLZG4k0MYXIANxcuBTYUC03zFCU= +github.com/go-openapi/analysis v0.23.0/go.mod h1:9mz9ZWaSlV8TvjQHLl2mUW2PbZtemkE8yA5v22ohupo= +github.com/go-openapi/errors v0.19.8/go.mod h1:cM//ZKUKyO06HSwqAelJ5NsEMMcpa6VpXe8DOa1Mi1M= +github.com/go-openapi/errors v0.19.9/go.mod h1:cM//ZKUKyO06HSwqAelJ5NsEMMcpa6VpXe8DOa1Mi1M= +github.com/go-openapi/errors v0.20.2/go.mod h1:cM//ZKUKyO06HSwqAelJ5NsEMMcpa6VpXe8DOa1Mi1M= +github.com/go-openapi/errors v0.22.0 h1:c4xY/OLxUBSTiepAg3j/MHuAv5mJhnf53LLMWFB+u/w= +github.com/go-openapi/errors v0.22.0/go.mod h1:J3DmZScxCDufmIMsdOuDHxJbdOGC0xtUynjIx092vXE= +github.com/go-openapi/jsonpointer v0.19.3/go.mod h1:Pl9vOtqEWErmShwVjC8pYs9cog34VGT37dQOVbmoatg= +github.com/go-openapi/jsonpointer v0.19.5/go.mod h1:Pl9vOtqEWErmShwVjC8pYs9cog34VGT37dQOVbmoatg= +github.com/go-openapi/jsonpointer v0.21.0 h1:YgdVicSA9vH5RiHs9TZW5oyafXZFc6+2Vc1rr/O9oNQ= +github.com/go-openapi/jsonpointer v0.21.0/go.mod h1:IUyH9l/+uyhIYQ/PXVA41Rexl+kOkAPDdXEYns6fzUY= +github.com/go-openapi/jsonreference v0.19.6/go.mod h1:diGHMEHg2IqXZGKxqyvWdfWU/aim5Dprw5bqpKkTvns= +github.com/go-openapi/jsonreference v0.21.0 h1:Rs+Y7hSXT83Jacb7kFyjn4ijOuVGSvOdF2+tg1TRrwQ= +github.com/go-openapi/jsonreference v0.21.0/go.mod h1:LmZmgsrTkVg9LG4EaHeY8cBDslNPMo06cago5JNLkm4= +github.com/go-openapi/loads v0.21.1/go.mod h1:/DtAMXXneXFjbQMGEtbamCZb+4x7eGwkvZCvBmwUG+g= +github.com/go-openapi/loads v0.22.0 h1:ECPGd4jX1U6NApCGG1We+uEozOAvXvJSF4nnwHZ8Aco= +github.com/go-openapi/loads v0.22.0/go.mod h1:yLsaTCS92mnSAZX5WWoxszLj0u+Ojl+Zs5Stn1oF+rs= +github.com/go-openapi/runtime v0.24.2 h1:yX9HMGQbz32M87ECaAhGpJjBmErO3QLcgdZj9BzGx7c= +github.com/go-openapi/runtime v0.24.2/go.mod h1:AKurw9fNre+h3ELZfk6ILsfvPN+bvvlaU/M9q/r9hpk= +github.com/go-openapi/spec v0.20.4/go.mod h1:faYFR1CvsJZ0mNsmsphTMSoRrNV3TEDoAM7FOEWeq8I= +github.com/go-openapi/spec v0.21.0 h1:LTVzPc3p/RzRnkQqLRndbAzjY0d0BCL72A6j3CdL9ZY= +github.com/go-openapi/spec v0.21.0/go.mod h1:78u6VdPw81XU44qEWGhtr982gJ5BWg2c0I5XwVMotYk= +github.com/go-openapi/strfmt v0.21.0/go.mod h1:ZRQ409bWMj+SOgXofQAGTIo2Ebu72Gs+WaRADcS5iNg= +github.com/go-openapi/strfmt v0.21.1/go.mod h1:I/XVKeLc5+MM5oPNN7P6urMOpuLXEcNrCX/rPGuWb0k= +github.com/go-openapi/strfmt v0.21.2/go.mod h1:I/XVKeLc5+MM5oPNN7P6urMOpuLXEcNrCX/rPGuWb0k= +github.com/go-openapi/strfmt v0.23.0 h1:nlUS6BCqcnAk0pyhi9Y+kdDVZdZMHfEKQiS4HaMgO/c= +github.com/go-openapi/strfmt v0.23.0/go.mod h1:NrtIpfKtWIygRkKVsxh7XQMDQW5HKQl6S5ik2elW+K4= +github.com/go-openapi/swag v0.19.5/go.mod h1:POnQmlKehdgb5mhVOsnJFsivZCEZ/vjK9gh66Z9tfKk= +github.com/go-openapi/swag v0.19.15/go.mod h1:QYRuS/SOXUCsnplDa677K7+DxSOj6IPNl/eQntq43wQ= +github.com/go-openapi/swag v0.21.1/go.mod h1:QYRuS/SOXUCsnplDa677K7+DxSOj6IPNl/eQntq43wQ= +github.com/go-openapi/swag v0.23.0 h1:vsEVJDUo2hPJ2tu0/Xc+4noaxyEffXNIs3cOULZ+GrE= +github.com/go-openapi/swag v0.23.0/go.mod h1:esZ8ITTYEsH1V2trKHjAN8Ai7xHb8RV+YSZ577vPjgQ= +github.com/go-openapi/validate v0.21.0/go.mod h1:rjnrwK57VJ7A8xqfpAOEKRH8yQSGUriMu5/zuPSQ1hg= +github.com/go-openapi/validate v0.24.0 h1:LdfDKwNbpB6Vn40xhTdNZAnfLECL81w+VX3BumrGD58= +github.com/go-openapi/validate v0.24.0/go.mod h1:iyeX1sEufmv3nPbBdX3ieNviWnOZaJ1+zquzJEf2BAQ= +github.com/go-stack/stack v1.8.0/go.mod h1:v0f6uXyyMGvRgIKkXu+yp6POWl0qKG85gN/melR3HDY= +github.com/go-stack/stack v1.8.1/go.mod h1:dcoOX6HbPZSZptuspn9bctJ+N/CnF5gGygcUP3XYfe4= +github.com/gobuffalo/attrs v0.0.0-20190224210810-a9411de4debd/go.mod h1:4duuawTqi2wkkpB4ePgWMaai6/Kc6WEz83bhFwpHzj0= +github.com/gobuffalo/depgen v0.0.0-20190329151759-d478694a28d3/go.mod h1:3STtPUQYuzV0gBVOY3vy6CfMm/ljR4pABfrTeHNLHUY= +github.com/gobuffalo/depgen v0.1.0/go.mod h1:+ifsuy7fhi15RWncXQQKjWS9JPkdah5sZvtHc2RXGlg= +github.com/gobuffalo/envy v1.6.15/go.mod h1:n7DRkBerg/aorDM8kbduw5dN3oXGswK5liaSCx4T5NI= +github.com/gobuffalo/envy v1.7.0/go.mod h1:n7DRkBerg/aorDM8kbduw5dN3oXGswK5liaSCx4T5NI= +github.com/gobuffalo/flect v0.1.0/go.mod h1:d2ehjJqGOH/Kjqcoz+F7jHTBbmDb38yXA598Hb50EGs= +github.com/gobuffalo/flect v0.1.1/go.mod h1:8JCgGVbRjJhVgD6399mQr4fx5rRfGKVzFjbj6RE/9UI= +github.com/gobuffalo/flect v0.1.3/go.mod h1:8JCgGVbRjJhVgD6399mQr4fx5rRfGKVzFjbj6RE/9UI= +github.com/gobuffalo/genny v0.0.0-20190329151137-27723ad26ef9/go.mod h1:rWs4Z12d1Zbf19rlsn0nurr75KqhYp52EAGGxTbBhNk= +github.com/gobuffalo/genny v0.0.0-20190403191548-3ca520ef0d9e/go.mod h1:80lIj3kVJWwOrXWWMRzzdhW3DsrdjILVil/SFKBzF28= +github.com/gobuffalo/genny v0.1.0/go.mod h1:XidbUqzak3lHdS//TPu2OgiFB+51Ur5f7CSnXZ/JDvo= +github.com/gobuffalo/genny v0.1.1/go.mod h1:5TExbEyY48pfunL4QSXxlDOmdsD44RRq4mVZ0Ex28Xk= +github.com/gobuffalo/gitgen v0.0.0-20190315122116-cc086187d211/go.mod h1:vEHJk/E9DmhejeLeNt7UVvlSGv3ziL+djtTr3yyzcOw= +github.com/gobuffalo/gogen v0.0.0-20190315121717-8f38393713f5/go.mod h1:V9QVDIxsgKNZs6L2IYiGR8datgMhB577vzTDqypH360= +github.com/gobuffalo/gogen v0.1.0/go.mod h1:8NTelM5qd8RZ15VjQTFkAW6qOMx5wBbW4dSCS3BY8gg= +github.com/gobuffalo/gogen v0.1.1/go.mod h1:y8iBtmHmGc4qa3urIyo1shvOD8JftTtfcKi+71xfDNE= +github.com/gobuffalo/logger v0.0.0-20190315122211-86e12af44bc2/go.mod h1:QdxcLw541hSGtBnhUc4gaNIXRjiDppFGaDqzbrBd3v8= +github.com/gobuffalo/mapi v1.0.1/go.mod h1:4VAGh89y6rVOvm5A8fKFxYG+wIW6LO1FMTG9hnKStFc= +github.com/gobuffalo/mapi v1.0.2/go.mod h1:4VAGh89y6rVOvm5A8fKFxYG+wIW6LO1FMTG9hnKStFc= +github.com/gobuffalo/packd v0.0.0-20190315124812-a385830c7fc0/go.mod h1:M2Juc+hhDXf/PnmBANFCqx4DM3wRbgDvnVWeG2RIxq4= +github.com/gobuffalo/packd v0.1.0/go.mod h1:M2Juc+hhDXf/PnmBANFCqx4DM3wRbgDvnVWeG2RIxq4= +github.com/gobuffalo/packr/v2 v2.0.9/go.mod h1:emmyGweYTm6Kdper+iywB6YK5YzuKchGtJQZ0Odn4pQ= +github.com/gobuffalo/packr/v2 v2.2.0/go.mod h1:CaAwI0GPIAv+5wKLtv8Afwl+Cm78K/I/VCm/3ptBN+0= +github.com/gobuffalo/syncx v0.0.0-20190224160051-33c29581e754/go.mod h1:HhnNqWY95UYwwW3uSASeV7vtgYkT2t16hJgV3AEPUpw= +github.com/godbus/dbus/v5 v5.0.4/go.mod h1:xhWf0FNVPg57R7Z0UbKHbJfkEywrmjJnf7w5xrFpKfA= +github.com/golang/protobuf v1.5.4 h1:i7eJL8qZTpSEXOPTxNKhASYpMn+8e5Q6AdndVa1dWek= +github.com/golang/protobuf v1.5.4/go.mod h1:lnTiLA8Wa4RWRcIUkrtSVa5nRhsEGBg48fD6rSs7xps= +github.com/golang/snappy v0.0.1/go.mod h1:/XxbfmMg8lxefKM7IXC3fBNl/7bRcc72aCRzEWrmP2Q= +github.com/google/go-cmp v0.5.2/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= +github.com/google/go-cmp v0.7.0 h1:wk8382ETsv4JYUZwIsn6YpYiWiBsYLSJiTsyBybVuN8= +github.com/google/go-cmp v0.7.0/go.mod h1:pXiqmnSA92OHEEa9HXL2W4E7lf9JzCmGVUdgjX3N/iU= +github.com/google/s2a-go v0.1.9 h1:LGD7gtMgezd8a/Xak7mEWL0PjoTQFvpRudN895yqKW0= +github.com/google/s2a-go v0.1.9/go.mod h1:YA0Ei2ZQL3acow2O62kdp9UlnvMmU7kA6Eutn0dXayM= +github.com/google/uuid v1.1.1/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= +github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0= +github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= +github.com/googleapis/enterprise-certificate-proxy v0.3.6 h1:GW/XbdyBFQ8Qe+YAmFU9uHLo7OnF5tL52HFAgMmyrf4= +github.com/googleapis/enterprise-certificate-proxy v0.3.6/go.mod h1:MkHOF77EYAE7qfSuSS9PU6g4Nt4e11cnsDUowfwewLA= +github.com/googleapis/gax-go/v2 v2.15.0 h1:SyjDc1mGgZU5LncH8gimWo9lW1DtIfPibOG81vgd/bo= +github.com/googleapis/gax-go/v2 v2.15.0/go.mod h1:zVVkkxAQHa1RQpg9z2AUCMnKhi0Qld9rcmyfL1OZhoc= +github.com/gorilla/websocket v1.5.3 h1:saDtZ6Pbx/0u+bgYQ3q96pZgCzfhKXGPqt7kZ72aNNg= +github.com/gorilla/websocket v1.5.3/go.mod h1:YR8l580nyteQvAITg2hZ9XVh4b55+EU/adAjf1fMHhE= +github.com/inconshreveable/mousetrap v1.0.0/go.mod h1:PxqpIevigyE2G7u3NXJIT2ANytuPF1OarO4DADm73n8= +github.com/invopop/jsonschema v0.13.0 h1:KvpoAJWEjR3uD9Kbm2HWJmqsEaHt8lBUpd0qHcIi21E= +github.com/invopop/jsonschema v0.13.0/go.mod h1:ffZ5Km5SWWRAIN6wbDXItl95euhFz2uON45H2qjYt+0= +github.com/jinzhu/inflection v1.0.0 h1:K317FqzuhWc8YvSVlFMCCUb36O/S9MCKRDI7QkRKD/E= +github.com/jinzhu/inflection v1.0.0/go.mod h1:h+uFLlag+Qp1Va5pdKtLDYj+kHp5pxUVkryuEj+Srlc= +github.com/jinzhu/now v1.1.5 h1:/o9tlHleP7gOFmsnYNz3RGnqzefHA47wQpKrrdTIwXQ= +github.com/jinzhu/now v1.1.5/go.mod h1:d3SSVoowX0Lcu0IBviAWJpolVfI5UJVZZ7cO71lE/z8= +github.com/joho/godotenv v1.3.0/go.mod h1:7hK45KPybAkOC6peb+G5yklZfMxEjkZhHbwpqxOKXbg= +github.com/josharian/intern v1.0.0 h1:vlS4z54oSdjm0bgjRigI+G1HpF+tI+9rE5LLzOg8HmY= +github.com/josharian/intern v1.0.0/go.mod h1:5DoeVV0s6jJacbCEi61lwdGj/aVlrQvzHFFd8Hwg//Y= +github.com/karrick/godirwalk v1.8.0/go.mod h1:H5KPZjojv4lE+QYImBI8xVtrBRgYrIVsaRPx4tDPEn4= +github.com/karrick/godirwalk v1.10.3/go.mod h1:RoGL9dQei4vP9ilrpETWE8CLOZ1kiN0LhBygSwrAsHA= +github.com/klauspost/compress v1.13.6/go.mod h1:/3/Vjq9QcHkK5uEr5lBEmyoZ1iFhe47etQ6QUkpK6sk= github.com/klauspost/compress v1.18.0 h1:c/Cqfb0r+Yi+JtIEq73FWXVkRonBlf0CRNYc8Zttxdo= github.com/klauspost/compress v1.18.0/go.mod h1:2Pp+KzxcywXVXMr50+X0Q/Lsb43OQHYWRCY2AiWywWQ= -github.com/maximhq/bifrost/core v1.0.1 h1:B0u6o13faUexA+V0EUU0bsLW2dHg9+R2TZKQzPzCxlY= -github.com/maximhq/bifrost/core v1.0.1/go.mod h1:4+Ept2EnX1EEjH/mBuSwK7eE56znI/BCoCbIrx25/x8= -github.com/maximhq/bifrost/core v1.0.2 h1:GG1CGrvbz5lbdDudlJodKHx9pHr0VAoUd5lhgxUWc00= -github.com/maximhq/bifrost/core v1.0.2/go.mod h1:ZF8LVnUwVzHZ3SkCQPvXXmu0w3b4sjRLS6ij9aPYcjg= -github.com/savsgio/gotils v0.0.0-20240704082632-aef3928b8a38 h1:D0vL7YNisV2yqE55+q0lFuGse6U8lxlg7fYTctlT5Gc= -github.com/savsgio/gotils v0.0.0-20240704082632-aef3928b8a38/go.mod h1:sM7Mt7uEoCeFSCBM+qBrqvEo+/9vdmj19wzp3yzUhmg= +github.com/klauspost/cpuid/v2 v2.3.0 h1:S4CRMLnYUhGeDFDqkGriYKdfoFlDnMtqTiI/sFzhA9Y= +github.com/klauspost/cpuid/v2 v2.3.0/go.mod h1:hqwkgyIinND0mEev00jJYCxPNVRVXFQeu1XKlok6oO0= +github.com/konsorten/go-windows-terminal-sequences v1.0.1/go.mod h1:T0+1ngSBFLxvqU3pZ+m/2kptfBszLMUkC4ZK/EgS/cQ= +github.com/konsorten/go-windows-terminal-sequences v1.0.2/go.mod h1:T0+1ngSBFLxvqU3pZ+m/2kptfBszLMUkC4ZK/EgS/cQ= +github.com/kr/pretty v0.1.0/go.mod h1:dAy3ld7l9f0ibDNOQOHHMYYIIbhfbHSm3C4ZsoJORNo= +github.com/kr/pretty v0.3.1 h1:flRD4NNwYAUpkphVc1HcthR4KEIFJ65n8Mw5qdRn3LE= +github.com/kr/pretty v0.3.1/go.mod h1:hoEshYVHaxMs3cyo3Yncou5ZscifuDolrwPKZanG3xk= +github.com/kr/pty v1.1.1/go.mod h1:pFQYn66WHrOpPYNljwOMqo10TkYh1fy3cYio2l3bCsQ= +github.com/kr/text v0.1.0/go.mod h1:4Jbv+DJW3UT/LiOwJeYQe1efqtUx/iVham/4vfdArNI= +github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY= +github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE= +github.com/kylelemons/godebug v1.1.0 h1:RPNrshWIDI6G2gRW9EHilWtl7Z6Sb1BR0xunSBf0SNc= +github.com/kylelemons/godebug v1.1.0/go.mod h1:9/0rRGxNHcop5bhtWyNeEfOS8JIWk580+fNqagV/RAw= +github.com/mailru/easyjson v0.0.0-20190614124828-94de47d64c63/go.mod h1:C1wdFJiN94OJF2b5HbByQZoLdCWB1Yqtg26g4irojpc= +github.com/mailru/easyjson v0.0.0-20190626092158-b2ccc519800e/go.mod h1:C1wdFJiN94OJF2b5HbByQZoLdCWB1Yqtg26g4irojpc= +github.com/mailru/easyjson v0.7.6/go.mod h1:xzfreul335JAWq5oZzymOObrkdz5UnU4kGfJJLY9Nlc= +github.com/mailru/easyjson v0.7.7/go.mod h1:xzfreul335JAWq5oZzymOObrkdz5UnU4kGfJJLY9Nlc= +github.com/mailru/easyjson v0.9.0 h1:PrnmzHw7262yW8sTBwxi1PdJA3Iw/EKBa8psRf7d9a4= +github.com/mailru/easyjson v0.9.0/go.mod h1:1+xMtQp2MRNVL/V1bOzuP3aP8VNwRW55fQUto+XFtTU= +github.com/mark3labs/mcp-go v0.37.0 h1:BywvZLPRT6Zx6mMG/MJfxLSZQkTGIcJSEGKsvr4DsoQ= +github.com/mark3labs/mcp-go v0.37.0/go.mod h1:T7tUa2jO6MavG+3P25Oy/jR7iCeJPHImCZHRymCn39g= +github.com/markbates/oncer v0.0.0-20181203154359-bf2de49a0be2/go.mod h1:Ld9puTsIW75CHf65OeIOkyKbteujpZVXDpWK6YGZbxE= +github.com/markbates/safe v1.0.1/go.mod h1:nAqgmRi7cY2nqMc92/bSEeQA+R4OheNU2T1kNSCBdG0= +github.com/mattn/go-colorable v0.1.13/go.mod h1:7S9/ev0klgBDR4GtXTXX8a3vIGJpMovkB8vQcUbaXHg= +github.com/mattn/go-colorable v0.1.14 h1:9A9LHSqF/7dyVVX6g0U9cwm9pG3kP9gSzcuIPHPsaIE= +github.com/mattn/go-colorable v0.1.14/go.mod h1:6LmQG8QLFO4G5z1gPvYEzlUgJ2wF+stgPZH1UqBm1s8= +github.com/mattn/go-isatty v0.0.16/go.mod h1:kYGgaQfpe5nmfYZH+SKPsOc2e4SrIfOl2e/yFXSvRLM= +github.com/mattn/go-isatty v0.0.19/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y= +github.com/mattn/go-isatty v0.0.20 h1:xfD0iDuEKnDkl03q4limB+vH+GxLEtL/jb4xVJSWWEY= +github.com/mattn/go-isatty v0.0.20/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y= +github.com/mattn/go-sqlite3 v1.14.32 h1:JD12Ag3oLy1zQA+BNn74xRgaBbdhbNIDYvQUEuuErjs= +github.com/mattn/go-sqlite3 v1.14.32/go.mod h1:Uh1q+B4BYcTPb+yiD3kU8Ct7aC0hY9fxUwlHK0RXw+Y= +github.com/maximhq/bifrost/core v1.1.37 h1:jVFY1tQFY8T2r4S3RE1zN8cFp1Uw97Dec3Ud32rR8Uc= +github.com/maximhq/bifrost/core v1.1.37/go.mod h1:tf2pFTpoM53UGXXMFYxsaUjMqnCqYDOd9glFgMJvA0c= +github.com/maximhq/bifrost/framework v1.0.23 h1:erRPP9Q0WIaUgxuLBN8urd77SObEF9irPvpV9Wbegyk= +github.com/maximhq/bifrost/framework v1.0.23/go.mod h1:uEB0iuQtFfuFuMrhccMsb+51mf8m8X2tB8ZlDVoJUbM= +github.com/maximhq/bifrost/plugins/governance v1.2.16 h1:a4WrRnmXRx/+YwGYt67zzammvFt8t4hJLffg4CbIgd8= +github.com/maximhq/bifrost/plugins/governance v1.2.16/go.mod h1:FXxJlyFMU9E3Vrzg1YSi6WCmDPzlc+0z+Td9jaFQzkE= +github.com/maximhq/bifrost/plugins/logging v1.2.16 h1:gXJfdV0yL3wL+tkOvr7pzSvw8XK3GIYm7VYO0QqJgZk= +github.com/maximhq/bifrost/plugins/logging v1.2.16/go.mod h1:yJYFA4rAr0sz3lyF8TWOY76a6HdcMc2Xe+mnAUvWWY4= +github.com/maximhq/bifrost/plugins/maxim v1.3.6 h1:y/JoP1GE2uL8xF80Y5FqPiG7IOaC47+dX1lgwzWVBSk= +github.com/maximhq/bifrost/plugins/maxim v1.3.6/go.mod h1:f9c2tzoQrOc+ILmiDhIlXwRQjgm+Qt1pBqy9fgCNBUo= +github.com/maximhq/bifrost/plugins/semanticcache v1.2.18 h1:dDDy8vgo8b+3ZF1aIkmddFQZy3qI8o2Wwhonriwk7HE= +github.com/maximhq/bifrost/plugins/semanticcache v1.2.18/go.mod h1:n6zkbVTp/YPl0DCTmojxOEyprSsfmVRhc9wVC7U6qLw= +github.com/maximhq/bifrost/plugins/telemetry v1.2.15 h1:8uFg67aWLgIfptbnbIr3keHGdJsP3x7HFyDwx0QsRSw= +github.com/maximhq/bifrost/plugins/telemetry v1.2.15/go.mod h1:yyz13bGb0RdWVzkwpljzaEgP6neSkZCeBbz/fqfUyr0= +github.com/maximhq/maxim-go v0.1.10 h1:rGBYSY3qld2zfZeL4HBmropkyfrqNiJ4IYA49jbvYX8= +github.com/maximhq/maxim-go v0.1.10/go.mod h1:0+UTWM7UZwNNE5VnljLtr/vpRGtYP8r/2q9WDwlLWFw= +github.com/mitchellh/mapstructure v1.3.3/go.mod h1:bFUtVrKA4DC2yAKiSyO/QUcy7e+RRV2QTWOzhPopBRo= +github.com/mitchellh/mapstructure v1.4.1/go.mod h1:bFUtVrKA4DC2yAKiSyO/QUcy7e+RRV2QTWOzhPopBRo= +github.com/mitchellh/mapstructure v1.4.3/go.mod h1:bFUtVrKA4DC2yAKiSyO/QUcy7e+RRV2QTWOzhPopBRo= +github.com/mitchellh/mapstructure v1.5.0 h1:jeMsZIYE/09sWLaz43PL7Gy6RuMjD2eJVyuac5Z2hdY= +github.com/mitchellh/mapstructure v1.5.0/go.mod h1:bFUtVrKA4DC2yAKiSyO/QUcy7e+RRV2QTWOzhPopBRo= +github.com/montanaflynn/stats v0.0.0-20171201202039-1bf9dbcd8cbe/go.mod h1:wL8QJuTMNUDYhXwkmfOly8iTdp5TEcJFWZD2D7SIkUc= +github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822 h1:C3w9PqII01/Oq1c1nUAm88MOHcQC9l5mIlSMApZMrHA= +github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822/go.mod h1:+n7T8mK8HuQTcFwEeznm/DIxMOiR9yIdICNftLE1DvQ= +github.com/niemeyer/pretty v0.0.0-20200227124842-a10e7caefd8e/go.mod h1:zD1mROLANZcx1PVRCS0qkT7pwLkGfwJo4zjcN/Tysno= +github.com/oklog/ulid v1.3.1 h1:EGfNDEx6MqHz8B3uNV6QAib1UR2Lm97sHi3ocA6ESJ4= +github.com/oklog/ulid v1.3.1/go.mod h1:CirwcVhetQ6Lv90oh/F+FBtV6XMibvdAFo93nm5qn4U= +github.com/opentracing/opentracing-go v1.2.0 h1:uEJPy/1a5RIPAJ0Ov+OIO8OxWu77jEv+1B0VhjKrZUs= +github.com/opentracing/opentracing-go v1.2.0/go.mod h1:GxEUsuufX4nBwe+T+Wl9TAgYrxe9dPLANfrWvHYVTgc= +github.com/pelletier/go-toml v1.7.0/go.mod h1:vwGMzjaWMwyfHwgIBhI2YUM4fB6nL6lVAvS1LBMMhTE= +github.com/pkg/errors v0.8.0/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= +github.com/pkg/errors v0.8.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= +github.com/pkg/errors v0.9.1 h1:FEBLx1zS214owpjy7qsBeixbURkuhQAwrK5UwLGTwt4= +github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= +github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= +github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2 h1:Jamvg5psRIccs7FGNTlIRMkT8wgtp5eCXdBlqhYGL6U= +github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= +github.com/prometheus/client_golang v1.23.0 h1:ust4zpdl9r4trLY/gSjlm07PuiBq2ynaXXlptpfy8Uc= +github.com/prometheus/client_golang v1.23.0/go.mod h1:i/o0R9ByOnHX0McrTMTyhYvKE4haaf2mW08I+jGAjEE= +github.com/prometheus/client_model v0.6.2 h1:oBsgwpGs7iVziMvrGhE53c/GrLUsZdHnqNwqPLxwZyk= +github.com/prometheus/client_model v0.6.2/go.mod h1:y3m2F6Gdpfy6Ut/GBsUqTWZqCUvMVzSfMLjcu6wAwpE= +github.com/prometheus/common v0.65.0 h1:QDwzd+G1twt//Kwj/Ww6E9FQq1iVMmODnILtW1t2VzE= +github.com/prometheus/common v0.65.0/go.mod h1:0gZns+BLRQ3V6NdaerOhMbwwRbNh9hkGINtQAsP5GS8= +github.com/prometheus/procfs v0.17.0 h1:FuLQ+05u4ZI+SS/w9+BWEM2TXiHKsUQ9TADiRH7DuK0= +github.com/prometheus/procfs v0.17.0/go.mod h1:oPQLaDAMRbA+u8H5Pbfq+dl3VDAvHxMUOVhe0wYB2zw= +github.com/redis/go-redis/v9 v9.12.1 h1:k5iquqv27aBtnTm2tIkROUDp8JBXhXZIVu1InSgvovg= +github.com/redis/go-redis/v9 v9.12.1/go.mod h1:huWgSWd8mW6+m0VPhJjSSQ+d6Nh1VICQ6Q5lHuCH/Iw= +github.com/rogpeppe/go-internal v1.1.0/go.mod h1:M8bDsm7K2OlrFYOpmOWEs/qY81heoFRclV5y23lUDJ4= +github.com/rogpeppe/go-internal v1.2.2/go.mod h1:M8bDsm7K2OlrFYOpmOWEs/qY81heoFRclV5y23lUDJ4= +github.com/rogpeppe/go-internal v1.3.0/go.mod h1:M8bDsm7K2OlrFYOpmOWEs/qY81heoFRclV5y23lUDJ4= +github.com/rogpeppe/go-internal v1.13.1 h1:KvO1DLK/DRN07sQ1LQKScxyZJuNnedQ5/wKSR38lUII= +github.com/rogpeppe/go-internal v1.13.1/go.mod h1:uMEvuHeurkdAXX61udpOXGD/AzZDWNMNyH2VO9fmH0o= +github.com/rs/xid v1.6.0/go.mod h1:7XoLgs4eV+QndskICGsho+ADou8ySMSjJKDIan90Nz0= +github.com/rs/zerolog v1.34.0 h1:k43nTLIwcTVQAncfCw4KZ2VY6ukYoZaBPNOE8txlOeY= +github.com/rs/zerolog v1.34.0/go.mod h1:bJsvje4Z08ROH4Nhs5iH600c3IkWhwp44iRc54W6wYQ= +github.com/savsgio/gotils v0.0.0-20250408102913-196191ec6287 h1:qIQ0tWF9vxGtkJa24bR+2i53WBCz1nW/Pc47oVYauC4= +github.com/savsgio/gotils v0.0.0-20250408102913-196191ec6287/go.mod h1:sM7Mt7uEoCeFSCBM+qBrqvEo+/9vdmj19wzp3yzUhmg= +github.com/sirupsen/logrus v1.4.0/go.mod h1:LxeOpSwHxABJmUn/MG1IvRgCAasNZTLOkJPxbbu5VWo= +github.com/sirupsen/logrus v1.4.1/go.mod h1:ni0Sbl8bgC9z8RoU9G6nDWqqs/fq4eDPysMBDgk/93Q= +github.com/sirupsen/logrus v1.4.2/go.mod h1:tLMulIdttU9McNUspp0xgXVQah82FyeX6MwdIuYE2rE= +github.com/spf13/cast v1.9.2 h1:SsGfm7M8QOFtEzumm7UZrZdLLquNdzFYfIbEXntcFbE= +github.com/spf13/cast v1.9.2/go.mod h1:jNfB8QC9IA6ZuY2ZjDp0KtFO2LZZlg4S/7bzP6qqeHo= +github.com/spf13/cobra v0.0.3/go.mod h1:1l0Ry5zgKvJasoi3XT1TypsSe7PqH0Sj9dhYf7v3XqQ= +github.com/spf13/pflag v1.0.3/go.mod h1:DYY7MBk1bdzusC3SYhjObp+wFpr4gzcvqqNjLnInEg4= +github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= +github.com/stretchr/objx v0.1.1/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= +github.com/stretchr/objx v0.4.0/go.mod h1:YvHI0jy2hoMjB+UWwv71VJQ9isScKT/TqJzVSSt89Yw= +github.com/stretchr/objx v0.5.0/go.mod h1:Yh+to48EsGEfYuaHDzXPcE3xhTkx73EhmCGUpEOglKo= +github.com/stretchr/objx v0.5.2 h1:xuMeJ0Sdp5ZMRXx/aWO6RZxdr3beISkG5/G/aIRr3pY= +github.com/stretchr/objx v0.5.2/go.mod h1:FRsXN1f5AsAjCGJKqEizvkpNtU+EGNCLh3NxZ/8L+MA= +github.com/stretchr/testify v1.2.2/go.mod h1:a8OnRcib4nhh0OaRAV+Yts87kKdq0PP7pXfy6kDkUVs= +github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI= +github.com/stretchr/testify v1.6.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= +github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= +github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= +github.com/stretchr/testify v1.8.0/go.mod h1:yNjHg4UonilssWZ8iaSj1OCr/vHnekPRkoO+kdMU+MU= +github.com/stretchr/testify v1.8.1/go.mod h1:w2LPCIKwWwSfY2zedu0+kehJoqGctiVI29o6fzry7u4= +github.com/stretchr/testify v1.10.0 h1:Xv5erBjTwe/5IxqUQTdXv5kgmIvbHo3QQyRwhJsOfJA= +github.com/stretchr/testify v1.10.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY= +github.com/tidwall/pretty v1.0.0/go.mod h1:XNkn88O1ChpSDQmQeStsy+sBenx6DDtFZJxhVysOjyk= +github.com/twitchyliquid64/golang-asm v0.15.1 h1:SU5vSMR7hnwNxj24w34ZyCi/FmDZTkS4MhqMhdFk5YI= +github.com/twitchyliquid64/golang-asm v0.15.1/go.mod h1:a1lVb/DtPvCB8fslRZhAngC2+aY1QWCk3Cedj/Gdt08= github.com/valyala/bytebufferpool v1.0.0 h1:GqA5TC/0021Y/b9FG4Oi9Mr3q7XYx6KllzawFIhcdPw= github.com/valyala/bytebufferpool v1.0.0/go.mod h1:6bBcMArwyJ5K/AmCkWv1jt77kVWyCJ6HpOuEn7z0Csc= -github.com/valyala/fasthttp v1.60.0 h1:kBRYS0lOhVJ6V+bYN8PqAHELKHtXqwq9zNMLKx1MBsw= -github.com/valyala/fasthttp v1.60.0/go.mod h1:iY4kDgV3Gc6EqhRZ8icqcmlG6bqhcDXfuHgTO4FXCvc= +github.com/valyala/fasthttp v1.65.0 h1:j/u3uzFEGFfRxw79iYzJN+TteTJwbYkru9uDp3d0Yf8= +github.com/valyala/fasthttp v1.65.0/go.mod h1:P/93/YkKPMsKSnATEeELUCkG8a7Y+k99uxNHVbKINr4= +github.com/weaviate/weaviate v1.31.5 h1:YcmU1NcY2rdegWpE/mifS/9OisjE3I30JC7k6OgRlIE= +github.com/weaviate/weaviate v1.31.5/go.mod h1:CMgFYC2WIekOrNtyCQZ+HRJzJVCtrJYAdAkZVUVy45E= +github.com/weaviate/weaviate-go-client/v5 v5.2.0 h1:/HG0vFiKBK3JoOKo0mdk2XVYZ+oM0KfvCLG2ySr/FCA= +github.com/weaviate/weaviate-go-client/v5 v5.2.0/go.mod h1:nzR0ScRmbbutI+0pAjylj9Pt6upGVotnphiLWjy/QNA= +github.com/wk8/go-ordered-map/v2 v2.1.8 h1:5h/BUHu93oj4gIdvHHHGsScSTMijfx5PeYkE/fJgbpc= +github.com/wk8/go-ordered-map/v2 v2.1.8/go.mod h1:5nJHM5DyteebpVlHnWMV0rPz6Zp7+xBAnxjb1X5vnTw= +github.com/xdg-go/pbkdf2 v1.0.0/go.mod h1:jrpuAogTd400dnrH08LKmI/xc1MbPOebTwRqcT5RDeI= +github.com/xdg-go/scram v1.0.2/go.mod h1:1WAq6h33pAW+iRreB34OORO2Nf7qel3VV3fjBj+hCSs= +github.com/xdg-go/stringprep v1.0.2/go.mod h1:8F9zXuvzgwmyT5DUm4GUfZGDdT3W+LCvS6+da4O5kxM= github.com/xyproto/randomstring v1.0.5 h1:YtlWPoRdgMu3NZtP45drfy1GKoojuR7hmRcnhZqKjWU= github.com/xyproto/randomstring v1.0.5/go.mod h1:rgmS5DeNXLivK7YprL0pY+lTuhNQW3iGxZ18UQApw/E= -golang.org/x/net v0.39.0 h1:ZCu7HMWDxpXpaiKdhzIfaltL9Lp31x/3fCP11bc6/fY= -golang.org/x/net v0.39.0/go.mod h1:X7NRbYVEA+ewNkCNyJ513WmMdQ3BineSwVtN2zD/d+E= -golang.org/x/text v0.24.0 h1:dd5Bzh4yt5KYA8f9CJHCP4FB4D51c2c6JvN37xJJkJ0= -golang.org/x/text v0.24.0/go.mod h1:L8rBsPeo2pSS+xqN0d5u2ikmjtmoJbDBT1b7nHvFCdU= +github.com/yosida95/uritemplate/v3 v3.0.2 h1:Ed3Oyj9yrmi9087+NczuL5BwkIc4wvTb5zIM+UJPGz4= +github.com/yosida95/uritemplate/v3 v3.0.2/go.mod h1:ILOh0sOhIJR3+L/8afwt/kE++YT040gmv5BQTMR2HP4= +github.com/youmark/pkcs8 v0.0.0-20181117223130-1be2e3e5546d/go.mod h1:rHwXgn7JulP+udvsHwJoVG1YGAP6VLg4y9I5dyZdqmA= +go.mongodb.org/mongo-driver v1.7.3/go.mod h1:NqaYOwnXWr5Pm7AOpO5QFxKJ503nbMse/R79oO62zWg= +go.mongodb.org/mongo-driver v1.7.5/go.mod h1:VXEWRZ6URJIkUq2SCAyapmhH0ZLRBP+FT4xhp5Zvxng= +go.mongodb.org/mongo-driver v1.8.3/go.mod h1:0sQWfOeY63QTntERDJJ/0SuKK0T1uVSgKCuAROlKEPY= +go.mongodb.org/mongo-driver v1.14.0 h1:P98w8egYRjYe3XDjxhYJagTokP/H6HzlsnojRgZRd80= +go.mongodb.org/mongo-driver v1.14.0/go.mod h1:Vzb0Mk/pa7e6cWw85R4F/endUC3u0U9jGcNU603k65c= +go.opentelemetry.io/auto/sdk v1.1.0 h1:cH53jehLUN6UFLY71z+NDOiNJqDdPRaXzTel0sJySYA= +go.opentelemetry.io/auto/sdk v1.1.0/go.mod h1:3wSPjt5PWp2RhlCcmmOial7AvC4DQqZb7a7wCow3W8A= +go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.62.0 h1:Hf9xI/XLML9ElpiHVDNwvqI0hIFlzV8dgIr35kV1kRU= +go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.62.0/go.mod h1:NfchwuyNoMcZ5MLHwPrODwUF1HWCXWrL31s8gSAdIKY= +go.opentelemetry.io/otel v1.37.0 h1:9zhNfelUvx0KBfu/gb+ZgeAfAgtWrfHJZcAqFC228wQ= +go.opentelemetry.io/otel v1.37.0/go.mod h1:ehE/umFRLnuLa/vSccNq9oS1ErUlkkK71gMcN34UG8I= +go.opentelemetry.io/otel/metric v1.37.0 h1:mvwbQS5m0tbmqML4NqK+e3aDiO02vsf/WgbsdpcPoZE= +go.opentelemetry.io/otel/metric v1.37.0/go.mod h1:04wGrZurHYKOc+RKeye86GwKiTb9FKm1WHtO+4EVr2E= +go.opentelemetry.io/otel/sdk v1.37.0 h1:ItB0QUqnjesGRvNcmAcU0LyvkVyGJ2xftD29bWdDvKI= +go.opentelemetry.io/otel/sdk v1.37.0/go.mod h1:VredYzxUvuo2q3WRcDnKDjbdvmO0sCzOvVAiY+yUkAg= +go.opentelemetry.io/otel/sdk/metric v1.37.0 h1:90lI228XrB9jCMuSdA0673aubgRobVZFhbjxHHspCPc= +go.opentelemetry.io/otel/sdk/metric v1.37.0/go.mod h1:cNen4ZWfiD37l5NhS+Keb5RXVWZWpRE+9WyVCpbo5ps= +go.opentelemetry.io/otel/trace v1.37.0 h1:HLdcFNbRQBE2imdSEgm/kwqmQj1Or1l/7bW6mxVK7z4= +go.opentelemetry.io/otel/trace v1.37.0/go.mod h1:TlgrlQ+PtQO5XFerSPUYG0JSgGyryXewPGyayAWSBS0= +go.uber.org/goleak v1.3.0 h1:2K3zAYmnTNqV73imy9J1T3WC+gmCePx2hEGkimedGto= +go.uber.org/goleak v1.3.0/go.mod h1:CoHD4mav9JJNrW/WLlf7HGZPjdw8EucARQHekz1X6bE= +golang.org/x/arch v0.20.0 h1:dx1zTU0MAE98U+TQ8BLl7XsJbgze2WnNKF/8tGp/Q6c= +golang.org/x/arch v0.20.0/go.mod h1:bdwinDaKcfZUGpH09BB7ZmOfhalA8lQdzl62l8gGWsk= +golang.org/x/crypto v0.0.0-20180904163835-0709b304e793/go.mod h1:6SG95UA2DQfeDnfUPMdvaQW0Q7yPrPDi9nlGo2tz2b4= +golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= +golang.org/x/crypto v0.0.0-20190422162423-af44ce270edf/go.mod h1:WFFai1msRO1wXaEeE5yQxYXgSfI8pQAWXbQop6sCtWE= +golang.org/x/crypto v0.0.0-20200302210943-78000ba7a073/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto= +golang.org/x/crypto v0.0.0-20201216223049-8b5274cf687f/go.mod h1:jdWPYTVW3xRLrWPugEBEK3UY2ZEsg3UU495nc5E+M+I= +golang.org/x/crypto v0.41.0 h1:WKYxWedPGCTVVl5+WHSSrOBT0O8lx32+zxmHxijgXp4= +golang.org/x/crypto v0.41.0/go.mod h1:pO5AFd7FA68rFak7rOAGVuygIISepHftHnr8dr6+sUc= +golang.org/x/net v0.0.0-20190311183353-d8887717615a/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg= +golang.org/x/net v0.0.0-20190404232315-eb5bcb51f2a3/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg= +golang.org/x/net v0.0.0-20210421230115-4e50805a0758/go.mod h1:72T/g9IO56b78aLF+1Kcs5dz7/ng1VjMUvfKvpfy+jM= +golang.org/x/net v0.0.0-20220127200216-cd36cc0744dd/go.mod h1:CfG3xpIq0wQ8r1q4Su4UZFWDARRcnwPjda9FqA0JpMk= +golang.org/x/net v0.43.0 h1:lat02VYK2j4aLzMzecihNvTlJNQUq316m2Mr9rnM6YE= +golang.org/x/net v0.43.0/go.mod h1:vhO1fvI4dGsIjh73sWfUVjj3N7CA9WkKJNQm2svM6Jg= +golang.org/x/oauth2 v0.30.0 h1:dnDm7JmhM45NNpd8FDDeLhK6FwqbOf4MLCM9zb1BOHI= +golang.org/x/oauth2 v0.30.0/go.mod h1:B++QgG3ZKulg6sRPGD/mqlHQs5rB3Ml9erfeDY7xKlU= +golang.org/x/sync v0.0.0-20190227155943-e225da77a7e6/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= +golang.org/x/sync v0.0.0-20190412183630-56d357773e84/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= +golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= +golang.org/x/sync v0.0.0-20190911185100-cd5d95a43a6e/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= +golang.org/x/sync v0.16.0 h1:ycBJEhp9p4vXvUZNszeOq0kGTPghopOL8q0fq3vstxw= +golang.org/x/sync v0.16.0/go.mod h1:1dzgHSNfp02xaA81J2MS99Qcpr2w7fw1gpm99rleRqA= +golang.org/x/sys v0.0.0-20180905080454-ebe1bf3edb33/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= +golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= +golang.org/x/sys v0.0.0-20190403152447-81d4e9dc473e/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20190412213103-97732733099d/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20190419153524-e8e3143a4f4a/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20190422165155-953cdadca894/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20190531175056-4c3a928424d2/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20191026070338-33540a1f6037/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20210420072515-93ed5bcd2bfe/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20210615035016-665e8c7367d1/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.0.0-20211216021012-1d35b9e2eb4e/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.0.0-20220811171246-fbc7d0a398ab/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.12.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.35.0 h1:vz1N37gP5bs89s7He8XuIYXpyY0+QlsKmzipCbUtyxI= +golang.org/x/sys v0.35.0/go.mod h1:BJP2sWEmIv4KK5OTEluFJCKSidICx8ciO85XgH3Ak8k= +golang.org/x/term v0.0.0-20201117132131-f5c789dd3221/go.mod h1:Nr5EML6q2oocZ2LXRh80K7BxOlk5/8JxuGnuhpl+muw= +golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo= +golang.org/x/term v0.0.0-20210927222741-03fcf44c2211/go.mod h1:jbD1KX2456YbFQfuXm/mYQcufACuNUgVhRMnK/tPxf8= +golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= +golang.org/x/text v0.3.5/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= +golang.org/x/text v0.3.6/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= +golang.org/x/text v0.3.7/go.mod h1:u+2+/6zg+i71rQMx5EYifcz6MCKuco9NR6JIITiCfzQ= +golang.org/x/text v0.28.0 h1:rhazDwis8INMIwQ4tpjLDzUhx6RlXqZNPEM0huQojng= +golang.org/x/text v0.28.0/go.mod h1:U8nCwOR8jO/marOQ0QbDiOngZVEBB7MAiitBuMjXiNU= +golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= +golang.org/x/tools v0.0.0-20190329151228-23e29df326fe/go.mod h1:LCzVGOaR6xXOjkQ3onu1FJEFr0SW1gC7cKk1uF8kGRs= +golang.org/x/tools v0.0.0-20190416151739-9c9e1878f421/go.mod h1:LCzVGOaR6xXOjkQ3onu1FJEFr0SW1gC7cKk1uF8kGRs= +golang.org/x/tools v0.0.0-20190420181800-aa740d480789/go.mod h1:LCzVGOaR6xXOjkQ3onu1FJEFr0SW1gC7cKk1uF8kGRs= +golang.org/x/tools v0.0.0-20190531172133-b3315ee88b7d/go.mod h1:/rFqwRUd4F7ZHNgwSSTFct+R/Kf4OFW1sUzUTQQTgfc= +golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= +google.golang.org/genai v1.22.0 h1:5hrEhXXWJQZa3tdPocl4vQ/0w6myEAxdNns2Kmx0f4Y= +google.golang.org/genai v1.22.0/go.mod h1:QPj5NGJw+3wEOHg+PrsWwJKvG6UC84ex5FR7qAYsN/M= +google.golang.org/genproto/googleapis/rpc v0.0.0-20250811230008-5f3141c8851a h1:tPE/Kp+x9dMSwUm/uM0JKK0IfdiJkwAbSMSeZBXXJXc= +google.golang.org/genproto/googleapis/rpc v0.0.0-20250811230008-5f3141c8851a/go.mod h1:gw1tLEfykwDz2ET4a12jcXt4couGAm7IwsVaTy0Sflo= +google.golang.org/grpc v1.74.2 h1:WoosgB65DlWVC9FqI82dGsZhWFNBSLjQ84bjROOpMu4= +google.golang.org/grpc v1.74.2/go.mod h1:CtQ+BGjaAIXHs/5YS3i473GqwBBa1zGQNevxdeBEXrM= +google.golang.org/protobuf v1.36.7 h1:IgrO7UwFQGJdRNXH/sQux4R1Dj1WAKcLElzeeRaXV2A= +google.golang.org/protobuf v1.36.7/go.mod h1:jduwjTPXsFjZGTmRluh+L6NjiWu7pchiJ2/5YcXBHnY= +gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= +gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= +gopkg.in/check.v1 v1.0.0-20200227125254-8fa46927fb4f/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= +gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c h1:Hei/4ADfdWqJk1ZMxUNpqntNwaWcugrBjAiHlqqRiVk= +gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c/go.mod h1:JHkPIbrfpd72SG/EVd6muEfDQjcINNoR0C8j2r3qZ4Q= +gopkg.in/errgo.v2 v2.1.0/go.mod h1:hNsd1EY+bozCKY1Ytp96fpM3vjJbqLJn88ws8XvfDNI= +gopkg.in/yaml.v2 v2.2.2/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= +gopkg.in/yaml.v2 v2.2.8/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= +gopkg.in/yaml.v2 v2.4.0 h1:D8xgwECY7CYvx+Y2n4sBz93Jn9JRvxdiyyo8CTfuKaY= +gopkg.in/yaml.v2 v2.4.0/go.mod h1:RDklbk79AGWmwhnvt/jBztapEOGDOx6ZbXqjP6csGnQ= +gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= +gopkg.in/yaml.v3 v3.0.0-20200605160147-a5ece683394c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= +gopkg.in/yaml.v3 v3.0.0-20200615113413-eeeca48fe776/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= +gopkg.in/yaml.v3 v3.0.0-20210107192922-496545a6307b/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= +gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= +gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= +gorm.io/driver/sqlite v1.6.0 h1:WHRRrIiulaPiPFmDcod6prc4l2VGVWHz80KspNsxSfQ= +gorm.io/driver/sqlite v1.6.0/go.mod h1:AO9V1qIQddBESngQUKWL9yoH93HIeA1X6V633rBwyT8= +gorm.io/gorm v1.30.1 h1:lSHg33jJTBxs2mgJRfRZeLDG+WZaHYCk3Wtfl6Ngzo4= +gorm.io/gorm v1.30.1/go.mod h1:8Z33v652h4//uMA76KjeDH8mJXPm1QNCYrMeatR0DOE= diff --git a/transports/http/main.go b/transports/http/main.go deleted file mode 100644 index 8af6fb317..000000000 --- a/transports/http/main.go +++ /dev/null @@ -1,443 +0,0 @@ -// Package http provides an HTTP service using FastHTTP that exposes endpoints -// for text and chat completions using various AI model providers (OpenAI, Anthropic, Bedrock, etc.). - -// The HTTP service provides two main endpoints: -// - /v1/text/completions: For text completion requests -// - /v1/chat/completions: For chat completion requests - -// Configuration is handled through a JSON config file and environment variables: -// - Use -config flag to specify the config file location -// - Use -env flag to specify the .env file location -// - Use -port flag to specify the server port (default: 8080) -// - Use -pool-size flag to specify the initial connection pool size (default: 300) - -// try running the server with: -// go run http.go -config config.example.json -env .env -port 8080 -pool-size 300 -// after setting the environment variables present in config.example.json in your .env file. - -package main - -import ( - "encoding/json" - "errors" - "flag" - "fmt" - "log" - "os" - "reflect" - "strings" - "sync" - - "github.com/fasthttp/router" - "github.com/joho/godotenv" - bifrost "github.com/maximhq/bifrost/core" - schemas "github.com/maximhq/bifrost/core/schemas" - "github.com/maximhq/bifrost/core/schemas/meta" - "github.com/valyala/fasthttp" -) - -// Command line flags -var ( - initialPoolSize int // Initial size of the connection pool - dropExcessRequests bool // Drop excess requests - port string // Port to run the server on - configPath string // Path to the config file - envPath string // Path to the .env file -) - -// init initializes command line flags with default values. -// It also checks for environment variables that might override the defaults. -func init() { - flag.IntVar(&initialPoolSize, "pool-size", 300, "Initial pool size for Bifrost") - flag.StringVar(&port, "port", "8080", "Port to run the server on") - flag.StringVar(&configPath, "config", "", "Path to the config file") - flag.StringVar(&envPath, "env", "", "Path to the .env file") - flag.BoolVar(&dropExcessRequests, "drop-excess-requests", false, "Drop excess requests") - flag.Parse() - - if configPath == "" { - log.Fatalf("config path is required") - } - - if envPath == "" { - log.Fatalf("env path is required") - } -} - -// ProviderConfig represents the configuration for a specific AI model provider. -// It includes API keys, network settings, provider-specific metadata, and concurrency settings. -type ProviderConfig struct { - Keys []schemas.Key `json:"keys"` // API keys for the provider - NetworkConfig *schemas.NetworkConfig `json:"network_config,omitempty"` // Network-related settings - MetaConfig *schemas.MetaConfig `json:"-"` // Provider-specific metadata - ConcurrencyAndBufferSize *schemas.ConcurrencyAndBufferSize `json:"concurrency_and_buffer_size,omitempty"` // Concurrency settings -} - -// ConfigMap maps provider names to their configurations. -type ConfigMap map[schemas.ModelProvider]ProviderConfig - -// readConfig reads and parses the configuration file. -// It handles case conversion for provider names and sets up provider-specific metadata. -// Returns a ConfigMap containing all provider configurations. -// Panics if the config file cannot be read or parsed. -// -// In the config file, use placeholder keys (e.g., env.OPENAI_API_KEY) instead of hardcoding actual values. -// These placeholders will be replaced with the corresponding values from the .env file. -// Location of the .env file is specified by the -env flag. It -// Example: -// -// "keys":[{ -// "value": "env.OPENAI_API_KEY" -// "models": ["gpt-4o-mini", "gpt-4-turbo"], -// "weight": 1.0 -// }] -// -// In this example, OPENAI_API_KEY refers to a key in the .env file. At runtime, its value will be used to replace the placeholder. -// Same setup applies to keys in meta configs of all the providers. -// Example: -// -// "meta_config": { -// "secret_access_key": "env.BEDROCK_ACCESS_KEY" -// "region": "env.BEDROCK_REGION" -// } -// -// In this example, BEDROCK_ACCESS_KEY and BEDROCK_REGION refer to keys in the .env file. -func readConfig(configLocation string) ConfigMap { - data, err := os.ReadFile(configLocation) - if err != nil { - log.Fatalf("failed to read config JSON file: %v", err) - } - - // First unmarshal into a map with string keys to handle case conversion - var rawConfig map[string]ProviderConfig - if err := json.Unmarshal(data, &rawConfig); err != nil { - log.Fatalf("failed to unmarshal JSON: %v", err) - } - - if rawConfig == nil { - log.Fatalf("provided config is nil") - } - - // Create a new config map with lowercase provider names - config := make(ConfigMap) - for rawProvider, cfg := range rawConfig { - provider := schemas.ModelProvider(strings.ToLower(rawProvider)) - - switch provider { - case schemas.Azure: - var azureMetaConfig meta.AzureMetaConfig - if err := json.Unmarshal(data, &struct { - Azure struct { - MetaConfig *meta.AzureMetaConfig `json:"meta_config"` - } `json:"Azure"` - }{Azure: struct { - MetaConfig *meta.AzureMetaConfig `json:"meta_config"` - }{&azureMetaConfig}}); err != nil { - log.Printf("warning: failed to unmarshal Azure meta config: %v", err) - } - var metaConfig schemas.MetaConfig = &azureMetaConfig - cfg.MetaConfig = &metaConfig - case schemas.Bedrock: - var bedrockMetaConfig meta.BedrockMetaConfig - if err := json.Unmarshal(data, &struct { - Bedrock struct { - MetaConfig *meta.BedrockMetaConfig `json:"meta_config"` - } `json:"Bedrock"` - }{Bedrock: struct { - MetaConfig *meta.BedrockMetaConfig `json:"meta_config"` - }{&bedrockMetaConfig}}); err != nil { - log.Printf("warning: failed to unmarshal Bedrock meta config: %v", err) - } - var metaConfig schemas.MetaConfig = &bedrockMetaConfig - cfg.MetaConfig = &metaConfig - } - - config[provider] = cfg - } - - return config -} - -// BaseAccount implements the Account interface for Bifrost. -// It manages provider configurations and API keys. -type BaseAccount struct { - Config ConfigMap // Map of provider configurations - mu sync.Mutex // Mutex to protect Config access -} - -// GetConfiguredProviders returns a list of all configured providers. -// Implements the Account interface. -func (baseAccount *BaseAccount) GetConfiguredProviders() ([]schemas.ModelProvider, error) { - baseAccount.mu.Lock() - defer baseAccount.mu.Unlock() - - providers := make([]schemas.ModelProvider, 0, len(baseAccount.Config)) - for provider := range baseAccount.Config { - providers = append(providers, provider) - } - return providers, nil -} - -// GetKeysForProvider returns the API keys configured for a specific provider. -// Implements the Account interface. -func (baseAccount *BaseAccount) GetKeysForProvider(providerKey schemas.ModelProvider) ([]schemas.Key, error) { - baseAccount.mu.Lock() - defer baseAccount.mu.Unlock() - - return baseAccount.Config[providerKey].Keys, nil -} - -// GetConfigForProvider returns the complete configuration for a specific provider. -// Implements the Account interface. -func (baseAccount *BaseAccount) GetConfigForProvider(providerKey schemas.ModelProvider) (*schemas.ProviderConfig, error) { - baseAccount.mu.Lock() - defer baseAccount.mu.Unlock() - - config, exists := baseAccount.Config[providerKey] - if !exists { - return nil, errors.New("config for provider not found") - } - - providerConfig := &schemas.ProviderConfig{} - - if config.NetworkConfig != nil { - providerConfig.NetworkConfig = *config.NetworkConfig - } - - if config.MetaConfig != nil { - providerConfig.MetaConfig = *config.MetaConfig - } - - if config.ConcurrencyAndBufferSize != nil { - providerConfig.ConcurrencyAndBufferSize = *config.ConcurrencyAndBufferSize - } - - return providerConfig, nil -} - -// readKeys reads environment variables from a .env file and updates the provider configurations. -// It replaces values starting with "env." in the config with actual values from the environment. -// Returns an error if any required environment variable is missing. -func (baseAccount *BaseAccount) readKeys(envLocation string) error { - envVars, err := godotenv.Read(envLocation) - if err != nil { - return fmt.Errorf("failed to read .env file: %w", err) - } - - // Helper function to check and replace env values - replaceEnvValue := func(value string) (string, error) { - if strings.HasPrefix(value, "env.") { - envKey := strings.TrimPrefix(value, "env.") - if envValue, exists := envVars[envKey]; exists { - return envValue, nil - } - return "", fmt.Errorf("environment variable %s not found in .env file", envKey) - } - return value, nil - } - - // Helper function to recursively check and replace env values in a struct - var processStruct func(interface{}) error - processStruct = func(v interface{}) error { - val := reflect.ValueOf(v) - - // Dereference pointer if present - if val.Kind() == reflect.Ptr { - val = val.Elem() - } - - // Handle interface types - if val.Kind() == reflect.Interface { - val = val.Elem() - // If the interface value is a pointer, dereference it - if val.Kind() == reflect.Ptr { - val = val.Elem() - } - } - - if val.Kind() != reflect.Struct { - return nil - } - - typ := val.Type() - for i := 0; i < val.NumField(); i++ { - field := val.Field(i) - fieldType := typ.Field(i) - - // Skip unexported fields - if !field.CanSet() { - continue - } - - switch field.Kind() { - case reflect.String: - if field.CanSet() { - value := field.String() - if strings.HasPrefix(value, "env.") { - newValue, err := replaceEnvValue(value) - if err != nil { - return fmt.Errorf("field %s: %w", fieldType.Name, err) - } - field.SetString(newValue) - } - } - case reflect.Interface: - if !field.IsNil() { - if err := processStruct(field.Interface()); err != nil { - return err - } - } - } - } - return nil - } - - // Lock the config map for the entire update operation - baseAccount.mu.Lock() - defer baseAccount.mu.Unlock() - - // Check and replace values in provider configs - for provider, config := range baseAccount.Config { - // Check keys - for i, key := range config.Keys { - newValue, err := replaceEnvValue(key.Value) - if err != nil { - return fmt.Errorf("provider %s: %w", provider, err) - } - config.Keys[i].Value = newValue - } - - // Check meta config if it exists - if config.MetaConfig != nil { - if err := processStruct(config.MetaConfig); err != nil { - return fmt.Errorf("provider %s: %w", provider, err) - } - } - - baseAccount.Config[provider] = config - } - - return nil -} - -// CompletionRequest represents a request for either text or chat completion. -// It includes all necessary fields for both types of completions. -type CompletionRequest struct { - Provider schemas.ModelProvider `json:"provider"` // The AI model provider to use - Messages []schemas.Message `json:"messages"` // Chat messages (for chat completion) - Text string `json:"text"` // Text input (for text completion) - Model string `json:"model"` // Model to use - Params *schemas.ModelParameters `json:"params"` // Additional model parameters - Fallbacks []schemas.Fallback `json:"fallbacks"` // Fallback providers and models -} - -// handleCompletion processes both text and chat completion requests. -// It handles request parsing, validation, and response formatting. -func handleCompletion(ctx *fasthttp.RequestCtx, client *bifrost.Bifrost, isChat bool) { - var req CompletionRequest - if err := json.Unmarshal(ctx.PostBody(), &req); err != nil { - ctx.SetStatusCode(fasthttp.StatusBadRequest) - ctx.SetBodyString(fmt.Sprintf("invalid request format: %v", err)) - return - } - - if req.Provider == "" { - ctx.SetStatusCode(fasthttp.StatusBadRequest) - ctx.SetBodyString("Provider is required") - return - } - - bifrostReq := &schemas.BifrostRequest{ - Model: req.Model, - Params: req.Params, - Fallbacks: req.Fallbacks, - } - - if isChat { - if len(req.Messages) == 0 { - ctx.SetStatusCode(fasthttp.StatusBadRequest) - ctx.SetBodyString("Messages array is required") - return - } - bifrostReq.Input = schemas.RequestInput{ - ChatCompletionInput: &req.Messages, - } - } else { - if req.Text == "" { - ctx.SetStatusCode(fasthttp.StatusBadRequest) - ctx.SetBodyString("Text is required") - return - } - bifrostReq.Input = schemas.RequestInput{ - TextCompletionInput: &req.Text, - } - } - - var resp *schemas.BifrostResponse - var err *schemas.BifrostError - if isChat { - resp, err = client.ChatCompletionRequest(req.Provider, bifrostReq, ctx) - } else { - resp, err = client.TextCompletionRequest(req.Provider, bifrostReq, ctx) - } - - if err != nil { - if err.IsBifrostError { - ctx.SetStatusCode(fasthttp.StatusInternalServerError) - } else { - ctx.SetStatusCode(fasthttp.StatusBadRequest) - } - ctx.SetContentType("application/json") - json.NewEncoder(ctx).Encode(err) - return - } - - ctx.SetStatusCode(fasthttp.StatusOK) - ctx.SetContentType("application/json") - json.NewEncoder(ctx).Encode(resp) -} - -// main is the entry point of the application. -// It: -// 1. Reads and parses configuration -// 2. Initializes the Bifrost client -// 3. Sets up HTTP routes -// 4. Starts the HTTP server -func main() { - config := readConfig(configPath) - account := &BaseAccount{Config: config} - - if err := account.readKeys(envPath); err != nil { - log.Printf("warning: failed to read environment variables: %v", err) - } - - client, err := bifrost.Init(schemas.BifrostConfig{ - Account: account, - InitialPoolSize: initialPoolSize, - DropExcessRequests: dropExcessRequests, - }) - if err != nil { - log.Fatalf("failed to initialize bifrost: %v", err) - } - - r := router.New() - - r.POST("/v1/text/completions", func(ctx *fasthttp.RequestCtx) { - handleCompletion(ctx, client, false) - }) - - r.POST("/v1/chat/completions", func(ctx *fasthttp.RequestCtx) { - handleCompletion(ctx, client, true) - }) - - server := &fasthttp.Server{ - Handler: r.Handler, - } - - fmt.Printf("Starting HTTP server on port %s\n", port) - if err := server.ListenAndServe(fmt.Sprintf(":%s", port)); err != nil { - log.Fatalf("failed to start server: %v", err) - } - - client.Shutdown() -} diff --git a/transports/version b/transports/version new file mode 100644 index 000000000..9728bd69a --- /dev/null +++ b/transports/version @@ -0,0 +1 @@ +1.2.21 diff --git a/ui/.gitignore b/ui/.gitignore new file mode 100644 index 000000000..5ef6a5207 --- /dev/null +++ b/ui/.gitignore @@ -0,0 +1,41 @@ +# See https://help.github.com/articles/ignoring-files/ for more about ignoring files. + +# dependencies +/node_modules +/.pnp +.pnp.* +.yarn/* +!.yarn/patches +!.yarn/plugins +!.yarn/releases +!.yarn/versions + +# testing +/coverage + +# next.js +/.next/ +/out/ + +# production +/build + +# misc +.DS_Store +*.pem + +# debug +npm-debug.log* +yarn-debug.log* +yarn-error.log* +.pnpm-debug.log* + +# env files (can opt-in for committing if needed) +.env* + +# vercel +.vercel + +# typescript +*.tsbuildinfo +next-env.d.ts diff --git a/ui/.prettierrc b/ui/.prettierrc new file mode 100644 index 000000000..f73138e64 --- /dev/null +++ b/ui/.prettierrc @@ -0,0 +1,21 @@ +{ + "printWidth": 140, + "singleQuote": false, + "bracketSpacing": true, + "semi": true, + "bracketSameLine": false, + "useTabs": true, + "tabWidth": 2, + "trailingComma": "all", + "plugins": [ + "prettier-plugin-tailwindcss" + ], + "tailwindAttributes": [ + "buttonClassname" + ], + "tailwindFunctions": [ + "cn", + "classNames" + ], + "endOfLine": "lf" +} \ No newline at end of file diff --git a/ui/README.md b/ui/README.md new file mode 100644 index 000000000..2bf1ed5ae --- /dev/null +++ b/ui/README.md @@ -0,0 +1,246 @@ +# Bifrost UI + +A modern, production-ready dashboard for the [Bifrost AI Gateway](https://github.com/maximhq/bifrost) - providing real-time monitoring, configuration management, and comprehensive observability for your AI infrastructure. + +## 🌟 Overview + +Bifrost UI is a Next.js-powered web dashboard that serves as the control center for your Bifrost AI Gateway. It provides an intuitive interface to monitor AI requests, configure providers, manage MCP clients, and extend functionality through plugins. + +### Key Features + +- **πŸ”΄ Real-time Log Monitoring** - Live streaming dashboard with WebSocket integration +- **βš™οΈ Provider Management** - Configure 8+ AI providers (OpenAI, Azure, Anthropic, Bedrock, etc.) +- **πŸ”Œ MCP Integration** - Manage Model Context Protocol clients for advanced AI capabilities +- **🧩 Plugin System** - Extend functionality with observability, testing, and custom plugins +- **πŸ“Š Analytics Dashboard** - Request metrics, success rates, latency tracking, and token usage +- **🎨 Modern UI** - Dark/light mode, responsive design, and accessible components +- **πŸ“š Documentation Hub** - Built-in documentation browser and quick-start guides + +## πŸš€ Quick Start + +### Development + +```bash +# Install dependencies +npm install + +# Start development server +npm run dev +``` + +The development server runs on `http://localhost:3000` and connects to your Bifrost HTTP transport backend (default: `http://localhost:8080`). + +### Environment Variables + +```bash +# Development only - customize Bifrost backend port +NEXT_PUBLIC_BIFROST_PORT=8080 +``` + +## πŸ—οΈ Architecture + +### Technology Stack + +- **Framework**: Next.js 15 with App Router +- **Language**: TypeScript +- **Styling**: Tailwind CSS + Radix UI components +- **State Management**: React hooks and context +- **Real-time**: WebSocket integration +- **HTTP Client**: Axios with typed service layer +- **Theme**: Dark/light mode support + +### Integration Model + +``` +β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β” HTTP/WebSocket β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β” +β”‚ Bifrost UI β”‚ ◄─────────────────► β”‚ Bifrost HTTP β”‚ +β”‚ (Next.js) β”‚ β”‚ Transport (Go) β”‚ +β””β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜ β””β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜ + β”‚ β”‚ + β”‚ Build artifacts β”‚ + β””β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜ +``` + +- **Development**: UI runs on port 3000, connects to Go backend on port 8080 +- **Production**: UI built as static assets served directly by Go HTTP transport +- **Communication**: REST API + WebSocket for real-time features + +## πŸ“± Features Deep Dive + +### Real-time Log Monitoring + +The main dashboard provides comprehensive request monitoring: + +- **Live Updates**: WebSocket connection for real-time log streaming +- **Advanced Filtering**: Filter by providers, models, status, content, and time ranges +- **Request Analytics**: Success rates, average latency, total tokens usage +- **Detailed Views**: Full request/response inspection with syntax highlighting +- **Search**: Full-text search across request content and metadata + +### Provider Configuration + +Manage all your AI providers from a unified interface: + +- **Supported Providers**: OpenAI, Azure OpenAI, Anthropic, AWS Bedrock, Cohere, Google Vertex AI, Mistral, Ollama, Groq, Parasail, SGLang, Cerebras, Gemini, OpenRouter +- **Key Management**: Multiple API keys with weights and model assignments +- **Network Configuration**: Custom base URLs, timeouts, retry policies, proxy settings +- **Provider-specific Settings**: Azure deployments, Bedrock regions, Vertex projects +- **Concurrency Control**: Per-provider concurrency limits and buffer sizes + +### MCP Client Management + +Model Context Protocol integration for advanced AI capabilities: + +- **Client Configuration**: Add, update, and delete MCP clients +- **Connection Monitoring**: Real-time status and health checks +- **Reconnection**: Manual and automatic reconnection capabilities +- **Tool Integration**: Seamless integration with MCP tools and resources + +### Plugin Ecosystem + +Extend Bifrost with powerful plugins: + +- **Maxim Logger**: Advanced LLM observability and analytics +- **Response Mocker**: Mock responses for testing and development +- **Circuit Breaker**: Resilience patterns and failure handling +- **Custom Plugins**: Build your own with the plugin development guide + +## πŸ› οΈ Development + +### Project Structure + +``` +ui/ +β”œβ”€β”€ app/ # Next.js App Router pages +β”‚ β”œβ”€β”€ page.tsx # Main logs dashboard +β”‚ β”œβ”€β”€ config/ # Provider & MCP configuration +β”‚ β”œβ”€β”€ docs/ # Documentation browser +β”‚ └── plugins/ # Plugin management +β”œβ”€β”€ components/ # Reusable UI components +β”‚ β”œβ”€β”€ logs/ # Log monitoring components +β”‚ β”œβ”€β”€ config/ # Configuration forms +β”‚ └── ui/ # Base UI components (Radix) +β”œβ”€β”€ hooks/ # Custom React hooks +β”œβ”€β”€ lib/ # Utilities and services +β”‚ β”œβ”€β”€ api.ts # Backend API service +β”‚ β”œβ”€β”€ types/ # TypeScript definitions +β”‚ └── utils/ # Helper functions +└── scripts/ # Build and deployment scripts +``` + +### API Integration + +The UI uses Redux Toolkit + RTK Query for state management and API communication with the Bifrost HTTP transport backend: + +```typescript +// Example API usage with RTK Query +import { + useGetLogsQuery, + useCreateProviderMutation, + getErrorMessage +} from '@/lib/store' + +// Get real-time logs with automatic caching +const { data: logs, error, isLoading } = useGetLogsQuery({ filters, pagination }) + +// Configure provider with optimistic updates +const [createProvider] = useCreateProviderMutation() + +const handleCreate = async () => { + try { + await createProvider({ + provider: 'openai', + keys: [{ value: 'sk-...', models: ['gpt-4'], weight: 1 }], + // ... other config + }).unwrap() + // Success handling + } catch (error) { + console.error(getErrorMessage(error)) + } +} +``` + +### Component Guidelines + +- **Composition**: Use Radix UI primitives for accessibility +- **Styling**: Tailwind CSS with CSS variables for theming +- **Types**: Full TypeScript coverage matching Go backend schemas +- **Error Handling**: Consistent error states and user feedback + +### Adding New Features + +1. **Backend Integration**: Add API endpoints to `lib/api.ts` +2. **Type Definitions**: Update types in `lib/types/` +3. **UI Components**: Build with Radix UI and Tailwind +4. **State Management**: Use React hooks or context as needed +5. **Real-time Updates**: Integrate WebSocket events when applicable + +## πŸ”§ Configuration + +### Provider Setup + +The UI supports comprehensive provider configuration: + +```typescript +interface ProviderConfig { + keys: Key[] // API keys with model assignments + network_config: NetworkConfig // URLs, timeouts, retries + meta_config?: MetaConfig // Provider-specific settings + concurrency_and_buffer_size: { + // Performance tuning + concurrency: number + buffer_size: number + } + proxy_config?: ProxyConfig // Proxy settings +} +``` + +### Real-time Features + +WebSocket connection provides: + +- Live log streaming +- Connection status monitoring +- Automatic reconnection +- Filtered real-time updates + +## πŸ“Š Monitoring & Analytics + +The dashboard provides comprehensive observability: + +- **Request Metrics**: Total requests, success rate, average latency +- **Token Usage**: Input/output tokens, total consumption tracking +- **Provider Performance**: Per-provider success rates and latencies +- **Error Analysis**: Detailed error categorization and troubleshooting +- **Historical Data**: Time-based filtering and trend analysis + +## 🀝 Contributing + +We welcome contributions! See our [Contributing Guide](https://github.com/maximhq/bifrost/tree/main/docs/contributing) for: + +- Code conventions and style guide +- Development setup and workflow +- Adding new providers or features +- Plugin development guidelines + +## πŸ“š Documentation + +- **Quick Start**: [Get started in 30 seconds](https://github.com/maximhq/bifrost/tree/main/docs/quickstart) +- **Configuration**: [Complete setup guide](https://github.com/maximhq/bifrost/tree/main/docs/usage/http-transport/configuration) +- **API Reference**: [HTTP transport endpoints](https://github.com/maximhq/bifrost/tree/main/docs/usage/http-transport) +- **Architecture**: [Design and performance](https://github.com/maximhq/bifrost/tree/main/docs/architecture) + +## πŸ”— Links + +- **Main Repository**: [github.com/maximhq/bifrost](https://github.com/maximhq/bifrost) +- **HTTP Transport**: [../transports/bifrost-http](../transports/bifrost-http) +- **Documentation**: [docs/](../docs/) +- **Website**: [getmaxim.ai](https://getmaxim.ai) + +## πŸ“„ License + +Licensed under the same terms as the main Bifrost project. See [LICENSE](../LICENSE) for details. + +--- + +_Built with β™₯️ by [Maxim AI](https://getmaxim.ai)_ diff --git a/ui/app/config/page.tsx b/ui/app/config/page.tsx new file mode 100644 index 000000000..514966fd9 --- /dev/null +++ b/ui/app/config/page.tsx @@ -0,0 +1,415 @@ +"use client"; + +import PluginsForm from "@/app/config/views/pluginsForm"; +import FullPageLoader from "@/components/fullPageLoader"; +import { Alert, AlertDescription } from "@/components/ui/alert"; +import { CardDescription, CardHeader, CardTitle } from "@/components/ui/card"; +import { Input } from "@/components/ui/input"; +import { Switch } from "@/components/ui/switch"; +import { Textarea } from "@/components/ui/textarea"; +import { getErrorMessage, useGetCoreConfigQuery, useGetDroppedRequestsQuery, useUpdateCoreConfigMutation } from "@/lib/store"; +import { CoreConfig } from "@/lib/types/config"; +import { parseArrayFromText } from "@/lib/utils/array"; +import { validateOrigins } from "@/lib/utils/validation"; +import { AlertTriangle } from "lucide-react"; +import { useCallback, useEffect, useRef, useState } from "react"; +import { toast } from "sonner"; + +const defaultConfig: CoreConfig = { + drop_excess_requests: false, + initial_pool_size: 1000, + prometheus_labels: [], + enable_logging: true, + enable_governance: true, + enforce_governance_header: false, + allow_direct_keys: false, + allowed_origins: [], + max_request_body_size_mb: 100, +}; + +export default function ConfigPage() { + const [droppedRequests, setDroppedRequests] = useState(0); + // RTK Query hooks + const { data: droppedRequestsData } = useGetDroppedRequestsQuery(); + const { data: bifrostConfig, isLoading } = useGetCoreConfigQuery({ fromDB: true }); + const config = bifrostConfig?.client_config; + const [needsRestart, setNeedsRestart] = useState(false); + const [updateCoreConfig] = useUpdateCoreConfigMutation(); + + const [localValues, setLocalValues] = useState<{ + initial_pool_size: string; + prometheus_labels: string; + allowed_origins: string; + max_request_body_size_mb: string; + }>({ + initial_pool_size: "300", + prometheus_labels: "", + allowed_origins: "", + max_request_body_size_mb: "100", + }); + + // Handle dropped requests data from RTK Query + useEffect(() => { + if (droppedRequestsData) { + setDroppedRequests(droppedRequestsData.dropped_requests); + } + }, [droppedRequestsData]); + + // Use refs to store timeout IDs + const poolSizeTimeoutRef = useRef(undefined); + const prometheusLabelsTimeoutRef = useRef(undefined); + const allowedOriginsTimeoutRef = useRef(undefined); + const maxRequestBodySizeMBTimeoutRef = useRef(undefined); + + // Update local values when config is loaded + useEffect(() => { + if (bifrostConfig && config) { + setLocalValues({ + initial_pool_size: config?.initial_pool_size?.toString() || "1000", + prometheus_labels: config?.prometheus_labels?.join(", ") || "", + allowed_origins: config?.allowed_origins?.join(", ") || "", + max_request_body_size_mb: config?.max_request_body_size_mb?.toString() || "100", + }); + } + }, [config, bifrostConfig]); + + const updateConfig = useCallback( + async (field: keyof CoreConfig, value: boolean | number | string[]) => { + try { + await updateCoreConfig({ ...(config ?? defaultConfig), [field]: value }).unwrap(); + toast.success("Core setting updated successfully."); + } catch (error) { + toast.error(getErrorMessage(error)); + } + }, + [config, updateCoreConfig], + ); + + const handleConfigChange = async (field: keyof CoreConfig, value: boolean | number | string[]) => { + await updateConfig(field, value); + }; + + const handlePoolSizeChange = useCallback( + (value: string) => { + setLocalValues((prev) => ({ ...prev, initial_pool_size: value })); + + // Clear existing timeout + if (poolSizeTimeoutRef.current) { + clearTimeout(poolSizeTimeoutRef.current); + } + + // Set new timeout + poolSizeTimeoutRef.current = setTimeout(() => { + const numValue = Number.parseInt(value); + if (!isNaN(numValue) && numValue > 0) { + updateConfig("initial_pool_size", numValue); + } + }, 1000); + setNeedsRestart(true); + }, + [updateConfig], + ); + + const handlePrometheusLabelsChange = useCallback( + (value: string) => { + setLocalValues((prev) => ({ ...prev, prometheus_labels: value })); + + // Clear existing timeout + if (prometheusLabelsTimeoutRef.current) { + clearTimeout(prometheusLabelsTimeoutRef.current); + } + + // Set new timeout + prometheusLabelsTimeoutRef.current = setTimeout(() => { + updateConfig("prometheus_labels", parseArrayFromText(value)); + }, 1000); + setNeedsRestart(true); + }, + [updateConfig], + ); + + const handleAllowedOriginsChange = useCallback( + (value: string) => { + setLocalValues((prev) => ({ ...prev, allowed_origins: value })); + + // Clear existing timeout + if (allowedOriginsTimeoutRef.current) { + clearTimeout(allowedOriginsTimeoutRef.current); + } + + // Set new timeout + allowedOriginsTimeoutRef.current = setTimeout(() => { + const origins = parseArrayFromText(value); + const validation = validateOrigins(origins); + + if (validation.isValid || origins.length === 0) { + updateConfig("allowed_origins", origins); + } else { + toast.error(`Invalid origins: ${validation.invalidOrigins.join(", ")}. Origins must be valid URLs like https://example.com`); + } + }, 1000); + setNeedsRestart(true); + }, + [updateConfig], + ); + + const handleMaxRequestBodySizeMBChange = useCallback( + (value: string) => { + setLocalValues((prev) => ({ ...prev, max_request_body_size_mb: value })); + + // Clear existing timeout + if (maxRequestBodySizeMBTimeoutRef.current) { + clearTimeout(maxRequestBodySizeMBTimeoutRef.current); + } + + // Set new timeout + maxRequestBodySizeMBTimeoutRef.current = setTimeout(() => { + const numValue = Number.parseInt(value); + if (!isNaN(numValue) && numValue > 0) { + updateConfig("max_request_body_size_mb", numValue); + } + }, 1000); + setNeedsRestart(true); + }, + [updateConfig], + ); + + // Cleanup timeouts on unmount + useEffect(() => { + return () => { + if (poolSizeTimeoutRef.current) { + clearTimeout(poolSizeTimeoutRef.current); + } + if (prometheusLabelsTimeoutRef.current) { + clearTimeout(prometheusLabelsTimeoutRef.current); + } + if (allowedOriginsTimeoutRef.current) { + clearTimeout(allowedOriginsTimeoutRef.current); + } + if (maxRequestBodySizeMBTimeoutRef.current) { + clearTimeout(maxRequestBodySizeMBTimeoutRef.current); + } + }; + }, []); + + return isLoading ? ( + + ) : ( +

+ {/* Page Header */} +
+ + Core System Settings + Configure core Bifrost settings like request handling, pool sizes, and system behavior. + +
+ {/* Drop Excess Requests */} +
+
+ +

+ If enabled, Bifrost will drop requests that exceed pool capacity.{" "} + {config?.drop_excess_requests && droppedRequests > 0 ? ( + + Have dropped {droppedRequests} requests since last restart. + + ) : ( + <> + )} +

+
+ handleConfigChange("drop_excess_requests", checked)} + /> +
+ + {config?.enable_governance && ( +
+
+ +

+ Enforce the use of a virtual key for all requests. If enabled, requests without the x-bf-vk header will be + rejected. +

+
+ handleConfigChange("enforce_governance_header", checked)} + /> +
+ )} + +
+
+ +

+ Allow API keys to be passed directly in request headers (Authorization or x-api-key). Bifrost will directly + use the key. +

+
+ handleConfigChange("allow_direct_keys", checked)} + /> +
+ + + + + The settings below require a Bifrost service restart to take effect. Current connections will continue with existing settings + until restart. + + + +
+
+
+ +

The initial connection pool size.

+
+ handlePoolSizeChange(e.target.value)} + min="1" + /> +
+ {needsRestart && } +
+ +
+
+
+ +

The initial connection pool size.

+
+ handleMaxRequestBodySizeMBChange(e.target.value)} + min="1" + /> +
+ {needsRestart && } +
+ +
+
+
+ +

+ Enable logging of requests and responses to a SQL database. This can add 40-60mb of overhead to the system memory. + {!bifrostConfig?.is_logs_connected && ( + Requires logs store to be configured and enabled in config.json. + )} +

+
+ { + if (bifrostConfig?.is_logs_connected) { + handleConfigChange("enable_logging", checked); + } + }} + /> +
+ {needsRestart && } +
+ +
+
+
+ +

+ Enable governance on requests. You can configure budgets and rate limits in the Governance tab. +

+
+ handleConfigChange("enable_governance", checked)} + /> +
+ {needsRestart && } +
+ + + +
+
+
+ +

Comma-separated list of custom labels to add to the Prometheus metrics.

+
+