Skip to content
7 changes: 4 additions & 3 deletions Jenkinsfile
Original file line number Diff line number Diff line change
Expand Up @@ -39,9 +39,10 @@ pipeline {
sh 'uv sync --locked --only-group gdal-build-dependencies'
sh 'uv sync --locked'
sh 'uv run pybabel compile -d sketch_map_tool/translations'
sh 'wget --quiet -P weights https://sketch-map-tool.heigit.org/weights/SMT-OSM.pt'
sh 'wget --quiet -P weights https://sketch-map-tool.heigit.org/weights/SMT-ESRI.pt'
sh 'wget --quiet -P weights https://sketch-map-tool.heigit.org/weights/SMT-CLS.pt'
sh 'wget --quiet -P weights https://downloads.ohsome.org/sketch-map-tool/weights/SMT-OSM.pt'
sh 'wget --quiet -P weights https://downloads.ohsome.org/sketch-map-tool/weights/SMT-ESRI.pt'
sh 'wget --quiet -P weights https://downloads.ohsome.org/sketch-map-tool/weights/SMT-CLS.pt'
sh 'wget --quiet -P weights https://downloads.ohsome.org/sketch-map-tool/weights/SMT-SAM.pt'
}
}
post {
Expand Down
2 changes: 1 addition & 1 deletion docs/development-setup.md
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ npm install
npm run build

# Download ml-model weights
wget -P weights https://downloads.ohsome.org/sketch-map-tool/weights/SMT-{OSM,ESRI,CLS}.pt
wget -P weights https://downloads.ohsome.org/sketch-map-tool/weights/SMT-{OSM,ESRI,CLS,SAM}.pt

# Fetch and run database & result store (postgres)
docker run --name smt-postgres -d -p 5432:5432 -e POSTGRES_PASSWORD=smt -e POSTGRES_USER=smt postgres:15
Expand Down
4 changes: 3 additions & 1 deletion docs/model_registry.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ The **Model Registry** maintains a collection of fine-tuned machine learning mod
| Object Detection | YOLO_OSM | 6-Channel Input | Detects sketches on OSM | [download](https://downloads.ohsome.org/sketch-map-tool/weights/SMT-OSM.pt) |
| Object Detection | YOLO_ESRI | 6-Channel Input | Detects sketches on ESRI maps | [download](hhttps://downloads.ohsome.org/sketch-map-tool/weights/SMT-ESRI.pt) |
| Image Classification | YOLO_CLS | Standard RGB | Classifies colors in sketches | [download](https://downloads.ohsome.org/sketch-map-tool/weights/SMT-CLS.pt) |
| Segmentation | SAM2 | tandard RGB | Performs segmentation on sketch | [github](https://github.com/facebookresearch/sam2) |
| Segmentation | SAM2 | Standard RGB | Performs segmentation on on sketches, finetuned from **SAM 2.1 hiera large** | [download](hhttps://downloads.ohsome.org/sketch-map-tool/weights/SMT-SAM.pt) |

## Models in the Registry
### 1. Object Detection Models
Expand Down Expand Up @@ -40,7 +40,9 @@ This model is used to determine the sketch's **color**.
For segmentation tasks, **SAM2 (Segment Anything Model v2)** is utilized.

#### **SAM2 - Segmentation Model**
- **Base Map:** Both OSM and ESRI Satellite Imagery
- **Task:** Performs pixel-wise segmentation to extract regions from sketches.
- **Fine-tuned On:** On a set of manually selected segmented sketches to improve performance on sketch data.


For questions, contact the **SketchMapTool Team**.
Expand Down
3 changes: 2 additions & 1 deletion sketch_map_tool/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,8 @@
"yolo_cls": "SMT-CLS",
"yolo_osm_obj": "SMT-OSM",
"yolo_esri_obj": "SMT-ESRI",
"model_type_sam": "vit_b",
"sam_checkpoint": "SMT-SAM",
"model_type_sam": "configs/sam2.1/sam2.1_hiera_l.yaml",
"esri-api-key": "",
"log-level": "INFO",
}
Expand Down
10 changes: 6 additions & 4 deletions sketch_map_tool/tasks.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import logging
from io import BytesIO

import torch
from celery.result import AsyncResult
from celery.signals import setup_logging, worker_process_init, worker_process_shutdown
from geojson import FeatureCollection
Expand All @@ -27,7 +28,6 @@
from sketch_map_tool.upload_processing.detect_markings import detect_markings
from sketch_map_tool.upload_processing.ml_models import (
init_model,
init_sam2,
select_computation_device,
)
from sketch_map_tool.wms import client as wms_client
Expand Down Expand Up @@ -55,13 +55,15 @@ def init_worker_ml_models(**_):
global yolo_obj_esri
global yolo_cls

path = init_sam2()
device = select_computation_device()
sam2_model = build_sam2(
config_file="sam2_hiera_b+.yaml",
ckpt_path=path,
config_file=get_config_value("model_type_sam"),
ckpt_path=None,
device=device,
)
sam2_model.load_state_dict(
torch.load(init_model(get_config_value("sam_checkpoint")), map_location=device)
)
sam_predictor = SAM2ImagePredictor(sam2_model)

yolo_obj_osm = YOLO_MB(init_model(get_config_value("yolo_osm_obj")))
Expand Down
14 changes: 0 additions & 14 deletions sketch_map_tool/upload_processing/ml_models.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
import logging
from pathlib import Path

import requests
import torch
from torch._prims_common import DeviceLikeType

Expand All @@ -17,19 +16,6 @@ def init_model(id: str) -> Path:
return path


def init_sam2(id: str = "sam2_hiera_base_plus") -> Path:
raw = Path(get_config_value("weights-dir")) / id
path = raw.with_suffix(".pt")
base_url = "https://dl.fbaipublicfiles.com/segment_anything_2/072824/"
url = base_url + id + ".pt"
if not path.is_file():
logging.info(f"Downloading model SAM-2 from fbaipublicfiles.com to {path}.")
response = requests.get(url=url)
with open(path, mode="wb") as file:
file.write(response.content)
return path


def select_computation_device() -> DeviceLikeType:
"""Select computation device (cuda, mps, cpu) for SAM-2"""
if torch.cuda.is_available():
Expand Down
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
64 changes: 25 additions & 39 deletions tests/integration/upload_processing/test_detect_markings.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
import numpy as np
import pytest
from PIL import Image
import torch
from PIL import Image, ImageDraw, ImageOps
from pytest_approval import verify_image_pillow
from sam2.build_sam import build_sam2
from sam2.sam2_image_predictor import SAM2ImagePredictor
from ultralytics import YOLO
Expand All @@ -12,7 +14,6 @@
)
from sketch_map_tool.upload_processing.ml_models import (
init_model,
init_sam2,
select_computation_device,
)

Expand All @@ -22,13 +23,15 @@
@pytest.fixture
def sam_predictor():
"""Zero shot segment anything model"""
path = init_sam2()
device = select_computation_device()
sam2_model = build_sam2(
config_file="sam2_hiera_b+.yaml",
ckpt_path=path,
config_file=get_config_value("model_type_sam"),
ckpt_path=None,
device=device,
)
sam2_model.load_state_dict(
torch.load(init_model(get_config_value("sam_checkpoint")), map_location=device)
)
return SAM2ImagePredictor(sam2_model)


Expand Down Expand Up @@ -80,38 +83,21 @@ def test_detect_markings(
yolo_cls,
sam_predictor,
)
img = Image.fromarray(map_frame_marked)
for m in markings:
m[m == m.max()] = 255
colored_marking = ImageOps.colorize(
Image.fromarray(m).convert("L"), black="black", white="green"
)
img.paste(colored_marking, (0, 0), Image.fromarray(m))
# draw bbox around each marking, derived from the mask m
bbox = (
np.min(np.where(m)[1]),
np.min(np.where(m)[0]),
np.max(np.where(m)[1]),
np.max(np.where(m)[0]),
)

# NOTE: uncomment for manual/visual assessment of detected markings
# TODO: use approval test
# import random
# from PIL import ImageDraw, ImageOps
# img = Image.fromarray(map_frame_marked)
# for m in markings:
# colors = [
# "red",
# "green",
# "blue",
# "yellow",
# "purple",
# "orange",
# "pink",
# "brown",
# ]
# m[m == m.max()] = 255
# colored_marking = ImageOps.colorize(
# Image.fromarray(m).convert("L"),
# black="black",
# white=random.choice(colors),
# )
# img.paste(colored_marking, (0, 0), Image.fromarray(m))
# # draw bbox around each marking, derived from the mask m
# bbox = (
# np.min(np.where(m)[1]),
# np.min(np.where(m)[0]),
# np.max(np.where(m)[1]),
# np.max(np.where(m)[0]),
# )
#
# draw = ImageDraw.Draw(img)
# draw.rectangle(bbox, outline="red", width=2)
# img.show()
draw = ImageDraw.Draw(img)
draw.rectangle(bbox, outline="red", width=2)
assert verify_image_pillow(img, extension=".png")
1 change: 1 addition & 0 deletions tests/unit/test_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ def config_keys():
"yolo_cls",
"yolo_osm_obj",
"yolo_esri_obj",
"sam_checkpoint",
"model_type_sam",
"esri-api-key",
"log-level",
Expand Down