|
23 | 23 | from torch.utils.data.distributed import DistributedSampler |
24 | 24 | from torch.nn.parallel import DistributedDataParallel as DDP |
25 | 25 |
|
| 26 | +from kornia.filters import filter2D |
| 27 | + |
26 | 28 | import torchvision |
27 | 29 | from torchvision import transforms |
28 | 30 | from stylegan2_pytorch.diff_augment import DiffAugment |
@@ -101,6 +103,16 @@ def forward(self, x): |
101 | 103 | out = out.permute(0, 3, 1, 2) |
102 | 104 | return out, loss |
103 | 105 |
|
| 106 | +class Blur(nn.Module): |
| 107 | + def __init__(self): |
| 108 | + super().__init__() |
| 109 | + f = torch.Tensor([1, 2, 1]) |
| 110 | + self.register_buffer('f', f) |
| 111 | + def forward(self, x): |
| 112 | + f = self.f |
| 113 | + f = f[None, None, :] * f [None, :, None] |
| 114 | + return filter2D(x, f, normalized=True) |
| 115 | + |
104 | 116 | # one layer of self-attention and feedforward, for images |
105 | 117 |
|
106 | 118 | attn_and_ff = lambda chan: nn.Sequential(*[ |
@@ -352,7 +364,10 @@ def __init__(self, latent_dim, input_channel, upsample, rgba = False): |
352 | 364 | out_filters = 3 if not rgba else 4 |
353 | 365 | self.conv = Conv2DMod(input_channel, out_filters, 1, demod=False) |
354 | 366 |
|
355 | | - self.upsample = nn.Upsample(scale_factor = 2, mode='bilinear', align_corners=False) if upsample else None |
| 367 | + self.upsample = nn.Sequential( |
| 368 | + nn.Upsample(scale_factor = 2, mode='bilinear', align_corners=False), |
| 369 | + Blur() |
| 370 | + ) if upsample else None |
356 | 371 |
|
357 | 372 | def forward(self, x, prev_rgb, istyle): |
358 | 373 | b, c, h, w = x.shape |
@@ -450,7 +465,10 @@ def __init__(self, input_channels, filters, downsample=True): |
450 | 465 | leaky_relu() |
451 | 466 | ) |
452 | 467 |
|
453 | | - self.downsample = nn.Conv2d(filters, filters, 3, padding = 1, stride = 2) if downsample else None |
| 468 | + self.downsample = nn.Sequential( |
| 469 | + Blur(), |
| 470 | + nn.Conv2d(filters, filters, 3, padding = 1, stride = 2) |
| 471 | + ) if downsample else None |
454 | 472 |
|
455 | 473 | def forward(self, x): |
456 | 474 | res = self.conv_res(x) |
|
0 commit comments