Skip to content

Commit 73f822c

Browse files
committed
--unsafe-fixes
1 parent 4d729d6 commit 73f822c

File tree

11 files changed

+37
-36
lines changed

11 files changed

+37
-36
lines changed

monai/apps/deepgrow/dataset.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -201,7 +201,9 @@ def _save_data_2d(vol_idx, vol_image, vol_label, dataset_dir, relative_path):
201201
logging.warning(f"Unique labels {unique_labels_count} exceeds 20. Please check if this is correct.")
202202

203203
logging.info(
204-
f"{vol_idx} => Image Shape: {vol_image.shape} => {image_count}; Label Shape: {vol_label.shape if vol_label is not None else None} => {label_count}; Unique Labels: {unique_labels_count}"
204+
f"{vol_idx} => Image Shape: {vol_image.shape} => {image_count};"
205+
f" Label Shape: {vol_label.shape if vol_label is not None else None} => {label_count};"
206+
f" Unique Labels: {unique_labels_count}"
205207
)
206208
return data_list
207209

@@ -252,6 +254,8 @@ def _save_data_3d(vol_idx, vol_image, vol_label, dataset_dir, relative_path):
252254
logging.warning(f"Unique labels {unique_labels_count} exceeds 20. Please check if this is correct.")
253255

254256
logging.info(
255-
f"{vol_idx} => Image Shape: {vol_image.shape} => {image_count}; Label Shape: {vol_label.shape if vol_label is not None else None} => {label_count}; Unique Labels: {unique_labels_count}"
257+
f"{vol_idx} => Image Shape: {vol_image.shape} => {image_count};"
258+
f" Label Shape: {vol_label.shape if vol_label is not None else None} => {label_count};"
259+
f" Unique Labels: {unique_labels_count}"
256260
)
257261
return data_list

monai/apps/nnunet/nnunet_bundle.py

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
import os
1414
import shutil
1515
from pathlib import Path
16-
from typing import Any, Optional, Union
16+
from typing import Any
1717

1818
import numpy as np
1919
import torch
@@ -36,17 +36,17 @@
3636

3737

3838
def get_nnunet_trainer(
39-
dataset_name_or_id: Union[str, int],
39+
dataset_name_or_id: str | int,
4040
configuration: str,
41-
fold: Union[int, str],
41+
fold: int | str,
4242
trainer_class_name: str = "nnUNetTrainer",
4343
plans_identifier: str = "nnUNetPlans",
4444
use_compressed_data: bool = False,
4545
continue_training: bool = False,
4646
only_run_validation: bool = False,
4747
disable_checkpointing: bool = False,
4848
device: str = "cuda",
49-
pretrained_model: Optional[str] = None,
49+
pretrained_model: str | None = None,
5050
) -> Any: # type: ignore
5151
"""
5252
Get the nnUNet trainer instance based on the provided configuration.
@@ -166,7 +166,7 @@ class ModelnnUNetWrapper(torch.nn.Module):
166166
restoring network architecture, and setting up the predictor for inference.
167167
"""
168168

169-
def __init__(self, predictor: object, model_folder: Union[str, Path], model_name: str = "model.pt"): # type: ignore
169+
def __init__(self, predictor: object, model_folder: str | Path, model_name: str = "model.pt"): # type: ignore
170170
super().__init__()
171171
self.predictor = predictor
172172

@@ -294,7 +294,7 @@ def forward(self, x: MetaTensor) -> MetaTensor:
294294
return MetaTensor(out_tensor, meta=x.meta)
295295

296296

