@@ -534,7 +534,7 @@ def forward(self, x):
534534 return x .squeeze (), quantize_loss
535535
536536class StyleGAN2 (nn .Module ):
537- def __init__ (self , image_size , latent_dim = 512 , fmap_max = 512 , style_depth = 8 , network_capacity = 16 , transparent = False , fp16 = False , cl_reg = False , steps = 1 , lr = 1e-4 , fq_layers = [], fq_dict_size = 256 , attn_layers = [], no_const = False ):
537+ def __init__ (self , image_size , latent_dim = 512 , fmap_max = 512 , style_depth = 8 , network_capacity = 16 , transparent = False , fp16 = False , cl_reg = False , steps = 1 , lr = 1e-4 , ttur_mult = 2 , fq_layers = [], fq_dict_size = 256 , attn_layers = [], no_const = False ):
538538 super ().__init__ ()
539539 self .lr = lr
540540 self .steps = steps
@@ -563,7 +563,7 @@ def __init__(self, image_size, latent_dim = 512, fmap_max = 512, style_depth = 8
563563
564564 generator_params = list (self .G .parameters ()) + list (self .S .parameters ())
565565 self .G_opt = AdamP (generator_params , lr = self .lr , betas = (0.5 , 0.9 ))
566- self .D_opt = AdamP (self .D .parameters (), lr = self .lr , betas = (0.5 , 0.9 ))
566+ self .D_opt = AdamP (self .D .parameters (), lr = self .lr * ttur_mult , betas = (0.5 , 0.9 ))
567567
568568 self ._init_weights ()
569569 self .reset_parameter_averaging ()
@@ -602,7 +602,7 @@ def forward(self, x):
602602 return x
603603
604604class Trainer ():
605- def __init__ (self , name , results_dir , models_dir , image_size , network_capacity , transparent = False , batch_size = 4 , mixed_prob = 0.9 , gradient_accumulate_every = 1 , lr = 2e-4 , num_workers = None , save_every = 1000 , trunc_psi = 0.6 , fp16 = False , cl_reg = False , fq_layers = [], fq_dict_size = 256 , attn_layers = [], no_const = False , aug_prob = 0. , dataset_aug_prob = 0. , * args , ** kwargs ):
605+ def __init__ (self , name , results_dir , models_dir , image_size , network_capacity , transparent = False , batch_size = 4 , mixed_prob = 0.9 , gradient_accumulate_every = 1 , lr = 2e-4 , ttur_mult = 2 , num_workers = None , save_every = 1000 , trunc_psi = 0.6 , fp16 = False , cl_reg = False , fq_layers = [], fq_dict_size = 256 , attn_layers = [], no_const = False , aug_prob = 0. , dataset_aug_prob = 0. , * args , ** kwargs ):
606606 self .GAN_params = [args , kwargs ]
607607 self .GAN = None
608608
@@ -623,6 +623,7 @@ def __init__(self, name, results_dir, models_dir, image_size, network_capacity,
623623 self .aug_prob = aug_prob
624624
625625 self .lr = lr
626+ self .ttur_mult = ttur_mult
626627 self .batch_size = batch_size
627628 self .num_workers = num_workers
628629 self .mixed_prob = mixed_prob
@@ -656,7 +657,7 @@ def __init__(self, name, results_dir, models_dir, image_size, network_capacity,
656657
657658 def init_GAN (self ):
658659 args , kwargs = self .GAN_params
659- self .GAN = StyleGAN2 (lr = self .lr , image_size = self .image_size , network_capacity = self .network_capacity , transparent = self .transparent , fq_layers = self .fq_layers , fq_dict_size = self .fq_dict_size , attn_layers = self .attn_layers , fp16 = self .fp16 , cl_reg = self .cl_reg , no_const = self .no_const , * args , ** kwargs )
660+ self .GAN = StyleGAN2 (lr = self .lr , ttur_mult = self . ttur_mult , image_size = self .image_size , network_capacity = self .network_capacity , transparent = self .transparent , fq_layers = self .fq_layers , fq_dict_size = self .fq_dict_size , attn_layers = self .attn_layers , fp16 = self .fp16 , cl_reg = self .cl_reg , no_const = self .no_const , * args , ** kwargs )
660661
661662 def write_config (self ):
662663 self .config_path .write_text (json .dumps (self .config ()))
0 commit comments