1515"""
1616
1717from __future__ import print_function , division
18- from typing import Optional , Tuple
18+ from typing import Optional , Tuple , List
1919import numpy as np
2020import cc3d
2121import fastremap
@@ -352,14 +352,17 @@ def decode_binary_contour_watershed(
352352def decode_binary_contour_distance_watershed (
353353 predictions : np .ndarray ,
354354 binary_threshold : Tuple [float , float ] = (0.9 , 0.85 ),
355- contour_threshold : Tuple [float , float ] = (0.8 , 1.1 ),
355+ contour_threshold : Optional [ Tuple [float , float ] ] = (0.8 , 1.1 ),
356356 distance_threshold : Tuple [float , float ] = (0.5 , 0 ),
357357 min_instance_size : int = 128 ,
358358 remove_small_mode : str = "background" ,
359359 min_seed_size : int = 32 ,
360360 return_seed : bool = False ,
361361 precomputed_seed : Optional [np .ndarray ] = None ,
362362 prediction_scale : int = 255 ,
363+ binary_channels : Optional [List [int ]] = None ,
364+ contour_channels : Optional [List [int ]] = None ,
365+ distance_channels : Optional [List [int ]] = None ,
363366):
364367 r"""Convert binary foreground probability maps, instance contours and signed distance
365368 transform to instance masks via watershed segmentation algorithm.
@@ -369,48 +372,145 @@ def decode_binary_contour_distance_watershed(
369372 function that converts the input image into ``np.float64`` data type for processing. Therefore please make sure enough memory is allocated when handling large arrays.
370373
371374 Args:
372- predictions (numpy.ndarray): foreground, contour, and distance probability of shape :math:`(3, Z, Y, X)`.
375+ predictions (numpy.ndarray): foreground, contour, and distance probability of shape :math:`(3, Z, Y, X)`
376+ or :math:`(2, Z, Y, X)` if contour is disabled.
373377 binary_threshold (tuple): tuple of two floats (seed_threshold, foreground_threshold) for binary mask.
374378 The first value is used for seed generation, the second for foreground mask. Default: (0.9, 0.85)
375- contour_threshold (tuple): tuple of two floats (seed_threshold, foreground_threshold) for instance contours.
376- The first value is used for seed generation, the second for foreground mask. Default: (0.8, 1.1)
379+ contour_threshold (tuple or None): tuple of two floats (seed_threshold, foreground_threshold) for instance contours.
380+ The first value is used for seed generation, the second for foreground mask.
381+ Set to None to disable contour constraints (for BANIS-style binary+distance only). Default: (0.8, 1.1)
377382 distance_threshold (tuple): tuple of two floats (seed_threshold, foreground_threshold) for signed distance.
378383 The first value is used for seed generation, the second for foreground mask. Default: (0.5, -0.5)
379384 min_instance_size (int): minimum size threshold for instances to keep. Default: 128
380385 remove_small_mode (str): ``'background'``, ``'neighbor'`` or ``'none'``. Default: ``'background'``
381386 min_seed_size (int): minimum size of seed objects. Default: 32
382387 return_seed (bool): whether to return the seed map. Default: False
383388 precomputed_seed (numpy.ndarray, optional): precomputed seed map. Default: None
384- prediction_scale (int): scale of input predictions (255 for uint8 range). Default: 255
389+ prediction_scale (int): scale of input predictions (255 for uint8 range, 1 for 0-1 range). Default: 255
390+ binary_channels (list of int, optional): channel indices for binary mask. If multiple, they are averaged.
391+ Default: None (uses position-based assignment)
392+ contour_channels (list of int, optional): channel indices for contour. If multiple, they are averaged.
393+ Default: None (uses position-based assignment)
394+ distance_channels (list of int, optional): channel indices for distance. If multiple, they are averaged.
395+ Default: None (uses position-based assignment)
385396
386397 Returns:
387398 numpy.ndarray or tuple: Instance segmentation mask, or (mask, seed) if return_seed=True.
388- """
389- assert predictions .shape [0 ] == 3
390- binary , contour , distance = predictions [0 ], predictions [1 ], predictions [2 ]
391399
400+ Example:
401+ >>> # Standard 3-channel (binary, contour, distance)
402+ >>> seg = decode_binary_contour_distance_watershed(predictions)
403+
404+ >>> # BANIS-style 2-channel (binary, distance) - no contour
405+ >>> seg = decode_binary_contour_distance_watershed(
406+ ... predictions, # shape (2, Z, Y, X)
407+ ... binary_threshold=(0.5, 0.5),
408+ ... contour_threshold=None, # Disable contour
409+ ... distance_threshold=(0.0, -1.0),
410+ ... prediction_scale=1,
411+ ... )
412+
413+ >>> # Explicit channel selection with averaging
414+ >>> seg = decode_binary_contour_distance_watershed(
415+ ... predictions, # shape (3, Z, Y, X) with channels [aff_x, aff_y, SDT]
416+ ... binary_channels=[0, 1], # Average channels 0 and 1 for binary
417+ ... contour_channels=None, # No contour
418+ ... distance_channels=[2], # Channel 2 for distance
419+ ... contour_threshold=None,
420+ ... prediction_scale=1,
421+ ... )
422+ """
423+ # Check if contour is disabled
424+ use_contour = contour_threshold is not None
425+
426+ # Extract channels using explicit selection or position-based fallback
427+ if binary_channels is not None or distance_channels is not None :
428+ # Explicit channel selection mode
429+ if binary_channels is not None :
430+ if len (binary_channels ) > 1 :
431+ binary = predictions [binary_channels ].mean (axis = 0 )
432+ else :
433+ binary = predictions [binary_channels [0 ]]
434+ else :
435+ binary = predictions [0 ]
436+
437+ if distance_channels is not None :
438+ if len (distance_channels ) > 1 :
439+ distance = predictions [distance_channels ].mean (axis = 0 )
440+ else :
441+ distance = predictions [distance_channels [0 ]]
442+ else :
443+ distance = predictions [- 1 ]
444+
445+ if use_contour :
446+ if contour_channels is not None :
447+ if len (contour_channels ) > 1 :
448+ contour = predictions [contour_channels ].mean (axis = 0 )
449+ else :
450+ contour = predictions [contour_channels [0 ]]
451+ else :
452+ # Default: assume contour is second-to-last if using contour
453+ contour = predictions [- 2 ]
454+ else :
455+ contour = None
456+ else :
457+ # Position-based fallback (legacy behavior)
458+ if use_contour :
459+ assert predictions .shape [0 ] >= 3 , f"Expected at least 3 channels (binary, contour, distance), got { predictions .shape [0 ]} "
460+ # If more than 3 channels, first N-2 channels are binary (average them)
461+ if predictions .shape [0 ] > 3 :
462+ binary = predictions [:- 2 ].mean (axis = 0 )
463+ contour , distance = predictions [- 2 ], predictions [- 1 ]
464+ else :
465+ binary , contour , distance = predictions [0 ], predictions [1 ], predictions [2 ]
466+ else :
467+ assert predictions .shape [0 ] >= 2 , f"Expected at least 2 channels (binary, distance) when contour disabled, got { predictions .shape [0 ]} "
468+ # If more than 2 channels, first N-1 channels are binary (average them)
469+ if predictions .shape [0 ] > 2 :
470+ binary = predictions [:- 1 ].mean (axis = 0 )
471+ distance = predictions [- 1 ]
472+ else :
473+ binary , distance = predictions [0 ], predictions [1 ]
474+ contour = None
475+
476+ # Convert thresholds based on prediction scale
392477 if prediction_scale == 255 :
393478 distance = (distance / prediction_scale ) * 2.0 - 1.0
394- binary_threshold = binary_threshold * prediction_scale
395- contour_threshold = contour_threshold * prediction_scale
396- distance_threshold = distance_threshold * prediction_scale
479+ binary_threshold = (binary_threshold [0 ] * prediction_scale , binary_threshold [1 ] * prediction_scale )
480+ if use_contour :
481+ contour_threshold = (contour_threshold [0 ] * prediction_scale , contour_threshold [1 ] * prediction_scale )
482+ distance_threshold = (distance_threshold [0 ] * prediction_scale , distance_threshold [1 ] * prediction_scale )
397483
398484 if precomputed_seed is not None :
399485 seed = precomputed_seed
400486 else : # compute the instance seeds
401- seed_map = (
402- (binary > binary_threshold [0 ])
403- * (contour < contour_threshold [0 ])
404- * (distance > distance_threshold [0 ])
405- )
487+ if use_contour :
488+ seed_map = (
489+ (binary > binary_threshold [0 ])
490+ * (contour < contour_threshold [0 ])
491+ * (distance > distance_threshold [0 ])
492+ )
493+ else :
494+ # No contour constraint - only binary and distance
495+ seed_map = (
496+ (binary > binary_threshold [0 ])
497+ * (distance > distance_threshold [0 ])
498+ )
406499 seed = cc3d .connected_components (seed_map )
407500 seed = remove_small_objects (seed , min_seed_size )
408501
409- foreground = (
410- (binary > binary_threshold [1 ])
411- * (contour < contour_threshold [1 ])
412- * (distance > distance_threshold [1 ])
413- )
502+ if use_contour :
503+ foreground = (
504+ (binary > binary_threshold [1 ])
505+ * (contour < contour_threshold [1 ])
506+ * (distance > distance_threshold [1 ])
507+ )
508+ else :
509+ # No contour constraint - only binary and distance
510+ foreground = (
511+ (binary > binary_threshold [1 ])
512+ * (distance > distance_threshold [1 ])
513+ )
414514
415515 segmentation = mahotas .cwatershed (- distance .astype (np .float64 ), seed )
416516 segmentation [~ foreground ] = (
0 commit comments