Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
67 changes: 67 additions & 0 deletions backend/deepcell_label/blueprints.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,12 +11,19 @@
import requests
from flask import Blueprint, abort, current_app, jsonify, request, send_file
from werkzeug.exceptions import HTTPException
from deepcell_label.types import BBox
from deepcell_label.utils import generate_onnx_sam_masks, process_image_for_sam, retrieve_sam_model_data

from deepcell_label.config import AWS_ACCESS_KEY_ID, AWS_SECRET_ACCESS_KEY, DELETE_TEMP
from deepcell_label.export import Export
from deepcell_label.label import Edit
from deepcell_label.loaders import Loader
from deepcell_label.models import Project
import json


import cv2
import numpy as np

bp = Blueprint('label', __name__) # pylint: disable=C0103

Expand Down Expand Up @@ -46,6 +53,7 @@ def handle_exception(error):
@bp.route('/api/project/<project>', methods=['GET'])
def get_project(project):
start = timeit.default_timer()
print(project)
project = Project.get(project)
if not project:
return abort(404, description=f'project {project} not found')
Expand Down Expand Up @@ -206,3 +214,62 @@ def submit_project():
timeit.default_timer() - start,
)
return {}

@bp.route('/api/testSamGeneration', methods=['POST'])
def test_sam_generation():
"""Tests the generation of a SAM model. This is a temporary route to test the generation of a SAM model.

Right now, it takes a single tiff image (example_input_combined.tif in the static folder) and generates
an image embedding (saved as embedding.npy) and an ONNX model (saved as model.onnx) in the static folder.

With these two artifacts, we are then able to call the next endpoint below (/api/testSamPrediction) to
generate a mask for a given bounding box.

Generating these two artifacts on a CPU with 8 cores takes about 50 - 60 seconds. On a 10 core GPU, it
takes less than 5 seconds.

The idea here is to offload this test API endpoint to a worker job that will generate the artifacts and
save them to S3 when a project is created in Deepcell Label. Then, when a user wants to generate a mask
for a given bounding box, we can just load the artifacts from S3 and generate the mask.

This is a temporary workaround to avoid having to use WebAssembly on the frontend to generate the mask
(streaming the .wasm files were giving us troubles).
"""
json_data = request.get_json()

process_image_for_sam(json_data["image"], json_data["embedding_output"], json_data["onnx_output"])

return {"message": "success"}

@bp.route('/api/testSamPrediction', methods=['POST'])
def test_sam_prediction():
"""Generates a mask based upon a provided bbox dataset.

Bbox data should be in the following format:

{
"x_start": int,
"x_end": int,
"y_start": int,
"y_end": int,
}

Temporarily, we are just using a static image (example_input_combined.tif). With this image, we are able
to load the embeddings npy file and the ONNX model file to generate a mask for the provided bbox data.

This DOES NOT require us to re-load the model and call the resource intensive `set_image()` method.

The output of this endpoint is an ndarray of 0's and 1's indicating where to draw the mask on the frontend.
"""
json_data = json.loads(request.data, strict=False)
bbox = BBox(**json_data)

image_embedding, ort_session = retrieve_sam_model_data()

image = cv2.imread('./deepcell_label/static/example_input_combined.tif', cv2.IMREAD_UNCHANGED)
image = cv2.cvtColor(image, cv2.COLOR_GRAY2BGR)
image = image.astype(np.uint8)

masks = generate_onnx_sam_masks(image_embedding, ort_session, image, bbox)

return {"data": masks[0].astype(int).tolist()}
1 change: 1 addition & 0 deletions backend/deepcell_label/label.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,7 @@ def load(self, labels_zip):
"""
if not zipfile.is_zipfile(labels_zip):
raise ValueError('Attached labels.zip is not a zip file.')

zf = zipfile.ZipFile(labels_zip)

# Load edit args
Expand Down
Binary file not shown.
Binary file added backend/deepcell_label/static/model.onnx
Binary file not shown.
7 changes: 7 additions & 0 deletions backend/deepcell_label/types.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
from pydantic import BaseModel

class BBox(BaseModel):
x_start: int
y_start: int
x_end: int
y_end: int
228 changes: 228 additions & 0 deletions backend/deepcell_label/utils.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,23 @@
"""Utility functions for DeepCell Label"""

import numpy as np
from segment_anything import sam_model_registry, SamPredictor
from segment_anything.modeling.sam import Sam
from segment_anything.utils.onnx import SamOnnxModel
from pathlib import Path
import os
import torch

from onnxruntime.quantization import QuantType
from onnxruntime.quantization.quantize import quantize_dynamic
from typing import Tuple
import onnxruntime

import cv2

from copy import deepcopy

from deepcell_label.types import BBox


def convert_lineage(lineage):
Expand Down Expand Up @@ -134,3 +151,214 @@ def permute_axes(array, input_axes, output_axes):
"""
permutation = tuple(input_axes.find(dim) for dim in output_axes)
return array.transpose(permutation)

