Skip to content

Commit a36546d

Browse files
committed
add simple vit with register tokens example, cite
1 parent d830b05 commit a36546d

File tree

3 files changed

+139
-1
lines changed

3 files changed

+139
-1
lines changed

README.md

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2020,4 +2020,13 @@ Coming from computer vision and new to transformers? Here are some resources tha
20202020
}
20212021
```
20222022

2023+
```bibtex
2024+
@inproceedings{Darcet2023VisionTN,
2025+
title = {Vision Transformers Need Registers},
2026+
author = {Timoth'ee Darcet and Maxime Oquab and Julien Mairal and Piotr Bojanowski},
2027+
year = {2023},
2028+
url = {https://api.semanticscholar.org/CorpusID:263134283}
2029+
}
2030+
```
2031+
20232032
*I visualise a time when we will be to robots what dogs are to humans, and I’m rooting for the machines.* — Claude Shannon

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
setup(
44
name = 'vit-pytorch',
55
packages = find_packages(exclude=['examples']),
6-
version = '1.4.5 ',
6+
version = '1.5.0 ',
77
license='MIT',
88
description = 'Vision Transformer (ViT) - Pytorch',
99
long_description_content_type = 'text/markdown',
Lines changed: 129 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,129 @@
1+
import torch
2+
from torch import nn
3+
4+
from einops import rearrange, repeat, pack, unpack
5+
from einops.layers.torch import Rearrange
6+
7+
# helpers
8+
9+
def pair(t):
10+
return t if isinstance(t, tuple) else (t, t)
11+
12+
def posemb_sincos_2d(h, w, dim, temperature: int = 10000, dtype = torch.float32):
13+
y, x = torch.meshgrid(torch.arange(h), torch.arange(w), indexing="ij")
14+
assert (dim % 4) == 0, "feature dimension must be multiple of 4 for sincos emb"
15+
omega = torch.arange(dim // 4) / (dim // 4 - 1)
16+
omega = 1.0 / (temperature ** omega)
17+
18+
y = y.flatten()[:, None] * omega[None, :]
19+
x = x.flatten()[:, None] * omega[None, :]
20+
pe = torch.cat((x.sin(), x.cos(), y.sin(), y.cos()), dim=1)
21+
return pe.type(dtype)
22+
23+
# classes
24+
25+
class FeedForward(nn.Module):
26+
def __init__(self, dim, hidden_dim):
27+
super().__init__()
28+
self.net = nn.Sequential(
29+
nn.LayerNorm(dim),
30+
nn.Linear(dim, hidden_dim),
31+
nn.GELU(),
32+
nn.Linear(hidden_dim, dim),
33+
)
34+
def forward(self, x):
35+
return self.net(x)
36+
37+
class Attention(nn.Module):
38+
def __init__(self, dim, heads = 8, dim_head = 64):
39+
super().__init__()
40+
inner_dim = dim_head * heads
41+
self.heads = heads
42+
self.scale = dim_head ** -0.5
43+
self.norm = nn.LayerNorm(dim)
44+
45+
self.attend = nn.Softmax(dim = -1)
46+
47+
self.to_qkv = nn.Linear(dim, inner_dim * 3, bias = False)
48+
self.to_out = nn.Linear(inner_dim, dim, bias = False)
49+
50+
def forward(self, x):
51+
x = self.norm(x)
52+
53+
qkv = self.to_qkv(x).chunk(3, dim = -1)
54+
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = self.heads), qkv)
55+
56+
dots = torch.matmul(q, k.transpose(-1, -2)) * self.scale
57+
58+
attn = self.attend(dots)
59+
60+
out = torch.matmul(attn, v)
61+
out = rearrange(out, 'b h n d -> b n (h d)')
62+
return self.to_out(out)
63+
64+
class Transformer(nn.Module):
65+
def __init__(self, dim, depth, heads, dim_head, mlp_dim):
66+
super().__init__()
67+
self.norm = nn.LayerNorm(dim)
68+
self.layers = nn.ModuleList([])
69+
for _ in range(depth):
70+
self.layers.append(nn.ModuleList([
71+
Attention(dim, heads = heads, dim_head = dim_head),
72+
FeedForward(dim, mlp_dim)
73+
]))
74+
def forward(self, x):
75+
for attn, ff in self.layers:
76+
x = attn(x) + x
77+
x = ff(x) + x
78+
return self.norm(x)
79+
80+
class SimpleViT(nn.Module):
81+
def __init__(self, *, image_size, patch_size, num_classes, dim, depth, heads, mlp_dim, num_register_tokens = 4, channels = 3, dim_head = 64):
82+
super().__init__()
83+
image_height, image_width = pair(image_size)
84+
patch_height, patch_width = pair(patch_size)
85+
86+
assert image_height % patch_height == 0 and image_width % patch_width == 0, 'Image dimensions must be divisible by the patch size.'
87+
88+
patch_dim = channels * patch_height * patch_width
89+
90+
self.to_patch_embedding = nn.Sequential(
91+
Rearrange("b c (h p1) (w p2) -> b (h w) (p1 p2 c)", p1 = patch_height, p2 = patch_width),
92+
nn.LayerNorm(patch_dim),
93+
nn.Linear(patch_dim, dim),
94+
nn.LayerNorm(dim),
95+
)
96+
97+
self.register_tokens = nn.Parameter(torch.randn(num_register_tokens, dim))
98+
99+
self.pos_embedding = posemb_sincos_2d(
100+
h = image_height // patch_height,
101+
w = image_width // patch_width,
102+
dim = dim,
103+
)
104+
105+
self.transformer = Transformer(dim, depth, heads, dim_head, mlp_dim)
106+
107+
self.pool = "mean"
108+
self.to_latent = nn.Identity()
109+
110+
self.linear_head = nn.Linear(dim, num_classes)
111+
112+
def forward(self, img):
113+
batch, device = img.shape[0], img.device
114+
115+
x = self.to_patch_embedding(img)
116+
x += self.pos_embedding.to(device, dtype=x.dtype)
117+
118+
r = repeat(self.register_tokens, 'n d -> b n d', b = batch)
119+
120+
x, ps = pack([x, r], 'b * d')
121+
122+
x = self.transformer(x)
123+
124+
x, _ = unpack(x, ps, 'b * d')
125+
126+
x = x.mean(dim = 1)
127+
128+
x = self.to_latent(x)
129+
return self.linear_head(x)

0 commit comments

Comments
 (0)