Skip to content

Commit e33f30e

Browse files
committed
add blurs
1 parent a34a3f2 commit e33f30e

File tree

2 files changed

+24
-5
lines changed

2 files changed

+24
-5
lines changed

setup.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
'stylegan2_pytorch = stylegan2_pytorch.cli:main',
99
],
1010
},
11-
version = '1.0.2',
11+
version = '1.0.3',
1212
license='GPLv3+',
1313
description = 'StyleGan2 in Pytorch',
1414
author = 'Phil Wang',
@@ -17,15 +17,16 @@
1717
download_url = 'https://github.com/lucidrains/stylegan2-pytorch/archive/v_036.tar.gz',
1818
keywords = ['generative adversarial networks', 'artificial intelligence'],
1919
install_requires=[
20+
'contrastive_learner>=0.1.0',
2021
'fire',
22+
'kornia',
23+
'linear_attention_transformer',
2124
'numpy',
2225
'retry',
2326
'tqdm',
2427
'torch',
2528
'torchvision',
2629
'pillow',
27-
'contrastive_learner>=0.1.0',
28-
'linear_attention_transformer',
2930
'vector-quantize-pytorch'
3031
],
3132
classifiers=[

stylegan2_pytorch/stylegan2_pytorch.py

Lines changed: 20 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,8 @@
2323
from torch.utils.data.distributed import DistributedSampler
2424
from torch.nn.parallel import DistributedDataParallel as DDP
2525

26+
from kornia.filters import filter2D
27+
2628
import torchvision
2729
from torchvision import transforms
2830
from stylegan2_pytorch.diff_augment import DiffAugment
@@ -101,6 +103,16 @@ def forward(self, x):
101103
out = out.permute(0, 3, 1, 2)
102104
return out, loss
103105

106+
class Blur(nn.Module):
107+
def __init__(self):
108+
super().__init__()
109+
f = torch.Tensor([1, 2, 1])
110+
self.register_buffer('f', f)
111+
def forward(self, x):
112+
f = self.f
113+
f = f[None, None, :] * f [None, :, None]
114+
return filter2D(x, f, normalized=True)
115+
104116
# one layer of self-attention and feedforward, for images
105117

106118
attn_and_ff = lambda chan: nn.Sequential(*[
@@ -352,7 +364,10 @@ def __init__(self, latent_dim, input_channel, upsample, rgba = False):
352364
out_filters = 3 if not rgba else 4
353365
self.conv = Conv2DMod(input_channel, out_filters, 1, demod=False)
354366

355-
self.upsample = nn.Upsample(scale_factor = 2, mode='bilinear', align_corners=False) if upsample else None
367+
self.upsample = nn.Sequential(
368+
nn.Upsample(scale_factor = 2, mode='bilinear', align_corners=False),
369+
Blur()
370+
) if upsample else None
356371

357372
def forward(self, x, prev_rgb, istyle):
358373
b, c, h, w = x.shape
@@ -450,7 +465,10 @@ def __init__(self, input_channels, filters, downsample=True):
450465
leaky_relu()
451466
)
452467

453-
self.downsample = nn.Conv2d(filters, filters, 3, padding = 1, stride = 2) if downsample else None
468+
self.downsample = nn.Sequential(
469+
Blur(),
470+
nn.Conv2d(filters, filters, 3, padding = 1, stride = 2)
471+
) if downsample else None
454472

455473
def forward(self, x):
456474
res = self.conv_res(x)

0 commit comments

Comments
 (0)