Skip to content

Commit b1365b8

Browse files
committed
add two time-scale update rule (TTUR)
1 parent 7663395 commit b1365b8

File tree

3 files changed

+8
-5
lines changed

3 files changed

+8
-5
lines changed

bin/stylegan2_pytorch

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@ def train_from_folder(
1919
gradient_accumulate_every = 5,
2020
num_train_steps = 150000,
2121
learning_rate = 2e-4,
22+
ttur_mult = 2,
2223
num_workers = None,
2324
save_every = 1000,
2425
generate = False,
@@ -45,6 +46,7 @@ def train_from_folder(
4546
network_capacity = network_capacity,
4647
transparent = transparent,
4748
lr = learning_rate,
49+
ttur_mult = ttur_mult,
4850
num_workers = num_workers,
4951
save_every = save_every,
5052
trunc_psi = trunc_psi,

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
name = 'stylegan2_pytorch',
55
packages = find_packages(),
66
scripts=['bin/stylegan2_pytorch'],
7-
version = '0.17.14',
7+
version = '0.18.0',
88
license='GPLv3+',
99
description = 'StyleGan2 in Pytorch',
1010
author = 'Phil Wang',

stylegan2_pytorch/stylegan2_pytorch.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -534,7 +534,7 @@ def forward(self, x):
534534
return x.squeeze(), quantize_loss
535535

536536
class StyleGAN2(nn.Module):
537-
def __init__(self, image_size, latent_dim = 512, fmap_max = 512, style_depth = 8, network_capacity = 16, transparent = False, fp16 = False, cl_reg = False, steps = 1, lr = 1e-4, fq_layers = [], fq_dict_size = 256, attn_layers = [], no_const = False):
537+
def __init__(self, image_size, latent_dim = 512, fmap_max = 512, style_depth = 8, network_capacity = 16, transparent = False, fp16 = False, cl_reg = False, steps = 1, lr = 1e-4, ttur_mult = 2, fq_layers = [], fq_dict_size = 256, attn_layers = [], no_const = False):
538538
super().__init__()
539539
self.lr = lr
540540
self.steps = steps
@@ -563,7 +563,7 @@ def __init__(self, image_size, latent_dim = 512, fmap_max = 512, style_depth = 8
563563

564564
generator_params = list(self.G.parameters()) + list(self.S.parameters())
565565
self.G_opt = AdamP(generator_params, lr = self.lr, betas=(0.5, 0.9))
566-
self.D_opt = AdamP(self.D.parameters(), lr = self.lr, betas=(0.5, 0.9))
566+
self.D_opt = AdamP(self.D.parameters(), lr = self.lr * ttur_mult, betas=(0.5, 0.9))
567567

568568
self._init_weights()
569569
self.reset_parameter_averaging()
@@ -602,7 +602,7 @@ def forward(self, x):
602602
return x
603603

604604
class Trainer():
605-
def __init__(self, name, results_dir, models_dir, image_size, network_capacity, transparent = False, batch_size = 4, mixed_prob = 0.9, gradient_accumulate_every=1, lr = 2e-4, num_workers = None, save_every = 1000, trunc_psi = 0.6, fp16 = False, cl_reg = False, fq_layers = [], fq_dict_size = 256, attn_layers = [], no_const = False, aug_prob = 0., dataset_aug_prob = 0., *args, **kwargs):
605+
def __init__(self, name, results_dir, models_dir, image_size, network_capacity, transparent = False, batch_size = 4, mixed_prob = 0.9, gradient_accumulate_every=1, lr = 2e-4, ttur_mult = 2, num_workers = None, save_every = 1000, trunc_psi = 0.6, fp16 = False, cl_reg = False, fq_layers = [], fq_dict_size = 256, attn_layers = [], no_const = False, aug_prob = 0., dataset_aug_prob = 0., *args, **kwargs):
606606
self.GAN_params = [args, kwargs]
607607
self.GAN = None
608608

@@ -623,6 +623,7 @@ def __init__(self, name, results_dir, models_dir, image_size, network_capacity,
623623
self.aug_prob = aug_prob
624624

625625
self.lr = lr
626+
self.ttur_mult = ttur_mult
626627
self.batch_size = batch_size
627628
self.num_workers = num_workers
628629
self.mixed_prob = mixed_prob
@@ -656,7 +657,7 @@ def __init__(self, name, results_dir, models_dir, image_size, network_capacity,
656657

657658
def init_GAN(self):
658659
args, kwargs = self.GAN_params
659-
self.GAN = StyleGAN2(lr=self.lr, image_size = self.image_size, network_capacity = self.network_capacity, transparent = self.transparent, fq_layers = self.fq_layers, fq_dict_size = self.fq_dict_size, attn_layers = self.attn_layers, fp16 = self.fp16, cl_reg = self.cl_reg, no_const = self.no_const, *args, **kwargs)
660+
self.GAN = StyleGAN2(lr = self.lr, ttur_mult = self.ttur_mult, image_size = self.image_size, network_capacity = self.network_capacity, transparent = self.transparent, fq_layers = self.fq_layers, fq_dict_size = self.fq_dict_size, attn_layers = self.attn_layers, fp16 = self.fp16, cl_reg = self.cl_reg, no_const = self.no_const, *args, **kwargs)
660661

661662
def write_config(self):
662663
self.config_path.write_text(json.dumps(self.config()))

0 commit comments

Comments
 (0)