Skip to content

Commit fcfe33e

Browse files
committed
[feat] bedrock mapper
1 parent 3d65f0a commit fcfe33e

File tree

2 files changed

+398
-348
lines changed

2 files changed

+398
-348
lines changed
Lines changed: 19 additions & 348 deletions
Original file line numberDiff line numberDiff line change
@@ -1,45 +1,23 @@
1-
import asyncio
2-
import json
3-
import os
4-
import time
5-
import uuid
6-
from typing import (
7-
Any,
8-
AsyncGenerator,
9-
Coroutine,
10-
Dict,
11-
Generator,
12-
List,
13-
Optional,
14-
Union,
15-
)
1+
from typing import Any, AsyncGenerator, Coroutine, Generator
162

17-
import boto3
18-
from fastapi import HTTPException
19-
from llmstudio_core.exceptions import ProviderError
3+
from llmstudio_core.providers.bedrock_providers.antropic import BedrockAntropicProvider
204
from llmstudio_core.providers.provider import ChatRequest, ProviderCore, provider
21-
from llmstudio_core.utils import OpenAITool, OpenAIToolFunction
22-
from openai.types.chat import ChatCompletionChunk
23-
from openai.types.chat.chat_completion_chunk import (
24-
Choice,
25-
ChoiceDelta,
26-
ChoiceDeltaToolCall,
27-
ChoiceDeltaToolCallFunction,
28-
)
29-
from pydantic import ValidationError
5+
6+
SUPORTED_PROVIDERS = ["antropic"]
307

318

329
@provider
3310
class BedrockProvider(ProviderCore):
3411
def __init__(self, config, **kwargs):
3512
super().__init__(config, **kwargs)
36-
self.access_key = (
37-
self.access_key if self.access_key else os.getenv("BEDROCK_ACCESS_KEY")
38-
)
39-
self.secret_key = (
40-
self.secret_key if self.secret_key else os.getenv("BEDROCK_SECRET_KEY")
41-
)
42-
self.region = self.region if self.region else os.getenv("BEDROCK_REGION")
13+
self.kwargs = kwargs
14+
self.selected_model = None
15+
16+
def _get_provider(self, model):
17+
if "anthropic." in model:
18+
return BedrockAntropicProvider(config=self.config, **self.kwargs)
19+
20+
raise ValueError(f" provider is not yet supported.")
4321

4422
@staticmethod
4523
def _provider_config_name():
@@ -49,325 +27,18 @@ def validate_request(self, request: ChatRequest):
4927
return ChatRequest(**request)
5028

5129
async def agenerate_client(self, request: ChatRequest) -> Coroutine[Any, Any, Any]:
52-
"""Generate an AWS Bedrock client"""
53-
return await asyncio.to_thread(self.generate_client, request)
30+
self.selected_model = self._get_provider(request.model)
31+
return self.selected_model.agenerate_client()
5432

5533
def generate_client(self, request: ChatRequest) -> Coroutine[Any, Any, Generator]:
56-
"""Generate an AWS Bedrock client"""
57-
try:
58-
59-
service = "bedrock-runtime"
60-
61-
if (
62-
self.access_key is None
63-
or self.secret_key is None
64-
or self.region is None
65-
):
66-
raise HTTPException(
67-
status_code=400,
68-
detail="AWS credentials were not given or not set in environment variables.",
69-
)
70-
71-
client = boto3.client(
72-
service,
73-
region_name=self.region,
74-
aws_access_key_id=self.access_key,
75-
aws_secret_access_key=self.secret_key,
76-
)
77-
78-
messages, system_prompt = self._process_messages(request.chat_input)
79-
tools = self._process_tools(request.parameters)
80-
81-
system_prompt = (
82-
request.parameters.get("system")
83-
if request.parameters.get("system")
84-
else system_prompt
85-
)
86-
87-
client_params = {
88-
"modelId": request.model,
89-
"messages": messages,
90-
"inferenceConfig": self._process_parameters(request.parameters),
91-
"system": system_prompt,
92-
}
93-
if tools:
94-
client_params["toolConfig"] = tools
95-
96-
return client.converse_stream(**client_params)
97-
except Exception as e:
98-
raise ProviderError(str(e))
34+
self.selected_model = self._get_provider(request.model)
35+
client = self.selected_model.generate_client(request=request)
36+
return client
9937

