Skip to content

Commit d65a742

Browse files
authored
intent to build (#210)
complete SepViT, from bytedance AI labs
1 parent 8c54e01 commit d65a742

File tree

4 files changed

+337
-3
lines changed

4 files changed

+337
-3
lines changed

README.md

Lines changed: 40 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
- [CrossFormer](#crossformer)
2020
- [RegionViT](#regionvit)
2121
- [ScalableViT](#scalablevit)
22+
- [SepViT](#sepvit)
2223
- [NesT](#nest)
2324
- [MobileViT](#mobilevit)
2425
- [Masked Autoencoder](#masked-autoencoder)
@@ -559,13 +560,42 @@ model = ScalableViT(
559560
reduction_factor = (8, 4, 2, 1), # downsampling of the key / values in SSA. in the paper, this was represented as (reduction_factor ** -2)
560561
window_size = (64, 32, None, None), # window size of the IWSA at each stage. None means no windowing needed
561562
dropout = 0.1, # attention and feedforward dropout
562-
).cuda()
563+
)
563564

564-
img = torch.randn(1, 3, 256, 256).cuda()
565+
img = torch.randn(1, 3, 256, 256)
565566

566567
preds = model(img) # (1, 1000)
567568
```
568569

570+
## SepViT
571+
572+
<img src="./images/sep-vit.png" width="400px"></img>
573+
574+
Another <a href="https://arxiv.org/abs/2203.15380">Bytedance AI paper</a>, it proposes a depthwise-pointwise self-attention layer that seems largely inspired by mobilenet's depthwise-separable convolution. The most interesting aspect is the reuse of the feature map from the depthwise self-attention stage as the values for the pointwise self-attention, as shown in the diagram above.
575+
576+
I have decided to include only the version of `SepViT` with this specific self-attention layer, as the grouped attention layers are not remarkable nor novel, and the authors were not clear on how they treated the window tokens for the group self-attention layer. Besides, it seems like with `DSSA` layer alone, they were able to beat Swin.
577+
578+
ex. SepViT-Lite
579+
580+
```python
581+
import torch
582+
from vit_pytorch.sep_vit import SepViT
583+
584+
v = SepViT(
585+
num_classes = 1000,
586+
dim = 32, # dimensions of first stage, which doubles every stage (32, 64, 128, 256) for SepViT-Lite
587+
dim_head = 32, # attention head dimension
588+
heads = (1, 2, 4, 8), # number of heads per stage
589+
depth = (1, 2, 6, 2), # number of transformer blocks per stage
590+
window_size = 7, # window size of DSS Attention block
591+
dropout = 0.1 # dropout
592+
)
593+
594+
img = torch.randn(1, 3, 224, 224)
595+
596+
preds = v(img) # (1, 1000)
597+
```
598+
569599
## NesT
570600

571601
<img src="./images/nest.png" width="400px"></img>
@@ -1506,6 +1536,14 @@ Coming from computer vision and new to transformers? Here are some resources tha
15061536
}
15071537
```
15081538

1539+
```bibtex
1540+
@inproceedings{Li2022SepViTSV,
1541+
title = {SepViT: Separable Vision Transformer},
1542+
author = {Wei Li and Xing Wang and Xin Xia and Jie Wu and Xuefeng Xiao and Minghang Zheng and Shiping Wen},
1543+
year = {2022}
1544+
}
1545+
```
1546+
15091547
```bibtex
15101548
@misc{vaswani2017attention,
15111549
title = {Attention Is All You Need},

images/sep-vit.png

142 KB
Loading

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
setup(
44
name = 'vit-pytorch',
55
packages = find_packages(exclude=['examples']),
6-
version = '0.31.1',
6+
version = '0.32.0',
77
license='MIT',
88
description = 'Vision Transformer (ViT) - Pytorch',
99
author = 'Phil Wang',

vit_pytorch/sep_vit.py

Lines changed: 296 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,296 @@
1+
from functools import partial
2+
3+
import torch
4+
from torch import nn, einsum
5+
6+
from einops import rearrange, repeat
7+
from einops.layers.torch import Rearrange, Reduce
8+
9+
# helpers
10+
11+
def cast_tuple(val, length = 1):
12+
return val if isinstance(val, tuple) else ((val,) * length)
13+
14+
# helper classes
15+
16+
class ChanLayerNorm(nn.Module):
17+
def __init__(self, dim, eps = 1e-5):
18+
super().__init__()
19+
self.eps = eps
20+
self.g = nn.Parameter(torch.ones(1, dim, 1, 1))
21+
self.b = nn.Parameter(torch.zeros(1, dim, 1, 1))
22+
23+
def forward(self, x):
24+
var = torch.var(x, dim = 1, unbiased = False, keepdim = True)
25+
mean = torch.mean(x, dim = 1, keepdim = True)
26+
return (x - mean) / (var + self.eps).sqrt() * self.g + self.b
27+
28+
class PreNorm(nn.Module):
29+
def __init__(self, dim, fn):
30+
super().__init__()
31+
self.norm = ChanLayerNorm(dim)
32+
self.fn = fn
33+
34+
def forward(self, x):
35+
return self.fn(self.norm(x))
36+
37+
class OverlappingPatchEmbed(nn.Module):
38+
def __init__(self, dim_in, dim_out, stride = 2):
39+
super().__init__()
40+
kernel_size = stride * 2 - 1
41+
padding = kernel_size // 2
42+
self.conv = nn.Conv2d(dim_in, dim_out, kernel_size, stride = stride, padding = padding)
43+
44+
def forward(self, x):
45+
return self.conv(x)
46+
47+
class PEG(nn.Module):
48+
def __init__(self, dim, kernel_size = 3):
49+
super().__init__()
50+
self.proj = nn.Conv2d(dim, dim, kernel_size = kernel_size, padding = kernel_size // 2, groups = dim, stride = 1)
51+
52+
def forward(self, x):
53+
return self.proj(x) + x
54+
55+
# feedforward
56+
57+
class FeedForward(nn.Module):
58+
def __init__(self, dim, mult = 4, dropout = 0.):
59+
super().__init__()
60+
inner_dim = int(dim * mult)
61+
self.net = nn.Sequential(
62+
nn.Conv2d(dim, inner_dim, 1),
63+
nn.GELU(),
64+
nn.Dropout(dropout),
65+
nn.Conv2d(inner_dim, dim, 1),
66+
nn.Dropout(dropout)
67+
)
68+
def forward(self, x):
69+
return self.net(x)
70+
71+
# attention
72+
73+
class DSSA(nn.Module):
74+
def __init__(
75+
self,
76+
dim,
77+
heads = 8,
78+
dim_head = 32,
79+
dropout = 0.,
80+
window_size = 7
81+
):
82+
super().__init__()
83+
self.heads = heads
84+
self.scale = dim_head ** -0.5
85+
self.window_size = window_size
86+
inner_dim = dim_head * heads
87+
88+
self.attend = nn.Sequential(
89+
nn.Softmax(dim = -1),
90+
nn.Dropout(dropout)
91+
)
92+
93+
self.to_qkv = nn.Conv1d(dim, inner_dim * 3, 1, bias = False)
94+
95+
# window tokens
96+
97+
self.window_tokens = nn.Parameter(torch.randn(dim))
98+
99+
# prenorm and non-linearity for window tokens
100+
# then projection to queries and keys for window tokens
101+
102+
self.window_tokens_to_qk = nn.Sequential(
103+
nn.LayerNorm(dim_head),
104+
nn.GELU(),
105+
Rearrange('b h n c -> b (h c) n'),
106+
nn.Conv1d(inner_dim, inner_dim * 2, 1, groups = heads),
107+
Rearrange('b (h c) n -> b h n c', h = heads),
108+
)
109+
110+
# window attention
111+
112+
self.window_attend = nn.Sequential(
113+
nn.Softmax(dim = -1),
114+
nn.Dropout(dropout)
115+
)
116+
117+
self.to_out = nn.Sequential(
118+
nn.Conv2d(inner_dim, dim, 1),
119+
nn.Dropout(dropout)
120+
)
121+
122+
def forward(self, x):
123+
"""
124+
einstein notation
125+
126+
b - batch
127+
c - channels
128+
w1 - window size (height)
129+
w2 - also window size (width)
130+
i - sequence dimension (source)
131+
j - sequence dimension (target dimension to be reduced)
132+
h - heads
133+
x - height of feature map divided by window size
134+
y - width of feature map divided by window size
135+
"""
136+
137+
batch, height, width, heads, wsz = x.shape[0], *x.shape[-2:], self.heads, self.window_size
138+
assert (height % wsz) == 0 and (width % wsz) == 0, f'height {height} and width {width} must be divisible by window size {wsz}'
139+
num_windows = (height // wsz) * (width // wsz)
140+
141+
# fold in windows for "depthwise" attention - not sure why it is named depthwise when it is just "windowed" attention
142+
143+
x = rearrange(x, 'b c (h w1) (w w2) -> (b h w) c (w1 w2)', w1 = wsz, w2 = wsz)
144+
145+
# add windowing tokens
146+
147+
w = repeat(self.window_tokens, 'c -> b c 1', b = x.shape[0])
148+
x = torch.cat((w, x), dim = -1)
149+
150+
# project for queries, keys, value
151+
152+
q, k, v = self.to_qkv(x).chunk(3, dim = 1)
153+
154+
# split out heads
155+
156+
q, k, v = map(lambda t: rearrange(t, 'b (h d) ... -> b h (...) d', h = heads), (q, k, v))
157+
158+
# scale
159+
160+
q = q * self.scale
161+
162+
# similarity
163+
164+
dots = einsum('b h i d, b h j d -> b h i j', q, k)
165+
166+
# attention
167+
168+
attn = self.attend(dots)
169+
170+
# aggregate values
171+
172+
out = torch.matmul(attn, v)
173+
174+
# split out windowed tokens
175+
176+
window_tokens, windowed_fmaps = out[:, :, 0], out[:, :, 1:]
177+
178+
# early return if there is only 1 window
179+
180+
if num_windows == 1:
181+
fmap = rearrange(windowed_fmaps, '(b x y) h (w1 w2) d -> b (h d) (x w1) (y w2)', x = height // wsz, y = width // wsz, w1 = wsz, w2 = wsz)
182+
return self.to_out(fmap)
183+
184+
# carry out the pointwise attention, the main novelty in the paper
185+
186+
window_tokens = rearrange(window_tokens, '(b x y) h d -> b h (x y) d', x = height // wsz, y = width // wsz)
187+
windowed_fmaps = rearrange(windowed_fmaps, '(b x y) h n d -> b h (x y) n d', x = height // wsz, y = width // wsz)
188+
189+
# windowed queries and keys (preceded by prenorm activation)
190+
191+
w_q, w_k = self.window_tokens_to_qk(window_tokens).chunk(2, dim = -1)
192+
193+
# scale
194+
195+
w_q = w_q * self.scale
196+
197+
# similarities
198+
199+
w_dots = einsum('b h i d, b h j d -> b h i j', w_q, w_k)
200+
201+
w_attn = self.window_attend(w_dots)
202+
203+
# aggregate the feature maps from the "depthwise" attention step (the most interesting part of the paper, one i haven't seen before)
204+
205+
aggregated_windowed_fmap = einsum('b h i j, b h j w d -> b h i w d', w_attn, windowed_fmaps)
206+
207+
# fold back the windows and then combine heads for aggregation
208+
209+
fmap = rearrange(aggregated_windowed_fmap, 'b h (x y) (w1 w2) d -> b (h d) (x w1) (y w2)', x = height // wsz, y = width // wsz, w1 = wsz, w2 = wsz)
210+
return self.to_out(fmap)
211+
212+
class Transformer(nn.Module):
213+
def __init__(
214+
self,
215+
dim,
216+
depth,
217+
dim_head = 32,
218+
heads = 8,
219+
ff_mult = 4,
220+
dropout = 0.,
221+
norm_output = True
222+
):
223+
super().__init__()
224+
self.layers = nn.ModuleList([])
225+
226+
for ind in range(depth):
227+
self.layers.append(nn.ModuleList([
228+
PreNorm(dim, DSSA(dim, heads = heads, dim_head = dim_head, dropout = dropout)),
229+
PreNorm(dim, FeedForward(dim, mult = ff_mult, dropout = dropout)),
230+
]))
231+
232+
self.norm = ChanLayerNorm(dim) if norm_output else nn.Identity()
233+
234+
def forward(self, x):
235+
for attn, ff in self.layers:
236+
x = attn(x) + x
237+
x = ff(x) + x
238+
239+
return self.norm(x)
240+
241+
class SepViT(nn.Module):
242+
def __init__(
243+
self,
244+
*,
245+
num_classes,
246+
dim,
247+
depth,
248+
heads,
249+
window_size = 7,
250+
dim_head = 32,
251+
ff_mult = 4,
252+
channels = 3,
253+
dropout = 0.
254+
):
255+
super().__init__()
256+
self.to_patches = nn.Conv2d(channels, dim, 7, stride = 4, padding = 3)
257+
258+
assert isinstance(depth, tuple), 'depth needs to be tuple if integers indicating number of transformer blocks at that stage'
259+
260+
num_stages = len(depth)
261+
262+
dims = tuple(map(lambda i: (2 ** i) * dim, range(num_stages)))
263+
dims = (channels, *dims)
264+
dim_pairs = tuple(zip(dims[:-1], dims[1:]))
265+
266+
strides = (4, *((2,) * (num_stages - 1)))
267+
268+
hyperparams_per_stage = [heads, window_size]
269+
hyperparams_per_stage = list(map(partial(cast_tuple, length = num_stages), hyperparams_per_stage))
270+
assert all(tuple(map(lambda arr: len(arr) == num_stages, hyperparams_per_stage)))
271+
272+
self.layers = nn.ModuleList([])
273+
274+
for ind, ((layer_dim_in, layer_dim), layer_depth, layer_stride, layer_heads, layer_window_size) in enumerate(zip(dim_pairs, depth, strides, *hyperparams_per_stage)):
275+
is_last = ind == (num_stages - 1)
276+
277+
self.layers.append(nn.ModuleList([
278+
OverlappingPatchEmbed(layer_dim_in, layer_dim, stride = layer_stride),
279+
PEG(layer_dim),
280+
Transformer(dim = layer_dim, depth = layer_depth, heads = layer_heads, ff_mult = ff_mult, dropout = dropout, norm_output = not is_last),
281+
]))
282+
283+
self.mlp_head = nn.Sequential(
284+
Reduce('b d h w -> b d', 'mean'),
285+
nn.LayerNorm(dims[-1]),
286+
nn.Linear(dims[-1], num_classes)
287+
)
288+
289+
def forward(self, x):
290+
291+
for ope, peg, transformer in self.layers:
292+
x = ope(x)
293+
x = peg(x)
294+
x = transformer(x)
295+
296+
return self.mlp_head(x)

0 commit comments

Comments
 (0)