Skip to content

Commit 1891844

Browse files
committed
another fix for training under amp
1 parent 36052f1 commit 1891844

8 files changed

+21
-7
lines changed

denoising_diffusion_pytorch/classifier_free_guidance.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
import torch
1010
from torch import nn, einsum
1111
import torch.nn.functional as F
12+
from torch.cuda.amp import autocast
1213

1314
from einops import rearrange, reduce, repeat
1415
from einops.layers.torch import Rearrange
@@ -711,6 +712,7 @@ def interpolate(self, x1, x2, t = None, lam = 0.5):
711712

712713
return img
713714

715+
@autocast(enabled = False)
714716
def q_sample(self, x_start, t, noise=None):
715717
noise = default(noise, lambda: torch.randn_like(x_start))
716718

denoising_diffusion_pytorch/continuous_time_gaussian_diffusion.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
from torch import sqrt
44
from torch import nn, einsum
55
import torch.nn.functional as F
6+
from torch.cuda.amp import autocast
67
from torch.special import expm1
78

89
from tqdm import tqdm
@@ -233,6 +234,7 @@ def sample(self, batch_size = 16):
233234

234235
# training related functions - noise prediction
235236

237+
@autocast(enabled = False)
236238
def q_sample(self, x_start, times, noise = None):
237239
noise = default(noise, lambda: torch.randn_like(x_start))
238240

denoising_diffusion_pytorch/denoising_diffusion_pytorch.py

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -706,6 +706,7 @@ def interpolate(self, x1, x2, t = None, lam = 0.5):
706706

707707
return img
708708

709+
@autocast(enabled = False)
709710
def q_sample(self, x_start, t, noise = None):
710711
noise = default(noise, lambda: torch.randn_like(x_start))
711712

@@ -823,23 +824,22 @@ def __init__(
823824
num_samples = 25,
824825
results_folder = './results',
825826
amp = False,
826-
fp16 = False,
827+
mixed_precision_type = 'fp16',
827828
split_batches = True,
828829
convert_image_to = None,
829830
calculate_fid = True,
830-
inception_block_idx = 2048
831+
inception_block_idx = 2048,
832+
max_grad_norm = 1.
831833
):
832834
super().__init__()
833835

834836
# accelerator
835837

836838
self.accelerator = Accelerator(
837839
split_batches = split_batches,
838-
mixed_precision = 'fp16' if fp16 else 'no'
840+
mixed_precision = mixed_precision_type if amp else 'no'
839841
)
840842

841-
self.accelerator.native_amp = amp
842-
843843
# model
844844

845845
self.model = diffusion_model
@@ -867,6 +867,8 @@ def __init__(
867867
self.train_num_steps = train_num_steps
868868
self.image_size = diffusion_model.image_size
869869

870+
self.max_grad_norm = max_grad_norm
871+
870872
# dataset and dataloader
871873

872874
self.ds = Dataset(folder, self.image_size, augment_horizontal_flip = augment_horizontal_flip, convert_image_to = convert_image_to)
@@ -980,7 +982,7 @@ def train(self):
980982

981983
self.accelerator.backward(loss)
982984

983-
accelerator.clip_grad_norm_(self.model.parameters(), 1.0)
985+
accelerator.clip_grad_norm_(self.model.parameters(), self.max_grad_norm)
984986
pbar.set_description(f'loss: {total_loss:.4f}')
985987

986988
accelerator.wait_for_everyone()

denoising_diffusion_pytorch/denoising_diffusion_pytorch_1d.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
from ema_pytorch import EMA
1111
from torch import nn, einsum
1212
import torch.nn.functional as F
13+
from torch.cuda.amp import autocast
1314

1415
from einops import rearrange, reduce
1516
from einops.layers.torch import Rearrange
@@ -639,6 +640,7 @@ def interpolate(self, x1, x2, t = None, lam = 0.5):
639640

640641
return img
641642

643+
@autocast(enabled = False)
642644
def q_sample(self, x_start, t, noise=None):
643645
noise = default(noise, lambda: torch.randn_like(x_start))
644646

denoising_diffusion_pytorch/guided_diffusion.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
import torch
1010
from torch import nn, einsum
1111
import torch.nn.functional as F
12+
from torch.cuda.amp import autocast
1213
from torch.utils.data import Dataset, DataLoader
1314

1415
from torch.optim import Adam
@@ -708,6 +709,7 @@ def interpolate(self, x1, x2, t = None, lam = 0.5):
708709

709710
return img
710711

712+
@autocast(enabled = False)
711713
def q_sample(self, x_start, t, noise=None):
712714
noise = default(noise, lambda: torch.randn_like(x_start))
713715

denoising_diffusion_pytorch/simple_diffusion.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
from torch import nn, einsum
77
import torch.nn.functional as F
88
from torch.special import expm1
9+
from torch.cuda.amp import autocast
910

1011
from tqdm import tqdm
1112
from einops import rearrange, repeat, reduce, pack, unpack
@@ -653,6 +654,7 @@ def sample(self, batch_size = 16):
653654

654655
# training related functions - noise prediction
655656

657+
@autocast(enabled = False)
656658
def q_sample(self, x_start, times, noise = None):
657659
noise = default(noise, lambda: torch.randn_like(x_start))
658660

denoising_diffusion_pytorch/v_param_continuous_time_gaussian_diffusion.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
from torch import nn, einsum
55
import torch.nn.functional as F
66
from torch.special import expm1
7+
from torch.cuda.amp import autocast
78

89
from tqdm import tqdm
910
from einops import rearrange, repeat, reduce
@@ -149,6 +150,7 @@ def sample(self, batch_size = 16):
149150

150151
# training related functions - noise prediction
151152

153+
@autocast(enabled = False)
152154
def q_sample(self, x_start, times, noise = None):
153155
noise = default(noise, lambda: torch.randn_like(x_start))
154156

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
__version__ = '1.7.4'
1+
__version__ = '1.7.6'

0 commit comments

Comments
 (0)