def _get_sam_checkpoint(model_type: str) -> str:
"""Gets the path to a SAM model checkpoint.

Arguments:
model_type (str): Type of SAM model to load.

Returns:
str: Path to SAM model checkpoint (local).
"""
weights_urls = {
"default": "https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth",
"vit_h": "https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth",
"vit_l": "https://dl.fbaipublicfiles.com/segment_anything/sam_vit_l_0b3195.pth",
"vit_b": "https://dl.fbaipublicfiles.com/segment_anything/sam_vit_b_01ec64.pth",
}

cache_dir = Path.home() / ".cache/sam"
cache_dir.mkdir(parents=True, exist_ok=True)
weight_path = cache_dir / weights_urls[model_type].split("/")[-1]

return weight_path

def _get_sam_model(model_type: str) -> Sam:
"""Loads a SAM model.

Arguments:
model_type (str): Type of SAM model to load.
checkpoint_path (str): Path to SAM model checkpoint (local).

Returns:
Sam: SAM model.
"""
checkpoint_path = _get_sam_checkpoint(model_type)

return sam_model_registry[model_type](checkpoint=checkpoint_path)

def _get_preprocess_shape(oldh: int, oldw: int, long_side_length: int) -> Tuple[int, int]:
"""
Compute the output size given input size and target long side length.
"""
scale = long_side_length * 1.0 / max(oldh, oldw)
newh, neww = oldh * scale, oldw * scale
neww = int(neww + 0.5)
newh = int(newh + 0.5)
return (newh, neww)

def apply_sam_coords(coords: np.ndarray, original_size: Tuple[int, ...]) -> np.ndarray:
"""
Expects a numpy array of length 2 in the final dimension. Requires the
original image size in (H, W) format.
"""
old_h, old_w = original_size
new_h, new_w = _get_preprocess_shape(
original_size[0], original_size[1], 1024
)
coords = deepcopy(coords).astype(float)
coords[..., 0] = coords[..., 0] * (new_w / old_w)
coords[..., 1] = coords[..., 1] * (new_h / old_h)
return coords

def generate_image_embedding(sam_model: Sam, image: np.ndarray, output_path: str) -> None:
"""Generates an SamPredictor image embedding and saves it to a file.

Arguments:
sam_model (Sam): SAM model to generate embedding with.
image (np.ndarray): Image to generate embedding for.
output_path (str): Path to save embedding to.
"""
predictor = SamPredictor(sam_model)
predictor.set_image(image)

image_embedding = predictor.get_image_embedding().cpu().numpy()

np.save(output_path, image_embedding)

def generate_onnx_model(sam_model: Sam, original_image_width: int, original_image_height: int, output_path: str) -> None:
"""Generates an ONNX model from a SAM model.

Arguments:
sam_model (Sam): SAM model to convert to ONNX.
original_image_width (int): Width of the original image.
original_image_height (int): Height of the original image.
output_path (str): Path to save ONNX model to.
"""
onnx_model = SamOnnxModel(sam_model, return_single_mask=True)

dynamic_axes = {
"point_coords": {1: "num_points"},
"point_labels": {1: "num_points"},
}

embed_dim = sam_model.prompt_encoder.embed_dim
embed_size = sam_model.prompt_encoder.image_embedding_size
mask_input_size = [4 * x for x in embed_size]
dummy_inputs = {
"image_embeddings": torch.randn(1, embed_dim, *embed_size, dtype=torch.float),
"point_coords": torch.randint(low=0, high=1024, size=(1, 5, 2), dtype=torch.float),
"point_labels": torch.randint(low=0, high=4, size=(1, 5), dtype=torch.float),
"mask_input": torch.randn(1, 1, *mask_input_size, dtype=torch.float),
"has_mask_input": torch.tensor([1], dtype=torch.float),
"orig_im_size": torch.tensor([original_image_width, original_image_height], dtype=torch.float),
}
output_names = ["masks", "iou_predictions", "low_res_masks"]

temp_path = "./deepcell_label/static/tmp.onnx"

with open(temp_path, "wb") as f:
torch.onnx.export(
onnx_model,
tuple(dummy_inputs.values()),
f,
export_params=True,
verbose=False,
opset_version=17,
do_constant_folding=True,
input_names=list(dummy_inputs.keys()),
output_names=output_names,
dynamic_axes=dynamic_axes,
)

quantize_dynamic(
model_input=temp_path,
model_output=output_path,
per_channel=False,
reduce_range=False,
weight_type=QuantType.QUInt8,
)

os.remove(temp_path)

