diff --git a/docs/pretrained.rst b/docs/pretrained.rst index 310fc83bc..1a2a53faf 100644 --- a/docs/pretrained.rst +++ b/docs/pretrained.rst @@ -353,7 +353,7 @@ The input output configuration is as follows: ioconfig = IOPatchPredictorConfig( patch_input_shape=(31, 31), stride_shape=(8, 8), - input_resolutions=[{"resolution": 0.25, "units": "mpp"}] + input_resolutions=[{"resolution": 0.5, "units": "mpp"}] ) @@ -369,7 +369,7 @@ The input output configuration is as follows: ioconfig = IOPatchPredictorConfig( patch_input_shape=(252, 252), stride_shape=(150, 150), - input_resolutions=[{"resolution": 0.25, "units": "mpp"}] + input_resolutions=[{"resolution": 0.5, "units": "mpp"}] ) @@ -393,7 +393,7 @@ The input output configuration is as follows: ioconfig = IOPatchPredictorConfig( patch_input_shape=(31, 31), stride_shape=(8, 8), - input_resolutions=[{"resolution": 0.25, "units": "mpp"}] + input_resolutions=[{"resolution": 0.5, "units": "mpp"}] ) @@ -409,7 +409,7 @@ The input output configuration is as follows: ioconfig = IOPatchPredictorConfig( patch_input_shape=(252, 252), stride_shape=(150, 150), - input_resolutions=[{"resolution": 0.25, "units": "mpp"}] + input_resolutions=[{"resolution": 0.5, "units": "mpp"}] ) diff --git a/tests/engines/test_nucleus_detection_engine.py b/tests/engines/test_nucleus_detection_engine.py new file mode 100644 index 000000000..c140db07b --- /dev/null +++ b/tests/engines/test_nucleus_detection_engine.py @@ -0,0 +1,323 @@ +"""Tests for NucleusDetector.""" + +import shutil +from collections.abc import Callable +from pathlib import Path + +import dask.array as da +import numpy as np +import pytest +import zarr +from click.testing import CliRunner + +from tiatoolbox import cli +from tiatoolbox.annotation.storage import SQLiteStore +from tiatoolbox.models.engine.nucleus_detector import ( + NucleusDetector, +) +from tiatoolbox.utils import env_detection as toolbox_env +from tiatoolbox.utils.misc import imwrite +from tiatoolbox.wsicore.wsireader import WSIReader + +device = "cuda" if toolbox_env.has_gpu() else "cpu" + + +def _rm_dir(path: Path) -> None: + """Helper func to remove directory.""" + if path.exists(): + shutil.rmtree(path, ignore_errors=True) + + +def test_centroid_maps_to_detection_arrays() -> None: + """Convert centroid maps to detection arrays.""" + detection_maps = np.zeros((4, 4, 2), dtype=np.float32) + detection_maps[1, 1, 0] = 1.0 + detection_maps[2, 3, 1] = 0.5 + detection_maps = da.from_array(detection_maps, chunks=(2, 2, 2)) + + detections = NucleusDetector._centroid_maps_to_detection_arrays(detection_maps) + + xs = detections["x"] + ys = detections["y"] + classes = detections["classes"] + probs = detections["probabilities"] + + np.testing.assert_array_equal(xs, np.array([1, 3], dtype=np.uint32)) + np.testing.assert_array_equal(ys, np.array([1, 2], dtype=np.uint32)) + np.testing.assert_array_equal(classes, np.array([0, 1], dtype=np.uint32)) + np.testing.assert_array_equal(probs, np.array([1.0, 0.5], dtype=np.float32)) + + +def test_write_detection_arrays_to_store() -> None: + """Test writing detection arrays to annotation store.""" + detection_arrays = { + "x": np.array([1, 3], dtype=np.uint32), + "y": np.array([1, 2], dtype=np.uint32), + "classes": np.array([0, 1], dtype=np.uint32), + "probabilities": np.array([1.0, 0.5], dtype=np.float32), + } + + store = NucleusDetector.save_detection_arrays_to_store(detection_arrays) + assert len(store.values()) == 2 + + detection_arrays = { + "x": np.array([1], dtype=np.uint32), + "y": np.array([1, 2], dtype=np.uint32), + "classes": np.array([0], dtype=np.uint32), + "probabilities": np.array([1.0, 0.5], dtype=np.float32), + } + with pytest.raises( + ValueError, + match=r"Detection record lengths are misaligned.", + ): + _ = NucleusDetector.save_detection_arrays_to_store(detection_arrays) + + +def test_write_detection_records_to_store_no_class_dict() -> None: + """Test writing detection records to annotation store.""" + detection_records = (np.array([1]), np.array([2]), np.array([0]), np.array([1.0])) + + dummy_store = SQLiteStore() + total = NucleusDetector._write_detection_arrays_to_store( + detection_records, store=dummy_store, scale_factor=(1.0, 1.0), class_dict=None + ) + assert len(dummy_store.values()) == 1 + assert total == 1 + annotation = next(iter(dummy_store.values())) + assert annotation.properties["type"] == 0 + dummy_store.close() + + +def test_nucleus_detector_patch_annotation_store_output( + remote_sample: Callable, track_tmp_path: Path +) -> None: + """Test for nucleus detection engine in patch mode.""" + mini_wsi_svs = Path(remote_sample("wsi1_2k_2k_svs")) + + wsi_reader = WSIReader.open(mini_wsi_svs) + patch_1 = wsi_reader.read_bounds( + (30, 30, 61, 61), + resolution=0.25, + units="mpp", + coord_space="resolution", + ) + patch_2 = np.zeros((31, 31, 3), dtype=np.uint8) + + pretrained_model = "sccnn-conic" + + save_dir = track_tmp_path + + nucleus_detector = NucleusDetector(model=pretrained_model) + _ = nucleus_detector.run( + patch_mode=True, + device=device, + output_type="annotationstore", + memory_threshold=50, + images=[patch_1, patch_2], + save_dir=save_dir, + overwrite=True, + class_dict=None, + ) + + store_1 = SQLiteStore.open(save_dir / "0.db") + assert len(store_1.values()) == 1 + store_1.close() + + store_2 = SQLiteStore.open(save_dir / "1.db") + assert len(store_2.values()) == 0 + store_2.close() + + imwrite(save_dir / "patch_0.png", patch_1) + imwrite(save_dir / "patch_1.png", patch_2) + _ = nucleus_detector.run( + patch_mode=True, + device=device, + output_type="annotationstore", + memory_threshold=50, + images=[save_dir / "patch_0.png", save_dir / "patch_1.png"], + save_dir=save_dir, + overwrite=True, + ) + + store_1 = SQLiteStore.open(save_dir / "patch_0.db") + assert len(store_1.values()) == 1 + store_1.close() + + store_2 = SQLiteStore.open(save_dir / "patch_1.db") + assert len(store_2.values()) == 0 + store_2.close() + + _rm_dir(save_dir) + + +def test_nucleus_detector_patches_dict_output( + remote_sample: Callable, +) -> None: + """Test for nucleus detection engine in patch mode.""" + mini_wsi_svs = Path(remote_sample("wsi1_2k_2k_svs")) + + wsi_reader = WSIReader.open(mini_wsi_svs) + patch_1 = wsi_reader.read_bounds( + (30, 30, 61, 61), + resolution=0.25, + units="mpp", + coord_space="resolution", + ) + patch_2 = np.zeros_like(patch_1) + + model = "sccnn-conic" + + nucleus_detector = NucleusDetector(model=model) + + output_dict = nucleus_detector.run( + patch_mode=True, + device=device, + output_type="dict", + memory_threshold=50, + images=np.stack([patch_1, patch_2], axis=0), + save_dir=None, + class_dict=None, + return_probabilities=True, + ) + assert len(output_dict["x"]) == 2 + assert len(output_dict["y"]) == 2 + assert len(output_dict["classes"]) == 2 + assert len(output_dict["probabilities"]) == 2 + assert len(output_dict["x"][0]) == 1 + assert len(output_dict["x"][1]) == 0 + assert len(output_dict["y"][0]) == 1 + assert len(output_dict["y"][1]) == 0 + assert len(output_dict["classes"][0]) == 1 + assert len(output_dict["classes"][1]) == 0 + assert len(output_dict["probabilities"][0]) == 1 + assert len(output_dict["probabilities"][1]) == 0 + + +def test_nucleus_detector_patches_zarr_output( + remote_sample: Callable, track_tmp_path: Path +) -> None: + """Test for nucleus detection engine in patch mode.""" + mini_wsi_svs = Path(remote_sample("wsi1_2k_2k_svs")) + wsi_reader = WSIReader.open(mini_wsi_svs) + patch_1 = wsi_reader.read_bounds( + (30, 30, 61, 61), + resolution=0.25, + units="mpp", + coord_space="resolution", + ) + patch_2 = np.zeros_like(patch_1) + + pretrained_model = "sccnn-conic" + + nucleus_detector = NucleusDetector(model=pretrained_model) + + save_dir = track_tmp_path + + output_path = nucleus_detector.run( + patch_mode=True, + device=device, + output_type="zarr", + memory_threshold=50, + images=np.stack([patch_1, patch_2], axis=0), + save_dir=save_dir, + class_dict=None, + overwrite=True, + return_probabilities=True, + ) + + output_zarr = zarr.open(output_path, mode="r") + + assert output_zarr["x"][0].size == 1 + assert output_zarr["x"][1].size == 0 + assert output_zarr["y"][0].size == 1 + assert output_zarr["y"][1].size == 0 + assert output_zarr["classes"][0].size == 1 + assert output_zarr["classes"][1].size == 0 + assert output_zarr["probabilities"][0].size == 1 + assert output_zarr["probabilities"][1].size == 0 + + _rm_dir(save_dir) + + +def test_nucleus_detector_wsi(remote_sample: Callable, track_tmp_path: Path) -> None: + """Test for nucleus detection engine.""" + mini_wsi_svs = Path(remote_sample("wsi4_512_512_svs")) + + pretrained_model = "sccnn-conic" + + save_dir = track_tmp_path + + nucleus_detector = NucleusDetector(model=pretrained_model) + nucleus_detector.drop_keys = [] + _ = nucleus_detector.run( + patch_mode=False, + device=device, + output_type="annotationstore", + memory_threshold=50, + images=[mini_wsi_svs], + save_dir=save_dir, + overwrite=True, + batch_size=8, + class_dict={0: "test_nucleus"}, + min_distance=5, + postproc_tile_shape=(2048, 2048), + ) + + store = SQLiteStore.open(save_dir / "wsi4_512_512.db") + assert 255 <= len(store.values()) <= 265 + annotation = next(iter(store.values())) + assert annotation.properties["type"] == "test_nucleus" + store.close() + + nucleus_detector.drop_keys = ["probabilities"] + result_path = nucleus_detector.run( + patch_mode=False, + device=device, + output_type="zarr", + memory_threshold=50, + images=[mini_wsi_svs], + save_dir=save_dir, + overwrite=True, + batch_size=8, + ) + print("Result path:", result_path) + + zarr_path = result_path[mini_wsi_svs] + zarr_group = zarr.open(zarr_path, mode="r") + xs = zarr_group["x"][:] + ys = zarr_group["y"][:] + classes = zarr_group["classes"][:] + probs = zarr_group.get("probabilities", None) + assert probs is None + assert 255 <= len(xs) <= 265 + assert 255 <= len(ys) <= 265 + assert 255 <= len(classes) <= 265 + + _rm_dir(save_dir) + mini_wsi_svs.unlink() + + +# ------------------------------------------------------------------------------------- +# Command Line Interface +# ------------------------------------------------------------------------------------- + + +def test_cli_model_single_file(remote_sample: Callable, track_tmp_path: Path) -> None: + """Test nucleus detector CLI single file.""" + runner = CliRunner() + mini_wsi_svs = Path(remote_sample("wsi4_512_512_svs")) + models_wsi_result = runner.invoke( + cli.main, + [ + "nucleus-detector", + "--img-input", + str(mini_wsi_svs), + "--patch-mode", + "False", + "--output-path", + str(track_tmp_path / "output"), + ], + ) + + assert models_wsi_result.exit_code == 0, models_wsi_result.output + assert (track_tmp_path / "output" / ("wsi4_512_512" + ".db")).exists() diff --git a/tests/models/test_arch_mapde.py b/tests/models/test_arch_mapde.py index 19163f593..2e30d6fcb 100644 --- a/tests/models/test_arch_mapde.py +++ b/tests/models/test_arch_mapde.py @@ -48,10 +48,91 @@ def test_functionality(remote_sample: Callable) -> None: batch = torch.from_numpy(patch)[None] output = model.infer_batch(model, batch, device=select_device(on_gpu=ON_GPU)) output = model.postproc(output[0]) - assert np.all(output[0:2] == [[19, 171], [53, 89]]) + ( + ys, + xs, + _, + ) = np.nonzero(output) + + np.testing.assert_array_equal(xs[0:2], np.array([242, 192])) + np.testing.assert_array_equal(ys[0:2], np.array([10, 13])) + + patch = reader.read_bounds( + (0, 0, 252, 252), + resolution=0.50, + units="mpp", + coord_space="resolution", + ) + + patch = model.preproc(patch) + batch = torch.from_numpy(patch)[None] + output = model.infer_batch(model, batch, device=select_device(on_gpu=ON_GPU)) + block_info = { + 0: { + "array-location": [ + [0, 1], + [0, 1], + ], # dummy block to test no valid detections + } + } + output = model.postproc(output[0], block_info=block_info) + ys, xs, _ = np.nonzero(output) + np.testing.assert_array_equal(xs, np.array([])) + np.testing.assert_array_equal(ys, np.array([])) + Path(weights_path).unlink() +def test_postproc_params_override(remote_sample: Callable) -> None: + """Test MapDe post-processing with overridden parameters.""" + sample_wsi = str(remote_sample("wsi1_2k_2k_svs")) + reader = WSIReader.open(sample_wsi) + + # * test fast mode (architecture used in PanNuke paper) + patch = reader.read_bounds( + (0, 0, 252, 252), + resolution=0.50, + units="mpp", + coord_space="resolution", + ) + + model, weight_path = _load_mapde(name="mapde-conic") + patch = model.preproc(patch) + batch = torch.from_numpy(patch)[None] + raw_output = model.infer_batch(model, batch, device=select_device(on_gpu=ON_GPU)) + + output_normal = model.postproc(raw_output[0]) + ( + ys_normal, + xs_normal, + _, + ) = np.nonzero(output_normal) + + # Use higher threshold should result in less detections + output_high_threshold = model.postproc(raw_output[0], threshold_abs=500) + ( + ys_high_threshold, + xs_high_threshold, + _, + ) = np.nonzero(output_high_threshold) + + # Use bigger min_distance should result in less detections + output_large_min_distance = model.postproc(raw_output[0], min_distance=9) + ( + ys_large_min_distance, + xs_large_min_distance, + _, + ) = np.nonzero(output_large_min_distance) + + assert len(xs_high_threshold) < len(xs_normal) + assert len(ys_high_threshold) < len(ys_normal) + + assert len(xs_large_min_distance) < len(xs_normal) + assert len(ys_large_min_distance) < len(ys_normal) + + Path(weight_path).unlink() + + def test_multiclass_output() -> None: """Test the architecture for multi-class output.""" multiclass_model = MapDe(num_input_channels=3, num_classes=3) diff --git a/tests/models/test_arch_sccnn.py b/tests/models/test_arch_sccnn.py index a456faff5..5c889cd0c 100644 --- a/tests/models/test_arch_sccnn.py +++ b/tests/models/test_arch_sccnn.py @@ -48,7 +48,10 @@ def test_functionality(remote_sample: Callable) -> None: device=select_device(on_gpu=env_detection.has_gpu()), ) output = model.postproc(output[0]) - np.testing.assert_array_equal(output, np.array([[8, 7]])) + ys, xs, _ = np.nonzero(output) + + np.testing.assert_array_equal(xs, np.array([8])) + np.testing.assert_array_equal(ys, np.array([7])) model = _load_sccnn(name="sccnn-conic") output = model.infer_batch( @@ -56,5 +59,64 @@ def test_functionality(remote_sample: Callable) -> None: batch, device=select_device(on_gpu=env_detection.has_gpu()), ) - output = model.postproc(output[0]) - np.testing.assert_array_equal(output, np.array([[7, 8]])) + block_info = { + 0: { + "array-location": [[0, 31], [0, 31]], + } + } + output = model.postproc(output[0], block_info=block_info) + ys, xs, _ = np.nonzero(output) + np.testing.assert_array_equal(xs, np.array([7])) + np.testing.assert_array_equal(ys, np.array([8])) + + model = _load_sccnn(name="sccnn-conic") + output = model.infer_batch( + model, + batch, + device=select_device(on_gpu=env_detection.has_gpu()), + ) + block_info = { + 0: { + "array-location": [ + [0, 1], + [0, 1], + ], # dummy block to test no valid detections + } + } + output = model.postproc(output[0], block_info=block_info) + ys, xs, _ = np.nonzero(output) + np.testing.assert_array_equal(xs, np.array([])) + np.testing.assert_array_equal(ys, np.array([])) + + +def test_postproc_params_override(remote_sample: Callable) -> None: + """Test postproc parameters override.""" + sample_wsi = str(remote_sample("wsi1_2k_2k_svs")) + reader = WSIReader.open(sample_wsi) + + # * test fast mode (architecture used in PanNuke paper) + patch = reader.read_bounds( + (30, 30, 61, 61), + resolution=0.25, + units="mpp", + coord_space="resolution", + ) + model = _load_sccnn(name="sccnn-crchisto") + patch = model.preproc(patch) + batch = torch.from_numpy(patch)[None] + raw_output = model.infer_batch( + model, + batch, + device=select_device(on_gpu=env_detection.has_gpu()), + ) + # Override to a high threshold to get no detections + output = model.postproc(raw_output[0], threshold_abs=0.9) + ys, xs, _ = np.nonzero(output) + np.testing.assert_array_equal(xs, np.array([])) + np.testing.assert_array_equal(ys, np.array([])) + + # Override with small min_distance + output = model.postproc(raw_output[0], min_distance=1) + ys, xs, _ = np.nonzero(output) + np.testing.assert_array_equal(xs, np.array([8])) + np.testing.assert_array_equal(ys, np.array([7])) diff --git a/tests/models/test_arch_utils.py b/tests/models/test_arch_utils.py index 8ac04bc92..8f5be5f56 100644 --- a/tests/models/test_arch_utils.py +++ b/tests/models/test_arch_utils.py @@ -1,5 +1,6 @@ """Unit test package for architecture utilities.""" +import dask.array as da import numpy as np import pytest import torch @@ -8,6 +9,7 @@ UpSample2x, centre_crop, centre_crop_to_shape, + peak_detection_map_overlap, ) @@ -65,3 +67,96 @@ def test_centre_crop_operators() -> None: y = x[:, :, 6:9, 6:9] with pytest.raises(ValueError, match=r".*Height.*smaller than `y`*"): centre_crop_to_shape(y, x, data_format="NCHW") + + +def test_peak_detection() -> None: + """Test for peak detection.""" + min_distance = 3 + threshold_abs = 0.5 + + heatmap = np.zeros((7, 7, 1), dtype=np.float32) + + peak_map = peak_detection_map_overlap( + heatmap, + min_distance=min_distance, + threshold_abs=threshold_abs, + ) + assert np.sum(peak_map) == 0.0 # No peaks + + heatmap[0, 0, 0] = 0.9 # First peak + heatmap[0, 1, 0] = 0.6 # Too close to first peak + heatmap[1, 0, 0] = 0.6 # Too close to first peak + heatmap[2, 2, 0] = 0.9 # Too close to first peak + heatmap[3, 3, 0] = 0.9 # Second peak + + peak_map = peak_detection_map_overlap( + heatmap, + min_distance=min_distance, + threshold_abs=threshold_abs, + ) + assert peak_map[0, 0, 0] == 1.0 + assert peak_map[3, 3, 0] == 1.0 + assert np.sum(peak_map) == 2.0 + + +def test_peak_detection_map_overlap() -> None: + """Test for peak detection with da.map_overlap.""" + heatmap = np.zeros((7, 7, 1), dtype=np.float32) + heatmap[0, 0, 0] = 0.9 # First peak + heatmap[0, 1, 0] = 0.6 # Too close to first peak + heatmap[1, 0, 0] = 0.6 # Too close to first peak + heatmap[2, 2, 0] = 0.9 # Too close to first peak + heatmap[3, 3, 0] = 0.9 # Second peak + + min_distance = 3 + threshold_abs = 0.5 + + # Add halo (overlap) around each block for post-processing + depth_h = min_distance + depth_w = min_distance + depth = {0: depth_h, 1: depth_w, 2: 0} + + # Test chunk is entire heatmap + da_heatmap = da.from_array(heatmap, chunks=(7, 7, 1)) + + da_peak_map = da.map_overlap( + da_heatmap, + peak_detection_map_overlap, + depth=depth, + boundary=0, + dtype=np.float32, + block_info=True, + depth_h=depth_h, + depth_w=depth_w, + threshold_abs=threshold_abs, + min_distance=min_distance, + ) + + peak_map = da_peak_map.compute() + + assert peak_map[0, 0, 0] == 1.0 + assert peak_map[3, 3, 0] == 1.0 + assert np.sum(peak_map) == 2.0 + + # Test small chunk with halo + # using very small chunk sizes (1,1,1) to force multiple overlaps + da_heatmap = da_heatmap.rechunk({0: 1, 1: 1, 2: 1}) + + da_peak_map = da.map_overlap( + da_heatmap, + peak_detection_map_overlap, + depth=depth, + boundary=0, + dtype=np.float32, + block_info=True, + depth_h=depth_h, + depth_w=depth_w, + threshold_abs=threshold_abs, + min_distance=min_distance, + ) + + peak_map = da_peak_map.compute() + + assert peak_map[0, 0, 0] == 1.0 + assert peak_map[3, 3, 0] == 1.0 + assert np.sum(peak_map) == 2.0 diff --git a/tiatoolbox/cli/__init__.py b/tiatoolbox/cli/__init__.py index b11f31f96..56c9a5dc6 100644 --- a/tiatoolbox/cli/__init__.py +++ b/tiatoolbox/cli/__init__.py @@ -8,6 +8,7 @@ from tiatoolbox import __version__ from tiatoolbox.cli.common import tiatoolbox_cli from tiatoolbox.cli.deep_feature_extractor import deep_feature_extractor +from tiatoolbox.cli.nucleus_detector import nucleus_detector from tiatoolbox.cli.nucleus_instance_segment import nucleus_instance_segment from tiatoolbox.cli.patch_predictor import patch_predictor from tiatoolbox.cli.read_bounds import read_bounds @@ -44,6 +45,7 @@ def main() -> click.BaseCommand: main.add_command(read_bounds) main.add_command(save_tiles) main.add_command(semantic_segmentor) +main.add_command(nucleus_detector) main.add_command(deep_feature_extractor) main.add_command(slide_info) main.add_command(slide_thumbnail) diff --git a/tiatoolbox/cli/common.py b/tiatoolbox/cli/common.py index 928037762..d239a6ec5 100644 --- a/tiatoolbox/cli/common.py +++ b/tiatoolbox/cli/common.py @@ -626,6 +626,61 @@ def cli_yaml_config_path( ) +def cli_min_distance( + usage_help: str = "Minimum distance separating two nuclei (in pixels).", + default: int | None = None, +) -> Callable: + """Enables --min_distance option for cli.""" + return click.option( + "--min_distance", + type=int, + help=add_default_to_usage_help(usage_help, default=default), + default=default, + ) + + +def cli_threshold_abs( + usage_help: str = "Absolute detection threshold applied to model outputs.", + default: float | None = None, +) -> Callable: + """Enables --threshold_abs option for cli.""" + return click.option( + "--threshold_abs", + type=float, + help=add_default_to_usage_help(usage_help, default=default), + default=default, + ) + + +def cli_threshold_rel( + usage_help: str = "Relative detection threshold" + " (e.g., with respect to local maxima).", + default: float | None = None, +) -> Callable: + """Enables --threshold_rel option for cli.""" + return click.option( + "--threshold_rel", + type=float, + help=add_default_to_usage_help(usage_help, default=default), + default=default, + ) + + +def cli_postproc_tile_shape( + usage_help: str = " Tile shape (height, width) used during post-processing " + "(in pixels) to control rechunking behavior.", + default: IntPair | None = None, +) -> Callable: + """Enables --postproc_tile_shape option for cli.""" + return click.option( + "--postproc_tile_shape", + type=int, + default=default, + nargs=2, + help=usage_help, + ) + + def cli_num_workers( usage_help: str = "Number of workers to load the data. Please note that they will " "also perform preprocessing.", diff --git a/tiatoolbox/cli/nucleus_detector.py b/tiatoolbox/cli/nucleus_detector.py new file mode 100644 index 000000000..290ebddd3 --- /dev/null +++ b/tiatoolbox/cli/nucleus_detector.py @@ -0,0 +1,168 @@ +"""Command line interface for nucleus detection.""" + +from __future__ import annotations + +from typing import TYPE_CHECKING + +from tiatoolbox.cli.common import ( + cli_auto_get_mask, + cli_batch_size, + cli_class_dict, + cli_device, + cli_file_type, + cli_img_input, + cli_input_resolutions, + cli_masks, + cli_memory_threshold, + cli_min_distance, + cli_model, + cli_num_workers, + cli_output_file, + cli_output_path, + cli_output_resolutions, + cli_output_type, + cli_overwrite, + cli_patch_input_shape, + cli_patch_mode, + cli_patch_output_shape, + cli_postproc_tile_shape, + cli_return_probabilities, + cli_scale_factor, + cli_stride_shape, + cli_threshold_abs, + cli_threshold_rel, + cli_verbose, + cli_weights, + cli_yaml_config_path, + prepare_ioconfig, + prepare_model_cli, + tiatoolbox_cli, +) + +if TYPE_CHECKING: # pragma: no cover + from tiatoolbox.type_hints import IntPair + + +@tiatoolbox_cli.command() +@cli_img_input() +@cli_output_path( + usage_help="Output directory where model prediction will be saved.", + default="nucleus_detection", +) +@cli_output_file(default=None) +@cli_file_type( + default="*.png, *.jpg, *.jpeg, *.tif, *.tiff, *.svs, *.ndpi, *.jp2, *.mrxs", +) +@cli_input_resolutions(default=None) +@cli_output_resolutions(default=None) +@cli_class_dict(default=None) +@cli_model(default="mapde-conic") +@cli_weights() +@cli_device(default="cpu") +@cli_batch_size(default=1) +@cli_yaml_config_path() +@cli_masks(default=None) +@cli_num_workers(default=0) +@cli_output_type( + default="AnnotationStore", +) +@cli_memory_threshold(default=80) +@cli_patch_input_shape(default=None) +@cli_patch_output_shape(default=None) +@cli_min_distance(default=None) +@cli_threshold_abs(default=None) +@cli_threshold_rel(default=None) +@cli_postproc_tile_shape(default=None) +@cli_stride_shape(default=None) +@cli_scale_factor(default=None) +@cli_patch_mode(default=False) +@cli_return_probabilities(default=True) +@cli_auto_get_mask(default=True) +@cli_overwrite(default=False) +@cli_verbose(default=True) +def nucleus_detector( + model: str, + weights: str, + img_input: str, + file_types: str, + class_dict: list[tuple[int, str]], + input_resolutions: list[dict], + output_resolutions: list[dict], + masks: str | None, + output_path: str, + patch_input_shape: IntPair | None, + patch_output_shape: tuple[int, int] | None, + stride_shape: IntPair | None, + scale_factor: tuple[float, float] | None, + batch_size: int, + yaml_config_path: str, + num_workers: int, + device: str, + output_type: str, + memory_threshold: int, + output_file: str | None, + min_distance: int | None, + threshold_abs: float | None, + threshold_rel: float | None, + postproc_tile_shape: IntPair | None, + *, + patch_mode: bool, + return_probabilities: bool, + auto_get_mask: bool, + verbose: bool, + overwrite: bool, +) -> None: + """Process a set of input images with a nucleus detection engine.""" + from tiatoolbox.models import IOSegmentorConfig, NucleusDetector # noqa: PLC0415 + + class_dict = dict(class_dict) if class_dict else None + + files_all, masks_all, output_path = prepare_model_cli( + img_input=img_input, + output_path=output_path, + masks=masks, + file_types=file_types, + ) + + ioconfig = prepare_ioconfig( + IOSegmentorConfig, + pretrained_weights=weights, + yaml_config_path=yaml_config_path, + ) + + detector = NucleusDetector( + model=model, + weights=weights, + batch_size=batch_size, + num_workers=num_workers, + verbose=verbose, + ) + + _ = detector.run( + images=files_all, + masks=masks_all, + class_dict=class_dict, + patch_mode=patch_mode, + patch_input_shape=patch_input_shape, + patch_output_shape=patch_output_shape, + input_resolutions=input_resolutions, + output_resolutions=output_resolutions, + batch_size=batch_size, + ioconfig=ioconfig, + device=device, + save_dir=output_path, + output_type=output_type, + return_probabilities=return_probabilities, + auto_get_mask=auto_get_mask, + memory_threshold=memory_threshold, + num_workers=num_workers, + output_file=output_file, + scale_factor=scale_factor, + stride_shape=stride_shape, + min_distance=min_distance, + threshold_abs=threshold_abs, + threshold_rel=threshold_rel, + postproc_tile_shape=postproc_tile_shape, + overwrite=overwrite, + verbose=verbose, + ) diff --git a/tiatoolbox/data/pretrained_model.yaml b/tiatoolbox/data/pretrained_model.yaml index 3a4ccab9b..9f715593e 100644 --- a/tiatoolbox/data/pretrained_model.yaml +++ b/tiatoolbox/data/pretrained_model.yaml @@ -814,6 +814,10 @@ mapde-crchisto: min_distance: 4 threshold_abs: 250 num_classes: 1 + postproc_tile_shape: [ 2048, 2048 ] + class_dict: { + 0: "nucleus" + } ioconfig: class: io_config.IOSegmentorConfig kwargs: @@ -821,7 +825,6 @@ mapde-crchisto: - { "units": "mpp", "resolution": 0.5 } output_resolutions: - { "units": "mpp", "resolution": 0.5 } - tile_shape: [ 2048, 2048 ] patch_input_shape: [ 252, 252 ] patch_output_shape: [ 252, 252 ] stride_shape: [ 150, 150 ] @@ -836,6 +839,10 @@ mapde-conic: min_distance: 3 threshold_abs: 205 num_classes: 1 + postproc_tile_shape: [ 2048, 2048 ] + class_dict: { + 0: "nucleus" + } ioconfig: class: io_config.IOSegmentorConfig kwargs: @@ -843,7 +850,6 @@ mapde-conic: - { "units": "mpp", "resolution": 0.5 } output_resolutions: - { "units": "mpp", "resolution": 0.5 } - tile_shape: [ 2048, 2048 ] patch_input_shape: [ 252, 252 ] patch_output_shape: [ 252, 252 ] stride_shape: [ 150, 150 ] @@ -859,6 +865,10 @@ sccnn-crchisto: min_distance: 6 threshold_abs: 0.20 patch_output_shape: [ 13, 13 ] + postproc_tile_shape: [ 2048, 2048 ] + class_dict: { + 0: "nucleus" + } ioconfig: class: io_config.IOSegmentorConfig kwargs: @@ -866,7 +876,6 @@ sccnn-crchisto: - { "units": "mpp", "resolution": 0.25 } output_resolutions: - { "units": "mpp", "resolution": 0.25 } - tile_shape: [ 2048, 2048 ] patch_input_shape: [ 31, 31 ] patch_output_shape: [ 13, 13 ] stride_shape: [ 8, 8 ] @@ -882,6 +891,10 @@ sccnn-conic: min_distance: 5 threshold_abs: 0.05 patch_output_shape: [ 13, 13 ] + postproc_tile_shape: [ 2048, 2048 ] + class_dict: { + 0: "nucleus" + } ioconfig: class: io_config.IOSegmentorConfig kwargs: @@ -889,7 +902,6 @@ sccnn-conic: - { "units": "mpp", "resolution": 0.25 } output_resolutions: - { "units": "mpp", "resolution": 0.25 } - tile_shape: [ 2048, 2048 ] patch_input_shape: [ 31, 31 ] patch_output_shape: [ 13, 13 ] stride_shape: [ 8, 8 ] diff --git a/tiatoolbox/models/__init__.py b/tiatoolbox/models/__init__.py index 4265f7caf..cd852cffa 100644 --- a/tiatoolbox/models/__init__.py +++ b/tiatoolbox/models/__init__.py @@ -19,6 +19,7 @@ ModelIOConfigABC, ) from .engine.multi_task_segmentor import MultiTaskSegmentor +from .engine.nucleus_detector import NucleusDetector from .engine.nucleus_instance_segmentor import NucleusInstanceSegmentor from .engine.patch_predictor import PatchPredictor from .engine.semantic_segmentor import SemanticSegmentor diff --git a/tiatoolbox/models/architecture/mapde.py b/tiatoolbox/models/architecture/mapde.py index 0900aa6fd..863a28b26 100644 --- a/tiatoolbox/models/architecture/mapde.py +++ b/tiatoolbox/models/architecture/mapde.py @@ -11,9 +11,9 @@ import numpy as np import torch import torch.nn.functional as F # noqa: N812 -from skimage.feature import peak_local_max from tiatoolbox.models.architecture.micronet import MicroNet +from tiatoolbox.models.architecture.utils import peak_detection_map_overlap class MapDe(MicroNet): @@ -78,6 +78,8 @@ def __init__( min_distance: int = 4, threshold_abs: float = 250, num_classes: int = 1, + postproc_tile_shape: tuple[int, int] = (2048, 2048), + class_dict: dict[int, str] | None = None, ) -> None: """Initialize :class:`MapDe`.""" super().__init__( @@ -85,6 +87,8 @@ def __init__( num_input_channels=num_input_channels, out_activation="relu", ) + self.output_class_dict = class_dict + self.postproc_tile_shape = postproc_tile_shape dist_filter = np.array( [ @@ -233,28 +237,60 @@ def forward(self: MapDe, input_tensor: torch.Tensor) -> torch.Tensor: return F.relu(out) # skipcq: PYL-W0221 # noqa: ERA001 - def postproc(self: MapDe, prediction_map: np.ndarray) -> np.ndarray: - """Post-processing script for MicroNet. + def postproc( + self: MapDe, + block: np.ndarray, + min_distance: int | None = None, + threshold_abs: float | None = None, + threshold_rel: float | None = None, + block_info: dict | None = None, + depth_h: int = 0, + depth_w: int = 0, + ) -> np.ndarray: + """MapDe post-processing function. + + Builds a processed mask per input channel, runs peak_local_max then + writes 1.0 at peak pixels. - Performs peak detection and extracts coordinates in x, y format. + Returns same spatial shape as the input block Args: - prediction_map (ndarray): - Input image of type numpy array. + block (np.ndarray): + shape (H, W, C). + min_distance (int | None): + The minimal allowed distance separating peaks. + threshold_abs (float | None): + Minimum intensity of peaks. + threshold_rel (float | None): + Minimum intensity of peaks. + block_info (dict | None): + Dask block info dict. Only used when called from + dask.array.map_overlap. + depth_h (int): + Halo size in pixels for height (rows). Only used + when it's called from dask.array.map_overlap. + depth_w (int): + Halo size in pixels for width (cols). Only used + when it's called from dask.array.map_overlap. Returns: - :class:`numpy.ndarray`: - Pixel-wise nuclear instance segmentation - prediction. - + out: NumPy array (H, W, C) with 1.0 at peaks, 0 elsewhere. """ - coordinates = peak_local_max( - np.squeeze(prediction_map[0], axis=2), - min_distance=self.min_distance, - threshold_abs=self.threshold_abs, - exclude_border=False, + min_distance_to_use = ( + self.min_distance if min_distance is None else min_distance + ) + threshold_abs_to_use = ( + self.threshold_abs if threshold_abs is None else threshold_abs + ) + return peak_detection_map_overlap( + block, + min_distance=min_distance_to_use, + threshold_abs=threshold_abs_to_use, + threshold_rel=threshold_rel, + block_info=block_info, + depth_h=depth_h, + depth_w=depth_w, ) - return np.fliplr(coordinates) @staticmethod def infer_batch( @@ -262,7 +298,7 @@ def infer_batch( batch_data: torch.Tensor, *, device: str, - ) -> list[np.ndarray]: + ) -> np.ndarray: """Run inference on an input batch. This contains logic for forward operation as well as batch I/O @@ -293,8 +329,4 @@ def infer_batch( pred = model(patch_imgs_gpu) pred = pred.permute(0, 2, 3, 1).contiguous() - pred = pred.cpu().numpy() - - return [ - pred, - ] + return pred.cpu().numpy() diff --git a/tiatoolbox/models/architecture/sccnn.py b/tiatoolbox/models/architecture/sccnn.py index 2c47f9d12..837975ab5 100644 --- a/tiatoolbox/models/architecture/sccnn.py +++ b/tiatoolbox/models/architecture/sccnn.py @@ -10,12 +10,15 @@ from __future__ import annotations from collections import OrderedDict +from typing import TYPE_CHECKING + +if TYPE_CHECKING: # pragma: no cover + import numpy as np -import numpy as np import torch -from skimage.feature import peak_local_max from torch import nn +from tiatoolbox.models.architecture.utils import peak_detection_map_overlap from tiatoolbox.models.models_abc import ModelABC @@ -91,6 +94,8 @@ def __init__( radius: int = 12, min_distance: int = 6, threshold_abs: float = 0.20, + postproc_tile_shape: tuple[int, int] = (2048, 2048), + class_dict: dict[int, str] | None = None, ) -> None: """Initialize :class:`SCCNN`.""" super().__init__() @@ -99,6 +104,8 @@ def __init__( self.in_ch = num_input_channels self.out_height = out_height self.out_width = out_width + self.postproc_tile_shape = postproc_tile_shape + self.output_class_dict = class_dict # Create mesh grid and convert to 3D vector x, y = torch.meshgrid( @@ -325,35 +332,68 @@ def spatially_constrained_layer1( return self.spatially_constrained_layer2(s1_sigmoid0, s1_sigmoid1, s1_sigmoid2) # skipcq: PYL-W0221 # noqa: ERA001 - def postproc(self: SCCNN, prediction_map: np.ndarray) -> np.ndarray: - """Post-processing script for MicroNet. + def postproc( + self: SCCNN, + block: np.ndarray, + min_distance: int | None = None, + threshold_abs: float | None = None, + threshold_rel: float | None = None, + block_info: dict | None = None, + depth_h: int = 0, + depth_w: int = 0, + ) -> np.ndarray: + """SCCNN post-processing function. - Performs peak detection and extracts coordinates in x, y format. + Builds a processed mask per input channel, runs peak_local_max then + writes 1.0 at peak pixels. + + Returns same spatial shape as the input block Args: - prediction_map (ndarray): - Input image of type numpy array. + block (np.ndarray): + shape (H, W, C). + min_distance (int | None): + The minimal allowed distance separating peaks. + threshold_abs (float | None): + Minimum intensity of peaks. + threshold_rel (float | None): + Minimum intensity of peaks. + block_info (dict | None): + Dask block info dict. Only used when called from + dask.array.map_overlap. + depth_h (int): + Halo size in pixels for height (rows). Only used + when it's called from dask.array.map_overlap. + depth_w (int): + Halo size in pixels for width (cols). Only used + when it's called from dask.array.map_overlap. Returns: - :class:`numpy.ndarray`: - Pixel-wise nuclear instance segmentation - prediction. - + out: NumPy array (H, W, C) with 1.0 at peaks, 0 elsewhere. """ - coordinates = peak_local_max( - np.squeeze(prediction_map[0], axis=2), - min_distance=self.min_distance, - threshold_abs=self.threshold_abs, - exclude_border=False, + min_distance_to_use = ( + self.min_distance if min_distance is None else min_distance + ) + threshold_abs_to_use = ( + self.threshold_abs if threshold_abs is None else threshold_abs + ) + return peak_detection_map_overlap( + block, + min_distance=min_distance_to_use, + threshold_abs=threshold_abs_to_use, + threshold_rel=threshold_rel, + block_info=block_info, + depth_h=depth_h, + depth_w=depth_w, ) - return np.fliplr(coordinates) @staticmethod def infer_batch( model: nn.Module, - batch_data: np.ndarray | torch.Tensor, + batch_data: torch.Tensor, + *, device: str, - ) -> list[np.ndarray]: + ) -> np.ndarray: """Run inference on an input batch. This contains logic for forward operation as well as batch I/O @@ -386,8 +426,4 @@ def infer_batch( pred = model(patch_imgs_gpu) pred = pred.permute(0, 2, 3, 1).contiguous() - pred = pred.cpu().numpy() - - return [ - pred, - ] + return pred.cpu().numpy() diff --git a/tiatoolbox/models/architecture/utils.py b/tiatoolbox/models/architecture/utils.py index 9b60cc7a9..f47707361 100644 --- a/tiatoolbox/models/architecture/utils.py +++ b/tiatoolbox/models/architecture/utils.py @@ -7,6 +7,7 @@ import numpy as np import torch +from skimage.feature import peak_local_max from torch import nn from tiatoolbox import logger @@ -251,3 +252,75 @@ def argmax_last_axis(image: np.ndarray) -> np.ndarray: """ return image.argmax(axis=-1) + + +def peak_detection_map_overlap( + block: np.ndarray, + min_distance: int, + threshold_abs: float | None = None, + threshold_rel: float | None = None, + block_info: dict | None = None, + depth_h: int = 0, + depth_w: int = 0, +) -> np.ndarray: + """Post-processing function for peak detection. + + Builds a processed mask per input channel. Runs peak_local_max then + writes 1.0 at peak pixels. + + Can be called from Dask.da.map_overlap on a padded NumPy block + (h_pad, w_pad, C) to process large prediction maps in chunks with overlap. + Keeps only centroids whose (row,col) lie in the interior window: + rows [depth_h : depth_h + core_h), cols [depth_w : depth_w + core_w) + + Returns same spatial shape as the input block + + Args: + block: NumPy array (H, W, C). + min_distance: Minimum number of pixels separating peaks. + threshold_abs: Minimum intensity of peaks. By default, None. + threshold_rel: Minimum relative intensity of peaks. By default, None. + block_info: Dask block info dict. + Only used when called from dask.array.map_overlap. + depth_h: Halo size in pixels for height (rows). + Only used when called from dask.array.map_overlap. + depth_w: Halo size in pixels for width (cols). + Only used when it's called from dask.array.map_overlap. + + Returns: + out: NumPy array (H, W, C) with 1.0 at peaks, 0 elsewhere. + + """ + block_height, block_width, block_channels = block.shape + + # --- derive core (pre-overlap) size for THIS block --- + if block_info is None: + core_h = block_height - 2 * depth_h + core_w = block_width - 2 * depth_w + else: + info = block_info[0] + locs = info["array-location"] # a list of (start, stop) coordinates per axis + core_h = int(locs[0][1] - locs[0][0]) # r1 - r0 + core_w = int(locs[1][1] - locs[1][0]) + + rmin, rmax = depth_h, depth_h + core_h + cmin, cmax = depth_w, depth_w + core_w + + out = np.zeros((block_height, block_width, block_channels), dtype=np.float32) + + for ch in range(block_channels): + img = np.asarray(block[..., ch]) # NumPy 2D view + + coords = peak_local_max( + img, + min_distance=min_distance, + threshold_abs=threshold_abs, + threshold_rel=threshold_rel, + exclude_border=False, + ) + + for r, c in coords: + if (rmin <= r < rmax) and (cmin <= c < cmax): + out[r, c, ch] = 1.0 + + return out diff --git a/tiatoolbox/models/engine/__init__.py b/tiatoolbox/models/engine/__init__.py index 2509b51a3..ee5d09122 100644 --- a/tiatoolbox/models/engine/__init__.py +++ b/tiatoolbox/models/engine/__init__.py @@ -3,6 +3,7 @@ from . import ( deep_feature_extractor, engine_abc, + nucleus_detector, nucleus_instance_segmentor, patch_predictor, semantic_segmentor, @@ -11,6 +12,7 @@ __all__ = [ "deep_feature_extractor", "engine_abc", + "nucleus_detector", "nucleus_instance_segmentor", "patch_predictor", "semantic_segmentor", diff --git a/tiatoolbox/models/engine/engine_abc.py b/tiatoolbox/models/engine/engine_abc.py index d7ac9ddfc..4b4c712df 100644 --- a/tiatoolbox/models/engine/engine_abc.py +++ b/tiatoolbox/models/engine/engine_abc.py @@ -45,7 +45,8 @@ import torch import zarr from dask import compute -from dask.diagnostics import ProgressBar +from dask.diagnostics.progress import ProgressBar +from numcodecs import Pickle from torch import nn from typing_extensions import Unpack @@ -699,29 +700,11 @@ def save_predictions( keys_to_compute = [k for k in processed_predictions if k not in self.drop_keys] if output_type.lower() == "zarr": - if is_zarr(save_path): - zarr_group = zarr.open(save_path, mode="r") - keys_to_compute = [k for k in keys_to_compute if k not in zarr_group] - write_tasks = [] - for key in keys_to_compute: - dask_array = processed_predictions[key].rechunk("auto") - task = dask_array.to_zarr( - url=save_path, - component=key, - compute=False, - ) - write_tasks.append(task) - msg = f"Saving output to {save_path}." - logger.info(msg=msg) - with ProgressBar(): - compute(*write_tasks) - - zarr_group = zarr.open(save_path, mode="r+") - for key in self.drop_keys: - if key in zarr_group: - del zarr_group[key] - - return save_path + return self.save_predictions_as_zarr( + processed_predictions=processed_predictions, + save_path=save_path, + keys_to_compute=keys_to_compute, + ) values_to_compute = [processed_predictions[k] for k in keys_to_compute] @@ -754,6 +737,68 @@ def save_predictions( msg = f"Unsupported output type: {output_type}" raise TypeError(msg) + def save_predictions_as_zarr( + self: EngineABC, + processed_predictions: dict, + save_path: Path, + keys_to_compute: list, + ) -> Path: + """Save model predictions as a zarr file. + + This method saves the processed predictions to a zarr file at the specified + path. + + Args: + processed_predictions (dict): + Dictionary containing processed model predictions. + save_path (Path): + Path to save the zarr file. + keys_to_compute (list): + List of keys in processed_predictions to save. + + Returns: + save_path (Path): + Path to the saved zarr file. + + """ + if is_zarr(save_path): + zarr_group = zarr.open(save_path, mode="r") + keys_to_compute = [k for k in keys_to_compute if k not in zarr_group] + write_tasks = [] + for key in keys_to_compute: + dask_output = processed_predictions[key] + if isinstance(dask_output, da.Array): + dask_output = dask_output.rechunk("auto") + task = dask_output.to_zarr( + url=save_path, component=key, compute=False, object_codec=None + ) + write_tasks.append(task) + + if isinstance(dask_output, list) and all( + isinstance(dask_array, da.Array) for dask_array in dask_output + ): + for i, dask_array in enumerate(dask_output): + object_codec = Pickle() if dask_array.dtype == "object" else None + task = dask_array.to_zarr( + url=save_path, + component=f"{key}/{i}", + compute=False, + object_codec=object_codec, + ) + write_tasks.append(task) + + msg = f"Saving output to {save_path}." + logger.info(msg=msg) + with ProgressBar(): + compute(*write_tasks) + + zarr_group = zarr.open(save_path, mode="r+") + for key in self.drop_keys: + if key in zarr_group: + del zarr_group[key] + + return save_path + def infer_wsi( self: EngineABC, dataloader: DataLoader, diff --git a/tiatoolbox/models/engine/nucleus_detector.py b/tiatoolbox/models/engine/nucleus_detector.py new file mode 100644 index 000000000..e23297c6a --- /dev/null +++ b/tiatoolbox/models/engine/nucleus_detector.py @@ -0,0 +1,1143 @@ +"""Nucleus Detection Engine for Digital Pathology (WSIs and patches). + +This module implements the `NucleusDetector` class which extends +`SemanticSegmentor` to perform instance-level nucleus detection on +histology images. It supports patch-mode and whole slide image (WSI) +workflows using TIAToolbox or custom PyTorch models, and provides +utilities for parallel post-processing (centroid extraction, thresholding), +merging detections across patches, and exporting results in multiple +formats (in-memory dict, Zarr, AnnotationStore). + +Classes +------- +NucleusDetectorRunParams + TypedDict specifying runtime configuration keys for detection. +NucleusDetector + Core engine for nucleus detection on image patches or WSIs. + +Examples: +-------- +>>> from tiatoolbox.models.engine.nucleus_detector import NucleusDetector +>>> detector = NucleusDetector(model="mapde-conic") +>>> # WSI workflow: save to AnnotationStore (.db) +>>> out = detector.run( +... images=[pathlib.Path("example_wsi.tiff")], +... patch_mode=False, +... device="cuda", +... save_dir=pathlib.Path("output_directory/"), +... overwrite=True, +... output_type="annotationstore", +... class_dict={0: "nucleus"}, +... auto_get_mask=True, +... memory_threshold=80, +... ) +>>> # Patch workflow: return in-memory detections +>>> patches = [np.ndarray, np.ndarray] # NHWC +>>> out = detector.run(patches, patch_mode=True, output_type="dict") + +Notes: +----- +- Outputs can be returned as Python dictionaries, saved as Zarr groups, + or converted to AnnotationStore (.db). +- Post-processing uses tile rechunking and halo padding to facilitate + centroid extraction near chunk boundaries. + +""" + +from __future__ import annotations + +import shutil +import tempfile +from pathlib import Path +from typing import TYPE_CHECKING + +import dask.array as da +import numpy as np +from dask import compute +from dask.diagnostics.progress import ProgressBar +from shapely.geometry import Point + +from tiatoolbox import logger +from tiatoolbox.annotation.storage import Annotation, SQLiteStore +from tiatoolbox.models.engine.semantic_segmentor import ( + SemanticSegmentor, + SemanticSegmentorRunParams, +) +from tiatoolbox.utils.misc import get_tqdm + +if TYPE_CHECKING: # pragma: no cover + import os + from typing import Unpack + + from tiatoolbox.annotation import AnnotationStore + from tiatoolbox.models.models_abc import ModelABC + from tiatoolbox.type_hints import IntPair, Resolution, Units + from tiatoolbox.wsicore import WSIReader + + from .io_config import IOSegmentorConfig + + +class NucleusDetectorRunParams(SemanticSegmentorRunParams, total=False): + """Runtime parameters for configuring the `NucleusDetector.run()` method. + + This class extends `SemanticSegmentorRunParams` (and transitively + `PredictorRunParams` → `EngineABCRunParams`) with additional options + specific to nucleus detection workflows. + + Attributes: + auto_get_mask (bool): + Whether to automatically generate segmentation masks using + `wsireader.tissue_mask()` during WSI processing. + batch_size (int): + Number of image patches to feed to the model in a forward pass. + class_dict (dict): + Optional dictionary mapping numeric class IDs to class names. + device (str): + Device to run the model on (e.g., "cpu", "cuda"). + labels (list): + Optional labels for input images. Only a single label per image + is supported. + memory_threshold (int): + Memory usage threshold (in percentage) to trigger caching behavior. + num_workers (int): + Number of workers used in the DataLoader. + output_file (str): + Output file name for saving results (e.g., ".zarr" or ".db"). + output_resolutions (Resolution): + Resolution used for writing output predictions/coordinates. + patch_output_shape (tuple[int, int]): + Shape of output patches (height, width). + min_distance (int): + Minimum separation between nuclei (in pixels) used during + centroid extraction/post-processing. + threshold_abs (float): + Absolute detection threshold applied to model outputs. + threshold_rel (float): + Relative detection threshold (e.g., with respect to local maxima). + postproc_tile_shape (tuple[int, int]): + Tile shape (height, width) used during post-processing + (in pixels) to control rechunking behavior. + return_labels (bool): + Whether to return labels with predictions. + return_probabilities (bool): + Whether to include per-class probabilities in the output. + scale_factor (tuple[float, float]): + Scale factor for converting coordinates to baseline resolution. + Typically, `model_mpp / slide_mpp`. + stride_shape (tuple[int, int]): + Stride used during WSI processing. Defaults to `patch_input_shape`. + verbose (bool): + Whether to enable verbose logging. + + """ + + min_distance: int + threshold_abs: float + threshold_rel: float + postproc_tile_shape: IntPair + + +class NucleusDetector(SemanticSegmentor): + r"""Nucleus detection engine for digital histology images. + + This class extends :class:`SemanticSegmentor` to support instance-level + nucleus detection using pretrained or custom models from TIAToolbox. + It operates in both patch-level and whole slide image (WSI) modes and + provides utilities for post-processing (e.g., centroid extraction, + thresholding, tile-overlap handling), merging predictions, and saving + results in multiple output formats. Supported TIAToolbox models include + nucleus-detection architectures such as ``mapde-conic`` and + ``mapde-crchisto``. For the full list of pretrained models, refer to the + model zoo documentation: + https://tia-toolbox.readthedocs.io/en/latest/pretrained.html + + The class integrates seamlessly with the TIAToolbox engine interface, + inheriting the data loading, inference orchestration, memory-aware + chunking, and output-saving conventions of :class:`SemanticSegmentor`, + while overriding only the nucleus-specific post-processing and export + routines. + + Args: + model (str or nn.Module): + Defined PyTorch model or name of the existing models support by + tiatoolbox for processing the data e.g., mapde-conic, mapde-crchisto. + For a full list of pretrained models, please refer to the `docs + `. + By default, the corresponding pretrained weights will also + be downloaded. However, you can override with your own set + of weights via the `weights` argument. Argument is case insensitive. + batch_size (int): + Number of image patches processed per forward pass. + Default is ``8``. + num_workers (int): + Number of workers for ``torch.utils.data.DataLoader``. + Default is ``0``. + weights (str or pathlib.Path or None): + Optional path to pretrained weights. If ``None`` and ``model`` is + a string, default pretrained weights for that model will be used. + If ``model`` is an ``nn.Module``, weights are loaded only if + provided. + device (str): + Device on which the model will run (e.g., ``"cpu"``, ``"cuda"``). + Default is ``"cpu"``. + verbose (bool): + Whether to output logging information. Default is ``True``. + + Attributes: + images (list[str or Path] or np.ndarray): + Input images supplied to the engine, either as WSI paths or + NHWC-formatted patches. + masks (list[str or Path] or np.ndarray): + Optional tissue masks for WSI processing. Only used when + ``patch_mode=False``. + patch_mode (bool): + Whether input is treated as image patches (``True``) or as WSIs + (``False``). + model (ModelABC): + Loaded PyTorch model. Can be a pretrained TIAToolbox model or a + custom user-provided model. + ioconfig (ModelIOConfigABC): + IO configuration specifying patch extraction shape, stride, and + resolution settings for inference. + return_labels (bool): + Whether to include labels in the output, if provided. + input_resolutions (list[dict]): + Resolution settings for model input heads. Supported units are + ``"level"``, ``"power"``, and ``"mpp"``. + patch_input_shape (tuple[int, int]): + Height and width of input patches read from slides, expressed in + read resolution space. + stride_shape (tuple[int, int]): + Stride used during patch extraction. Defaults to + ``patch_input_shape``. + drop_keys (list): + Keys to exclude from model output when saving results. + output_type (str): + Output format (``"dict"``, ``"zarr"``, or ``"annotationstore"``). + + Examples: + >>> from tiatoolbox.models.engine.nucleus_detector import NucleusDetector + >>> model_name = "mapde-conic" + >>> detector = NucleusDetector(model=model_name, batch_size=16, num_workers=8) + >>> detector.run( + ... images=[pathlib.Path("example_wsi.tiff")], + ... patch_mode=False, + ... device="cuda", + ... save_dir=pathlib.Path("output_directory/"), + ... overwrite=True, + ... output_type="annotationstore", + ... class_dict={0: "nucleus"}, + ... auto_get_mask=True, + ... memory_threshold=80, + ... ) + + """ + + def __init__( + self: NucleusDetector, + model: str | ModelABC, + batch_size: int = 8, + num_workers: int = 0, + weights: str | Path | None = None, + *, + device: str = "cpu", + verbose: bool = True, + ) -> None: + """Initialize :class:`NucleusDetector`. + + This constructor follows the standard TIAToolbox engine initialization + workflow. A model may be provided either as a string referring to a + pretrained TIAToolbox architecture or as a custom ``torch.nn.Module``. + When ``model`` is a string, the corresponding pretrained weights are + automatically downloaded unless explicitly overridden via ``weights``. + + Args: + model (str or ModelABC): + A PyTorch model instance or the name of a pretrained TIAToolbox + model. If a string is provided, default pretrained weights are + loaded unless ``weights`` is supplied to override them. + + batch_size (int): + Number of image patches processed per forward pass. + Default is ``8``. + + num_workers (int): + Number of workers used for ``torch.utils.data.DataLoader``. + Default is ``0``. + + weights (str or Path or None): + Path to model weights. If ``None`` and ``model`` is a string, + the default pretrained weights for that model will be used. + If ``model`` is a ``nn.Module``, weights are loaded only when + specified here. + + device (str): + Device on which the model will run (e.g., ``"cpu"``, ``"cuda"``). + Default is ``"cpu"``. + + verbose (bool): + Whether to enable verbose logging during initialization and + inference. Default is ``True``. + + """ + super().__init__( + model=model, + batch_size=batch_size, + num_workers=num_workers, + weights=weights, + device=device, + verbose=verbose, + ) + + def post_process_patches( + self: NucleusDetector, + raw_predictions: da.Array, + prediction_shape: tuple[int, ...], + prediction_dtype: type, + **kwargs: Unpack[NucleusDetectorRunParams], + ) -> dict: + """Post-process patch-level detection outputs. + + Applies the model's post-processing function (e.g., centroid extraction and + thresholding) to each patch's probability map, yielding per-patch detection + arrays suitable for saving or further merging. + + Args: + raw_predictions (da.Array): + Patch predictions of shape ``(B, H, W, C)``, where ``B`` is the number + of patches (probabilities/logits). + prediction_shape (tuple[int, ...]): + Expected prediction shape. + prediction_dtype (type): + Expected prediction dtype. + **kwargs (NucleusDetectorRunParams): + Additional runtime parameters to configure segmentation. + + Optional Keys: + min_distance (int): + Minimum separation between nuclei (in pixels) used during + centroid extraction/post-processing. + threshold_abs (float): + Absolute detection threshold applied to model outputs. + threshold_rel (float): + Relative detection threshold + (e.g., with respect to local maxima). + + + Returns: + dict[str, list[da.Array]]: + A dictionary of lists (one list per patch), with keys: + - ``"x"`` (list[dask array]): + 1-D object dask arrays of x coordinates (``np.uint32``). + - ``"y"`` (list[dask array]): + 1-D object dask arrays of y coordinates (``np.uint32``). + - ``"classes"`` (list[dask array]): + 1-D object dask arrays of class IDs (``np.uint32``). + - ``"probabilities"`` (list[dask array]): + 1-D object dask arrays of detection scores (``np.float32``). + + Notes: + - If thresholds are not provided via ``kwargs``, model defaults are used. + + """ + logger.info("Post processing patch predictions in NucleusDetector") + _ = prediction_shape + _ = prediction_dtype + + # If these are not provided, defaults from model will be used in postproc + min_distance = kwargs.get("min_distance") + threshold_abs = kwargs.get("threshold_abs") + threshold_rel = kwargs.get("threshold_rel") + + # Lists to hold per-patch detection arrays + xs = [] + ys = [] + classes = [] + probs = [] + + # Process each patch's predictions + for i in range(raw_predictions.shape[0]): + probs_prediction_patch = raw_predictions[i].compute() + centroids_map_patch = self.model.postproc( + probs_prediction_patch, + min_distance=min_distance, + threshold_abs=threshold_abs, + threshold_rel=threshold_rel, + ) + centroids_map_patch = da.from_array(centroids_map_patch, chunks="auto") + xs_patch, ys_patch, classes_patch, probs_patch = ( + self._centroid_maps_to_detection_arrays(centroids_map_patch).values() + ) + xs.append(xs_patch) + ys.append(ys_patch) + classes.append(classes_patch) + probs.append(probs_patch) + + return {"x": xs, "y": ys, "classes": classes, "probabilities": probs} + + def post_process_wsi( + self: NucleusDetector, + raw_predictions: da.Array, + prediction_shape: tuple[int, ...], + prediction_dtype: type, + **kwargs: Unpack[NucleusDetectorRunParams], + ) -> dict[str, da.Array]: + """Post-process WSI-level nucleus detection outputs. + + Processes the full-slide prediction map using Dask's block-wise operations + to extract nuclei centroids across the entire WSI. The prediction map is + first re-chunked to the model's preferred post-processing tile shape, and + `dask.map_overlap` with halo padding is used to facilitate centroid + extraction on large prediction maps. The resulting centroid maps are + computed and saved to Zarr storage for memory-efficient processing, then + converted into detection arrays (x, y, classes, probabilities) through + sequential block processing. + + Args: + raw_predictions (da.Array): + WSI prediction map of shape ``(H, W, C)`` containing + per-class probabilities or logits. + prediction_shape (tuple[int, ...]): + Expected prediction shape. + prediction_dtype (type): + Expected prediction dtype. + **kwargs (NucleusDetectorRunParams): + Additional runtime parameters to configure segmentation. + + Optional Keys: + min_distance (int): + Minimum distance separating two nuclei (in pixels). + threshold_abs (float): + Absolute detection threshold applied to model outputs. + threshold_rel (float): + Relative detection threshold + (e.g., with respect to local maxima). + postproc_tile_shape (tuple[int, int]): + Tile shape (height, width) for post-processing rechunking. + cache_dir (str or os.PathLike): + Directory for caching intermediate centroid maps as Zarr. + Defaults to './tmp/'. + + Returns: + dict[str, da.Array]: + A dictionary mapping detection fields to 1-D Dask arrays: + - ``"x"``: x coordinates of detected nuclei (``np.uint32``). + - ``"y"``: y coordinates of detected nuclei (``np.uint32``). + - ``"classes"``: class IDs (``np.uint32``). + - ``"probabilities"``: detection scores (``np.float32``). + + Notes: + - Halo padding ensures that nuclei crossing tile/chunk boundaries + are not fragmented or duplicated. + - If thresholds are not explicitly provided, model defaults are used. + - Centroid maps are computed and saved to Zarr storage to avoid + out-of-memory errors on large WSIs. + - The Zarr-backed centroid maps are then processed block-by-block + to extract detections incrementally. + + """ + _ = prediction_shape + + logger.info("Post processing WSI predictions in NucleusDetector") + + # If these are not provided, defaults from model will be used in postproc + threshold_abs = kwargs.get("threshold_abs") + threshold_rel = kwargs.get("threshold_rel") + + # min_distance and postproc_tile_shape cannot be None here + min_distance = kwargs.get("min_distance") + if min_distance is None: + min_distance = self.model.min_distance + postproc_tile_shape = kwargs.get("postproc_tile_shape") + if postproc_tile_shape is None: + postproc_tile_shape = self.model.postproc_tile_shape + + # Add halo (overlap) around each block for post-processing + depth_h = min_distance + depth_w = min_distance + depth = {0: depth_h, 1: depth_w, 2: 0} + + # Re-chunk to post-processing tile shape for more efficient processing + rechunked_prediction_map = raw_predictions.rechunk( + (postproc_tile_shape[0], postproc_tile_shape[1], -1) + ) + + centroid_maps = da.map_overlap( + self.model.postproc, + rechunked_prediction_map, + min_distance=min_distance, + threshold_abs=threshold_abs, + threshold_rel=threshold_rel, + depth=depth, + boundary=0, + dtype=prediction_dtype, + block_info=True, + depth_h=depth_h, + depth_w=depth_w, + ) + + logger.info("Computing and saving centroid maps to temporary zarr file.") + temp_zarr_file = tempfile.TemporaryDirectory( + prefix="tiatoolbox_nucleus_detector_", suffix=".zarr" + ) + logger.info("Temporary zarr file created at: %s", temp_zarr_file.name) + task = centroid_maps.to_zarr( + url=temp_zarr_file.name, compute=False, object_codec=None + ) + with ProgressBar(): + compute(task) + + centroid_maps = da.from_zarr(temp_zarr_file.name) + + return self._centroid_maps_to_detection_arrays(centroid_maps) + + def save_predictions( + self: NucleusDetector, + processed_predictions: dict, + output_type: str, + save_path: Path | None = None, + **kwargs: Unpack[NucleusDetectorRunParams], + ) -> dict | AnnotationStore | Path | list[Path]: + """Save nucleus detections to disk or return them in memory. + + Saves post-processed detection outputs in one of the supported formats. + If ``patch_mode=True``, predictions are saved per image. If + ``patch_mode=False``, detections are merged and saved as a single output. + + Args: + processed_predictions (dict): + Dictionary containing processed detection results. Expected to include + a ``"predictions"`` key with detection arrays. The internal structure + follows TIAToolbox conventions and may differ slightly between patch + and WSI modes: + - Patch mode: + - ``"x"`` (list[da.Array]): + per-patch x coordinates (np.uint32). + - ``"y"`` (list[da.Array]): + per-patch y coordinates (np.uint32). + - ``"classes"`` (list[da.Array]): + per-patch class IDs (np.uint32). + - ``"probabilities"`` (list[da.Array]): + per-patch detection scores (np.float32). + - WSI mode: + - ``"x"`` (da.Array): + x coordinates (np.uint32). + - ``"y"`` (da.Array): + y coordinates (np.uint32). + - ``"classes"`` (da.Array): + class IDs (np.uint32). + - ``"probabilities"`` (da.Array): + detection scores (np.float32). + + output_type (str): + Desired output format: ``"dict"``, ``"zarr"``, or ``"annotationstore"``. + + save_path (Path | None): + Path at which to save the output file(s). Required for file outputs + (e.g., Zarr or SQLite .db). If ``None`` and ``output_type="dict"``, + results are returned in memory. + + **kwargs (NucleusDetectorRunParams): + Additional runtime parameters to configure segmentation. + + Optional Keys: + auto_get_mask (bool): + Whether to automatically generate segmentation masks using + `wsireader.tissue_mask()` during processing. + batch_size (int): + Number of image patches to feed to the model in a forward pass. + class_dict (dict): + Optional dictionary mapping classification outputs to + class names. + device (str): + Device to run the model on (e.g., "cpu", "cuda"). + labels (list): + Optional labels for input images. Only a single label per image + is supported. + memory_threshold (int): + Memory usage threshold (in percentage) to + trigger caching behavior. + num_workers (int): + Number of workers used in DataLoader. + output_file (str): + Output file name for saving results (e.g., .zarr or .db). + output_resolutions (Resolution): + Resolution used for writing output predictions. + patch_output_shape (tuple[int, int]): + Shape of output patches (height, width). + min_distance (int): + Minimum distance separating two nuclei (in pixels). + postproc_tile_shape (tuple[int, int]): + Tile shape (height, width) for post-processing (in pixels). + return_labels (bool): + Whether to return labels with predictions. + return_probabilities (bool): + Whether to return per-class probabilities. + scale_factor (tuple[float, float]): + Scale factor for converting annotations to baseline resolution. + Typically model_mpp / slide_mpp. + stride_shape (tuple[int, int]): + Stride used during WSI processing. + Defaults to patch_input_shape. + verbose (bool): + Whether to output logging information. + + Returns: + dict | AnnotationStore | Path | list[Path]: + - If ``output_type="dict"``: + returns a Python dictionary of predictions. + - If ``output_type="zarr"``: + returns the path to the saved ``.zarr`` group. + - If ``output_type="annotationstore"``: + returns an AnnotationStore handle or the path(s) to saved + ``.db`` file(s). In patch mode, a list of per-image paths + may be returned. + + Notes: + - For non-AnnotationStore outputs, this method delegates to the + base engine's saving function to preserve consistency across + TIAToolbox engines. + + """ + if output_type.lower() != "annotationstore": + return super().save_predictions( + processed_predictions["predictions"], + output_type, + save_path=save_path, + **kwargs, + ) + + # scale_factor set from kwargs + scale_factor = kwargs.get("scale_factor", (1.0, 1.0)) + # class_dict set from kwargs + class_dict = kwargs.get("class_dict") + if class_dict is None: + class_dict = self.model.output_class_dict + + return self._save_predictions_annotation_store( + processed_predictions, + save_path=save_path, + scale_factor=scale_factor, + class_dict=class_dict, + ) + + def _save_predictions_annotation_store( + self: NucleusDetector, + processed_predictions: dict, + save_path: Path | None = None, + scale_factor: tuple[float, float] = (1.0, 1.0), + class_dict: dict | None = None, + ) -> AnnotationStore | Path | list[Path]: + """Save nucleus detections to an AnnotationStore (.db). + + Converts the processed detection arrays into per-instance `Annotation` + records, applies coordinate scaling and optional class-ID remapping, + and writes the results into an SQLite-backed AnnotationStore. In patch + mode, detections are written to separate `.db` files per input image; + in WSI mode, all detections are merged and written to a single store. + + Args: + processed_predictions (dict): + Dictionary containing the computed detection outputs. Expected to + include a top-level key ``"predictions"`` with fields: + - ``"x"`` (da.Array): + dask array of x coordinates (``np.uint32``) + - ``"y"`` (da.Array): + dask array of y coordinates (``np.uint32``) + - ``"classes"`` (da.Array): + dask array of class IDs (``np.uint32``) + - ``"probabilities"`` (da.Array): + dask array of detection scores (``np.float32``) + + save_path (Path or None): + Output path for saving the AnnotationStore. If ``None``, an in-memory + store is returned. When patch mode is active, this path serves as the + directory for producing one `.db` file per patch input. + + scale_factor (tuple[float, float], optional): + Scaling factors applied to x and y coordinates prior to writing. + Typically corresponds to ``model_mpp / slide_mpp``. + Defaults to ``(1.0, 1.0)``. + + class_dict (dict or None): + Optional mapping from original class IDs to class names or remapped IDs. + If ``None``, an identity mapping based on present classes is used. + + Returns: + AnnotationStore or Path or list[Path]: + - For WSI mode: a single AnnotationStore handle or the path to the saved + `.db` file. + - For patch mode: a list of paths, one per saved patch-level + AnnotationStore. + + Notes: + - This method centralizes the translation of detection arrays into + `Annotation` objects and abstracts batching logic via + ``_write_detection_arrays_to_store``. + + """ + logger.info("Saving predictions as AnnotationStore.") + if self.patch_mode: + save_paths = [] + detections = processed_predictions["predictions"] + + num_patches = len(detections["x"]) + for i in range(num_patches): + if isinstance(self.images[i], Path): + output_path = save_path.parent / (self.images[i].stem + ".db") + else: + output_path = save_path.parent / (str(i) + ".db") + + detection_arrays = { + "x": detections["x"][i], + "y": detections["y"][i], + "classes": detections["classes"][i], + "probabilities": detections["probabilities"][i], + } + + out_file = self.save_detection_arrays_to_store( + detection_arrays=detection_arrays, + scale_factor=scale_factor, + class_dict=class_dict, + save_path=output_path, + ) + + save_paths.append(out_file) + return save_paths + predictions = processed_predictions["predictions"] + return self.save_detection_arrays_to_store( + detection_arrays=predictions, + scale_factor=scale_factor, + save_path=save_path, + class_dict=class_dict, + ) + + @staticmethod + def _centroid_maps_to_detection_arrays( + detection_maps: da.Array, + ) -> dict[str, da.Array]: + """Convert centroid maps into 1-D detection arrays. + + This helper function extracts non-zero centroid predictions from a + already computed Dask array of centroid maps and flattens them into + coordinate, class, and probability arrays suitable for saving or + further processing. The function processes the centroid maps block + by block to minimize memory usage, reading each block from disk + and extracting detections incrementally. + + Args: + detection_maps (da.Array): + A Dask array of shape ``(H, W, C)`` representing centroid + probability maps, where non-zero values correspond to nucleus + detections. Each non-zero entry encodes both the class channel + and its associated probability. This array is expected to be + already computed. + + Returns: + dict[str, da.Array]: + A dictionary containing four 1-D Dask arrays: + - ``"x"``: + x coordinates of detected nuclei (``np.uint32``). + - ``"y"``: + y coordinates of detected nuclei (``np.uint32``). + - ``"classes"``: + class IDs for each detection (``np.uint32``). + - ``"probabilities"``: + detection probabilities (``np.float32``). + + Notes: + - The centroid maps are expected to be pre-computed. + - Blocks are processed sequentially to avoid loading the entire + centroid map into memory at once. + - Global coordinates are computed by adding block offsets to local + coordinates within each block. + - This method is used by both patch-level and WSI-level + post-processing routines to unify detection formatting. + + + """ + logger.info("Extracting detections from centroid maps block by block...") + + # Get chunk information + num_blocks_h = detection_maps.numblocks[0] + num_blocks_w = detection_maps.numblocks[1] + + # Lists to collect detections from each block + ys_list = [] + xs_list = [] + classes_list = [] + probs_list = [] + + tqdm = get_tqdm() + for i in tqdm(range(num_blocks_h), desc="Processing detection blocks"): + for j in range(num_blocks_w): + # Get block offsets + y_offset = sum(detection_maps.chunks[0][:i]) if i > 0 else 0 + x_offset = sum(detection_maps.chunks[1][:j]) if j > 0 else 0 + + # Read this block from Zarr (already computed, so this is just I/O) + block = np.array(detection_maps.blocks[i, j]) + + # Extract nonzero detections + ys, xs, classes = np.nonzero(block) + probs = block[ys, xs, classes] + + # Adjust to global coordinates + ys = ys + y_offset + xs = xs + x_offset + + # Append to lists if we have detections + if len(ys) > 0: + ys_list.append(ys.astype(np.uint32)) + xs_list.append(xs.astype(np.uint32)) + classes_list.append(classes.astype(np.uint32)) + probs_list.append(probs.astype(np.float32)) + + # Concatenate all block results + if ys_list: + ys = np.concatenate(ys_list) + xs = np.concatenate(xs_list) + classes = np.concatenate(classes_list) + probs = np.concatenate(probs_list) + else: + ys = np.array([], dtype=np.uint32) + xs = np.array([], dtype=np.uint32) + classes = np.array([], dtype=np.uint32) + probs = np.array([], dtype=np.float32) + + return { + "y": da.from_array(ys, chunks="auto"), + "x": da.from_array(xs, chunks="auto"), + "classes": da.from_array(classes, chunks="auto"), + "probabilities": da.from_array(probs, chunks="auto"), + } + + @staticmethod + def _write_detection_arrays_to_store( + detection_arrays: tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray], + store: SQLiteStore, + scale_factor: tuple[float, float], + class_dict: dict[int, str | int] | None, + batch_size: int = 5000, + ) -> int: + """Write detection arrays to an AnnotationStore in batches. + + Converts coordinate, class, and probability arrays into `Annotation` + objects and appends them to an SQLite-backed store in configurable + batch sizes. Coordinates are scaled to baseline slide resolution using + the provided `scale_factor`, and optional class-ID remapping is applied + via `class_dict`. + + Args: + detection_arrays (tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]): + Tuple of arrays in the order: + `(x_coords, y_coords, class_ids, probabilities)`. + Each element must be a 1-D NumPy array of equal length. + store (SQLiteStore): + Target `AnnotationStore` instance to receive the detections. + scale_factor (tuple[float, float]): + Factors applied to `(x, y)` coordinates prior to writing, + typically `(model_mpp / slide_mpp)`. The scaled coordinates are + rounded to `np.uint32`. + class_dict (dict[int, str | int] | None): + Optional mapping from original class IDs to names or remapped IDs. + If `None`, an identity mapping is used for the set of present classes. + batch_size (int): + Number of records to write per batch. Default is `5000`. + + Returns: + int: + Total number of detection records written to the store. + + Notes: + - Coordinates are scaled and rounded to integers to ensure consistent + geometry creation for `Annotation` points. + - Class mapping is applied per-record; unmapped IDs fall back to their + original values. + - Writing in batches reduces memory pressure and improves throughput + on large number of detections. + + """ + xs, ys, classes, probs = detection_arrays + n = len(xs) + if n == 0: + return 0 # nothing to write + + # scale coordinates + xs = np.rint(xs * scale_factor[0]).astype(np.uint32, copy=False) + ys = np.rint(ys * scale_factor[1]).astype(np.uint32, copy=False) + + # class mapping + if class_dict is None: + # identity over actually-present types + uniq = np.unique(classes) + class_dict = {int(k): int(k) for k in uniq} + labels = np.array( + [class_dict.get(int(k), int(k)) for k in classes], dtype=object + ) + + def make_points(xs_batch: np.ndarray, ys_batch: np.ndarray) -> list[Point]: + """Create Shapely Point geometries from coordinate arrays in batches.""" + return [ + Point(int(xx), int(yy)) + for xx, yy in zip(xs_batch, ys_batch, strict=True) + ] + + tqdm = get_tqdm() + tqdm_loop = tqdm(range(0, n, batch_size), desc="Writing detections to store") + written = 0 + for i in tqdm_loop: + j = min(i + batch_size, n) + pts = make_points(xs[i:j], ys[i:j]) + + anns = [ + Annotation( + geometry=pt, properties={"type": lbl, "probability": float(pp)} + ) + for pt, lbl, pp in zip(pts, labels[i:j], probs[i:j], strict=True) + ] + store.append_many(anns) + written += j - i + return written + + @staticmethod + def save_detection_arrays_to_store( + detection_arrays: dict[str, da.Array], + scale_factor: tuple[float, float] = (1.0, 1.0), + class_dict: dict | None = None, + save_path: Path | None = None, + batch_size: int = 5000, + ) -> Path | SQLiteStore: + """Write nucleus detection arrays to an SQLite-backed AnnotationStore. + + Converts the detection arrays into NumPy form, applies coordinate scaling + and optional class-ID remapping, and writes the results into an in-memory + SQLiteStore. If `save_path` is provided, the store is committed and saved + to disk as a `.db` file. This method provides a unified interface for + converting Dask-based detection outputs into persistent annotation storage. + + Args: + detection_arrays (dict[str, da.Array]): + A dictionary containing the detection fields: + - ``"x"``: dask array of x coordinates (``np.uint32``). + - ``"y"``: dask array of y coordinates (``np.uint32``). + - ``"classes"``: dask array of class IDs (``np.uint32``). + - ``"probabilities"``: dask array of detection scores (``np.float32``). + + scale_factor (tuple[float, float], optional): + Multiplicative factors applied to the x and y coordinates before + saving. The scaled coordinates are rounded to integer pixel + locations. Defaults to ``(1.0, 1.0)``. + + class_dict (dict or None): + Optional mapping of class IDs to class names or remapped IDs. + If ``None``, an identity mapping is used based on the detected + class IDs. + + save_path (Path or None): + Destination path for saving the `.db` file. If ``None``, the + resulting SQLiteStore is returned in memory. If provided, the + parent directory is created if needed, and the final store is + written as ``save_path.with_suffix(".db")``. + + batch_size (int): + Number of detection records to write per batch. Defaults to ``5000``. + + Returns: + Path or SQLiteStore: + - If `save_path` is provided: the path to the saved `.db` file. + - If `save_path` is ``None``: an in-memory `SQLiteStore` containing + all detections. + + Notes: + - The heavy lifting is delegated to + :meth:`NucleusDetector._write_detection_arrays_to_store`, + which performs coordinate scaling, class mapping, and batch writing. + + """ + xs = detection_arrays["x"] + ys = detection_arrays["y"] + classes = detection_arrays["classes"] + probs = detection_arrays["probabilities"] + + xs = np.atleast_1d(np.asarray(xs)) + ys = np.atleast_1d(np.asarray(ys)) + classes = np.atleast_1d(np.asarray(classes)) + probs = np.atleast_1d(np.asarray(probs)) + + if not len(xs) == len(ys) == len(classes) == len(probs): + msg = "Detection record lengths are misaligned." + raise ValueError(msg) + + store = SQLiteStore() + total_written = NucleusDetector._write_detection_arrays_to_store( + (xs, ys, classes, probs), + store, + scale_factor, + class_dict, + batch_size, + ) + logger.info("Total detections written to store: %s", total_written) + + if save_path: + save_path.parent.absolute().mkdir(parents=True, exist_ok=True) + save_path = save_path.parent.absolute() / (save_path.stem + ".db") + store.commit() + store.dump(save_path) + return save_path + + return store + + def run( + self: NucleusDetector, + images: list[os.PathLike | Path | WSIReader] | np.ndarray, + *, + masks: list[os.PathLike | Path] | np.ndarray | None = None, + input_resolutions: list[dict[Units, Resolution]] | None = None, + patch_input_shape: IntPair | None = None, + ioconfig: IOSegmentorConfig | None = None, + patch_mode: bool = True, + save_dir: os.PathLike | Path | None = None, + overwrite: bool = False, + output_type: str = "dict", + **kwargs: Unpack[NucleusDetectorRunParams], + ) -> AnnotationStore | Path | str | dict | list[Path]: + """Run the nucleus detection engine on input images. + + This method orchestrates the full inference pipeline, including preprocessing, + model inference, post-processing, and saving results. It supports both + patch-level and whole slide image (WSI) modes. + + Args: + images (list[PathLike | WSIReader] | np.ndarray): + Input images or patches. Can be a list of file paths, WSIReader objects, + or a NumPy array of image patches. + masks (list[PathLike] | np.ndarray | None): + Optional masks for WSI processing. Only used when `patch_mode` is False. + input_resolutions (list[dict[Units, Resolution]] | None): + Resolution settings for input heads. Supported units are `level`, + `power`, and `mpp`. Keys should be "units" and "resolution", e.g., + [{"units": "mpp", "resolution": 0.25}]. See :class:`WSIReader` for + details. + patch_input_shape (IntPair | None): + Shape of input patches (height, width), requested at read + resolution. Must be positive. + ioconfig (IOSegmentorConfig | None): + IO configuration for patch extraction and resolution. + patch_mode (bool): + Whether to treat input as patches (`True`) or WSIs (`False`). Default + is True. + save_dir (PathLike | None): + Directory to save output files. Required for WSI mode. + overwrite (bool): + Whether to overwrite existing output files. Default is False. + output_type (str): + Desired output format: "dict", "zarr", or "annotationstore". Default + is "dict". + **kwargs (NucleusDetectorRunParams): + Additional runtime parameters to configure segmentation. + + Optional Keys: + auto_get_mask (bool): + Whether to automatically generate segmentation masks using + `wsireader.tissue_mask()` during processing. + batch_size (int): + Number of image patches to feed to the model in a forward pass. + class_dict (dict): + Optional dictionary mapping classification outputs to + class names. + device (str): + Device to run the model on (e.g., "cpu", "cuda"). + labels (list): + Optional labels for input images. Only a single label per image + is supported. + memory_threshold (int): + Memory usage threshold (in percentage) to + trigger caching behavior. + num_workers (int): + Number of workers used in DataLoader. + output_file (str): + Output file name for saving results (e.g., .zarr or .db). + output_resolutions (Resolution): + Resolution used for writing output predictions. + patch_output_shape (tuple[int, int]): + Shape of output patches (height, width). + min_distance (int): + Minimum distance separating two nuclei (in pixels). + threshold_abs (float): + Absolute detection threshold applied to model outputs. + threshold_rel (float): + Relative detection threshold + (e.g., with respect to local maxima). + postproc_tile_shape (tuple[int, int]): + Tile shape (height, width) for post-processing (in pixels). + return_labels (bool): + Whether to return labels with predictions. + return_probabilities (bool): + Whether to return per-class probabilities. + scale_factor (tuple[float, float]): + Scale factor for converting annotations to baseline resolution. + Typically model_mpp / slide_mpp. + stride_shape (tuple[int, int]): + Stride used during WSI processing. + Defaults to patch_input_shape. + verbose (bool): + Whether to output logging information. + + Returns: + AnnotationStore | Path | str | dict | list[Path]: + - If `patch_mode` is True: returns predictions or path to saved output. + - If `patch_mode` is False: returns a dictionary mapping each WSI + to its output path. + + Examples: + >>> from tiatoolbox.models.engine.nucleus_detector import NucleusDetector + >>> detector = NucleusDetector(model="mapde-conic") + >>> # WSI workflow: save to AnnotationStore (.db) + >>> out = detector.run( + ... images=[pathlib.Path("example_wsi.tiff")], + ... patch_mode=False, + ... device="cuda", + ... save_dir=pathlib.Path("output_directory/"), + ... overwrite=True, + ... output_type="annotationstore", + ... class_dict={0: "nucleus"}, + ... auto_get_mask=True, + ... memory_threshold=80, + ... ) + >>> # Patch workflow: return in-memory detections + >>> patches = [np.ndarray, np.ndarray] # NHWC + >>> out = detector.run(patches, patch_mode=True, output_type="dict") + + + """ + output = super().run( + images=images, + masks=masks, + input_resolutions=input_resolutions, + patch_input_shape=patch_input_shape, + ioconfig=ioconfig, + patch_mode=patch_mode, + save_dir=save_dir, + overwrite=overwrite, + output_type=output_type, + **kwargs, + ) + + if not patch_mode: + # Clean up temporary zarr directory after WSI processing + # It should have been already deleted, but check anyway + temp_dir = Path(tempfile.gettempdir()) + if temp_dir.exists(): + # find file starting with 'tiatoolbox_nucleus_detector_' + # and ending with '.zarr' + for item in temp_dir.iterdir(): + if item.name.startswith( + "tiatoolbox_nucleus_detector_" + ) and item.name.endswith(".zarr"): + shutil.rmtree(item) + logger.info( + "Temporary zarr directory %s has been removed.", item + ) + + return output diff --git a/tiatoolbox/models/engine/semantic_segmentor.py b/tiatoolbox/models/engine/semantic_segmentor.py index e429940d7..ca0d6de5a 100644 --- a/tiatoolbox/models/engine/semantic_segmentor.py +++ b/tiatoolbox/models/engine/semantic_segmentor.py @@ -1369,7 +1369,7 @@ def prepare_full_batch( full_output_dict = {tuple(row): i for i, row in enumerate(full_output_locs)} matches = [full_output_dict[tuple(row)] for row in batch_locs] - total_size = np.max(matches).astype(np.uint16) + 1 + total_size = np.max(matches).astype(np.uint32) + 1 # Initialize full output array full_batch_output = np.zeros(