Skip to content

Commit 4d808fb

Browse files
Added Nova models on bedrock via converse.
1 parent 443bb45 commit 4d808fb

File tree

1 file changed

+46
-0
lines changed

1 file changed

+46
-0
lines changed

src/fmcore/algorithm/bedrock.py

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -350,6 +350,44 @@ def call_llama_3(
350350
response_body: Dict = json.loads(response.get("body").read())
351351
return response_body.get("generation")
352352

353+
def call_nova(
354+
bedrock_client,
355+
model_name: str,
356+
prompt: str,
357+
max_tokens_to_sample: int,
358+
temperature: Optional[float] = None,
359+
top_p: Optional[float] = None,
360+
stop_sequences: Optional[List[str]] = None,
361+
**kwargs,
362+
):
363+
messages = [
364+
{
365+
"role": "user",
366+
"content": [{"text": prompt}],
367+
}
368+
]
369+
370+
inference_config = {
371+
"maxTokens": max_tokens_to_sample,
372+
}
373+
if temperature is not None:
374+
assert isinstance(temperature, (float, int)) and 0 <= temperature <= 1
375+
inference_config["temperature"] = temperature
376+
if top_p is not None:
377+
assert isinstance(top_p, (float, int)) and 0 <= top_p <= 1
378+
inference_config["topP"] = top_p
379+
if stop_sequences is not None:
380+
assert isinstance(stop_sequences, list)
381+
if len(stop_sequences) > 0:
382+
inference_config["stopSequences"] = stop_sequences
383+
384+
response_body: Dict = bedrock_client.converse(
385+
modelId=model_name,
386+
messages=messages,
387+
inferenceConfig=inference_config,
388+
)
389+
return "\n".join([d["text"] for d in response_body["output"]["message"]["content"]])
390+
353391
def call_mistral(
354392
bedrock_client,
355393
model_name: str,
@@ -458,6 +496,14 @@ def call_bedrock(
458496
**generation_params,
459497
)
460498
return generated_text
499+
elif "nova" in model_name:
500+
generated_text: str = call_nova(
501+
bedrock_client=bedrock_client,
502+
prompt=prompt,
503+
model_name=model_name,
504+
**generation_params,
505+
)
506+
return generated_text
461507
else:
462508
bedrock_invoke_model_params = {"prompt": prompt, **generation_params}
463509
response = bedrock_client.invoke_model(

0 commit comments

Comments
 (0)