Skip to content

Commit 5d7854a

Browse files
author
Donglai Wei
committed
fix github issues
1 parent c99cd91 commit 5d7854a

File tree

8 files changed

+186
-70
lines changed

8 files changed

+186
-70
lines changed

connectomics/config/hydra_config.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -845,6 +845,9 @@ class InferenceDataConfig:
845845
"predictions.h5" # Output filename (auto-pathed to inference/{checkpoint}/{output_name})
846846
)
847847

848+
# Image transformation (applied to test images during inference)
849+
image_transform: ImageTransformConfig = field(default_factory=ImageTransformConfig)
850+
848851
# 2D data support
849852
do_2d: bool = False # Enable 2D data processing for inference
850853

@@ -885,6 +888,9 @@ class TestTimeAugmentationConfig:
885888
save_predictions: bool = (
886889
False # Save intermediate TTA predictions (before decoding) to disk (default: False)
887890
)
891+
save_dtype: Optional[str] = (
892+
None # Data type for saving predictions: "float16", "float32", "uint8", "uint16", or None (keep original)
893+
)
888894

889895

890896
@dataclass

connectomics/data/augment/build.py

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -321,13 +321,19 @@ def _build_eval_transforms_impl(
321321
# else: mode == "test" -> no cropping for sliding window inference
322322

323323
# Normalization - use smart normalization
324-
if cfg.data.image_transform.normalize != "none":
324+
# For test mode, check inference.data.image_transform first, then fall back to data.image_transform
325+
if mode == "test" and hasattr(cfg, "inference") and hasattr(cfg.inference, "data") and hasattr(cfg.inference.data, "image_transform"):
326+
image_transform = cfg.inference.data.image_transform
327+
else:
328+
image_transform = cfg.data.image_transform
329+
330+
if image_transform.normalize != "none":
325331
transforms.append(
326332
SmartNormalizeIntensityd(
327333
keys=["image"],
328-
mode=cfg.data.image_transform.normalize,
329-
clip_percentile_low=cfg.data.image_transform.clip_percentile_low,
330-
clip_percentile_high=cfg.data.image_transform.clip_percentile_high,
334+
mode=image_transform.normalize,
335+
clip_percentile_low=getattr(image_transform, 'clip_percentile_low', 0.0),
336+
clip_percentile_high=getattr(image_transform, 'clip_percentile_high', 1.0),
331337
)
332338
)
333339

connectomics/data/augment/monai_transforms.py

Lines changed: 19 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1010,12 +1010,13 @@ class SmartNormalizeIntensityd(MapTransform):
10101010
- "none": No normalization
10111011
- "normal": Z-score normalization (x - mean) / std
10121012
- "0-1": Min-max scaling to [0, 1] (default)
1013+
- "divide-K": Simple divide by K (e.g., "divide-255" for uint8 images)
10131014
10141015
Percentile clipping is applied BEFORE normalization when low > 0.0 or high < 1.0.
10151016
10161017
Args:
10171018
keys: Keys to normalize
1018-
mode: Normalization mode ("none", "normal", "0-1")
1019+
mode: Normalization mode ("none", "normal", "0-1", or "divide-K")
10191020
clip_percentile_low: Lower percentile (0.0 = no clip, 0.05 = 5th percentile)
10201021
clip_percentile_high: Upper percentile (1.0 = no clip, 0.95 = 95th percentile)
10211022
allow_missing_keys: Whether to allow missing keys
@@ -1044,9 +1045,20 @@ def __init__(
10441045
allow_missing_keys: bool = False,
10451046
) -> None:
10461047
super().__init__(keys, allow_missing_keys)
1047-
if mode not in ["none", "normal", "0-1"]:
1048-
raise ValueError(f"Invalid mode '{mode}'. Must be 'none', 'normal', or '0-1'")
1049-
self.mode = mode
1048+
1049+
# Parse mode - support "divide-K" format where K is a number
1050+
self.divide_value = None
1051+
if mode.startswith("divide-"):
1052+
try:
1053+
self.divide_value = float(mode.split("-", 1)[1])
1054+
self.mode = "divide"
1055+
except ValueError:
1056+
raise ValueError(f"Invalid divide mode '{mode}'. Format should be 'divide-K' where K is a number (e.g., 'divide-255')")
1057+
elif mode not in ["none", "normal", "0-1"]:
1058+
raise ValueError(f"Invalid mode '{mode}'. Must be 'none', 'normal', '0-1', or 'divide-K'")
1059+
else:
1060+
self.mode = mode
1061+
10501062
self.clip_percentile_low = clip_percentile_low
10511063
self.clip_percentile_high = clip_percentile_high
10521064

@@ -1088,6 +1100,9 @@ def _normalize(
10881100
max_val = volume.max()
10891101
if max_val > min_val:
10901102
volume = (volume - min_val) / (max_val - min_val)
1103+
elif self.mode == "divide":
1104+
# Simple divide by K (e.g., divide-255 for uint8 images)
1105+
volume = volume / self.divide_value
10911106

10921107
return volume if is_numpy else torch.from_numpy(volume)
10931108

connectomics/data/process/monai_transforms.py

Lines changed: 22 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -67,7 +67,17 @@ def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]:
6767
d = dict(data)
6868
for key in self.key_iterator(d):
6969
if key in d:
70-
d[key] = seg_to_affinity(d[key], self.offsets)
70+
label = d[key]
71+
# Convert tensor to numpy if needed
72+
if isinstance(label, torch.Tensor):
73+
label = label.detach().cpu().numpy()
74+
# Handle channel dimension: input may be [C, D, H, W] or [D, H, W]
75+
if label.ndim == 4 and label.shape[0] == 1:
76+
label = label[0] # Remove channel dim: [1, D, H, W] -> [D, H, W]
77+
elif label.ndim == 3 and label.shape[0] == 1:
78+
# 2D case: [1, H, W] -> keep as is for 2D affinity
79+
pass
80+
d[key] = seg_to_affinity(label, self.offsets)
7181
return d
7282

7383

@@ -700,6 +710,17 @@ def _prepare_label(self, label: Any) -> Tuple[np.ndarray, bool]:
700710
return np.asarray(label), False
701711

702712
def _to_tensor(self, array: np.ndarray, *, add_batch_dim: bool) -> torch.Tensor:
713+
# Ensure array is a proper numpy array (not a numpy scalar type like numpy.uint8)
714+
# torch.as_tensor cannot infer dtype from numpy scalar types
715+
if not isinstance(array, np.ndarray):
716+
array = np.asarray(array)
717+
# Convert to a supported dtype if needed (torch doesn't support all numpy dtypes)
718+
if array.dtype == np.uint8:
719+
array = array.astype(np.float32)
720+
elif array.dtype == np.uint16:
721+
array = array.astype(np.float32)
722+
elif array.dtype == np.int8:
723+
array = array.astype(np.int32)
703724
tensor = torch.as_tensor(array)
704725
if self.output_dtype is not None:
705726
tensor = tensor.to(self.output_dtype)

connectomics/lightning/inference.py

Lines changed: 96 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -176,8 +176,9 @@ def extract_main_output(
176176

177177
def sliding_window_predict(self, inputs: torch.Tensor) -> torch.Tensor:
178178
"""Wrapper used by MONAI inferer to obtain primary model predictions."""
179-
outputs = self.forward_fn(inputs)
180-
return self.extract_main_output(outputs)
179+
with torch.no_grad():
180+
outputs = self.forward_fn(inputs)
181+
return self.extract_main_output(outputs)
181182

182183
def apply_tta_preprocessing(self, tensor: torch.Tensor) -> torch.Tensor:
183184
"""
@@ -220,6 +221,10 @@ def apply_tta_preprocessing(self, tensor: torch.Tensor) -> torch.Tensor:
220221

221222
if act == "sigmoid":
222223
channel_tensor = torch.sigmoid(channel_tensor)
224+
elif act == "scale_sigmoid":
225+
# Scaled sigmoid for BANIS: sigmoid(0.2 * x)
226+
# This avoids numerical issues with high-confidence fp16 predictions
227+
channel_tensor = torch.sigmoid(0.2 * channel_tensor)
223228
elif act == "tanh":
224229
channel_tensor = torch.tanh(channel_tensor)
225230
elif act == "softmax":
@@ -237,7 +242,7 @@ def apply_tta_preprocessing(self, tensor: torch.Tensor) -> torch.Tensor:
237242
else:
238243
raise ValueError(
239244
f"Unknown activation '{act}' for channels {start_ch}:{end_ch}. "
240-
f"Supported: 'sigmoid', 'softmax', 'tanh', None"
245+
f"Supported: 'sigmoid', 'scale_sigmoid', 'softmax', 'tanh', None"
241246
)
242247

243248
activated_channels.append(channel_tensor)
@@ -334,13 +339,14 @@ def predict_with_tta(
334339
# Handle different tta_flip_axes configurations
335340
if tta_flip_axes_config is None:
336341
# null: No augmentation, but still apply tta_act and tta_channel (no ensemble)
337-
if self.sliding_inferer is not None:
338-
pred = self.sliding_inferer(inputs=images, network=self.sliding_window_predict)
339-
else:
340-
pred = self.sliding_window_predict(images)
342+
with torch.no_grad():
343+
if self.sliding_inferer is not None:
344+
pred = self.sliding_inferer(inputs=images, network=self.sliding_window_predict)
345+
else:
346+
pred = self.sliding_window_predict(images)
341347

342-
# Apply TTA preprocessing (activation + channel selection) even without augmentation
343-
ensemble_result = self.apply_tta_preprocessing(pred)
348+
# Apply TTA preprocessing (activation + channel selection) even without augmentation
349+
ensemble_result = self.apply_tta_preprocessing(pred)
344350
else:
345351
if tta_flip_axes_config == "all" or tta_flip_axes_config == []:
346352
# "all" or []: All flips (all combinations of spatial axes)
@@ -369,7 +375,13 @@ def predict_with_tta(
369375
)
370376

371377
# Apply TTA with flips, preprocessing, and ensembling
372-
predictions = []
378+
# Use running average to reduce memory usage instead of accumulating all predictions
379+
ensemble_mode = getattr(
380+
self.cfg.inference.test_time_augmentation, "ensemble_mode", "mean"
381+
)
382+
383+
ensemble_result = None
384+
num_predictions = 0
373385

374386
for flip_axes in tta_flip_axes:
375387
# Apply flip augmentation
@@ -379,40 +391,52 @@ def predict_with_tta(
379391
x_aug = images
380392

381393
# Inference with sliding window
382-
if self.sliding_inferer is not None:
383-
pred = self.sliding_inferer(
384-
inputs=x_aug,
385-
network=self.sliding_window_predict,
386-
)
387-
else:
388-
pred = self.sliding_window_predict(x_aug)
394+
with torch.no_grad():
395+
if self.sliding_inferer is not None:
396+
pred = self.sliding_inferer(
397+
inputs=x_aug,
398+
network=self.sliding_window_predict,
399+
)
400+
else:
401+
pred = self.sliding_window_predict(x_aug)
389402

390-
# Invert flip for prediction
391-
if flip_axes:
392-
pred = Flip(spatial_axis=flip_axes)(pred)
403+
# Invert flip for prediction
404+
if flip_axes:
405+
pred = Flip(spatial_axis=flip_axes)(pred)
393406

394-
# Apply TTA preprocessing (activation + channel selection) if configured
395-
# Note: This is applied BEFORE ensembling for probability-space averaging
396-
pred_processed = self.apply_tta_preprocessing(pred)
407+
# Apply TTA preprocessing (activation + channel selection) if configured
408+
# Note: This is applied BEFORE ensembling for probability-space averaging
409+
pred_processed = self.apply_tta_preprocessing(pred)
397410

398-
predictions.append(pred_processed)
411+
# Free intermediate memory
412+
del pred
413+
if flip_axes:
414+
del x_aug
399415

400-
# Ensemble predictions based on configured mode
401-
ensemble_mode = getattr(
402-
self.cfg.inference.test_time_augmentation, "ensemble_mode", "mean"
403-
)
404-
stacked_preds = torch.stack(predictions, dim=0)
405-
406-
if ensemble_mode == "mean":
407-
ensemble_result = stacked_preds.mean(dim=0)
408-
elif ensemble_mode == "min":
409-
ensemble_result = stacked_preds.min(dim=0)[0] # min returns (values, indices)
410-
elif ensemble_mode == "max":
411-
ensemble_result = stacked_preds.max(dim=0)[0] # max returns (values, indices)
412-
else:
413-
raise ValueError(
414-
f"Unknown TTA ensemble mode: {ensemble_mode}. Use 'mean', 'min', or 'max'."
415-
)
416+
# Update running ensemble to reduce memory usage
417+
if ensemble_result is None:
418+
ensemble_result = pred_processed.clone()
419+
else:
420+
if ensemble_mode == "mean":
421+
# Running average: new_avg = old_avg + (new_val - old_avg) / n
422+
ensemble_result = ensemble_result + (pred_processed - ensemble_result) / (num_predictions + 1)
423+
elif ensemble_mode == "min":
424+
ensemble_result = torch.minimum(ensemble_result, pred_processed)
425+
elif ensemble_mode == "max":
426+
ensemble_result = torch.maximum(ensemble_result, pred_processed)
427+
else:
428+
raise ValueError(
429+
f"Unknown TTA ensemble mode: {ensemble_mode}. Use 'mean', 'min', or 'max'."
430+
)
431+
432+
num_predictions += 1
433+
434+
# Free processed prediction memory
435+
del pred_processed
436+
437+
# Force CUDA cache clear periodically to prevent OOM
438+
if torch.cuda.is_available() and num_predictions % 4 == 0:
439+
torch.cuda.empty_cache()
416440

417441
# Apply mask after ensemble if requested
418442
apply_mask = getattr(self.cfg.inference.test_time_augmentation, "apply_mask", False)
@@ -868,9 +892,40 @@ def write_outputs(
868892
# Squeeze singleton dimensions (e.g., (1, 1, D, H, W) -> (D, H, W))
869893
sample = np.squeeze(sample)
870894

895+
# Convert to specified dtype if save_dtype is set
896+
save_dtype = None
897+
if hasattr(cfg.inference, "test_time_augmentation"):
898+
save_dtype = getattr(cfg.inference.test_time_augmentation, "save_dtype", None)
899+
900+
if save_dtype is not None:
901+
original_dtype = sample.dtype
902+
if save_dtype == "float16":
903+
sample = sample.astype(np.float16)
904+
elif save_dtype == "float32":
905+
sample = sample.astype(np.float32)
906+
elif save_dtype == "uint8":
907+
# For uint8, detect value range and scale appropriately
908+
if sample.min() < 0:
909+
# [-1, 1] to [0, 255]
910+
sample = ((sample + 1) * 127.5).clip(0, 255).astype(np.uint8)
911+
elif sample.max() <= 1.0:
912+
# [0, 1] to [0, 255]
913+
sample = (sample * 255).clip(0, 255).astype(np.uint8)
914+
else:
915+
sample = sample.clip(0, 255).astype(np.uint8)
916+
elif save_dtype == "uint16":
917+
# For uint16, scale from [0, 1] to [0, 65535]
918+
if sample.max() <= 1.0:
919+
sample = (sample * 65535).clip(0, 65535).astype(np.uint16)
920+
else:
921+
sample = sample.clip(0, 65535).astype(np.uint16)
922+
else:
923+
print(f" WARNING: Unknown save_dtype '{save_dtype}', keeping original dtype")
924+
871925
# Write HDF5 file
872926
try:
873927
write_hdf5(str(output_path), sample, dataset="main")
874-
print(f" Saved prediction: {output_path} (shape: {sample.shape})")
928+
dtype_info = f", dtype: {sample.dtype}" if save_dtype else ""
929+
print(f" Saved prediction: {output_path} (shape: {sample.shape}{dtype_info})")
875930
except Exception as e:
876931
print(f" ERROR: write_outputs - failed to write {output_path}: {e}")

connectomics/lightning/lit_model.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -337,6 +337,10 @@ def _apply_tta_preprocessing(self, tensor: torch.Tensor) -> torch.Tensor:
337337

338338
if act == 'sigmoid':
339339
channel_tensor = torch.sigmoid(channel_tensor)
340+
elif act == 'scale_sigmoid':
341+
# Scaled sigmoid for BANIS: sigmoid(0.2 * x)
342+
# This avoids numerical issues with high-confidence fp16 predictions
343+
channel_tensor = torch.sigmoid(0.2 * channel_tensor)
340344
elif act == 'tanh':
341345
channel_tensor = torch.tanh(channel_tensor)
342346
elif act == 'softmax':
@@ -354,7 +358,7 @@ def _apply_tta_preprocessing(self, tensor: torch.Tensor) -> torch.Tensor:
354358
else:
355359
raise ValueError(
356360
f"Unknown activation '{act}' for channels {start_ch}:{end_ch}. "
357-
f"Supported: 'sigmoid', 'softmax', 'tanh', None"
361+
f"Supported: 'sigmoid', 'scale_sigmoid', 'softmax', 'tanh', None"
358362
)
359363

360364
activated_channels.append(channel_tensor)

connectomics/utils/demo.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -80,8 +80,8 @@ def create_demo_config():
8080
from connectomics.config import Config
8181
from connectomics.config.hydra_config import (
8282
SystemConfig,
83-
TrainingSystemConfig,
84-
InferenceSystemConfig,
83+
SystemTrainingConfig,
84+
SystemInferenceConfig,
8585
ModelConfig,
8686
DataConfig,
8787
OptimizationConfig,
@@ -97,13 +97,13 @@ def create_demo_config():
9797
cfg = Config(
9898
system=SystemConfig(
9999
seed=42,
100-
training=TrainingSystemConfig(
100+
training=SystemTrainingConfig(
101101
num_gpus=1 if torch.cuda.is_available() else 0,
102102
num_cpus=2,
103103
batch_size=2,
104104
num_workers=0, # 0 for demo to avoid multiprocessing issues
105105
),
106-
inference=InferenceSystemConfig(
106+
inference=SystemInferenceConfig(
107107
num_gpus=1 if torch.cuda.is_available() else 0,
108108
num_cpus=2,
109109
batch_size=2,

0 commit comments

Comments
 (0)