@@ -244,74 +244,111 @@ def seg_to_binary(label, segment_id=[]):
244244 return fg_mask
245245
246246
247- def seg_to_affinity (seg : np .ndarray , target_opt : List [str ]) -> np .ndarray :
247+ def seg_to_affinity (
248+ seg : np .ndarray ,
249+ offsets : List [str ] = None ,
250+ long_range : int = None ,
251+ ) -> np .ndarray :
248252 """
249- Compute affinities from a segmentation based on target options.
253+ Compute affinity maps from segmentation.
254+
255+ Supports two modes:
256+ 1. DeepEM/SNEMI style: Provide `offsets` as list of strings (e.g., ["0-0-1", "0-1-0", "1-0-0"])
257+ 2. BANIS style: Provide `long_range` as int for 6-channel output (3 short + 3 long range)
250258
251259 Args:
252260 seg: The segmentation to compute affinities from. Shape: (z, y, x).
253- target_opt: List of strings defining affinity offsets.
254- Can be either:
255- - Legacy format: ['1', '0-0-1', '0-1-0', ...] (first element is type indicator)
256- - Modern format: ['0-0-1', '0-1-0', ...] (direct offset list)
261+ 0 indicates background.
262+ offsets: List of offset strings in "z-y-x" format (e.g., ["0-0-1", "0-1-0", "1-0-0"]).
263+ Each string defines one affinity channel.
264+ long_range: BANIS-style: offset for long-range affinities. Produces 6 channels:
265+ - Channel 0-2: Short-range (offset 1) for z, y, x
266+ - Channel 3-5: Long-range (offset long_range) for z, y, x
257267
258268 Returns:
259- The affinities. Shape: (num_offsets , z, y, x).
269+ The affinities. Shape: (num_channels , z, y, x).
260270 """
261- if len (target_opt ) == 0 :
262- # Default short-range affinities
263- offsets = [[1 , 0 , 0 ], [0 , 1 , 0 ], [0 , 0 , 1 ]]
264- else :
265- # Detect format: check if first element is a type indicator or an offset
266- start_idx = 0
267- if len (target_opt ) > 0 and "-" not in target_opt [0 ]:
268- # Legacy format: first element is type indicator (e.g., '1')
269- start_idx = 1
270-
271- # Parse offsets from target_opt
272- offsets = []
273- for opt_str in target_opt [start_idx :]:
274- if "-" in opt_str :
275- offset = [int (x ) for x in opt_str .split ("-" )]
276- offsets .append (offset )
277-
278- # Fallback to default if no valid offsets found
279- if len (offsets ) == 0 :
280- offsets = [[1 , 0 , 0 ], [0 , 1 , 0 ], [0 , 0 , 1 ]]
281-
282- num_offsets = len (offsets )
283- affinities = np .zeros ((num_offsets , * seg .shape ), dtype = np .float32 )
284-
285- for i , offset in enumerate (offsets ):
286- dz , dy , dx = offset
287-
288- # Create slices for the offset
271+ # BANIS mode: use long_range parameter (takes precedence if specified)
272+ if long_range is not None :
273+ affinities = np .zeros ((6 , * seg .shape ), dtype = np .float32 )
274+
275+ # Short range affinities (offset 1)
276+ affinities [0 , :- 1 ] = (seg [:- 1 ] == seg [1 :]) & (seg [1 :] > 0 )
277+ affinities [1 , :, :- 1 ] = (seg [:, :- 1 ] == seg [:, 1 :]) & (seg [:, 1 :] > 0 )
278+ affinities [2 , :, :, :- 1 ] = (seg [:, :, :- 1 ] == seg [:, :, 1 :]) & (seg [:, :, 1 :] > 0 )
279+
280+ # Long range affinities
281+ affinities [3 , :- long_range ] = (seg [:- long_range ] == seg [long_range :]) & (seg [long_range :] > 0 )
282+ affinities [4 , :, :- long_range ] = (seg [:, :- long_range ] == seg [:, long_range :]) & (seg [:, long_range :] > 0 )
283+ affinities [5 , :, :, :- long_range ] = (seg [:, :, :- long_range ] == seg [:, :, long_range :]) & (seg [:, :, long_range :] > 0 )
284+
285+ return affinities
286+
287+ # DeepEM/SNEMI mode: use offsets parameter
288+ if offsets is None :
289+ # Default: short-range affinities for z, y, x
290+ offsets = ["1-0-0" , "0-1-0" , "0-0-1" ]
291+
292+ # Parse offsets from strings
293+ parsed_offsets = []
294+ for offset_str in offsets :
295+ parts = offset_str .split ("-" )
296+ if len (parts ) == 3 :
297+ parsed_offsets .append ([int (parts [0 ]), int (parts [1 ]), int (parts [2 ])])
298+ else :
299+ raise ValueError (f"Invalid offset format: { offset_str } . Expected 'z-y-x' format." )
300+
301+ num_channels = len (parsed_offsets )
302+ affinities = np .zeros ((num_channels , * seg .shape ), dtype = np .float32 )
303+
304+ for i , (dz , dy , dx ) in enumerate (parsed_offsets ):
305+ # Handle each axis independently
306+ # For positive offset: compare seg[:-offset] with seg[offset:]
307+ # For negative offset: compare seg[-offset:] with seg[:offset]
308+
309+ if dz == 0 and dy == 0 and dx == 0 :
310+ # Zero offset: all foreground pixels are 1
311+ affinities [i ] = (seg > 0 ).astype (np .float32 )
312+ continue
313+
314+ # Build source and destination slices for each axis
289315 if dz > 0 :
290- src_slice = ( slice (None , - dz ), slice ( None ), slice ( None ) )
291- dst_slice = ( slice (dz , None ), slice ( None ), slice ( None ) )
316+ z_src = slice (None , - dz )
317+ z_dst = slice (dz , None )
292318 elif dz < 0 :
293- src_slice = ( slice (- dz , None ), slice ( None ), slice ( None ) )
294- dst_slice = ( slice (None , dz ), slice ( None ), slice ( None ) )
319+ z_src = slice (- dz , None )
320+ z_dst = slice (None , dz )
295321 else :
296- src_slice = ( slice (None ), slice ( None ), slice ( None ) )
297- dst_slice = ( slice (None ), slice ( None ), slice ( None ) )
322+ z_src = slice (None )
323+ z_dst = slice (None )
298324
299325 if dy > 0 :
300- src_slice = ( src_slice [ 0 ], slice (None , - dy ), src_slice [ 2 ] )
301- dst_slice = ( dst_slice [ 0 ], slice (dy , None ), dst_slice [ 2 ] )
326+ y_src = slice (None , - dy )
327+ y_dst = slice (dy , None )
302328 elif dy < 0 :
303- src_slice = (src_slice [0 ], slice (- dy , None ), src_slice [2 ])
304- dst_slice = (dst_slice [0 ], slice (None , dy ), dst_slice [2 ])
329+ y_src = slice (- dy , None )
330+ y_dst = slice (None , dy )
331+ else :
332+ y_src = slice (None )
333+ y_dst = slice (None )
305334
306335 if dx > 0 :
307- src_slice = ( src_slice [ 0 ], src_slice [ 1 ], slice (None , - dx ) )
308- dst_slice = ( dst_slice [ 0 ], dst_slice [ 1 ], slice (dx , None ) )
336+ x_src = slice (None , - dx )
337+ x_dst = slice (dx , None )
309338 elif dx < 0 :
310- src_slice = (src_slice [0 ], src_slice [1 ], slice (- dx , None ))
311- dst_slice = (dst_slice [0 ], dst_slice [1 ], slice (None , dx ))
339+ x_src = slice (- dx , None )
340+ x_dst = slice (None , dx )
341+ else :
342+ x_src = slice (None )
343+ x_dst = slice (None )
344+
345+ src_slice = (z_src , y_src , x_src )
346+ dst_slice = (z_dst , y_dst , x_dst )
312347
313- # Compute affinity
314- affinities [i ][dst_slice ] = (seg [src_slice ] == seg [dst_slice ]) & (seg [dst_slice ] > 0 )
348+ # Compute affinity: same segment ID and not background
349+ affinities [i ][dst_slice ] = (
350+ (seg [src_slice ] == seg [dst_slice ]) & (seg [dst_slice ] > 0 )
351+ ).astype (np .float32 )
315352
316353 return affinities
317354
0 commit comments