Skip to content

Commit 9c3a9b2

Browse files
Pratham-Mishra04TejasGhatte
authored andcommitted
feat: bedrock responses streaming added
1 parent 1fe62ac commit 9c3a9b2

File tree

2 files changed

+319
-58
lines changed

2 files changed

+319
-58
lines changed

core/providers/bedrock.go

Lines changed: 219 additions & 58 deletions
Original file line numberDiff line numberDiff line change
@@ -84,7 +84,7 @@ func (provider *BedrockProvider) GetProviderKey() schemas.ModelProvider {
8484
return getProviderName(schemas.Bedrock, provider.customProviderConfig)
8585
}
8686

87-
// CompleteRequest sends a request to Bedrock's API and handles the response.
87+
// completeRequest sends a request to Bedrock's API and handles the response.
8888
// It constructs the API URL, sets up AWS authentication, and processes the response.
8989
// Returns the response body, request latency, or an error if the request fails.
9090
func (provider *BedrockProvider) completeRequest(ctx context.Context, requestBody interface{}, path string, key schemas.Key) ([]byte, time.Duration, *schemas.BifrostError) {
@@ -206,6 +206,70 @@ func (provider *BedrockProvider) completeRequest(ctx context.Context, requestBod
206206
return body, latency, nil
207207
}
208208

209+
// makeStreamingRequest creates a streaming request to Bedrock's API.
210+
// It formats the request, sends it to Bedrock, and returns the response.
211+
// Returns the response body and an error if the request fails.
212+
func (provider *BedrockProvider) makeStreamingRequest(ctx context.Context, requestBody interface{}, key schemas.Key, model string) (*http.Response, *schemas.BifrostError) {
213+
providerName := provider.GetProviderKey()
214+
215+
if key.BedrockKeyConfig == nil {
216+
return nil, newConfigurationError("bedrock key config is not provided", providerName)
217+
}
218+
219+
// Format the path with proper model identifier for streaming
220+
path := provider.getModelPath("converse-stream", model, key)
221+
222+
region := "us-east-1"
223+
if key.BedrockKeyConfig.Region != nil {
224+
region = *key.BedrockKeyConfig.Region
225+
}
226+
227+
// Create the streaming request
228+
jsonBody, jsonErr := sonic.Marshal(requestBody)
229+
if jsonErr != nil {
230+
return nil, newBifrostOperationError(schemas.ErrProviderJSONMarshaling, jsonErr, providerName)
231+
}
232+
233+
// Create HTTP request for streaming
234+
req, reqErr := http.NewRequestWithContext(ctx, "POST", fmt.Sprintf("https://bedrock-runtime.%s.amazonaws.com/model/%s", region, path), bytes.NewReader(jsonBody))
235+
if reqErr != nil {
236+
return nil, newBifrostOperationError("error creating request", reqErr, providerName)
237+
}
238+
239+
// Set any extra headers from network config
240+
setExtraHeadersHTTP(req, provider.networkConfig.ExtraHeaders, nil)
241+
242+
// If Value is set, use API Key authentication - else use IAM role authentication
243+
if key.Value != "" {
244+
req.Header.Set("Accept", "application/vnd.amazon.eventstream")
245+
req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", key.Value))
246+
} else {
247+
req.Header.Set("Accept", "application/vnd.amazon.eventstream")
248+
// Sign the request using either explicit credentials or IAM role authentication
249+
if err := signAWSRequest(ctx, req, key.BedrockKeyConfig.AccessKey, key.BedrockKeyConfig.SecretKey, key.BedrockKeyConfig.SessionToken, region, "bedrock", providerName); err != nil {
250+
return nil, err
251+
}
252+
}
253+
254+
// Make the request
255+
resp, respErr := provider.client.Do(req)
256+
if respErr != nil {
257+
if errors.Is(respErr, fasthttp.ErrTimeout) || errors.Is(respErr, context.Canceled) || errors.Is(respErr, context.DeadlineExceeded) {
258+
return nil, newBifrostOperationError(schemas.ErrProviderRequestTimedOut, respErr, provider.GetProviderKey())
259+
}
260+
return nil, newBifrostOperationError(schemas.ErrProviderRequest, respErr, providerName)
261+
}
262+
263+
// Check for HTTP errors
264+
if resp.StatusCode != http.StatusOK {
265+
body, _ := io.ReadAll(resp.Body)
266+
resp.Body.Close()
267+
return nil, newProviderAPIError(fmt.Sprintf("HTTP error from %s: %d", providerName, resp.StatusCode), fmt.Errorf("%s", string(body)), resp.StatusCode, providerName, nil, nil)
268+
}
269+
270+
return resp, nil
271+
}
272+
209273
// signAWSRequest signs an HTTP request using AWS Signature Version 4.
210274
// It is used in providers like Bedrock.
211275
// It sets required headers, calculates the request body hash, and signs the request
@@ -424,64 +488,14 @@ func (provider *BedrockProvider) ChatCompletionStream(ctx context.Context, postH
424488

425489
providerName := provider.GetProviderKey()
426490

427-
if key.BedrockKeyConfig == nil {
428-
return nil, newConfigurationError("bedrock key config is not provided", providerName)
429-
}
430-
431491
reqBody, err := bedrock.ToBedrockChatCompletionRequest(request)
432492
if err != nil {
433493
return nil, newBifrostOperationError("failed to convert request", err, providerName)
434494
}
435495

436-
// Format the path with proper model identifier for streaming
437-
path := provider.getModelPath("converse-stream", request.Model, key)
438-
439-
region := "us-east-1"
440-
if key.BedrockKeyConfig.Region != nil {
441-
region = *key.BedrockKeyConfig.Region
442-
}
443-
444-
// Create the streaming request
445-
jsonBody, jsonErr := sonic.Marshal(reqBody)
446-
if jsonErr != nil {
447-
return nil, newBifrostOperationError(schemas.ErrProviderJSONMarshaling, jsonErr, providerName)
448-
}
449-
450-
// Create HTTP request for streaming
451-
req, reqErr := http.NewRequestWithContext(ctx, "POST", fmt.Sprintf("https://bedrock-runtime.%s.amazonaws.com/model/%s", region, path), bytes.NewReader(jsonBody))
452-
if reqErr != nil {
453-
return nil, newBifrostOperationError("error creating request", reqErr, providerName)
454-
}
455-
456-
// Set any extra headers from network config
457-
setExtraHeadersHTTP(req, provider.networkConfig.ExtraHeaders, nil)
458-
459-
// If Value is set, use API Key authentication - else use IAM role authentication
460-
if key.Value != "" {
461-
req.Header.Set("Accept", "application/vnd.amazon.eventstream")
462-
req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", key.Value))
463-
} else {
464-
req.Header.Set("Accept", "application/vnd.amazon.eventstream")
465-
// Sign the request using either explicit credentials or IAM role authentication
466-
if err := signAWSRequest(ctx, req, key.BedrockKeyConfig.AccessKey, key.BedrockKeyConfig.SecretKey, key.BedrockKeyConfig.SessionToken, region, "bedrock", providerName); err != nil {
467-
return nil, err
468-
}
469-
}
470-
471-
// Make the request
472-
resp, respErr := provider.client.Do(req)
473-
if respErr != nil {
474-
if errors.Is(respErr, fasthttp.ErrTimeout) || errors.Is(respErr, context.Canceled) || errors.Is(respErr, context.DeadlineExceeded) {
475-
return nil, newBifrostOperationError(schemas.ErrProviderRequestTimedOut, respErr, provider.GetProviderKey())
476-
}
477-
return nil, newBifrostOperationError(schemas.ErrProviderRequest, respErr, providerName)
478-
}
479-
480-
// Check for HTTP errors
481-
if resp.StatusCode != http.StatusOK {
482-
body, _ := io.ReadAll(resp.Body)
483-
resp.Body.Close()
484-
return nil, newProviderAPIError(fmt.Sprintf("HTTP error from %s: %d", providerName, resp.StatusCode), fmt.Errorf("%s", string(body)), resp.StatusCode, providerName, nil, nil)
496+
resp, bifrostErr := provider.makeStreamingRequest(ctx, reqBody, key, request.Model)
497+
if bifrostErr != nil {
498+
return nil, bifrostErr
485499
}
486500

487501
// Create response channel
@@ -655,6 +669,157 @@ func (provider *BedrockProvider) Responses(ctx context.Context, key schemas.Key,
655669
return bifrostResponse, nil
656670
}
657671

672+
// ChatCompletionStream performs a streaming chat completion request to Bedrock's API.
673+
// It formats the request, sends it to Bedrock, and processes the streaming response.
674+
// Returns a channel for streaming BifrostResponse objects or an error if the request fails.
675+
func (provider *BedrockProvider) ResponsesStream(ctx context.Context, postHookRunner schemas.PostHookRunner, key schemas.Key, request *schemas.BifrostResponsesRequest) (chan *schemas.BifrostStream, *schemas.BifrostError) {
676+
if err := checkOperationAllowed(schemas.Bedrock, provider.customProviderConfig, schemas.ChatCompletionStreamRequest); err != nil {
677+
return nil, err
678+
}
679+
680+
providerName := provider.GetProviderKey()
681+
682+
reqBody, err := bedrock.ToBedrockResponsesRequest(request)
683+
if err != nil {
684+
return nil, newBifrostOperationError("failed to convert request", err, providerName)
685+
}
686+
687+
resp, bifrostErr := provider.makeStreamingRequest(ctx, reqBody, key, request.Model)
688+
if bifrostErr != nil {
689+
return nil, bifrostErr
690+
}
691+
692+
// Create response channel
693+
responseChan := make(chan *schemas.BifrostStream, schemas.DefaultStreamBufferSize)
694+
695+
// Start streaming in a goroutine
696+
go func() {
697+
defer close(responseChan)
698+
defer resp.Body.Close()
699+
700+
// Process AWS Event Stream format
701+
var usage *schemas.LLMUsage
702+
chunkIndex := 0
703+
704+
// Process AWS Event Stream format using proper decoder
705+
startTime := time.Now()
706+
lastChunkTime := startTime
707+
decoder := eventstream.NewDecoder()
708+
payloadBuf := make([]byte, 0, 1024*1024) // 1MB payload buffer
709+
710+
for {
711+
// Decode a single EventStream message
712+
message, err := decoder.Decode(resp.Body, payloadBuf)
713+
if err != nil {
714+
if err == io.EOF {
715+
// End of stream - this is normal
716+
break
717+
}
718+
provider.logger.Warn(fmt.Sprintf("Error decoding %s EventStream message: %v", providerName, err))
719+
processAndSendError(ctx, postHookRunner, err, responseChan, schemas.ChatCompletionStreamRequest, providerName, request.Model, provider.logger)
720+
return
721+
}
722+
723+
// Process the decoded message payload (contains JSON for normal events)
724+
if len(message.Payload) > 0 {
725+
if msgTypeHeader := message.Headers.Get(":message-type"); msgTypeHeader != nil {
726+
if msgType := msgTypeHeader.String(); msgType != "event" {
727+
excType := msgType
728+
if excHeader := message.Headers.Get(":exception-type"); excHeader != nil {
729+
if v := excHeader.String(); v != "" {
730+
excType = v
731+
}
732+
}
733+
errMsg := string(message.Payload)
734+
err := fmt.Errorf("%s stream %s: %s", providerName, excType, errMsg)
735+
processAndSendError(ctx, postHookRunner, err, responseChan, schemas.ChatCompletionStreamRequest, providerName, request.Model, provider.logger)
736+
return
737+
}
738+
}
739+
740+
// Parse the JSON event into our typed structure
741+
var streamEvent bedrock.BedrockStreamEvent
742+
if err := sonic.Unmarshal(message.Payload, &streamEvent); err != nil {
743+
provider.logger.Debug(fmt.Sprintf("Failed to parse JSON from event buffer: %v, data: %s", err, string(message.Payload)))
744+
return
745+
}
746+
747+
if chunkIndex == 0 {
748+
sendCreatedEventResponsesChunk(ctx, postHookRunner, provider.GetProviderKey(), request.Model, startTime, responseChan, provider.logger)
749+
sendInProgressEventResponsesChunk(ctx, postHookRunner, provider.GetProviderKey(), request.Model, startTime, responseChan, provider.logger)
750+
chunkIndex = 2
751+
}
752+
753+
if streamEvent.Usage != nil {
754+
usage = &schemas.LLMUsage{
755+
ResponsesExtendedResponseUsage: &schemas.ResponsesExtendedResponseUsage{
756+
InputTokens: streamEvent.Usage.InputTokens,
757+
OutputTokens: streamEvent.Usage.OutputTokens,
758+
},
759+
TotalTokens: streamEvent.Usage.TotalTokens,
760+
}
761+
}
762+
763+
response, bifrostErr, _ := streamEvent.ToBifrostResponsesStream(chunkIndex)
764+
if response != nil {
765+
766+
response.ExtraFields = schemas.BifrostResponseExtraFields{
767+
RequestType: schemas.ResponsesStreamRequest,
768+
Provider: providerName,
769+
ModelRequested: request.Model,
770+
ChunkIndex: chunkIndex,
771+
Latency: time.Since(lastChunkTime).Milliseconds(),
772+
}
773+
chunkIndex++
774+
lastChunkTime = time.Now()
775+
776+
if provider.sendBackRawResponse {
777+
response.ExtraFields.RawResponse = string(message.Payload)
778+
}
779+
780+
processAndSendResponse(ctx, postHookRunner, response, responseChan, provider.logger)
781+
}
782+
if bifrostErr != nil {
783+
bifrostErr.ExtraFields = schemas.BifrostErrorExtraFields{
784+
RequestType: schemas.ResponsesStreamRequest,
785+
Provider: providerName,
786+
ModelRequested: request.Model,
787+
}
788+
processAndSendBifrostError(ctx, postHookRunner, bifrostErr, responseChan, provider.logger)
789+
return
790+
}
791+
}
792+
}
793+
794+
// Send final response
795+
response := &schemas.BifrostResponse{
796+
ResponsesStreamResponse: &schemas.ResponsesStreamResponse{
797+
Type: schemas.ResponsesStreamResponseTypeCompleted,
798+
SequenceNumber: chunkIndex + 1,
799+
Response: &schemas.ResponsesStreamResponseStruct{
800+
Usage: &schemas.ResponsesResponseUsage{
801+
ResponsesExtendedResponseUsage: &schemas.ResponsesExtendedResponseUsage{
802+
InputTokens: usage.InputTokens,
803+
OutputTokens: usage.OutputTokens,
804+
},
805+
TotalTokens: usage.TotalTokens,
806+
},
807+
},
808+
},
809+
ExtraFields: schemas.BifrostResponseExtraFields{
810+
RequestType: schemas.ResponsesStreamRequest,
811+
Provider: providerName,
812+
ModelRequested: request.Model,
813+
ChunkIndex: chunkIndex + 1,
814+
Latency: time.Since(startTime).Milliseconds(),
815+
},
816+
}
817+
handleStreamEndWithSuccess(ctx, response, postHookRunner, responseChan, provider.logger)
818+
}()
819+
820+
return responseChan, nil
821+
}
822+
658823
// Embedding generates embeddings for the given input text(s) using Amazon Bedrock.
659824
// Supports Titan and Cohere embedding models. Returns a BifrostResponse containing the embedding(s) and any error that occurred.
660825
func (provider *BedrockProvider) Embedding(ctx context.Context, key schemas.Key, request *schemas.BifrostEmbeddingRequest) (*schemas.BifrostResponse, *schemas.BifrostError) {
@@ -770,7 +935,3 @@ func (provider *BedrockProvider) getModelPath(basePath string, model string, key
770935

771936
return path
772937
}
773-
774-
func (provider *BedrockProvider) ResponsesStream(ctx context.Context, postHookRunner schemas.PostHookRunner, key schemas.Key, request *schemas.BifrostResponsesRequest) (chan *schemas.BifrostStream, *schemas.BifrostError) {
775-
return nil, newUnsupportedOperationError("responses stream", "bedrock")
776-
}

0 commit comments

Comments
 (0)