diff --git a/Jenkinsfile b/Jenkinsfile index eb533e1d..f277b360 100644 --- a/Jenkinsfile +++ b/Jenkinsfile @@ -39,9 +39,10 @@ pipeline { sh 'uv sync --locked --only-group gdal-build-dependencies' sh 'uv sync --locked' sh 'uv run pybabel compile -d sketch_map_tool/translations' - sh 'wget --quiet -P weights https://sketch-map-tool.heigit.org/weights/SMT-OSM.pt' - sh 'wget --quiet -P weights https://sketch-map-tool.heigit.org/weights/SMT-ESRI.pt' - sh 'wget --quiet -P weights https://sketch-map-tool.heigit.org/weights/SMT-CLS.pt' + sh 'wget --quiet -P weights https://downloads.ohsome.org/sketch-map-tool/weights/SMT-OSM.pt' + sh 'wget --quiet -P weights https://downloads.ohsome.org/sketch-map-tool/weights/SMT-ESRI.pt' + sh 'wget --quiet -P weights https://downloads.ohsome.org/sketch-map-tool/weights/SMT-CLS.pt' + sh 'wget --quiet -P weights https://downloads.ohsome.org/sketch-map-tool/weights/SMT-SAM.pt' } } post { diff --git a/docs/development-setup.md b/docs/development-setup.md index 9bd0bdef..1aa95371 100644 --- a/docs/development-setup.md +++ b/docs/development-setup.md @@ -49,7 +49,7 @@ npm install npm run build # Download ml-model weights -wget -P weights https://downloads.ohsome.org/sketch-map-tool/weights/SMT-{OSM,ESRI,CLS}.pt +wget -P weights https://downloads.ohsome.org/sketch-map-tool/weights/SMT-{OSM,ESRI,CLS,SAM}.pt # Fetch and run database & result store (postgres) docker run --name smt-postgres -d -p 5432:5432 -e POSTGRES_PASSWORD=smt -e POSTGRES_USER=smt postgres:15 diff --git a/docs/model_registry.md b/docs/model_registry.md index 530dc857..de8ab8e0 100644 --- a/docs/model_registry.md +++ b/docs/model_registry.md @@ -10,7 +10,7 @@ The **Model Registry** maintains a collection of fine-tuned machine learning mod | Object Detection | YOLO_OSM | 6-Channel Input | Detects sketches on OSM | [download](https://downloads.ohsome.org/sketch-map-tool/weights/SMT-OSM.pt) | | Object Detection | YOLO_ESRI | 6-Channel Input | Detects sketches on ESRI maps | [download](hhttps://downloads.ohsome.org/sketch-map-tool/weights/SMT-ESRI.pt) | | Image Classification | YOLO_CLS | Standard RGB | Classifies colors in sketches | [download](https://downloads.ohsome.org/sketch-map-tool/weights/SMT-CLS.pt) | -| Segmentation | SAM2 | tandard RGB | Performs segmentation on sketch | [github](https://github.com/facebookresearch/sam2) | +| Segmentation | SAM2 | Standard RGB | Performs segmentation on on sketches, finetuned from **SAM 2.1 hiera large** | [download](hhttps://downloads.ohsome.org/sketch-map-tool/weights/SMT-SAM.pt) | ## Models in the Registry ### 1. Object Detection Models @@ -40,7 +40,9 @@ This model is used to determine the sketch's **color**. For segmentation tasks, **SAM2 (Segment Anything Model v2)** is utilized. #### **SAM2 - Segmentation Model** +- **Base Map:** Both OSM and ESRI Satellite Imagery - **Task:** Performs pixel-wise segmentation to extract regions from sketches. +- **Fine-tuned On:** On a set of manually selected segmented sketches to improve performance on sketch data. For questions, contact the **SketchMapTool Team**. diff --git a/sketch_map_tool/config.py b/sketch_map_tool/config.py index cb25a8a3..60c18ce9 100644 --- a/sketch_map_tool/config.py +++ b/sketch_map_tool/config.py @@ -24,7 +24,8 @@ "yolo_cls": "SMT-CLS", "yolo_osm_obj": "SMT-OSM", "yolo_esri_obj": "SMT-ESRI", - "model_type_sam": "vit_b", + "sam_checkpoint": "SMT-SAM", + "model_type_sam": "configs/sam2.1/sam2.1_hiera_l.yaml", "esri-api-key": "", "log-level": "INFO", } diff --git a/sketch_map_tool/tasks.py b/sketch_map_tool/tasks.py index f9af3d7c..9cbfd5a1 100644 --- a/sketch_map_tool/tasks.py +++ b/sketch_map_tool/tasks.py @@ -1,6 +1,7 @@ import logging from io import BytesIO +import torch from celery.result import AsyncResult from celery.signals import setup_logging, worker_process_init, worker_process_shutdown from geojson import FeatureCollection @@ -27,7 +28,6 @@ from sketch_map_tool.upload_processing.detect_markings import detect_markings from sketch_map_tool.upload_processing.ml_models import ( init_model, - init_sam2, select_computation_device, ) from sketch_map_tool.wms import client as wms_client @@ -55,13 +55,15 @@ def init_worker_ml_models(**_): global yolo_obj_esri global yolo_cls - path = init_sam2() device = select_computation_device() sam2_model = build_sam2( - config_file="sam2_hiera_b+.yaml", - ckpt_path=path, + config_file=get_config_value("model_type_sam"), + ckpt_path=None, device=device, ) + sam2_model.load_state_dict( + torch.load(init_model(get_config_value("sam_checkpoint")), map_location=device) + ) sam_predictor = SAM2ImagePredictor(sam2_model) yolo_obj_osm = YOLO_MB(init_model(get_config_value("yolo_osm_obj"))) diff --git a/sketch_map_tool/upload_processing/ml_models.py b/sketch_map_tool/upload_processing/ml_models.py index 2d5c6402..78333381 100644 --- a/sketch_map_tool/upload_processing/ml_models.py +++ b/sketch_map_tool/upload_processing/ml_models.py @@ -1,7 +1,6 @@ import logging from pathlib import Path -import requests import torch from torch._prims_common import DeviceLikeType @@ -17,19 +16,6 @@ def init_model(id: str) -> Path: return path -def init_sam2(id: str = "sam2_hiera_base_plus") -> Path: - raw = Path(get_config_value("weights-dir")) / id - path = raw.with_suffix(".pt") - base_url = "https://dl.fbaipublicfiles.com/segment_anything_2/072824/" - url = base_url + id + ".pt" - if not path.is_file(): - logging.info(f"Downloading model SAM-2 from fbaipublicfiles.com to {path}.") - response = requests.get(url=url) - with open(path, mode="wb") as file: - file.write(response.content) - return path - - def select_computation_device() -> DeviceLikeType: """Select computation device (cuda, mps, cpu) for SAM-2""" if torch.cuda.is_available(): diff --git a/tests/approvals/integration/upload_processing/test_detect_markings.py--test_detect_markings[1726835278].approved.png b/tests/approvals/integration/upload_processing/test_detect_markings.py--test_detect_markings[1726835278].approved.png new file mode 100644 index 00000000..4d2de3b8 Binary files /dev/null and b/tests/approvals/integration/upload_processing/test_detect_markings.py--test_detect_markings[1726835278].approved.png differ diff --git a/tests/approvals/integration/upload_processing/test_detect_markings.py--test_detect_markings[2346410719].approved.png b/tests/approvals/integration/upload_processing/test_detect_markings.py--test_detect_markings[2346410719].approved.png new file mode 100644 index 00000000..bd488cc5 Binary files /dev/null and b/tests/approvals/integration/upload_processing/test_detect_markings.py--test_detect_markings[2346410719].approved.png differ diff --git a/tests/approvals/integration/upload_processing/test_detect_markings.py--test_detect_markings[68370924].approved.png b/tests/approvals/integration/upload_processing/test_detect_markings.py--test_detect_markings[68370924].approved.png new file mode 100644 index 00000000..3cf27f47 Binary files /dev/null and b/tests/approvals/integration/upload_processing/test_detect_markings.py--test_detect_markings[68370924].approved.png differ diff --git a/tests/approvals/integration/upload_processing/test_detect_markings.py--test_detect_markings[68370924].received.png b/tests/approvals/integration/upload_processing/test_detect_markings.py--test_detect_markings[68370924].received.png new file mode 100644 index 00000000..1c82cdc2 Binary files /dev/null and b/tests/approvals/integration/upload_processing/test_detect_markings.py--test_detect_markings[68370924].received.png differ diff --git a/tests/integration/upload_processing/test_detect_markings.py b/tests/integration/upload_processing/test_detect_markings.py index 61b06e89..5a7eec92 100644 --- a/tests/integration/upload_processing/test_detect_markings.py +++ b/tests/integration/upload_processing/test_detect_markings.py @@ -1,6 +1,8 @@ import numpy as np import pytest -from PIL import Image +import torch +from PIL import Image, ImageDraw, ImageOps +from pytest_approval import verify_image_pillow from sam2.build_sam import build_sam2 from sam2.sam2_image_predictor import SAM2ImagePredictor from ultralytics import YOLO @@ -12,7 +14,6 @@ ) from sketch_map_tool.upload_processing.ml_models import ( init_model, - init_sam2, select_computation_device, ) @@ -22,13 +23,15 @@ @pytest.fixture def sam_predictor(): """Zero shot segment anything model""" - path = init_sam2() device = select_computation_device() sam2_model = build_sam2( - config_file="sam2_hiera_b+.yaml", - ckpt_path=path, + config_file=get_config_value("model_type_sam"), + ckpt_path=None, device=device, ) + sam2_model.load_state_dict( + torch.load(init_model(get_config_value("sam_checkpoint")), map_location=device) + ) return SAM2ImagePredictor(sam2_model) @@ -80,38 +83,21 @@ def test_detect_markings( yolo_cls, sam_predictor, ) + img = Image.fromarray(map_frame_marked) + for m in markings: + m[m == m.max()] = 255 + colored_marking = ImageOps.colorize( + Image.fromarray(m).convert("L"), black="black", white="green" + ) + img.paste(colored_marking, (0, 0), Image.fromarray(m)) + # draw bbox around each marking, derived from the mask m + bbox = ( + np.min(np.where(m)[1]), + np.min(np.where(m)[0]), + np.max(np.where(m)[1]), + np.max(np.where(m)[0]), + ) - # NOTE: uncomment for manual/visual assessment of detected markings - # TODO: use approval test - # import random - # from PIL import ImageDraw, ImageOps - # img = Image.fromarray(map_frame_marked) - # for m in markings: - # colors = [ - # "red", - # "green", - # "blue", - # "yellow", - # "purple", - # "orange", - # "pink", - # "brown", - # ] - # m[m == m.max()] = 255 - # colored_marking = ImageOps.colorize( - # Image.fromarray(m).convert("L"), - # black="black", - # white=random.choice(colors), - # ) - # img.paste(colored_marking, (0, 0), Image.fromarray(m)) - # # draw bbox around each marking, derived from the mask m - # bbox = ( - # np.min(np.where(m)[1]), - # np.min(np.where(m)[0]), - # np.max(np.where(m)[1]), - # np.max(np.where(m)[0]), - # ) - # - # draw = ImageDraw.Draw(img) - # draw.rectangle(bbox, outline="red", width=2) - # img.show() + draw = ImageDraw.Draw(img) + draw.rectangle(bbox, outline="red", width=2) + assert verify_image_pillow(img, extension=".png") diff --git a/tests/unit/test_config.py b/tests/unit/test_config.py index 7d8ceeb2..0be119c2 100644 --- a/tests/unit/test_config.py +++ b/tests/unit/test_config.py @@ -26,6 +26,7 @@ def config_keys(): "yolo_cls", "yolo_osm_obj", "yolo_esri_obj", + "sam_checkpoint", "model_type_sam", "esri-api-key", "log-level",