@@ -83,7 +83,7 @@ func (provider *BedrockProvider) GetProviderKey() schemas.ModelProvider {
8383	return  getProviderName (schemas .Bedrock , provider .customProviderConfig )
8484}
8585
86- // CompleteRequest  sends a request to Bedrock's API and handles the response. 
86+ // completeRequest  sends a request to Bedrock's API and handles the response. 
8787// It constructs the API URL, sets up AWS authentication, and processes the response. 
8888// Returns the response body, request latency, or an error if the request fails. 
8989func  (provider  * BedrockProvider ) completeRequest (ctx  context.Context , requestBody  interface {}, path  string , key  schemas.Key ) ([]byte , time.Duration , * schemas.BifrostError ) {
@@ -205,6 +205,70 @@ func (provider *BedrockProvider) completeRequest(ctx context.Context, requestBod
205205	return  body , latency , nil 
206206}
207207
208+ // makeStreamingRequest creates a streaming request to Bedrock's API. 
209+ // It formats the request, sends it to Bedrock, and returns the response. 
210+ // Returns the response body and an error if the request fails. 
211+ func  (provider  * BedrockProvider ) makeStreamingRequest (ctx  context.Context , requestBody  interface {}, key  schemas.Key , model  string ) (* http.Response , * schemas.BifrostError ) {
212+ 	providerName  :=  provider .GetProviderKey ()
213+ 
214+ 	if  key .BedrockKeyConfig  ==  nil  {
215+ 		return  nil , newConfigurationError ("bedrock key config is not provided" , providerName )
216+ 	}
217+ 
218+ 	// Format the path with proper model identifier for streaming 
219+ 	path  :=  provider .getModelPath ("converse-stream" , model , key )
220+ 
221+ 	region  :=  "us-east-1" 
222+ 	if  key .BedrockKeyConfig .Region  !=  nil  {
223+ 		region  =  * key .BedrockKeyConfig .Region 
224+ 	}
225+ 
226+ 	// Create the streaming request 
227+ 	jsonBody , jsonErr  :=  sonic .Marshal (requestBody )
228+ 	if  jsonErr  !=  nil  {
229+ 		return  nil , newBifrostOperationError (schemas .ErrProviderJSONMarshaling , jsonErr , providerName )
230+ 	}
231+ 
232+ 	// Create HTTP request for streaming 
233+ 	req , reqErr  :=  http .NewRequestWithContext (ctx , "POST" , fmt .Sprintf ("https://bedrock-runtime.%s.amazonaws.com/model/%s" , region , path ), bytes .NewReader (jsonBody ))
234+ 	if  reqErr  !=  nil  {
235+ 		return  nil , newBifrostOperationError ("error creating request" , reqErr , providerName )
236+ 	}
237+ 
238+ 	// Set any extra headers from network config 
239+ 	setExtraHeadersHTTP (req , provider .networkConfig .ExtraHeaders , nil )
240+ 
241+ 	// If Value is set, use API Key authentication - else use IAM role authentication 
242+ 	if  key .Value  !=  ""  {
243+ 		req .Header .Set ("Accept" , "application/vnd.amazon.eventstream" )
244+ 		req .Header .Set ("Authorization" , fmt .Sprintf ("Bearer %s" , key .Value ))
245+ 	} else  {
246+ 		req .Header .Set ("Accept" , "application/vnd.amazon.eventstream" )
247+ 		// Sign the request using either explicit credentials or IAM role authentication 
248+ 		if  err  :=  signAWSRequest (ctx , req , key .BedrockKeyConfig .AccessKey , key .BedrockKeyConfig .SecretKey , key .BedrockKeyConfig .SessionToken , region , "bedrock" , providerName ); err  !=  nil  {
249+ 			return  nil , err 
250+ 		}
251+ 	}
252+ 
253+ 	// Make the request 
254+ 	resp , respErr  :=  provider .client .Do (req )
255+ 	if  respErr  !=  nil  {
256+ 		if  errors .Is (respErr , http .ErrHandlerTimeout ) ||  errors .Is (respErr , context .Canceled ) ||  errors .Is (respErr , context .DeadlineExceeded ) {
257+ 			return  nil , newBifrostOperationError (schemas .ErrProviderRequestTimedOut , respErr , provider .GetProviderKey ())
258+ 		}
259+ 		return  nil , newBifrostOperationError (schemas .ErrProviderRequest , respErr , providerName )
260+ 	}
261+ 
262+ 	// Check for HTTP errors 
263+ 	if  resp .StatusCode  !=  http .StatusOK  {
264+ 		body , _  :=  io .ReadAll (resp .Body )
265+ 		resp .Body .Close ()
266+ 		return  nil , newProviderAPIError (fmt .Sprintf ("HTTP error from %s: %d" , providerName , resp .StatusCode ), fmt .Errorf ("%s" , string (body )), resp .StatusCode , providerName , nil , nil )
267+ 	}
268+ 
269+ 	return  resp , nil 
270+ }
271+ 
208272// signAWSRequest signs an HTTP request using AWS Signature Version 4. 
209273// It is used in providers like Bedrock. 
210274// It sets required headers, calculates the request body hash, and signs the request 
@@ -423,64 +487,14 @@ func (provider *BedrockProvider) ChatCompletionStream(ctx context.Context, postH
423487
424488	providerName  :=  provider .GetProviderKey ()
425489
426- 	if  key .BedrockKeyConfig  ==  nil  {
427- 		return  nil , newConfigurationError ("bedrock key config is not provided" , providerName )
428- 	}
429- 
430490	reqBody , err  :=  bedrock .ToBedrockChatCompletionRequest (request )
431491	if  err  !=  nil  {
432492		return  nil , newBifrostOperationError ("failed to convert request" , err , providerName )
433493	}
434494
435- 	// Format the path with proper model identifier for streaming 
436- 	path  :=  provider .getModelPath ("converse-stream" , request .Model , key )
437- 
438- 	region  :=  "us-east-1" 
439- 	if  key .BedrockKeyConfig .Region  !=  nil  {
440- 		region  =  * key .BedrockKeyConfig .Region 
441- 	}
442- 
443- 	// Create the streaming request 
444- 	jsonBody , jsonErr  :=  sonic .Marshal (reqBody )
445- 	if  jsonErr  !=  nil  {
446- 		return  nil , newBifrostOperationError (schemas .ErrProviderJSONMarshaling , jsonErr , providerName )
447- 	}
448- 
449- 	// Create HTTP request for streaming 
450- 	req , reqErr  :=  http .NewRequestWithContext (ctx , "POST" , fmt .Sprintf ("https://bedrock-runtime.%s.amazonaws.com/model/%s" , region , path ), bytes .NewReader (jsonBody ))
451- 	if  reqErr  !=  nil  {
452- 		return  nil , newBifrostOperationError ("error creating request" , reqErr , providerName )
453- 	}
454- 
455- 	// Set any extra headers from network config 
456- 	setExtraHeadersHTTP (req , provider .networkConfig .ExtraHeaders , nil )
457- 
458- 	// If Value is set, use API Key authentication - else use IAM role authentication 
459- 	if  key .Value  !=  ""  {
460- 		req .Header .Set ("Accept" , "application/vnd.amazon.eventstream" )
461- 		req .Header .Set ("Authorization" , fmt .Sprintf ("Bearer %s" , key .Value ))
462- 	} else  {
463- 		req .Header .Set ("Accept" , "application/vnd.amazon.eventstream" )
464- 		// Sign the request using either explicit credentials or IAM role authentication 
465- 		if  err  :=  signAWSRequest (ctx , req , key .BedrockKeyConfig .AccessKey , key .BedrockKeyConfig .SecretKey , key .BedrockKeyConfig .SessionToken , region , "bedrock" , providerName ); err  !=  nil  {
466- 			return  nil , err 
467- 		}
468- 	}
469- 
470- 	// Make the request 
471- 	resp , respErr  :=  provider .client .Do (req )
472- 	if  respErr  !=  nil  {
473- 		if  errors .Is (respErr , http .ErrHandlerTimeout ) ||  errors .Is (respErr , context .Canceled ) ||  errors .Is (respErr , context .DeadlineExceeded ) {
474- 			return  nil , newBifrostOperationError (schemas .ErrProviderRequestTimedOut , respErr , provider .GetProviderKey ())
475- 		}
476- 		return  nil , newBifrostOperationError (schemas .ErrProviderRequest , respErr , providerName )
477- 	}
478- 
479- 	// Check for HTTP errors 
480- 	if  resp .StatusCode  !=  http .StatusOK  {
481- 		body , _  :=  io .ReadAll (resp .Body )
482- 		resp .Body .Close ()
483- 		return  nil , newProviderAPIError (fmt .Sprintf ("HTTP error from %s: %d" , providerName , resp .StatusCode ), fmt .Errorf ("%s" , string (body )), resp .StatusCode , providerName , nil , nil )
495+ 	resp , bifrostErr  :=  provider .makeStreamingRequest (ctx , reqBody , key , request .Model )
496+ 	if  bifrostErr  !=  nil  {
497+ 		return  nil , bifrostErr 
484498	}
485499
486500	// Create response channel 
@@ -654,6 +668,157 @@ func (provider *BedrockProvider) Responses(ctx context.Context, key schemas.Key,
654668	return  bifrostResponse , nil 
655669}
656670
671+ // ChatCompletionStream performs a streaming chat completion request to Bedrock's API. 
672+ // It formats the request, sends it to Bedrock, and processes the streaming response. 
673+ // Returns a channel for streaming BifrostResponse objects or an error if the request fails. 
674+ func  (provider  * BedrockProvider ) ResponsesStream (ctx  context.Context , postHookRunner  schemas.PostHookRunner , key  schemas.Key , request  * schemas.BifrostResponsesRequest ) (chan  * schemas.BifrostStream , * schemas.BifrostError ) {
675+ 	if  err  :=  checkOperationAllowed (schemas .Bedrock , provider .customProviderConfig , schemas .ResponsesStreamRequest ); err  !=  nil  {
676+ 		return  nil , err 
677+ 	}
678+ 
679+ 	providerName  :=  provider .GetProviderKey ()
680+ 
681+ 	reqBody , err  :=  bedrock .ToBedrockResponsesRequest (request )
682+ 	if  err  !=  nil  {
683+ 		return  nil , newBifrostOperationError ("failed to convert request" , err , providerName )
684+ 	}
685+ 
686+ 	resp , bifrostErr  :=  provider .makeStreamingRequest (ctx , reqBody , key , request .Model )
687+ 	if  bifrostErr  !=  nil  {
688+ 		return  nil , bifrostErr 
689+ 	}
690+ 
691+ 	// Create response channel 
692+ 	responseChan  :=  make (chan  * schemas.BifrostStream , schemas .DefaultStreamBufferSize )
693+ 
694+ 	// Start streaming in a goroutine 
695+ 	go  func () {
696+ 		defer  close (responseChan )
697+ 		defer  resp .Body .Close ()
698+ 
699+ 		// Process AWS Event Stream format 
700+ 		var  usage  * schemas.LLMUsage 
701+ 		chunkIndex  :=  0 
702+ 
703+ 		// Process AWS Event Stream format using proper decoder 
704+ 		startTime  :=  time .Now ()
705+ 		lastChunkTime  :=  startTime 
706+ 		decoder  :=  eventstream .NewDecoder ()
707+ 		payloadBuf  :=  make ([]byte , 0 , 1024 * 1024 ) // 1MB payload buffer 
708+ 
709+ 		for  {
710+ 			// Decode a single EventStream message 
711+ 			message , err  :=  decoder .Decode (resp .Body , payloadBuf )
712+ 			if  err  !=  nil  {
713+ 				if  err  ==  io .EOF  {
714+ 					// End of stream - this is normal 
715+ 					break 
716+ 				}
717+ 				provider .logger .Warn (fmt .Sprintf ("Error decoding %s EventStream message: %v" , providerName , err ))
718+ 				processAndSendError (ctx , postHookRunner , err , responseChan , schemas .ResponsesStreamRequest , providerName , request .Model , provider .logger )
719+ 				return 
720+ 			}
721+ 
722+ 			// Process the decoded message payload (contains JSON for normal events) 
723+ 			if  len (message .Payload ) >  0  {
724+ 				if  msgTypeHeader  :=  message .Headers .Get (":message-type" ); msgTypeHeader  !=  nil  {
725+ 					if  msgType  :=  msgTypeHeader .String (); msgType  !=  "event"  {
726+ 						excType  :=  msgType 
727+ 						if  excHeader  :=  message .Headers .Get (":exception-type" ); excHeader  !=  nil  {
728+ 							if  v  :=  excHeader .String (); v  !=  ""  {
729+ 								excType  =  v 
730+ 							}
731+ 						}
732+ 						errMsg  :=  string (message .Payload )
733+ 						err  :=  fmt .Errorf ("%s stream %s: %s" , providerName , excType , errMsg )
734+ 						processAndSendError (ctx , postHookRunner , err , responseChan , schemas .ChatCompletionStreamRequest , providerName , request .Model , provider .logger )
735+ 						return 
736+ 					}
737+ 				}
738+ 
739+ 				// Parse the JSON event into our typed structure 
740+ 				var  streamEvent  bedrock.BedrockStreamEvent 
741+ 				if  err  :=  sonic .Unmarshal (message .Payload , & streamEvent ); err  !=  nil  {
742+ 					provider .logger .Debug (fmt .Sprintf ("Failed to parse JSON from event buffer: %v, data: %s" , err , string (message .Payload )))
743+ 					return 
744+ 				}
745+ 
746+ 				if  chunkIndex  ==  0  {
747+ 					sendCreatedEventResponsesChunk (ctx , postHookRunner , provider .GetProviderKey (), request .Model , startTime , responseChan , provider .logger )
748+ 					sendInProgressEventResponsesChunk (ctx , postHookRunner , provider .GetProviderKey (), request .Model , startTime , responseChan , provider .logger )
749+ 					chunkIndex  =  2 
750+ 				}
751+ 
752+ 				if  streamEvent .Usage  !=  nil  {
753+ 					usage  =  & schemas.LLMUsage {
754+ 						ResponsesExtendedResponseUsage : & schemas.ResponsesExtendedResponseUsage {
755+ 							InputTokens :  streamEvent .Usage .InputTokens ,
756+ 							OutputTokens : streamEvent .Usage .OutputTokens ,
757+ 						},
758+ 						TotalTokens : streamEvent .Usage .TotalTokens ,
759+ 					}
760+ 				}
761+ 
762+ 				response , bifrostErr , _  :=  streamEvent .ToBifrostResponsesStream (chunkIndex )
763+ 				if  response  !=  nil  {
764+ 
765+ 					response .ExtraFields  =  schemas.BifrostResponseExtraFields {
766+ 						RequestType :    schemas .ResponsesStreamRequest ,
767+ 						Provider :       providerName ,
768+ 						ModelRequested : request .Model ,
769+ 						ChunkIndex :     chunkIndex ,
770+ 						Latency :        time .Since (lastChunkTime ).Milliseconds (),
771+ 					}
772+ 					chunkIndex ++ 
773+ 					lastChunkTime  =  time .Now ()
774+ 
775+ 					if  provider .sendBackRawResponse  {
776+ 						response .ExtraFields .RawResponse  =  string (message .Payload )
777+ 					}
778+ 
779+ 					processAndSendResponse (ctx , postHookRunner , response , responseChan , provider .logger )
780+ 				}
781+ 				if  bifrostErr  !=  nil  {
782+ 					bifrostErr .ExtraFields  =  schemas.BifrostErrorExtraFields {
783+ 						RequestType :    schemas .ResponsesStreamRequest ,
784+ 						Provider :       providerName ,
785+ 						ModelRequested : request .Model ,
786+ 					}
787+ 					processAndSendBifrostError (ctx , postHookRunner , bifrostErr , responseChan , provider .logger )
788+ 					return 
789+ 				}
790+ 			}
791+ 		}
792+ 
793+ 		// Send final response 
794+ 		response  :=  & schemas.BifrostResponse {
795+ 			ResponsesStreamResponse : & schemas.ResponsesStreamResponse {
796+ 				Type :           schemas .ResponsesStreamResponseTypeCompleted ,
797+ 				SequenceNumber : chunkIndex  +  1 ,
798+ 				Response : & schemas.ResponsesStreamResponseStruct {
799+ 					Usage : & schemas.ResponsesResponseUsage {
800+ 						ResponsesExtendedResponseUsage : & schemas.ResponsesExtendedResponseUsage {
801+ 							InputTokens :  usage .InputTokens ,
802+ 							OutputTokens : usage .OutputTokens ,
803+ 						},
804+ 						TotalTokens : usage .TotalTokens ,
805+ 					},
806+ 				},
807+ 			},
808+ 			ExtraFields : schemas.BifrostResponseExtraFields {
809+ 				RequestType :    schemas .ResponsesStreamRequest ,
810+ 				Provider :       providerName ,
811+ 				ModelRequested : request .Model ,
812+ 				ChunkIndex :     chunkIndex  +  1 ,
813+ 				Latency :        time .Since (startTime ).Milliseconds (),
814+ 			},
815+ 		}
816+ 		handleStreamEndWithSuccess (ctx , response , postHookRunner , responseChan , provider .logger )
817+ 	}()
818+ 
819+ 	return  responseChan , nil 
820+ }
821+ 
657822// Embedding generates embeddings for the given input text(s) using Amazon Bedrock. 
658823// Supports Titan and Cohere embedding models. Returns a BifrostResponse containing the embedding(s) and any error that occurred. 
659824func  (provider  * BedrockProvider ) Embedding (ctx  context.Context , key  schemas.Key , request  * schemas.BifrostEmbeddingRequest ) (* schemas.BifrostResponse , * schemas.BifrostError ) {
@@ -769,7 +934,3 @@ func (provider *BedrockProvider) getModelPath(basePath string, model string, key
769934
770935	return  path 
771936}
772- 
773- func  (provider  * BedrockProvider ) ResponsesStream (ctx  context.Context , postHookRunner  schemas.PostHookRunner , key  schemas.Key , request  * schemas.BifrostResponsesRequest ) (chan  * schemas.BifrostStream , * schemas.BifrostError ) {
774- 	return  nil , newUnsupportedOperationError ("responses stream" , "bedrock" )
775- }
0 commit comments