@@ -84,7 +84,7 @@ func (provider *BedrockProvider) GetProviderKey() schemas.ModelProvider {
8484 return getProviderName (schemas .Bedrock , provider .customProviderConfig )
8585}
8686
87- // CompleteRequest sends a request to Bedrock's API and handles the response.
87+ // completeRequest sends a request to Bedrock's API and handles the response.
8888// It constructs the API URL, sets up AWS authentication, and processes the response.
8989// Returns the response body, request latency, or an error if the request fails.
9090func (provider * BedrockProvider ) completeRequest (ctx context.Context , requestBody interface {}, path string , key schemas.Key ) ([]byte , time.Duration , * schemas.BifrostError ) {
@@ -206,6 +206,70 @@ func (provider *BedrockProvider) completeRequest(ctx context.Context, requestBod
206206 return body , latency , nil
207207}
208208
209+ // makeStreamingRequest creates a streaming request to Bedrock's API.
210+ // It formats the request, sends it to Bedrock, and returns the response.
211+ // Returns the response body and an error if the request fails.
212+ func (provider * BedrockProvider ) makeStreamingRequest (ctx context.Context , requestBody interface {}, key schemas.Key , model string ) (* http.Response , * schemas.BifrostError ) {
213+ providerName := provider .GetProviderKey ()
214+
215+ if key .BedrockKeyConfig == nil {
216+ return nil , newConfigurationError ("bedrock key config is not provided" , providerName )
217+ }
218+
219+ // Format the path with proper model identifier for streaming
220+ path := provider .getModelPath ("converse-stream" , model , key )
221+
222+ region := "us-east-1"
223+ if key .BedrockKeyConfig .Region != nil {
224+ region = * key .BedrockKeyConfig .Region
225+ }
226+
227+ // Create the streaming request
228+ jsonBody , jsonErr := sonic .Marshal (requestBody )
229+ if jsonErr != nil {
230+ return nil , newBifrostOperationError (schemas .ErrProviderJSONMarshaling , jsonErr , providerName )
231+ }
232+
233+ // Create HTTP request for streaming
234+ req , reqErr := http .NewRequestWithContext (ctx , "POST" , fmt .Sprintf ("https://bedrock-runtime.%s.amazonaws.com/model/%s" , region , path ), bytes .NewReader (jsonBody ))
235+ if reqErr != nil {
236+ return nil , newBifrostOperationError ("error creating request" , reqErr , providerName )
237+ }
238+
239+ // Set any extra headers from network config
240+ setExtraHeadersHTTP (req , provider .networkConfig .ExtraHeaders , nil )
241+
242+ // If Value is set, use API Key authentication - else use IAM role authentication
243+ if key .Value != "" {
244+ req .Header .Set ("Accept" , "application/vnd.amazon.eventstream" )
245+ req .Header .Set ("Authorization" , fmt .Sprintf ("Bearer %s" , key .Value ))
246+ } else {
247+ req .Header .Set ("Accept" , "application/vnd.amazon.eventstream" )
248+ // Sign the request using either explicit credentials or IAM role authentication
249+ if err := signAWSRequest (ctx , req , key .BedrockKeyConfig .AccessKey , key .BedrockKeyConfig .SecretKey , key .BedrockKeyConfig .SessionToken , region , "bedrock" , providerName ); err != nil {
250+ return nil , err
251+ }
252+ }
253+
254+ // Make the request
255+ resp , respErr := provider .client .Do (req )
256+ if respErr != nil {
257+ if errors .Is (respErr , fasthttp .ErrTimeout ) || errors .Is (respErr , context .Canceled ) || errors .Is (respErr , context .DeadlineExceeded ) {
258+ return nil , newBifrostOperationError (schemas .ErrProviderRequestTimedOut , respErr , provider .GetProviderKey ())
259+ }
260+ return nil , newBifrostOperationError (schemas .ErrProviderRequest , respErr , providerName )
261+ }
262+
263+ // Check for HTTP errors
264+ if resp .StatusCode != http .StatusOK {
265+ body , _ := io .ReadAll (resp .Body )
266+ resp .Body .Close ()
267+ return nil , newProviderAPIError (fmt .Sprintf ("HTTP error from %s: %d" , providerName , resp .StatusCode ), fmt .Errorf ("%s" , string (body )), resp .StatusCode , providerName , nil , nil )
268+ }
269+
270+ return resp , nil
271+ }
272+
209273// signAWSRequest signs an HTTP request using AWS Signature Version 4.
210274// It is used in providers like Bedrock.
211275// It sets required headers, calculates the request body hash, and signs the request
@@ -424,64 +488,14 @@ func (provider *BedrockProvider) ChatCompletionStream(ctx context.Context, postH
424488
425489 providerName := provider .GetProviderKey ()
426490
427- if key .BedrockKeyConfig == nil {
428- return nil , newConfigurationError ("bedrock key config is not provided" , providerName )
429- }
430-
431491 reqBody , err := bedrock .ToBedrockChatCompletionRequest (request )
432492 if err != nil {
433493 return nil , newBifrostOperationError ("failed to convert request" , err , providerName )
434494 }
435495
436- // Format the path with proper model identifier for streaming
437- path := provider .getModelPath ("converse-stream" , request .Model , key )
438-
439- region := "us-east-1"
440- if key .BedrockKeyConfig .Region != nil {
441- region = * key .BedrockKeyConfig .Region
442- }
443-
444- // Create the streaming request
445- jsonBody , jsonErr := sonic .Marshal (reqBody )
446- if jsonErr != nil {
447- return nil , newBifrostOperationError (schemas .ErrProviderJSONMarshaling , jsonErr , providerName )
448- }
449-
450- // Create HTTP request for streaming
451- req , reqErr := http .NewRequestWithContext (ctx , "POST" , fmt .Sprintf ("https://bedrock-runtime.%s.amazonaws.com/model/%s" , region , path ), bytes .NewReader (jsonBody ))
452- if reqErr != nil {
453- return nil , newBifrostOperationError ("error creating request" , reqErr , providerName )
454- }
455-
456- // Set any extra headers from network config
457- setExtraHeadersHTTP (req , provider .networkConfig .ExtraHeaders , nil )
458-
459- // If Value is set, use API Key authentication - else use IAM role authentication
460- if key .Value != "" {
461- req .Header .Set ("Accept" , "application/vnd.amazon.eventstream" )
462- req .Header .Set ("Authorization" , fmt .Sprintf ("Bearer %s" , key .Value ))
463- } else {
464- req .Header .Set ("Accept" , "application/vnd.amazon.eventstream" )
465- // Sign the request using either explicit credentials or IAM role authentication
466- if err := signAWSRequest (ctx , req , key .BedrockKeyConfig .AccessKey , key .BedrockKeyConfig .SecretKey , key .BedrockKeyConfig .SessionToken , region , "bedrock" , providerName ); err != nil {
467- return nil , err
468- }
469- }
470-
471- // Make the request
472- resp , respErr := provider .client .Do (req )
473- if respErr != nil {
474- if errors .Is (respErr , fasthttp .ErrTimeout ) || errors .Is (respErr , context .Canceled ) || errors .Is (respErr , context .DeadlineExceeded ) {
475- return nil , newBifrostOperationError (schemas .ErrProviderRequestTimedOut , respErr , provider .GetProviderKey ())
476- }
477- return nil , newBifrostOperationError (schemas .ErrProviderRequest , respErr , providerName )
478- }
479-
480- // Check for HTTP errors
481- if resp .StatusCode != http .StatusOK {
482- body , _ := io .ReadAll (resp .Body )
483- resp .Body .Close ()
484- return nil , newProviderAPIError (fmt .Sprintf ("HTTP error from %s: %d" , providerName , resp .StatusCode ), fmt .Errorf ("%s" , string (body )), resp .StatusCode , providerName , nil , nil )
496+ resp , bifrostErr := provider .makeStreamingRequest (ctx , reqBody , key , request .Model )
497+ if bifrostErr != nil {
498+ return nil , bifrostErr
485499 }
486500
487501 // Create response channel
@@ -655,6 +669,157 @@ func (provider *BedrockProvider) Responses(ctx context.Context, key schemas.Key,
655669 return bifrostResponse , nil
656670}
657671
672+ // ChatCompletionStream performs a streaming chat completion request to Bedrock's API.
673+ // It formats the request, sends it to Bedrock, and processes the streaming response.
674+ // Returns a channel for streaming BifrostResponse objects or an error if the request fails.
675+ func (provider * BedrockProvider ) ResponsesStream (ctx context.Context , postHookRunner schemas.PostHookRunner , key schemas.Key , request * schemas.BifrostResponsesRequest ) (chan * schemas.BifrostStream , * schemas.BifrostError ) {
676+ if err := checkOperationAllowed (schemas .Bedrock , provider .customProviderConfig , schemas .ChatCompletionStreamRequest ); err != nil {
677+ return nil , err
678+ }
679+
680+ providerName := provider .GetProviderKey ()
681+
682+ reqBody , err := bedrock .ToBedrockResponsesRequest (request )
683+ if err != nil {
684+ return nil , newBifrostOperationError ("failed to convert request" , err , providerName )
685+ }
686+
687+ resp , bifrostErr := provider .makeStreamingRequest (ctx , reqBody , key , request .Model )
688+ if bifrostErr != nil {
689+ return nil , bifrostErr
690+ }
691+
692+ // Create response channel
693+ responseChan := make (chan * schemas.BifrostStream , schemas .DefaultStreamBufferSize )
694+
695+ // Start streaming in a goroutine
696+ go func () {
697+ defer close (responseChan )
698+ defer resp .Body .Close ()
699+
700+ // Process AWS Event Stream format
701+ var usage * schemas.LLMUsage
702+ chunkIndex := 0
703+
704+ // Process AWS Event Stream format using proper decoder
705+ startTime := time .Now ()
706+ lastChunkTime := startTime
707+ decoder := eventstream .NewDecoder ()
708+ payloadBuf := make ([]byte , 0 , 1024 * 1024 ) // 1MB payload buffer
709+
710+ for {
711+ // Decode a single EventStream message
712+ message , err := decoder .Decode (resp .Body , payloadBuf )
713+ if err != nil {
714+ if err == io .EOF {
715+ // End of stream - this is normal
716+ break
717+ }
718+ provider .logger .Warn (fmt .Sprintf ("Error decoding %s EventStream message: %v" , providerName , err ))
719+ processAndSendError (ctx , postHookRunner , err , responseChan , schemas .ChatCompletionStreamRequest , providerName , request .Model , provider .logger )
720+ return
721+ }
722+
723+ // Process the decoded message payload (contains JSON for normal events)
724+ if len (message .Payload ) > 0 {
725+ if msgTypeHeader := message .Headers .Get (":message-type" ); msgTypeHeader != nil {
726+ if msgType := msgTypeHeader .String (); msgType != "event" {
727+ excType := msgType
728+ if excHeader := message .Headers .Get (":exception-type" ); excHeader != nil {
729+ if v := excHeader .String (); v != "" {
730+ excType = v
731+ }
732+ }
733+ errMsg := string (message .Payload )
734+ err := fmt .Errorf ("%s stream %s: %s" , providerName , excType , errMsg )
735+ processAndSendError (ctx , postHookRunner , err , responseChan , schemas .ChatCompletionStreamRequest , providerName , request .Model , provider .logger )
736+ return
737+ }
738+ }
739+
740+ // Parse the JSON event into our typed structure
741+ var streamEvent bedrock.BedrockStreamEvent
742+ if err := sonic .Unmarshal (message .Payload , & streamEvent ); err != nil {
743+ provider .logger .Debug (fmt .Sprintf ("Failed to parse JSON from event buffer: %v, data: %s" , err , string (message .Payload )))
744+ return
745+ }
746+
747+ if chunkIndex == 0 {
748+ sendCreatedEventResponsesChunk (ctx , postHookRunner , provider .GetProviderKey (), request .Model , startTime , responseChan , provider .logger )
749+ sendInProgressEventResponsesChunk (ctx , postHookRunner , provider .GetProviderKey (), request .Model , startTime , responseChan , provider .logger )
750+ chunkIndex = 2
751+ }
752+
753+ if streamEvent .Usage != nil {
754+ usage = & schemas.LLMUsage {
755+ ResponsesExtendedResponseUsage : & schemas.ResponsesExtendedResponseUsage {
756+ InputTokens : streamEvent .Usage .InputTokens ,
757+ OutputTokens : streamEvent .Usage .OutputTokens ,
758+ },
759+ TotalTokens : streamEvent .Usage .TotalTokens ,
760+ }
761+ }
762+
763+ response , bifrostErr , _ := streamEvent .ToBifrostResponsesStream (chunkIndex )
764+ if response != nil {
765+
766+ response .ExtraFields = schemas.BifrostResponseExtraFields {
767+ RequestType : schemas .ResponsesStreamRequest ,
768+ Provider : providerName ,
769+ ModelRequested : request .Model ,
770+ ChunkIndex : chunkIndex ,
771+ Latency : time .Since (lastChunkTime ).Milliseconds (),
772+ }
773+ chunkIndex ++
774+ lastChunkTime = time .Now ()
775+
776+ if provider .sendBackRawResponse {
777+ response .ExtraFields .RawResponse = string (message .Payload )
778+ }
779+
780+ processAndSendResponse (ctx , postHookRunner , response , responseChan , provider .logger )
781+ }
782+ if bifrostErr != nil {
783+ bifrostErr .ExtraFields = schemas.BifrostErrorExtraFields {
784+ RequestType : schemas .ResponsesStreamRequest ,
785+ Provider : providerName ,
786+ ModelRequested : request .Model ,
787+ }
788+ processAndSendBifrostError (ctx , postHookRunner , bifrostErr , responseChan , provider .logger )
789+ return
790+ }
791+ }
792+ }
793+
794+ // Send final response
795+ response := & schemas.BifrostResponse {
796+ ResponsesStreamResponse : & schemas.ResponsesStreamResponse {
797+ Type : schemas .ResponsesStreamResponseTypeCompleted ,
798+ SequenceNumber : chunkIndex + 1 ,
799+ Response : & schemas.ResponsesStreamResponseStruct {
800+ Usage : & schemas.ResponsesResponseUsage {
801+ ResponsesExtendedResponseUsage : & schemas.ResponsesExtendedResponseUsage {
802+ InputTokens : usage .InputTokens ,
803+ OutputTokens : usage .OutputTokens ,
804+ },
805+ TotalTokens : usage .TotalTokens ,
806+ },
807+ },
808+ },
809+ ExtraFields : schemas.BifrostResponseExtraFields {
810+ RequestType : schemas .ResponsesStreamRequest ,
811+ Provider : providerName ,
812+ ModelRequested : request .Model ,
813+ ChunkIndex : chunkIndex + 1 ,
814+ Latency : time .Since (startTime ).Milliseconds (),
815+ },
816+ }
817+ handleStreamEndWithSuccess (ctx , response , postHookRunner , responseChan , provider .logger )
818+ }()
819+
820+ return responseChan , nil
821+ }
822+
658823// Embedding generates embeddings for the given input text(s) using Amazon Bedrock.
659824// Supports Titan and Cohere embedding models. Returns a BifrostResponse containing the embedding(s) and any error that occurred.
660825func (provider * BedrockProvider ) Embedding (ctx context.Context , key schemas.Key , request * schemas.BifrostEmbeddingRequest ) (* schemas.BifrostResponse , * schemas.BifrostError ) {
@@ -770,7 +935,3 @@ func (provider *BedrockProvider) getModelPath(basePath string, model string, key
770935
771936 return path
772937}
773-
774- func (provider * BedrockProvider ) ResponsesStream (ctx context.Context , postHookRunner schemas.PostHookRunner , key schemas.Key , request * schemas.BifrostResponsesRequest ) (chan * schemas.BifrostStream , * schemas.BifrostError ) {
775- return nil , newUnsupportedOperationError ("responses stream" , "bedrock" )
776- }
0 commit comments