Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 15 additions & 8 deletions core/bifrost.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ import (
"sync/atomic"
"time"

"github.com/google/uuid"
"github.com/maximhq/bifrost/core/providers"
schemas "github.com/maximhq/bifrost/core/schemas"
)
Expand Down Expand Up @@ -681,10 +682,11 @@ transferComplete:
providerKey,
providerConfig.ConcurrencyAndBufferSize.BufferSize)

waitGroupValue, _ := bifrost.waitGroups.Load(providerKey)
currentWaitGroup := waitGroupValue.(*sync.WaitGroup)

for range providerConfig.ConcurrencyAndBufferSize.Concurrency {
waitGroupValue, _ := bifrost.waitGroups.Load(providerKey)
waitGroup := waitGroupValue.(*sync.WaitGroup)
waitGroup.Add(1)
currentWaitGroup.Add(1)
go bifrost.requestWorker(provider, providerConfig, newQueue)
}

Expand Down Expand Up @@ -992,10 +994,11 @@ func (bifrost *Bifrost) prepareProvider(providerKey schemas.ModelProvider, confi
return fmt.Errorf("failed to create provider for the given key: %v", err)
}

waitGroupValue, _ := bifrost.waitGroups.Load(providerKey)
currentWaitGroup := waitGroupValue.(*sync.WaitGroup)

for range providerConfig.ConcurrencyAndBufferSize.Concurrency {
waitGroupValue, _ := bifrost.waitGroups.Load(providerKey)
waitGroup := waitGroupValue.(*sync.WaitGroup)
waitGroup.Add(1)
currentWaitGroup.Add(1)
go bifrost.requestWorker(provider, providerConfig, queue)
}

