Skip to content

Commit e225255

Browse files
committed
allow for setting in max feature maps, after learning @aydao raised his to 1024 for superb results
1 parent 20c2d68 commit e225255

File tree

3 files changed

+7
-2
lines changed

3 files changed

+7
-2
lines changed

stylegan2_pytorch/cli.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -74,6 +74,7 @@ def train_from_folder(
7474
load_from = -1,
7575
image_size = 128,
7676
network_capacity = 16,
77+
fmap_max = 512,
7778
transparent = False,
7879
batch_size = 5,
7980
gradient_accumulate_every = 6,
@@ -117,6 +118,7 @@ def train_from_folder(
117118
gradient_accumulate_every = gradient_accumulate_every,
118119
image_size = image_size,
119120
network_capacity = network_capacity,
121+
fmap_max = fmap_max,
120122
transparent = transparent,
121123
lr = learning_rate,
122124
lr_mlp = lr_mlp,

stylegan2_pytorch/stylegan2_pytorch.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -696,6 +696,7 @@ def __init__(
696696
base_dir = './',
697697
image_size = 128,
698698
network_capacity = 16,
699+
fmap_max = 512,
699700
transparent = False,
700701
batch_size = 4,
701702
mixed_prob = 0.9,
@@ -742,6 +743,7 @@ def __init__(
742743
assert log2(image_size).is_integer(), 'image size must be a power of 2 (64, 128, 256, 512, 1024)'
743744
self.image_size = image_size
744745
self.network_capacity = network_capacity
746+
self.fmap_max = fmap_max
745747
self.transparent = transparent
746748

747749
self.fq_layers = cast_list(fq_layers)
@@ -819,7 +821,7 @@ def hparams(self):
819821

820822
def init_GAN(self):
821823
args, kwargs = self.GAN_params
822-
self.GAN = StyleGAN2(lr = self.lr, lr_mlp = self.lr_mlp, ttur_mult = self.ttur_mult, image_size = self.image_size, network_capacity = self.network_capacity, transparent = self.transparent, fq_layers = self.fq_layers, fq_dict_size = self.fq_dict_size, attn_layers = self.attn_layers, fp16 = self.fp16, cl_reg = self.cl_reg, no_const = self.no_const, rank = self.rank, *args, **kwargs)
824+
self.GAN = StyleGAN2(lr = self.lr, lr_mlp = self.lr_mlp, ttur_mult = self.ttur_mult, image_size = self.image_size, network_capacity = self.network_capacity, fmap_max = self.fmap_max, transparent = self.transparent, fq_layers = self.fq_layers, fq_dict_size = self.fq_dict_size, attn_layers = self.attn_layers, fp16 = self.fp16, cl_reg = self.cl_reg, no_const = self.no_const, rank = self.rank, *args, **kwargs)
823825

824826
if self.is_ddp:
825827
ddp_kwargs = {'device_ids': [self.rank]}
@@ -841,6 +843,7 @@ def load_config(self):
841843
self.transparent = config['transparent']
842844
self.fq_layers = config['fq_layers']
843845
self.fq_dict_size = config['fq_dict_size']
846+
self.fmap_max = config.pop('fmap_max', 512)
844847
self.attn_layers = config.pop('attn_layers', [])
845848
self.no_const = config.pop('no_const', False)
846849
del self.GAN

stylegan2_pytorch/version.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
__version__ = '1.5.0'
1+
__version__ = '1.5.1'

0 commit comments

Comments
 (0)