Skip to content

Commit f124b61

Browse files
Pratham-Mishra04TejasGhatte
authored andcommitted
feat: bedrock responses streaming added
1 parent 4b72401 commit f124b61

File tree

2 files changed

+319
-58
lines changed

2 files changed

+319
-58
lines changed

core/providers/bedrock.go

Lines changed: 219 additions & 58 deletions
Original file line numberDiff line numberDiff line change
@@ -83,7 +83,7 @@ func (provider *BedrockProvider) GetProviderKey() schemas.ModelProvider {
8383
return getProviderName(schemas.Bedrock, provider.customProviderConfig)
8484
}
8585

86-
// CompleteRequest sends a request to Bedrock's API and handles the response.
86+
// completeRequest sends a request to Bedrock's API and handles the response.
8787
// It constructs the API URL, sets up AWS authentication, and processes the response.
8888
// Returns the response body, request latency, or an error if the request fails.
8989
func (provider *BedrockProvider) completeRequest(ctx context.Context, requestBody interface{}, path string, key schemas.Key) ([]byte, time.Duration, *schemas.BifrostError) {
@@ -205,6 +205,70 @@ func (provider *BedrockProvider) completeRequest(ctx context.Context, requestBod
205205
return body, latency, nil
206206
}
207207

208+
// makeStreamingRequest creates a streaming request to Bedrock's API.
209+
// It formats the request, sends it to Bedrock, and returns the response.
210+
// Returns the response body and an error if the request fails.
211+
func (provider *BedrockProvider) makeStreamingRequest(ctx context.Context, requestBody interface{}, key schemas.Key, model string) (*http.Response, *schemas.BifrostError) {
212+
providerName := provider.GetProviderKey()
213+
214+
if key.BedrockKeyConfig == nil {
215+
return nil, newConfigurationError("bedrock key config is not provided", providerName)
216+
}
217+
218+
// Format the path with proper model identifier for streaming
219+
path := provider.getModelPath("converse-stream", model, key)
220+
221+
region := "us-east-1"
222+
if key.BedrockKeyConfig.Region != nil {
223+
region = *key.BedrockKeyConfig.Region
224+
}
225+
226+
// Create the streaming request
227+
jsonBody, jsonErr := sonic.Marshal(requestBody)
228+
if jsonErr != nil {
229+
return nil, newBifrostOperationError(schemas.ErrProviderJSONMarshaling, jsonErr, providerName)
230+
}
231+
232+
// Create HTTP request for streaming
233+
req, reqErr := http.NewRequestWithContext(ctx, "POST", fmt.Sprintf("https://bedrock-runtime.%s.amazonaws.com/model/%s", region, path), bytes.NewReader(jsonBody))
234+
if reqErr != nil {
235+
return nil, newBifrostOperationError("error creating request", reqErr, providerName)
236+
}
237+
238+
// Set any extra headers from network config
239+
setExtraHeadersHTTP(req, provider.networkConfig.ExtraHeaders, nil)
240+
241+
// If Value is set, use API Key authentication - else use IAM role authentication
242+
if key.Value != "" {
243+
req.Header.Set("Accept", "application/vnd.amazon.eventstream")
244+
req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", key.Value))
245+
} else {
246+
req.Header.Set("Accept", "application/vnd.amazon.eventstream")
247+
// Sign the request using either explicit credentials or IAM role authentication
248+
if err := signAWSRequest(ctx, req, key.BedrockKeyConfig.AccessKey, key.BedrockKeyConfig.SecretKey, key.BedrockKeyConfig.SessionToken, region, "bedrock", providerName); err != nil {
249+
return nil, err
250+
}
251+
}
252+
253+
// Make the request
254+
resp, respErr := provider.client.Do(req)
255+
if respErr != nil {
256+
if errors.Is(respErr, http.ErrHandlerTimeout) || errors.Is(respErr, context.Canceled) || errors.Is(respErr, context.DeadlineExceeded) {
257+
return nil, newBifrostOperationError(schemas.ErrProviderRequestTimedOut, respErr, provider.GetProviderKey())
258+
}
259+
return nil, newBifrostOperationError(schemas.ErrProviderRequest, respErr, providerName)
260+
}
261+
262+
// Check for HTTP errors
263+
if resp.StatusCode != http.StatusOK {
264+
body, _ := io.ReadAll(resp.Body)
265+
resp.Body.Close()
266+
return nil, newProviderAPIError(fmt.Sprintf("HTTP error from %s: %d", providerName, resp.StatusCode), fmt.Errorf("%s", string(body)), resp.StatusCode, providerName, nil, nil)
267+
}
268+
269+
return resp, nil
270+
}
271+
208272
// signAWSRequest signs an HTTP request using AWS Signature Version 4.
209273
// It is used in providers like Bedrock.
210274
// It sets required headers, calculates the request body hash, and signs the request
@@ -423,64 +487,14 @@ func (provider *BedrockProvider) ChatCompletionStream(ctx context.Context, postH
423487

424488
providerName := provider.GetProviderKey()
425489

426-
if key.BedrockKeyConfig == nil {
427-
return nil, newConfigurationError("bedrock key config is not provided", providerName)
428-
}
429-
430490
reqBody, err := bedrock.ToBedrockChatCompletionRequest(request)
431491
if err != nil {
432492
return nil, newBifrostOperationError("failed to convert request", err, providerName)
433493
}
434494

435-
// Format the path with proper model identifier for streaming
436-
path := provider.getModelPath("converse-stream", request.Model, key)
437-
438-
region := "us-east-1"
439-
if key.BedrockKeyConfig.Region != nil {
440-
region = *key.BedrockKeyConfig.Region
441-
}
442-
443-
// Create the streaming request
444-
jsonBody, jsonErr := sonic.Marshal(reqBody)
445-
if jsonErr != nil {
446-
return nil, newBifrostOperationError(schemas.ErrProviderJSONMarshaling, jsonErr, providerName)
447-
}
448-
449-
// Create HTTP request for streaming
450-
req, reqErr := http.NewRequestWithContext(ctx, "POST", fmt.Sprintf("https://bedrock-runtime.%s.amazonaws.com/model/%s", region, path), bytes.NewReader(jsonBody))
451-
if reqErr != nil {
452-
return nil, newBifrostOperationError("error creating request", reqErr, providerName)
453-
}
454-
455-
// Set any extra headers from network config
456-
setExtraHeadersHTTP(req, provider.networkConfig.ExtraHeaders, nil)
457-
458-
// If Value is set, use API Key authentication - else use IAM role authentication
459-
if key.Value != "" {
460-
req.Header.Set("Accept", "application/vnd.amazon.eventstream")
461-
req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", key.Value))
462-
} else {
463-
req.Header.Set("Accept", "application/vnd.amazon.eventstream")
464-
// Sign the request using either explicit credentials or IAM role authentication
465-
if err := signAWSRequest(ctx, req, key.BedrockKeyConfig.AccessKey, key.BedrockKeyConfig.SecretKey, key.BedrockKeyConfig.SessionToken, region, "bedrock", providerName); err != nil {
466-
return nil, err
467-
}
468-
}
469-
470-
// Make the request
471-
resp, respErr := provider.client.Do(req)
472-
if respErr != nil {
473-
if errors.Is(respErr, http.ErrHandlerTimeout) || errors.Is(respErr, context.Canceled) || errors.Is(respErr, context.DeadlineExceeded) {
474-
return nil, newBifrostOperationError(schemas.ErrProviderRequestTimedOut, respErr, provider.GetProviderKey())
475-
}
476-
return nil, newBifrostOperationError(schemas.ErrProviderRequest, respErr, providerName)
477-
}
478-
479-
// Check for HTTP errors
480-
if resp.StatusCode != http.StatusOK {
481-
body, _ := io.ReadAll(resp.Body)
482-
resp.Body.Close()
483-
return nil, newProviderAPIError(fmt.Sprintf("HTTP error from %s: %d", providerName, resp.StatusCode), fmt.Errorf("%s", string(body)), resp.StatusCode, providerName, nil, nil)
495+
resp, bifrostErr := provider.makeStreamingRequest(ctx, reqBody, key, request.Model)
496+
if bifrostErr != nil {
497+
return nil, bifrostErr
484498
}
485499

486500
// Create response channel
@@ -654,6 +668,157 @@ func (provider *BedrockProvider) Responses(ctx context.Context, key schemas.Key,
654668
return bifrostResponse, nil
655669
}
656670

671+
// ChatCompletionStream performs a streaming chat completion request to Bedrock's API.
672+
// It formats the request, sends it to Bedrock, and processes the streaming response.
673+
// Returns a channel for streaming BifrostResponse objects or an error if the request fails.
674+
func (provider *BedrockProvider) ResponsesStream(ctx context.Context, postHookRunner schemas.PostHookRunner, key schemas.Key, request *schemas.BifrostResponsesRequest) (chan *schemas.BifrostStream, *schemas.BifrostError) {
675+
if err := checkOperationAllowed(schemas.Bedrock, provider.customProviderConfig, schemas.ResponsesStreamRequest); err != nil {
676+
return nil, err
677+
}
678+
679+
providerName := provider.GetProviderKey()
680+
681+
reqBody, err := bedrock.ToBedrockResponsesRequest(request)
682+
if err != nil {
683+
return nil, newBifrostOperationError("failed to convert request", err, providerName)
684+
}
685+
686+
resp, bifrostErr := provider.makeStreamingRequest(ctx, reqBody, key, request.Model)
687+
if bifrostErr != nil {
688+
return nil, bifrostErr
689+
}
690+
691+
// Create response channel
692+
responseChan := make(chan *schemas.BifrostStream, schemas.DefaultStreamBufferSize)
693+
694+
// Start streaming in a goroutine
695+
go func() {
696+
defer close(responseChan)
697+
defer resp.Body.Close()
698+
699+
// Process AWS Event Stream format
700+
var usage *schemas.LLMUsage
701+
chunkIndex := 0
702+
703+
// Process AWS Event Stream format using proper decoder
704+
startTime := time.Now()
705+
lastChunkTime := startTime
706+
decoder := eventstream.NewDecoder()
707+
payloadBuf := make([]byte, 0, 1024*1024) // 1MB payload buffer
708+
709+
for {
710+
// Decode a single EventStream message
711+
message, err := decoder.Decode(resp.Body, payloadBuf)
712+
if err != nil {
713+
if err == io.EOF {
714+
// End of stream - this is normal
715+
break
716+
}
717+
provider.logger.Warn(fmt.Sprintf("Error decoding %s EventStream message: %v", providerName, err))
718+
processAndSendError(ctx, postHookRunner, err, responseChan, schemas.ResponsesStreamRequest, providerName, request.Model, provider.logger)
719+
return
720+
}
721+
722+
// Process the decoded message payload (contains JSON for normal events)
723+
if len(message.Payload) > 0 {
724+
if msgTypeHeader := message.Headers.Get(":message-type"); msgTypeHeader != nil {
725+
if msgType := msgTypeHeader.String(); msgType != "event" {
726+
excType := msgType
727+
if excHeader := message.Headers.Get(":exception-type"); excHeader != nil {
728+
if v := excHeader.String(); v != "" {
729+
excType = v
730+
}
731+
}
732+
errMsg := string(message.Payload)
733+
err := fmt.Errorf("%s stream %s: %s", providerName, excType, errMsg)
734+
processAndSendError(ctx, postHookRunner, err, responseChan, schemas.ChatCompletionStreamRequest, providerName, request.Model, provider.logger)
735+
return
736+
}
737+
}
738+
739+
// Parse the JSON event into our typed structure
740+
var streamEvent bedrock.BedrockStreamEvent
741+
if err := sonic.Unmarshal(message.Payload, &streamEvent); err != nil {
742+
provider.logger.Debug(fmt.Sprintf("Failed to parse JSON from event buffer: %v, data: %s", err, string(message.Payload)))
743+
return
744+
}
745+
746+
if chunkIndex == 0 {
747+
sendCreatedEventResponsesChunk(ctx, postHookRunner, provider.GetProviderKey(), request.Model, startTime, responseChan, provider.logger)
748+
sendInProgressEventResponsesChunk(ctx, postHookRunner, provider.GetProviderKey(), request.Model, startTime, responseChan, provider.logger)
749+
chunkIndex = 2
750+
}
751+
752+
if streamEvent.Usage != nil {
753+
usage = &schemas.LLMUsage{
754+
ResponsesExtendedResponseUsage: &schemas.ResponsesExtendedResponseUsage{
755+
InputTokens: streamEvent.Usage.InputTokens,
756+
OutputTokens: streamEvent.Usage.OutputTokens,
757+
},
758+
TotalTokens: streamEvent.Usage.TotalTokens,
759+
}
760+
}
761+
762+
response, bifrostErr, _ := streamEvent.ToBifrostResponsesStream(chunkIndex)
763+
if response != nil {
764+
765+
response.ExtraFields = schemas.BifrostResponseExtraFields{
766+
RequestType: schemas.ResponsesStreamRequest,
767+
Provider: providerName,
768+
ModelRequested: request.Model,
769+
ChunkIndex: chunkIndex,
770+
Latency: time.Since(lastChunkTime).Milliseconds(),
771+
}
772+
chunkIndex++
773+
lastChunkTime = time.Now()
774+
775+
if provider.sendBackRawResponse {
776+
response.ExtraFields.RawResponse = string(message.Payload)
777+
}
778+
779+
processAndSendResponse(ctx, postHookRunner, response, responseChan, provider.logger)
780+
}
781+
if bifrostErr != nil {
782+
bifrostErr.ExtraFields = schemas.BifrostErrorExtraFields{
783+
RequestType: schemas.ResponsesStreamRequest,
784+
Provider: providerName,
785+
ModelRequested: request.Model,
786+
}
787+
processAndSendBifrostError(ctx, postHookRunner, bifrostErr, responseChan, provider.logger)
788+
return
789+
}
790+
}
791+
}
792+
793+
// Send final response
794+
response := &schemas.BifrostResponse{
795+
ResponsesStreamResponse: &schemas.ResponsesStreamResponse{
796+
Type: schemas.ResponsesStreamResponseTypeCompleted,
797+
SequenceNumber: chunkIndex + 1,
798+
Response: &schemas.ResponsesStreamResponseStruct{
799+
Usage: &schemas.ResponsesResponseUsage{
800+
ResponsesExtendedResponseUsage: &schemas.ResponsesExtendedResponseUsage{
801+
InputTokens: usage.InputTokens,
802+
OutputTokens: usage.OutputTokens,
803+
},
804+
TotalTokens: usage.TotalTokens,
805+
},
806+
},
807+
},
808+
ExtraFields: schemas.BifrostResponseExtraFields{
809+
RequestType: schemas.ResponsesStreamRequest,
810+
Provider: providerName,
811+
ModelRequested: request.Model,
812+
ChunkIndex: chunkIndex + 1,
813+
Latency: time.Since(startTime).Milliseconds(),
814+
},
815+
}
816+
handleStreamEndWithSuccess(ctx, response, postHookRunner, responseChan, provider.logger)
817+
}()
818+
819+
return responseChan, nil
820+
}
821+
657822
// Embedding generates embeddings for the given input text(s) using Amazon Bedrock.
658823
// Supports Titan and Cohere embedding models. Returns a BifrostResponse containing the embedding(s) and any error that occurred.
659824
func (provider *BedrockProvider) Embedding(ctx context.Context, key schemas.Key, request *schemas.BifrostEmbeddingRequest) (*schemas.BifrostResponse, *schemas.BifrostError) {
@@ -769,7 +934,3 @@ func (provider *BedrockProvider) getModelPath(basePath string, model string, key
769934

770935
return path
771936
}
772-
773-
func (provider *BedrockProvider) ResponsesStream(ctx context.Context, postHookRunner schemas.PostHookRunner, key schemas.Key, request *schemas.BifrostResponsesRequest) (chan *schemas.BifrostStream, *schemas.BifrostError) {
774-
return nil, newUnsupportedOperationError("responses stream", "bedrock")
775-
}

0 commit comments

Comments
 (0)