Skip to content

Commit c99cd91

Browse files
author
Donglai Wei
committed
load external mednext model
1 parent 066f1c2 commit c99cd91

File tree

10 files changed

+583
-77
lines changed

10 files changed

+583
-77
lines changed

connectomics/config/hydra_config.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -202,6 +202,11 @@ class ModelConfig:
202202
None # None = single task (apply all losses to all channels)
203203
)
204204

205+
# External model weights loading
206+
# For loading pretrained weights from external checkpoints (e.g., BANIS, nnUNet)
207+
external_weights_path: Optional[str] = None # Path to external checkpoint file
208+
external_weights_key_prefix: str = "model." # Prefix to strip from state_dict keys
209+
205210

206211
# Label transformation configurations
207212
@dataclass
@@ -858,6 +863,7 @@ class SlidingWindowConfig:
858863
)
859864
padding_mode: str = "constant" # Padding mode at volume boundaries
860865
pad_size: Optional[List[int]] = None # Padding size for context (e.g., [16, 32, 32])
866+
save_channels: Optional[List[int]] = None # Channel indices to save (None = all channels)
861867

862868

863869
@dataclass

connectomics/data/io/io.py

Lines changed: 58 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
import h5py
1717
import numpy as np
1818
import imageio
19+
import nibabel as nib
1920

