Skip to content

Commit 1e4c307

Browse files
committed
offer way to turn off path length reg
1 parent 25ccfcf commit 1e4c307

File tree

3 files changed

+6
-2
lines changed

3 files changed

+6
-2
lines changed

stylegan2_pytorch/cli.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -95,6 +95,7 @@ def train_from_folder(
9595
trunc_psi = 0.75,
9696
mixed_prob = 0.9,
9797
fp16 = False,
98+
no_pl_reg = False,
9899
cl_reg = False,
99100
fq_layers = [],
100101
fq_dict_size = 256,
@@ -131,6 +132,7 @@ def train_from_folder(
131132
num_image_tiles = num_image_tiles,
132133
trunc_psi = trunc_psi,
133134
fp16 = fp16,
135+
no_pl_reg = no_pl_reg,
134136
cl_reg = cl_reg,
135137
fq_layers = fq_layers,
136138
fq_dict_size = fq_dict_size,

stylegan2_pytorch/stylegan2_pytorch.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -710,6 +710,7 @@ def __init__(
710710
trunc_psi = 0.6,
711711
fp16 = False,
712712
cl_reg = False,
713+
no_pl_reg = False,
713714
fq_layers = [],
714715
fq_dict_size = 256,
715716
attn_layers = [],
@@ -771,6 +772,7 @@ def __init__(
771772
self.av = None
772773
self.trunc_psi = trunc_psi
773774

775+
self.no_pl_reg = no_pl_reg
774776
self.pl_mean = None
775777

776778
self.gradient_accumulate_every = gradient_accumulate_every
@@ -880,7 +882,7 @@ def train(self):
880882
aug_kwargs = {'prob': aug_prob, 'types': aug_types}
881883

882884
apply_gradient_penalty = self.steps % 4 == 0
883-
apply_path_penalty = self.steps > 5000 and self.steps % 32 == 0
885+
apply_path_penalty = not self.no_pl_reg and self.steps > 5000 and self.steps % 32 == 0
884886
apply_cl_reg_to_generated = self.steps > 20000
885887

886888
S = self.GAN.S if not self.is_ddp else self.S_ddp

stylegan2_pytorch/version.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
__version__ = '1.5.5'
1+
__version__ = '1.5.6'

0 commit comments

Comments
 (0)