@@ -176,8 +176,9 @@ def extract_main_output(
176176
177177 def sliding_window_predict (self , inputs : torch .Tensor ) -> torch .Tensor :
178178 """Wrapper used by MONAI inferer to obtain primary model predictions."""
179- outputs = self .forward_fn (inputs )
180- return self .extract_main_output (outputs )
179+ with torch .no_grad ():
180+ outputs = self .forward_fn (inputs )
181+ return self .extract_main_output (outputs )
181182
182183 def apply_tta_preprocessing (self , tensor : torch .Tensor ) -> torch .Tensor :
183184 """
@@ -220,6 +221,10 @@ def apply_tta_preprocessing(self, tensor: torch.Tensor) -> torch.Tensor:
220221
221222 if act == "sigmoid" :
222223 channel_tensor = torch .sigmoid (channel_tensor )
224+ elif act == "scale_sigmoid" :
225+ # Scaled sigmoid for BANIS: sigmoid(0.2 * x)
226+ # This avoids numerical issues with high-confidence fp16 predictions
227+ channel_tensor = torch .sigmoid (0.2 * channel_tensor )
223228 elif act == "tanh" :
224229 channel_tensor = torch .tanh (channel_tensor )
225230 elif act == "softmax" :
@@ -237,7 +242,7 @@ def apply_tta_preprocessing(self, tensor: torch.Tensor) -> torch.Tensor:
237242 else :
238243 raise ValueError (
239244 f"Unknown activation '{ act } ' for channels { start_ch } :{ end_ch } . "
240- f"Supported: 'sigmoid', 'softmax', 'tanh', None"
245+ f"Supported: 'sigmoid', 'scale_sigmoid', ' softmax', 'tanh', None"
241246 )
242247
243248 activated_channels .append (channel_tensor )
@@ -334,13 +339,14 @@ def predict_with_tta(
334339 # Handle different tta_flip_axes configurations
335340 if tta_flip_axes_config is None :
336341 # null: No augmentation, but still apply tta_act and tta_channel (no ensemble)
337- if self .sliding_inferer is not None :
338- pred = self .sliding_inferer (inputs = images , network = self .sliding_window_predict )
339- else :
340- pred = self .sliding_window_predict (images )
342+ with torch .no_grad ():
343+ if self .sliding_inferer is not None :
344+ pred = self .sliding_inferer (inputs = images , network = self .sliding_window_predict )
345+ else :
346+ pred = self .sliding_window_predict (images )
341347
342- # Apply TTA preprocessing (activation + channel selection) even without augmentation
343- ensemble_result = self .apply_tta_preprocessing (pred )
348+ # Apply TTA preprocessing (activation + channel selection) even without augmentation
349+ ensemble_result = self .apply_tta_preprocessing (pred )
344350 else :
345351 if tta_flip_axes_config == "all" or tta_flip_axes_config == []:
346352 # "all" or []: All flips (all combinations of spatial axes)
@@ -369,7 +375,13 @@ def predict_with_tta(
369375 )
370376
371377 # Apply TTA with flips, preprocessing, and ensembling
372- predictions = []
378+ # Use running average to reduce memory usage instead of accumulating all predictions
379+ ensemble_mode = getattr (
380+ self .cfg .inference .test_time_augmentation , "ensemble_mode" , "mean"
381+ )
382+
383+ ensemble_result = None
384+ num_predictions = 0
373385
374386 for flip_axes in tta_flip_axes :
375387 # Apply flip augmentation
@@ -379,40 +391,52 @@ def predict_with_tta(
379391 x_aug = images
380392
381393 # Inference with sliding window
382- if self .sliding_inferer is not None :
383- pred = self .sliding_inferer (
384- inputs = x_aug ,
385- network = self .sliding_window_predict ,
386- )
387- else :
388- pred = self .sliding_window_predict (x_aug )
394+ with torch .no_grad ():
395+ if self .sliding_inferer is not None :
396+ pred = self .sliding_inferer (
397+ inputs = x_aug ,
398+ network = self .sliding_window_predict ,
399+ )
400+ else :
401+ pred = self .sliding_window_predict (x_aug )
389402
390- # Invert flip for prediction
391- if flip_axes :
392- pred = Flip (spatial_axis = flip_axes )(pred )
403+ # Invert flip for prediction
404+ if flip_axes :
405+ pred = Flip (spatial_axis = flip_axes )(pred )
393406
394- # Apply TTA preprocessing (activation + channel selection) if configured
395- # Note: This is applied BEFORE ensembling for probability-space averaging
396- pred_processed = self .apply_tta_preprocessing (pred )
407+ # Apply TTA preprocessing (activation + channel selection) if configured
408+ # Note: This is applied BEFORE ensembling for probability-space averaging
409+ pred_processed = self .apply_tta_preprocessing (pred )
397410
398- predictions .append (pred_processed )
411+ # Free intermediate memory
412+ del pred
413+ if flip_axes :
414+ del x_aug
399415
400- # Ensemble predictions based on configured mode
401- ensemble_mode = getattr (
402- self .cfg .inference .test_time_augmentation , "ensemble_mode" , "mean"
403- )
404- stacked_preds = torch .stack (predictions , dim = 0 )
405-
406- if ensemble_mode == "mean" :
407- ensemble_result = stacked_preds .mean (dim = 0 )
408- elif ensemble_mode == "min" :
409- ensemble_result = stacked_preds .min (dim = 0 )[0 ] # min returns (values, indices)
410- elif ensemble_mode == "max" :
411- ensemble_result = stacked_preds .max (dim = 0 )[0 ] # max returns (values, indices)
412- else :
413- raise ValueError (
414- f"Unknown TTA ensemble mode: { ensemble_mode } . Use 'mean', 'min', or 'max'."
415- )
416+ # Update running ensemble to reduce memory usage
417+ if ensemble_result is None :
418+ ensemble_result = pred_processed .clone ()
419+ else :
420+ if ensemble_mode == "mean" :
421+ # Running average: new_avg = old_avg + (new_val - old_avg) / n
422+ ensemble_result = ensemble_result + (pred_processed - ensemble_result ) / (num_predictions + 1 )
423+ elif ensemble_mode == "min" :
424+ ensemble_result = torch .minimum (ensemble_result , pred_processed )
425+ elif ensemble_mode == "max" :
426+ ensemble_result = torch .maximum (ensemble_result , pred_processed )
427+ else :
428+ raise ValueError (
429+ f"Unknown TTA ensemble mode: { ensemble_mode } . Use 'mean', 'min', or 'max'."
430+ )
431+
432+ num_predictions += 1
433+
434+ # Free processed prediction memory
435+ del pred_processed
436+
437+ # Force CUDA cache clear periodically to prevent OOM
438+ if torch .cuda .is_available () and num_predictions % 4 == 0 :
439+ torch .cuda .empty_cache ()
416440
417441 # Apply mask after ensemble if requested
418442 apply_mask = getattr (self .cfg .inference .test_time_augmentation , "apply_mask" , False )
@@ -868,9 +892,40 @@ def write_outputs(
868892 # Squeeze singleton dimensions (e.g., (1, 1, D, H, W) -> (D, H, W))
869893 sample = np .squeeze (sample )
870894
895+ # Convert to specified dtype if save_dtype is set
896+ save_dtype = None
897+ if hasattr (cfg .inference , "test_time_augmentation" ):
898+ save_dtype = getattr (cfg .inference .test_time_augmentation , "save_dtype" , None )
899+
900+ if save_dtype is not None :
901+ original_dtype = sample .dtype
902+ if save_dtype == "float16" :
903+ sample = sample .astype (np .float16 )
904+ elif save_dtype == "float32" :
905+ sample = sample .astype (np .float32 )
906+ elif save_dtype == "uint8" :
907+ # For uint8, detect value range and scale appropriately
908+ if sample .min () < 0 :
909+ # [-1, 1] to [0, 255]
910+ sample = ((sample + 1 ) * 127.5 ).clip (0 , 255 ).astype (np .uint8 )
911+ elif sample .max () <= 1.0 :
912+ # [0, 1] to [0, 255]
913+ sample = (sample * 255 ).clip (0 , 255 ).astype (np .uint8 )
914+ else :
915+ sample = sample .clip (0 , 255 ).astype (np .uint8 )
916+ elif save_dtype == "uint16" :
917+ # For uint16, scale from [0, 1] to [0, 65535]
918+ if sample .max () <= 1.0 :
919+ sample = (sample * 65535 ).clip (0 , 65535 ).astype (np .uint16 )
920+ else :
921+ sample = sample .clip (0 , 65535 ).astype (np .uint16 )
922+ else :
923+ print (f" WARNING: Unknown save_dtype '{ save_dtype } ', keeping original dtype" )
924+
871925 # Write HDF5 file
872926 try :
873927 write_hdf5 (str (output_path ), sample , dataset = "main" )
874- print (f" Saved prediction: { output_path } (shape: { sample .shape } )" )
928+ dtype_info = f", dtype: { sample .dtype } " if save_dtype else ""
929+ print (f" Saved prediction: { output_path } (shape: { sample .shape } { dtype_info } )" )
875930 except Exception as e :
876931 print (f" ERROR: write_outputs - failed to write { output_path } : { e } " )
0 commit comments