Skip to content

Commit 9ade90b

Browse files
committed
add ModelLoader class to allow simple interface for sampling images through python code
1 parent ae7918f commit 9ade90b

File tree

4 files changed

+114
-17
lines changed

4 files changed

+114
-17
lines changed

README.md

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -288,6 +288,27 @@ By default, the StyleGAN architecture styles a constant learned 4x4 block as it
288288
$ stylegan2_pytorch --data ./data --no-const
289289
```
290290

291+
## Research
292+
293+
If you would like to sample images programmatically, you can do so with the following simple `ModelLoader` class.
294+
295+
```python
296+
import torch
297+
from torchvision.utils import save_image
298+
from stylegan2_pytorch import ModelLoader
299+
300+
loader = ModelLoader(
301+
base_dir = '/path/to/directory', # path to where you invoked the command line tool
302+
name = 'default' # the project name, defaults to 'default'
303+
)
304+
305+
noise = torch.randn(1, 512).cuda() # noise
306+
styles = loader.noise_to_styles(noise, trunc_psi = 0.7) # pass through mapping network
307+
images = loader.styles_to_images(styles) # call the generator on intermediate style vectors
308+
309+
save_image(images, './sample.jpg') # save your images, or do whatever you desire
310+
```
311+
291312
## Alternatives
292313

293314
<a href="https://github.com/lucidrains/unet-stylegan2">Stylegan2 + Unet Discriminator</a>

stylegan2_pytorch/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
from stylegan2_pytorch.stylegan2_pytorch import Trainer, StyleGAN2, NanException
1+
from stylegan2_pytorch.stylegan2_pytorch import Trainer, StyleGAN2, NanException, ModelLoader

stylegan2_pytorch/stylegan2_pytorch.py

Lines changed: 91 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -686,13 +686,54 @@ def forward(self, x):
686686
return x
687687

688688
class Trainer():
689-
def __init__(self, name, results_dir, models_dir, image_size, network_capacity, transparent = False, batch_size = 4, mixed_prob = 0.9, gradient_accumulate_every=1, lr = 2e-4, lr_mlp = 1., ttur_mult = 2, rel_disc_loss = False, num_workers = None, save_every = 1000, evaluate_every = 1000, trunc_psi = 0.6, fp16 = False, cl_reg = False, fq_layers = [], fq_dict_size = 256, attn_layers = [], no_const = False, aug_prob = 0., aug_types = ['translation', 'cutout'], top_k_training = False, generator_top_k_gamma = 0.99, generator_top_k_frac = 0.5, dataset_aug_prob = 0., calculate_fid_every = None, is_ddp = False, rank = 0, world_size = 1, *args, **kwargs):
689+
def __init__(
690+
self,
691+
name = 'default',
692+
results_dir = 'results',
693+
models_dir = 'models',
694+
base_dir = './',
695+
image_size = 128,
696+
network_capacity = 16,
697+
transparent = False,
698+
batch_size = 4,
699+
mixed_prob = 0.9,
700+
gradient_accumulate_every=1,
701+
lr = 2e-4,
702+
lr_mlp = 1.,
703+
ttur_mult = 2,
704+
rel_disc_loss = False,
705+
num_workers = None,
706+
save_every = 1000,
707+
evaluate_every = 1000,
708+
trunc_psi = 0.6,
709+
fp16 = False,
710+
cl_reg = False,
711+
fq_layers = [],
712+
fq_dict_size = 256,
713+
attn_layers = [],
714+
no_const = False,
715+
aug_prob = 0.,
716+
aug_types = ['translation', 'cutout'],
717+
top_k_training = False,
718+
generator_top_k_gamma = 0.99,
719+
generator_top_k_frac = 0.5,
720+
dataset_aug_prob = 0.,
721+
calculate_fid_every = None,
722+
is_ddp = False,
723+
rank = 0,
724+
world_size = 1,
725+
*args,
726+
**kwargs
727+
):
690728
self.GAN_params = [args, kwargs]
691729
self.GAN = None
692730

693731
self.name = name
694-
self.results_dir = Path(results_dir)
695-
self.models_dir = Path(models_dir)
732+
733+
base_dir = Path(base_dir)
734+
self.base_dir = base_dir
735+
self.results_dir = base_dir / results_dir
736+
self.models_dir = base_dir / models_dir
696737
self.config_path = self.models_dir / name / '.config.json'
697738

698739
assert log2(image_size).is_integer(), 'image size must be a power of 2 (64, 128, 256, 512, 1024)'
@@ -1076,23 +1117,34 @@ def calculate_fid(self, num_batches):
10761117
return fid_score.calculate_fid_given_paths([real_path, fake_path], 256, True, 2048)
10771118

10781119
@torch.no_grad()
1079-
def generate_truncated(self, S, G, style, noi, trunc_psi = 0.75, num_image_tiles = 8):
1080-
latent_dim = G.latent_dim
1120+
def truncate_style(self, tensor, trunc_psi = 0.75):
1121+
S = self.GAN.S
1122+
batch_size = self.batch_size
1123+
latent_dim = self.GAN.G.latent_dim
10811124

10821125
if not exists(self.av):
10831126
z = noise(2000, latent_dim, device=self.rank)
1084-
samples = evaluate_in_chunks(self.batch_size, S, z).cpu().numpy()
1127+
samples = evaluate_in_chunks(batch_size, S, z).cpu().numpy()
10851128
self.av = np.mean(samples, axis = 0)
10861129
self.av = np.expand_dims(self.av, axis = 0)
1087-
1130+
1131+
av_torch = torch.from_numpy(self.av).cuda(self.rank)
1132+
tensor = trunc_psi * (tensor - av_torch) + av_torch
1133+
return tensor
1134+
1135+
@torch.no_grad()
1136+
def truncate_style_defs(self, w, trunc_psi = 0.75):
10881137
w_space = []
1089-
for tensor, num_layers in style:
1090-
tmp = S(tensor)
1091-
av_torch = torch.from_numpy(self.av).cuda(self.rank)
1092-
tmp = trunc_psi * (tmp - av_torch) + av_torch
1093-
w_space.append((tmp, num_layers))
1138+
for tensor, num_layers in w:
1139+
tensor = self.truncate_style(tensor, trunc_psi = trunc_psi)
1140+
w_space.append((tensor, num_layers))
1141+
return w_space
10941142

1095-
w_styles = styles_def_to_tensor(w_space)
1143+
@torch.no_grad()
1144+
def generate_truncated(self, S, G, style, noi, trunc_psi = 0.75, num_image_tiles = 8):
1145+
w = map(lambda t: (S(t[0]), t[1]), style)
1146+
w_truncated = self.truncate_style_defs(w, trunc_psi = trunc_psi)
1147+
w_styles = styles_def_to_tensor(w_truncated)
10961148
generated_images = evaluate_in_chunks(self.batch_size, G, w_styles, noi)
10971149
return generated_images.clamp_(0., 1.)
10981150

@@ -1159,8 +1211,8 @@ def init_folders(self):
11591211
(self.models_dir / self.name).mkdir(parents=True, exist_ok=True)
11601212

11611213
def clear(self):
1162-
rmtree(f'./models/{self.name}', True)
1163-
rmtree(f'./results/{self.name}', True)
1214+
rmtree(str(self.models_dir / self.name), True)
1215+
rmtree(str(self.results_dir / self.name), True)
11641216
rmtree(str(self.config_path), True)
11651217
self.init_folders()
11661218

@@ -1202,3 +1254,27 @@ def load(self, num = -1):
12021254
raise e
12031255
if self.GAN.fp16 and 'amp' in load_data:
12041256
amp.load_state_dict(load_data['amp'])
1257+
1258+
class ModelLoader:
1259+
def __init__(self, *, base_dir, name = 'default', load_from = -1):
1260+
self.model = Trainer(name = name, base_dir = base_dir)
1261+
self.model.load(load_from)
1262+
1263+
def noise_to_styles(self, noise, trunc_psi = None):
1264+
w = self.model.GAN.S(noise)
1265+
if exists(trunc_psi):
1266+
w = self.model.truncate_style(w)
1267+
return w
1268+
1269+
def styles_to_images(self, w):
1270+
batch_size, *_ = w.shape
1271+
num_layers = self.model.GAN.G.num_layers
1272+
image_size = self.model.image_size
1273+
w_def = [(w, num_layers)]
1274+
1275+
w_tensors = styles_def_to_tensor(w_def)
1276+
noise = image_noise(batch_size, image_size, device = 0)
1277+
1278+
images = self.model.GAN.G(w_tensors, noise)
1279+
images.clamp_(0., 1.)
1280+
return images

stylegan2_pytorch/version.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
__version__ = '1.2.8'
1+
__version__ = '1.4.0'

0 commit comments

Comments
 (0)