297-
def get_nnunet_monai_predictor(model_folder: Union[str, Path], model_name: str = "model.pt") -> ModelnnUNetWrapper:
297+
def get_nnunet_monai_predictor(model_folder: str | Path, model_name: str = "model.pt") -> ModelnnUNetWrapper:
298298
"""
299299
Initializes and returns a `nnUNetMONAIModelWrapper` containing the corresponding `nnUNetPredictor`.
300300
The model folder should contain the following files, created during training:
@@ -426,9 +426,9 @@ def get_network_from_nnunet_plans(
426426
plans_file: str,
427427
dataset_file: str,
428428
configuration: str,
429-
model_ckpt: Optional[str] = None,
429+
model_ckpt: str | None = None,
430430
model_key_in_ckpt: str = "model",
431-
) -> Union[torch.nn.Module, Any]:
431+
) -> torch.nn.Module | Any:
432432
"""
433433
Load and initialize a nnUNet network based on nnUNet plans and configuration.
434434
@@ -518,7 +518,7 @@ def convert_monai_bundle_to_nnunet(nnunet_config: dict, bundle_root_folder: str,
518518
from nnunetv2.utilities.dataset_name_id_conversion import maybe_convert_to_dataset_name
519519

520520
def subfiles(
521-
folder: Union[str, Path], prefix: Optional[str] = None, suffix: Optional[str] = None, sort: bool = True
521+
folder: str | Path, prefix: str | None = None, suffix: str | None = None, sort: bool = True
522522
) -> list[str]:
523523
res = [
524524
i.name

monai/losses/adversarial_loss.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -57,8 +57,7 @@ def __init__(
5757

5858
if criterion.lower() not in list(AdversarialCriterions):
5959
raise ValueError(
60-
"Unrecognised criterion entered for Adversarial Loss. Must be one in: %s"
61-
% ", ".join(AdversarialCriterions)
60+
"Unrecognised criterion entered for Adversarial Loss. Must be one in: {}".format(", ".join(AdversarialCriterions))
6261
)
6362

6463
# Depending on the criterion, a different activation layer is used.

monai/losses/dice.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -494,7 +494,7 @@ def __init__(
494494
raise ValueError(f"dist_matrix must be C x C, got {dist_matrix.shape[0]} x {dist_matrix.shape[1]}.")
495495

496496
if weighting_mode not in ["default", "GDL"]:
497-
raise ValueError("weighting_mode must be either 'default' or 'GDL, got %s." % weighting_mode)
497+
raise ValueError(f"weighting_mode must be either 'default' or 'GDL, got {weighting_mode}.")
498498

499499
self.m = dist_matrix
500500
if isinstance(self.m, np.ndarray):

monai/losses/ds_loss.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,6 @@
1111

1212
from __future__ import annotations
1313

14-
from typing import Union
1514

1615
import torch
1716
import torch.nn.functional as F
@@ -70,7 +69,7 @@ def get_loss(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
7069
target = F.interpolate(target, size=input.shape[2:], mode=self.interp_mode)
7170
return self.loss(input, target) # type: ignore[no-any-return]
7271

73-
def forward(self, input: Union[None, torch.Tensor, list[torch.Tensor]], target: torch.Tensor) -> torch.Tensor:
72+
def forward(self, input: None | torch.Tensor | list[torch.Tensor], target: torch.Tensor) -> torch.Tensor:
7473
if isinstance(input, (list, tuple)):
7574
weights = self.get_weights(levels=len(input))
7675
loss = torch.tensor(0, dtype=torch.float, device=target.device)

monai/losses/focal_loss.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,6 @@
1313

1414
import warnings
1515
from collections.abc import Sequence
16-
from typing import Optional
1716

1817
import torch
1918
import torch.nn.functional as F
@@ -153,7 +152,7 @@ def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
153152
if target.shape != input.shape:
154153
raise ValueError(f"ground truth has different shape ({target.shape}) from input ({input.shape})")
155154

156-
loss: Optional[torch.Tensor] = None
155+
loss: torch.Tensor | None = None
157156
input = input.float()
158157
target = target.float()
159158
if self.use_softmax:
@@ -203,7 +202,7 @@ def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
203202

204203

205204
def softmax_focal_loss(
206-
input: torch.Tensor, target: torch.Tensor, gamma: float = 2.0, alpha: Optional[float] = None
205+
input: torch.Tensor, target: torch.Tensor, gamma: float = 2.0, alpha: float | None = None
207206
) -> torch.Tensor:
208207
"""
209208
FL(pt) = -alpha * (1 - pt)**gamma * log(pt)
@@ -225,7 +224,7 @@ def softmax_focal_loss(
225224

226225

227226
def sigmoid_focal_loss(
228-
input: torch.Tensor, target: torch.Tensor, gamma: float = 2.0, alpha: Optional[float] = None
227+
input: torch.Tensor, target: torch.Tensor, gamma: float = 2.0, alpha: float | None = None
229228
) -> torch.Tensor:
230229
"""
231230
FL(pt) = -alpha * (1 - pt)**gamma * log(pt)

monai/losses/perceptual.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -95,8 +95,7 @@ def __init__(
9595

9696
if network_type.lower() not in list(PercetualNetworkType):
9797
raise ValueError(
98-
"Unrecognised criterion entered for Adversarial Loss. Must be one in: %s"
99-
% ", ".join(PercetualNetworkType)
98+
"Unrecognised criterion entered for Adversarial Loss. Must be one in: {}".format(", ".join(PercetualNetworkType))
10099
)
101100

102101
if cache_dir:

monai/losses/spatial_mask.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
import inspect
1515
import warnings
1616
from collections.abc import Callable
17-
from typing import Any, Optional
17+
from typing import Any
1818

1919
import torch
2020
from torch.nn.modules.loss import _Loss
@@ -47,7 +47,7 @@ def __init__(
4747
if not callable(self.loss):
4848
raise ValueError("The loss function is not callable.")
4949

50-
def forward(self, input: torch.Tensor, target: torch.Tensor, mask: Optional[torch.Tensor] = None) -> torch.Tensor:
50+
def forward(self, input: torch.Tensor, target: torch.Tensor, mask: torch.Tensor | None = None) -> torch.Tensor:
5151
"""
5252
Args:
5353
input: the shape should be BNH[WD].

monai/losses/sure_loss.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111

1212
from __future__ import annotations
1313

14-
from typing import Callable, Optional
14+
from typing import Callable
1515

1616
import torch
1717
import torch.nn as nn
@@ -42,10 +42,10 @@ def sure_loss_function(
4242
operator: Callable,
4343
x: torch.Tensor,
4444
y_pseudo_gt: torch.Tensor,
45-
y_ref: Optional[torch.Tensor] = None,
46-
eps: Optional[float] = -1.0,
47-
perturb_noise: Optional[torch.Tensor] = None,
48-
complex_input: Optional[bool] = False,
45+
y_ref: torch.Tensor | None = None,
46+
eps: float | None = -1.0,
47+
perturb_noise: torch.Tensor | None = None,
48+
complex_input: bool | None = False,
4949
) -> torch.Tensor:
5050
"""
5151
Args:
@@ -131,7 +131,7 @@ class SURELoss(_Loss):
131131
(https://arxiv.org/pdf/2310.01799.pdf)
132132
"""
133133

134-
def __init__(self, perturb_noise: Optional[torch.Tensor] = None, eps: Optional[float] = None) -> None:
134+
def __init__(self, perturb_noise: torch.Tensor | None = None, eps: float | None = None) -> None:
135135
"""
136136
Args:
137137
perturb_noise (torch.Tensor, optional): The noise vector of shape
@@ -149,8 +149,8 @@ def forward(
149149
operator: Callable,
150150
x: torch.Tensor,
151151
y_pseudo_gt: torch.Tensor,
152-
y_ref: Optional[torch.Tensor] = None,
153-
complex_input: Optional[bool] = False,
152+
y_ref: torch.Tensor | None = None,
153+
complex_input: bool | None = False,
154154
) -> torch.Tensor:
155155
"""
156156
Args:

monai/transforms/utility/array.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121
from collections.abc import Hashable, Mapping, Sequence
2222
from copy import deepcopy
2323
from functools import partial
24-
from typing import Any, Callable, Union
24+
from typing import Any, Callable
2525

2626
import numpy as np
2727
import torch
@@ -1216,7 +1216,7 @@ def __init__(self, name: str, *args, **kwargs) -> None:
12161216
transform, _ = optional_import("torchio.transforms", "0.18.0", min_version, name=name)
12171217
self.trans = transform(*args, **kwargs)
12181218

1219-
def __call__(self, img: Union[NdarrayOrTensor, Mapping[Hashable, NdarrayOrTensor]]):
1219+
def __call__(self, img: NdarrayOrTensor | Mapping[Hashable, NdarrayOrTensor]):
12201220
"""
12211221
Args:
12221222
img: an instance of torchio.Subject, torchio.Image, numpy.ndarray, torch.Tensor, SimpleITK.Image,
@@ -1248,7 +1248,7 @@ def __init__(self, name: str, *args, **kwargs) -> None:
12481248
transform, _ = optional_import("torchio.transforms", "0.18.0", min_version, name=name)
12491249
self.trans = transform(*args, **kwargs)
12501250

1251-
def __call__(self, img: Union[NdarrayOrTensor, Mapping[Hashable, NdarrayOrTensor]]):
1251+
def __call__(self, img: NdarrayOrTensor | Mapping[Hashable, NdarrayOrTensor]):
12521252
"""
12531253
Args:
12541254
img: an instance of torchio.Subject, torchio.Image, numpy.ndarray, torch.Tensor, SimpleITK.Image,

0 commit comments

Comments
 (0)