|
52 | 52 |
|
53 | 53 | EXTS = ['jpg', 'jpeg', 'png'] |
54 | 54 | EPS = 1e-8 |
55 | | -EVALUATE_EVERY = 1000 |
56 | 55 | CALC_FID_NUM_IMAGES = 12800 |
57 | 56 |
|
58 | 57 | # helper classes |
@@ -687,7 +686,7 @@ def forward(self, x): |
687 | 686 | return x |
688 | 687 |
|
689 | 688 | class Trainer(): |
690 | | - def __init__(self, name, results_dir, models_dir, image_size, network_capacity, transparent = False, batch_size = 4, mixed_prob = 0.9, gradient_accumulate_every=1, lr = 2e-4, lr_mlp = 1., ttur_mult = 2, rel_disc_loss = False, num_workers = None, save_every = 1000, trunc_psi = 0.6, fp16 = False, cl_reg = False, fq_layers = [], fq_dict_size = 256, attn_layers = [], no_const = False, aug_prob = 0., aug_types = ['translation', 'cutout'], top_k_training = False, generator_top_k_gamma = 0.99, generator_top_k_frac = 0.5, dataset_aug_prob = 0., calculate_fid_every = None, is_ddp = False, rank = 0, world_size = 1, *args, **kwargs): |
| 689 | + def __init__(self, name, results_dir, models_dir, image_size, network_capacity, transparent = False, batch_size = 4, mixed_prob = 0.9, gradient_accumulate_every=1, lr = 2e-4, lr_mlp = 1., ttur_mult = 2, rel_disc_loss = False, num_workers = None, save_every = 1000, evaluate_every = 1000, trunc_psi = 0.6, fp16 = False, cl_reg = False, fq_layers = [], fq_dict_size = 256, attn_layers = [], no_const = False, aug_prob = 0., aug_types = ['translation', 'cutout'], top_k_training = False, generator_top_k_gamma = 0.99, generator_top_k_frac = 0.5, dataset_aug_prob = 0., calculate_fid_every = None, is_ddp = False, rank = 0, world_size = 1, *args, **kwargs): |
691 | 690 | self.GAN_params = [args, kwargs] |
692 | 691 | self.GAN = None |
693 | 692 |
|
@@ -719,6 +718,7 @@ def __init__(self, name, results_dir, models_dir, image_size, network_capacity, |
719 | 718 | self.num_workers = num_workers |
720 | 719 | self.mixed_prob = mixed_prob |
721 | 720 |
|
| 721 | + self.evaluate_every = evaluate_every |
722 | 722 | self.save_every = save_every |
723 | 723 | self.steps = 0 |
724 | 724 |
|
@@ -977,8 +977,8 @@ def train(self): |
977 | 977 | if self.steps % self.save_every == 0: |
978 | 978 | self.save(self.checkpoint_num) |
979 | 979 |
|
980 | | - if self.steps % EVALUATE_EVERY == 0 or (self.steps % 100 == 0 and self.steps < 2500): |
981 | | - self.evaluate(floor(self.steps / EVALUATE_EVERY)) |
| 980 | + if self.steps % self.evaluate_every == 0 or (self.steps % 100 == 0 and self.steps < 2500): |
| 981 | + self.evaluate(floor(self.steps / self.evaluate_every)) |
982 | 982 |
|
983 | 983 | if exists(self.calculate_fid_every) and self.steps % self.calculate_fid_every == 0 and self.steps != 0: |
984 | 984 | num_batches = math.ceil(CALC_FID_NUM_IMAGES / self.batch_size) |
|
0 commit comments