@@ -696,6 +696,7 @@ def __init__(
696696 base_dir = './' ,
697697 image_size = 128 ,
698698 network_capacity = 16 ,
699+ fmap_max = 512 ,
699700 transparent = False ,
700701 batch_size = 4 ,
701702 mixed_prob = 0.9 ,
@@ -742,6 +743,7 @@ def __init__(
742743 assert log2 (image_size ).is_integer (), 'image size must be a power of 2 (64, 128, 256, 512, 1024)'
743744 self .image_size = image_size
744745 self .network_capacity = network_capacity
746+ self .fmap_max = fmap_max
745747 self .transparent = transparent
746748
747749 self .fq_layers = cast_list (fq_layers )
@@ -819,7 +821,7 @@ def hparams(self):
819821
820822 def init_GAN (self ):
821823 args , kwargs = self .GAN_params
822- self .GAN = StyleGAN2 (lr = self .lr , lr_mlp = self .lr_mlp , ttur_mult = self .ttur_mult , image_size = self .image_size , network_capacity = self .network_capacity , transparent = self .transparent , fq_layers = self .fq_layers , fq_dict_size = self .fq_dict_size , attn_layers = self .attn_layers , fp16 = self .fp16 , cl_reg = self .cl_reg , no_const = self .no_const , rank = self .rank , * args , ** kwargs )
824+ self .GAN = StyleGAN2 (lr = self .lr , lr_mlp = self .lr_mlp , ttur_mult = self .ttur_mult , image_size = self .image_size , network_capacity = self .network_capacity , fmap_max = self . fmap_max , transparent = self .transparent , fq_layers = self .fq_layers , fq_dict_size = self .fq_dict_size , attn_layers = self .attn_layers , fp16 = self .fp16 , cl_reg = self .cl_reg , no_const = self .no_const , rank = self .rank , * args , ** kwargs )
823825
824826 if self .is_ddp :
825827 ddp_kwargs = {'device_ids' : [self .rank ]}
@@ -841,6 +843,7 @@ def load_config(self):
841843 self .transparent = config ['transparent' ]
842844 self .fq_layers = config ['fq_layers' ]
843845 self .fq_dict_size = config ['fq_dict_size' ]
846+ self .fmap_max = config .pop ('fmap_max' , 512 )
844847 self .attn_layers = config .pop ('attn_layers' , [])
845848 self .no_const = config .pop ('no_const' , False )
846849 del self .GAN
0 commit comments