Skip to content

Commit 443bb45

Browse files
Added Mistral models to bedrock
1 parent 83e8f56 commit 443bb45

File tree

1 file changed

+90
-0
lines changed

1 file changed

+90
-0
lines changed

src/fmcore/algorithm/bedrock.py

Lines changed: 90 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -313,6 +313,80 @@ def call_claude_v3_messages_api(
313313
response_body: Dict = json.loads(response.get("body").read())
314314
return "\n".join([d["text"] for d in response_body.get("content")])
315315

316+
def call_llama_3(
317+
bedrock_client,
318+
model_name: str,
319+
prompt: str,
320+
max_tokens_to_sample: int,
321+
temperature: Optional[float] = None,
322+
top_k: Optional[int] = None,
323+
top_p: Optional[float] = None,
324+
**kwargs,
325+
) -> str:
326+
assert any_are_none(top_k, top_p), "At least one of top_k, top_p must be None"
327+
bedrock_params = {
328+
"prompt": prompt,
329+
"max_gen_len": max_tokens_to_sample,
330+
}
331+
if top_p is not None and temperature is not None:
332+
raise ValueError("Cannot specify both top_p and temperature; at most one must be specified.")
333+
334+
if top_k is not None:
335+
assert isinstance(top_k, int)
336+
bedrock_params["top_k"] = top_k
337+
elif temperature is not None:
338+
assert isinstance(temperature, (float, int)) and 0 <= temperature <= 1
339+
bedrock_params["temperature"] = temperature
340+
elif top_p is not None:
341+
assert isinstance(top_p, (float, int)) and 0 <= top_p <= 1
342+
bedrock_params["top_p"] = top_p
343+
344+
response = bedrock_client.invoke_model(
345+
body=json.dumps(bedrock_params),
346+
modelId=model_name,
347+
accept="application/json",
348+
contentType="application/json",
349+
)
350+
response_body: Dict = json.loads(response.get("body").read())
351+
return response_body.get("generation")
352+
353+
def call_mistral(
354+
bedrock_client,
355+
model_name: str,
356+
prompt: str,
357+
max_tokens_to_sample: int,
358+
temperature: Optional[float] = None,
359+
top_k: Optional[int] = None,
360+
top_p: Optional[float] = None,
361+
**kwargs,
362+
) -> str:
363+
assert any_are_none(top_k, top_p), "At least one of top_k, top_p must be None"
364+
bedrock_params = {
365+
"prompt": prompt,
366+
"max_tokens": max_tokens_to_sample,
367+
}
368+
if top_p is not None and temperature is not None:
369+
raise ValueError("Cannot specify both top_p and temperature; at most one must be specified.")
370+
371+
if top_k is not None:
372+
assert isinstance(top_k, int)
373+
bedrock_params["top_k"] = top_k
374+
elif temperature is not None:
375+
assert isinstance(temperature, (float, int)) and 0 <= temperature <= 1
376+
bedrock_params["temperature"] = temperature
377+
elif top_p is not None:
378+
assert isinstance(top_p, (float, int)) and 0 <= top_p <= 1
379+
bedrock_params["top_p"] = top_p
380+
381+
response = bedrock_client.invoke_model(
382+
body=json.dumps(bedrock_params),
383+
modelId=model_name,
384+
accept="application/json",
385+
contentType="application/json",
386+
)
387+
response_body: Dict = json.loads(response.get("body").read())
388+
return "\n".join([d["text"] for d in response_body["outputs"]])
389+
316390
def call_bedrock(
317391
*,
318392
bedrock_client: Any,
@@ -368,6 +442,22 @@ def call_bedrock(
368442
**generation_params,
369443
)
370444
return generated_text
445+
elif "meta.llama3" in model_name:
446+
generated_text: str = call_llama_3(
447+
bedrock_client=bedrock_client,
448+
prompt=prompt,
449+
model_name=model_name,
450+
**generation_params,
451+
)
452+
return generated_text
453+
elif "mistral" in model_name:
454+
generated_text: str = call_mistral(
455+
bedrock_client=bedrock_client,
456+
prompt=prompt,
457+
model_name=model_name,
458+
**generation_params,
459+
)
460+
return generated_text
371461
else:
372462
bedrock_invoke_model_params = {"prompt": prompt, **generation_params}
373463
response = bedrock_client.invoke_model(

0 commit comments

Comments
 (0)