@@ -178,7 +178,8 @@ def __init__(
178178 channels = 3 ,
179179 dim_head = 64 ,
180180 dropout = 0. ,
181- emb_dropout = 0.
181+ emb_dropout = 0. ,
182+ num_register_tokens = 0
182183 ):
183184 super ().__init__ ()
184185 self .dim = dim
@@ -200,8 +201,8 @@ def __init__(
200201 nn .LayerNorm (dim ),
201202 )
202203
203- self .pos_embedding = nn .Parameter (torch .randn (1 , num_patches + 1 , dim ))
204- self .cls_token = nn .Parameter (torch .randn (1 , 1 , dim ))
204+ self .pos_embedding = nn .Parameter (torch .randn (num_patches , dim ))
205+ self .cls_token = nn .Parameter (torch .randn (dim ))
205206 self .dropout = nn .Dropout (emb_dropout )
206207
207208 self .transformer = Transformer (dim , depth , heads , dim_head , mlp_dim , dropout )
@@ -211,13 +212,19 @@ def __init__(
211212
212213 self .mlp_head = nn .Linear (dim , num_classes )
213214
215+ self .register_tokens = nn .Parameter (torch .randn (num_register_tokens , dim ) * 1e-2 )
216+
214217 def forward (self , img , return_hiddens = False ):
215218 x = self .to_patch_embedding (img )
216219 b , n , _ = x .shape
217220
218- cls_tokens = repeat (self .cls_token , '1 1 d -> b 1 d' , b = b )
219- x = cat ((cls_tokens , x ), dim = 1 )
220- x += self .pos_embedding [:, :(n + 1 )]
221+ x += self .pos_embedding [:n ]
222+
223+ cls_tokens = repeat (self .cls_token , 'd -> b d' , b = b )
224+ register_tokens = repeat (self .register_tokens , 'n d -> b n d' , b = b )
225+
226+ x , packed_shape = pack ((register_tokens , cls_tokens , x ), 'b * d' )
227+
221228 x = self .dropout (x )
222229
223230 x , hiddens = self .transformer (x , return_hiddens = True )
@@ -227,7 +234,9 @@ def forward(self, img, return_hiddens = False):
227234 if return_hiddens :
228235 return x , stack (hiddens )
229236
230- x = x .mean (dim = 1 ) if self .pool == 'mean' else x [:, 0 ]
237+ cls_tokens , x , register_tokens = unpack (x , packed_shape , 'b * d' )
238+
239+ x = x .mean (dim = 1 ) if self .pool == 'mean' else cls_tokens
231240
232241 x = self .to_latent (x )
233242 return self .mlp_head (x )
@@ -251,6 +260,7 @@ def __init__(
251260 num_views = None ,
252261 num_tasks = None ,
253262 dim_extra_token = None ,
263+ num_register_tokens = 4 ,
254264 action_chunk_len = 7 ,
255265 time_seq_len = 1 ,
256266 dropout = 0. ,
@@ -295,6 +305,10 @@ def __init__(
295305 if self .has_tasks :
296306 self .task_emb = nn .Parameter (torch .randn (num_tasks , dim ) * 1e-2 )
297307
308+ # register tokens from Darcet et al.
309+
310+ self .register_tokens = nn .Parameter (torch .randn (num_register_tokens , dim ) * 1e-2 )
311+
298312 # to action tokens
299313
300314 self .action_pos_emb = nn .Parameter (torch .randn (action_chunk_len , dim ) * 1e-2 )
@@ -407,6 +421,12 @@ def forward(
407421
408422 action_tokens , packed_extra = pack ([action_tokens , extra_token ], 'b * d' )
409423
424+ # register tokens
425+
426+ register_tokens = repeat (self .register_tokens , 'n d -> b n d' , b = batch )
427+
428+ action_tokens , registers_packed_shape = pack ((register_tokens , action_tokens ), 'b * d' )
429+
410430 # cross attention
411431
412432 hiddens = [action_tokens ]
@@ -425,6 +445,10 @@ def forward(
425445
426446 hiddens .append (action_tokens )
427447
448+ # unpack registers
449+
450+ _ , action_tokens = unpack (action_tokens , registers_packed_shape , 'b * d' )
451+
428452 # maybe unpack extra
429453
430454 if has_extra :
0 commit comments