Skip to content

Commit e41d9ec

Browse files
feat: vk provider routing added
1 parent d53f042 commit e41d9ec

File tree

35 files changed

+1299
-593
lines changed

35 files changed

+1299
-593
lines changed

Makefile

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
# Variables
44
HOST ?= localhost
55
PORT ?= 8080
6-
APP_DIR ?=
6+
APP_DIR ?=
77
PROMETHEUS_LABELS ?=
88
LOG_STYLE ?= json
99
LOG_LEVEL ?= info

core/bifrost.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1583,7 +1583,7 @@ func (bifrost *Bifrost) selectKeyFromProviderForModel(ctx *context.Context, prov
15831583
}
15841584

15851585
if len(keys) == 0 {
1586-
return schemas.Key{}, fmt.Errorf("no keys found for provider: %v", providerKey)
1586+
return schemas.Key{}, fmt.Errorf("no keys found for provider: %v and model: %s", providerKey, model)
15871587
}
15881588

15891589
// filter out keys which dont support the model, if the key has no models, it is supported for all models

core/providers/bedrock.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -88,7 +88,7 @@ type BedrockMistralContent struct {
8888
type BedrockMistralChatMessage struct {
8989
Role schemas.ModelChatMessageRole `json:"role"` // Role of the message sender
9090
Content []BedrockMistralContent `json:"content"` // Array of message content
91-
ToolCalls []BedrockAnthropicToolCall `json:"tool_calls,omitempty"` // Optional tool calls
91+
ToolCalls []BedrockAnthropicToolCall `json:"tool_calls,omitempty"` // Optional tool calls
9292
ToolCallID *string `json:"tool_call_id,omitempty"` // Optional tool call ID
9393
}
9494

core/schemas/bifrost.go

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -118,7 +118,7 @@ const (
118118
// a text completion, a chat completion, an embedding request, a speech request, or a transcription request.
119119
type RequestInput struct {
120120
TextCompletionInput *string `json:"text_completion_input,omitempty"`
121-
ChatCompletionInput []BifrostMessage `json:"chat_completion_input,omitempty"`
121+
ChatCompletionInput []BifrostMessage `json:"chat_completion_input,omitempty"`
122122
EmbeddingInput *EmbeddingInput `json:"embedding_input,omitempty"`
123123
SpeechInput *SpeechInput `json:"speech_input,omitempty"`
124124
TranscriptionInput *TranscriptionInput `json:"transcription_input,omitempty"`
@@ -295,12 +295,12 @@ type Fallback struct {
295295
// mapped to the provider's parameters.
296296
type ModelParameters struct {
297297
ToolChoice *ToolChoice `json:"tool_choice,omitempty"` // Whether to call a tool
298-
Tools []Tool `json:"tools,omitempty"` // Tools to use
298+
Tools []Tool `json:"tools,omitempty"` // Tools to use
299299
Temperature *float64 `json:"temperature,omitempty"` // Controls randomness in the output
300300
TopP *float64 `json:"top_p,omitempty"` // Controls diversity via nucleus sampling
301301
TopK *int `json:"top_k,omitempty"` // Controls diversity via top-k sampling
302302
MaxTokens *int `json:"max_tokens,omitempty"` // Maximum number of tokens to generate
303-
StopSequences []string `json:"stop_sequences,omitempty"` // Sequences that stop generation
303+
StopSequences []string `json:"stop_sequences,omitempty"` // Sequences that stop generation
304304
PresencePenalty *float64 `json:"presence_penalty,omitempty"` // Penalizes repeated tokens
305305
FrequencyPenalty *float64 `json:"frequency_penalty,omitempty"` // Penalizes frequent tokens
306306
ParallelToolCalls *bool `json:"parallel_tool_calls,omitempty"` // Enables parallel tool calls
@@ -318,7 +318,7 @@ type FunctionParameters struct {
318318
Description *string `json:"description,omitempty"` // Description of the parameters
319319
Required []string `json:"required,omitempty"` // Required parameter names
320320
Properties map[string]interface{} `json:"properties,omitempty"` // Parameter properties
321-
Enum []string `json:"enum,omitempty"` // Enum values for the parameters
321+
Enum []string `json:"enum,omitempty"` // Enum values for the parameters
322322
}
323323

324324
// Function represents a function that can be called by the model.
@@ -492,7 +492,7 @@ type ToolMessage struct {
492492
type AssistantMessage struct {
493493
Refusal *string `json:"refusal,omitempty"`
494494
Annotations []Annotation `json:"annotations,omitempty"`
495-
ToolCalls []ToolCall `json:"tool_calls,omitempty"`
495+
ToolCalls []ToolCall `json:"tool_calls,omitempty"`
496496
Thought *string `json:"thought,omitempty"`
497497
}
498498

@@ -795,7 +795,7 @@ type BifrostResponseExtraFields struct {
795795
Provider ModelProvider `json:"provider"`
796796
Params ModelParameters `json:"model_params"`
797797
Latency *int64 `json:"latency,omitempty"`
798-
ChatHistory []BifrostMessage `json:"chat_history,omitempty"`
798+
ChatHistory []BifrostMessage `json:"chat_history,omitempty"`
799799
BilledUsage *BilledLLMUsage `json:"billed_usage,omitempty"`
800800
ChunkIndex int `json:"chunk_index"` // used for streaming responses to identify the chunk index, will be 0 for non-streaming responses
801801
RawResponse interface{} `json:"raw_response,omitempty"`

framework/configstore/migrations.go

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,9 @@ func triggerMigrations(ctx context.Context, db *gorm.DB) error {
1919
if err := migrationAddCustomProviderConfigJSONColumn(ctx, db); err != nil {
2020
return err
2121
}
22+
if err := migrationAddVirtualKeyProviderConfigTable(ctx, db); err != nil {
23+
return err
24+
}
2225
return nil
2326
}
2427

@@ -247,3 +250,35 @@ func migrationAddCustomProviderConfigJSONColumn(ctx context.Context, db *gorm.DB
247250
}
248251
return nil
249252
}
253+
254+
func migrationAddVirtualKeyProviderConfigTable(ctx context.Context, db *gorm.DB) error {
255+
m := migration.New(db, migration.DefaultOptions, []*migration.Migration{{
256+
ID: "addvirtualkeyproviderconfig",
257+
Migrate: func(tx *gorm.DB) error {
258+
tx = tx.WithContext(ctx)
259+
migrator := tx.Migrator()
260+
261+
if !migrator.HasTable(&TableVirtualKeyProviderConfig{}) {
262+
if err := migrator.CreateTable(&TableVirtualKeyProviderConfig{}); err != nil {
263+
return err
264+
}
265+
}
266+
267+
return nil
268+
},
269+
Rollback: func(tx *gorm.DB) error {
270+
tx = tx.WithContext(ctx)
271+
migrator := tx.Migrator()
272+
273+
if err := migrator.DropTable(&TableVirtualKeyProviderConfig{}); err != nil {
274+
return err
275+
}
276+
return nil
277+
},
278+
}})
279+
err := m.Migrate()
280+
if err != nil {
281+
return fmt.Errorf("error while running db migration: %s", err.Error())
282+
}
283+
return nil
284+
}

framework/configstore/store.go

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -55,10 +55,17 @@ type ConfigStore interface {
5555
// Governance config CRUD
5656
GetVirtualKeys(ctx context.Context) ([]TableVirtualKey, error)
5757
GetVirtualKey(ctx context.Context, id string) (*TableVirtualKey, error)
58+
GetVirtualKeyByValue(ctx context.Context, value string) (*TableVirtualKey, error)
5859
CreateVirtualKey(ctx context.Context, virtualKey *TableVirtualKey, tx ...*gorm.DB) error
5960
UpdateVirtualKey(ctx context.Context, virtualKey *TableVirtualKey, tx ...*gorm.DB) error
6061
DeleteVirtualKey(ctx context.Context, id string) error
6162

63+
// Virtual key provider config CRUD
64+
GetVirtualKeyProviderConfigs(ctx context.Context, virtualKeyID string) ([]TableVirtualKeyProviderConfig, error)
65+
CreateVirtualKeyProviderConfig(ctx context.Context, virtualKeyProviderConfig *TableVirtualKeyProviderConfig, tx ...*gorm.DB) error
66+
UpdateVirtualKeyProviderConfig(ctx context.Context, virtualKeyProviderConfig *TableVirtualKeyProviderConfig, tx ...*gorm.DB) error
67+
DeleteVirtualKeyProviderConfig(ctx context.Context, id uint, tx ...*gorm.DB) error
68+
6269
// Team CRUD
6370
GetTeams(ctx context.Context, customerID string) ([]TableTeam, error)
6471
GetTeam(ctx context.Context, id string) (*TableTeam, error)

framework/configstore/tables.go

Lines changed: 23 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -616,13 +616,12 @@ type TableTeam struct {
616616

617617
// TableVirtualKey represents a virtual key with budget, rate limits, and team/customer association
618618
type TableVirtualKey struct {
619-
ID string `gorm:"primaryKey;type:varchar(255)" json:"id"`
620-
Name string `gorm:"uniqueIndex:idx_virtual_key_name;type:varchar(255);not null" json:"name"`
621-
Description string `gorm:"type:text" json:"description,omitempty"`
622-
Value string `gorm:"uniqueIndex:idx_virtual_key_value;type:varchar(255);not null" json:"value"` // The virtual key value
623-
IsActive bool `gorm:"default:true" json:"is_active"`
624-
AllowedModels []string `gorm:"type:text;serializer:json" json:"allowed_models"` // Empty means all models allowed
625-
AllowedProviders []string `gorm:"type:text;serializer:json" json:"allowed_providers"` // Empty means all providers allowed
619+
ID string `gorm:"primaryKey;type:varchar(255)" json:"id"`
620+
Name string `gorm:"uniqueIndex:idx_virtual_key_name;type:varchar(255);not null" json:"name"`
621+
Description string `gorm:"type:text" json:"description,omitempty"`
622+
Value string `gorm:"uniqueIndex:idx_virtual_key_value;type:varchar(255);not null" json:"value"` // The virtual key value
623+
IsActive bool `gorm:"default:true" json:"is_active"`
624+
ProviderConfigs []TableVirtualKeyProviderConfig `gorm:"foreignKey:VirtualKeyID;constraint:OnDelete:CASCADE" json:"provider_configs"` // Empty means all providers allowed
626625

627626
// Foreign key relationships (mutually exclusive: either TeamID or CustomerID, not both)
628627
TeamID *string `gorm:"type:varchar(255);index" json:"team_id,omitempty"`
@@ -641,6 +640,15 @@ type TableVirtualKey struct {
641640
UpdatedAt time.Time `gorm:"index;not null" json:"updated_at"`
642641
}
643642

643+
// TableVirtualKeyProviderConfig represents a provider configuration for a virtual key
644+
type TableVirtualKeyProviderConfig struct {
645+
ID uint `gorm:"primaryKey;autoIncrement" json:"id"`
646+
VirtualKeyID string `gorm:"type:varchar(255);not null" json:"virtual_key_id"`
647+
Provider string `gorm:"type:varchar(50);not null" json:"provider"`
648+
Weight float64 `gorm:"default:1.0" json:"weight"`
649+
AllowedModels []string `gorm:"type:text;serializer:json" json:"allowed_models"` // Empty means all models allowed
650+
}
651+
644652
// TableModelPricing represents pricing information for AI models
645653
type TableModelPricing struct {
646654
ID uint `gorm:"primaryKey;autoIncrement" json:"id"`
@@ -675,11 +683,14 @@ type TableModelPricing struct {
675683
}
676684

677685
// Table names
678-
func (TableBudget) TableName() string { return "governance_budgets" }
679-
func (TableRateLimit) TableName() string { return "governance_rate_limits" }
680-
func (TableCustomer) TableName() string { return "governance_customers" }
681-
func (TableTeam) TableName() string { return "governance_teams" }
682-
func (TableVirtualKey) TableName() string { return "governance_virtual_keys" }
686+
func (TableBudget) TableName() string { return "governance_budgets" }
687+
func (TableRateLimit) TableName() string { return "governance_rate_limits" }
688+
func (TableCustomer) TableName() string { return "governance_customers" }
689+
func (TableTeam) TableName() string { return "governance_teams" }
690+
func (TableVirtualKey) TableName() string { return "governance_virtual_keys" }
691+
func (TableVirtualKeyProviderConfig) TableName() string {
692+
return "governance_virtual_key_provider_configs"
693+
}
683694
func (TableConfig) TableName() string { return "governance_config" }
684695
func (TableModelPricing) TableName() string { return "governance_model_pricing" }
685696

framework/logstore/tables.go

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -102,10 +102,10 @@ type Log struct {
102102
// Virtual fields for JSON output - these will be populated when needed
103103
InputHistoryParsed []schemas.BifrostMessage `gorm:"-" json:"input_history,omitempty"`
104104
OutputMessageParsed *schemas.BifrostMessage `gorm:"-" json:"output_message,omitempty"`
105-
EmbeddingOutputParsed []schemas.BifrostEmbedding `gorm:"-" json:"embedding_output,omitempty"`
105+
EmbeddingOutputParsed []schemas.BifrostEmbedding `gorm:"-" json:"embedding_output,omitempty"`
106106
ParamsParsed *schemas.ModelParameters `gorm:"-" json:"params,omitempty"`
107-
ToolsParsed []schemas.Tool `gorm:"-" json:"tools,omitempty"`
108-
ToolCallsParsed []schemas.ToolCall `gorm:"-" json:"tool_calls,omitempty"`
107+
ToolsParsed []schemas.Tool `gorm:"-" json:"tools,omitempty"`
108+
ToolCallsParsed []schemas.ToolCall `gorm:"-" json:"tool_calls,omitempty"`
109109
TokenUsageParsed *schemas.LLMUsage `gorm:"-" json:"token_usage,omitempty"`
110110
ErrorDetailsParsed *schemas.BifrostError `gorm:"-" json:"error_details,omitempty"`
111111
SpeechInputParsed *schemas.SpeechInput `gorm:"-" json:"speech_input,omitempty"`

framework/pricing/main.go

Lines changed: 83 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@ package pricing
33
import (
44
"context"
55
"fmt"
6+
"slices"
67
"sync"
78
"time"
89

@@ -26,6 +27,8 @@ type PricingManager struct {
2627
pricingData map[string]configstore.TableModelPricing
2728
mu sync.RWMutex
2829

30+
modelPool map[schemas.ModelProvider][]string
31+
2932
// Background sync worker
3033
syncTicker *time.Ticker
3134
done chan struct{}
@@ -75,9 +78,12 @@ func Init(ctx context.Context, configStore configstore.ConfigStore, logger schem
7578
configStore: configStore,
7679
logger: logger,
7780
pricingData: make(map[string]configstore.TableModelPricing),
81+
modelPool: make(map[schemas.ModelProvider][]string),
7882
done: make(chan struct{}),
7983
}
8084

85+
logger.Info("initializing pricing manager...")
86+
8187
if configStore != nil {
8288
// Load initial pricing data
8389
if err := pm.loadPricingFromDatabase(ctx); err != nil {
@@ -88,14 +94,16 @@ func Init(ctx context.Context, configStore configstore.ConfigStore, logger schem
8894
if err := pm.syncPricing(ctx); err != nil {
8995
return nil, fmt.Errorf("failed to sync pricing data: %w", err)
9096
}
91-
9297
} else {
9398
// Load pricing data from config memory
9499
if err := pm.loadPricingIntoMemory(ctx); err != nil {
95100
return nil, fmt.Errorf("failed to load pricing data from config memory: %w", err)
96101
}
97102
}
98103

104+
// Populate model pool with normalized providers
105+
pm.populateModelPool()
106+
99107
// Start background sync worker
100108
pm.syncCtx, pm.syncCancel = context.WithCancel(ctx)
101109
pm.startSyncWorker(pm.syncCtx)
@@ -333,6 +341,80 @@ func (pm *PricingManager) CalculateCostFromUsage(provider string, model string,
333341
return totalCost
334342
}
335343

344+
// populateModelPool populates the model pool with all available models per provider (thread-safe)
345+
func (pm *PricingManager) populateModelPool() {
346+
// Acquire write lock for the entire rebuild operation
347+
pm.mu.Lock()
348+
defer pm.mu.Unlock()
349+
350+
// Clear existing model pool
351+
pm.modelPool = make(map[schemas.ModelProvider][]string)
352+
353+
// Map to track unique models per provider
354+
providerModels := make(map[schemas.ModelProvider]map[string]bool)
355+
356+
// Iterate through all pricing data to collect models per provider
357+
for _, pricing := range pm.pricingData {
358+
// Normalize provider before adding to model pool
359+
normalizedProvider := schemas.ModelProvider(normalizeProvider(pricing.Provider))
360+
361+
// Initialize map for this provider if not exists
362+
if providerModels[normalizedProvider] == nil {
363+
providerModels[normalizedProvider] = make(map[string]bool)
364+
}
365+
366+
// Add model to the provider's model set (using map for deduplication)
367+
providerModels[normalizedProvider][pricing.Model] = true
368+
}
369+
370+
// Convert sets to slices and assign to modelPool
371+
for provider, modelSet := range providerModels {
372+
models := make([]string, 0, len(modelSet))
373+
for model := range modelSet {
374+
models = append(models, model)
375+
}
376+
pm.modelPool[provider] = models
377+
}
378+
379+
// Log the populated model pool for debugging
380+
totalModels := 0
381+
for provider, models := range pm.modelPool {
382+
totalModels += len(models)
383+
pm.logger.Debug("populated %d models for provider %s", len(models), string(provider))
384+
}
385+
pm.logger.Info("populated model pool with %d models across %d providers", totalModels, len(pm.modelPool))
386+
}
387+
388+
// GetModelsForProvider returns all available models for a given provider (thread-safe)
389+
func (pm *PricingManager) GetModelsForProvider(provider schemas.ModelProvider) []string {
390+
pm.mu.RLock()
391+
defer pm.mu.RUnlock()
392+
393+
models, exists := pm.modelPool[provider]
394+
if !exists {
395+
return []string{}
396+
}
397+
398+
// Return a copy to prevent external modification
399+
result := make([]string, len(models))
400+
copy(result, models)
401+
return result
402+
}
403+
404+
// GetProvidersForModel returns all providers for a given model (thread-safe)
405+
func (pm *PricingManager) GetProvidersForModel(model string) []schemas.ModelProvider {
406+
pm.mu.RLock()
407+
defer pm.mu.RUnlock()
408+
409+
providers := make([]schemas.ModelProvider, 0)
410+
for provider, models := range pm.modelPool {
411+
if slices.Contains(models, model) {
412+
providers = append(providers, provider)
413+
}
414+
}
415+
return providers
416+
}
417+
336418
// getPricing returns pricing information for a model (thread-safe)
337419
func (pm *PricingManager) getPricing(model, provider string, requestType schemas.RequestType) (*configstore.TableModelPricing, bool) {
338420
pm.mu.RLock()

framework/pricing/sync.go

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -62,9 +62,9 @@ func (pm *PricingManager) syncPricing(ctx context.Context) error {
6262
pricingData, err := pm.loadPricingFromURL(ctx)
6363
if err != nil {
6464
// Check if we have existing data in database
65-
pricingRecords, err := pm.configStore.GetModelPrices(ctx)
66-
if err != nil {
67-
return fmt.Errorf("failed to get pricing records: %w", err)
65+
pricingRecords, pricingErr := pm.configStore.GetModelPrices(ctx)
66+
if pricingErr != nil {
67+
return fmt.Errorf("failed to get pricing records: %w", pricingErr)
6868
}
6969
if len(pricingRecords) > 0 {
7070
pm.logger.Error("failed to load pricing data from URL, but existing data found in database: %v", err)

0 commit comments

Comments
 (0)