Skip to content

Commit b440cc2

Browse files
committed
auto-set augmentation probability if number of training samples is low
1 parent c91e500 commit b440cc2

File tree

2 files changed

+7
-1
lines changed

2 files changed

+7
-1
lines changed

stylegan2_pytorch/stylegan2_pytorch.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -862,6 +862,12 @@ def set_data_src(self, folder):
862862
dataloader = data.DataLoader(self.dataset, num_workers = num_workers, batch_size = math.ceil(self.batch_size / self.world_size), sampler = sampler, shuffle = not self.is_ddp, drop_last = True, pin_memory = True)
863863
self.loader = cycle(dataloader)
864864

865+
# auto set augmentation prob for user if dataset is detected to be low
866+
num_samples = len(self.dataset)
867+
if not exists(self.aug_prob) and num_samples < 1e5:
868+
self.aug_prob = min(0.5, (1e5 - num_samples) * 3e-6)
869+
print(f'autosetting augmentation probability to {round(self.aug_prob * 100)}%')
870+
865871
def train(self):
866872
assert exists(self.loader), 'You must first initialize the data source with `.set_data_src(<folder of images>)`'
867873

stylegan2_pytorch/version.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
__version__ = '1.5.9'
1+
__version__ = '1.5.10'

0 commit comments

Comments
 (0)