Expand Down Expand Up @@ -1182,6 +1185,8 @@ func (bifrost *Bifrost) handleRequest(ctx context.Context, req *schemas.BifrostR

// Try fallbacks in order
for _, fallback := range req.Fallbacks {
ctx = context.WithValue(ctx, schemas.BifrostContextKeyFallbackRequestID, uuid.New().String())

fallbackReq := bifrost.prepareFallbackRequest(req, fallback)
if fallbackReq == nil {
continue
Expand All @@ -1190,7 +1195,7 @@ func (bifrost *Bifrost) handleRequest(ctx context.Context, req *schemas.BifrostR
// Try the fallback provider
result, fallbackErr := bifrost.tryRequest(fallbackReq, ctx)
if fallbackErr == nil {
bifrost.logger.Info(fmt.Sprintf("Successfully used fallback provider %s with model %s", fallback.Provider, fallback.Model))
bifrost.logger.Debug(fmt.Sprintf("Successfully used fallback provider %s with model %s", fallback.Provider, fallback.Model))
return result, nil
}

Expand Down Expand Up @@ -1234,6 +1239,8 @@ func (bifrost *Bifrost) handleStreamRequest(ctx context.Context, req *schemas.Bi

// Try fallbacks in order
for _, fallback := range req.Fallbacks {
ctx = context.WithValue(ctx, schemas.BifrostContextKeyFallbackRequestID, uuid.New().String())

fallbackReq := bifrost.prepareFallbackRequest(req, fallback)
if fallbackReq == nil {
continue
Expand All @@ -1242,7 +1249,7 @@ func (bifrost *Bifrost) handleStreamRequest(ctx context.Context, req *schemas.Bi
// Try the fallback provider
result, fallbackErr := bifrost.tryStreamRequest(fallbackReq, ctx)
if fallbackErr == nil {
bifrost.logger.Info(fmt.Sprintf("Successfully used fallback provider %s with model %s", fallback.Provider, fallback.Model))
bifrost.logger.Debug(fmt.Sprintf("Successfully used fallback provider %s with model %s", fallback.Provider, fallback.Model))
return result, nil
}

Expand Down
2 changes: 1 addition & 1 deletion core/go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ require (
github.com/aws/aws-sdk-go-v2 v1.38.0
github.com/aws/aws-sdk-go-v2/config v1.31.0
github.com/bytedance/sonic v1.14.0
github.com/google/uuid v1.6.0
github.com/mark3labs/mcp-go v0.37.0
github.com/rs/zerolog v1.34.0
github.com/valyala/fasthttp v1.65.0
Expand All @@ -33,7 +34,6 @@ require (
github.com/bytedance/sonic/loader v0.3.0 // indirect
github.com/cloudwego/base64x v0.1.6 // indirect
github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc // indirect
github.com/google/uuid v1.6.0 // indirect
github.com/invopop/jsonschema v0.13.0 // indirect
github.com/klauspost/compress v1.18.0 // indirect
github.com/klauspost/cpuid/v2 v2.3.0 // indirect
Expand Down
12 changes: 11 additions & 1 deletion core/providers/sgl.go
Original file line number Diff line number Diff line change
Expand Up @@ -114,7 +114,17 @@ func (provider *SGLProvider) Responses(ctx context.Context, key schemas.Key, req

// Embedding is not supported by the SGL provider.
func (provider *SGLProvider) Embedding(ctx context.Context, key schemas.Key, request *schemas.BifrostEmbeddingRequest) (*schemas.BifrostResponse, *schemas.BifrostError) {
return nil, newUnsupportedOperationError("embedding", "sgl")
return handleOpenAIEmbeddingRequest(
ctx,
provider.client,
provider.networkConfig.BaseURL+"/v1/embeddings",
request,
key,
provider.networkConfig.ExtraHeaders,
provider.GetProviderKey(),
provider.sendBackRawResponse,
provider.logger,
)
}

// ChatCompletionStream performs a streaming chat completion request to the SGL API.
Expand Down
1 change: 1 addition & 0 deletions core/schemas/bifrost.go
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,7 @@ type BifrostContextKey string
// BifrostContextKeyRequestType is a context key for the request type.
const (
BifrostContextKeyRequestID BifrostContextKey = "request-id"
BifrostContextKeyFallbackRequestID BifrostContextKey = "fallback-request-id"
BifrostContextKeyVirtualKeyHeader BifrostContextKey = "x-bf-vk"
BifrostContextKeyDirectKey BifrostContextKey = "bifrost-direct-key"
BifrostContextKeyStreamEndIndicator BifrostContextKey = "bifrost-stream-end-indicator"
Expand Down
86 changes: 73 additions & 13 deletions core/schemas/providers/gemini/chat.go
Original file line number Diff line number Diff line change
Expand Up @@ -386,19 +386,56 @@ func (r *GenerateContentResponse) ToBifrostResponse() *schemas.BifrostResponse {
},
}
} else if hasText && textContent != "" {
// This is a transcription response
response.Object = "audio.transcription"
response.Transcribe = &schemas.BifrostTranscribe{
Text: textContent,
Usage: &schemas.TranscriptionUsage{
Type: "tokens",
InputTokens: &inputTokens,
OutputTokens: &outputTokens,
TotalTokens: &totalTokens,
},
BifrostTranscribeNonStreamResponse: &schemas.BifrostTranscribeNonStreamResponse{
Task: schemas.Ptr("transcribe"),
},
// Check if this is actually a transcription response by looking for transcription context
// Only treat as transcription if we have explicit transcription metadata or context
isTranscription := r.isTranscriptionResponse()

if isTranscription {
// This is a transcription response
response.Object = "audio.transcription"
response.Transcribe = &schemas.BifrostTranscribe{
Text: textContent,
Usage: &schemas.TranscriptionUsage{
Type: "tokens",
InputTokens: &inputTokens,
OutputTokens: &outputTokens,
TotalTokens: &totalTokens,
},
BifrostTranscribeNonStreamResponse: &schemas.BifrostTranscribeNonStreamResponse{
Task: schemas.Ptr("transcribe"),
},
}
} else {
// This is a regular chat completion response
response.Object = "chat.completion"

// Create choice from the candidate
choice := schemas.BifrostChatResponseChoice{
Index: 0,
BifrostNonStreamResponseChoice: &schemas.BifrostNonStreamResponseChoice{
Message: schemas.ChatMessage{
Role: schemas.ChatMessageRoleAssistant,
Content: schemas.ChatMessageContent{
ContentStr: &textContent,
},
},
},
}

// Set finish reason if available
if candidate.FinishReason != "" {
finishReason := string(candidate.FinishReason)
choice.FinishReason = &finishReason
}

response.Choices = []schemas.BifrostChatResponseChoice{choice}

// Set usage information
response.Usage = &schemas.LLMUsage{
PromptTokens: inputTokens,
CompletionTokens: outputTokens,
TotalTokens: totalTokens,
}
}
}
}
Expand All @@ -407,6 +444,29 @@ func (r *GenerateContentResponse) ToBifrostResponse() *schemas.BifrostResponse {
return response
}

// isTranscriptionResponse determines if this response is from a transcription request
// by checking for transcription-specific context and metadata
func (r *GenerateContentResponse) isTranscriptionResponse() bool {
// Check if any candidates contain audio input data in their parts
// This would indicate the original request included audio for transcription
for _, candidate := range r.Candidates {
if candidate.Content != nil {
for _, part := range candidate.Content.Parts {
if part.InlineData != nil && part.InlineData.MIMEType != "" {
// If we have audio data in the response parts, it's likely a transcription
if strings.HasPrefix(part.InlineData.MIMEType, "audio/") {
return true
}
}
}
}
}

// Default to false - assume it's a regular chat completion
// This is safer than incorrectly classifying chat responses as transcriptions
return false
}

// FromBifrostResponse converts a BifrostResponse back to Gemini's GenerateContentResponse
func ToGeminiGenerationResponse(bifrostResp *schemas.BifrostResponse) interface{} {
if bifrostResp == nil {
Expand Down
50 changes: 50 additions & 0 deletions framework/configstore/migrations.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,12 @@ func triggerMigrations(ctx context.Context, db *gorm.DB) error {
if err := migrationAddOpenAIUseResponsesAPIColumn(ctx, db); err != nil {
return err
}
if err := migrationAddAllowedOriginsJSONColumn(ctx, db); err != nil {
return err
}
if err := migrationAddAllowDirectKeysColumn(ctx, db); err != nil {
return err
}
return nil
}

Expand Down Expand Up @@ -307,3 +313,47 @@ func migrationAddOpenAIUseResponsesAPIColumn(ctx context.Context, db *gorm.DB) e
}
return nil
}

func migrationAddAllowedOriginsJSONColumn(ctx context.Context, db *gorm.DB) error {
m := migration.New(db, migration.DefaultOptions, []*migration.Migration{{
ID: "add_allowed_origins_json_column",
Migrate: func(tx *gorm.DB) error {
tx = tx.WithContext(ctx)
migrator := tx.Migrator()

if !migrator.HasColumn(&TableClientConfig{}, "allowed_origins_json") {
if err := migrator.AddColumn(&TableClientConfig{}, "allowed_origins_json"); err != nil {
return err
}
}
return nil
},
}})
err := m.Migrate()
if err != nil {
return fmt.Errorf("error while running db migration: %s", err.Error())
}
return nil
}

func migrationAddAllowDirectKeysColumn(ctx context.Context, db *gorm.DB) error {
m := migration.New(db, migration.DefaultOptions, []*migration.Migration{{
ID: "add_allow_direct_keys_column",
Migrate: func(tx *gorm.DB) error {
tx = tx.WithContext(ctx)
migrator := tx.Migrator()

if !migrator.HasColumn(&TableClientConfig{}, "allow_direct_keys") {
if err := migrator.AddColumn(&TableClientConfig{}, "allow_direct_keys"); err != nil {
return err
}
}
return nil
},
}})
err := m.Migrate()
if err != nil {
return fmt.Errorf("error while running db migration: %s", err.Error())
}
return nil
}
2 changes: 2 additions & 0 deletions framework/configstore/rdb.go
Original file line number Diff line number Diff line change
Expand Up @@ -818,6 +818,7 @@ func (s *RDBConfigStore) GetVirtualKeys(ctx context.Context) ([]TableVirtualKey,
Preload("Customer").
Preload("Budget").
Preload("RateLimit").
Preload("ProviderConfigs").
Preload("Keys", func(db *gorm.DB) *gorm.DB {
return db.Select("id, key_id, models_json")
}).Find(&virtualKeys).Error; err != nil {
Expand All @@ -834,6 +835,7 @@ func (s *RDBConfigStore) GetVirtualKey(ctx context.Context, id string) (*TableVi
Preload("Customer").
Preload("Budget").
Preload("RateLimit").
Preload("ProviderConfigs").
Preload("Keys", func(db *gorm.DB) *gorm.DB {
return db.Select("id, key_id, models_json")
}).First(&virtualKey, "id = ?", id).Error; err != nil {
Expand Down
1 change: 1 addition & 0 deletions framework/logstore/tables.go
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,7 @@ type SearchStats struct {
// This is the GORM model with appropriate tags
type Log struct {
ID string `gorm:"primaryKey;type:varchar(255)" json:"id"`
ParentRequestID *string `gorm:"type:varchar(255)" json:"parent_request_id"`
Timestamp time.Time `gorm:"index;not null" json:"timestamp"`
Object string `gorm:"type:varchar(255);index;not null;column:object_type" json:"object"` // text.completion, chat.completion, or embedding
Provider string `gorm:"type:varchar(255);index;not null" json:"provider"`
Expand Down
28 changes: 22 additions & 6 deletions plugins/logging/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ import (
)

const (
PluginName = "bifrost-http-logging"
PluginName = "logging"
)

// ContextKey is a custom type for context keys to prevent collisions
Expand All @@ -37,8 +37,8 @@ const (

// Context keys for logging optimization
const (
DroppedCreateContextKey ContextKey = "bifrost-logging-dropped"
CreatedTimestampKey ContextKey = "bifrost-logging-created-timestamp"
DroppedCreateContextKey ContextKey = "logging-dropped"
CreatedTimestampKey ContextKey = "logging-created-timestamp"
)

// UpdateLogData contains data for log entry updates
Expand All @@ -59,7 +59,8 @@ type UpdateLogData struct {
// LogMessage represents a message in the logging queue
type LogMessage struct {
Operation LogOperation
RequestID string
RequestID string // Unique ID for the request
ParentRequestID string // Unique ID for the parent request
Timestamp time.Time // Of the preHook/postHook call
InitialData *InitialLogData // For create operations
SemanticCacheDebug *schemas.BifrostCacheDebug // For semantic cache operations
Expand Down Expand Up @@ -224,6 +225,7 @@ func (p *LoggerPlugin) PreHook(ctx *context.Context, req *schemas.BifrostRequest
p.logger.Error("request-id not found in context or is empty")
return req, nil, nil
}

createdTimestamp := time.Now()
// If request type is streaming we create a stream accumulator
if bifrost.IsStreamRequestType(req.RequestType) {
Expand Down Expand Up @@ -271,13 +273,22 @@ func (p *LoggerPlugin) PreHook(ctx *context.Context, req *schemas.BifrostRequest
// Queue the log creation message (non-blocking) - Using sync.Pool
logMsg := p.getLogMessage()
logMsg.Operation = LogOperationCreate
logMsg.RequestID = requestID

// If fallback request ID is present, use it instead of the primary request ID
fallbackRequestID, ok := (*ctx).Value(schemas.BifrostContextKeyFallbackRequestID).(string)
if ok && fallbackRequestID != "" {
logMsg.RequestID = fallbackRequestID
logMsg.ParentRequestID = requestID
} else {
logMsg.RequestID = requestID
}

logMsg.Timestamp = createdTimestamp
logMsg.InitialData = initialData

go func(logMsg *LogMessage) {
defer p.putLogMessage(logMsg) // Return to pool when done
if err := p.insertInitialLogEntry(p.ctx, logMsg.RequestID, logMsg.Timestamp, logMsg.InitialData); err != nil {
if err := p.insertInitialLogEntry(p.ctx, logMsg.RequestID, logMsg.ParentRequestID, logMsg.Timestamp, logMsg.InitialData); err != nil {
p.logger.Error("failed to insert initial log entry for request %s: %v", logMsg.RequestID, err)
} else {
// Call callback for initial log creation (WebSocket "create" message)
Expand Down Expand Up @@ -324,6 +335,11 @@ func (p *LoggerPlugin) PostHook(ctx *context.Context, result *schemas.BifrostRes
p.logger.Error("request-id not found in context or is empty")
return result, bifrostErr, nil
}
// If fallback request ID is present, use it instead of the primary request ID
fallbackRequestID, ok := (*ctx).Value(schemas.BifrostContextKeyFallbackRequestID).(string)
if ok && fallbackRequestID != "" {
requestID = fallbackRequestID
}
requestType, _, _ := bifrost.GetRequestFields(result, bifrostErr)
// Queue the log update message (non-blocking) - use same pattern for both streaming and regular
logMsg := p.getLogMessage()
Expand Down
6 changes: 5 additions & 1 deletion plugins/logging/operations.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ import (
)

// insertInitialLogEntry creates a new log entry in the database using GORM
func (p *LoggerPlugin) insertInitialLogEntry(ctx context.Context, requestID string, timestamp time.Time, data *InitialLogData) error {
func (p *LoggerPlugin) insertInitialLogEntry(ctx context.Context, requestID string, parentRequestID string, timestamp time.Time, data *InitialLogData) error {
entry := &logstore.Log{
ID: requestID,
Timestamp: timestamp,
Expand All @@ -30,6 +30,10 @@ func (p *LoggerPlugin) insertInitialLogEntry(ctx context.Context, requestID stri
TranscriptionInputParsed: data.TranscriptionInput,
}

if parentRequestID != "" {
entry.ParentRequestID = &parentRequestID
}

return p.store.Create(ctx, entry)
}

Expand Down
Loading