Skip to content

Commit d9f7f92

Browse files
Supported passing multiple images
1 parent bdb2c4e commit d9f7f92

File tree

1 file changed

+54
-40
lines changed

1 file changed

+54
-40
lines changed

src/fmcore/algorithm/bedrock.py

Lines changed: 54 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -132,7 +132,7 @@ def call_claude_v3(
132132
*,
133133
model_name: str,
134134
prompt: str,
135-
image: Optional[Any] = None,
135+
images: Optional[List[Any]] = None,
136136
image_media_type: Optional[str] = None,
137137
max_tokens_to_sample: int,
138138
thinking_tokens: Optional[int] = None,
@@ -144,14 +144,14 @@ def call_claude_v3(
144144
**kwargs,
145145
) -> Union[str, Dict[str, str]]:
146146
"""
147-
Call Claude v3 models with support for images and thinking parameter.
147+
Call Claude v3 models with support for multiple images and thinking parameter.
148148
149149
Args:
150150
bedrock_client: Boto3 bedrock client
151151
model_name (str): Claude model name
152152
prompt (str): Text prompt to send
153-
image (Optional[Any]): Base64-encoded image data
154-
image_media_type (Optional[str]): Media type of the image (e.g., "image/png")
153+
images (Optional[List[Any]]): List of base64-encoded image data
154+
image_media_type (Optional[str]): Media type of the images (e.g., "image/png")
155155
max_tokens_to_sample (int): Maximum tokens to generate
156156
thinking_tokens (Optional[int]): Number of tokens allocated for model thinking (Claude 3.7 only)
157157
temperature (Optional[float]): Temperature parameter for generation
@@ -168,10 +168,10 @@ def call_claude_v3(
168168
>>> bedrock_client = boto3.client(service_name="bedrock-runtime")
169169
>>> result = call_claude_v3(
170170
>>> bedrock_client=bedrock_client,
171-
>>> prompt="Describe this image",
171+
>>> prompt="Describe these images",
172172
>>> model_name="anthropic.claude-3-sonnet-20240229-v1:0",
173173
>>> max_tokens_to_sample=500,
174-
>>> image=base64_encoded_image,
174+
>>> images=[base64_encoded_image1, base64_encoded_image2],
175175
>>> image_media_type="image/png"
176176
>>> )
177177
"""
@@ -180,18 +180,19 @@ def call_claude_v3(
180180
## Prepare the message content:
181181
message_content: List[Dict[str, Any]] = []
182182

183-
## Add image if provided:
184-
if image is not None:
185-
message_content.append(
186-
{
187-
"type": "image",
188-
"source": {
189-
"type": "base64",
190-
"media_type": image_media_type,
191-
"data": image,
192-
},
193-
}
194-
)
183+
## Add images if provided:
184+
if images is not None:
185+
for image in as_list(images):
186+
message_content.append(
187+
{
188+
"type": "image",
189+
"source": {
190+
"type": "base64",
191+
"media_type": image_media_type,
192+
"data": image,
193+
},
194+
}
195+
)
195196

196197
## Add text prompt:
197198
message_content.append({"type": "text", "text": prompt})
@@ -751,14 +752,14 @@ def prompt_model_with_retries(
751752
self,
752753
*,
753754
prompt: str,
754-
image: Optional = None,
755+
images: Optional[List[Any]] = None,
755756
) -> Union[str, Dict[str, str]]:
756757
"""
757-
Prompt the model with retries, supporting both text and image inputs.
758+
Prompt the model with retries, supporting both text and multiple image inputs.
758759
759760
Args:
760761
prompt (str): Text prompt
761-
image: URL or data of an image to include
762+
images: List of URLs or data of images to include
762763
763764
Returns:
764765
Union[str, Dict[str, str]]: Generated text or dict with response and thinking
@@ -769,17 +770,29 @@ def prompt_model_with_retries(
769770
if self.bedrock_client is None:
770771
raise SystemError("BedrockPrompter not initialized. Call initialize() first.")
771772

772-
## Process image if provided:
773-
image: Optional = None
774-
if isinstance(image, str):
775-
## Check if the image is a URL:
776-
if image.startswith("http://") or image.startswith("https://"):
777-
image = process_image_url(image)
773+
## Process images if provided:
774+
processed_images = []
775+
if images is not None:
776+
for image in as_list(images):
777+
if image is None:
778+
continue
779+
if isinstance(image, str):
780+
## Check if the image is a URL:
781+
if image.startswith("http://") or image.startswith("https://"):
782+
processed_image = process_image_url(image)
783+
if processed_image is not None:
784+
processed_images.append(processed_image)
785+
else:
786+
## Assume it's already base64 encoded:
787+
processed_images.append(image)
788+
elif image is not None:
789+
## Assume it's raw image data that needs to be sent:
790+
processed_images.append(image)
778791

779792
try:
780793
generation_params = self.bedrock_text_generation_params
781-
if image is not None:
782-
generation_params["image"] = image
794+
if len(processed_images) > 0:
795+
generation_params["images"] = processed_images
783796
generation_params["image_media_type"] = "image/png"
784797

785798
with Timer(silent=True) as gen_timer:
@@ -817,24 +830,25 @@ def predict_step(self, batch: Prompts, **kwargs) -> Dict[str, List[Any]]:
817830
"""
818831
generated_texts: List[Union[str, Dict[str, str]]] = []
819832

820-
## Identify image column if available:
821-
image_column: Optional[str] = None
833+
## Identify all image columns:
834+
image_columns: List[str] = []
822835
for col_name, col_type in batch.data_schema.flatten().items():
823836
if col_type == MLType.IMAGE:
824-
image_column = col_name
825-
break
837+
image_columns.append(col_name)
826838

827839
for i, prompt in enumerate(batch.prompts().tolist()):
828-
## Get image URL if available:
829-
image: Optional = None
830-
if image_column is not None:
831-
image = batch.data[image_column].iloc[i]
832-
833-
## Generate text with image if available:
840+
## Get all images if available:
841+
images: List = []
842+
for image_column in image_columns:
843+
image_value = batch.data[image_column].iloc[i]
844+
if image_value is not None:
845+
images.append(image_value)
846+
847+
## Generate text with images if available:
834848
result: Union[str, Dict[str, str]] = dispatch(
835849
self.prompt_model_with_retries,
836850
prompt=prompt,
837-
image=image,
851+
images=images if len(images) > 0 else None,
838852
executor=self.executor,
839853
parallelize=Parallelize.sync
840854
if self.hyperparams.max_workers is None

0 commit comments

Comments
 (0)