Skip to content

Commit 723c584

Browse files
committed
🐛 fix cosine noise scheduler
Signed-off-by: Slava Shen <shen9910@gmail.com>
1 parent b58e883 commit 723c584

File tree

2 files changed

+6
-4
lines changed

2 files changed

+6
-4
lines changed

monai/networks/schedulers/scheduler.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -105,9 +105,11 @@ def _cosine_beta(num_train_timesteps: int, s: float = 8e-3):
105105
x = torch.linspace(0, num_train_timesteps, num_train_timesteps + 1)
106106
alphas_cumprod = torch.cos(((x / num_train_timesteps) + s) / (1 + s) * torch.pi * 0.5) ** 2
107107
alphas_cumprod /= alphas_cumprod[0].item()
108-
alphas = torch.clip(alphas_cumprod[1:] / alphas_cumprod[:-1], 0.0001, 0.9999)
109-
betas = 1.0 - alphas
110-
return betas, alphas, alphas_cumprod[:-1]
108+
betas = 1.0 - (alphas_cumprod[1:] / alphas_cumprod[:-1])
109+
betas = torch.clip(betas, 0.0, 0.999)
110+
alphas = 1.0 - betas
111+
alphas_cumprod = torch.cumprod(alphas, dim=0)
112+
return betas, alphas, alphas_cumprod
111113

112114

113115
class Scheduler(nn.Module):

monai/utils/jupyter_utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -234,7 +234,7 @@ def plot_engine_status(
234234

235235

236236
def _get_loss_from_output(
237-
output: list[torch.Tensor | dict[str, torch.Tensor]] | dict[str, torch.Tensor] | torch.Tensor,
237+
output: list[torch.Tensor | dict[str, torch.Tensor]] | dict[str, torch.Tensor] | torch.Tensor
238238
) -> torch.Tensor:
239239
"""Returns a single value from the network output, which is a dict or tensor."""
240240

0 commit comments

Comments
 (0)