Skip to content

Commit 54a2fa9

Browse files
Supported passing multiple images
1 parent bdb2c4e commit 54a2fa9

File tree

1 file changed

+56
-40
lines changed

1 file changed

+56
-40
lines changed

src/fmcore/algorithm/bedrock.py

Lines changed: 56 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -132,7 +132,7 @@ def call_claude_v3(
132132
*,
133133
model_name: str,
134134
prompt: str,
135-
image: Optional[Any] = None,
135+
images: Optional[List[Any]] = None,
136136
image_media_type: Optional[str] = None,
137137
max_tokens_to_sample: int,
138138
thinking_tokens: Optional[int] = None,
@@ -144,14 +144,14 @@ def call_claude_v3(
144144
**kwargs,
145145
) -> Union[str, Dict[str, str]]:
146146
"""
147-
Call Claude v3 models with support for images and thinking parameter.
147+
Call Claude v3 models with support for multiple images and thinking parameter.
148148
149149
Args:
150150
bedrock_client: Boto3 bedrock client
151151
model_name (str): Claude model name
152152
prompt (str): Text prompt to send
153-
image (Optional[Any]): Base64-encoded image data
154-
image_media_type (Optional[str]): Media type of the image (e.g., "image/png")
153+
images (Optional[List[Any]]): List of base64-encoded image data
154+
image_media_type (Optional[str]): Media type of the images (e.g., "image/png")
155155
max_tokens_to_sample (int): Maximum tokens to generate
156156
thinking_tokens (Optional[int]): Number of tokens allocated for model thinking (Claude 3.7 only)
157157
temperature (Optional[float]): Temperature parameter for generation
@@ -168,10 +168,10 @@ def call_claude_v3(
168168
>>> bedrock_client = boto3.client(service_name="bedrock-runtime")
169169
>>> result = call_claude_v3(
170170
>>> bedrock_client=bedrock_client,
171-
>>> prompt="Describe this image",
171+
>>> prompt="Describe these images",
172172
>>> model_name="anthropic.claude-3-sonnet-20240229-v1:0",
173173
>>> max_tokens_to_sample=500,
174-
>>> image=base64_encoded_image,
174+
>>> images=[base64_encoded_image1, base64_encoded_image2],
175175
>>> image_media_type="image/png"
176176
>>> )
177177
"""
@@ -180,18 +180,20 @@ def call_claude_v3(
180180
## Prepare the message content:
181181
message_content: List[Dict[str, Any]] = []
182182

183-
## Add image if provided:
184-
if image is not None:
185-
message_content.append(
186-
{
187-
"type": "image",
188-
"source": {
189-
"type": "base64",
190-
"media_type": image_media_type,
191-
"data": image,
192-
},
193-
}
194-
)
183+
## Add images if provided:
184+
if images is not None:
185+
for image in as_list(images):
186+
if image is not None:
187+
message_content.append(
188+
{
189+
"type": "image",
190+
"source": {
191+
"type": "base64",
192+
"media_type": image_media_type,
193+
"data": image,
194+
},
195+
}
196+
)
195197

196198
## Add text prompt:
197199
message_content.append({"type": "text", "text": prompt})
@@ -751,14 +753,14 @@ def prompt_model_with_retries(
751753
self,
752754
*,
753755
prompt: str,
754-
image: Optional = None,
756+
images: Optional[List[Any]] = None,
755757
) -> Union[str, Dict[str, str]]:
756758
"""
757-
Prompt the model with retries, supporting both text and image inputs.
759+
Prompt the model with retries, supporting both text and multiple image inputs.
758760
759761
Args:
760762
prompt (str): Text prompt
761-
image: URL or data of an image to include
763+
images: List of URLs or data of images to include
762764
763765
Returns:
764766
Union[str, Dict[str, str]]: Generated text or dict with response and thinking
@@ -769,17 +771,29 @@ def prompt_model_with_retries(
769771
if self.bedrock_client is None:
770772
raise SystemError("BedrockPrompter not initialized. Call initialize() first.")
771773

772-
## Process image if provided:
773-
image: Optional = None
774-
if isinstance(image, str):
775-
## Check if the image is a URL:
776-
if image.startswith("http://") or image.startswith("https://"):
777-
image = process_image_url(image)
774+
## Process images if provided:
775+
processed_images = []
776+
if images is not None:
777+
for image in as_list(images):
778+
if image is None:
779+
continue
780+
if isinstance(image, str):
781+
## Check if the image is a URL:
782+
if image.startswith("http://") or image.startswith("https://"):
783+
processed_image = process_image_url(image)
784+
if processed_image is not None:
785+
processed_images.append(processed_image)
786+
else:
787+
## Assume it's already base64 encoded:
788+
processed_images.append(image)
789+
elif image is not None:
790+
## Assume it's raw image data that needs to be sent:
791+
processed_images.append(image)
778792

779793
try:
780794
generation_params = self.bedrock_text_generation_params
781-
if image is not None:
782-
generation_params["image"] = image
795+
if len(processed_images) > 0:
796+
generation_params["images"] = processed_images
783797
generation_params["image_media_type"] = "image/png"
784798

785799
with Timer(silent=True) as gen_timer:
@@ -817,24 +831,25 @@ def predict_step(self, batch: Prompts, **kwargs) -> Dict[str, List[Any]]:
817831
"""
818832
generated_texts: List[Union[str, Dict[str, str]]] = []
819833

820-
## Identify image column if available:
821-
image_column: Optional[str] = None
834+
## Identify all image columns:
835+
image_columns: List[str] = []
822836
for col_name, col_type in batch.data_schema.flatten().items():
823837
if col_type == MLType.IMAGE:
824-
image_column = col_name
825-
break
838+
image_columns.append(col_name)
826839

827840
for i, prompt in enumerate(batch.prompts().tolist()):
828-
## Get image URL if available:
829-
image: Optional = None
830-
if image_column is not None:
831-
image = batch.data[image_column].iloc[i]
832-
833-
## Generate text with image if available:
841+
## Get all images if available:
842+
images: List = []
843+
for image_column in image_columns:
844+
image_value = batch.data[image_column].iloc[i]
845+
if image_value is not None:
846+
images.append(image_value)
847+
848+
## Generate text with images if available:
834849
result: Union[str, Dict[str, str]] = dispatch(
835850
self.prompt_model_with_retries,
836851
prompt=prompt,
837-
image=image,
852+
images=images if len(images) > 0 else None,
838853
executor=self.executor,
839854
parallelize=Parallelize.sync
840855
if self.hyperparams.max_workers is None
@@ -869,6 +884,7 @@ def predict_step(self, batch: Prompts, **kwargs) -> Dict[str, List[Any]]:
869884
## Handle case where result is a string:
870885
thinking_outputs.append("")
871886
generated_texts.append(result)
887+
generation_times.append(math.nan)
872888
else:
873889
raise ValueError(f"Unexpected result type: {type(result)} with value:\n{result}")
874890

0 commit comments

Comments
 (0)