@@ -203,8 +203,8 @@ def __init__(
203203 unique_sizes : List [Tuple [int , int ]],
204204 batch_size : int ,
205205 seq_len : int ,
206- dtype : torch .dtype ,
207206 device : torch .device ,
207+ dtype : torch .dtype ,
208208 ):
209209 self .rope = rope_module
210210 self .size_to_indices = size_to_indices
@@ -362,6 +362,8 @@ def __init__(
362362 norm_layer : Optional [Type [nn .Module ]] = None ,
363363 pos_drop_rate : float = 0. ,
364364 enable_patch_interpolator : bool = False ,
365+ device = None ,
366+ dtype = None ,
365367 ) -> None :
366368 """Initialize NaFlexEmbeds module.
367369
@@ -385,6 +387,7 @@ def __init__(
385387 pos_drop_rate: Dropout rate for position embeddings.
386388 enable_patch_interpolator: Enable dynamic patch size support.
387389 """
390+ dd = {'device' : device , 'dtype' : dtype }
388391 super ().__init__ ()
389392 self .has_class_token = class_token
390393 self .num_reg_tokens = reg_tokens
@@ -402,8 +405,8 @@ def __init__(
402405 self .num_prefix_tokens += reg_tokens
403406
404407 # Create class and register tokens
405- self .cls_token = nn .Parameter (torch .zeros (1 , 1 , embed_dim )) if class_token else None
406- self .reg_token = nn .Parameter (torch .zeros (1 , reg_tokens , embed_dim )) if reg_tokens else None
408+ self .cls_token = nn .Parameter (torch .zeros (1 , 1 , embed_dim , ** dd )) if class_token else None
409+ self .reg_token = nn .Parameter (torch .zeros (1 , reg_tokens , embed_dim , ** dd )) if reg_tokens else None
407410
408411 # Calculate grid size and number of patches
409412 self .default_img_size : Optional [Tuple [int , int ]] = None
@@ -425,15 +428,20 @@ def __init__(
425428 "`norm_layer` must be given when input_norm_layer=True"
426429 input_norm_layer = norm_layer if input_norm_layer is True else (input_norm_layer or None )
427430 self .norm_input = input_norm_layer (patch_dim ) if input_norm_layer else None
428- self .proj = nn .Linear (patch_dim , embed_dim , bias = proj_bias )
431+ self .proj = nn .Linear (patch_dim , embed_dim , bias = proj_bias , ** dd )
429432 self .flatten = False
430433 self .is_linear = True
431434 else :
432435 # Default to convolutional patch embedding for image inputs
433436 assert not input_norm_layer
434437 self .norm_input = None
435438 self .proj = nn .Conv2d (
436- in_chans , embed_dim , kernel_size = patch_size , stride = patch_size , bias = proj_bias
439+ in_chans ,
440+ embed_dim ,
441+ kernel_size = patch_size ,
442+ stride = patch_size ,
443+ bias = proj_bias ,
444+ ** dd ,
437445 )
438446 self .flatten = True
439447 self .is_linear = False
@@ -470,12 +478,12 @@ def __init__(
470478 assert self .pos_embed_grid_size is not None
471479 h , w = self .pos_embed_grid_size
472480 self .pos_embed_type = 'factorized'
473- self .pos_embed_y = nn .Parameter (torch .randn (1 , h , embed_dim ) * .02 )
474- self .pos_embed_x = nn .Parameter (torch .randn (1 , w , embed_dim ) * .02 )
481+ self .pos_embed_y = nn .Parameter (torch .randn (1 , h , embed_dim , ** dd ) * .02 )
482+ self .pos_embed_x = nn .Parameter (torch .randn (1 , w , embed_dim , ** dd ) * .02 )
475483 else :
476484 assert self .pos_embed_grid_size is not None
477485 h , w = self .pos_embed_grid_size
478- self .pos_embed = nn .Parameter (torch .randn (1 , h , w , embed_dim ) * .02 )
486+ self .pos_embed = nn .Parameter (torch .randn (1 , h , w , embed_dim , ** dd ) * .02 )
479487 self .pos_embed_type = 'learned'
480488
481489 # Dropout layer
@@ -1080,6 +1088,8 @@ def __init__(
10801088 in_chans : int = 3 ,
10811089 num_classes : int = 1000 ,
10821090 img_size : Optional [Union [int , Tuple [int , int ]]] = None ,
1091+ device = None ,
1092+ dtype = None ,
10831093 ** kwargs ,
10841094 ) -> None :
10851095 """Initialize NaFlexVit model.
@@ -1092,6 +1102,7 @@ def __init__(
10921102 **kwargs: Additional config parameters to override cfg values.
10931103 """
10941104 super ().__init__ ()
1105+ dd = {'device' : device , 'dtype' : dtype }
10951106
10961107 # Initialize config
10971108 cfg = cfg or NaFlexVitCfg ()
@@ -1141,8 +1152,9 @@ def __init__(
11411152 proj_norm_layer = embed_norm_layer ,
11421153 pos_drop_rate = cfg .pos_drop_rate ,
11431154 enable_patch_interpolator = getattr (cfg , 'enable_patch_interpolator' , False ),
1155+ ** dd ,
11441156 )
1145- self .norm_pre = norm_layer (cfg .embed_dim ) if cfg .pre_norm else nn .Identity ()
1157+ self .norm_pre = norm_layer (cfg .embed_dim , ** dd ) if cfg .pre_norm else nn .Identity ()
11461158
11471159 # ROPE position embeddings at model level
11481160 self .rope : Optional [nn .Module ] = None
@@ -1157,6 +1169,7 @@ def __init__(
11571169 temperature = cfg .rope_temperature ,
11581170 feat_shape = None , # Dynamic shapes for NaFlex
11591171 grid_indexing = cfg .rope_grid_indexing ,
1172+ ** dd ,
11601173 )
11611174 self .rope_is_mixed = True
11621175 elif cfg .rope_type == 'axial' :
@@ -1168,6 +1181,7 @@ def __init__(
11681181 ref_feat_shape = cfg .rope_ref_feat_shape ,
11691182 grid_offset = cfg .rope_grid_offset ,
11701183 grid_indexing = cfg .rope_grid_indexing ,
1184+ ** dd ,
11711185 )
11721186 self .rope_is_mixed = False
11731187 else :
@@ -1200,6 +1214,7 @@ def __init__(
12001214 norm_layer = norm_layer ,
12011215 act_layer = act_layer ,
12021216 mlp_layer = mlp_layer ,
1217+ ** dd ,
12031218 )
12041219 for i in range (cfg .depth )
12051220 ])
@@ -1211,7 +1226,7 @@ def __init__(
12111226 for i in range (cfg .depth )
12121227 ]
12131228
1214- self .norm = norm_layer (cfg .embed_dim ) if cfg .final_norm and not cfg .fc_norm else nn .Identity ()
1229+ self .norm = norm_layer (cfg .embed_dim , ** dd ) if cfg .final_norm and not cfg .fc_norm else nn .Identity ()
12151230
12161231 # Classifier Head
12171232 if cfg .global_pool == 'map' :
@@ -1221,6 +1236,7 @@ def __init__(
12211236 mlp_ratio = cfg .attn_pool_mlp_ratio or cfg .mlp_ratio ,
12221237 norm_layer = norm_layer ,
12231238 act_layer = act_layer ,
1239+ ** dd ,
12241240 )
12251241 else :
12261242 self .attn_pool = None
@@ -1229,9 +1245,9 @@ def __init__(
12291245 fc_norm = cfg .fc_norm
12301246 if fc_norm is None :
12311247 fc_norm = cfg .global_pool == 'avg'
1232- self .fc_norm = norm_layer (cfg .embed_dim ) if cfg .final_norm and fc_norm else nn .Identity ()
1248+ self .fc_norm = norm_layer (cfg .embed_dim , ** dd ) if cfg .final_norm and fc_norm else nn .Identity ()
12331249 self .head_drop = nn .Dropout (cfg .drop_rate )
1234- self .head = nn .Linear (self .embed_dim , num_classes ) if num_classes > 0 else nn .Identity ()
1250+ self .head = nn .Linear (self .embed_dim , num_classes , ** dd ) if num_classes > 0 else nn .Identity ()
12351251
12361252 if cfg .weight_init != 'skip' :
12371253 self .init_weights (cfg .weight_init )
0 commit comments