1
- import asyncio
2
- import json
3
- import os
4
- import time
5
- import uuid
6
- from typing import (
7
- Any ,
8
- AsyncGenerator ,
9
- Coroutine ,
10
- Dict ,
11
- Generator ,
12
- List ,
13
- Optional ,
14
- Union ,
15
- )
1
+ from typing import Any , AsyncGenerator , Coroutine , Generator
16
2
17
- import boto3
18
- from fastapi import HTTPException
19
- from llmstudio_core .exceptions import ProviderError
3
+ from llmstudio_core .providers .bedrock_providers .antropic import BedrockAntropicProvider
20
4
from llmstudio_core .providers .provider import ChatRequest , ProviderCore , provider
21
- from llmstudio_core .utils import OpenAITool , OpenAIToolFunction
22
- from openai .types .chat import ChatCompletionChunk
23
- from openai .types .chat .chat_completion_chunk import (
24
- Choice ,
25
- ChoiceDelta ,
26
- ChoiceDeltaToolCall ,
27
- ChoiceDeltaToolCallFunction ,
28
- )
29
- from pydantic import ValidationError
5
+
6
+ SUPORTED_PROVIDERS = ["antropic" ]
30
7
31
8
32
9
@provider
33
10
class BedrockProvider (ProviderCore ):
34
11
def __init__ (self , config , ** kwargs ):
35
12
super ().__init__ (config , ** kwargs )
36
- self .access_key = (
37
- self .access_key if self .access_key else os .getenv ("BEDROCK_ACCESS_KEY" )
38
- )
39
- self .secret_key = (
40
- self .secret_key if self .secret_key else os .getenv ("BEDROCK_SECRET_KEY" )
41
- )
42
- self .region = self .region if self .region else os .getenv ("BEDROCK_REGION" )
13
+ self .kwargs = kwargs
14
+ self .selected_model = None
15
+
16
+ def _get_provider (self , model ):
17
+ if "anthropic." in model :
18
+ return BedrockAntropicProvider (config = self .config , ** self .kwargs )
19
+
20
+ raise ValueError (f" provider is not yet supported." )
43
21
44
22
@staticmethod
45
23
def _provider_config_name ():
@@ -49,325 +27,18 @@ def validate_request(self, request: ChatRequest):
49
27
return ChatRequest (** request )
50
28
51
29
async def agenerate_client (self , request : ChatRequest ) -> Coroutine [Any , Any , Any ]:
52
- """Generate an AWS Bedrock client"""
53
- return await asyncio . to_thread ( self .generate_client , request )
30
+ self . selected_model = self . _get_provider ( request . model )
31
+ return self .selected_model . agenerate_client ( )
54
32
55
33
def generate_client (self , request : ChatRequest ) -> Coroutine [Any , Any , Generator ]:
56
- """Generate an AWS Bedrock client"""
57
- try :
58
-
59
- service = "bedrock-runtime"
60
-
61
- if (
62
- self .access_key is None
63
- or self .secret_key is None
64
- or self .region is None
65
- ):
66
- raise HTTPException (
67
- status_code = 400 ,
68
- detail = "AWS credentials were not given or not set in environment variables." ,
69
- )
70
-
71
- client = boto3 .client (
72
- service ,
73
- region_name = self .region ,
74
- aws_access_key_id = self .access_key ,
75
- aws_secret_access_key = self .secret_key ,
76
- )
77
-
78
- messages , system_prompt = self ._process_messages (request .chat_input )
79
- tools = self ._process_tools (request .parameters )
80
-
81
- system_prompt = (
82
- request .parameters .get ("system" )
83
- if request .parameters .get ("system" )
84
- else system_prompt
85
- )
86
-
87
- client_params = {
88
- "modelId" : request .model ,
89
- "messages" : messages ,
90
- "inferenceConfig" : self ._process_parameters (request .parameters ),
91
- "system" : system_prompt ,
92
- }
93
- if tools :
94
- client_params ["toolConfig" ] = tools
95
-
96
- return client .converse_stream (** client_params )
97
- except Exception as e :
98
- raise ProviderError (str (e ))
34
+ self .selected_model = self ._get_provider (request .model )
35
+ client = self .selected_model .generate_client (request = request )
36
+ return client
99
37
100
38
async def aparse_response (
101
39
self , response : Any , ** kwargs
102
40
) -> AsyncGenerator [str , None ]:
103
- iterator = await asyncio .to_thread (
104
- self .parse_response , response = response , ** kwargs
105
- )
106
- for item in iterator :
107
- yield item
41
+ return self .selected_model .aparse_response (response = response , ** kwargs )
108
42
109
43
def parse_response (self , response : AsyncGenerator [Any , None ], ** kwargs ) -> Any :
110
- tool_name = None
111
- tool_arguments = ""
112
- tool_id = None
113
-
114
- for chunk in response ["stream" ]:
115
- if chunk .get ("messageStart" ):
116
- first_chunk = ChatCompletionChunk (
117
- id = str (uuid .uuid4 ()),
118
- choices = [
119
- Choice (
120
- delta = ChoiceDelta (
121
- content = None ,
122
- function_call = None ,
123
- role = "assistant" ,
124
- tool_calls = None ,
125
- ),
126
- index = 0 ,
127
- )
128
- ],
129
- created = int (time .time ()),
130
- model = kwargs .get ("request" ).model ,
131
- object = "chat.completion.chunk" ,
132
- usage = None ,
133
- )
134
- yield first_chunk .model_dump ()
135
-
136
- elif chunk .get ("contentBlockStart" ):
137
- if chunk ["contentBlockStart" ]["start" ].get ("toolUse" ):
138
- tool_name = chunk ["contentBlockStart" ]["start" ]["toolUse" ]["name" ]
139
- tool_arguments = ""
140
- tool_id = chunk ["contentBlockStart" ]["start" ]["toolUse" ][
141
- "toolUseId"
142
- ]
143
-
144
- elif chunk .get ("contentBlockDelta" ):
145
- delta = chunk ["contentBlockDelta" ]["delta" ]
146
- if delta .get ("text" ):
147
- # Regular content, yield it
148
- text = delta ["text" ]
149
- chunk = ChatCompletionChunk (
150
- id = str (uuid .uuid4 ()),
151
- choices = [
152
- Choice (
153
- delta = ChoiceDelta (content = text ),
154
- finish_reason = None ,
155
- index = 0 ,
156
- )
157
- ],
158
- created = int (time .time ()),
159
- model = kwargs .get ("request" ).model ,
160
- object = "chat.completion.chunk" ,
161
- )
162
- yield chunk .model_dump ()
163
-
164
- elif delta .get ("toolUse" ):
165
- partial_json = delta ["toolUse" ]["input" ]
166
- tool_arguments += partial_json
167
-
168
- elif chunk .get ("contentBlockStop" ) and tool_id :
169
- name_chunk = ChatCompletionChunk (
170
- id = str (uuid .uuid4 ()),
171
- choices = [
172
- Choice (
173
- delta = ChoiceDelta (
174
- role = "assistant" ,
175
- tool_calls = [
176
- ChoiceDeltaToolCall (
177
- index = chunk ["contentBlockStop" ][
178
- "contentBlockIndex"
179
- ],
180
- id = tool_id ,
181
- function = ChoiceDeltaToolCallFunction (
182
- name = tool_name ,
183
- arguments = "" ,
184
- type = "function" ,
185
- ),
186
- )
187
- ],
188
- ),
189
- finish_reason = None ,
190
- index = chunk ["contentBlockStop" ]["contentBlockIndex" ],
191
- )
192
- ],
193
- created = int (time .time ()),
194
- model = kwargs .get ("request" ).model ,
195
- object = "chat.completion.chunk" ,
196
- )
197
- yield name_chunk .model_dump ()
198
-
199
- args_chunk = ChatCompletionChunk (
200
- id = tool_id ,
201
- choices = [
202
- Choice (
203
- delta = ChoiceDelta (
204
- tool_calls = [
205
- ChoiceDeltaToolCall (
206
- index = chunk ["contentBlockStop" ][
207
- "contentBlockIndex"
208
- ],
209
- function = ChoiceDeltaToolCallFunction (
210
- arguments = tool_arguments ,
211
- ),
212
- )
213
- ],
214
- ),
215
- finish_reason = None ,
216
- index = chunk ["contentBlockStop" ]["contentBlockIndex" ],
217
- )
218
- ],
219
- created = int (time .time ()),
220
- model = kwargs .get ("request" ).model ,
221
- object = "chat.completion.chunk" ,
222
- )
223
- yield args_chunk .model_dump ()
224
-
225
- elif chunk .get ("messageStop" ):
226
- stop_reason = chunk ["messageStop" ].get ("stopReason" )
227
- final_chunk = ChatCompletionChunk (
228
- id = str (uuid .uuid4 ()),
229
- choices = [
230
- Choice (
231
- delta = ChoiceDelta (),
232
- finish_reason = "tool_calls"
233
- if stop_reason == "tool_use"
234
- else "stop" ,
235
- index = 0 ,
236
- )
237
- ],
238
- created = int (time .time ()),
239
- model = kwargs .get ("request" ).model ,
240
- object = "chat.completion.chunk" ,
241
- )
242
- yield final_chunk .model_dump ()
243
-
244
- @staticmethod
245
- def _process_messages (
246
- chat_input : Union [str , List [Dict [str , str ]]]
247
- ) -> List [Dict [str , Union [List [Dict [str , str ]], str ]]]:
248
- """
249
- Generate input text for the Bedrock API based on the provided chat input.
250
-
251
- Args:
252
- chat_input (Union[str, List[Dict[str, str]]]): The input text or a list of message dictionaries.
253
-
254
- Returns:
255
- List[Dict[str, Union[List[Dict[str, str]], str]]]: A list of formatted messages for the Bedrock API.
256
-
257
- Raises:
258
- HTTPException: If the input is invalid.
259
- """
260
- if isinstance (chat_input , str ):
261
- return [
262
- {
263
- "role" : "user" ,
264
- "content" : [{"text" : chat_input }],
265
- }
266
- ]
267
-
268
- elif isinstance (chat_input , list ):
269
- messages = []
270
- tool_result = None
271
- system_prompt = []
272
- for message in chat_input :
273
- if message .get ("role" ) in ["assistant" , "user" ]:
274
- if message .get ("tool_calls" ):
275
- tool_use = {"role" : "assistant" , "content" : []}
276
- for tool in message .get ("tool_calls" ):
277
- tool_use ["content" ].append (
278
- {
279
- "toolUse" : {
280
- "toolUseId" : tool ["id" ],
281
- "name" : tool ["function" ]["name" ],
282
- "input" : json .loads (
283
- tool ["function" ]["arguments" ]
284
- ),
285
- }
286
- }
287
- )
288
- messages .append (tool_use )
289
- else :
290
- messages .append (
291
- {
292
- "role" : message .get ("role" ),
293
- "content" : [{"text" : message .get ("content" )}],
294
- }
295
- )
296
- if message .get ("role" ) in ["tool" ]:
297
- if not tool_result :
298
- tool_result = {"role" : "user" , "content" : []}
299
- tool_result ["content" ].append (
300
- {
301
- "toolResult" : {
302
- "toolUseId" : message ["tool_call_id" ],
303
- "content" : [{"json" : {"text" : message ["content" ]}}],
304
- }
305
- }
306
- )
307
- if message .get ("role" ) in ["system" ]:
308
- system_prompt = [{"text" : message .get ("content" )}]
309
-
310
- if tool_result :
311
- messages .append (tool_result )
312
-
313
- return messages , system_prompt
314
-
315
- @staticmethod
316
- def _process_tools (parameters : dict ) -> Optional [Dict ]:
317
- if parameters .get ("tools" ) is None and parameters .get ("functions" ) is None :
318
- return None
319
-
320
- try :
321
- if parameters .get ("tools" ):
322
- parsed_tools = (
323
- [OpenAITool (** tool ) for tool in parameters .get ("tools" )]
324
- if isinstance (parameters .get ("tools" ), list )
325
- else [OpenAITool (** parameters .get ("tools" ))]
326
- )
327
- if parameters .get ("functions" ):
328
- parsed_tools = (
329
- [OpenAIToolFunction (** tool ) for tool in parameters .get ("functions" )]
330
- if isinstance (parameters .get ("functions" ), list )
331
- else [OpenAIToolFunction (** parameters .get ("functions" ))]
332
- )
333
- tool_configurations = []
334
- for tool in parsed_tools :
335
- tool_config = {
336
- "toolSpec" : {
337
- "name" : tool .function .name
338
- if parameters .get ("tools" )
339
- else tool .name ,
340
- "description" : tool .function .description
341
- if parameters .get ("tools" )
342
- else tool .description ,
343
- "inputSchema" : {
344
- "json" : {
345
- "type" : tool .function .parameters .type
346
- if parameters .get ("tools" )
347
- else tool .parameters .type ,
348
- "properties" : tool .function .parameters .properties
349
- if parameters .get ("tools" )
350
- else tool .parameters .properties ,
351
- "required" : tool .function .parameters .required
352
- if parameters .get ("tools" )
353
- else tool .parameters .required ,
354
- }
355
- },
356
- }
357
- }
358
- tool_configurations .append (tool_config )
359
- return {"tools" : tool_configurations }
360
-
361
- except ValidationError :
362
- return (
363
- parameters .get ("tools" )
364
- if parameters .get ("tools" )
365
- else parameters .get ("functions" )
366
- )
367
-
368
- @staticmethod
369
- def _process_parameters (parameters : dict ) -> dict :
370
- remove_keys = ["system" , "stop" , "tools" ]
371
- for key in remove_keys :
372
- parameters .pop (key , None )
373
- return parameters
44
+ return self .selected_model .parse_response (response = response , ** kwargs )
0 commit comments