Skip to content

Commit a3ba351

Browse files
committed
properly manage gradient accumulations for distributed case
1 parent d0bab13 commit a3ba351

File tree

2 files changed

+31
-3
lines changed

2 files changed

+31
-3
lines changed

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
'stylegan2_pytorch = stylegan2_pytorch.cli:main',
99
],
1010
},
11-
version = '0.22.0',
11+
version = '0.22.1',
1212
license='GPLv3+',
1313
description = 'StyleGan2 in Pytorch',
1414
author = 'Phil Wang',

stylegan2_pytorch/stylegan2_pytorch.py

Lines changed: 30 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,12 +3,14 @@
33
import math
44
import fire
55
import json
6+
67
from tqdm import tqdm
78
from math import floor, log2
89
from random import random
910
from shutil import rmtree
1011
from functools import partial
1112
import multiprocessing
13+
from contextlib import contextmanager, ExitStack
1214

1315
import numpy as np
1416

@@ -107,6 +109,17 @@ def forward(self, x):
107109

108110
# helpers
109111

112+
@contextmanager
113+
def null_context():
114+
yield
115+
116+
def combine_contexts(contexts):
117+
@contextmanager
118+
def multi_contexts():
119+
with ExitStack() as stack:
120+
yield [stack.enter_context(ctx()) for ctx in contexts]
121+
return multi_contexts
122+
110123
def default(value, d):
111124
return d if value is None else value
112125

@@ -127,6 +140,19 @@ def raise_if_nan(t):
127140
if torch.isnan(t):
128141
raise NanException
129142

143+
def gradient_accumulate_contexts(gradient_accumulate_every, is_ddp, ddps):
144+
if is_ddp:
145+
num_no_syncs = gradient_accumulate_every - 1
146+
head = [combine_contexts(map(lambda ddp: ddp.no_sync, ddps))] * num_no_syncs
147+
tail = [null_context]
148+
contexts = head + tail
149+
else:
150+
contexts = [null_context] * gradient_accumulate_every
151+
152+
for context in contexts:
153+
with context():
154+
yield
155+
130156
def loss_backwards(fp16, loss, optimizer, loss_id, **kwargs):
131157
if fp16:
132158
with amp.scale_loss(loss, optimizer, loss_id) as scaled_loss:
@@ -244,6 +270,7 @@ def __init__(self, folder, image_size, transparent = False, aug_prob = 0.):
244270
self.folder = folder
245271
self.image_size = image_size
246272
self.paths = [p for ext in EXTS for p in Path(f'{folder}').glob(f'**/*.{ext}')]
273+
assert len(self.paths) > 0, f'No images were found in {folder} for training'
247274

248275
convert_image_fn = convert_transparent_to_rgb if not transparent else convert_rgb_to_transparent
249276
num_channels = 3 if not transparent else 4
@@ -802,7 +829,7 @@ def train(self):
802829
avg_pl_length = self.pl_mean
803830
self.GAN.D_opt.zero_grad()
804831

805-
for i in range(self.gradient_accumulate_every):
832+
for i in gradient_accumulate_contexts(self.gradient_accumulate_every, self.is_ddp, ddps=[D_aug]):
806833
get_latents_fn = mixed_list if random() < self.mixed_prob else noise_list
807834
style = get_latents_fn(batch_size, num_layers, latent_dim, device=self.rank)
808835
noise = image_noise(batch_size, image_size, device=self.rank)
@@ -849,7 +876,8 @@ def train(self):
849876
# train generator
850877

851878
self.GAN.G_opt.zero_grad()
852-
for i in range(self.gradient_accumulate_every):
879+
880+
for i in gradient_accumulate_contexts(self.gradient_accumulate_every, self.is_ddp, ddps=[S, G, D_aug]):
853881
style = get_latents_fn(batch_size, num_layers, latent_dim, device=self.rank)
854882
noise = image_noise(batch_size, image_size, device=self.rank)
855883

0 commit comments

Comments
 (0)