Skip to content

Commit 59bc8d4

Browse files
committed
add image augmentation feature for low data settings
1 parent a88681e commit 59bc8d4

File tree

4 files changed

+78
-10
lines changed

4 files changed

+78
-10
lines changed

README.md

Lines changed: 27 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -79,7 +79,22 @@ If a previous checkpoint contained a better generator, (which often happens as g
7979
$ stylegan2_pytorch --generate --load-from {checkpoint number}
8080
```
8181

82-
### Attention
82+
## Low amounts of Training Data
83+
84+
In the past, GANs needed a lot of data to learn how to generate well. The faces model took **70k** high quality images from Flickr, as an example.
85+
86+
However, in the month of May 2020, researchers all across the world independently converged on a simple technique to reduce that number to as low as **1-2k**. That simple idea was to differentiably augment all images, generated or real, going into the discriminator during training.
87+
88+
If one were to augment at a low enough probability, the augmentations will not 'leak' into the generations.
89+
90+
In the setting of low data, you can use the feature with a simple flag.
91+
92+
```bash
93+
# find a suitable probability between 0. -> 0.7 at maximum
94+
$ stylegan2_pytorch --data ./data --aug-prob 0.25
95+
```
96+
97+
## Attention
8398

8499
This framework also allows for you to add an efficient form of self-attention to the designated layers of the discriminator (and the symmetric layer of the generator), which will greatly improve results. The more attention you can afford, the better!
85100

@@ -277,4 +292,15 @@ Thank you to Matthew Mann for his inspiring [simple port](https://github.com/man
277292
eprint = {2006.02595},
278293
archivePrefix = {arXiv}
279294
}
295+
```
296+
297+
```bibtex
298+
@misc{karras2020training,
299+
title = {Training Generative Adversarial Networks with Limited Data},
300+
author = {Tero Karras and Miika Aittala and Janne Hellsten and Samuli Laine and Jaakko Lehtinen and Timo Aila},
301+
year = {2020},
302+
eprint = {2006.06676},
303+
archivePrefix = {arXiv},
304+
primaryClass = {cs.CV}
305+
}
280306
```

bin/stylegan2_pytorch

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,8 @@ def train_from_folder(
3030
fq_layers = [],
3131
fq_dict_size = 256,
3232
attn_layers = [],
33-
no_const = False
33+
no_const = False,
34+
aug_prob = 0.
3435
):
3536
model = Trainer(
3637
name,
@@ -50,7 +51,8 @@ def train_from_folder(
5051
fq_layers = fq_layers,
5152
fq_dict_size = fq_dict_size,
5253
attn_layers = attn_layers,
53-
no_const = no_const
54+
no_const = no_const,
55+
aug_prob = aug_prob
5456
)
5557

5658
if not new:

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
name = 'stylegan2_pytorch',
55
packages = find_packages(),
66
scripts=['bin/stylegan2_pytorch'],
7-
version = '0.15.0',
7+
version = '0.16.0',
88
license='GPLv3+',
99
description = 'StyleGan2 in Pytorch',
1010
author = 'Phil Wang',

stylegan2_pytorch/stylegan2_pytorch.py

Lines changed: 46 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -213,7 +213,6 @@ def __init__(self, folder, image_size, transparent = False):
213213
self.transform = transforms.Compose([
214214
transforms.Lambda(convert_image_fn),
215215
transforms.Lambda(partial(resize_to_minimum_size, image_size)),
216-
transforms.RandomHorizontalFlip(),
217216
transforms.Resize(image_size),
218217
transforms.CenterCrop(image_size),
219218
transforms.ToTensor(),
@@ -228,6 +227,41 @@ def __getitem__(self, index):
228227
img = Image.open(path)
229228
return self.transform(img)
230229

230+
# augmentations
231+
232+
def random_float(lo, hi):
233+
return lo + (hi - lo) * random()
234+
235+
def random_crop_and_resize(tensor, scale):
236+
b, c, h, _ = tensor.shape
237+
new_width = int(h * scale)
238+
delta = h - new_width
239+
h_delta = int(random() * delta)
240+
w_delta = int(random() * delta)
241+
cropped = tensor[:, :, h_delta:(h_delta + new_width), w_delta:(w_delta + new_width)].clone()
242+
return F.interpolate(cropped, size=(h, h))
243+
244+
def random_hflip(tensor, prob):
245+
if prob > random():
246+
return tensor
247+
return torch.flip(tensor, dims=(3,))
248+
249+
class AugWrapper(nn.Module):
250+
def __init__(self, D, image_size):
251+
super().__init__()
252+
self.D = D
253+
254+
def forward(self, images, prob = 0., detach = False):
255+
if random() < prob:
256+
random_scale = random_float(0.5, 0.9)
257+
images = random_hflip(images, prob=0.5)
258+
images = random_crop_and_resize(images, scale = random_scale)
259+
260+
if detach:
261+
images.detach_()
262+
263+
return self.D(images)
264+
231265
# stylegan2 classes
232266

233267
class StyleVectorizer(nn.Module):
@@ -495,7 +529,10 @@ def __init__(self, image_size, latent_dim = 512, style_depth = 8, network_capaci
495529
self.GE = Generator(image_size, latent_dim, network_capacity, transparent = transparent, attn_layers = attn_layers, no_const = no_const)
496530

497531
# experimental contrastive loss discriminator regularization
498-
self.D_cl = ContrastiveLearner(self.D, image_size, hidden_layer='flatten') if cl_reg else None
532+
self.D_cl = None
533+
534+
# wrapper for augmenting all images going into the discriminator
535+
self.D_aug = AugWrapper(self.D, image_size)
499536

500537
set_requires_grad(self.SE, False)
501538
set_requires_grad(self.GE, False)
@@ -540,7 +577,7 @@ def forward(self, x):
540577
return x
541578

542579
class Trainer():
543-
def __init__(self, name, results_dir, models_dir, image_size, network_capacity, transparent = False, batch_size = 4, mixed_prob = 0.9, gradient_accumulate_every=1, lr = 2e-4, num_workers = None, save_every = 1000, trunc_psi = 0.6, fp16 = False, cl_reg = False, fq_layers = [], fq_dict_size = 256, attn_layers = [], no_const = False, *args, **kwargs):
580+
def __init__(self, name, results_dir, models_dir, image_size, network_capacity, transparent = False, batch_size = 4, mixed_prob = 0.9, gradient_accumulate_every=1, lr = 2e-4, num_workers = None, save_every = 1000, trunc_psi = 0.6, fp16 = False, cl_reg = False, fq_layers = [], fq_dict_size = 256, attn_layers = [], no_const = False, aug_prob = 0., *args, **kwargs):
544581
self.GAN_params = [args, kwargs]
545582
self.GAN = None
546583

@@ -558,6 +595,7 @@ def __init__(self, name, results_dir, models_dir, image_size, network_capacity,
558595

559596
self.attn_layers = cast_list(attn_layers)
560597
self.no_const = no_const
598+
self.aug_prob = aug_prob
561599

562600
self.lr = lr
563601
self.batch_size = batch_size
@@ -632,6 +670,8 @@ def train(self):
632670
latent_dim = self.GAN.G.latent_dim
633671
num_layers = self.GAN.G.num_layers
634672

673+
aug_prob = self.aug_prob
674+
635675
apply_gradient_penalty = self.steps % 4 == 0
636676
apply_path_penalty = self.steps % 32 == 0
637677
apply_cl_reg_to_generated = self.steps > 20000
@@ -677,11 +717,11 @@ def train(self):
677717
w_styles = styles_def_to_tensor(w_space)
678718

679719
generated_images = self.GAN.G(w_styles, noise)
680-
fake_output, fake_q_loss = self.GAN.D(generated_images.clone().detach())
720+
fake_output, fake_q_loss = self.GAN.D_aug(generated_images.clone().detach(), detach = True, prob = aug_prob)
681721

682722
image_batch = next(self.loader).cuda()
683723
image_batch.requires_grad_()
684-
real_output, real_q_loss = self.GAN.D(image_batch)
724+
real_output, real_q_loss = self.GAN.D_aug(image_batch, prob = aug_prob)
685725

686726
divergence = (F.relu(1 + real_output) + F.relu(1 - fake_output)).mean()
687727
disc_loss = divergence
@@ -716,7 +756,7 @@ def train(self):
716756
w_styles = styles_def_to_tensor(w_space)
717757

718758
generated_images = self.GAN.G(w_styles, noise)
719-
fake_output, _ = self.GAN.D(generated_images)
759+
fake_output, _ = self.GAN.D_aug(generated_images, prob = aug_prob)
720760
loss = fake_output.mean()
721761
gen_loss = loss
722762

0 commit comments

Comments
 (0)