Skip to content

Commit e738aae

Browse files
Added loading images from s3 on bedrock.
1 parent 0783676 commit e738aae

File tree

1 file changed

+56
-0
lines changed

1 file changed

+56
-0
lines changed

src/fmcore/algorithm/bedrock.py

Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,7 @@ class ConfigSelectionStrategy(AutoEnum):
4949
import base64
5050
from io import BytesIO
5151

52+
import boto3
5253
import imageio
5354
from botocore.exceptions import ClientError
5455

@@ -87,6 +88,57 @@ def process_image_url(image_url: str) -> Optional[str]:
8788
Log.error(f"Failed to process image from URL {image_url}: {e}")
8889
return None
8990

91+
def process_image_s3(s3_path: str, aws_session: Optional[Any] = None) -> Optional[str]:
92+
"""
93+
Process an image from S3 by downloading it and converting it to base64.
94+
95+
Args:
96+
s3_path (str): S3 path to the image (e.g., "s3://bucket-name/path/to/image.jpg")
97+
aws_session (Optional[Any]): boto3 session to use. If None, uses default session.
98+
99+
Returns:
100+
Optional[str]: Base64-encoded image or None if processing failed
101+
102+
Example usage:
103+
>>> base64_image = process_image_s3("s3://my-bucket/images/photo.jpg")
104+
>>> if base64_image is not None:
105+
>>> print("Successfully processed image from S3")
106+
"""
107+
try:
108+
# Parse S3 path to extract bucket and key
109+
if not s3_path.startswith("s3://"):
110+
raise ValueError(f"Invalid S3 path format: {s3_path}. Expected format: s3://bucket-name/key")
111+
112+
s3_path_parts = s3_path[5:].split("/", 1) # Remove "s3://" and split into bucket and key
113+
if len(s3_path_parts) != 2:
114+
raise ValueError(f"Invalid S3 path format: {s3_path}. Expected format: s3://bucket-name/key")
115+
116+
bucket_name, object_key = s3_path_parts
117+
118+
# Create S3 client using provided session or default
119+
if aws_session is not None:
120+
s3_client = aws_session.client("s3")
121+
else:
122+
s3_client = boto3.client("s3")
123+
124+
# Download the image from S3
125+
response = s3_client.get_object(Bucket=bucket_name, Key=object_key)
126+
image_bytes = response["Body"].read()
127+
128+
# Convert the image to a standard format (PNG)
129+
image_array = imageio.imread(BytesIO(image_bytes))
130+
memfile = BytesIO()
131+
imageio.imwrite(memfile, image_array, format="png")
132+
memfile.seek(0)
133+
png_bytes = memfile.read()
134+
135+
# Encode as base64
136+
base64_image = base64.b64encode(png_bytes).decode("utf-8")
137+
return base64_image
138+
except Exception as e:
139+
Log.error(f"Failed to process image from S3 {s3_path}: {e}")
140+
return None
141+
90142
def call_claude_v1_v2(
91143
bedrock_client,
92144
model_name: str,
@@ -785,6 +837,10 @@ def prompt_model_with_retries(
785837
processed_image = process_image_url(image)
786838
if processed_image is not None:
787839
processed_images.append(processed_image)
840+
elif image.startswith("s3://"):
841+
processed_image = process_image_s3(image)
842+
if processed_image is not None:
843+
processed_images.append(processed_image)
788844
else:
789845
## Assume it's already base64 encoded:
790846
processed_images.append(image)

0 commit comments

Comments
 (0)