From 5f648cc45463c11bfc13d061551f62dd56564a47 Mon Sep 17 00:00:00 2001 From: Pratham-Mishra04 Date: Thu, 2 Oct 2025 18:33:03 +0530 Subject: [PATCH] feat: vk provider routing added --- Makefile | 2 +- core/bifrost.go | 2 +- core/providers/bedrock.go | 2 +- core/schemas/bifrost.go | 12 +- framework/configstore/migrations.go | 35 + framework/configstore/store.go | 7 + framework/configstore/tables.go | 35 +- framework/logstore/tables.go | 6 +- framework/pricing/main.go | 84 +- framework/pricing/sync.go | 6 +- plugins/governance/resolver.go | 41 +- plugins/governance/tracker.go | 1 - plugins/maxim/main.go | 2 +- plugins/semanticcache/test_utils.go | 6 +- .../bifrost-http/handlers/completions.go | 4 +- .../bifrost-http/handlers/governance.go | 139 ++- .../bifrost-http/handlers/middlewares.go | 132 +++ transports/bifrost-http/handlers/server.go | 8 +- .../integrations/anthropic/types.go | 4 +- .../bifrost-http/integrations/openai/types.go | 6 +- transports/bifrost-http/lib/config.go | 9 +- .../fragments/apiKeysFormFragment.tsx | 12 +- .../teams-customers/views/customerTable.tsx | 31 +- ui/app/teams-customers/views/teamDialog.tsx | 7 +- .../views/virtualKeyDetailsDialog.tsx | 142 ++- .../virtual-keys/views/virtualKeyDialog.tsx | 965 +++++++++++------- .../virtual-keys/views/virtualKeysTable.tsx | 5 +- ui/components/ui/dialog.tsx | 8 +- ui/components/ui/multiSelect.tsx | 2 +- ui/components/ui/numberAndSelect.tsx | 39 +- ui/components/ui/tagInput.tsx | 30 +- ui/hooks/useDebounce.ts | 35 + ui/lib/constants/config.ts | 37 + ui/lib/constants/logs.ts | 12 +- ui/lib/types/governance.ts | 16 +- 35 files changed, 1295 insertions(+), 589 deletions(-) create mode 100644 ui/hooks/useDebounce.ts diff --git a/Makefile b/Makefile index f7e8b72dc..9aee60b8a 100644 --- a/Makefile +++ b/Makefile @@ -3,7 +3,7 @@ # Variables HOST ?= localhost PORT ?= 8080 -APP_DIR ?= +APP_DIR ?= PROMETHEUS_LABELS ?= LOG_STYLE ?= json LOG_LEVEL ?= info diff --git a/core/bifrost.go b/core/bifrost.go index 88532a5f4..644e1becb 100644 --- a/core/bifrost.go +++ b/core/bifrost.go @@ -1583,7 +1583,7 @@ func (bifrost *Bifrost) selectKeyFromProviderForModel(ctx *context.Context, prov } if len(keys) == 0 { - return schemas.Key{}, fmt.Errorf("no keys found for provider: %v", providerKey) + return schemas.Key{}, fmt.Errorf("no keys found for provider: %v and model: %s", providerKey, model) } // filter out keys which dont support the model, if the key has no models, it is supported for all models diff --git a/core/providers/bedrock.go b/core/providers/bedrock.go index 92d9c07b0..93442107b 100644 --- a/core/providers/bedrock.go +++ b/core/providers/bedrock.go @@ -88,7 +88,7 @@ type BedrockMistralContent struct { type BedrockMistralChatMessage struct { Role schemas.ModelChatMessageRole `json:"role"` // Role of the message sender Content []BedrockMistralContent `json:"content"` // Array of message content - ToolCalls []BedrockAnthropicToolCall `json:"tool_calls,omitempty"` // Optional tool calls + ToolCalls []BedrockAnthropicToolCall `json:"tool_calls,omitempty"` // Optional tool calls ToolCallID *string `json:"tool_call_id,omitempty"` // Optional tool call ID } diff --git a/core/schemas/bifrost.go b/core/schemas/bifrost.go index a64b477df..7c6ea193c 100644 --- a/core/schemas/bifrost.go +++ b/core/schemas/bifrost.go @@ -118,7 +118,7 @@ const ( // a text completion, a chat completion, an embedding request, a speech request, or a transcription request. type RequestInput struct { TextCompletionInput *string `json:"text_completion_input,omitempty"` - ChatCompletionInput []BifrostMessage `json:"chat_completion_input,omitempty"` + ChatCompletionInput []BifrostMessage `json:"chat_completion_input,omitempty"` EmbeddingInput *EmbeddingInput `json:"embedding_input,omitempty"` SpeechInput *SpeechInput `json:"speech_input,omitempty"` TranscriptionInput *TranscriptionInput `json:"transcription_input,omitempty"` @@ -295,12 +295,12 @@ type Fallback struct { // mapped to the provider's parameters. type ModelParameters struct { ToolChoice *ToolChoice `json:"tool_choice,omitempty"` // Whether to call a tool - Tools []Tool `json:"tools,omitempty"` // Tools to use + Tools []Tool `json:"tools,omitempty"` // Tools to use Temperature *float64 `json:"temperature,omitempty"` // Controls randomness in the output TopP *float64 `json:"top_p,omitempty"` // Controls diversity via nucleus sampling TopK *int `json:"top_k,omitempty"` // Controls diversity via top-k sampling MaxTokens *int `json:"max_tokens,omitempty"` // Maximum number of tokens to generate - StopSequences []string `json:"stop_sequences,omitempty"` // Sequences that stop generation + StopSequences []string `json:"stop_sequences,omitempty"` // Sequences that stop generation PresencePenalty *float64 `json:"presence_penalty,omitempty"` // Penalizes repeated tokens FrequencyPenalty *float64 `json:"frequency_penalty,omitempty"` // Penalizes frequent tokens ParallelToolCalls *bool `json:"parallel_tool_calls,omitempty"` // Enables parallel tool calls @@ -318,7 +318,7 @@ type FunctionParameters struct { Description *string `json:"description,omitempty"` // Description of the parameters Required []string `json:"required,omitempty"` // Required parameter names Properties map[string]interface{} `json:"properties,omitempty"` // Parameter properties - Enum []string `json:"enum,omitempty"` // Enum values for the parameters + Enum []string `json:"enum,omitempty"` // Enum values for the parameters } // Function represents a function that can be called by the model. @@ -492,7 +492,7 @@ type ToolMessage struct { type AssistantMessage struct { Refusal *string `json:"refusal,omitempty"` Annotations []Annotation `json:"annotations,omitempty"` - ToolCalls []ToolCall `json:"tool_calls,omitempty"` + ToolCalls []ToolCall `json:"tool_calls,omitempty"` Thought *string `json:"thought,omitempty"` } @@ -795,7 +795,7 @@ type BifrostResponseExtraFields struct { Provider ModelProvider `json:"provider"` Params ModelParameters `json:"model_params"` Latency *int64 `json:"latency,omitempty"` - ChatHistory []BifrostMessage `json:"chat_history,omitempty"` + ChatHistory []BifrostMessage `json:"chat_history,omitempty"` BilledUsage *BilledLLMUsage `json:"billed_usage,omitempty"` ChunkIndex int `json:"chunk_index"` // used for streaming responses to identify the chunk index, will be 0 for non-streaming responses RawResponse interface{} `json:"raw_response,omitempty"` diff --git a/framework/configstore/migrations.go b/framework/configstore/migrations.go index a00af140c..06aa9cf9e 100644 --- a/framework/configstore/migrations.go +++ b/framework/configstore/migrations.go @@ -19,6 +19,9 @@ func triggerMigrations(ctx context.Context, db *gorm.DB) error { if err := migrationAddCustomProviderConfigJSONColumn(ctx, db); err != nil { return err } + if err := migrationAddVirtualKeyProviderConfigTable(ctx, db); err != nil { + return err + } return nil } @@ -247,3 +250,35 @@ func migrationAddCustomProviderConfigJSONColumn(ctx context.Context, db *gorm.DB } return nil } + +func migrationAddVirtualKeyProviderConfigTable(ctx context.Context, db *gorm.DB) error { + m := migration.New(db, migration.DefaultOptions, []*migration.Migration{{ + ID: "addvirtualkeyproviderconfig", + Migrate: func(tx *gorm.DB) error { + tx = tx.WithContext(ctx) + migrator := tx.Migrator() + + if !migrator.HasTable(&TableVirtualKeyProviderConfig{}) { + if err := migrator.CreateTable(&TableVirtualKeyProviderConfig{}); err != nil { + return err + } + } + + return nil + }, + Rollback: func(tx *gorm.DB) error { + tx = tx.WithContext(ctx) + migrator := tx.Migrator() + + if err := migrator.DropTable(&TableVirtualKeyProviderConfig{}); err != nil { + return err + } + return nil + }, + }}) + err := m.Migrate() + if err != nil { + return fmt.Errorf("error while running db migration: %s", err.Error()) + } + return nil +} diff --git a/framework/configstore/store.go b/framework/configstore/store.go index 5428e832d..17f737aee 100644 --- a/framework/configstore/store.go +++ b/framework/configstore/store.go @@ -55,10 +55,17 @@ type ConfigStore interface { // Governance config CRUD GetVirtualKeys(ctx context.Context) ([]TableVirtualKey, error) GetVirtualKey(ctx context.Context, id string) (*TableVirtualKey, error) + GetVirtualKeyByValue(ctx context.Context, value string) (*TableVirtualKey, error) CreateVirtualKey(ctx context.Context, virtualKey *TableVirtualKey, tx ...*gorm.DB) error UpdateVirtualKey(ctx context.Context, virtualKey *TableVirtualKey, tx ...*gorm.DB) error DeleteVirtualKey(ctx context.Context, id string) error + // Virtual key provider config CRUD + GetVirtualKeyProviderConfigs(ctx context.Context, virtualKeyID string) ([]TableVirtualKeyProviderConfig, error) + CreateVirtualKeyProviderConfig(ctx context.Context, virtualKeyProviderConfig *TableVirtualKeyProviderConfig, tx ...*gorm.DB) error + UpdateVirtualKeyProviderConfig(ctx context.Context, virtualKeyProviderConfig *TableVirtualKeyProviderConfig, tx ...*gorm.DB) error + DeleteVirtualKeyProviderConfig(ctx context.Context, id uint, tx ...*gorm.DB) error + // Team CRUD GetTeams(ctx context.Context, customerID string) ([]TableTeam, error) GetTeam(ctx context.Context, id string) (*TableTeam, error) diff --git a/framework/configstore/tables.go b/framework/configstore/tables.go index ddc03045c..1ef2b0af9 100644 --- a/framework/configstore/tables.go +++ b/framework/configstore/tables.go @@ -616,13 +616,12 @@ type TableTeam struct { // TableVirtualKey represents a virtual key with budget, rate limits, and team/customer association type TableVirtualKey struct { - ID string `gorm:"primaryKey;type:varchar(255)" json:"id"` - Name string `gorm:"uniqueIndex:idx_virtual_key_name;type:varchar(255);not null" json:"name"` - Description string `gorm:"type:text" json:"description,omitempty"` - Value string `gorm:"uniqueIndex:idx_virtual_key_value;type:varchar(255);not null" json:"value"` // The virtual key value - IsActive bool `gorm:"default:true" json:"is_active"` - AllowedModels []string `gorm:"type:text;serializer:json" json:"allowed_models"` // Empty means all models allowed - AllowedProviders []string `gorm:"type:text;serializer:json" json:"allowed_providers"` // Empty means all providers allowed + ID string `gorm:"primaryKey;type:varchar(255)" json:"id"` + Name string `gorm:"uniqueIndex:idx_virtual_key_name;type:varchar(255);not null" json:"name"` + Description string `gorm:"type:text" json:"description,omitempty"` + Value string `gorm:"uniqueIndex:idx_virtual_key_value;type:varchar(255);not null" json:"value"` // The virtual key value + IsActive bool `gorm:"default:true" json:"is_active"` + ProviderConfigs []TableVirtualKeyProviderConfig `gorm:"foreignKey:VirtualKeyID;constraint:OnDelete:CASCADE" json:"provider_configs"` // Empty means all providers allowed // Foreign key relationships (mutually exclusive: either TeamID or CustomerID, not both) TeamID *string `gorm:"type:varchar(255);index" json:"team_id,omitempty"` @@ -641,6 +640,15 @@ type TableVirtualKey struct { UpdatedAt time.Time `gorm:"index;not null" json:"updated_at"` } +// TableVirtualKeyProviderConfig represents a provider configuration for a virtual key +type TableVirtualKeyProviderConfig struct { + ID uint `gorm:"primaryKey;autoIncrement" json:"id"` + VirtualKeyID string `gorm:"type:varchar(255);not null" json:"virtual_key_id"` + Provider string `gorm:"type:varchar(50);not null" json:"provider"` + Weight float64 `gorm:"default:1.0" json:"weight"` + AllowedModels []string `gorm:"type:text;serializer:json" json:"allowed_models"` // Empty means all models allowed +} + // TableModelPricing represents pricing information for AI models type TableModelPricing struct { ID uint `gorm:"primaryKey;autoIncrement" json:"id"` @@ -675,11 +683,14 @@ type TableModelPricing struct { } // Table names -func (TableBudget) TableName() string { return "governance_budgets" } -func (TableRateLimit) TableName() string { return "governance_rate_limits" } -func (TableCustomer) TableName() string { return "governance_customers" } -func (TableTeam) TableName() string { return "governance_teams" } -func (TableVirtualKey) TableName() string { return "governance_virtual_keys" } +func (TableBudget) TableName() string { return "governance_budgets" } +func (TableRateLimit) TableName() string { return "governance_rate_limits" } +func (TableCustomer) TableName() string { return "governance_customers" } +func (TableTeam) TableName() string { return "governance_teams" } +func (TableVirtualKey) TableName() string { return "governance_virtual_keys" } +func (TableVirtualKeyProviderConfig) TableName() string { + return "governance_virtual_key_provider_configs" +} func (TableConfig) TableName() string { return "governance_config" } func (TableModelPricing) TableName() string { return "governance_model_pricing" } diff --git a/framework/logstore/tables.go b/framework/logstore/tables.go index 6257ef46e..4cd833f59 100644 --- a/framework/logstore/tables.go +++ b/framework/logstore/tables.go @@ -102,10 +102,10 @@ type Log struct { // Virtual fields for JSON output - these will be populated when needed InputHistoryParsed []schemas.BifrostMessage `gorm:"-" json:"input_history,omitempty"` OutputMessageParsed *schemas.BifrostMessage `gorm:"-" json:"output_message,omitempty"` - EmbeddingOutputParsed []schemas.BifrostEmbedding `gorm:"-" json:"embedding_output,omitempty"` + EmbeddingOutputParsed []schemas.BifrostEmbedding `gorm:"-" json:"embedding_output,omitempty"` ParamsParsed *schemas.ModelParameters `gorm:"-" json:"params,omitempty"` - ToolsParsed []schemas.Tool `gorm:"-" json:"tools,omitempty"` - ToolCallsParsed []schemas.ToolCall `gorm:"-" json:"tool_calls,omitempty"` + ToolsParsed []schemas.Tool `gorm:"-" json:"tools,omitempty"` + ToolCallsParsed []schemas.ToolCall `gorm:"-" json:"tool_calls,omitempty"` TokenUsageParsed *schemas.LLMUsage `gorm:"-" json:"token_usage,omitempty"` ErrorDetailsParsed *schemas.BifrostError `gorm:"-" json:"error_details,omitempty"` SpeechInputParsed *schemas.SpeechInput `gorm:"-" json:"speech_input,omitempty"` diff --git a/framework/pricing/main.go b/framework/pricing/main.go index 9c880e6db..7807dc282 100644 --- a/framework/pricing/main.go +++ b/framework/pricing/main.go @@ -3,6 +3,7 @@ package pricing import ( "context" "fmt" + "slices" "sync" "time" @@ -26,6 +27,8 @@ type PricingManager struct { pricingData map[string]configstore.TableModelPricing mu sync.RWMutex + modelPool map[schemas.ModelProvider][]string + // Background sync worker syncTicker *time.Ticker done chan struct{} @@ -75,9 +78,12 @@ func Init(ctx context.Context, configStore configstore.ConfigStore, logger schem configStore: configStore, logger: logger, pricingData: make(map[string]configstore.TableModelPricing), + modelPool: make(map[schemas.ModelProvider][]string), done: make(chan struct{}), } + logger.Info("initializing pricing manager...") + if configStore != nil { // Load initial pricing data if err := pm.loadPricingFromDatabase(ctx); err != nil { @@ -88,7 +94,6 @@ func Init(ctx context.Context, configStore configstore.ConfigStore, logger schem if err := pm.syncPricing(ctx); err != nil { return nil, fmt.Errorf("failed to sync pricing data: %w", err) } - } else { // Load pricing data from config memory if err := pm.loadPricingIntoMemory(ctx); err != nil { @@ -96,6 +101,9 @@ func Init(ctx context.Context, configStore configstore.ConfigStore, logger schem } } + // Populate model pool with normalized providers + pm.populateModelPool() + // Start background sync worker pm.syncCtx, pm.syncCancel = context.WithCancel(ctx) pm.startSyncWorker(pm.syncCtx) @@ -333,6 +341,80 @@ func (pm *PricingManager) CalculateCostFromUsage(provider string, model string, return totalCost } +// populateModelPool populates the model pool with all available models per provider (thread-safe) +func (pm *PricingManager) populateModelPool() { + // Acquire write lock for the entire rebuild operation + pm.mu.Lock() + defer pm.mu.Unlock() + + // Clear existing model pool + pm.modelPool = make(map[schemas.ModelProvider][]string) + + // Map to track unique models per provider + providerModels := make(map[schemas.ModelProvider]map[string]bool) + + // Iterate through all pricing data to collect models per provider + for _, pricing := range pm.pricingData { + // Normalize provider before adding to model pool + normalizedProvider := schemas.ModelProvider(normalizeProvider(pricing.Provider)) + + // Initialize map for this provider if not exists + if providerModels[normalizedProvider] == nil { + providerModels[normalizedProvider] = make(map[string]bool) + } + + // Add model to the provider's model set (using map for deduplication) + providerModels[normalizedProvider][pricing.Model] = true + } + + // Convert sets to slices and assign to modelPool + for provider, modelSet := range providerModels { + models := make([]string, 0, len(modelSet)) + for model := range modelSet { + models = append(models, model) + } + pm.modelPool[provider] = models + } + + // Log the populated model pool for debugging + totalModels := 0 + for provider, models := range pm.modelPool { + totalModels += len(models) + pm.logger.Debug("populated %d models for provider %s", len(models), string(provider)) + } + pm.logger.Info("populated model pool with %d models across %d providers", totalModels, len(pm.modelPool)) +} + +// GetModelsForProvider returns all available models for a given provider (thread-safe) +func (pm *PricingManager) GetModelsForProvider(provider schemas.ModelProvider) []string { + pm.mu.RLock() + defer pm.mu.RUnlock() + + models, exists := pm.modelPool[provider] + if !exists { + return []string{} + } + + // Return a copy to prevent external modification + result := make([]string, len(models)) + copy(result, models) + return result +} + +// GetProvidersForModel returns all providers for a given model (thread-safe) +func (pm *PricingManager) GetProvidersForModel(model string) []schemas.ModelProvider { + pm.mu.RLock() + defer pm.mu.RUnlock() + + providers := make([]schemas.ModelProvider, 0) + for provider, models := range pm.modelPool { + if slices.Contains(models, model) { + providers = append(providers, provider) + } + } + return providers +} + // getPricing returns pricing information for a model (thread-safe) func (pm *PricingManager) getPricing(model, provider string, requestType schemas.RequestType) (*configstore.TableModelPricing, bool) { pm.mu.RLock() diff --git a/framework/pricing/sync.go b/framework/pricing/sync.go index 1fa7982d8..97a282be5 100644 --- a/framework/pricing/sync.go +++ b/framework/pricing/sync.go @@ -62,9 +62,9 @@ func (pm *PricingManager) syncPricing(ctx context.Context) error { pricingData, err := pm.loadPricingFromURL(ctx) if err != nil { // Check if we have existing data in database - pricingRecords, err := pm.configStore.GetModelPrices(ctx) - if err != nil { - return fmt.Errorf("failed to get pricing records: %w", err) + pricingRecords, pricingErr := pm.configStore.GetModelPrices(ctx) + if pricingErr != nil { + return fmt.Errorf("failed to get pricing records: %w", pricingErr) } if len(pricingRecords) > 0 { pm.logger.Error("failed to load pricing data from URL, but existing data found in database: %v", err) diff --git a/plugins/governance/resolver.go b/plugins/governance/resolver.go index cfc741b73..28614b103 100644 --- a/plugins/governance/resolver.go +++ b/plugins/governance/resolver.go @@ -93,20 +93,20 @@ func (r *BudgetResolver) EvaluateRequest(ctx *context.Context, evaluationRequest } } - // 2. Check model filtering - if !r.isModelAllowed(vk, evaluationRequest.Model) { + // 2. Check provider filtering + if !r.isProviderAllowed(vk, evaluationRequest.Provider) { return &EvaluationResult{ - Decision: DecisionModelBlocked, - Reason: fmt.Sprintf("Model '%s' is not allowed for this virtual key", evaluationRequest.Model), + Decision: DecisionProviderBlocked, + Reason: fmt.Sprintf("Provider '%s' is not allowed for this virtual key", evaluationRequest.Provider), VirtualKey: vk, } } - // 3. Check provider filtering - if !r.isProviderAllowed(vk, evaluationRequest.Provider) { + // 3. Check model filtering + if !r.isModelAllowed(vk, evaluationRequest.Provider, evaluationRequest.Model) { return &EvaluationResult{ - Decision: DecisionProviderBlocked, - Reason: fmt.Sprintf("Provider '%s' is not allowed for this virtual key", evaluationRequest.Provider), + Decision: DecisionModelBlocked, + Reason: fmt.Sprintf("Model '%s' is not allowed for this virtual key", evaluationRequest.Model), VirtualKey: vk, } } @@ -141,23 +141,38 @@ func (r *BudgetResolver) EvaluateRequest(ctx *context.Context, evaluationRequest } // isModelAllowed checks if the requested model is allowed for this VK -func (r *BudgetResolver) isModelAllowed(vk *configstore.TableVirtualKey, model string) bool { +func (r *BudgetResolver) isModelAllowed(vk *configstore.TableVirtualKey, provider schemas.ModelProvider, model string) bool { // Empty AllowedModels means all models are allowed - if len(vk.AllowedModels) == 0 { + if len(vk.ProviderConfigs) == 0 { return true } - return slices.Contains(vk.AllowedModels, model) + for _, pc := range vk.ProviderConfigs { + if pc.Provider == string(provider) { + if len(pc.AllowedModels) == 0 { + return true + } + return slices.Contains(pc.AllowedModels, model) + } + } + + return false } // isProviderAllowed checks if the requested provider is allowed for this VK func (r *BudgetResolver) isProviderAllowed(vk *configstore.TableVirtualKey, provider schemas.ModelProvider) bool { // Empty AllowedProviders means all providers are allowed - if len(vk.AllowedProviders) == 0 { + if len(vk.ProviderConfigs) == 0 { return true } - return slices.Contains(vk.AllowedProviders, string(provider)) + for _, pc := range vk.ProviderConfigs { + if pc.Provider == string(provider) { + return true + } + } + + return false } // checkRateLimits checks the VK's rate limits using flexible approach diff --git a/plugins/governance/tracker.go b/plugins/governance/tracker.go index f6d935d6e..da007b894 100644 --- a/plugins/governance/tracker.go +++ b/plugins/governance/tracker.go @@ -223,7 +223,6 @@ func (t *UsageTracker) PerformStartupResets(ctx context.Context) error { } } } - t.logger.Info("startup reset summary: VKs with RL=%d, without RL=%d, RL resets=%d", vksWithRateLimits, vksWithoutRateLimits, len(resetRateLimits)) if len(errs) > 0 { t.logger.Error("startup reset encountered %d errors: %v", len(errs), errs) return fmt.Errorf("startup reset completed with %d errors", len(errs)) diff --git a/plugins/maxim/main.go b/plugins/maxim/main.go index 94de2a163..bd8567f59 100644 --- a/plugins/maxim/main.go +++ b/plugins/maxim/main.go @@ -392,7 +392,7 @@ func (plugin *Plugin) PostHook(ctxRef *context.Context, res *schemas.BifrostResp } generationID, ok := ctx.Value(GenerationIDKey).(string) if ok { - if bifrostErr != nil { + if bifrostErr != nil { genErr := logging.GenerationError{ Message: bifrostErr.Error.Message, Code: bifrostErr.Error.Code, diff --git a/plugins/semanticcache/test_utils.go b/plugins/semanticcache/test_utils.go index a965463f2..c71612e3b 100644 --- a/plugins/semanticcache/test_utils.go +++ b/plugins/semanticcache/test_utils.go @@ -122,9 +122,9 @@ func NewTestSetup(t *testing.T) *TestSetup { } return NewTestSetupWithConfig(t, &Config{ - Provider: schemas.OpenAI, - EmbeddingModel: "text-embedding-3-small", - Threshold: 0.8, + Provider: schemas.OpenAI, + EmbeddingModel: "text-embedding-3-small", + Threshold: 0.8, CleanUpOnShutdown: true, Keys: []schemas.Key{ { diff --git a/transports/bifrost-http/handlers/completions.go b/transports/bifrost-http/handlers/completions.go index a855c49b5..4852a3cfb 100644 --- a/transports/bifrost-http/handlers/completions.go +++ b/transports/bifrost-http/handlers/completions.go @@ -81,12 +81,12 @@ type CompletionRequest struct { StreamFormat *string `json:"stream_format,omitempty"` ToolChoice *schemas.ToolChoice `json:"tool_choice,omitempty"` // Whether to call a tool - Tools []schemas.Tool `json:"tools,omitempty"` // Tools to use + Tools []schemas.Tool `json:"tools,omitempty"` // Tools to use Temperature *float64 `json:"temperature,omitempty"` // Controls randomness in the output TopP *float64 `json:"top_p,omitempty"` // Controls diversity via nucleus sampling TopK *int `json:"top_k,omitempty"` // Controls diversity via top-k sampling MaxTokens *int `json:"max_tokens,omitempty"` // Maximum number of tokens to generate - StopSequences []string `json:"stop_sequences,omitempty"` // Sequences that stop generation + StopSequences []string `json:"stop_sequences,omitempty"` // Sequences that stop generation PresencePenalty *float64 `json:"presence_penalty,omitempty"` // Penalizes repeated tokens FrequencyPenalty *float64 `json:"frequency_penalty,omitempty"` // Penalizes frequent tokens ParallelToolCalls *bool `json:"parallel_tool_calls,omitempty"` // Enables parallel tool calls diff --git a/transports/bifrost-http/handlers/governance.go b/transports/bifrost-http/handlers/governance.go index ae6e0f0eb..12a084fa9 100644 --- a/transports/bifrost-http/handlers/governance.go +++ b/transports/bifrost-http/handlers/governance.go @@ -40,29 +40,38 @@ func NewGovernanceHandler(plugin *governance.GovernancePlugin, configStore confi // CreateVirtualKeyRequest represents the request body for creating a virtual key type CreateVirtualKeyRequest struct { - Name string `json:"name" validate:"required"` - Description string `json:"description,omitempty"` - AllowedModels []string `json:"allowed_models,omitempty"` // Empty means all models allowed - AllowedProviders []string `json:"allowed_providers,omitempty"` // Empty means all providers allowed - TeamID *string `json:"team_id,omitempty"` // Mutually exclusive with CustomerID - CustomerID *string `json:"customer_id,omitempty"` // Mutually exclusive with TeamID - Budget *CreateBudgetRequest `json:"budget,omitempty"` - RateLimit *CreateRateLimitRequest `json:"rate_limit,omitempty"` - KeyIDs []string `json:"key_ids,omitempty"` // List of DBKey UUIDs to associate with this VirtualKey - IsActive *bool `json:"is_active,omitempty"` + Name string `json:"name" validate:"required"` + Description string `json:"description,omitempty"` + AllowedModels []string `json:"allowed_models,omitempty"` // Empty means all models allowed + ProviderConfigs []struct { + Provider string `json:"provider" validate:"required"` + Weight float64 `json:"weight,omitempty"` + AllowedModels []string `json:"allowed_models,omitempty"` // Empty means all models allowed + } `json:"provider_configs,omitempty"` // Empty means all providers allowed + TeamID *string `json:"team_id,omitempty"` // Mutually exclusive with CustomerID + CustomerID *string `json:"customer_id,omitempty"` // Mutually exclusive with TeamID + Budget *CreateBudgetRequest `json:"budget,omitempty"` + RateLimit *CreateRateLimitRequest `json:"rate_limit,omitempty"` + KeyIDs []string `json:"key_ids,omitempty"` // List of DBKey UUIDs to associate with this VirtualKey + IsActive *bool `json:"is_active,omitempty"` } // UpdateVirtualKeyRequest represents the request body for updating a virtual key type UpdateVirtualKeyRequest struct { - Description *string `json:"description,omitempty"` - AllowedModels []string `json:"allowed_models,omitempty"` - AllowedProviders []string `json:"allowed_providers,omitempty"` - TeamID *string `json:"team_id,omitempty"` - CustomerID *string `json:"customer_id,omitempty"` - Budget *UpdateBudgetRequest `json:"budget,omitempty"` - RateLimit *UpdateRateLimitRequest `json:"rate_limit,omitempty"` - KeyIDs []string `json:"key_ids,omitempty"` // List of DBKey UUIDs to associate with this VirtualKey - IsActive *bool `json:"is_active,omitempty"` + Description *string `json:"description,omitempty"` + AllowedModels []string `json:"allowed_models,omitempty"` + ProviderConfigs []struct { + ID *uint `json:"id,omitempty"` // null for new entries + Provider string `json:"provider" validate:"required"` + Weight float64 `json:"weight,omitempty"` + AllowedModels []string `json:"allowed_models,omitempty"` // Empty means all models allowed + } `json:"provider_configs,omitempty"` + TeamID *string `json:"team_id,omitempty"` + CustomerID *string `json:"customer_id,omitempty"` + Budget *UpdateBudgetRequest `json:"budget,omitempty"` + RateLimit *UpdateRateLimitRequest `json:"rate_limit,omitempty"` + KeyIDs []string `json:"key_ids,omitempty"` // List of DBKey UUIDs to associate with this VirtualKey + IsActive *bool `json:"is_active,omitempty"` } // CreateBudgetRequest represents the request body for creating a budget @@ -216,16 +225,14 @@ func (h *GovernanceHandler) createVirtualKey(ctx *fasthttp.RequestCtx) { } vk = configstore.TableVirtualKey{ - ID: uuid.NewString(), - Name: req.Name, - Value: uuid.NewString(), - Description: req.Description, - AllowedModels: req.AllowedModels, - AllowedProviders: req.AllowedProviders, - TeamID: req.TeamID, - CustomerID: req.CustomerID, - IsActive: isActive, - Keys: keys, // Set the keys for the many-to-many relationship + ID: uuid.NewString(), + Name: req.Name, + Value: uuid.NewString(), + Description: req.Description, + TeamID: req.TeamID, + CustomerID: req.CustomerID, + IsActive: isActive, + Keys: keys, // Set the keys for the many-to-many relationship } if req.Budget != nil { @@ -262,6 +269,19 @@ func (h *GovernanceHandler) createVirtualKey(ctx *fasthttp.RequestCtx) { return err } + if req.ProviderConfigs != nil { + for _, pc := range req.ProviderConfigs { + if err := h.configStore.CreateVirtualKeyProviderConfig(ctx, &configstore.TableVirtualKeyProviderConfig{ + VirtualKeyID: vk.ID, + Provider: pc.Provider, + Weight: pc.Weight, + AllowedModels: pc.AllowedModels, + }, tx); err != nil { + return err + } + } + } + return nil }); err != nil { SendError(ctx, 500, err.Error(), h.logger) @@ -340,12 +360,6 @@ func (h *GovernanceHandler) updateVirtualKey(ctx *fasthttp.RequestCtx) { if req.Description != nil { vk.Description = *req.Description } - if req.AllowedModels != nil { - vk.AllowedModels = req.AllowedModels - } - if req.AllowedProviders != nil { - vk.AllowedProviders = req.AllowedProviders - } if req.TeamID != nil { vk.TeamID = req.TeamID vk.CustomerID = nil // Clear CustomerID if setting TeamID @@ -454,7 +468,7 @@ func (h *GovernanceHandler) updateVirtualKey(ctx *fasthttp.RequestCtx) { var keys []configstore.TableKey if len(req.KeyIDs) > 0 { var err error - keys, err = h.configStore.GetKeysByIDs(ctx,req.KeyIDs) + keys, err = h.configStore.GetKeysByIDs(ctx, req.KeyIDs) if err != nil { return fmt.Errorf("failed to get keys by IDs: %w", err) } @@ -471,6 +485,59 @@ func (h *GovernanceHandler) updateVirtualKey(ctx *fasthttp.RequestCtx) { return err } + if req.ProviderConfigs != nil { + // Get existing provider configs for comparison + var existingConfigs []configstore.TableVirtualKeyProviderConfig + if err := tx.Where("virtual_key_id = ?", vk.ID).Find(&existingConfigs).Error; err != nil { + return err + } + + // Create maps for easier lookup + existingConfigsMap := make(map[uint]configstore.TableVirtualKeyProviderConfig) + for _, config := range existingConfigs { + existingConfigsMap[config.ID] = config + } + + requestConfigsMap := make(map[uint]bool) + + // Process new configs: create new ones and update existing ones + for _, pc := range req.ProviderConfigs { + if pc.ID == nil { + // Create new provider config + if err := h.configStore.CreateVirtualKeyProviderConfig(ctx, &configstore.TableVirtualKeyProviderConfig{ + VirtualKeyID: vk.ID, + Provider: pc.Provider, + Weight: pc.Weight, + AllowedModels: pc.AllowedModels, + }, tx); err != nil { + return err + } + } else { + // Update existing provider config + existing, ok := existingConfigsMap[*pc.ID] + if !ok { + return fmt.Errorf("provider config %d does not belong to this virtual key", *pc.ID) + } + requestConfigsMap[*pc.ID] = true + existing.Provider = pc.Provider + existing.Weight = pc.Weight + existing.AllowedModels = pc.AllowedModels + if err := h.configStore.UpdateVirtualKeyProviderConfig(ctx, &existing, tx); err != nil { + return err + } + } + } + + // Delete provider configs that are not in the request + for id := range existingConfigsMap { + if !requestConfigsMap[id] { + if err := h.configStore.DeleteVirtualKeyProviderConfig(ctx, id, tx); err != nil { + return err + } + } + } + } + return nil }); err != nil { h.logger.Error("failed to update virtual key: %v", err) diff --git a/transports/bifrost-http/handlers/middlewares.go b/transports/bifrost-http/handlers/middlewares.go index b353e2a62..166cc86d6 100644 --- a/transports/bifrost-http/handlers/middlewares.go +++ b/transports/bifrost-http/handlers/middlewares.go @@ -1,6 +1,16 @@ package handlers import ( + "context" + "encoding/json" + "fmt" + "math/rand" + "slices" + "strings" + + "github.com/maximhq/bifrost/core/schemas" + "github.com/maximhq/bifrost/framework/configstore" + "github.com/maximhq/bifrost/plugins/governance" "github.com/maximhq/bifrost/transports/bifrost-http/lib" "github.com/valyala/fasthttp" ) @@ -37,6 +47,128 @@ func CorsMiddleware(config *lib.Config) BifrostHTTPMiddleware { } } +func VKProviderRoutingMiddleware(config *lib.Config, logger schemas.Logger) BifrostHTTPMiddleware { + return func(next fasthttp.RequestHandler) fasthttp.RequestHandler { + return func(ctx *fasthttp.RequestCtx) { + if !config.LoadedPlugins[governance.PluginName] { + next(ctx) + return + } + var virtualKeyValue string + // Extract x-bf-vk header + ctx.Request.Header.All()(func(key, value []byte) bool { + if strings.ToLower(string(key)) == "x-bf-vk" { + virtualKeyValue = string(value) + } + return true + }) + // If no virtual key, continue to next handler + if virtualKeyValue == "" { + next(ctx) + return + } + // Only process POST requests with a body + if string(ctx.Method()) != "POST" { + next(ctx) + return + } + // Get the request body + body := ctx.Request.Body() + if len(body) == 0 { + next(ctx) + return + } + // Parse the request body to extract the model field + var requestBody map[string]interface{} + if err := json.Unmarshal(body, &requestBody); err != nil { + // If we can't parse as JSON, continue without modification + next(ctx) + return + } + // Check if the request has a model field + modelValue, hasModel := requestBody["model"] + if !hasModel { + next(ctx) + return + } + modelStr, ok := modelValue.(string) + if !ok || modelStr == "" { + next(ctx) + return + } + // Check if model already has provider prefix (contains "/") + if strings.Contains(modelStr, "/") { + next(ctx) + return + } + opCtx := context.Background() + virtualKey, err := config.ConfigStore.GetVirtualKeyByValue(opCtx, virtualKeyValue) + if err != nil { + SendError(ctx, fasthttp.StatusInternalServerError, fmt.Sprintf("Failed to get virtual key: %v", err), logger) + return + } + if virtualKey == nil { + SendError(ctx, fasthttp.StatusBadRequest, "Invalid virtual key", logger) + return + } + if !virtualKey.IsActive { + next(ctx) + return + } + // Get provider configs for this virtual key + providerConfigs, err := config.ConfigStore.GetVirtualKeyProviderConfigs(opCtx, virtualKey.ID) + if err != nil { + SendError(ctx, fasthttp.StatusInternalServerError, fmt.Sprintf("Failed to get virtual key provider configs: %v", err), logger) + return + } + if len(providerConfigs) == 0 { + // No provider configs, continue without modification + next(ctx) + return + } + allowedProviderConfigs := make([]configstore.TableVirtualKeyProviderConfig, 0) + for _, config := range providerConfigs { + if len(config.AllowedModels) == 0 || slices.Contains(config.AllowedModels, modelStr) { + allowedProviderConfigs = append(allowedProviderConfigs, config) + } + } + if len(allowedProviderConfigs) == 0 { + // No allowed provider configs, continue without modification + next(ctx) + return + } + // Weighted random selection from allowed providers + totalWeight := 0.0 + for _, config := range allowedProviderConfigs { + totalWeight += config.Weight + } + // Generate random number between 0 and totalWeight + randomValue := rand.Float64() * totalWeight + // Select provider based on weighted random selection + var selectedProvider schemas.ModelProvider + currentWeight := 0.0 + for _, config := range allowedProviderConfigs { + currentWeight += config.Weight + if randomValue <= currentWeight { + selectedProvider = schemas.ModelProvider(config.Provider) + break + } + } + // Update the model field in the request body + requestBody["model"] = string(selectedProvider) + "/" + modelStr + // Marshal the updated request body back to JSON + updatedBody, err := json.Marshal(requestBody) + if err != nil { + SendError(ctx, fasthttp.StatusInternalServerError, fmt.Sprintf("Failed to marshal updated request body: %v", err), logger) + return + } + // Replace the request body with the updated one + ctx.Request.SetBody(updatedBody) + next(ctx) + } + } +} + // ChainMiddlewares chains multiple middlewares together // Middlewares are applied in order: the first middleware wraps the second, etc. // This allows earlier middlewares to short-circuit by not calling next(ctx) diff --git a/transports/bifrost-http/handlers/server.go b/transports/bifrost-http/handlers/server.go index 29481563c..78d409035 100644 --- a/transports/bifrost-http/handlers/server.go +++ b/transports/bifrost-http/handlers/server.go @@ -237,12 +237,14 @@ func LoadPlugin[T schemas.Plugin](ctx context.Context, name string, pluginConfig func LoadPlugins(ctx context.Context, config *lib.Config) ([]schemas.Plugin, error) { var err error plugins := []schemas.Plugin{} + config.LoadedPlugins = make(map[string]bool) // Initialize telemetry plugin promPlugin, err := LoadPlugin[*telemetry.PrometheusPlugin](ctx, telemetry.PluginName, nil, config) if err != nil { logger.Error("failed to initialize telemetry plugin: %v", err) } else { plugins = append(plugins, promPlugin) + config.LoadedPlugins[telemetry.PluginName] = true } // Initializing logger plugin var loggingPlugin *logging.LoggerPlugin @@ -253,6 +255,7 @@ func LoadPlugins(ctx context.Context, config *lib.Config) ([]schemas.Plugin, err logger.Error("failed to initialize logging plugin: %v", err) } else { plugins = append(plugins, loggingPlugin) + config.LoadedPlugins[logging.PluginName] = true } } // Initializing governance plugin @@ -266,7 +269,7 @@ func LoadPlugins(ctx context.Context, config *lib.Config) ([]schemas.Plugin, err logger.Error("failed to initialize governance plugin: %s", err.Error()) } else { plugins = append(plugins, governancePlugin) - + config.LoadedPlugins[governance.PluginName] = true } } // Currently we support first party plugins only @@ -280,6 +283,7 @@ func LoadPlugins(ctx context.Context, config *lib.Config) ([]schemas.Plugin, err logger.Error("failed to load plugin %s: %v", plugin.Name, err) } else { plugins = append(plugins, pluginInstance) + config.LoadedPlugins[plugin.Name] = true } } return plugins, nil @@ -466,7 +470,7 @@ func (s *BifrostHTTPServer) Bootstrap(ctx context.Context) error { } // Create fasthttp server instance s.Server = &fasthttp.Server{ - Handler: CorsMiddleware(s.Config)(s.Router.Handler), + Handler: CorsMiddleware(s.Config)(VKProviderRoutingMiddleware(s.Config, logger)(s.Router.Handler)), MaxRequestBodySize: s.Config.ClientConfig.MaxRequestBodySizeMB * 1024 * 1024, } return nil diff --git a/transports/bifrost-http/integrations/anthropic/types.go b/transports/bifrost-http/integrations/anthropic/types.go index f454be7f3..f9d6c481b 100644 --- a/transports/bifrost-http/integrations/anthropic/types.go +++ b/transports/bifrost-http/integrations/anthropic/types.go @@ -69,9 +69,9 @@ type AnthropicMessageRequest struct { Temperature *float64 `json:"temperature,omitempty"` TopP *float64 `json:"top_p,omitempty"` TopK *int `json:"top_k,omitempty"` - StopSequences []string `json:"stop_sequences,omitempty"` + StopSequences []string `json:"stop_sequences,omitempty"` Stream *bool `json:"stream,omitempty"` - Tools []AnthropicTool `json:"tools,omitempty"` + Tools []AnthropicTool `json:"tools,omitempty"` ToolChoice *AnthropicToolChoice `json:"tool_choice,omitempty"` } diff --git a/transports/bifrost-http/integrations/openai/types.go b/transports/bifrost-http/integrations/openai/types.go index 42e4fea6a..03c5cb86b 100644 --- a/transports/bifrost-http/integrations/openai/types.go +++ b/transports/bifrost-http/integrations/openai/types.go @@ -19,7 +19,7 @@ type OpenAIChatRequest struct { FrequencyPenalty *float64 `json:"frequency_penalty,omitempty"` LogitBias map[string]float64 `json:"logit_bias,omitempty"` User *string `json:"user,omitempty"` - Tools []schemas.Tool `json:"tools,omitempty"` // Reuse schema type + Tools []schemas.Tool `json:"tools,omitempty"` // Reuse schema type ToolChoice *schemas.ToolChoice `json:"tool_choice,omitempty"` Stream *bool `json:"stream,omitempty"` LogProbs *bool `json:"logprobs,omitempty"` @@ -140,8 +140,8 @@ type OpenAIStreamChoice struct { // OpenAIStreamDelta represents the incremental content in a streaming chunk type OpenAIStreamDelta struct { - Role *string `json:"role,omitempty"` - Content *string `json:"content,omitempty"` + Role *string `json:"role,omitempty"` + Content *string `json:"content,omitempty"` ToolCalls []schemas.ToolCall `json:"tool_calls,omitempty"` } diff --git a/transports/bifrost-http/lib/config.go b/transports/bifrost-http/lib/config.go index 52de5a406..bc8030f0d 100644 --- a/transports/bifrost-http/lib/config.go +++ b/transports/bifrost-http/lib/config.go @@ -134,7 +134,8 @@ type Config struct { EnvKeys map[string][]configstore.EnvKeyInfo // Plugin configs - Plugins []*schemas.PluginConfig + Plugins []*schemas.PluginConfig + LoadedPlugins map[string]bool // Pricing manager PricingManager *pricing.PricingManager @@ -223,7 +224,6 @@ func LoadConfig(ctx context.Context, configDirPath string) (*Config, error) { if err != nil { return nil, fmt.Errorf("failed to get logs store config: %w", err) } - logger.Debug("log store config from DB: %v", logStoreConfig) if logStoreConfig == nil { logStoreConfig = &logstore.Config{ Enabled: true, @@ -238,6 +238,7 @@ func LoadConfig(ctx context.Context, configDirPath string) (*Config, error) { if err != nil { return nil, fmt.Errorf("failed to initialize logs store: %v", err) } + logger.Info("logs store initialized.") err = config.ConfigStore.UpdateLogsStoreConfig(ctx, logStoreConfig) if err != nil { return nil, fmt.Errorf("failed to update logs store config: %w", err) @@ -602,7 +603,7 @@ func LoadConfig(ctx context.Context, configDirPath string) (*Config, error) { if config.ConfigStore != nil { logger.Debug("updating governance config in store") - if err := config.ConfigStore.ExecuteTransaction(ctx,func(tx *gorm.DB) error { + if err := config.ConfigStore.ExecuteTransaction(ctx, func(tx *gorm.DB) error { // Create budgets for _, budget := range config.GovernanceConfig.Budgets { if err := config.ConfigStore.CreateBudget(ctx, &budget, tx); err != nil { @@ -1472,7 +1473,7 @@ func (s *Config) EditMCPClientTools(ctx context.Context, name string, toolsToAdd if err := s.ConfigStore.UpdateMCPConfig(ctx, s.MCPConfig, s.EnvKeys); err != nil { return fmt.Errorf("failed to update MCP config in store: %w", err) } - if err := s.ConfigStore.UpdateEnvKeys(ctx, s.EnvKeys); err != nil { + if err := s.ConfigStore.UpdateEnvKeys(ctx, s.EnvKeys); err != nil { logger.Warn("failed to update env keys: %v", err) } } diff --git a/ui/app/providers/fragments/apiKeysFormFragment.tsx b/ui/app/providers/fragments/apiKeysFormFragment.tsx index 1ebda0058..e33f1fbaf 100644 --- a/ui/app/providers/fragments/apiKeysFormFragment.tsx +++ b/ui/app/providers/fragments/apiKeysFormFragment.tsx @@ -7,6 +7,7 @@ import { Separator } from "@/components/ui/separator"; import { TagInput } from "@/components/ui/tagInput"; import { Textarea } from "@/components/ui/textarea"; import { Tooltip, TooltipContent, TooltipProvider, TooltipTrigger } from "@/components/ui/tooltip"; +import { MODEL_PLACEHOLDERS } from "@/lib/constants/config"; import { isRedacted } from "@/lib/utils/validation"; import { Info } from "lucide-react"; import { Control, UseFormReturn } from "react-hook-form"; @@ -17,16 +18,7 @@ interface Props { form: UseFormReturn; } -// Model placeholders based on provider type -const MODEL_PLACEHOLDERS = { - default: "e.g. gpt-4, gpt-3.5-turbo. Leave blank for all models.", - openai: "e.g. gpt-4, gpt-3.5-turbo, gpt-4-turbo, gpt-4o", - azure: "e.g. gpt-4, gpt-3.5-turbo (must match deployment mappings)", - bedrock: "e.g. claude-v2, titan-text-express-v1", - vertex: "e.g. gemini-pro, text-bison, chat-bison", -}; - -export function ApiKeyFormFragment({ control, providerName, form }: Props) { +export function ApiKeyFormFragment({ control, providerName }: Props) { const isBedrock = providerName === "bedrock"; const isVertex = providerName === "vertex"; const isAzure = providerName === "azure"; diff --git a/ui/app/teams-customers/views/customerTable.tsx b/ui/app/teams-customers/views/customerTable.tsx index b717c3fa5..b3db75db9 100644 --- a/ui/app/teams-customers/views/customerTable.tsx +++ b/ui/app/teams-customers/views/customerTable.tsx @@ -18,10 +18,11 @@ import { Tooltip, TooltipContent, TooltipProvider, TooltipTrigger } from "@/comp import { getErrorMessage, useDeleteCustomerMutation } from "@/lib/store"; import { Customer, Team, VirtualKey } from "@/lib/types/governance"; import { formatCurrency, parseResetPeriod } from "@/lib/utils/governance"; -import { DollarSign, Edit, Key, Plus, Trash2, Users } from "lucide-react"; +import { Edit, Key, Plus, Trash2, Users } from "lucide-react"; import { useState } from "react"; import { toast } from "sonner"; import CustomerDialog from "./customerDialog"; +import { cn } from "@/lib/utils"; interface CustomersTableProps { customers: Customer[]; @@ -94,6 +95,7 @@ export default function CustomersTable({ customers, teams, virtualKeys, onRefres Name Teams Budget + Reset Period Virtual Keys Actions @@ -136,18 +138,21 @@ export default function CustomersTable({ customers, teams, virtualKeys, onRefres {customer.budget ? ( -
- - - {formatCurrency(customer.budget.current_usage)} / {formatCurrency(customer.budget.max_limit)} - - = customer.budget.max_limit ? "destructive" : "secondary"} - className="text-xs" - > - {parseResetPeriod(customer.budget.reset_duration)} - -
+ = customer.budget.max_limit && "text-destructive", + )} + > + {formatCurrency(customer.budget.current_usage)} / {formatCurrency(customer.budget.max_limit)} + + ) : ( + - + )} +
+ + {customer.budget ? ( + parseResetPeriod(customer.budget.reset_duration) ) : ( - )} diff --git a/ui/app/teams-customers/views/teamDialog.tsx b/ui/app/teams-customers/views/teamDialog.tsx index 4cb15de7f..2ee04143a 100644 --- a/ui/app/teams-customers/views/teamDialog.tsx +++ b/ui/app/teams-customers/views/teamDialog.tsx @@ -14,7 +14,7 @@ import { formatCurrency } from "@/lib/utils/governance"; import { Validator } from "@/lib/utils/validation"; import { formatDistanceToNow } from "date-fns"; import isEqual from "lodash.isequal"; -import { User } from "lucide-react"; +import { Building } from "lucide-react"; import { useEffect, useMemo, useState } from "react"; import { toast } from "sonner"; @@ -185,7 +185,7 @@ export default function TeamDialog({ team, customers, onSave, onCancel }: TeamDi {customers.map((customer) => (
- + {customer.name}
@@ -214,7 +214,8 @@ export default function TeamDialog({ team, customers, onSave, onCancel }: TeamDi Current Usage:
- {formatCurrency(team.budget.current_usage)} / {formatCurrency(team.budget.max_limit)} + {formatCurrency(team.budget.current_usage)} /{" "} + {formatCurrency(team.budget.max_limit)} = team.budget.max_limit ? "destructive" : "default"} diff --git a/ui/app/virtual-keys/views/virtualKeyDetailsDialog.tsx b/ui/app/virtual-keys/views/virtualKeyDetailsDialog.tsx index 22ae15461..0a72d3178 100644 --- a/ui/app/virtual-keys/views/virtualKeyDetailsDialog.tsx +++ b/ui/app/virtual-keys/views/virtualKeyDetailsDialog.tsx @@ -3,6 +3,9 @@ import { Badge } from "@/components/ui/badge"; import { Dialog, DialogContent, DialogDescription, DialogHeader, DialogTitle } from "@/components/ui/dialog"; import { Separator } from "@/components/ui/separator"; +import { Table, TableBody, TableCell, TableHead, TableHeader, TableRow } from "@/components/ui/table"; +import { ProviderIconType, RenderProviderIcon } from "@/lib/constants/icons"; +import { ProviderLabels, ProviderName } from "@/lib/constants/logs"; import { VirtualKey } from "@/lib/types/governance"; import { calculateUsagePercentage, formatCurrency, getUsageVariant, parseResetPeriod } from "@/lib/utils/governance"; import { formatDistanceToNow } from "date-fns"; @@ -36,8 +39,8 @@ export default function VirtualKeyDetailDialog({ virtualKey, onClose }: VirtualK return ( - - + + {virtualKey.name} {virtualKey.description || "Virtual key details and usage information"} @@ -86,56 +89,99 @@ export default function VirtualKeyDetailDialog({ virtualKey, onClose }: VirtualK - {/* Model & Provider Restrictions */} + {/* Allowed Keys */}
-

Allowed Models & Providers

+

Allowed Keys

- {!virtualKey.allowed_models && !virtualKey.allowed_providers ? ( - All models and providers allowed + {virtualKey.keys && virtualKey.keys.length > 0 ? ( +
+ + + + Key ID + Allowed Models + + + + {virtualKey.keys.map((key) => ( + + + {key.key_id} + + + {key.models && key.models.length > 0 ? ( +
+ {key.models.map((model: string) => ( + + {model} + + ))} +
+ ) : ( + All models allowed + )} +
+
+ ))} +
+
+
) : ( - <> - {virtualKey.allowed_models && virtualKey.allowed_models.length > 0 ? ( -
- Models -
- {virtualKey.allowed_models && virtualKey.allowed_models.length > 0 ? ( -
- {virtualKey.allowed_models.map((model) => ( - - {model} - - ))} -
- ) : ( - All models allowed - )} -
-
- ) : ( - All models allowed - )} - {virtualKey.allowed_providers && virtualKey.allowed_providers.length > 0 ? ( -
- Providers -
- {virtualKey.allowed_providers && virtualKey.allowed_providers.length > 0 ? ( -
- {virtualKey.allowed_providers.map((provider) => ( - - {provider} - - ))} -
- ) : ( - All providers allowed - )} -
-
- ) : ( - All providers allowed - )} - + No specific keys assigned - all keys allowed + )} +
+
+ + + + {/* Provider Configurations */} +
+

Provider Configurations

+ +
+ {!virtualKey.provider_configs || virtualKey.provider_configs.length === 0 ? ( + All providers allowed with default settings + ) : ( +
+ + + + Provider + Weight + Allowed Models + + + + {virtualKey.provider_configs.map((config, index) => ( + + +
+ + {ProviderLabels[config.provider as ProviderName] || config.provider} +
+
+ + {config.weight} + + + {config.allowed_models && config.allowed_models.length > 0 ? ( +
+ {config.allowed_models.map((model) => ( + + {model} + + ))} +
+ ) : ( + All models allowed + )} +
+
+ ))} +
+
+
)}
diff --git a/ui/app/virtual-keys/views/virtualKeyDialog.tsx b/ui/app/virtual-keys/views/virtualKeyDialog.tsx index 40184e8c5..6aadeed74 100644 --- a/ui/app/virtual-keys/views/virtualKeyDialog.tsx +++ b/ui/app/virtual-keys/views/virtualKeyDialog.tsx @@ -1,25 +1,38 @@ "use client"; -import FormFooter from "@/components/formFooter"; +import { Button } from "@/components/ui/button"; import { Dialog, DialogContent, DialogDescription, DialogHeader, DialogTitle } from "@/components/ui/dialog"; +import { Form, FormControl, FormField, FormItem, FormLabel, FormMessage } from "@/components/ui/form"; import { Input } from "@/components/ui/input"; import { Label } from "@/components/ui/label"; import { MultiSelect } from "@/components/ui/multiSelect"; import NumberAndSelect from "@/components/ui/numberAndSelect"; import { Select, SelectContent, SelectItem, SelectTrigger, SelectValue } from "@/components/ui/select"; import { DottedSeparator } from "@/components/ui/separator"; +import { Table, TableBody, TableCell, TableHead, TableHeader, TableRow } from "@/components/ui/table"; import { TagInput } from "@/components/ui/tagInput"; import { Textarea } from "@/components/ui/textarea"; import Toggle from "@/components/ui/toggle"; import { Tooltip, TooltipContent, TooltipProvider, TooltipTrigger } from "@/components/ui/tooltip"; +import { MODEL_PLACEHOLDERS as ModelPlaceholders } from "@/lib/constants/config"; import { resetDurationOptions } from "@/lib/constants/governance"; +import { ProviderIconType, RenderProviderIcon } from "@/lib/constants/icons"; +import { ProviderLabels, ProviderName, ProviderNames } from "@/lib/constants/logs"; import { getErrorMessage, useCreateVirtualKeyMutation, useGetAllKeysQuery, useUpdateVirtualKeyMutation } from "@/lib/store"; -import { CreateVirtualKeyRequest, Customer, Team, UpdateVirtualKeyRequest, VirtualKey } from "@/lib/types/governance"; -import { Validator } from "@/lib/utils/validation"; -import isEqual from "lodash.isequal"; -import { Info, User, Users } from "lucide-react"; -import { useEffect, useMemo, useState } from "react"; +import { + CreateVirtualKeyRequest, + Customer, + Team, + UpdateVirtualKeyRequest, + VirtualKey, + VirtualKeyProviderConfig, +} from "@/lib/types/governance"; +import { zodResolver } from "@hookform/resolvers/zod"; +import { Building, Info, Trash2, Users } from "lucide-react"; +import { useEffect, useState } from "react"; +import { useForm } from "react-hook-form"; import { toast } from "sonner"; +import { z } from "zod"; interface VirtualKeyDialogProps { virtualKey?: VirtualKey | null; @@ -29,59 +42,57 @@ interface VirtualKeyDialogProps { onCancel: () => void; } -interface VirtualKeyFormData { - name: string; - description: string; - allowedModels: string[]; - allowedProviders: string[]; - entityType: "team" | "customer" | "none"; - teamId: string; - customerId: string; - isActive: boolean; - selectedDBKeys: string[]; // Array of selected DBKey IDs - // Budget - budgetMaxLimit: number | undefined; - budgetResetDuration: string; - // Token limits - tokenMaxLimit: number | undefined; - tokenResetDuration: string; - // Request limits - requestMaxLimit: number | undefined; - requestResetDuration: string; - isDirty: boolean; -} - -// Helper function to create initial state -const createInitialState = (virtualKey?: VirtualKey | null): Omit => { - return { - name: virtualKey?.name || "", - description: virtualKey?.description || "", - allowedModels: virtualKey?.allowed_models || [], - allowedProviders: virtualKey?.allowed_providers || [], - entityType: virtualKey?.team_id ? "team" : virtualKey?.customer_id ? "customer" : "none", - teamId: virtualKey?.team_id || "", - customerId: virtualKey?.customer_id || "", - isActive: virtualKey?.is_active ?? true, - selectedDBKeys: virtualKey?.keys?.map((key) => key.key_id) || [], // Extract key IDs +// Provider configuration schema +const providerConfigSchema = z.object({ + id: z.number().optional(), + provider: z.string().min(1, "Provider is required"), + weight: z.union([z.number().min(0, "Weight must be at least 0").max(1, "Weight must be at most 1"), z.string()]), + allowed_models: z.array(z.string()).optional(), +}); + +// Main form schema +const formSchema = z + .object({ + name: z.string().min(1, "Virtual key name is required"), + description: z.string().optional(), + providerConfigs: z.array(providerConfigSchema).optional(), + entityType: z.enum(["team", "customer", "none"]), + teamId: z.string().optional(), + customerId: z.string().optional(), + isActive: z.boolean(), + selectedDBKeys: z.array(z.string()).optional(), // Budget - budgetMaxLimit: virtualKey?.budget ? virtualKey.budget.max_limit : undefined, // Already in dollars - budgetResetDuration: virtualKey?.budget?.reset_duration || "1M", + budgetMaxLimit: z.string().optional(), + budgetResetDuration: z.string().optional(), // Token limits - tokenMaxLimit: virtualKey?.rate_limit?.token_max_limit || undefined, - tokenResetDuration: virtualKey?.rate_limit?.token_reset_duration || "1h", + tokenMaxLimit: z.string().optional(), + tokenResetDuration: z.string().optional(), // Request limits - requestMaxLimit: virtualKey?.rate_limit?.request_max_limit || undefined, - requestResetDuration: virtualKey?.rate_limit?.request_reset_duration || "1h", - }; -}; + requestMaxLimit: z.string().optional(), + requestResetDuration: z.string().optional(), + }) + .refine( + (data) => { + // Validate that sum of provider weights equals 1 (only when there are multiple providers) + if (data.providerConfigs && data.providerConfigs.length > 1) { + const totalWeight = data.providerConfigs.reduce((sum, config) => { + const weight = typeof config.weight === "string" ? parseFloat(config.weight) : config.weight; + return sum + (isNaN(weight) ? 0 : weight); + }, 0); + return Math.abs(totalWeight - 1) < 0.001; // Allow small floating point errors + } + return true; + }, + { + message: "Sum of all provider weights must equal 1 when multiple providers are configured", + path: ["providerConfigs"], + }, + ); + +type FormData = z.infer; export default function VirtualKeyDialog({ virtualKey, teams, customers, onSave, onCancel }: VirtualKeyDialogProps) { const isEditing = !!virtualKey; - const [initialState] = useState>(createInitialState(virtualKey)); - const [formData, setFormData] = useState({ - ...initialState, - isDirty: false, - }); // RTK Query hooks const { data: keysData, error: keysError, isLoading: keysLoading } = useGetAllKeysQuery(); @@ -91,6 +102,27 @@ export default function VirtualKeyDialog({ virtualKey, teams, customers, onSave, const availableKeys = keysData || []; + // Form setup + const form = useForm({ + resolver: zodResolver(formSchema), + defaultValues: { + name: virtualKey?.name || "", + description: virtualKey?.description || "", + providerConfigs: virtualKey?.provider_configs || [], + entityType: virtualKey?.team_id ? "team" : virtualKey?.customer_id ? "customer" : "none", + teamId: virtualKey?.team_id || "", + customerId: virtualKey?.customer_id || "", + isActive: virtualKey?.is_active ?? true, + selectedDBKeys: virtualKey?.keys?.map((key) => key.key_id) || [], + budgetMaxLimit: virtualKey?.budget ? String(virtualKey.budget.max_limit) : "", + budgetResetDuration: virtualKey?.budget?.reset_duration || "1M", + tokenMaxLimit: virtualKey?.rate_limit?.token_max_limit ? String(virtualKey.rate_limit.token_max_limit) : "", + tokenResetDuration: virtualKey?.rate_limit?.token_reset_duration || "1h", + requestMaxLimit: virtualKey?.rate_limit?.request_max_limit ? String(virtualKey.rate_limit.request_max_limit) : "", + requestResetDuration: virtualKey?.rate_limit?.request_reset_duration || "1h", + }, + }); + // Handle keys loading error useEffect(() => { if (keysError) { @@ -98,139 +130,94 @@ export default function VirtualKeyDialog({ virtualKey, teams, customers, onSave, } }, [keysError]); - // Track isDirty state - useEffect(() => { - const currentData = { - name: formData.name, - description: formData.description, - allowedModels: formData.allowedModels, - allowedProviders: formData.allowedProviders, - entityType: formData.entityType, - teamId: formData.teamId, - customerId: formData.customerId, - isActive: formData.isActive, - selectedDBKeys: formData.selectedDBKeys, - budgetMaxLimit: formData.budgetMaxLimit, - budgetResetDuration: formData.budgetResetDuration, - tokenMaxLimit: formData.tokenMaxLimit, - tokenResetDuration: formData.tokenResetDuration, - requestMaxLimit: formData.requestMaxLimit, - requestResetDuration: formData.requestResetDuration, + // Provider configuration state + const [selectedProvider, setSelectedProvider] = useState(""); + + // Get current provider configs from form + const providerConfigs = form.watch("providerConfigs") || []; + + // Handle adding a new provider configuration + const handleAddProvider = (provider: string) => { + const existingConfig = providerConfigs.find((config) => config.provider === provider); + if (existingConfig) { + toast.error("This provider is already configured"); + return; + } + + const newConfig: VirtualKeyProviderConfig = { + provider: provider, + weight: 0.5, // Default weight, user can adjust + allowed_models: [], }; - setFormData((prev) => ({ - ...prev, - isDirty: !isEqual(initialState, currentData), - })); - }, [ - formData.name, - formData.description, - formData.allowedModels, - formData.allowedProviders, - formData.entityType, - formData.teamId, - formData.customerId, - formData.isActive, - formData.selectedDBKeys, - formData.budgetMaxLimit, - formData.budgetResetDuration, - formData.tokenMaxLimit, - formData.tokenResetDuration, - formData.requestMaxLimit, - formData.requestResetDuration, - initialState, - ]); - - // Validation - const validator = useMemo( - () => - new Validator([ - // Basic validation - Validator.required(formData.name.trim(), "Virtual key name is required"), - - // Check if anything is dirty - Validator.custom(formData.isDirty, "No changes to save"), - - // Entity validation - Validator.custom( - formData.entityType === "none" || - (formData.entityType === "team" && !!formData.teamId) || - (formData.entityType === "customer" && !!formData.customerId), - "Please select a valid team or customer assignment", - ), - - // Budget validation - ...(formData.budgetMaxLimit - ? [ - Validator.minValue(formData.budgetMaxLimit, 0.01, "Budget max limit must be greater than $0.01"), - Validator.required(formData.budgetResetDuration, "Budget reset duration is required"), - ] - : []), - - // Rate limit validation - at least one limit must be set if rate limiting is enabled - ...(formData.tokenMaxLimit || formData.requestMaxLimit - ? [ - // Token limit validation - ...(formData.tokenMaxLimit - ? [ - Validator.required(formData.tokenMaxLimit, "Token max limit is required when token limiting is enabled"), - Validator.minValue(formData.tokenMaxLimit || 0, 1, "Token max limit must be at least 1"), - Validator.required(formData.tokenResetDuration, "Token reset duration is required"), - ] - : []), - // Request limit validation - ...(formData.requestMaxLimit - ? [ - Validator.required(formData.requestMaxLimit, "Request max limit is required when request limiting is enabled"), - Validator.minValue(formData.requestMaxLimit || 0, 1, "Request max limit must be at least 1"), - Validator.required(formData.requestResetDuration, "Request reset duration is required"), - ] - : []), - ] - : []), - ]), - [formData], - ); - const updateField = (field: K, value: VirtualKeyFormData[K]) => { - setFormData((prev) => ({ ...prev, [field]: value })); + form.setValue("providerConfigs", [...providerConfigs, newConfig], { shouldDirty: true }); }; - const handleSubmit = async (e: React.FormEvent) => { - e.preventDefault(); + // Handle removing a provider configuration + const handleRemoveProvider = (index: number) => { + const updatedConfigs = providerConfigs.filter((_, i) => i !== index); + form.setValue("providerConfigs", updatedConfigs, { shouldDirty: true }); + }; - if (!validator.isValid()) { - toast.error(validator.getFirstError()); - return; - } + // Handle updating provider configuration + const handleUpdateProviderConfig = (index: number, field: keyof VirtualKeyProviderConfig, value: any) => { + const updatedConfigs = [...providerConfigs]; + updatedConfigs[index] = { ...updatedConfigs[index], [field]: value }; + form.setValue("providerConfigs", updatedConfigs, { shouldDirty: true }); + }; + // Helper function to convert string weights to numbers + const normalizeProviderConfigs = (configs: (VirtualKeyProviderConfig & { weight: string | number })[]): VirtualKeyProviderConfig[] => { + return configs.map((config) => ({ + ...config, + weight: typeof config.weight === "string" ? parseFloat(config.weight) || 0 : config.weight, + })); + }; + + // Normalize numeric fields to ensure they are numbers or undefined + const normalizeNumericField = (value: string | undefined): number | undefined => { + if (value === undefined || value === "") return undefined; + const num = parseFloat(value); + return isNaN(num) ? undefined : num; + }; + + // Handle form submission + const onSubmit = async (data: FormData) => { try { + // Normalize provider configs to ensure weights are numbers + const normalizedProviderConfigs = data.providerConfigs + ? normalizeProviderConfigs(data.providerConfigs as (VirtualKeyProviderConfig & { weight: string | number })[]) + : []; + if (isEditing && virtualKey) { // Update existing virtual key const updateData: UpdateVirtualKeyRequest = { - description: formData.description || undefined, - allowed_models: formData.allowedModels, - allowed_providers: formData.allowedProviders, - team_id: formData.entityType === "team" ? formData.teamId : undefined, - customer_id: formData.entityType === "customer" ? formData.customerId : undefined, - key_ids: !isEqual(formData.selectedDBKeys, initialState.selectedDBKeys) ? formData.selectedDBKeys : undefined, // Only send if changed - is_active: formData.isActive, + description: data.description || undefined, + provider_configs: normalizedProviderConfigs, + team_id: data.entityType === "team" ? data.teamId : undefined, + customer_id: data.entityType === "customer" ? data.customerId : undefined, + key_ids: data.selectedDBKeys, + is_active: data.isActive, }; // Add budget if enabled - if (formData.budgetMaxLimit) { + const budgetMaxLimit = normalizeNumericField(data.budgetMaxLimit); + if (budgetMaxLimit) { updateData.budget = { - max_limit: formData.budgetMaxLimit, // Already in dollars - reset_duration: formData.budgetResetDuration, + max_limit: budgetMaxLimit, + reset_duration: data.budgetResetDuration || "1M", }; } // Add rate limit if enabled - if (formData.tokenMaxLimit || formData.requestMaxLimit) { + const tokenMaxLimit = normalizeNumericField(data.tokenMaxLimit); + const requestMaxLimit = normalizeNumericField(data.requestMaxLimit); + if (tokenMaxLimit || requestMaxLimit) { updateData.rate_limit = { - token_max_limit: formData.tokenMaxLimit, - token_reset_duration: formData.tokenResetDuration, - request_max_limit: formData.requestMaxLimit, - request_reset_duration: formData.requestResetDuration, + token_max_limit: tokenMaxLimit, + token_reset_duration: data.tokenResetDuration || "1h", + request_max_limit: requestMaxLimit, + request_reset_duration: data.requestResetDuration || "1h", }; } @@ -239,31 +226,33 @@ export default function VirtualKeyDialog({ virtualKey, teams, customers, onSave, } else { // Create new virtual key const createData: CreateVirtualKeyRequest = { - name: formData.name, - description: formData.description || undefined, - allowed_models: formData.allowedModels.length > 0 ? formData.allowedModels : undefined, - allowed_providers: formData.allowedProviders.length > 0 ? formData.allowedProviders : undefined, - team_id: formData.entityType === "team" ? formData.teamId : undefined, - customer_id: formData.entityType === "customer" ? formData.customerId : undefined, - key_ids: formData.selectedDBKeys.length > 0 ? formData.selectedDBKeys : undefined, // Empty means all keys - is_active: formData.isActive, + name: data.name, + description: data.description || undefined, + provider_configs: normalizedProviderConfigs, + team_id: data.entityType === "team" ? data.teamId : undefined, + customer_id: data.entityType === "customer" ? data.customerId : undefined, + key_ids: data.selectedDBKeys, + is_active: data.isActive, }; // Add budget if enabled - if (formData.budgetMaxLimit) { + const budgetMaxLimit = normalizeNumericField(data.budgetMaxLimit); + if (budgetMaxLimit) { createData.budget = { - max_limit: formData.budgetMaxLimit, // Already in dollars - reset_duration: formData.budgetResetDuration, + max_limit: budgetMaxLimit, + reset_duration: data.budgetResetDuration || "1M", }; } // Add rate limit if enabled - if (formData.tokenMaxLimit || formData.requestMaxLimit) { + const tokenMaxLimit = normalizeNumericField(data.tokenMaxLimit); + const requestMaxLimit = normalizeNumericField(data.requestMaxLimit); + if (tokenMaxLimit || requestMaxLimit) { createData.rate_limit = { - token_max_limit: formData.tokenMaxLimit, - token_reset_duration: formData.tokenResetDuration, - request_max_limit: formData.requestMaxLimit, - request_reset_duration: formData.requestResetDuration, + token_max_limit: tokenMaxLimit, + token_reset_duration: data.tokenResetDuration || "1h", + request_max_limit: requestMaxLimit, + request_reset_duration: data.requestResetDuration || "1h", }; } @@ -289,217 +278,439 @@ export default function VirtualKeyDialog({ virtualKey, teams, customers, onSave,
-
-
-
- - updateField("name", e.target.value)} - maxLength={50} - disabled={isEditing} // Can't change name when editing - /> -
+ + +
+ {/* Basic Information */} +
+ ( + + Name * + + + + + + )} + /> + + ( + + Description + +