|
| 1 | +from functools import partial |
| 2 | + |
| 3 | +import torch |
| 4 | +from torch import nn, einsum |
| 5 | + |
| 6 | +from einops import rearrange, repeat |
| 7 | +from einops.layers.torch import Rearrange, Reduce |
| 8 | + |
| 9 | +# helpers |
| 10 | + |
| 11 | +def cast_tuple(val, length = 1): |
| 12 | + return val if isinstance(val, tuple) else ((val,) * length) |
| 13 | + |
| 14 | +# helper classes |
| 15 | + |
| 16 | +class ChanLayerNorm(nn.Module): |
| 17 | + def __init__(self, dim, eps = 1e-5): |
| 18 | + super().__init__() |
| 19 | + self.eps = eps |
| 20 | + self.g = nn.Parameter(torch.ones(1, dim, 1, 1)) |
| 21 | + self.b = nn.Parameter(torch.zeros(1, dim, 1, 1)) |
| 22 | + |
| 23 | + def forward(self, x): |
| 24 | + var = torch.var(x, dim = 1, unbiased = False, keepdim = True) |
| 25 | + mean = torch.mean(x, dim = 1, keepdim = True) |
| 26 | + return (x - mean) / (var + self.eps).sqrt() * self.g + self.b |
| 27 | + |
| 28 | +class PreNorm(nn.Module): |
| 29 | + def __init__(self, dim, fn): |
| 30 | + super().__init__() |
| 31 | + self.norm = ChanLayerNorm(dim) |
| 32 | + self.fn = fn |
| 33 | + |
| 34 | + def forward(self, x): |
| 35 | + return self.fn(self.norm(x)) |
| 36 | + |
| 37 | +class OverlappingPatchEmbed(nn.Module): |
| 38 | + def __init__(self, dim_in, dim_out, stride = 2): |
| 39 | + super().__init__() |
| 40 | + kernel_size = stride * 2 - 1 |
| 41 | + padding = kernel_size // 2 |
| 42 | + self.conv = nn.Conv2d(dim_in, dim_out, kernel_size, stride = stride, padding = padding) |
| 43 | + |
| 44 | + def forward(self, x): |
| 45 | + return self.conv(x) |
| 46 | + |
| 47 | +class PEG(nn.Module): |
| 48 | + def __init__(self, dim, kernel_size = 3): |
| 49 | + super().__init__() |
| 50 | + self.proj = nn.Conv2d(dim, dim, kernel_size = kernel_size, padding = kernel_size // 2, groups = dim, stride = 1) |
| 51 | + |
| 52 | + def forward(self, x): |
| 53 | + return self.proj(x) + x |
| 54 | + |
| 55 | +# feedforward |
| 56 | + |
| 57 | +class FeedForward(nn.Module): |
| 58 | + def __init__(self, dim, mult = 4, dropout = 0.): |
| 59 | + super().__init__() |
| 60 | + inner_dim = int(dim * mult) |
| 61 | + self.net = nn.Sequential( |
| 62 | + nn.Conv2d(dim, inner_dim, 1), |
| 63 | + nn.GELU(), |
| 64 | + nn.Dropout(dropout), |
| 65 | + nn.Conv2d(inner_dim, dim, 1), |
| 66 | + nn.Dropout(dropout) |
| 67 | + ) |
| 68 | + def forward(self, x): |
| 69 | + return self.net(x) |
| 70 | + |
| 71 | +# attention |
| 72 | + |
| 73 | +class DSSA(nn.Module): |
| 74 | + def __init__( |
| 75 | + self, |
| 76 | + dim, |
| 77 | + heads = 8, |
| 78 | + dim_head = 32, |
| 79 | + dropout = 0., |
| 80 | + window_size = 7 |
| 81 | + ): |
| 82 | + super().__init__() |
| 83 | + self.heads = heads |
| 84 | + self.scale = dim_head ** -0.5 |
| 85 | + self.window_size = window_size |
| 86 | + inner_dim = dim_head * heads |
| 87 | + |
| 88 | + self.attend = nn.Sequential( |
| 89 | + nn.Softmax(dim = -1), |
| 90 | + nn.Dropout(dropout) |
| 91 | + ) |
| 92 | + |
| 93 | + self.to_qkv = nn.Conv1d(dim, inner_dim * 3, 1, bias = False) |
| 94 | + |
| 95 | + # window tokens |
| 96 | + |
| 97 | + self.window_tokens = nn.Parameter(torch.randn(dim)) |
| 98 | + |
| 99 | + # prenorm and non-linearity for window tokens |
| 100 | + # then projection to queries and keys for window tokens |
| 101 | + |
| 102 | + self.window_tokens_to_qk = nn.Sequential( |
| 103 | + nn.LayerNorm(dim_head), |
| 104 | + nn.GELU(), |
| 105 | + Rearrange('b h n c -> b (h c) n'), |
| 106 | + nn.Conv1d(inner_dim, inner_dim * 2, 1, groups = heads), |
| 107 | + Rearrange('b (h c) n -> b h n c', h = heads), |
| 108 | + ) |
| 109 | + |
| 110 | + # window attention |
| 111 | + |
| 112 | + self.window_attend = nn.Sequential( |
| 113 | + nn.Softmax(dim = -1), |
| 114 | + nn.Dropout(dropout) |
| 115 | + ) |
| 116 | + |
| 117 | + self.to_out = nn.Sequential( |
| 118 | + nn.Conv2d(inner_dim, dim, 1), |
| 119 | + nn.Dropout(dropout) |
| 120 | + ) |
| 121 | + |
| 122 | + def forward(self, x): |
| 123 | + """ |
| 124 | + einstein notation |
| 125 | +
|
| 126 | + b - batch |
| 127 | + c - channels |
| 128 | + w1 - window size (height) |
| 129 | + w2 - also window size (width) |
| 130 | + i - sequence dimension (source) |
| 131 | + j - sequence dimension (target dimension to be reduced) |
| 132 | + h - heads |
| 133 | + x - height of feature map divided by window size |
| 134 | + y - width of feature map divided by window size |
| 135 | + """ |
| 136 | + |
| 137 | + batch, height, width, heads, wsz = x.shape[0], *x.shape[-2:], self.heads, self.window_size |
| 138 | + assert (height % wsz) == 0 and (width % wsz) == 0, f'height {height} and width {width} must be divisible by window size {wsz}' |
| 139 | + num_windows = (height // wsz) * (width // wsz) |
| 140 | + |
| 141 | + # fold in windows for "depthwise" attention - not sure why it is named depthwise when it is just "windowed" attention |
| 142 | + |
| 143 | + x = rearrange(x, 'b c (h w1) (w w2) -> (b h w) c (w1 w2)', w1 = wsz, w2 = wsz) |
| 144 | + |
| 145 | + # add windowing tokens |
| 146 | + |
| 147 | + w = repeat(self.window_tokens, 'c -> b c 1', b = x.shape[0]) |
| 148 | + x = torch.cat((w, x), dim = -1) |
| 149 | + |
| 150 | + # project for queries, keys, value |
| 151 | + |
| 152 | + q, k, v = self.to_qkv(x).chunk(3, dim = 1) |
| 153 | + |
| 154 | + # split out heads |
| 155 | + |
| 156 | + q, k, v = map(lambda t: rearrange(t, 'b (h d) ... -> b h (...) d', h = heads), (q, k, v)) |
| 157 | + |
| 158 | + # scale |
| 159 | + |
| 160 | + q = q * self.scale |
| 161 | + |
| 162 | + # similarity |
| 163 | + |
| 164 | + dots = einsum('b h i d, b h j d -> b h i j', q, k) |
| 165 | + |
| 166 | + # attention |
| 167 | + |
| 168 | + attn = self.attend(dots) |
| 169 | + |
| 170 | + # aggregate values |
| 171 | + |
| 172 | + out = torch.matmul(attn, v) |
| 173 | + |
| 174 | + # split out windowed tokens |
| 175 | + |
| 176 | + window_tokens, windowed_fmaps = out[:, :, 0], out[:, :, 1:] |
| 177 | + |
| 178 | + # early return if there is only 1 window |
| 179 | + |
| 180 | + if num_windows == 1: |
| 181 | + fmap = rearrange(windowed_fmaps, '(b x y) h (w1 w2) d -> b (h d) (x w1) (y w2)', x = height // wsz, y = width // wsz, w1 = wsz, w2 = wsz) |
| 182 | + return self.to_out(fmap) |
| 183 | + |
| 184 | + # carry out the pointwise attention, the main novelty in the paper |
| 185 | + |
| 186 | + window_tokens = rearrange(window_tokens, '(b x y) h d -> b h (x y) d', x = height // wsz, y = width // wsz) |
| 187 | + windowed_fmaps = rearrange(windowed_fmaps, '(b x y) h n d -> b h (x y) n d', x = height // wsz, y = width // wsz) |
| 188 | + |
| 189 | + # windowed queries and keys (preceded by prenorm activation) |
| 190 | + |
| 191 | + w_q, w_k = self.window_tokens_to_qk(window_tokens).chunk(2, dim = -1) |
| 192 | + |
| 193 | + # scale |
| 194 | + |
| 195 | + w_q = w_q * self.scale |
| 196 | + |
| 197 | + # similarities |
| 198 | + |
| 199 | + w_dots = einsum('b h i d, b h j d -> b h i j', w_q, w_k) |
| 200 | + |
| 201 | + w_attn = self.window_attend(w_dots) |
| 202 | + |
| 203 | + # aggregate the feature maps from the "depthwise" attention step (the most interesting part of the paper, one i haven't seen before) |
| 204 | + |
| 205 | + aggregated_windowed_fmap = einsum('b h i j, b h j w d -> b h i w d', w_attn, windowed_fmaps) |
| 206 | + |
| 207 | + # fold back the windows and then combine heads for aggregation |
| 208 | + |
| 209 | + fmap = rearrange(aggregated_windowed_fmap, 'b h (x y) (w1 w2) d -> b (h d) (x w1) (y w2)', x = height // wsz, y = width // wsz, w1 = wsz, w2 = wsz) |
| 210 | + return self.to_out(fmap) |
| 211 | + |
| 212 | +class Transformer(nn.Module): |
| 213 | + def __init__( |
| 214 | + self, |
| 215 | + dim, |
| 216 | + depth, |
| 217 | + dim_head = 32, |
| 218 | + heads = 8, |
| 219 | + ff_mult = 4, |
| 220 | + dropout = 0., |
| 221 | + norm_output = True |
| 222 | + ): |
| 223 | + super().__init__() |
| 224 | + self.layers = nn.ModuleList([]) |
| 225 | + |
| 226 | + for ind in range(depth): |
| 227 | + self.layers.append(nn.ModuleList([ |
| 228 | + PreNorm(dim, DSSA(dim, heads = heads, dim_head = dim_head, dropout = dropout)), |
| 229 | + PreNorm(dim, FeedForward(dim, mult = ff_mult, dropout = dropout)), |
| 230 | + ])) |
| 231 | + |
| 232 | + self.norm = ChanLayerNorm(dim) if norm_output else nn.Identity() |
| 233 | + |
| 234 | + def forward(self, x): |
| 235 | + for attn, ff in self.layers: |
| 236 | + x = attn(x) + x |
| 237 | + x = ff(x) + x |
| 238 | + |
| 239 | + return self.norm(x) |
| 240 | + |
| 241 | +class SepViT(nn.Module): |
| 242 | + def __init__( |
| 243 | + self, |
| 244 | + *, |
| 245 | + num_classes, |
| 246 | + dim, |
| 247 | + depth, |
| 248 | + heads, |
| 249 | + window_size = 7, |
| 250 | + dim_head = 32, |
| 251 | + ff_mult = 4, |
| 252 | + channels = 3, |
| 253 | + dropout = 0. |
| 254 | + ): |
| 255 | + super().__init__() |
| 256 | + self.to_patches = nn.Conv2d(channels, dim, 7, stride = 4, padding = 3) |
| 257 | + |
| 258 | + assert isinstance(depth, tuple), 'depth needs to be tuple if integers indicating number of transformer blocks at that stage' |
| 259 | + |
| 260 | + num_stages = len(depth) |
| 261 | + |
| 262 | + dims = tuple(map(lambda i: (2 ** i) * dim, range(num_stages))) |
| 263 | + dims = (channels, *dims) |
| 264 | + dim_pairs = tuple(zip(dims[:-1], dims[1:])) |
| 265 | + |
| 266 | + strides = (4, *((2,) * (num_stages - 1))) |
| 267 | + |
| 268 | + hyperparams_per_stage = [heads, window_size] |
| 269 | + hyperparams_per_stage = list(map(partial(cast_tuple, length = num_stages), hyperparams_per_stage)) |
| 270 | + assert all(tuple(map(lambda arr: len(arr) == num_stages, hyperparams_per_stage))) |
| 271 | + |
| 272 | + self.layers = nn.ModuleList([]) |
| 273 | + |
| 274 | + for ind, ((layer_dim_in, layer_dim), layer_depth, layer_stride, layer_heads, layer_window_size) in enumerate(zip(dim_pairs, depth, strides, *hyperparams_per_stage)): |
| 275 | + is_last = ind == (num_stages - 1) |
| 276 | + |
| 277 | + self.layers.append(nn.ModuleList([ |
| 278 | + OverlappingPatchEmbed(layer_dim_in, layer_dim, stride = layer_stride), |
| 279 | + PEG(layer_dim), |
| 280 | + Transformer(dim = layer_dim, depth = layer_depth, heads = layer_heads, ff_mult = ff_mult, dropout = dropout, norm_output = not is_last), |
| 281 | + ])) |
| 282 | + |
| 283 | + self.mlp_head = nn.Sequential( |
| 284 | + Reduce('b d h w -> b d', 'mean'), |
| 285 | + nn.LayerNorm(dims[-1]), |
| 286 | + nn.Linear(dims[-1], num_classes) |
| 287 | + ) |
| 288 | + |
| 289 | + def forward(self, x): |
| 290 | + |
| 291 | + for ope, peg, transformer in self.layers: |
| 292 | + x = ope(x) |
| 293 | + x = peg(x) |
| 294 | + x = transformer(x) |
| 295 | + |
| 296 | + return self.mlp_head(x) |
0 commit comments