@@ -686,13 +686,54 @@ def forward(self, x):
686686 return x
687687
688688class Trainer ():
689- def __init__ (self , name , results_dir , models_dir , image_size , network_capacity , transparent = False , batch_size = 4 , mixed_prob = 0.9 , gradient_accumulate_every = 1 , lr = 2e-4 , lr_mlp = 1. , ttur_mult = 2 , rel_disc_loss = False , num_workers = None , save_every = 1000 , evaluate_every = 1000 , trunc_psi = 0.6 , fp16 = False , cl_reg = False , fq_layers = [], fq_dict_size = 256 , attn_layers = [], no_const = False , aug_prob = 0. , aug_types = ['translation' , 'cutout' ], top_k_training = False , generator_top_k_gamma = 0.99 , generator_top_k_frac = 0.5 , dataset_aug_prob = 0. , calculate_fid_every = None , is_ddp = False , rank = 0 , world_size = 1 , * args , ** kwargs ):
689+ def __init__ (
690+ self ,
691+ name = 'default' ,
692+ results_dir = 'results' ,
693+ models_dir = 'models' ,
694+ base_dir = './' ,
695+ image_size = 128 ,
696+ network_capacity = 16 ,
697+ transparent = False ,
698+ batch_size = 4 ,
699+ mixed_prob = 0.9 ,
700+ gradient_accumulate_every = 1 ,
701+ lr = 2e-4 ,
702+ lr_mlp = 1. ,
703+ ttur_mult = 2 ,
704+ rel_disc_loss = False ,
705+ num_workers = None ,
706+ save_every = 1000 ,
707+ evaluate_every = 1000 ,
708+ trunc_psi = 0.6 ,
709+ fp16 = False ,
710+ cl_reg = False ,
711+ fq_layers = [],
712+ fq_dict_size = 256 ,
713+ attn_layers = [],
714+ no_const = False ,
715+ aug_prob = 0. ,
716+ aug_types = ['translation' , 'cutout' ],
717+ top_k_training = False ,
718+ generator_top_k_gamma = 0.99 ,
719+ generator_top_k_frac = 0.5 ,
720+ dataset_aug_prob = 0. ,
721+ calculate_fid_every = None ,
722+ is_ddp = False ,
723+ rank = 0 ,
724+ world_size = 1 ,
725+ * args ,
726+ ** kwargs
727+ ):
690728 self .GAN_params = [args , kwargs ]
691729 self .GAN = None
692730
693731 self .name = name
694- self .results_dir = Path (results_dir )
695- self .models_dir = Path (models_dir )
732+
733+ base_dir = Path (base_dir )
734+ self .base_dir = base_dir
735+ self .results_dir = base_dir / results_dir
736+ self .models_dir = base_dir / models_dir
696737 self .config_path = self .models_dir / name / '.config.json'
697738
698739 assert log2 (image_size ).is_integer (), 'image size must be a power of 2 (64, 128, 256, 512, 1024)'
@@ -1076,23 +1117,34 @@ def calculate_fid(self, num_batches):
10761117 return fid_score .calculate_fid_given_paths ([real_path , fake_path ], 256 , True , 2048 )
10771118
10781119 @torch .no_grad ()
1079- def generate_truncated (self , S , G , style , noi , trunc_psi = 0.75 , num_image_tiles = 8 ):
1080- latent_dim = G .latent_dim
1120+ def truncate_style (self , tensor , trunc_psi = 0.75 ):
1121+ S = self .GAN .S
1122+ batch_size = self .batch_size
1123+ latent_dim = self .GAN .G .latent_dim
10811124
10821125 if not exists (self .av ):
10831126 z = noise (2000 , latent_dim , device = self .rank )
1084- samples = evaluate_in_chunks (self . batch_size , S , z ).cpu ().numpy ()
1127+ samples = evaluate_in_chunks (batch_size , S , z ).cpu ().numpy ()
10851128 self .av = np .mean (samples , axis = 0 )
10861129 self .av = np .expand_dims (self .av , axis = 0 )
1087-
1130+
1131+ av_torch = torch .from_numpy (self .av ).cuda (self .rank )
1132+ tensor = trunc_psi * (tensor - av_torch ) + av_torch
1133+ return tensor
1134+
1135+ @torch .no_grad ()
1136+ def truncate_style_defs (self , w , trunc_psi = 0.75 ):
10881137 w_space = []
1089- for tensor , num_layers in style :
1090- tmp = S (tensor )
1091- av_torch = torch .from_numpy (self .av ).cuda (self .rank )
1092- tmp = trunc_psi * (tmp - av_torch ) + av_torch
1093- w_space .append ((tmp , num_layers ))
1138+ for tensor , num_layers in w :
1139+ tensor = self .truncate_style (tensor , trunc_psi = trunc_psi )
1140+ w_space .append ((tensor , num_layers ))
1141+ return w_space
10941142
1095- w_styles = styles_def_to_tensor (w_space )
1143+ @torch .no_grad ()
1144+ def generate_truncated (self , S , G , style , noi , trunc_psi = 0.75 , num_image_tiles = 8 ):
1145+ w = map (lambda t : (S (t [0 ]), t [1 ]), style )
1146+ w_truncated = self .truncate_style_defs (w , trunc_psi = trunc_psi )
1147+ w_styles = styles_def_to_tensor (w_truncated )
10961148 generated_images = evaluate_in_chunks (self .batch_size , G , w_styles , noi )
10971149 return generated_images .clamp_ (0. , 1. )
10981150
@@ -1159,8 +1211,8 @@ def init_folders(self):
11591211 (self .models_dir / self .name ).mkdir (parents = True , exist_ok = True )
11601212
11611213 def clear (self ):
1162- rmtree (f'./models/ { self .name } ' , True )
1163- rmtree (f'./results/ { self .name } ' , True )
1214+ rmtree (str ( self . models_dir / self .name ) , True )
1215+ rmtree (str ( self . results_dir / self .name ) , True )
11641216 rmtree (str (self .config_path ), True )
11651217 self .init_folders ()
11661218
@@ -1202,3 +1254,27 @@ def load(self, num = -1):
12021254 raise e
12031255 if self .GAN .fp16 and 'amp' in load_data :
12041256 amp .load_state_dict (load_data ['amp' ])
1257+
1258+ class ModelLoader :
1259+ def __init__ (self , * , base_dir , name = 'default' , load_from = - 1 ):
1260+ self .model = Trainer (name = name , base_dir = base_dir )
1261+ self .model .load (load_from )
1262+
1263+ def noise_to_styles (self , noise , trunc_psi = None ):
1264+ w = self .model .GAN .S (noise )
1265+ if exists (trunc_psi ):
1266+ w = self .model .truncate_style (w )
1267+ return w
1268+
1269+ def styles_to_images (self , w ):
1270+ batch_size , * _ = w .shape
1271+ num_layers = self .model .GAN .G .num_layers
1272+ image_size = self .model .image_size
1273+ w_def = [(w , num_layers )]
1274+
1275+ w_tensors = styles_def_to_tensor (w_def )
1276+ noise = image_noise (batch_size , image_size , device = 0 )
1277+
1278+ images = self .model .GAN .G (w_tensors , noise )
1279+ images .clamp_ (0. , 1. )
1280+ return images
0 commit comments