Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
# Variables
HOST ?= localhost
PORT ?= 8080
APP_DIR ?=
APP_DIR ?=
PROMETHEUS_LABELS ?=
LOG_STYLE ?= json
LOG_LEVEL ?= info
Expand Down
2 changes: 1 addition & 1 deletion core/bifrost.go
Original file line number Diff line number Diff line change
Expand Up @@ -1583,7 +1583,7 @@ func (bifrost *Bifrost) selectKeyFromProviderForModel(ctx *context.Context, prov
}

if len(keys) == 0 {
return schemas.Key{}, fmt.Errorf("no keys found for provider: %v", providerKey)
return schemas.Key{}, fmt.Errorf("no keys found for provider: %v and model: %s", providerKey, model)
}

// filter out keys which dont support the model, if the key has no models, it is supported for all models
Expand Down
2 changes: 1 addition & 1 deletion core/providers/bedrock.go
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,7 @@ type BedrockMistralContent struct {
type BedrockMistralChatMessage struct {
Role schemas.ModelChatMessageRole `json:"role"` // Role of the message sender
Content []BedrockMistralContent `json:"content"` // Array of message content
ToolCalls []BedrockAnthropicToolCall `json:"tool_calls,omitempty"` // Optional tool calls
ToolCalls []BedrockAnthropicToolCall `json:"tool_calls,omitempty"` // Optional tool calls
ToolCallID *string `json:"tool_call_id,omitempty"` // Optional tool call ID
}

Expand Down
12 changes: 6 additions & 6 deletions core/schemas/bifrost.go
Original file line number Diff line number Diff line change
Expand Up @@ -118,7 +118,7 @@ const (
// a text completion, a chat completion, an embedding request, a speech request, or a transcription request.
type RequestInput struct {
TextCompletionInput *string `json:"text_completion_input,omitempty"`
ChatCompletionInput []BifrostMessage `json:"chat_completion_input,omitempty"`
ChatCompletionInput []BifrostMessage `json:"chat_completion_input,omitempty"`
EmbeddingInput *EmbeddingInput `json:"embedding_input,omitempty"`
SpeechInput *SpeechInput `json:"speech_input,omitempty"`
TranscriptionInput *TranscriptionInput `json:"transcription_input,omitempty"`
Expand Down Expand Up @@ -295,12 +295,12 @@ type Fallback struct {
// mapped to the provider's parameters.
type ModelParameters struct {
ToolChoice *ToolChoice `json:"tool_choice,omitempty"` // Whether to call a tool
Tools []Tool `json:"tools,omitempty"` // Tools to use
Tools []Tool `json:"tools,omitempty"` // Tools to use
Temperature *float64 `json:"temperature,omitempty"` // Controls randomness in the output
TopP *float64 `json:"top_p,omitempty"` // Controls diversity via nucleus sampling
TopK *int `json:"top_k,omitempty"` // Controls diversity via top-k sampling
MaxTokens *int `json:"max_tokens,omitempty"` // Maximum number of tokens to generate
StopSequences []string `json:"stop_sequences,omitempty"` // Sequences that stop generation
StopSequences []string `json:"stop_sequences,omitempty"` // Sequences that stop generation
PresencePenalty *float64 `json:"presence_penalty,omitempty"` // Penalizes repeated tokens
FrequencyPenalty *float64 `json:"frequency_penalty,omitempty"` // Penalizes frequent tokens
ParallelToolCalls *bool `json:"parallel_tool_calls,omitempty"` // Enables parallel tool calls
Expand All @@ -318,7 +318,7 @@ type FunctionParameters struct {
Description *string `json:"description,omitempty"` // Description of the parameters
Required []string `json:"required,omitempty"` // Required parameter names
Properties map[string]interface{} `json:"properties,omitempty"` // Parameter properties
Enum []string `json:"enum,omitempty"` // Enum values for the parameters
Enum []string `json:"enum,omitempty"` // Enum values for the parameters
}

// Function represents a function that can be called by the model.
Expand Down Expand Up @@ -492,7 +492,7 @@ type ToolMessage struct {
type AssistantMessage struct {
Refusal *string `json:"refusal,omitempty"`
Annotations []Annotation `json:"annotations,omitempty"`
ToolCalls []ToolCall `json:"tool_calls,omitempty"`
ToolCalls []ToolCall `json:"tool_calls,omitempty"`
Thought *string `json:"thought,omitempty"`
}

Expand Down Expand Up @@ -795,7 +795,7 @@ type BifrostResponseExtraFields struct {
Provider ModelProvider `json:"provider"`
Params ModelParameters `json:"model_params"`
Latency *int64 `json:"latency,omitempty"`
ChatHistory []BifrostMessage `json:"chat_history,omitempty"`
ChatHistory []BifrostMessage `json:"chat_history,omitempty"`
BilledUsage *BilledLLMUsage `json:"billed_usage,omitempty"`
ChunkIndex int `json:"chunk_index"` // used for streaming responses to identify the chunk index, will be 0 for non-streaming responses
RawResponse interface{} `json:"raw_response,omitempty"`
Expand Down
35 changes: 35 additions & 0 deletions framework/configstore/migrations.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,9 @@ func triggerMigrations(ctx context.Context, db *gorm.DB) error {
if err := migrationAddCustomProviderConfigJSONColumn(ctx, db); err != nil {
return err
}
if err := migrationAddVirtualKeyProviderConfigTable(ctx, db); err != nil {
return err
}
return nil
}

Expand Down Expand Up @@ -247,3 +250,35 @@ func migrationAddCustomProviderConfigJSONColumn(ctx context.Context, db *gorm.DB
}
return nil
}

func migrationAddVirtualKeyProviderConfigTable(ctx context.Context, db *gorm.DB) error {
m := migration.New(db, migration.DefaultOptions, []*migration.Migration{{
ID: "addvirtualkeyproviderconfig",
Migrate: func(tx *gorm.DB) error {
tx = tx.WithContext(ctx)
migrator := tx.Migrator()

if !migrator.HasTable(&TableVirtualKeyProviderConfig{}) {
if err := migrator.CreateTable(&TableVirtualKeyProviderConfig{}); err != nil {
return err
}
}

return nil
},
Rollback: func(tx *gorm.DB) error {
tx = tx.WithContext(ctx)
migrator := tx.Migrator()

if err := migrator.DropTable(&TableVirtualKeyProviderConfig{}); err != nil {
return err
}
return nil
},
}})
err := m.Migrate()
if err != nil {
return fmt.Errorf("error while running db migration: %s", err.Error())
}
return nil
}
7 changes: 7 additions & 0 deletions framework/configstore/store.go
Original file line number Diff line number Diff line change
Expand Up @@ -55,10 +55,17 @@ type ConfigStore interface {
// Governance config CRUD
GetVirtualKeys(ctx context.Context) ([]TableVirtualKey, error)
GetVirtualKey(ctx context.Context, id string) (*TableVirtualKey, error)
GetVirtualKeyByValue(ctx context.Context, value string) (*TableVirtualKey, error)
CreateVirtualKey(ctx context.Context, virtualKey *TableVirtualKey, tx ...*gorm.DB) error
UpdateVirtualKey(ctx context.Context, virtualKey *TableVirtualKey, tx ...*gorm.DB) error
DeleteVirtualKey(ctx context.Context, id string) error

// Virtual key provider config CRUD
GetVirtualKeyProviderConfigs(ctx context.Context, virtualKeyID string) ([]TableVirtualKeyProviderConfig, error)
CreateVirtualKeyProviderConfig(ctx context.Context, virtualKeyProviderConfig *TableVirtualKeyProviderConfig, tx ...*gorm.DB) error
UpdateVirtualKeyProviderConfig(ctx context.Context, virtualKeyProviderConfig *TableVirtualKeyProviderConfig, tx ...*gorm.DB) error
DeleteVirtualKeyProviderConfig(ctx context.Context, id uint, tx ...*gorm.DB) error

// Team CRUD
GetTeams(ctx context.Context, customerID string) ([]TableTeam, error)
GetTeam(ctx context.Context, id string) (*TableTeam, error)
Expand Down
35 changes: 23 additions & 12 deletions framework/configstore/tables.go
Original file line number Diff line number Diff line change
Expand Up @@ -616,13 +616,12 @@ type TableTeam struct {

// TableVirtualKey represents a virtual key with budget, rate limits, and team/customer association
type TableVirtualKey struct {
ID string `gorm:"primaryKey;type:varchar(255)" json:"id"`
Name string `gorm:"uniqueIndex:idx_virtual_key_name;type:varchar(255);not null" json:"name"`
Description string `gorm:"type:text" json:"description,omitempty"`
Value string `gorm:"uniqueIndex:idx_virtual_key_value;type:varchar(255);not null" json:"value"` // The virtual key value
IsActive bool `gorm:"default:true" json:"is_active"`
AllowedModels []string `gorm:"type:text;serializer:json" json:"allowed_models"` // Empty means all models allowed
AllowedProviders []string `gorm:"type:text;serializer:json" json:"allowed_providers"` // Empty means all providers allowed
ID string `gorm:"primaryKey;type:varchar(255)" json:"id"`
Name string `gorm:"uniqueIndex:idx_virtual_key_name;type:varchar(255);not null" json:"name"`
Description string `gorm:"type:text" json:"description,omitempty"`
Value string `gorm:"uniqueIndex:idx_virtual_key_value;type:varchar(255);not null" json:"value"` // The virtual key value
IsActive bool `gorm:"default:true" json:"is_active"`
ProviderConfigs []TableVirtualKeyProviderConfig `gorm:"foreignKey:VirtualKeyID;constraint:OnDelete:CASCADE" json:"provider_configs"` // Empty means all providers allowed

// Foreign key relationships (mutually exclusive: either TeamID or CustomerID, not both)
TeamID *string `gorm:"type:varchar(255);index" json:"team_id,omitempty"`
Expand All @@ -641,6 +640,15 @@ type TableVirtualKey struct {
UpdatedAt time.Time `gorm:"index;not null" json:"updated_at"`
}

// TableVirtualKeyProviderConfig represents a provider configuration for a virtual key
type TableVirtualKeyProviderConfig struct {
ID uint `gorm:"primaryKey;autoIncrement" json:"id"`
VirtualKeyID string `gorm:"type:varchar(255);not null" json:"virtual_key_id"`
Provider string `gorm:"type:varchar(50);not null" json:"provider"`
Weight float64 `gorm:"default:1.0" json:"weight"`
AllowedModels []string `gorm:"type:text;serializer:json" json:"allowed_models"` // Empty means all models allowed
}

// TableModelPricing represents pricing information for AI models
type TableModelPricing struct {
ID uint `gorm:"primaryKey;autoIncrement" json:"id"`
Expand Down Expand Up @@ -675,11 +683,14 @@ type TableModelPricing struct {
}

// Table names
func (TableBudget) TableName() string { return "governance_budgets" }
func (TableRateLimit) TableName() string { return "governance_rate_limits" }
func (TableCustomer) TableName() string { return "governance_customers" }
func (TableTeam) TableName() string { return "governance_teams" }
func (TableVirtualKey) TableName() string { return "governance_virtual_keys" }
func (TableBudget) TableName() string { return "governance_budgets" }
func (TableRateLimit) TableName() string { return "governance_rate_limits" }
func (TableCustomer) TableName() string { return "governance_customers" }
func (TableTeam) TableName() string { return "governance_teams" }
func (TableVirtualKey) TableName() string { return "governance_virtual_keys" }
func (TableVirtualKeyProviderConfig) TableName() string {
return "governance_virtual_key_provider_configs"
}
func (TableConfig) TableName() string { return "governance_config" }
func (TableModelPricing) TableName() string { return "governance_model_pricing" }

Expand Down
6 changes: 3 additions & 3 deletions framework/logstore/tables.go
Original file line number Diff line number Diff line change
Expand Up @@ -102,10 +102,10 @@ type Log struct {
// Virtual fields for JSON output - these will be populated when needed
InputHistoryParsed []schemas.BifrostMessage `gorm:"-" json:"input_history,omitempty"`
OutputMessageParsed *schemas.BifrostMessage `gorm:"-" json:"output_message,omitempty"`
EmbeddingOutputParsed []schemas.BifrostEmbedding `gorm:"-" json:"embedding_output,omitempty"`
EmbeddingOutputParsed []schemas.BifrostEmbedding `gorm:"-" json:"embedding_output,omitempty"`
ParamsParsed *schemas.ModelParameters `gorm:"-" json:"params,omitempty"`
ToolsParsed []schemas.Tool `gorm:"-" json:"tools,omitempty"`
ToolCallsParsed []schemas.ToolCall `gorm:"-" json:"tool_calls,omitempty"`
ToolsParsed []schemas.Tool `gorm:"-" json:"tools,omitempty"`
ToolCallsParsed []schemas.ToolCall `gorm:"-" json:"tool_calls,omitempty"`
TokenUsageParsed *schemas.LLMUsage `gorm:"-" json:"token_usage,omitempty"`
ErrorDetailsParsed *schemas.BifrostError `gorm:"-" json:"error_details,omitempty"`
SpeechInputParsed *schemas.SpeechInput `gorm:"-" json:"speech_input,omitempty"`
Expand Down
84 changes: 83 additions & 1 deletion framework/pricing/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package pricing
import (
"context"
"fmt"
"slices"
"sync"
"time"

Expand All @@ -26,6 +27,8 @@ type PricingManager struct {
pricingData map[string]configstore.TableModelPricing
mu sync.RWMutex

modelPool map[schemas.ModelProvider][]string

// Background sync worker
syncTicker *time.Ticker
done chan struct{}
Expand Down Expand Up @@ -75,9 +78,12 @@ func Init(ctx context.Context, configStore configstore.ConfigStore, logger schem
configStore: configStore,
logger: logger,
pricingData: make(map[string]configstore.TableModelPricing),
modelPool: make(map[schemas.ModelProvider][]string),
done: make(chan struct{}),
}

logger.Info("initializing pricing manager...")

if configStore != nil {
// Load initial pricing data
if err := pm.loadPricingFromDatabase(ctx); err != nil {
Expand All @@ -88,14 +94,16 @@ func Init(ctx context.Context, configStore configstore.ConfigStore, logger schem
if err := pm.syncPricing(ctx); err != nil {
return nil, fmt.Errorf("failed to sync pricing data: %w", err)
}

} else {
// Load pricing data from config memory
if err := pm.loadPricingIntoMemory(ctx); err != nil {
return nil, fmt.Errorf("failed to load pricing data from config memory: %w", err)
}
}

// Populate model pool with normalized providers
pm.populateModelPool()

// Start background sync worker
pm.syncCtx, pm.syncCancel = context.WithCancel(ctx)
pm.startSyncWorker(pm.syncCtx)
Expand Down Expand Up @@ -333,6 +341,80 @@ func (pm *PricingManager) CalculateCostFromUsage(provider string, model string,
return totalCost
}

// populateModelPool populates the model pool with all available models per provider (thread-safe)
func (pm *PricingManager) populateModelPool() {
// Acquire write lock for the entire rebuild operation
pm.mu.Lock()
defer pm.mu.Unlock()

// Clear existing model pool
pm.modelPool = make(map[schemas.ModelProvider][]string)

// Map to track unique models per provider
providerModels := make(map[schemas.ModelProvider]map[string]bool)

// Iterate through all pricing data to collect models per provider
for _, pricing := range pm.pricingData {
// Normalize provider before adding to model pool
normalizedProvider := schemas.ModelProvider(normalizeProvider(pricing.Provider))

// Initialize map for this provider if not exists
if providerModels[normalizedProvider] == nil {
providerModels[normalizedProvider] = make(map[string]bool)
}

// Add model to the provider's model set (using map for deduplication)
providerModels[normalizedProvider][pricing.Model] = true
}

// Convert sets to slices and assign to modelPool
for provider, modelSet := range providerModels {
models := make([]string, 0, len(modelSet))
for model := range modelSet {
models = append(models, model)
}
pm.modelPool[provider] = models
}

// Log the populated model pool for debugging
totalModels := 0
for provider, models := range pm.modelPool {
totalModels += len(models)
pm.logger.Debug("populated %d models for provider %s", len(models), string(provider))
}
pm.logger.Info("populated model pool with %d models across %d providers", totalModels, len(pm.modelPool))
}

// GetModelsForProvider returns all available models for a given provider (thread-safe)
func (pm *PricingManager) GetModelsForProvider(provider schemas.ModelProvider) []string {
pm.mu.RLock()
defer pm.mu.RUnlock()

models, exists := pm.modelPool[provider]
if !exists {
return []string{}
}

// Return a copy to prevent external modification
result := make([]string, len(models))
copy(result, models)
return result
}

// GetProvidersForModel returns all providers for a given model (thread-safe)
func (pm *PricingManager) GetProvidersForModel(model string) []schemas.ModelProvider {
pm.mu.RLock()
defer pm.mu.RUnlock()

providers := make([]schemas.ModelProvider, 0)
for provider, models := range pm.modelPool {
if slices.Contains(models, model) {
providers = append(providers, provider)
}
}
return providers
}

// getPricing returns pricing information for a model (thread-safe)
func (pm *PricingManager) getPricing(model, provider string, requestType schemas.RequestType) (*configstore.TableModelPricing, bool) {
pm.mu.RLock()
Expand Down
6 changes: 3 additions & 3 deletions framework/pricing/sync.go
Original file line number Diff line number Diff line change
Expand Up @@ -62,9 +62,9 @@ func (pm *PricingManager) syncPricing(ctx context.Context) error {
pricingData, err := pm.loadPricingFromURL(ctx)
if err != nil {
// Check if we have existing data in database
pricingRecords, err := pm.configStore.GetModelPrices(ctx)
if err != nil {
return fmt.Errorf("failed to get pricing records: %w", err)
pricingRecords, pricingErr := pm.configStore.GetModelPrices(ctx)
if pricingErr != nil {
return fmt.Errorf("failed to get pricing records: %w", pricingErr)
}
if len(pricingRecords) > 0 {
pm.logger.Error("failed to load pricing data from URL, but existing data found in database: %v", err)
Expand Down
Loading