Skip to content

Commit 82684f9

Browse files
committed
add integration with Aim
1 parent e8a7dc8 commit 82684f9

File tree

5 files changed

+48
-3
lines changed

5 files changed

+48
-3
lines changed

README.md

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -261,6 +261,25 @@ images = loader.styles_to_images(styles) # call the generator on intermediate s
261261
save_image(images, './sample.jpg') # save your images, or do whatever you desire
262262
```
263263

264+
### Logging to experiment tracker
265+
266+
To log the losses to an open source experiment tracker (Aim), you simply need to pass an extra flag like so.
267+
268+
```bash
269+
$ stylegan2_pytorch --data ./data --log
270+
```
271+
272+
Then, you need to make sure you have <a href="https://docs.docker.com/get-docker/">Docker installed</a>. Following the instructions at <a href="https://github.com/aimhubio/aim">Aim</a>, you execute the following in your terminal.
273+
274+
```bash
275+
$ aim up
276+
```
277+
278+
Then open up your browser to the address and you should see
279+
280+
<img src="./images/aim.png" width="600px"></img>
281+
282+
264283
## Experimental
265284

266285
### Top-k Training for Generator

images/aim.png

51.7 KB
Loading

setup.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
download_url = 'https://github.com/lucidrains/stylegan2-pytorch/archive/v_036.tar.gz',
2222
keywords = ['generative adversarial networks', 'artificial intelligence'],
2323
install_requires=[
24+
'aim',
2425
'contrastive_learner>=0.1.0',
2526
'fire',
2627
'kornia',

stylegan2_pytorch/cli.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -106,7 +106,8 @@ def train_from_folder(
106106
dataset_aug_prob = 0.,
107107
multi_gpus = False,
108108
calculate_fid_every = None,
109-
seed = 42
109+
seed = 42,
110+
log = False
110111
):
111112
model_args = dict(
112113
name = name,
@@ -138,7 +139,8 @@ def train_from_folder(
138139
generator_top_k_frac = generator_top_k_frac,
139140
dataset_aug_prob = dataset_aug_prob,
140141
calculate_fid_every = calculate_fid_every,
141-
mixed_prob = mixed_prob
142+
mixed_prob = mixed_prob,
143+
log = log
142144
)
143145

144146
if generate:

stylegan2_pytorch/stylegan2_pytorch.py

Lines changed: 24 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,8 @@
4444
except:
4545
APEX_AVAILABLE = False
4646

47+
import aim
48+
4749
assert torch.cuda.is_available(), 'You need to have an Nvidia GPU with CUDA installed.'
4850

4951
num_cores = multiprocessing.cpu_count()
@@ -722,6 +724,7 @@ def __init__(
722724
is_ddp = False,
723725
rank = 0,
724726
world_size = 1,
727+
log = False,
725728
*args,
726729
**kwargs
727730
):
@@ -800,14 +803,20 @@ def __init__(
800803
self.rank = rank
801804
self.world_size = world_size
802805

806+
self.logger = aim.Session(experiment=name) if log else None
807+
803808
@property
804809
def image_extension(self):
805810
return 'jpg' if not self.transparent else 'png'
806811

807812
@property
808813
def checkpoint_num(self):
809814
return floor(self.steps // self.save_every)
810-
815+
816+
@property
817+
def hparams(self):
818+
return {'image_size': self.image_size, 'network_capacity': self.network_capacity}
819+
811820
def init_GAN(self):
812821
args, kwargs = self.GAN_params
813822
self.GAN = StyleGAN2(lr = self.lr, lr_mlp = self.lr_mlp, ttur_mult = self.ttur_mult, image_size = self.image_size, network_capacity = self.network_capacity, transparent = self.transparent, fq_layers = self.fq_layers, fq_dict_size = self.fq_dict_size, attn_layers = self.attn_layers, fp16 = self.fp16, cl_reg = self.cl_reg, no_const = self.no_const, rank = self.rank, *args, **kwargs)
@@ -819,6 +828,9 @@ def init_GAN(self):
819828
self.D_ddp = DDP(self.GAN.D, **ddp_kwargs)
820829
self.D_aug_ddp = DDP(self.GAN.D_aug, **ddp_kwargs)
821830

831+
if exists(self.logger):
832+
self.logger.set_params(self.hparams)
833+
822834
def write_config(self):
823835
self.config_path.write_text(json.dumps(self.config()))
824836

@@ -939,6 +951,7 @@ def train(self):
939951
if apply_gradient_penalty:
940952
gp = gradient_penalty(image_batch, real_output)
941953
self.last_gp_loss = gp.clone().detach().item()
954+
self.track(self.last_gp_loss, 'GP')
942955
disc_loss = disc_loss + gp
943956

944957
disc_loss = disc_loss / self.gradient_accumulate_every
@@ -948,6 +961,8 @@ def train(self):
948961
total_disc_loss += divergence.detach().item() / self.gradient_accumulate_every
949962

950963
self.d_loss = float(total_disc_loss)
964+
self.track(self.d_loss, 'D')
965+
951966
self.GAN.D_opt.step()
952967

953968
# train generator
@@ -992,12 +1007,15 @@ def train(self):
9921007
total_gen_loss += loss.detach().item() / self.gradient_accumulate_every
9931008

9941009
self.g_loss = float(total_gen_loss)
1010+
self.track(self.g_loss, 'G')
1011+
9951012
self.GAN.G_opt.step()
9961013

9971014
# calculate moving averages
9981015

9991016
if apply_path_penalty and not np.isnan(avg_pl_length):
10001017
self.pl_mean = self.pl_length_ma.update_average(self.pl_mean, avg_pl_length)
1018+
self.track(self.pl_mean, 'PL')
10011019

10021020
if self.is_main and self.steps % 10 == 0 and self.steps > 20000:
10031021
self.GAN.EMA()
@@ -1203,6 +1221,11 @@ def print_log(self):
12031221
log = ' | '.join(map(lambda n: f'{n[0]}: {n[1]:.2f}', data))
12041222
print(log)
12051223

1224+
def track(self, value, name):
1225+
if not exists(self.logger):
1226+
return
1227+
self.logger.track(value, name = name)
1228+
12061229
def model_name(self, num):
12071230
return str(self.models_dir / self.name / f'model_{num}.pt')
12081231

0 commit comments

Comments
 (0)