Skip to content

Commit a59e8e5

Browse files
authored
Merge pull request #140 from lucidrains/pw/add-generator-loss-top-k
add Top-k Generator Training
2 parents 94f7349 + 1301330 commit a59e8e5

File tree

3 files changed

+42
-2
lines changed

3 files changed

+42
-2
lines changed

README.md

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -260,6 +260,15 @@ By default, the StyleGAN architecture styles a constant learned 4x4 block as it
260260
$ stylegan2_pytorch --data ./data --no-const
261261
```
262262

263+
## Top-k Training for Generator
264+
265+
A new paper has produced evidence that by simply zero-ing out the gradient contributions from samples that are deemed fake by the discriminator, the generator learns significantly better, achieving new state of the art.
266+
267+
```python
268+
$ stylegan2_pytorch --data ./data --generate-top-k --generate-top-k-frac 0.5 --generate-top-k-gamma 0.99
269+
```
270+
271+
Gamma is a decay schedule that slowly decreases the topk from the full batch size to the target fraction of 50% (also modifiable hyperparameter).
263272

264273
## Appreciation
265274

@@ -373,3 +382,14 @@ Thank you to Matthew Mann for his inspiring [simple port](https://github.com/man
373382
primaryClass = {cs.LG}
374383
}
375384
```
385+
386+
```bibtex
387+
@misc{sinha2020topk,
388+
title = {Top-k Training of GANs: Improving GAN Performance by Throwing Away Bad Samples},
389+
author = {Samarth Sinha and Zhengli Zhao and Anirudh Goyal and Colin Raffel and Augustus Odena},
390+
year = {2020},
391+
eprint = {2002.06224},
392+
archivePrefix = {arXiv},
393+
primaryClass = {stat.ML}
394+
}
395+
```

stylegan2_pytorch/cli.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -96,6 +96,9 @@ def train_from_folder(
9696
no_const = False,
9797
aug_prob = 0.,
9898
aug_types = ['translation', 'cutout'],
99+
generator_top_k = False,
100+
generator_top_k_gamma = 0.99,
101+
generator_top_k_frac = 0.5,
99102
dataset_aug_prob = 0.,
100103
multi_gpus = False,
101104
calculate_fid_every = None
@@ -124,6 +127,9 @@ def train_from_folder(
124127
no_const = no_const,
125128
aug_prob = aug_prob,
126129
aug_types = cast_list(aug_types),
130+
generator_top_k = generator_top_k,
131+
generator_top_k_gamma = generator_top_k_gamma,
132+
generator_top_k_frac = generator_top_k_frac,
127133
dataset_aug_prob = dataset_aug_prob,
128134
calculate_fid_every = calculate_fid_every
129135
)

stylegan2_pytorch/stylegan2_pytorch.py

Lines changed: 16 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -685,7 +685,7 @@ def forward(self, x):
685685
return x
686686

687687
class Trainer():
688-
def __init__(self, name, results_dir, models_dir, image_size, network_capacity, transparent = False, batch_size = 4, mixed_prob = 0.9, gradient_accumulate_every=1, lr = 2e-4, lr_mlp = 1., ttur_mult = 2, rel_disc_loss = False, num_workers = None, save_every = 1000, trunc_psi = 0.6, fp16 = False, cl_reg = False, fq_layers = [], fq_dict_size = 256, attn_layers = [], no_const = False, aug_prob = 0., aug_types = ['translation', 'cutout'], dataset_aug_prob = 0., calculate_fid_every = None, is_ddp = False, rank = 0, world_size = 1, *args, **kwargs):
688+
def __init__(self, name, results_dir, models_dir, image_size, network_capacity, transparent = False, batch_size = 4, mixed_prob = 0.9, gradient_accumulate_every=1, lr = 2e-4, lr_mlp = 1., ttur_mult = 2, rel_disc_loss = False, num_workers = None, save_every = 1000, trunc_psi = 0.6, fp16 = False, cl_reg = False, fq_layers = [], fq_dict_size = 256, attn_layers = [], no_const = False, aug_prob = 0., aug_types = ['translation', 'cutout'], generator_top_k = False, generator_top_k_gamma = 0.99, generator_top_k_frac = 0.5, dataset_aug_prob = 0., calculate_fid_every = None, is_ddp = False, rank = 0, world_size = 1, *args, **kwargs):
689689
self.GAN_params = [args, kwargs]
690690
self.GAN = None
691691

@@ -747,6 +747,10 @@ def __init__(self, name, results_dir, models_dir, image_size, network_capacity,
747747

748748
self.calculate_fid_every = calculate_fid_every
749749

750+
self.generator_top_k = generator_top_k
751+
self.generator_top_k_gamma = generator_top_k_gamma
752+
self.generator_top_k_frac = generator_top_k_frac
753+
750754
assert not (is_ddp and cl_reg), 'Contrastive loss regularization does not work well with multi GPUs yet'
751755
self.is_ddp = is_ddp
752756
self.is_main = rank == 0
@@ -912,7 +916,17 @@ def train(self):
912916

913917
generated_images = G(w_styles, noise)
914918
fake_output, _ = D_aug(generated_images, **aug_kwargs)
915-
loss = fake_output.mean()
919+
fake_output_loss = fake_output
920+
921+
if self.generator_top_k:
922+
epochs = (self.steps * batch_size * self.gradient_accumulate_every) / len(self.dataset)
923+
k_frac = max(self.generator_top_k_gamma ** epochs, self.generator_top_k_frac)
924+
k = math.ceil(batch_size * k_frac)
925+
926+
if k != batch_size:
927+
fake_output_loss, _ = fake_output_loss.topk(k=k, largest=False)
928+
929+
loss = fake_output_loss.mean()
916930
gen_loss = loss
917931

918932
if apply_path_penalty:

0 commit comments

Comments
 (0)