10038
async def aparse_response(
10139
self, response: Any, **kwargs
10240
) -> AsyncGenerator[str, None]:
103-
iterator = await asyncio.to_thread(
104-
self.parse_response, response=response, **kwargs
105-
)
106-
for item in iterator:
107-
yield item
41+
return self.selected_model.aparse_response(response=response, **kwargs)
10842

10943
def parse_response(self, response: AsyncGenerator[Any, None], **kwargs) -> Any:
110-
tool_name = None
111-
tool_arguments = ""
112-
tool_id = None
113-
114-
for chunk in response["stream"]:
115-
if chunk.get("messageStart"):
116-
first_chunk = ChatCompletionChunk(
117-
id=str(uuid.uuid4()),
118-
choices=[
119-
Choice(
120-
delta=ChoiceDelta(
121-
content=None,
122-
function_call=None,
123-
role="assistant",
124-
tool_calls=None,
125-
),
126-
index=0,
127-
)
128-
],
129-
created=int(time.time()),
130-
model=kwargs.get("request").model,
131-
object="chat.completion.chunk",
132-
usage=None,
133-
)
134-
yield first_chunk.model_dump()
135-
136-
elif chunk.get("contentBlockStart"):
137-
if chunk["contentBlockStart"]["start"].get("toolUse"):
138-
tool_name = chunk["contentBlockStart"]["start"]["toolUse"]["name"]
139-
tool_arguments = ""
140-
tool_id = chunk["contentBlockStart"]["start"]["toolUse"][
141-
"toolUseId"
142-
]
143-
144-
elif chunk.get("contentBlockDelta"):
145-
delta = chunk["contentBlockDelta"]["delta"]
146-
if delta.get("text"):
147-
# Regular content, yield it
148-
text = delta["text"]
149-
chunk = ChatCompletionChunk(
150-
id=str(uuid.uuid4()),
151-
choices=[
152-
Choice(
153-
delta=ChoiceDelta(content=text),
154-
finish_reason=None,
155-
index=0,
156-
)
157-
],
158-
created=int(time.time()),
159-
model=kwargs.get("request").model,
160-
object="chat.completion.chunk",
161-
)
162-
yield chunk.model_dump()
163-
164-
elif delta.get("toolUse"):
165-
partial_json = delta["toolUse"]["input"]
166-
tool_arguments += partial_json
167-
168-
elif chunk.get("contentBlockStop") and tool_id:
169-
name_chunk = ChatCompletionChunk(
170-
id=str(uuid.uuid4()),
171-
choices=[
172-
Choice(
173-
delta=ChoiceDelta(
174-
role="assistant",
175-
tool_calls=[
176-
ChoiceDeltaToolCall(
177-
index=chunk["contentBlockStop"][
178-
"contentBlockIndex"
179-
],
180-
id=tool_id,
181-
function=ChoiceDeltaToolCallFunction(
182-
name=tool_name,
183-
arguments="",
184-
type="function",
185-
),
186-
)
187-
],
188-
),
189-
finish_reason=None,
190-
index=chunk["contentBlockStop"]["contentBlockIndex"],
191-
)
192-
],
193-
created=int(time.time()),
194-
model=kwargs.get("request").model,
195-
object="chat.completion.chunk",
196-
)
197-
yield name_chunk.model_dump()
198-
199-
args_chunk = ChatCompletionChunk(
200-
id=tool_id,
201-
choices=[
202-
Choice(
203-
delta=ChoiceDelta(
204-
tool_calls=[
205-
ChoiceDeltaToolCall(
206-
index=chunk["contentBlockStop"][
207-
"contentBlockIndex"
208-
],
209-
function=ChoiceDeltaToolCallFunction(
210-
arguments=tool_arguments,
211-
),
212-
)
213-
],
214-
),
215-
finish_reason=None,
216-
index=chunk["contentBlockStop"]["contentBlockIndex"],
217-
)
218-
],
219-
created=int(time.time()),
220-
model=kwargs.get("request").model,
221-
object="chat.completion.chunk",
222-
)
223-
yield args_chunk.model_dump()
224-
225-
elif chunk.get("messageStop"):
226-
stop_reason = chunk["messageStop"].get("stopReason")
227-
final_chunk = ChatCompletionChunk(
228-
id=str(uuid.uuid4()),
229-
choices=[
230-
Choice(
231-
delta=ChoiceDelta(),
232-
finish_reason="tool_calls"
233-
if stop_reason == "tool_use"
234-
else "stop",
235-
index=0,
236-
)
237-
],
238-
created=int(time.time()),
239-
model=kwargs.get("request").model,
240-
object="chat.completion.chunk",
241-
)
242-
yield final_chunk.model_dump()
243-
244-
@staticmethod
245-
def _process_messages(
246-
chat_input: Union[str, List[Dict[str, str]]]
247-
) -> List[Dict[str, Union[List[Dict[str, str]], str]]]:
248-
"""
249-
Generate input text for the Bedrock API based on the provided chat input.
250-
251-
Args:
252-
chat_input (Union[str, List[Dict[str, str]]]): The input text or a list of message dictionaries.
253-
254-
Returns:
255-
List[Dict[str, Union[List[Dict[str, str]], str]]]: A list of formatted messages for the Bedrock API.
256-
257-
Raises:
258-
HTTPException: If the input is invalid.
259-
"""
260-
if isinstance(chat_input, str):
261-
return [
262-
{
263-
"role": "user",
264-
"content": [{"text": chat_input}],
265-
}
266-
]
267-
268-
elif isinstance(chat_input, list):
269-
messages = []
270-
tool_result = None
271-
system_prompt = []
272-
for message in chat_input:
273-
if message.get("role") in ["assistant", "user"]:
274-
if message.get("tool_calls"):
275-
tool_use = {"role": "assistant", "content": []}
276-
for tool in message.get("tool_calls"):
277-
tool_use["content"].append(
278-
{
279-
"toolUse": {
280-
"toolUseId": tool["id"],
281-
"name": tool["function"]["name"],
282-
"input": json.loads(
283-
tool["function"]["arguments"]
284-
),
285-
}
286-
}
287-
)
288-
messages.append(tool_use)
289-
else:
290-
messages.append(
291-
{
292-
"role": message.get("role"),
293-
"content": [{"text": message.get("content")}],
294-
}
295-
)
296-
if message.get("role") in ["tool"]:
297-
if not tool_result:
298-
tool_result = {"role": "user", "content": []}
299-
tool_result["content"].append(
300-
{
301-
"toolResult": {
302-
"toolUseId": message["tool_call_id"],
303-
"content": [{"json": {"text": message["content"]}}],
304-
}
305-
}
306-
)
307-
if message.get("role") in ["system"]:
308-
system_prompt = [{"text": message.get("content")}]
309-
310-
if tool_result:
311-
messages.append(tool_result)
312-
313-
return messages, system_prompt
314-
315-
@staticmethod
316-
def _process_tools(parameters: dict) -> Optional[Dict]:
317-
if parameters.get("tools") is None and parameters.get("functions") is None:
318-
return None
319-
320-
try:
321-
if parameters.get("tools"):
322-
parsed_tools = (
323-
[OpenAITool(**tool) for tool in parameters.get("tools")]
324-
if isinstance(parameters.get("tools"), list)
325-
else [OpenAITool(**parameters.get("tools"))]
326-
)
327-
if parameters.get("functions"):
328-
parsed_tools = (
329-
[OpenAIToolFunction(**tool) for tool in parameters.get("functions")]
330-
if isinstance(parameters.get("functions"), list)
331-
else [OpenAIToolFunction(**parameters.get("functions"))]
332-
)
333-
tool_configurations = []
334-
for tool in parsed_tools:
335-
tool_config = {
336-
"toolSpec": {
337-
"name": tool.function.name
338-
if parameters.get("tools")
339-
else tool.name,
340-
"description": tool.function.description
341-
if parameters.get("tools")
342-
else tool.description,
343-
"inputSchema": {
344-
"json": {
345-
"type": tool.function.parameters.type
346-
if parameters.get("tools")
347-
else tool.parameters.type,
348-
"properties": tool.function.parameters.properties
349-
if parameters.get("tools")
350-
else tool.parameters.properties,
351-
"required": tool.function.parameters.required
352-
if parameters.get("tools")
353-
else tool.parameters.required,
354-
}
355-
},
356-
}
357-
}
358-
tool_configurations.append(tool_config)
359-
return {"tools": tool_configurations}
360-
361-
except ValidationError:
362-
return (
363-
parameters.get("tools")
364-
if parameters.get("tools")
365-
else parameters.get("functions")
366-
)
367-
368-
@staticmethod
369-
def _process_parameters(parameters: dict) -> dict:
370-
remove_keys = ["system", "stop", "tools"]
371-
for key in remove_keys:
372-
parameters.pop(key, None)
373-
return parameters
44+
return self.selected_model.parse_response(response=response, **kwargs)

0 commit comments

Comments
 (0)