55from torch .nn import Module , ModuleList
66import torch .nn .functional as F
77
8- from einops import rearrange , repeat , reduce
8+ from einops import rearrange , repeat , reduce , pack , unpack
99from einops .layers .torch import Rearrange
1010
1111# helpers
@@ -245,7 +245,11 @@ def __init__(
245245 self .to_latent = nn .Identity ()
246246 self .mlp_head = nn .Linear (dim , num_classes )
247247
248- def forward (self , x ):
248+ def forward (
249+ self ,
250+ x ,
251+ return_embed = False
252+ ):
249253 x = self .to_patch_embedding (x ) # (b, *spatial_dims, patch_dim)
250254
251255 batch , * spatial_dims , _ , device = * x .shape , x .device
@@ -259,16 +263,24 @@ def forward(self, x):
259263 # flatten spatial dimensions for attention with nd rotary
260264
261265 pos = repeat (pos , '... p -> b (...) p' , b = batch )
262- x = rearrange ( x , 'b ... d -> b (...) d' )
266+ x , packed_shape = pack ([ x ] , 'b * d' )
263267
264268 x = self .dropout (x )
265269
266- x = self .transformer (x , pos )
267-
268- x = reduce (x , 'b n d -> b d' , 'mean' )
269-
270- x = self .to_latent (x )
271- return self .mlp_head (x )
270+ embed = self .transformer (x , pos )
271+
272+ # return the embed with reconstituted patch shape
273+
274+ if return_embed :
275+ embed , = unpack (embed , packed_shape , 'b * d' )
276+ return embed
277+
278+ # pooling to logits
279+
280+ pooled = reduce (embed , 'b n d -> b d' , 'mean' )
281+
282+ pooled = self .to_latent (pooled )
283+ return self .mlp_head (pooled )
272284
273285
274286if __name__ == '__main__' :
@@ -288,5 +300,7 @@ def forward(self, x):
288300 )
289301
290302 data = torch .randn (2 , 3 , 4 , 8 , 16 , 32 , 64 )
291-
303+
292304 logits = model (data )
305+
306+ embed = model (data , return_embed = True ) # (2, 2, 4, 4, 8, 8, 512)
0 commit comments