Skip to content

Commit 80e0713

Browse files
Added Nova models on bedrock via converse.
1 parent 443bb45 commit 80e0713

File tree

1 file changed

+45
-0
lines changed

1 file changed

+45
-0
lines changed

src/fmcore/algorithm/bedrock.py

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -350,6 +350,43 @@ def call_llama_3(
350350
response_body: Dict = json.loads(response.get("body").read())
351351
return response_body.get("generation")
352352

353+
def call_nova(
354+
bedrock_client,
355+
model_name: str,
356+
prompt: str,
357+
max_tokens_to_sample: int,
358+
temperature: Optional[float] = None,
359+
top_p: Optional[float] = None,
360+
stop_sequences: Optional[List[str]] = None,
361+
**kwargs,
362+
):
363+
messages = [
364+
{
365+
"role": "user",
366+
"content": [{"text": prompt}],
367+
}
368+
]
369+
370+
inference_config = {
371+
"maxTokens": max_tokens_to_sample,
372+
}
373+
if temperature is not None:
374+
assert isinstance(temperature, (float, int)) and 0 <= temperature <= 1
375+
inference_config["temperature"] = temperature
376+
if top_p is not None:
377+
assert isinstance(top_p, (float, int)) and 0 <= top_p <= 1
378+
inference_config["topP"] = top_p
379+
if stop_sequences is not None:
380+
assert isinstance(stop_sequences, list) and len(stop_sequences) > 0
381+
inference_config["stopSequences"] = stop_sequences
382+
383+
response_body: Dict = bedrock_client.converse(
384+
modelId=model_name,
385+
messages=messages,
386+
inferenceConfig=inference_config,
387+
)
388+
return "\n".join([d["text"] for d in response_body["output"]["message"]["content"]])
389+
353390
def call_mistral(
354391
bedrock_client,
355392
model_name: str,
@@ -458,6 +495,14 @@ def call_bedrock(
458495
**generation_params,
459496
)
460497
return generated_text
498+
elif "nova" in model_name:
499+
generated_text: str = call_nova(
500+
bedrock_client=bedrock_client,
501+
prompt=prompt,
502+
model_name=model_name,
503+
**generation_params,
504+
)
505+
return generated_text
461506
else:
462507
bedrock_invoke_model_params = {"prompt": prompt, **generation_params}
463508
response = bedrock_client.invoke_model(

0 commit comments

Comments
 (0)