@@ -132,7 +132,7 @@ def call_claude_v3(
132132 * ,
133133 model_name : str ,
134134 prompt : str ,
135- image : Optional [Any ] = None ,
135+ images : Optional [List [ Any ] ] = None ,
136136 image_media_type : Optional [str ] = None ,
137137 max_tokens_to_sample : int ,
138138 thinking_tokens : Optional [int ] = None ,
@@ -144,14 +144,14 @@ def call_claude_v3(
144144 ** kwargs ,
145145 ) -> Union [str , Dict [str , str ]]:
146146 """
147- Call Claude v3 models with support for images and thinking parameter.
147+ Call Claude v3 models with support for multiple images and thinking parameter.
148148
149149 Args:
150150 bedrock_client: Boto3 bedrock client
151151 model_name (str): Claude model name
152152 prompt (str): Text prompt to send
153- image (Optional[Any]): Base64 -encoded image data
154- image_media_type (Optional[str]): Media type of the image (e.g., "image/png")
153+ images (Optional[List[ Any]] ): List of base64 -encoded image data
154+ image_media_type (Optional[str]): Media type of the images (e.g., "image/png")
155155 max_tokens_to_sample (int): Maximum tokens to generate
156156 thinking_tokens (Optional[int]): Number of tokens allocated for model thinking (Claude 3.7 only)
157157 temperature (Optional[float]): Temperature parameter for generation
@@ -168,10 +168,10 @@ def call_claude_v3(
168168 >>> bedrock_client = boto3.client(service_name="bedrock-runtime")
169169 >>> result = call_claude_v3(
170170 >>> bedrock_client=bedrock_client,
171- >>> prompt="Describe this image ",
171+ >>> prompt="Describe these images ",
172172 >>> model_name="anthropic.claude-3-sonnet-20240229-v1:0",
173173 >>> max_tokens_to_sample=500,
174- >>> image=base64_encoded_image ,
174+ >>> images=[base64_encoded_image1, base64_encoded_image2] ,
175175 >>> image_media_type="image/png"
176176 >>> )
177177 """
@@ -180,18 +180,19 @@ def call_claude_v3(
180180 ## Prepare the message content:
181181 message_content : List [Dict [str , Any ]] = []
182182
183- ## Add image if provided:
184- if image is not None :
185- message_content .append (
186- {
187- "type" : "image" ,
188- "source" : {
189- "type" : "base64" ,
190- "media_type" : image_media_type ,
191- "data" : image ,
192- },
193- }
194- )
183+ ## Add images if provided:
184+ if images is not None :
185+ for image in as_list (images ):
186+ message_content .append (
187+ {
188+ "type" : "image" ,
189+ "source" : {
190+ "type" : "base64" ,
191+ "media_type" : image_media_type ,
192+ "data" : image ,
193+ },
194+ }
195+ )
195196
196197 ## Add text prompt:
197198 message_content .append ({"type" : "text" , "text" : prompt })
@@ -751,14 +752,14 @@ def prompt_model_with_retries(
751752 self ,
752753 * ,
753754 prompt : str ,
754- image : Optional = None ,
755+ images : Optional [ List [ Any ]] = None ,
755756 ) -> Union [str , Dict [str , str ]]:
756757 """
757- Prompt the model with retries, supporting both text and image inputs.
758+ Prompt the model with retries, supporting both text and multiple image inputs.
758759
759760 Args:
760761 prompt (str): Text prompt
761- image: URL or data of an image to include
762+ images: List of URLs or data of images to include
762763
763764 Returns:
764765 Union[str, Dict[str, str]]: Generated text or dict with response and thinking
@@ -769,17 +770,29 @@ def prompt_model_with_retries(
769770 if self .bedrock_client is None :
770771 raise SystemError ("BedrockPrompter not initialized. Call initialize() first." )
771772
772- ## Process image if provided:
773- image : Optional = None
774- if isinstance (image , str ):
775- ## Check if the image is a URL:
776- if image .startswith ("http://" ) or image .startswith ("https://" ):
777- image = process_image_url (image )
773+ ## Process images if provided:
774+ processed_images = []
775+ if images is not None :
776+ for image in as_list (images ):
777+ if image is None :
778+ continue
779+ if isinstance (image , str ):
780+ ## Check if the image is a URL:
781+ if image .startswith ("http://" ) or image .startswith ("https://" ):
782+ processed_image = process_image_url (image )
783+ if processed_image is not None :
784+ processed_images .append (processed_image )
785+ else :
786+ ## Assume it's already base64 encoded:
787+ processed_images .append (image )
788+ elif image is not None :
789+ ## Assume it's raw image data that needs to be sent:
790+ processed_images .append (image )
778791
779792 try :
780793 generation_params = self .bedrock_text_generation_params
781- if image is not None :
782- generation_params ["image " ] = image
794+ if len ( processed_images ) > 0 :
795+ generation_params ["images " ] = processed_images
783796 generation_params ["image_media_type" ] = "image/png"
784797
785798 with Timer (silent = True ) as gen_timer :
@@ -817,24 +830,25 @@ def predict_step(self, batch: Prompts, **kwargs) -> Dict[str, List[Any]]:
817830 """
818831 generated_texts : List [Union [str , Dict [str , str ]]] = []
819832
820- ## Identify image column if available :
821- image_column : Optional [str ] = None
833+ ## Identify all image columns :
834+ image_columns : List [str ] = []
822835 for col_name , col_type in batch .data_schema .flatten ().items ():
823836 if col_type == MLType .IMAGE :
824- image_column = col_name
825- break
837+ image_columns .append (col_name )
826838
827839 for i , prompt in enumerate (batch .prompts ().tolist ()):
828- ## Get image URL if available:
829- image : Optional = None
830- if image_column is not None :
831- image = batch .data [image_column ].iloc [i ]
832-
833- ## Generate text with image if available:
840+ ## Get all images if available:
841+ images : List = []
842+ for image_column in image_columns :
843+ image_value = batch .data [image_column ].iloc [i ]
844+ if image_value is not None :
845+ images .append (image_value )
846+
847+ ## Generate text with images if available:
834848 result : Union [str , Dict [str , str ]] = dispatch (
835849 self .prompt_model_with_retries ,
836850 prompt = prompt ,
837- image = image ,
851+ images = images if len ( images ) > 0 else None ,
838852 executor = self .executor ,
839853 parallelize = Parallelize .sync
840854 if self .hyperparams .max_workers is None
0 commit comments