2021
# Avoid PIL "IOError: image file truncated"
2122
from PIL import ImageFile
@@ -263,7 +264,7 @@ def write_pickle_file(filename: str, data: object) -> None:
263264
def read_volume(
264265
filename: str, dataset: Optional[str] = None, drop_channel: bool = False
265266
) -> np.ndarray:
266-
"""Load volumetric data in HDF5, TIFF or PNG formats.
267+
"""Load volumetric data in HDF5, TIFF, PNG, or NIfTI formats.
267268
268269
Args:
269270
filename: Path to the volume file
@@ -276,7 +277,11 @@ def read_volume(
276277
Raises:
277278
ValueError: If file format is not recognized
278279
"""
279-
image_suffix = filename[filename.rfind(".") + 1 :].lower()
280+
# Handle .nii.gz files specially
281+
if filename.endswith('.nii.gz'):
282+
image_suffix = 'nii.gz'
283+
else:
284+
image_suffix = filename[filename.rfind(".") + 1 :].lower()
280285

281286
if image_suffix in ["h5", "hdf5"]:
282287
data = read_hdf5(filename, dataset)
@@ -314,9 +319,23 @@ def read_volume(
314319
if data.ndim == 4:
315320
# Convert (D, H, W, C) to (C, D, H, W) order
316321
data = data.transpose(3, 0, 1, 2)
322+
elif image_suffix in ["nii", "nii.gz"]:
323+
# NIfTI format (.nii or .nii.gz)
324+
nii_img = nib.load(filename)
325+
data = np.asarray(nii_img.dataobj)
326+
# NIfTI is typically (X, Y, Z) or (X, Y, Z, C)
327+
# Convert to our (D, H, W) or (C, D, H, W) format
328+
# X=W (width), Y=H (height), Z=D (depth)
329+
if data.ndim == 3:
330+
# (X, Y, Z) -> (Z, Y, X) = (D, H, W)
331+
data = data.transpose(2, 1, 0)
332+
elif data.ndim == 4:
333+
# (X, Y, Z, C) -> (C, Z, Y, X) = (C, D, H, W)
334+
data = data.transpose(3, 2, 1, 0)
317335
else:
318336
raise ValueError(
319-
f"Unrecognizable file format for {filename}. " f"Expected: h5, hdf5, tif, tiff, or png"
337+
f"Unrecognizable file format for {filename}. "
338+
f"Expected: h5, hdf5, tif, tiff, png, nii, or nii.gz"
320339
)
321340

322341
# if data.ndim not in [3, 4]:
@@ -342,7 +361,7 @@ def save_volume(
342361
filename: Output filename or directory path
343362
volume: Volume data to save
344363
dataset: Dataset name for HDF5 format
345-
file_format: Output format ('h5' or 'png')
364+
file_format: Output format ('h5', 'png', 'nii', or 'nii.gz')
346365
347366
Raises:
348367
ValueError: If file format is not supported
@@ -351,8 +370,21 @@ def save_volume(
351370
write_hdf5(filename, volume, dataset=dataset)
352371
elif file_format == "png":
353372
save_images(filename, volume)
373+
elif file_format in ["nii", "nii.gz"]:
374+
# NIfTI format
375+
# Convert from our (D, H, W) or (C, D, H, W) to NIfTI (X, Y, Z) or (X, Y, Z, C)
376+
if volume.ndim == 3:
377+
# (D, H, W) -> (W, H, D) = (X, Y, Z)
378+
nii_data = volume.transpose(2, 1, 0)
379+
elif volume.ndim == 4:
380+
# (C, D, H, W) -> (W, H, D, C) = (X, Y, Z, C)
381+
nii_data = volume.transpose(3, 2, 1, 0)
382+
else:
383+
nii_data = volume
384+
nii_img = nib.Nifti1Image(nii_data, affine=np.eye(4))
385+
nib.save(nii_img, filename)
354386
else:
355-
raise ValueError(f"Unsupported format: {file_format}. " f"Supported formats: h5, png")
387+
raise ValueError(f"Unsupported format: {file_format}. " f"Supported formats: h5, png, nii, nii.gz")
356388

357389

358390
def get_vol_shape(filename: str, dataset: Optional[str] = None) -> tuple:
@@ -382,7 +414,11 @@ def get_vol_shape(filename: str, dataset: Optional[str] = None) -> tuple:
382414
if not os.path.exists(filename):
383415
raise FileNotFoundError(f"File not found: {filename}")
384416

385-
image_suffix = filename[filename.rfind(".") + 1 :].lower()
417+
# Handle .nii.gz files specially
418+
if filename.endswith('.nii.gz'):
419+
image_suffix = 'nii.gz'
420+
else:
421+
image_suffix = filename[filename.rfind(".") + 1 :].lower()
386422

387423
if image_suffix in ["h5", "hdf5"]:
388424
# HDF5: Read shape from metadata (no data loading)
@@ -424,9 +460,24 @@ def get_vol_shape(filename: str, dataset: Optional[str] = None) -> tuple:
424460
else:
425461
raise ValueError(f"Unsupported PNG dimensions: {first_image.ndim}D")
426462

463+
elif image_suffix in ["nii", "nii.gz"]:
464+
# NIfTI: Read shape from header (no data loading)
465+
nii_img = nib.load(filename)
466+
nii_shape = nii_img.header.get_data_shape()
467+
# Convert from NIfTI (X, Y, Z) or (X, Y, Z, C) to our (D, H, W) or (C, D, H, W)
468+
if len(nii_shape) == 3:
469+
# (X, Y, Z) -> (Z, Y, X) = (D, H, W)
470+
return (nii_shape[2], nii_shape[1], nii_shape[0])
471+
elif len(nii_shape) == 4:
472+
# (X, Y, Z, C) -> (C, D, H, W)
473+
return (nii_shape[3], nii_shape[2], nii_shape[1], nii_shape[0])
474+
else:
475+
return nii_shape
476+
427477
else:
428478
raise ValueError(
429-
f"Unrecognizable file format for {filename}. " f"Expected: h5, hdf5, tif, tiff, or png"
479+
f"Unrecognizable file format for {filename}. "
480+
f"Expected: h5, hdf5, tif, tiff, png, nii, or nii.gz"
430481
)
431482

432483

connectomics/decoding/segmentation.py

Lines changed: 122 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
"""
1616

1717
from __future__ import print_function, division
18-
from typing import Optional, Tuple
18+
from typing import Optional, Tuple, List
1919
import numpy as np
2020
import cc3d
2121
import fastremap
@@ -352,14 +352,17 @@ def decode_binary_contour_watershed(
352352
def decode_binary_contour_distance_watershed(
353353
predictions: np.ndarray,
354354
binary_threshold: Tuple[float, float] = (0.9, 0.85),
355-
contour_threshold: Tuple[float, float] = (0.8, 1.1),
355+
contour_threshold: Optional[Tuple[float, float]] = (0.8, 1.1),
356356
distance_threshold: Tuple[float, float] = (0.5, 0),
357357
min_instance_size: int = 128,
358358
remove_small_mode: str = "background",
359359
min_seed_size: int = 32,
360360
return_seed: bool = False,
361361
precomputed_seed: Optional[np.ndarray] = None,
362362
prediction_scale: int = 255,
363+
binary_channels: Optional[List[int]] = None,
364+
contour_channels: Optional[List[int]] = None,
365+
distance_channels: Optional[List[int]] = None,
363366
):
364367
r"""Convert binary foreground probability maps, instance contours and signed distance
365368
transform to instance masks via watershed segmentation algorithm.
@@ -369,48 +372,145 @@ def decode_binary_contour_distance_watershed(
369372
function that converts the input image into ``np.float64`` data type for processing. Therefore please make sure enough memory is allocated when handling large arrays.
370373
371374
Args:
372-
predictions (numpy.ndarray): foreground, contour, and distance probability of shape :math:`(3, Z, Y, X)`.
375+
predictions (numpy.ndarray): foreground, contour, and distance probability of shape :math:`(3, Z, Y, X)`
376+
or :math:`(2, Z, Y, X)` if contour is disabled.
373377
binary_threshold (tuple): tuple of two floats (seed_threshold, foreground_threshold) for binary mask.
374378
The first value is used for seed generation, the second for foreground mask. Default: (0.9, 0.85)
375-
contour_threshold (tuple): tuple of two floats (seed_threshold, foreground_threshold) for instance contours.
376-
The first value is used for seed generation, the second for foreground mask. Default: (0.8, 1.1)
379+
contour_threshold (tuple or None): tuple of two floats (seed_threshold, foreground_threshold) for instance contours.
380+
The first value is used for seed generation, the second for foreground mask.
381+
Set to None to disable contour constraints (for BANIS-style binary+distance only). Default: (0.8, 1.1)
377382
distance_threshold (tuple): tuple of two floats (seed_threshold, foreground_threshold) for signed distance.
378383
The first value is used for seed generation, the second for foreground mask. Default: (0.5, -0.5)
379384
min_instance_size (int): minimum size threshold for instances to keep. Default: 128
380385
remove_small_mode (str): ``'background'``, ``'neighbor'`` or ``'none'``. Default: ``'background'``
381386
min_seed_size (int): minimum size of seed objects. Default: 32
382387
return_seed (bool): whether to return the seed map. Default: False
383388
precomputed_seed (numpy.ndarray, optional): precomputed seed map. Default: None
384-
prediction_scale (int): scale of input predictions (255 for uint8 range). Default: 255
389+
prediction_scale (int): scale of input predictions (255 for uint8 range, 1 for 0-1 range). Default: 255
390+
binary_channels (list of int, optional): channel indices for binary mask. If multiple, they are averaged.
391+
Default: None (uses position-based assignment)
392+
contour_channels (list of int, optional): channel indices for contour. If multiple, they are averaged.
393+
Default: None (uses position-based assignment)
394+
distance_channels (list of int, optional): channel indices for distance. If multiple, they are averaged.
395+
Default: None (uses position-based assignment)
385396
386397
Returns:
387398
numpy.ndarray or tuple: Instance segmentation mask, or (mask, seed) if return_seed=True.
388-
"""
389-
assert predictions.shape[0] == 3
390-
binary, contour, distance = predictions[0], predictions[1], predictions[2]
391399
400+
Example:
401+
>>> # Standard 3-channel (binary, contour, distance)
402+
>>> seg = decode_binary_contour_distance_watershed(predictions)
403+
404+
>>> # BANIS-style 2-channel (binary, distance) - no contour
405+
>>> seg = decode_binary_contour_distance_watershed(
406+
... predictions, # shape (2, Z, Y, X)
407+
... binary_threshold=(0.5, 0.5),
408+
... contour_threshold=None, # Disable contour
409+
... distance_threshold=(0.0, -1.0),
410+
... prediction_scale=1,
411+
... )
412+
413+
>>> # Explicit channel selection with averaging
414+
>>> seg = decode_binary_contour_distance_watershed(
415+
... predictions, # shape (3, Z, Y, X) with channels [aff_x, aff_y, SDT]
416+
... binary_channels=[0, 1], # Average channels 0 and 1 for binary
417+
... contour_channels=None, # No contour
418+
... distance_channels=[2], # Channel 2 for distance
419+
... contour_threshold=None,
420+
... prediction_scale=1,
421+
... )
422+
"""
423+
# Check if contour is disabled
424+
use_contour = contour_threshold is not None
425+
426+
# Extract channels using explicit selection or position-based fallback
427+
if binary_channels is not None or distance_channels is not None:
428+
# Explicit channel selection mode
429+
if binary_channels is not None:
430+
if len(binary_channels) > 1:
431+
binary = predictions[binary_channels].mean(axis=0)
432+
else:
433+
binary = predictions[binary_channels[0]]
434+
else:
435+
binary = predictions[0]
436+
437+
if distance_channels is not None:
438+
if len(distance_channels) > 1:
439+
distance = predictions[distance_channels].mean(axis=0)
440+
else:
441+
distance = predictions[distance_channels[0]]
442+
else:
443+
distance = predictions[-1]
444+
445+
if use_contour:
446+
if contour_channels is not None:
447+
if len(contour_channels) > 1:
448+
contour = predictions[contour_channels].mean(axis=0)
449+
else:
450+
contour = predictions[contour_channels[0]]
451+
else:
452+
# Default: assume contour is second-to-last if using contour
453+
contour = predictions[-2]
454+
else:
455+
contour = None
456+
else:
457+
# Position-based fallback (legacy behavior)
458+
if use_contour:
459+
assert predictions.shape[0] >= 3, f"Expected at least 3 channels (binary, contour, distance), got {predictions.shape[0]}"
460+
# If more than 3 channels, first N-2 channels are binary (average them)
461+
if predictions.shape[0] > 3:
462+
binary = predictions[:-2].mean(axis=0)
463+
contour, distance = predictions[-2], predictions[-1]
464+
else:
465+
binary, contour, distance = predictions[0], predictions[1], predictions[2]
466+
else:
467+
assert predictions.shape[0] >= 2, f"Expected at least 2 channels (binary, distance) when contour disabled, got {predictions.shape[0]}"
468+
# If more than 2 channels, first N-1 channels are binary (average them)
469+
if predictions.shape[0] > 2:
470+
binary = predictions[:-1].mean(axis=0)
471+
distance = predictions[-1]
472+
else:
473+
binary, distance = predictions[0], predictions[1]
474+
contour = None
475+
476+
# Convert thresholds based on prediction scale
392477
if prediction_scale == 255:
393478
distance = (distance / prediction_scale) * 2.0 - 1.0
394-
binary_threshold = binary_threshold * prediction_scale
395-
contour_threshold = contour_threshold * prediction_scale
396-
distance_threshold = distance_threshold * prediction_scale
479+
binary_threshold = (binary_threshold[0] * prediction_scale, binary_threshold[1] * prediction_scale)
480+
if use_contour:
481+
contour_threshold = (contour_threshold[0] * prediction_scale, contour_threshold[1] * prediction_scale)
482+
distance_threshold = (distance_threshold[0] * prediction_scale, distance_threshold[1] * prediction_scale)
397483

398484
if precomputed_seed is not None:
399485
seed = precomputed_seed
400486
else: # compute the instance seeds
401-
seed_map = (
402-
(binary > binary_threshold[0])
403-
* (contour < contour_threshold[0])
404-
* (distance > distance_threshold[0])
405-
)
487+
if use_contour:
488+
seed_map = (
489+
(binary > binary_threshold[0])
490+
* (contour < contour_threshold[0])
491+
* (distance > distance_threshold[0])
492+
)
493+
else:
494+
# No contour constraint - only binary and distance
495+
seed_map = (
496+
(binary > binary_threshold[0])
497+
* (distance > distance_threshold[0])
498+
)
406499
seed = cc3d.connected_components(seed_map)
407500
seed = remove_small_objects(seed, min_seed_size)
408501

409-
foreground = (
410-
(binary > binary_threshold[1])
411-
* (contour < contour_threshold[1])
412-
* (distance > distance_threshold[1])
413-
)
502+
if use_contour:
503+
foreground = (
504+
(binary > binary_threshold[1])
505+
* (contour < contour_threshold[1])
506+
* (distance > distance_threshold[1])
507+
)
508+
else:
509+
# No contour constraint - only binary and distance
510+
foreground = (
511+
(binary > binary_threshold[1])
512+
* (distance > distance_threshold[1])
513+
)
414514

415515
segmentation = mahotas.cwatershed(-distance.astype(np.float64), seed)
416516
segmentation[~foreground] = (

connectomics/lightning/inference.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -802,6 +802,11 @@ def write_outputs(
802802
if hasattr(cfg.inference, "postprocessing"):
803803
output_transpose = getattr(cfg.inference.postprocessing, "output_transpose", [])
804804

805+
# Get save_channels from sliding_window config (to reduce memory/storage)
806+
save_channels = None
807+
if hasattr(cfg.inference, "sliding_window"):
808+
save_channels = getattr(cfg.inference.sliding_window, "save_channels", None)
809+
805810
# Determine actual batch size from predictions
806811
# Handle both batched (B, ...) and unbatched (...) predictions
807812
if predictions.ndim >= 4:
@@ -839,6 +844,20 @@ def write_outputs(
839844
filename = filenames[idx]
840845
output_path = output_dir / f"{filename}_{suffix}.h5"
841846

847+
# Select specific channels if save_channels is specified
848+
# save_channels can be a list of indices like [0, 6] to save only channels 0 and 6
849+
# Skip if channels were already filtered during inference (check if num_channels matches)
850+
if save_channels is not None and sample.ndim >= 4:
851+
channel_indices = list(save_channels)
852+
num_channels = sample.shape[0]
853+
# Only filter if not already filtered (channels > len(save_channels))
854+
if num_channels > len(channel_indices):
855+
try:
856+
sample = sample[channel_indices]
857+
print(f" Selected channels {channel_indices} from {predictions[idx].shape[0]} channels")
858+
except Exception as e:
859+
print(f" WARNING: write_outputs - channel selection failed: {e}, keeping all channels")
860+
842861
# Transpose if needed (output_transpose: list of axis permutation)
843862
if output_transpose and len(output_transpose) > 0:
844863
try:

connectomics/lightning/lit_model.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -286,7 +286,17 @@ def _extract_main_output(self, outputs: Union[torch.Tensor, Dict[str, torch.Tens
286286
def _sliding_window_predict(self, inputs: torch.Tensor) -> torch.Tensor:
287287
"""Wrapper used by MONAI inferer to obtain primary model predictions."""
288288
outputs = self.forward(inputs)
289-
return self._extract_main_output(outputs)
289+
result = self._extract_main_output(outputs)
290+
291+
# Filter channels during inference to reduce memory during accumulation
292+
# This is applied before the sliding window aggregates results
293+
if hasattr(self.cfg, 'inference') and hasattr(self.cfg.inference, 'sliding_window'):
294+
save_channels = getattr(self.cfg.inference.sliding_window, 'save_channels', None)
295+
if save_channels is not None:
296+
# Select only specified channels (e.g., [0, 6] for foreground + SDT)
297+
result = result[:, save_channels, ...]
298+
299+
return result
290300

291301
def _apply_tta_preprocessing(self, tensor: torch.Tensor) -> torch.Tensor:
292302
"""

0 commit comments

Comments
 (0)