3
3
import math
4
4
import fire
5
5
import json
6
+
6
7
from tqdm import tqdm
7
8
from math import floor , log2
8
9
from random import random
9
10
from shutil import rmtree
10
11
from functools import partial
11
12
import multiprocessing
13
+ from contextlib import contextmanager , ExitStack
12
14
13
15
import numpy as np
14
16
@@ -107,6 +109,17 @@ def forward(self, x):
107
109
108
110
# helpers
109
111
112
+ @contextmanager
113
+ def null_context ():
114
+ yield
115
+
116
+ def combine_contexts (contexts ):
117
+ @contextmanager
118
+ def multi_contexts ():
119
+ with ExitStack () as stack :
120
+ yield [stack .enter_context (ctx ()) for ctx in contexts ]
121
+ return multi_contexts
122
+
110
123
def default (value , d ):
111
124
return d if value is None else value
112
125
@@ -127,6 +140,19 @@ def raise_if_nan(t):
127
140
if torch .isnan (t ):
128
141
raise NanException
129
142
143
+ def gradient_accumulate_contexts (gradient_accumulate_every , is_ddp , ddps ):
144
+ if is_ddp :
145
+ num_no_syncs = gradient_accumulate_every - 1
146
+ head = [combine_contexts (map (lambda ddp : ddp .no_sync , ddps ))] * num_no_syncs
147
+ tail = [null_context ]
148
+ contexts = head + tail
149
+ else :
150
+ contexts = [null_context ] * gradient_accumulate_every
151
+
152
+ for context in contexts :
153
+ with context ():
154
+ yield
155
+
130
156
def loss_backwards (fp16 , loss , optimizer , loss_id , ** kwargs ):
131
157
if fp16 :
132
158
with amp .scale_loss (loss , optimizer , loss_id ) as scaled_loss :
@@ -244,6 +270,7 @@ def __init__(self, folder, image_size, transparent = False, aug_prob = 0.):
244
270
self .folder = folder
245
271
self .image_size = image_size
246
272
self .paths = [p for ext in EXTS for p in Path (f'{ folder } ' ).glob (f'**/*.{ ext } ' )]
273
+ assert len (self .paths ) > 0 , f'No images were found in { folder } for training'
247
274
248
275
convert_image_fn = convert_transparent_to_rgb if not transparent else convert_rgb_to_transparent
249
276
num_channels = 3 if not transparent else 4
@@ -802,7 +829,7 @@ def train(self):
802
829
avg_pl_length = self .pl_mean
803
830
self .GAN .D_opt .zero_grad ()
804
831
805
- for i in range (self .gradient_accumulate_every ):
832
+ for i in gradient_accumulate_contexts (self .gradient_accumulate_every , self . is_ddp , ddps = [ D_aug ] ):
806
833
get_latents_fn = mixed_list if random () < self .mixed_prob else noise_list
807
834
style = get_latents_fn (batch_size , num_layers , latent_dim , device = self .rank )
808
835
noise = image_noise (batch_size , image_size , device = self .rank )
@@ -849,7 +876,8 @@ def train(self):
849
876
# train generator
850
877
851
878
self .GAN .G_opt .zero_grad ()
852
- for i in range (self .gradient_accumulate_every ):
879
+
880
+ for i in gradient_accumulate_contexts (self .gradient_accumulate_every , self .is_ddp , ddps = [S , G , D_aug ]):
853
881
style = get_latents_fn (batch_size , num_layers , latent_dim , device = self .rank )
854
882
noise = image_noise (batch_size , image_size , device = self .rank )
855
883
0 commit comments