Skip to content

Commit 4dd3899

Browse files
committed
add --evaluate-every
1 parent c6b4b76 commit 4dd3899

File tree

3 files changed

+7
-5
lines changed

3 files changed

+7
-5
lines changed

stylegan2_pytorch/cli.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -84,6 +84,7 @@ def train_from_folder(
8484
rel_disc_loss = False,
8585
num_workers = None,
8686
save_every = 1000,
87+
evaluate_every = 1000,
8788
generate = False,
8889
generate_interpolation = False,
8990
interpolation_num_steps = 100,
@@ -122,6 +123,7 @@ def train_from_folder(
122123
rel_disc_loss = rel_disc_loss,
123124
num_workers = num_workers,
124125
save_every = save_every,
126+
evaluate_every = evaluate_every,
125127
trunc_psi = trunc_psi,
126128
fp16 = fp16,
127129
cl_reg = cl_reg,

stylegan2_pytorch/stylegan2_pytorch.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,6 @@
5252

5353
EXTS = ['jpg', 'jpeg', 'png']
5454
EPS = 1e-8
55-
EVALUATE_EVERY = 1000
5655
CALC_FID_NUM_IMAGES = 12800
5756

5857
# helper classes
@@ -687,7 +686,7 @@ def forward(self, x):
687686
return x
688687

689688
class Trainer():
690-
def __init__(self, name, results_dir, models_dir, image_size, network_capacity, transparent = False, batch_size = 4, mixed_prob = 0.9, gradient_accumulate_every=1, lr = 2e-4, lr_mlp = 1., ttur_mult = 2, rel_disc_loss = False, num_workers = None, save_every = 1000, trunc_psi = 0.6, fp16 = False, cl_reg = False, fq_layers = [], fq_dict_size = 256, attn_layers = [], no_const = False, aug_prob = 0., aug_types = ['translation', 'cutout'], top_k_training = False, generator_top_k_gamma = 0.99, generator_top_k_frac = 0.5, dataset_aug_prob = 0., calculate_fid_every = None, is_ddp = False, rank = 0, world_size = 1, *args, **kwargs):
689+
def __init__(self, name, results_dir, models_dir, image_size, network_capacity, transparent = False, batch_size = 4, mixed_prob = 0.9, gradient_accumulate_every=1, lr = 2e-4, lr_mlp = 1., ttur_mult = 2, rel_disc_loss = False, num_workers = None, save_every = 1000, evaluate_every = 1000, trunc_psi = 0.6, fp16 = False, cl_reg = False, fq_layers = [], fq_dict_size = 256, attn_layers = [], no_const = False, aug_prob = 0., aug_types = ['translation', 'cutout'], top_k_training = False, generator_top_k_gamma = 0.99, generator_top_k_frac = 0.5, dataset_aug_prob = 0., calculate_fid_every = None, is_ddp = False, rank = 0, world_size = 1, *args, **kwargs):
691690
self.GAN_params = [args, kwargs]
692691
self.GAN = None
693692

@@ -719,6 +718,7 @@ def __init__(self, name, results_dir, models_dir, image_size, network_capacity,
719718
self.num_workers = num_workers
720719
self.mixed_prob = mixed_prob
721720

721+
self.evaluate_every = evaluate_every
722722
self.save_every = save_every
723723
self.steps = 0
724724

@@ -977,8 +977,8 @@ def train(self):
977977
if self.steps % self.save_every == 0:
978978
self.save(self.checkpoint_num)
979979

980-
if self.steps % EVALUATE_EVERY == 0 or (self.steps % 100 == 0 and self.steps < 2500):
981-
self.evaluate(floor(self.steps / EVALUATE_EVERY))
980+
if self.steps % self.evaluate_every == 0 or (self.steps % 100 == 0 and self.steps < 2500):
981+
self.evaluate(floor(self.steps / self.evaluate_every))
982982

983983
if exists(self.calculate_fid_every) and self.steps % self.calculate_fid_every == 0 and self.steps != 0:
984984
num_batches = math.ceil(CALC_FID_NUM_IMAGES / self.batch_size)

stylegan2_pytorch/version.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
__version__ = '1.2.5'
1+
__version__ = '1.2.6'

0 commit comments

Comments
 (0)