Skip to content

Commit c7bb5fc

Browse files
authored
maxvit intent to build (#211)
complete hybrid mbconv + block / grid efficient self attention MaxViT
1 parent 946b19b commit c7bb5fc

File tree

7 files changed

+317
-7
lines changed

7 files changed

+317
-7
lines changed

README.md

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
- [RegionViT](#regionvit)
2121
- [ScalableViT](#scalablevit)
2222
- [SepViT](#sepvit)
23+
- [MaxViT](#maxvit)
2324
- [NesT](#nest)
2425
- [MobileViT](#mobilevit)
2526
- [Masked Autoencoder](#masked-autoencoder)
@@ -596,6 +597,37 @@ img = torch.randn(1, 3, 224, 224)
596597
preds = v(img) # (1, 1000)
597598
```
598599

600+
## MaxViT
601+
602+
<img src="./images/max-vit.png" width="400px"></img>
603+
604+
This paper proposes a hybrid convolutional / attention network, using MBConv from the convolution side, and then block / grid axial sparse attention.
605+
606+
They also claim this specific vision transformer is good for generative models (GANs).
607+
608+
ex. MaxViT-S
609+
610+
```python
611+
import torch
612+
from vit_pytorch.max_vit import MaxViT
613+
614+
v = MaxViT(
615+
num_classes = 1000,
616+
dim_conv_stem = 64, # dimension of the convolutional stem, would default to dimension of first layer if not specified
617+
dim = 96, # dimension of first layer, doubles every layer
618+
dim_head = 32, # dimension of attention heads, kept at 32 in paper
619+
depth = (2, 2, 5, 2), # number of MaxViT blocks per stage, which consists of MBConv, block-like attention, grid-like attention
620+
window_size = 7, # window size for block and grids
621+
mbconv_expansion_rate = 4, # expansion rate of MBConv
622+
mbconv_shrinkage_rate = 0.25, # shrinkage rate of squeeze-excitation in MBConv
623+
dropout = 0.1 # dropout
624+
)
625+
626+
img = torch.randn(2, 3, 224, 224)
627+
628+
preds = v(img) # (2, 1000)
629+
```
630+
599631
## NesT
600632

601633
<img src="./images/nest.png" width="400px"></img>
@@ -1544,6 +1576,14 @@ Coming from computer vision and new to transformers? Here are some resources tha
15441576
}
15451577
```
15461578

1579+
```bibtex
1580+
@inproceedings{Tu2022MaxViTMV,
1581+
title = {MaxViT: Multi-Axis Vision Transformer},
1582+
author = {Zhe-Wei Tu and Hossein Talebi and Han Zhang and Feng Yang and Peyman Milanfar and Alan Conrad Bovik and Yinxiao Li},
1583+
year = {2022}
1584+
}
1585+
```
1586+
15471587
```bibtex
15481588
@misc{vaswani2017attention,
15491589
title = {Attention Is All You Need},

images/max-vit.png

133 KB
Loading

setup.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
setup(
44
name = 'vit-pytorch',
55
packages = find_packages(exclude=['examples']),
6-
version = '0.32.2',
6+
version = '0.33.0',
77
license='MIT',
88
description = 'Vision Transformer (ViT) - Pytorch',
99
author = 'Phil Wang',
@@ -16,7 +16,7 @@
1616
],
1717
install_requires=[
1818
'einops>=0.4.1',
19-
'torch>=1.6',
19+
'torch>=1.10',
2020
'torchvision'
2121
],
2222
setup_requires=[

vit_pytorch/crossformer.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -108,7 +108,7 @@ def __init__(
108108
# calculate and store indices for retrieving bias
109109

110110
pos = torch.arange(window_size)
111-
grid = torch.stack(torch.meshgrid(pos, pos))
111+
grid = torch.stack(torch.meshgrid(pos, pos, indexing = 'ij'))
112112
grid = rearrange(grid, 'c i j -> (i j) c')
113113
rel_pos = grid[:, None] - grid[None, :]
114114
rel_pos += window_size - 1
@@ -144,7 +144,7 @@ def forward(self, x):
144144
# add dynamic positional bias
145145

146146
pos = torch.arange(-wsz, wsz + 1, device = device)
147-
rel_pos = torch.stack(torch.meshgrid(pos, pos))
147+
rel_pos = torch.stack(torch.meshgrid(pos, pos, indexing = 'ij'))
148148
rel_pos = rearrange(rel_pos, 'c i j -> (i j) c')
149149
biases = self.dpb(rel_pos.float())
150150
rel_pos_bias = biases[self.rel_pos_indices]

vit_pytorch/levit.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -71,8 +71,8 @@ def __init__(self, dim, fmap_size, heads = 8, dim_key = 32, dim_value = 64, drop
7171
q_range = torch.arange(0, fmap_size, step = (2 if downsample else 1))
7272
k_range = torch.arange(fmap_size)
7373

74-
q_pos = torch.stack(torch.meshgrid(q_range, q_range), dim = -1)
75-
k_pos = torch.stack(torch.meshgrid(k_range, k_range), dim = -1)
74+
q_pos = torch.stack(torch.meshgrid(q_range, q_range, indexing = 'ij'), dim = -1)
75+
k_pos = torch.stack(torch.meshgrid(k_range, k_range, indexing = 'ij'), dim = -1)
7676

7777
q_pos, k_pos = map(lambda t: rearrange(t, 'i j c -> (i j) c'), (q_pos, k_pos))
7878
rel_pos = (q_pos[:, None, ...] - k_pos[None, :, ...]).abs()

vit_pytorch/max_vit.py

Lines changed: 270 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,270 @@
1+
from functools import partial
2+
3+
import torch
4+
from torch import nn, einsum
5+
6+
from einops import rearrange, repeat
7+
from einops.layers.torch import Rearrange, Reduce
8+
9+
# helpers
10+
11+
def exists(val):
12+
return val is not None
13+
14+
def default(val, d):
15+
return val if exists(val) else d
16+
17+
def cast_tuple(val, length = 1):
18+
return val if isinstance(val, tuple) else ((val,) * length)
19+
20+
# helper classes
21+
22+
class PreNormResidual(nn.Module):
23+
def __init__(self, dim, fn):
24+
super().__init__()
25+
self.norm = nn.LayerNorm(dim)
26+
self.fn = fn
27+
28+
def forward(self, x):
29+
return self.fn(self.norm(x)) + x
30+
31+
# MBConv
32+
33+
class SqueezeExcitation(nn.Module):
34+
def __init__(self, dim, shrinkage_rate = 0.25):
35+
super().__init__()
36+
hidden_dim = int(dim * shrinkage_rate)
37+
38+
self.gate = nn.Sequential(
39+
Reduce('b c h w -> b c', 'mean'),
40+
nn.Linear(dim, hidden_dim, bias = False),
41+
nn.SiLU(),
42+
nn.Linear(hidden_dim, dim, bias = False),
43+
nn.Sigmoid(),
44+
Rearrange('b c -> b c 1 1')
45+
)
46+
47+
def forward(self, x):
48+
return x * self.gate(x)
49+
50+
51+
class MBConvResidual(nn.Module):
52+
def __init__(self, fn, dropout = 0.):
53+
super().__init__()
54+
self.fn = fn
55+
self.dropsample = Dropsample(dropout)
56+
57+
def forward(self, x):
58+
out = self.fn(x)
59+
out = self.dropsample(out)
60+
return out
61+
62+
class Dropsample(nn.Module):
63+
def __init__(self, prob = 0):
64+
super().__init__()
65+
self.prob = prob
66+
67+
def forward(self, x):
68+
device = x.device
69+
70+
if self.prob == 0. or (not self.training):
71+
return x
72+
73+
keep_mask = torch.FloatTensor((x.shape[0], 1, 1, 1), device = device).uniform_() > self.prob
74+
return x * keep_mask / (1 - self.prob)
75+
76+
def MBConv(
77+
dim_in,
78+
dim_out,
79+
*,
80+
downsample,
81+
expansion_rate = 4,
82+
shrinkage_rate = 0.25,
83+
dropout = 0.
84+
):
85+
hidden_dim = int(expansion_rate * dim_out)
86+
stride = 2 if downsample else 1
87+
88+
net = nn.Sequential(
89+
nn.Conv2d(dim_in, dim_out, 1),
90+
nn.BatchNorm2d(dim_out),
91+
nn.SiLU(),
92+
nn.Conv2d(dim_out, dim_out, 3, stride = stride, padding = 1, groups = dim_out),
93+
SqueezeExcitation(dim_out, shrinkage_rate = shrinkage_rate),
94+
nn.Conv2d(dim_out, dim_out, 1),
95+
nn.BatchNorm2d(dim_out)
96+
)
97+
98+
if dim_in == dim_out and not downsample:
99+
net = MBConvResidual(net, dropout = dropout)
100+
101+
return net
102+
103+
# attention related classes
104+
105+
class Attention(nn.Module):
106+
def __init__(
107+
self,
108+
dim,
109+
dim_head = 32,
110+
dropout = 0.,
111+
window_size = 7
112+
):
113+
super().__init__()
114+
assert (dim % dim_head) == 0, 'dimension should be divisible by dimension per head'
115+
116+
self.heads = dim // dim_head
117+
self.scale = dim_head ** -0.5
118+
119+
self.to_qkv = nn.Linear(dim, dim * 3, bias = False)
120+
121+
self.attend = nn.Sequential(
122+
nn.Softmax(dim = -1),
123+
nn.Dropout(dropout)
124+
)
125+
126+
self.to_out = nn.Sequential(
127+
nn.Linear(dim, dim, bias = False),
128+
nn.Dropout(dropout)
129+
)
130+
131+
# relative positional bias
132+
133+
self.rel_pos_bias = nn.Embedding((2 * window_size - 1) ** 2, self.heads)
134+
135+
pos = torch.arange(window_size)
136+
grid = torch.stack(torch.meshgrid(pos, pos, indexing = 'ij'))
137+
grid = rearrange(grid, 'c i j -> (i j) c')
138+
rel_pos = rearrange(grid, 'i ... -> i 1 ...') - rearrange(grid, 'j ... -> 1 j ...')
139+
rel_pos += window_size - 1
140+
rel_pos_indices = (rel_pos * torch.tensor([2 * window_size - 1, 1])).sum(dim = -1)
141+
142+
self.register_buffer('rel_pos_indices', rel_pos_indices, persistent = False)
143+
144+
def forward(self, x):
145+
batch, height, width, window_height, window_width, _, device, h = *x.shape, x.device, self.heads
146+
147+
# flatten
148+
149+
x = rearrange(x, 'b x y w1 w2 d -> (b x y) (w1 w2) d')
150+
151+
# project for queries, keys, values
152+
153+
q, k, v = self.to_qkv(x).chunk(3, dim = -1)
154+
155+
# split heads
156+
157+
q, k, v = map(lambda t: rearrange(t, 'b n (h d ) -> b h n d', h = h), (q, k, v))
158+
159+
# scale
160+
161+
q = q * self.scale
162+
163+
# sim
164+
165+
sim = einsum('b h i d, b h j d -> b h i j', q, k)
166+
167+
# add positional bias
168+
169+
bias = self.rel_pos_bias(self.rel_pos_indices)
170+
sim = sim + rearrange(bias, 'i j h -> h i j')
171+
172+
# attention
173+
174+
attn = self.attend(sim)
175+
176+
# aggregate
177+
178+
out = einsum('b h i j, b h j d -> b h i d', attn, v)
179+
180+
# merge heads
181+
182+
out = rearrange(out, 'b h (w1 w2) d -> b w1 w2 (h d)', w1 = window_height, w2 = window_width)
183+
184+
# combine heads out
185+
186+
out = self.to_out(out)
187+
return rearrange(out, '(b x y) ... -> b x y ...', x = height, y = width)
188+
189+
class MaxViT(nn.Module):
190+
def __init__(
191+
self,
192+
*,
193+
num_classes,
194+
dim,
195+
depth,
196+
dim_head = 32,
197+
dim_conv_stem = None,
198+
window_size = 7,
199+
mbconv_expansion_rate = 4,
200+
mbconv_shrinkage_rate = 0.25,
201+
dropout = 0.1,
202+
channels = 3
203+
):
204+
super().__init__()
205+
assert isinstance(depth, tuple), 'depth needs to be tuple if integers indicating number of transformer blocks at that stage'
206+
207+
# convolutional stem
208+
209+
dim_conv_stem = default(dim_conv_stem, dim)
210+
211+
self.conv_stem = nn.Sequential(
212+
nn.Conv2d(channels, dim_conv_stem, 3, stride = 2, padding = 1),
213+
nn.Conv2d(dim_conv_stem, dim_conv_stem, 3, padding = 1)
214+
)
215+
216+
# variables
217+
218+
num_stages = len(depth)
219+
220+
dims = tuple(map(lambda i: (2 ** i) * dim, range(num_stages)))
221+
dims = (dim_conv_stem, *dims)
222+
dim_pairs = tuple(zip(dims[:-1], dims[1:]))
223+
224+
self.layers = nn.ModuleList([])
225+
226+
# shorthand for window size for efficient block - grid like attention
227+
228+
w = window_size
229+
230+
# iterate through stages
231+
232+
for ind, ((layer_dim_in, layer_dim), layer_depth) in enumerate(zip(dim_pairs, depth)):
233+
for stage_ind in range(layer_depth):
234+
is_first = stage_ind == 0
235+
stage_dim_in = layer_dim_in if is_first else layer_dim
236+
237+
block = nn.Sequential(
238+
MBConv(
239+
stage_dim_in,
240+
layer_dim,
241+
downsample = is_first,
242+
expansion_rate = mbconv_expansion_rate,
243+
shrinkage_rate = mbconv_shrinkage_rate
244+
),
245+
Rearrange('b d (x w1) (y w2) -> b x y w1 w2 d', w1 = w, w2 = w), # block-like attention
246+
PreNormResidual(layer_dim, Attention(dim = layer_dim, dim_head = dim_head, dropout = dropout, window_size = w)),
247+
Rearrange('b x y w1 w2 d -> b d (x w1) (y w2)'),
248+
249+
Rearrange('b d (w1 x) (w2 y) -> b x y w1 w2 d', w1 = w, w2 = w), # grid-like attention
250+
PreNormResidual(layer_dim, Attention(dim = layer_dim, dim_head = dim_head, dropout = dropout, window_size = w)),
251+
Rearrange('b x y w1 w2 d -> b d (w1 x) (w2 y)'),
252+
)
253+
254+
self.layers.append(block)
255+
256+
# mlp head out
257+
258+
self.mlp_head = nn.Sequential(
259+
Reduce('b d h w -> b d', 'mean'),
260+
nn.LayerNorm(dims[-1]),
261+
nn.Linear(dims[-1], num_classes)
262+
)
263+
264+
def forward(self, x):
265+
x = self.conv_stem(x)
266+
267+
for stage in self.layers:
268+
x = stage(x)
269+
270+
return self.mlp_head(x)

vit_pytorch/regionvit.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -138,7 +138,7 @@ def forward(self, local_tokens, region_tokens):
138138
h_range = torch.arange(window_size_h, device = device)
139139
w_range = torch.arange(window_size_w, device = device)
140140

141-
grid_x, grid_y = torch.meshgrid(h_range, w_range)
141+
grid_x, grid_y = torch.meshgrid(h_range, w_range, indexing = 'ij')
142142
grid = torch.stack((grid_x, grid_y))
143143
grid = rearrange(grid, 'c h w -> c (h w)')
144144
grid = (grid[:, :, None] - grid[:, None, :]) + (self.window_size - 1)

0 commit comments

Comments
 (0)