@@ -132,7 +132,7 @@ def call_claude_v3(
132132 * ,
133133 model_name : str ,
134134 prompt : str ,
135- image : Optional [Any ] = None ,
135+ images : Optional [List [ Any ] ] = None ,
136136 image_media_type : Optional [str ] = None ,
137137 max_tokens_to_sample : int ,
138138 thinking_tokens : Optional [int ] = None ,
@@ -144,14 +144,14 @@ def call_claude_v3(
144144 ** kwargs ,
145145 ) -> Union [str , Dict [str , str ]]:
146146 """
147- Call Claude v3 models with support for images and thinking parameter.
147+ Call Claude v3 models with support for multiple images and thinking parameter.
148148
149149 Args:
150150 bedrock_client: Boto3 bedrock client
151151 model_name (str): Claude model name
152152 prompt (str): Text prompt to send
153- image (Optional[Any]): Base64 -encoded image data
154- image_media_type (Optional[str]): Media type of the image (e.g., "image/png")
153+ images (Optional[List[ Any]] ): List of base64 -encoded image data
154+ image_media_type (Optional[str]): Media type of the images (e.g., "image/png")
155155 max_tokens_to_sample (int): Maximum tokens to generate
156156 thinking_tokens (Optional[int]): Number of tokens allocated for model thinking (Claude 3.7 only)
157157 temperature (Optional[float]): Temperature parameter for generation
@@ -168,10 +168,10 @@ def call_claude_v3(
168168 >>> bedrock_client = boto3.client(service_name="bedrock-runtime")
169169 >>> result = call_claude_v3(
170170 >>> bedrock_client=bedrock_client,
171- >>> prompt="Describe this image ",
171+ >>> prompt="Describe these images ",
172172 >>> model_name="anthropic.claude-3-sonnet-20240229-v1:0",
173173 >>> max_tokens_to_sample=500,
174- >>> image=base64_encoded_image ,
174+ >>> images=[base64_encoded_image1, base64_encoded_image2] ,
175175 >>> image_media_type="image/png"
176176 >>> )
177177 """
@@ -180,18 +180,20 @@ def call_claude_v3(
180180 ## Prepare the message content:
181181 message_content : List [Dict [str , Any ]] = []
182182
183- ## Add image if provided:
184- if image is not None :
185- message_content .append (
186- {
187- "type" : "image" ,
188- "source" : {
189- "type" : "base64" ,
190- "media_type" : image_media_type ,
191- "data" : image ,
192- },
193- }
194- )
183+ ## Add images if provided:
184+ if images is not None :
185+ for image in as_list (images ):
186+ if image is not None :
187+ message_content .append (
188+ {
189+ "type" : "image" ,
190+ "source" : {
191+ "type" : "base64" ,
192+ "media_type" : image_media_type ,
193+ "data" : image ,
194+ },
195+ }
196+ )
195197
196198 ## Add text prompt:
197199 message_content .append ({"type" : "text" , "text" : prompt })
@@ -751,14 +753,14 @@ def prompt_model_with_retries(
751753 self ,
752754 * ,
753755 prompt : str ,
754- image : Optional = None ,
756+ images : Optional [ List [ Any ]] = None ,
755757 ) -> Union [str , Dict [str , str ]]:
756758 """
757- Prompt the model with retries, supporting both text and image inputs.
759+ Prompt the model with retries, supporting both text and multiple image inputs.
758760
759761 Args:
760762 prompt (str): Text prompt
761- image: URL or data of an image to include
763+ images: List of URLs or data of images to include
762764
763765 Returns:
764766 Union[str, Dict[str, str]]: Generated text or dict with response and thinking
@@ -769,17 +771,29 @@ def prompt_model_with_retries(
769771 if self .bedrock_client is None :
770772 raise SystemError ("BedrockPrompter not initialized. Call initialize() first." )
771773
772- ## Process image if provided:
773- image : Optional = None
774- if isinstance (image , str ):
775- ## Check if the image is a URL:
776- if image .startswith ("http://" ) or image .startswith ("https://" ):
777- image = process_image_url (image )
774+ ## Process images if provided:
775+ processed_images = []
776+ if images is not None :
777+ for image in as_list (images ):
778+ if image is None :
779+ continue
780+ if isinstance (image , str ):
781+ ## Check if the image is a URL:
782+ if image .startswith ("http://" ) or image .startswith ("https://" ):
783+ processed_image = process_image_url (image )
784+ if processed_image is not None :
785+ processed_images .append (processed_image )
786+ else :
787+ ## Assume it's already base64 encoded:
788+ processed_images .append (image )
789+ elif image is not None :
790+ ## Assume it's raw image data that needs to be sent:
791+ processed_images .append (image )
778792
779793 try :
780794 generation_params = self .bedrock_text_generation_params
781- if image is not None :
782- generation_params ["image " ] = image
795+ if len ( processed_images ) > 0 :
796+ generation_params ["images " ] = processed_images
783797 generation_params ["image_media_type" ] = "image/png"
784798
785799 with Timer (silent = True ) as gen_timer :
@@ -817,24 +831,25 @@ def predict_step(self, batch: Prompts, **kwargs) -> Dict[str, List[Any]]:
817831 """
818832 generated_texts : List [Union [str , Dict [str , str ]]] = []
819833
820- ## Identify image column if available :
821- image_column : Optional [str ] = None
834+ ## Identify all image columns :
835+ image_columns : List [str ] = []
822836 for col_name , col_type in batch .data_schema .flatten ().items ():
823837 if col_type == MLType .IMAGE :
824- image_column = col_name
825- break
838+ image_columns .append (col_name )
826839
827840 for i , prompt in enumerate (batch .prompts ().tolist ()):
828- ## Get image URL if available:
829- image : Optional = None
830- if image_column is not None :
831- image = batch .data [image_column ].iloc [i ]
832-
833- ## Generate text with image if available:
841+ ## Get all images if available:
842+ images : List = []
843+ for image_column in image_columns :
844+ image_value = batch .data [image_column ].iloc [i ]
845+ if image_value is not None :
846+ images .append (image_value )
847+
848+ ## Generate text with images if available:
834849 result : Union [str , Dict [str , str ]] = dispatch (
835850 self .prompt_model_with_retries ,
836851 prompt = prompt ,
837- image = image ,
852+ images = images if len ( images ) > 0 else None ,
838853 executor = self .executor ,
839854 parallelize = Parallelize .sync
840855 if self .hyperparams .max_workers is None
@@ -869,6 +884,7 @@ def predict_step(self, batch: Prompts, **kwargs) -> Dict[str, List[Any]]:
869884 ## Handle case where result is a string:
870885 thinking_outputs .append ("" )
871886 generated_texts .append (result )
887+ generation_times .append (math .nan )
872888 else :
873889 raise ValueError (f"Unexpected result type: { type (result )} with value:\n { result } " )
874890
0 commit comments