3939# --------------------------------------------------------'
4040
4141import math
42- from typing import Any , Callable , Dict , List , Optional , Set , Tuple , Union
42+ from typing import Any , Callable , Dict , List , Optional , Set , Tuple , Type , Union
4343
4444import torch
4545import torch .nn as nn
4646import torch .nn .functional as F
4747
4848from timm .data import IMAGENET_DEFAULT_MEAN , IMAGENET_DEFAULT_STD
49- from timm .layers import PatchEmbed , Mlp , SwiGLU , LayerNorm , DropPath , calculate_drop_path_rates , trunc_normal_ , use_fused_attn
50- from timm .layers import resample_patch_embed , resample_abs_pos_embed , resize_rel_pos_bias_table , ndgrid
49+ from timm .layers import (
50+ PatchEmbed ,
51+ Mlp ,
52+ SwiGLU ,
53+ LayerNorm ,
54+ DropPath ,
55+ calculate_drop_path_rates ,
56+ trunc_normal_ ,
57+ use_fused_attn ,
58+ resample_patch_embed ,
59+ resample_abs_pos_embed ,
60+ resize_rel_pos_bias_table ,
61+ ndgrid ,
62+ )
5163
5264from ._builder import build_model_with_cfg
5365from ._features import feature_take_indices
5769__all__ = ['Beit' ]
5870
5971
60- def gen_relative_position_index (window_size : Tuple [int , int ]) -> torch .Tensor :
72+ def gen_relative_position_index (window_size : Tuple [int , int ], device = None ) -> torch .Tensor :
6173 """Generate relative position index for window-based attention.
6274
6375 Creates a lookup table for relative position indices between all pairs of positions
@@ -74,14 +86,17 @@ def gen_relative_position_index(window_size: Tuple[int, int]) -> torch.Tensor:
7486 # cls to token & token 2 cls & cls to cls
7587 # get pair-wise relative position index for each token inside the window
7688 window_area = window_size [0 ] * window_size [1 ]
77- coords = torch .stack (ndgrid (torch .arange (window_size [0 ]), torch .arange (window_size [1 ]))) # 2, Wh, Ww
89+ coords = torch .stack (ndgrid (
90+ torch .arange (window_size [0 ], device = device , dtype = torch .long ),
91+ torch .arange (window_size [1 ], device = device , dtype = torch .long ),
92+ )) # 2, Wh, Ww
7893 coords_flatten = torch .flatten (coords , 1 ) # 2, Wh*Ww
7994 relative_coords = coords_flatten [:, :, None ] - coords_flatten [:, None , :] # 2, Wh*Ww, Wh*Ww
8095 relative_coords = relative_coords .permute (1 , 2 , 0 ).contiguous () # Wh*Ww, Wh*Ww, 2
8196 relative_coords [:, :, 0 ] += window_size [0 ] - 1 # shift to start from 0
8297 relative_coords [:, :, 1 ] += window_size [1 ] - 1
8398 relative_coords [:, :, 0 ] *= 2 * window_size [1 ] - 1
84- relative_position_index = torch .zeros (size = (window_area + 1 ,) * 2 , dtype = relative_coords .dtype )
99+ relative_position_index = torch .zeros (size = (window_area + 1 ,) * 2 , device = device , dtype = relative_coords .dtype )
85100 relative_position_index [1 :, 1 :] = relative_coords .sum (- 1 ) # Wh*Ww, Wh*Ww
86101 relative_position_index [0 , 0 :] = num_relative_distance - 3
87102 relative_position_index [0 :, 0 ] = num_relative_distance - 2
@@ -107,6 +122,8 @@ def __init__(
107122 proj_drop : float = 0. ,
108123 window_size : Optional [Tuple [int , int ]] = None ,
109124 attn_head_dim : Optional [int ] = None ,
125+ device = None ,
126+ dtype = None ,
110127 ):
111128 """Initialize attention module.
112129
@@ -120,6 +137,7 @@ def __init__(
120137 window_size: Window size for relative position bias. If None, no relative position bias.
121138 attn_head_dim: Dimension per attention head. If None, uses dim // num_heads.
122139 """
140+ dd = {'device' : device , 'dtype' : dtype }
123141 super ().__init__ ()
124142 self .num_heads = num_heads
125143 head_dim = dim // num_heads
@@ -130,11 +148,11 @@ def __init__(
130148 self .fused_attn = use_fused_attn ()
131149 self .qkv_bias_separate = qkv_bias_separate
132150
133- self .qkv = nn .Linear (dim , all_head_dim * 3 , bias = False )
151+ self .qkv = nn .Linear (dim , all_head_dim * 3 , bias = False , ** dd )
134152 if qkv_bias :
135- self .q_bias = nn .Parameter (torch .zeros (all_head_dim ))
136- self .register_buffer ('k_bias' , torch .zeros (all_head_dim ), persistent = False )
137- self .v_bias = nn .Parameter (torch .zeros (all_head_dim ))
153+ self .q_bias = nn .Parameter (torch .zeros (all_head_dim , ** dd ))
154+ self .register_buffer ('k_bias' , torch .zeros (all_head_dim , ** dd ), persistent = False )
155+ self .v_bias = nn .Parameter (torch .zeros (all_head_dim , ** dd ))
138156 else :
139157 self .q_bias = None
140158 self .k_bias = None
@@ -144,15 +162,19 @@ def __init__(
144162 self .window_size = window_size
145163 self .num_relative_distance = (2 * window_size [0 ] - 1 ) * (2 * window_size [1 ] - 1 ) + 3
146164 self .relative_position_bias_table = nn .Parameter (
147- torch .zeros (self .num_relative_distance , num_heads )) # 2*Wh-1 * 2*Ww-1, nH
148- self .register_buffer ("relative_position_index" , gen_relative_position_index (window_size ), persistent = False )
165+ torch .zeros (self .num_relative_distance , num_heads , ** dd )) # 2*Wh-1 * 2*Ww-1, nH
166+ self .register_buffer (
167+ "relative_position_index" ,
168+ gen_relative_position_index (window_size , device = device ),
169+ persistent = False ,
170+ )
149171 else :
150172 self .window_size = None
151173 self .relative_position_bias_table = None
152174 self .relative_position_index = None
153175
154176 self .attn_drop = nn .Dropout (attn_drop )
155- self .proj = nn .Linear (all_head_dim , dim )
177+ self .proj = nn .Linear (all_head_dim , dim , ** dd )
156178 self .proj_drop = nn .Dropout (proj_drop )
157179
158180 def _get_rel_pos_bias (self ) -> torch .Tensor :
@@ -245,10 +267,12 @@ def __init__(
245267 attn_drop : float = 0. ,
246268 drop_path : float = 0. ,
247269 init_values : Optional [float ] = None ,
248- act_layer : Callable = nn .GELU ,
249- norm_layer : Callable = LayerNorm ,
270+ act_layer : Type [ nn . Module ] = nn .GELU ,
271+ norm_layer : Type [ nn . Module ] = LayerNorm ,
250272 window_size : Optional [Tuple [int , int ]] = None ,
251273 attn_head_dim : Optional [int ] = None ,
274+ device = None ,
275+ dtype = None ,
252276 ):
253277 """Initialize transformer block.
254278
@@ -268,8 +292,9 @@ def __init__(
268292 window_size: Window size for relative position bias in attention.
269293 attn_head_dim: Dimension per attention head.
270294 """
295+ dd = {'device' : device , 'dtype' : dtype }
271296 super ().__init__ ()
272- self .norm1 = norm_layer (dim )
297+ self .norm1 = norm_layer (dim , ** dd )
273298 self .attn = Attention (
274299 dim ,
275300 num_heads = num_heads ,
@@ -278,17 +303,19 @@ def __init__(
278303 proj_drop = proj_drop ,
279304 window_size = window_size ,
280305 attn_head_dim = attn_head_dim ,
306+ ** dd ,
281307 )
282308 # NOTE: drop path for stochastic depth, we shall see if this is better than dropout here
283309 self .drop_path1 = DropPath (drop_path ) if drop_path > 0. else nn .Identity ()
284310
285- self .norm2 = norm_layer (dim )
311+ self .norm2 = norm_layer (dim , ** dd )
286312 if swiglu_mlp :
287313 self .mlp = SwiGLU (
288314 in_features = dim ,
289315 hidden_features = int (dim * mlp_ratio ),
290316 norm_layer = norm_layer if scale_mlp else None ,
291317 drop = proj_drop ,
318+ ** dd ,
292319 )
293320 else :
294321 self .mlp = Mlp (
@@ -297,12 +324,13 @@ def __init__(
297324 act_layer = act_layer ,
298325 norm_layer = norm_layer if scale_mlp else None ,
299326 drop = proj_drop ,
327+ ** dd ,
300328 )
301329 self .drop_path2 = DropPath (drop_path ) if drop_path > 0. else nn .Identity ()
302330
303331 if init_values :
304- self .gamma_1 = nn .Parameter (init_values * torch .ones (dim ))
305- self .gamma_2 = nn .Parameter (init_values * torch .ones (dim ))
332+ self .gamma_1 = nn .Parameter (init_values * torch .ones (dim , ** dd ))
333+ self .gamma_2 = nn .Parameter (init_values * torch .ones (dim , ** dd ))
306334 else :
307335 self .gamma_1 , self .gamma_2 = None , None
308336
@@ -332,18 +360,19 @@ class RelativePositionBias(nn.Module):
332360 within a window, including special handling for cls token.
333361 """
334362
335- def __init__ (self , window_size : Tuple [int , int ], num_heads : int ):
363+ def __init__ (self , window_size : Tuple [int , int ], num_heads : int , device = None , dtype = None ):
336364 """Initialize relative position bias module.
337365
338366 Args:
339367 window_size: Height and width of the attention window.
340368 num_heads: Number of attention heads.
341369 """
370+ dd = {'device' : device , 'dtype' : dtype }
342371 super ().__init__ ()
343372 self .window_size = window_size
344373 self .window_area = window_size [0 ] * window_size [1 ]
345374 num_relative_distance = (2 * window_size [0 ] - 1 ) * (2 * window_size [1 ] - 1 ) + 3
346- self .relative_position_bias_table = nn .Parameter (torch .zeros (num_relative_distance , num_heads ))
375+ self .relative_position_bias_table = nn .Parameter (torch .zeros (num_relative_distance , num_heads , ** dd ))
347376 # trunc_normal_(self.relative_position_bias_table, std=.02)
348377 self .register_buffer ("relative_position_index" , gen_relative_position_index (window_size ))
349378
@@ -385,12 +414,14 @@ def __init__(
385414 proj_drop_rate : float = 0. ,
386415 attn_drop_rate : float = 0. ,
387416 drop_path_rate : float = 0. ,
388- norm_layer : Callable = LayerNorm ,
417+ norm_layer : Type [ nn . Module ] = LayerNorm ,
389418 init_values : Optional [float ] = None ,
390419 use_abs_pos_emb : bool = True ,
391420 use_rel_pos_bias : bool = False ,
392421 use_shared_rel_pos_bias : bool = False ,
393422 head_init_scale : float = 0.001 ,
423+ device = None ,
424+ dtype = None ,
394425 ):
395426 """Initialize BEiT model.
396427
@@ -419,6 +450,7 @@ def __init__(
419450 use_shared_rel_pos_bias: If True, share relative position bias across layers.
420451 head_init_scale: Scale factor for head initialization.
421452 """
453+ dd = {'device' : device , 'dtype' : dtype }
422454 super ().__init__ ()
423455 self .num_classes = num_classes
424456 self .global_pool = global_pool
@@ -431,19 +463,21 @@ def __init__(
431463 patch_size = patch_size ,
432464 in_chans = in_chans ,
433465 embed_dim = embed_dim ,
466+ ** dd ,
434467 )
435468 num_patches = self .patch_embed .num_patches
436469 r = self .patch_embed .feat_ratio () if hasattr (self .patch_embed , 'feat_ratio' ) else patch_size
437470
438- self .cls_token = nn .Parameter (torch .zeros (1 , 1 , embed_dim ))
471+ self .cls_token = nn .Parameter (torch .zeros (1 , 1 , embed_dim , ** dd ))
439472 # self.mask_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
440- self .pos_embed = nn .Parameter (torch .zeros (1 , num_patches + 1 , embed_dim )) if use_abs_pos_emb else None
473+ self .pos_embed = nn .Parameter (torch .zeros (1 , num_patches + 1 , embed_dim , ** dd )) if use_abs_pos_emb else None
441474 self .pos_drop = nn .Dropout (p = pos_drop_rate )
442475
443476 if use_shared_rel_pos_bias :
444477 self .rel_pos_bias = RelativePositionBias (
445478 window_size = self .patch_embed .grid_size ,
446479 num_heads = num_heads ,
480+ ** dd ,
447481 )
448482 else :
449483 self .rel_pos_bias = None
@@ -463,16 +497,17 @@ def __init__(
463497 norm_layer = norm_layer ,
464498 init_values = init_values ,
465499 window_size = self .patch_embed .grid_size if use_rel_pos_bias else None ,
500+ ** dd ,
466501 )
467502 for i in range (depth )])
468503 self .feature_info = [
469504 dict (module = f'blocks.{ i } ' , num_chs = embed_dim , reduction = r ) for i in range (depth )]
470505
471506 use_fc_norm = self .global_pool == 'avg'
472- self .norm = nn .Identity () if use_fc_norm else norm_layer (embed_dim )
473- self .fc_norm = norm_layer (embed_dim ) if use_fc_norm else nn .Identity ()
507+ self .norm = nn .Identity () if use_fc_norm else norm_layer (embed_dim , ** dd )
508+ self .fc_norm = norm_layer (embed_dim , ** dd ) if use_fc_norm else nn .Identity ()
474509 self .head_drop = nn .Dropout (drop_rate )
475- self .head = nn .Linear (embed_dim , num_classes ) if num_classes > 0 else nn .Identity ()
510+ self .head = nn .Linear (embed_dim , num_classes , ** dd ) if num_classes > 0 else nn .Identity ()
476511
477512 self .apply (self ._init_weights )
478513 if self .pos_embed is not None :
0 commit comments