Skip to content

Commit f6bc14c

Browse files
committed
able to return embed from vit-nd-rotary
1 parent 845c844 commit f6bc14c

File tree

2 files changed

+25
-11
lines changed

2 files changed

+25
-11
lines changed

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
setup(
77
name = 'vit-pytorch',
88
packages = find_packages(exclude=['examples']),
9-
version = '1.12.1',
9+
version = '1.12.2',
1010
license='MIT',
1111
description = 'Vision Transformer (ViT) - Pytorch',
1212
long_description = long_description,

vit_pytorch/vit_nd_rotary.py

Lines changed: 24 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
from torch.nn import Module, ModuleList
66
import torch.nn.functional as F
77

8-
from einops import rearrange, repeat, reduce
8+
from einops import rearrange, repeat, reduce, pack, unpack
99
from einops.layers.torch import Rearrange
1010

1111
# helpers
@@ -245,7 +245,11 @@ def __init__(
245245
self.to_latent = nn.Identity()
246246
self.mlp_head = nn.Linear(dim, num_classes)
247247

248-
def forward(self, x):
248+
def forward(
249+
self,
250+
x,
251+
return_embed = False
252+
):
249253
x = self.to_patch_embedding(x) # (b, *spatial_dims, patch_dim)
250254

251255
batch, *spatial_dims, _, device = *x.shape, x.device
@@ -259,16 +263,24 @@ def forward(self, x):
259263
# flatten spatial dimensions for attention with nd rotary
260264

261265
pos = repeat(pos, '... p -> b (...) p', b = batch)
262-
x = rearrange(x, 'b ... d -> b (...) d')
266+
x, packed_shape = pack([x], 'b * d')
263267

264268
x = self.dropout(x)
265269

266-
x = self.transformer(x, pos)
267-
268-
x = reduce(x, 'b n d -> b d', 'mean')
269-
270-
x = self.to_latent(x)
271-
return self.mlp_head(x)
270+
embed = self.transformer(x, pos)
271+
272+
# return the embed with reconstituted patch shape
273+
274+
if return_embed:
275+
embed, = unpack(embed, packed_shape, 'b * d')
276+
return embed
277+
278+
# pooling to logits
279+
280+
pooled = reduce(embed, 'b n d -> b d', 'mean')
281+
282+
pooled = self.to_latent(pooled)
283+
return self.mlp_head(pooled)
272284

273285

274286
if __name__ == '__main__':
@@ -288,5 +300,7 @@ def forward(self, x):
288300
)
289301

290302
data = torch.randn(2, 3, 4, 8, 16, 32, 64)
291-
303+
292304
logits = model(data)
305+
306+
embed = model(data, return_embed = True) # (2, 2, 4, 4, 8, 8, 512)

0 commit comments

Comments
 (0)