|
23 | 23 | from torch.utils.data.distributed import DistributedSampler |
24 | 24 | from torch.nn.parallel import DistributedDataParallel as DDP |
25 | 25 |
|
26 | | -from kornia.filters import filter2D |
27 | | - |
28 | 26 | import torchvision |
29 | 27 | from torchvision import transforms |
30 | 28 | from stylegan2_pytorch.diff_augment import DiffAugment |
@@ -103,16 +101,6 @@ def forward(self, x): |
103 | 101 | out = out.permute(0, 3, 1, 2) |
104 | 102 | return out, loss |
105 | 103 |
|
106 | | -class Blur(nn.Module): |
107 | | - def __init__(self): |
108 | | - super().__init__() |
109 | | - f = torch.Tensor([1, 2, 1]) |
110 | | - self.register_buffer('f', f) |
111 | | - def forward(self, x): |
112 | | - f = self.f |
113 | | - f = f[None, None, :] * f [None, :, None] |
114 | | - return filter2D(x, f, normalized=True) |
115 | | - |
116 | 104 | # one layer of self-attention and feedforward, for images |
117 | 105 |
|
118 | 106 | attn_and_ff = lambda chan: nn.Sequential(*[ |
@@ -364,10 +352,7 @@ def __init__(self, latent_dim, input_channel, upsample, rgba = False): |
364 | 352 | out_filters = 3 if not rgba else 4 |
365 | 353 | self.conv = Conv2DMod(input_channel, out_filters, 1, demod=False) |
366 | 354 |
|
367 | | - self.upsample = nn.Sequential( |
368 | | - nn.Upsample(scale_factor = 2, mode='bilinear', align_corners=False), |
369 | | - Blur() |
370 | | - ) if upsample else None |
| 355 | + self.upsample = nn.Upsample(scale_factor = 2, mode='bilinear', align_corners=False) if upsample else None |
371 | 356 |
|
372 | 357 | def forward(self, x, prev_rgb, istyle): |
373 | 358 | b, c, h, w = x.shape |
@@ -465,10 +450,7 @@ def __init__(self, input_channels, filters, downsample=True): |
465 | 450 | leaky_relu() |
466 | 451 | ) |
467 | 452 |
|
468 | | - self.downsample = nn.Sequential( |
469 | | - Blur(), |
470 | | - nn.Conv2d(filters, filters, 3, padding = 1, stride = 2) |
471 | | - ) if downsample else None |
| 453 | + self.downsample = nn.Conv2d(filters, filters, 3, padding = 1, stride = 2) if downsample else None |
472 | 454 |
|
473 | 455 | def forward(self, x): |
474 | 456 | res = self.conv_res(x) |
|
0 commit comments