Skip to content

Commit 39465aa

Browse files
feat: responses streaming added
1 parent 48c06ab commit 39465aa

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

51 files changed

+2416
-790
lines changed

core/providers/anthropic.go

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -395,7 +395,7 @@ func handleAnthropicStreaming(
395395
},
396396
}
397397
}
398-
if errors.Is(err, fasthttp.ErrTimeout) || errors.Is(err, context.DeadlineExceeded) {
398+
if errors.Is(err, http.ErrHandlerTimeout) || errors.Is(err, context.DeadlineExceeded) {
399399
return nil, newBifrostOperationError(schemas.ErrProviderRequestTimedOut, err, providerType)
400400
}
401401
return nil, newBifrostOperationError(schemas.ErrProviderRequest, err, providerType)
@@ -422,7 +422,7 @@ func handleAnthropicStreaming(
422422
},
423423
}
424424
}
425-
if errors.Is(err, fasthttp.ErrTimeout) || errors.Is(err, context.DeadlineExceeded) {
425+
if errors.Is(err, http.ErrHandlerTimeout) || errors.Is(err, context.DeadlineExceeded) {
426426
return nil, newBifrostOperationError(schemas.ErrProviderRequestTimedOut, err, providerType)
427427
}
428428
return nil, newBifrostOperationError(schemas.ErrProviderRequest, err, providerType)

core/providers/azure.go

Lines changed: 62 additions & 57 deletions
Original file line numberDiff line numberDiff line change
@@ -273,58 +273,6 @@ func (provider *AzureProvider) ChatCompletion(ctx context.Context, key schemas.K
273273
return response, nil
274274
}
275275

276-
// Responses performs a responses request to Azure's API.
277-
// It formats the request, sends it to Azure, and processes the response.
278-
// Returns a BifrostResponse containing the completion results or an error if the request fails.
279-
func (provider *AzureProvider) Responses(ctx context.Context, key schemas.Key, request *schemas.BifrostResponsesRequest) (*schemas.BifrostResponse, *schemas.BifrostError) {
280-
response, err := provider.ChatCompletion(ctx, key, request.ToChatRequest())
281-
if err != nil {
282-
return nil, err
283-
}
284-
285-
response.ToResponsesOnly()
286-
response.ExtraFields.RequestType = schemas.ResponsesRequest
287-
response.ExtraFields.Provider = provider.GetProviderKey()
288-
response.ExtraFields.ModelRequested = request.Model
289-
290-
return response, nil
291-
}
292-
293-
// Embedding generates embeddings for the given input text(s) using Azure OpenAI.
294-
// The input can be either a single string or a slice of strings for batch embedding.
295-
// Returns a BifrostResponse containing the embedding(s) and any error that occurred.
296-
func (provider *AzureProvider) Embedding(ctx context.Context, key schemas.Key, request *schemas.BifrostEmbeddingRequest) (*schemas.BifrostResponse, *schemas.BifrostError) {
297-
// Use centralized converter
298-
reqBody := openai.ToOpenAIEmbeddingRequest(request)
299-
if reqBody == nil {
300-
return nil, newBifrostOperationError("embedding input is not provided", nil, schemas.Azure)
301-
}
302-
303-
responseBody, latency, err := provider.completeRequest(ctx, reqBody, "embeddings", key, request.Model)
304-
if err != nil {
305-
return nil, err
306-
}
307-
308-
response := &schemas.BifrostResponse{}
309-
310-
// Use enhanced response handler with pre-allocated response
311-
rawResponse, bifrostErr := handleProviderResponse(responseBody, response, provider.sendBackRawResponse)
312-
if bifrostErr != nil {
313-
return nil, bifrostErr
314-
}
315-
316-
response.ExtraFields.Provider = schemas.Azure
317-
response.ExtraFields.Latency = latency.Milliseconds()
318-
response.ExtraFields.ModelRequested = request.Model
319-
response.ExtraFields.RequestType = schemas.EmbeddingRequest
320-
321-
if provider.sendBackRawResponse {
322-
response.ExtraFields.RawResponse = rawResponse
323-
}
324-
325-
return response, nil
326-
}
327-
328276
// ChatCompletionStream performs a streaming chat completion request to Azure's OpenAI API.
329277
// It supports real-time streaming of responses using Server-Sent Events (SSE).
330278
// Uses Azure-specific URL construction with deployments and supports both api-key and Bearer token authentication.
@@ -369,7 +317,7 @@ func (provider *AzureProvider) ChatCompletionStream(ctx context.Context, postHoo
369317
}
370318

371319
// Use shared streaming logic from OpenAI
372-
return handleOpenAIStreaming(
320+
return handleOpenAIChatCompletionStreaming(
373321
ctx,
374322
provider.streamClient,
375323
fullURL,
@@ -383,6 +331,67 @@ func (provider *AzureProvider) ChatCompletionStream(ctx context.Context, postHoo
383331
)
384332
}
385333

334+
// Responses performs a responses request to Azure's API.
335+
// It formats the request, sends it to Azure, and processes the response.
336+
// Returns a BifrostResponse containing the completion results or an error if the request fails.
337+
func (provider *AzureProvider) Responses(ctx context.Context, key schemas.Key, request *schemas.BifrostResponsesRequest) (*schemas.BifrostResponse, *schemas.BifrostError) {
338+
response, err := provider.ChatCompletion(ctx, key, request.ToChatRequest())
339+
if err != nil {
340+
return nil, err
341+
}
342+
343+
response.ToResponsesOnly()
344+
response.ExtraFields.RequestType = schemas.ResponsesRequest
345+
response.ExtraFields.Provider = provider.GetProviderKey()
346+
response.ExtraFields.ModelRequested = request.Model
347+
348+
return response, nil
349+
}
350+
351+
func (provider *AzureProvider) ResponsesStream(ctx context.Context, postHookRunner schemas.PostHookRunner, key schemas.Key, request *schemas.BifrostResponsesRequest) (chan *schemas.BifrostStream, *schemas.BifrostError) {
352+
return provider.ChatCompletionStream(
353+
ctx,
354+
getResponsesChunkConverterCombinedPostHookRunner(postHookRunner),
355+
key,
356+
request.ToChatRequest(),
357+
)
358+
}
359+
360+
// Embedding generates embeddings for the given input text(s) using Azure OpenAI.
361+
// The input can be either a single string or a slice of strings for batch embedding.
362+
// Returns a BifrostResponse containing the embedding(s) and any error that occurred.
363+
func (provider *AzureProvider) Embedding(ctx context.Context, key schemas.Key, request *schemas.BifrostEmbeddingRequest) (*schemas.BifrostResponse, *schemas.BifrostError) {
364+
// Use centralized converter
365+
reqBody := openai.ToOpenAIEmbeddingRequest(request)
366+
if reqBody == nil {
367+
return nil, newBifrostOperationError("embedding input is not provided", nil, schemas.Azure)
368+
}
369+
370+
responseBody, latency, err := provider.completeRequest(ctx, reqBody, "embeddings", key, request.Model)
371+
if err != nil {
372+
return nil, err
373+
}
374+
375+
response := &schemas.BifrostResponse{}
376+
377+
// Use enhanced response handler with pre-allocated response
378+
rawResponse, bifrostErr := handleProviderResponse(responseBody, response, provider.sendBackRawResponse)
379+
if bifrostErr != nil {
380+
return nil, bifrostErr
381+
}
382+
383+
response.ExtraFields.Provider = schemas.Azure
384+
response.ExtraFields.Latency = latency.Milliseconds()
385+
response.ExtraFields.ModelRequested = request.Model
386+
response.ExtraFields.RequestType = schemas.EmbeddingRequest
387+
388+
if provider.sendBackRawResponse {
389+
response.ExtraFields.RawResponse = rawResponse
390+
}
391+
392+
return response, nil
393+
}
394+
386395
func (provider *AzureProvider) Speech(ctx context.Context, key schemas.Key, request *schemas.BifrostSpeechRequest) (*schemas.BifrostResponse, *schemas.BifrostError) {
387396
return nil, newUnsupportedOperationError("speech", "azure")
388397
}
@@ -398,7 +407,3 @@ func (provider *AzureProvider) Transcription(ctx context.Context, key schemas.Ke
398407
func (provider *AzureProvider) TranscriptionStream(ctx context.Context, postHookRunner schemas.PostHookRunner, key schemas.Key, request *schemas.BifrostTranscriptionRequest) (chan *schemas.BifrostStream, *schemas.BifrostError) {
399408
return nil, newUnsupportedOperationError("transcription stream", "azure")
400409
}
401-
402-
func (provider *AzureProvider) ResponsesStream(ctx context.Context, postHookRunner schemas.PostHookRunner, key schemas.Key, request *schemas.BifrostResponsesRequest) (chan *schemas.BifrostStream, *schemas.BifrostError) {
403-
return nil, newUnsupportedOperationError("responses stream", "azure")
404-
}

core/providers/bedrock.go

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,6 @@ import (
2424
schemas "github.com/maximhq/bifrost/core/schemas"
2525
"github.com/maximhq/bifrost/core/schemas/providers/bedrock"
2626
cohere "github.com/maximhq/bifrost/core/schemas/providers/cohere"
27-
"github.com/valyala/fasthttp"
2827
)
2928

3029
// BedrockProvider implements the Provider interface for AWS Bedrock.
@@ -156,7 +155,7 @@ func (provider *BedrockProvider) completeRequest(ctx context.Context, requestBod
156155
},
157156
}
158157
}
159-
if errors.Is(err, fasthttp.ErrTimeout) || errors.Is(err, context.DeadlineExceeded) {
158+
if errors.Is(err, http.ErrHandlerTimeout) || errors.Is(err, context.DeadlineExceeded) {
160159
return nil, latency, newBifrostOperationError(schemas.ErrProviderRequestTimedOut, err, provider.GetProviderKey())
161160
}
162161
return nil, latency, &schemas.BifrostError{
@@ -471,7 +470,7 @@ func (provider *BedrockProvider) ChatCompletionStream(ctx context.Context, postH
471470
// Make the request
472471
resp, respErr := provider.client.Do(req)
473472
if respErr != nil {
474-
if errors.Is(respErr, fasthttp.ErrTimeout) || errors.Is(respErr, context.Canceled) || errors.Is(respErr, context.DeadlineExceeded) {
473+
if errors.Is(respErr, http.ErrHandlerTimeout) || errors.Is(respErr, context.Canceled) || errors.Is(respErr, context.DeadlineExceeded) {
475474
return nil, newBifrostOperationError(schemas.ErrProviderRequestTimedOut, respErr, provider.GetProviderKey())
476475
}
477476
return nil, newBifrostOperationError(schemas.ErrProviderRequest, respErr, providerName)

core/providers/cerebras.go

Lines changed: 29 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -111,32 +111,13 @@ func (provider *CerebrasProvider) ChatCompletion(ctx context.Context, key schema
111111
)
112112
}
113113

114-
func (provider *CerebrasProvider) Responses(ctx context.Context, key schemas.Key, request *schemas.BifrostResponsesRequest) (*schemas.BifrostResponse, *schemas.BifrostError) {
115-
response, err := provider.ChatCompletion(ctx, key, request.ToChatRequest())
116-
if err != nil {
117-
return nil, err
118-
}
119-
120-
response.ToResponsesOnly()
121-
response.ExtraFields.RequestType = schemas.ResponsesRequest
122-
response.ExtraFields.Provider = provider.GetProviderKey()
123-
response.ExtraFields.ModelRequested = request.Model
124-
125-
return response, nil
126-
}
127-
128-
// Embedding is not supported by the Cerebras provider.
129-
func (provider *CerebrasProvider) Embedding(ctx context.Context, key schemas.Key, request *schemas.BifrostEmbeddingRequest) (*schemas.BifrostResponse, *schemas.BifrostError) {
130-
return nil, newUnsupportedOperationError("embedding", "cerebras")
131-
}
132-
133114
// ChatCompletionStream performs a streaming chat completion request to the Cerebras API.
134115
// It supports real-time streaming of responses using Server-Sent Events (SSE).
135116
// Uses Cerebras's OpenAI-compatible streaming format.
136117
// Returns a channel containing BifrostResponse objects representing the stream or an error if the request fails.
137118
func (provider *CerebrasProvider) ChatCompletionStream(ctx context.Context, postHookRunner schemas.PostHookRunner, key schemas.Key, request *schemas.BifrostChatRequest) (chan *schemas.BifrostStream, *schemas.BifrostError) {
138119
// Use shared OpenAI-compatible streaming logic
139-
return handleOpenAIStreaming(
120+
return handleOpenAIChatCompletionStreaming(
140121
ctx,
141122
provider.streamClient,
142123
provider.networkConfig.BaseURL+"/v1/chat/completions",
@@ -150,6 +131,34 @@ func (provider *CerebrasProvider) ChatCompletionStream(ctx context.Context, post
150131
)
151132
}
152133

134+
func (provider *CerebrasProvider) Responses(ctx context.Context, key schemas.Key, request *schemas.BifrostResponsesRequest) (*schemas.BifrostResponse, *schemas.BifrostError) {
135+
response, err := provider.ChatCompletion(ctx, key, request.ToChatRequest())
136+
if err != nil {
137+
return nil, err
138+
}
139+
140+
response.ToResponsesOnly()
141+
response.ExtraFields.RequestType = schemas.ResponsesRequest
142+
response.ExtraFields.Provider = provider.GetProviderKey()
143+
response.ExtraFields.ModelRequested = request.Model
144+
145+
return response, nil
146+
}
147+
148+
func (provider *CerebrasProvider) ResponsesStream(ctx context.Context, postHookRunner schemas.PostHookRunner, key schemas.Key, request *schemas.BifrostResponsesRequest) (chan *schemas.BifrostStream, *schemas.BifrostError) {
149+
return provider.ChatCompletionStream(
150+
ctx,
151+
getResponsesChunkConverterCombinedPostHookRunner(postHookRunner),
152+
key,
153+
request.ToChatRequest(),
154+
)
155+
}
156+
157+
// Embedding is not supported by the Cerebras provider.
158+
func (provider *CerebrasProvider) Embedding(ctx context.Context, key schemas.Key, request *schemas.BifrostEmbeddingRequest) (*schemas.BifrostResponse, *schemas.BifrostError) {
159+
return nil, newUnsupportedOperationError("embedding", "cerebras")
160+
}
161+
153162
func (provider *CerebrasProvider) Speech(ctx context.Context, key schemas.Key, request *schemas.BifrostSpeechRequest) (*schemas.BifrostResponse, *schemas.BifrostError) {
154163
return nil, newUnsupportedOperationError("speech", "cerebras")
155164
}
@@ -165,7 +174,3 @@ func (provider *CerebrasProvider) Transcription(ctx context.Context, key schemas
165174
func (provider *CerebrasProvider) TranscriptionStream(ctx context.Context, postHookRunner schemas.PostHookRunner, key schemas.Key, request *schemas.BifrostTranscriptionRequest) (chan *schemas.BifrostStream, *schemas.BifrostError) {
166175
return nil, newUnsupportedOperationError("transcription stream", "cerebras")
167176
}
168-
169-
func (provider *CerebrasProvider) ResponsesStream(ctx context.Context, postHookRunner schemas.PostHookRunner, key schemas.Key, request *schemas.BifrostResponsesRequest) (chan *schemas.BifrostStream, *schemas.BifrostError) {
170-
return nil, newUnsupportedOperationError("responses stream", "cerebras")
171-
}

core/providers/cohere.go

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -377,7 +377,7 @@ func (provider *CohereProvider) ChatCompletionStream(ctx context.Context, postHo
377377
},
378378
}
379379
}
380-
if errors.Is(err, fasthttp.ErrTimeout) || errors.Is(err, context.DeadlineExceeded) {
380+
if errors.Is(err, http.ErrHandlerTimeout) || errors.Is(err, context.DeadlineExceeded) {
381381
return nil, newBifrostOperationError(schemas.ErrProviderRequestTimedOut, err, providerName)
382382
}
383383
return nil, newBifrostOperationError(schemas.ErrProviderRequest, err, providerName)
@@ -405,7 +405,7 @@ func (provider *CohereProvider) ChatCompletionStream(ctx context.Context, postHo
405405
},
406406
}
407407
}
408-
if errors.Is(err, fasthttp.ErrTimeout) || errors.Is(err, context.DeadlineExceeded) {
408+
if errors.Is(err, http.ErrHandlerTimeout) || errors.Is(err, context.DeadlineExceeded) {
409409
return nil, newBifrostOperationError(schemas.ErrProviderRequestTimedOut, err, providerName)
410410
}
411411
return nil, &schemas.BifrostError{

core/providers/gemini.go

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -177,7 +177,7 @@ func (provider *GeminiProvider) ChatCompletionStream(ctx context.Context, postHo
177177
}
178178

179179
// Use shared OpenAI-compatible streaming logic
180-
return handleOpenAIStreaming(
180+
return handleOpenAIChatCompletionStreaming(
181181
ctx,
182182
provider.streamClient,
183183
provider.networkConfig.BaseURL+"/openai/chat/completions",
@@ -310,7 +310,7 @@ func (provider *GeminiProvider) SpeechStream(ctx context.Context, postHookRunner
310310
},
311311
}
312312
}
313-
if errors.Is(err, fasthttp.ErrTimeout) || errors.Is(err, context.DeadlineExceeded) {
313+
if errors.Is(err, http.ErrHandlerTimeout) || errors.Is(err, context.DeadlineExceeded) {
314314
return nil, newBifrostOperationError(schemas.ErrProviderRequestTimedOut, err, providerName)
315315
}
316316
return nil, newBifrostOperationError(schemas.ErrProviderRequest, err, providerName)
@@ -338,7 +338,7 @@ func (provider *GeminiProvider) SpeechStream(ctx context.Context, postHookRunner
338338
},
339339
}
340340
}
341-
if errors.Is(err, fasthttp.ErrTimeout) || errors.Is(err, context.DeadlineExceeded) {
341+
if errors.Is(err, http.ErrHandlerTimeout) || errors.Is(err, context.DeadlineExceeded) {
342342
return nil, newBifrostOperationError(schemas.ErrProviderRequestTimedOut, err, providerName)
343343
}
344344
return nil, newBifrostOperationError(schemas.ErrProviderRequest, err, providerName)
@@ -594,7 +594,7 @@ func (provider *GeminiProvider) TranscriptionStream(ctx context.Context, postHoo
594594
},
595595
}
596596
}
597-
if errors.Is(err, fasthttp.ErrTimeout) || errors.Is(err, context.DeadlineExceeded) {
597+
if errors.Is(err, http.ErrHandlerTimeout) || errors.Is(err, context.DeadlineExceeded) {
598598
return nil, newBifrostOperationError(schemas.ErrProviderRequestTimedOut, err, providerName)
599599
}
600600
return nil, newBifrostOperationError(schemas.ErrProviderRequest, err, providerName)
@@ -622,7 +622,7 @@ func (provider *GeminiProvider) TranscriptionStream(ctx context.Context, postHoo
622622
},
623623
}
624624
}
625-
if errors.Is(err, fasthttp.ErrTimeout) || errors.Is(err, context.DeadlineExceeded) {
625+
if errors.Is(err, http.ErrHandlerTimeout) || errors.Is(err, context.DeadlineExceeded) {
626626
return nil, newBifrostOperationError(schemas.ErrProviderRequestTimedOut, err, providerName)
627627
}
628628
return nil, newBifrostOperationError(schemas.ErrProviderRequest, err, providerName)

0 commit comments

Comments
 (0)