15
15
)
16
16
17
17
import boto3
18
- from fastapi import HTTPException
19
18
from llmstudio_core .exceptions import ProviderError
20
19
from llmstudio_core .providers .provider import ChatRequest , ProviderCore , provider
21
20
30
29
)
31
30
from pydantic import ValidationError
32
31
32
+ SERVICE = "bedrock-runtime"
33
+
33
34
34
35
@provider
35
- class BedrockAntropicProvider (ProviderCore ):
36
+ class BedrockAnthropicProvider (ProviderCore ):
36
37
def __init__ (self , config , ** kwargs ):
37
38
super ().__init__ (config , ** kwargs )
38
- self .access_key = (
39
- self .access_key if self .access_key else os .getenv ("BEDROCK_ACCESS_KEY" )
40
- )
41
- self .secret_key = (
42
- self .secret_key if self .secret_key else os .getenv ("BEDROCK_SECRET_KEY" )
39
+ self ._client = boto3 .client (
40
+ SERVICE ,
41
+ region_name = self .region if self .region else os .getenv ("BEDROCK_REGION" ),
42
+ aws_access_key_id = self .access_key
43
+ if self .access_key
44
+ else os .getenv ("BEDROCK_ACCESS_KEY" ),
45
+ aws_secret_access_key = self .secret_key
46
+ if self .secret_key
47
+ else os .getenv ("BEDROCK_SECRET_KEY" ),
43
48
)
44
- self .region = self .region if self .region else os .getenv ("BEDROCK_REGION" )
45
49
46
50
@staticmethod
47
51
def _provider_config_name ():
@@ -57,26 +61,6 @@ async def agenerate_client(self, request: ChatRequest) -> Coroutine[Any, Any, An
57
61
def generate_client (self , request : ChatRequest ) -> Coroutine [Any , Any , Generator ]:
58
62
"""Generate an AWS Bedrock client"""
59
63
try :
60
-
61
- service = "bedrock-runtime"
62
-
63
- if (
64
- self .access_key is None
65
- or self .secret_key is None
66
- or self .region is None
67
- ):
68
- raise HTTPException (
69
- status_code = 400 ,
70
- detail = "AWS credentials were not given or not set in environment variables." ,
71
- )
72
-
73
- client = boto3 .client (
74
- service ,
75
- region_name = self .region ,
76
- aws_access_key_id = self .access_key ,
77
- aws_secret_access_key = self .secret_key ,
78
- )
79
-
80
64
messages , system_prompt = self ._process_messages (request .chat_input )
81
65
tools = self ._process_tools (request .parameters )
82
66
@@ -95,7 +79,7 @@ def generate_client(self, request: ChatRequest) -> Coroutine[Any, Any, Generator
95
79
if tools :
96
80
client_params ["toolConfig" ] = tools
97
81
98
- return client .converse_stream (** client_params )
82
+ return self . _client .converse_stream (** client_params )
99
83
except Exception as e :
100
84
raise ProviderError (str (e ))
101
85
@@ -249,18 +233,7 @@ def parse_response(self, response: AsyncGenerator[Any, None], **kwargs) -> Any:
249
233
def _process_messages (
250
234
chat_input : Union [str , List [Dict [str , str ]]]
251
235
) -> List [Dict [str , Union [List [Dict [str , str ]], str ]]]:
252
- """
253
- Generate input text for the Bedrock API based on the provided chat input.
254
-
255
- Args:
256
- chat_input (Union[str, List[Dict[str, str]]]): The input text or a list of message dictionaries.
257
-
258
- Returns:
259
- List[Dict[str, Union[List[Dict[str, str]], str]]]: A list of formatted messages for the Bedrock API.
260
236
261
- Raises:
262
- HTTPException: If the input is invalid.
263
- """
264
237
if isinstance (chat_input , str ):
265
238
return [
266
239
{
0 commit comments