def process_image_for_sam(image_path: str, embedding_output_path: str, onnx_output_path: str) -> None:
"""Processes an image for SAM.

Arguments:
image_path (str): Path to image to process.
embedding_output_path (str): Path to save image embedding to.
onnx_output_path (str): Path to save ONNX model to.
"""
print(image_path)
print(os.getcwd())
image = cv2.imread(image_path, cv2.IMREAD_UNCHANGED)
image = cv2.cvtColor(image, cv2.COLOR_GRAY2BGR)
image = image.astype(np.uint8)

sam_model = _get_sam_model("vit_h")

generate_image_embedding(sam_model, image, embedding_output_path)

generate_onnx_model(sam_model, image.shape[1], image.shape[0], onnx_output_path)

def retrieve_sam_model_data() -> Tuple[np.ndarray, onnxruntime.InferenceSession]:
"""Retrieves SAM model data; both the image embedding and the ONNX model.

Returns:
Tuple[np.ndarray, onnxruntime.InferenceSession]: Tuple containing the image embedding and the ONNX model.

TODO: This should be rewritten to interact with S3 instead of a local static folder.
"""
image_embedding = np.load("./deepcell_label/static/embedding.npy")
ort_session = onnxruntime.InferenceSession("./deepcell_label/static/model.onnx")

return image_embedding, ort_session

def generate_onnx_sam_masks(
image_embedding: np.ndarray,
ort_session: onnxruntime.InferenceSession,
image: np.ndarray,
bbox: BBox
) -> np.ndarray:
"""Generates masks from an image embedding and ONNX model.

Arguments:
image_embedding (np.ndarray): Image embedding to use.
ort_session (onnxruntime.InferenceSession): ONNX model to use.
image (np.ndarray): Image to generate masks for.
bbox (dict): Bounding box to generate masks for.

Returns:
np.ndarray: Generated masks. For now only one mask is returned.
"""
input_box = np.array([bbox.x_start, bbox.y_start, bbox.x_end, bbox.y_end])

# Random for now until I figure out how to remove
input_point = np.array([[0, 0]])
input_label = np.array([0])

onnx_box_coords = input_box.reshape(2, 2)
onnx_box_labels = np.array([2,3])

onnx_coord = np.concatenate([input_point, onnx_box_coords], axis=0)[None, :, :]
onnx_label = np.concatenate([input_label, onnx_box_labels], axis=0)[None, :].astype(np.float32)

onnx_coord = apply_sam_coords(onnx_coord, image.shape[:2]).astype(np.float32)

onnx_mask_input = np.zeros((1, 1, 256, 256), dtype=np.float32)
onnx_has_mask_input = np.zeros(1, dtype=np.float32)

ort_inputs = {
"image_embeddings": image_embedding,
"point_coords": onnx_coord,
"point_labels": onnx_label,
"mask_input": onnx_mask_input,
"has_mask_input": onnx_has_mask_input,
"orig_im_size": np.array(image.shape[:2], dtype=np.float32)
}

masks, _, _ = ort_session.run(None, ort_inputs)
masks = masks > 0.0

return masks
7 changes: 7 additions & 0 deletions backend/requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -16,3 +16,10 @@ scikit-image~=0.19.0
sqlalchemy~=1.3.24
tifffile
imagecodecs
torch==2.1.0
torchvision==0.16.0
segment-anything @ git+https://github.com/facebookresearch/segment-anything.git@6fdee8f2727f4506cfbbe553e23b895e27956588
opencv-python==4.8.1.78
onnx==1.15.0
onnxruntime==1.16.1
pydantic==2.4.2
5 changes: 4 additions & 1 deletion frontend/package.json
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
"@babel/plugin-transform-react-jsx": "^7.16.7",
"@emotion/react": "^11.7.1",
"@emotion/styled": "^11.6.0",
"@hms-dbmi/viv": "^0.12.6",
"@hms-dbmi/viv": "^0.13.8",
"@luma.gl/core": "^8.5.10",
"@mui/icons-material": "^5.3.1",
"@mui/lab": "^5.0.0-alpha.107",
Expand Down Expand Up @@ -49,7 +49,9 @@
"lodash.debounce": "^4.0.8",
"mathjs": "^11.6.0",
"mousetrap": "^1.6.5",
"npyjs": "^0.5.0",
"nyc": "^15.1.0",
"onnxruntime-web": "^1.16.1",
"plotly.js": "2.12.1",
"prop-types": "^15.8.1",
"quickselect": "^2.0.0",
Expand Down Expand Up @@ -102,6 +104,7 @@
"fake-indexeddb": "^3.1.7",
"prettier": "^2.3.1",
"prettier-plugin-organize-imports": "^2.1.0",
"wasm-loader": "^1.3.0",
"worker-loader": "^3.0.8"
},
"husky": {
Expand Down
Binary file added frontend/public/sam_onnx_example.onnx
Binary file not shown.
Binary file added frontend/public/sam_onnx_quantized_example.onnx
Binary file not shown.
Loading