Skip to content

Commit 62a83e1

Browse files
committed
allow user to set number of interpolation steps
1 parent b35810a commit 62a83e1

File tree

4 files changed

+8
-7
lines changed

4 files changed

+8
-7
lines changed

README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -76,7 +76,7 @@ $ stylegan2_pytorch --generate
7676
To generate a video of a interpolation through two random points in latent space.
7777

7878
```bash
79-
$ stylegan2_pytorch --generate-interpolation
79+
$ stylegan2_pytorch --generate-interpolation --interpolation-num-steps 100
8080
```
8181

8282
To save each individual frame of the interpolation

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
'stylegan2_pytorch = stylegan2_pytorch.cli:main',
99
],
1010
},
11-
version = '1.0.0',
11+
version = '1.0.1',
1212
license='GPLv3+',
1313
description = 'StyleGan2 in Pytorch',
1414
author = 'Phil Wang',

stylegan2_pytorch/cli.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -82,6 +82,7 @@ def train_from_folder(
8282
save_every = 1000,
8383
generate = False,
8484
generate_interpolation = False,
85+
interpolation_num_steps = 100,
8586
save_frames = False,
8687
num_image_tiles = 8,
8788
trunc_psi = 0.75,
@@ -135,7 +136,7 @@ def train_from_folder(
135136
model = Trainer(**model_args)
136137
model.load(load_from)
137138
samples_name = timestamped_filename()
138-
model.generate_interpolation(samples_name, num_image_tiles, save_frames = save_frames)
139+
model.generate_interpolation(samples_name, num_image_tiles, num_steps = interpolation_num_steps, save_frames = save_frames)
139140
print(f'interpolation generated at {results_dir}/{name}/{samples_name}')
140141
return
141142

stylegan2_pytorch/stylegan2_pytorch.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -998,7 +998,7 @@ def generate_truncated(self, S, G, style, noi, trunc_psi = 0.75, num_image_tiles
998998
return generated_images.clamp_(0., 1.)
999999

10001000
@torch.no_grad()
1001-
def generate_interpolation(self, num = 0, num_image_tiles = 8, trunc = 1.0, save_frames = False):
1001+
def generate_interpolation(self, num = 0, num_image_tiles = 8, trunc = 1.0, num_steps = 100, save_frames = False):
10021002
self.GAN.eval()
10031003
ext = 'jpg' if not self.transparent else 'png'
10041004
num_rows = num_image_tiles
@@ -1009,11 +1009,11 @@ def generate_interpolation(self, num = 0, num_image_tiles = 8, trunc = 1.0, save
10091009

10101010
# latents and noise
10111011

1012-
latents_low = noise(num_rows ** 2, latent_dim)
1013-
latents_high = noise(num_rows ** 2, latent_dim)
1012+
latents_low = noise(num_rows ** 2, latent_dim, device=self.rank)
1013+
latents_high = noise(num_rows ** 2, latent_dim, device=self.rank)
10141014
n = image_noise(num_rows ** 2, image_size, device=self.rank)
10151015

1016-
ratios = torch.linspace(0., 8., 100)
1016+
ratios = torch.linspace(0., 8., num_steps)
10171017

10181018
frames = []
10191019
for ratio in tqdm(ratios):

0 commit comments

Comments
 (0)