Skip to content

Commit f7d59ce

Browse files
committed
some register tokens cannot hurt for VAT
1 parent a583cb5 commit f7d59ce

File tree

2 files changed

+32
-8
lines changed

2 files changed

+32
-8
lines changed

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"
44

55
[project]
66
name = "vit-pytorch"
7-
version = "1.14.4"
7+
version = "1.14.5"
88
description = "Vision Transformer (ViT) - Pytorch"
99
readme = { file = "README.md", content-type = "text/markdown" }
1010
license = { file = "LICENSE" }

vit_pytorch/vat.py

Lines changed: 31 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -178,7 +178,8 @@ def __init__(
178178
channels = 3,
179179
dim_head = 64,
180180
dropout = 0.,
181-
emb_dropout = 0.
181+
emb_dropout = 0.,
182+
num_register_tokens = 0
182183
):
183184
super().__init__()
184185
self.dim = dim
@@ -200,8 +201,8 @@ def __init__(
200201
nn.LayerNorm(dim),
201202
)
202203

203-
self.pos_embedding = nn.Parameter(torch.randn(1, num_patches + 1, dim))
204-
self.cls_token = nn.Parameter(torch.randn(1, 1, dim))
204+
self.pos_embedding = nn.Parameter(torch.randn(num_patches, dim))
205+
self.cls_token = nn.Parameter(torch.randn(dim))
205206
self.dropout = nn.Dropout(emb_dropout)
206207

207208
self.transformer = Transformer(dim, depth, heads, dim_head, mlp_dim, dropout)
@@ -211,13 +212,19 @@ def __init__(
211212

212213
self.mlp_head = nn.Linear(dim, num_classes)
213214

215+
self.register_tokens = nn.Parameter(torch.randn(num_register_tokens, dim) * 1e-2)
216+
214217
def forward(self, img, return_hiddens = False):
215218
x = self.to_patch_embedding(img)
216219
b, n, _ = x.shape
217220

218-
cls_tokens = repeat(self.cls_token, '1 1 d -> b 1 d', b = b)
219-
x = cat((cls_tokens, x), dim=1)
220-
x += self.pos_embedding[:, :(n + 1)]
221+
x += self.pos_embedding[:n]
222+
223+
cls_tokens = repeat(self.cls_token, 'd -> b d', b = b)
224+
register_tokens = repeat(self.register_tokens, 'n d -> b n d', b = b)
225+
226+
x, packed_shape = pack((register_tokens, cls_tokens, x), 'b * d')
227+
221228
x = self.dropout(x)
222229

223230
x, hiddens = self.transformer(x, return_hiddens = True)
@@ -227,7 +234,9 @@ def forward(self, img, return_hiddens = False):
227234
if return_hiddens:
228235
return x, stack(hiddens)
229236

230-
x = x.mean(dim = 1) if self.pool == 'mean' else x[:, 0]
237+
cls_tokens, x, register_tokens = unpack(x, packed_shape, 'b * d')
238+
239+
x = x.mean(dim = 1) if self.pool == 'mean' else cls_tokens
231240

232241
x = self.to_latent(x)
233242
return self.mlp_head(x)
@@ -251,6 +260,7 @@ def __init__(
251260
num_views = None,
252261
num_tasks = None,
253262
dim_extra_token = None,
263+
num_register_tokens = 4,
254264
action_chunk_len = 7,
255265
time_seq_len = 1,
256266
dropout = 0.,
@@ -295,6 +305,10 @@ def __init__(
295305
if self.has_tasks:
296306
self.task_emb = nn.Parameter(torch.randn(num_tasks, dim) * 1e-2)
297307

308+
# register tokens from Darcet et al.
309+
310+
self.register_tokens = nn.Parameter(torch.randn(num_register_tokens, dim) * 1e-2)
311+
298312
# to action tokens
299313

300314
self.action_pos_emb = nn.Parameter(torch.randn(action_chunk_len, dim) * 1e-2)
@@ -407,6 +421,12 @@ def forward(
407421

408422
action_tokens, packed_extra = pack([action_tokens, extra_token], 'b * d')
409423

424+
# register tokens
425+
426+
register_tokens = repeat(self.register_tokens, 'n d -> b n d', b = batch)
427+
428+
action_tokens, registers_packed_shape = pack((register_tokens, action_tokens), 'b * d')
429+
410430
# cross attention
411431

412432
hiddens = [action_tokens]
@@ -425,6 +445,10 @@ def forward(
425445

426446
hiddens.append(action_tokens)
427447

448+
# unpack registers
449+
450+
_, action_tokens = unpack(action_tokens, registers_packed_shape, 'b * d')
451+
428452
# maybe unpack extra
429453

430454
if has_extra:

0 commit comments

Comments
 (0)