2121# SPDX-License-Identifier: Apache-2.0
2222
2323from functools import partial
24- from typing import List , Optional , Tuple
24+ from typing import List , Optional , Tuple , Type , Union
2525
2626import torch
2727import torch .nn as nn
@@ -40,7 +40,17 @@ class PatchEmbed(nn.Module):
4040 """ Image to Patch Embedding
4141 """
4242
43- def __init__ (self , img_size = 224 , patch_size = 16 , in_chans = 3 , embed_dim = 768 , multi_conv = False ):
43+ def __init__ (
44+ self ,
45+ img_size : Union [int , Tuple [int , int ]] = 224 ,
46+ patch_size : int = 16 ,
47+ in_chans : int = 3 ,
48+ embed_dim : int = 768 ,
49+ multi_conv : bool = False ,
50+ device = None ,
51+ dtype = None ,
52+ ):
53+ dd = {'device' : device , 'dtype' : dtype }
4454 super ().__init__ ()
4555 img_size = to_2tuple (img_size )
4656 patch_size = to_2tuple (patch_size )
@@ -51,22 +61,22 @@ def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768, multi
5161 if multi_conv :
5262 if patch_size [0 ] == 12 :
5363 self .proj = nn .Sequential (
54- nn .Conv2d (in_chans , embed_dim // 4 , kernel_size = 7 , stride = 4 , padding = 3 ),
64+ nn .Conv2d (in_chans , embed_dim // 4 , kernel_size = 7 , stride = 4 , padding = 3 , ** dd ),
5565 nn .ReLU (inplace = True ),
56- nn .Conv2d (embed_dim // 4 , embed_dim // 2 , kernel_size = 3 , stride = 3 , padding = 0 ),
66+ nn .Conv2d (embed_dim // 4 , embed_dim // 2 , kernel_size = 3 , stride = 3 , padding = 0 , ** dd ),
5767 nn .ReLU (inplace = True ),
58- nn .Conv2d (embed_dim // 2 , embed_dim , kernel_size = 3 , stride = 1 , padding = 1 ),
68+ nn .Conv2d (embed_dim // 2 , embed_dim , kernel_size = 3 , stride = 1 , padding = 1 , ** dd ),
5969 )
6070 elif patch_size [0 ] == 16 :
6171 self .proj = nn .Sequential (
62- nn .Conv2d (in_chans , embed_dim // 4 , kernel_size = 7 , stride = 4 , padding = 3 ),
72+ nn .Conv2d (in_chans , embed_dim // 4 , kernel_size = 7 , stride = 4 , padding = 3 , ** dd ),
6373 nn .ReLU (inplace = True ),
64- nn .Conv2d (embed_dim // 4 , embed_dim // 2 , kernel_size = 3 , stride = 2 , padding = 1 ),
74+ nn .Conv2d (embed_dim // 4 , embed_dim // 2 , kernel_size = 3 , stride = 2 , padding = 1 , ** dd ),
6575 nn .ReLU (inplace = True ),
66- nn .Conv2d (embed_dim // 2 , embed_dim , kernel_size = 3 , stride = 2 , padding = 1 ),
76+ nn .Conv2d (embed_dim // 2 , embed_dim , kernel_size = 3 , stride = 2 , padding = 1 , ** dd ),
6777 )
6878 else :
69- self .proj = nn .Conv2d (in_chans , embed_dim , kernel_size = patch_size , stride = patch_size )
79+ self .proj = nn .Conv2d (in_chans , embed_dim , kernel_size = patch_size , stride = patch_size , ** dd )
7080
7181 def forward (self , x ):
7282 B , C , H , W = x .shape
@@ -82,23 +92,26 @@ def forward(self, x):
8292class CrossAttention (nn .Module ):
8393 def __init__ (
8494 self ,
85- dim ,
86- num_heads = 8 ,
87- qkv_bias = False ,
88- attn_drop = 0. ,
89- proj_drop = 0. ,
95+ dim : int ,
96+ num_heads : int = 8 ,
97+ qkv_bias : bool = False ,
98+ attn_drop : float = 0. ,
99+ proj_drop : float = 0. ,
100+ device = None ,
101+ dtype = None ,
90102 ):
103+ dd = {'device' : device , 'dtype' : dtype }
91104 super ().__init__ ()
92105 self .num_heads = num_heads
93106 head_dim = dim // num_heads
94107 # NOTE scale factor was wrong in my original version, can set manually to be compat with prev weights
95108 self .scale = head_dim ** - 0.5
96109
97- self .wq = nn .Linear (dim , dim , bias = qkv_bias )
98- self .wk = nn .Linear (dim , dim , bias = qkv_bias )
99- self .wv = nn .Linear (dim , dim , bias = qkv_bias )
110+ self .wq = nn .Linear (dim , dim , bias = qkv_bias , ** dd )
111+ self .wk = nn .Linear (dim , dim , bias = qkv_bias , ** dd )
112+ self .wv = nn .Linear (dim , dim , bias = qkv_bias , ** dd )
100113 self .attn_drop = nn .Dropout (attn_drop )
101- self .proj = nn .Linear (dim , dim )
114+ self .proj = nn .Linear (dim , dim , ** dd )
102115 self .proj_drop = nn .Dropout (proj_drop )
103116
104117 def forward (self , x ):
@@ -124,24 +137,28 @@ class CrossAttentionBlock(nn.Module):
124137
125138 def __init__ (
126139 self ,
127- dim ,
128- num_heads ,
129- mlp_ratio = 4. ,
130- qkv_bias = False ,
131- proj_drop = 0. ,
132- attn_drop = 0. ,
133- drop_path = 0. ,
134- act_layer = nn .GELU ,
135- norm_layer = nn .LayerNorm ,
140+ dim : int ,
141+ num_heads : int ,
142+ mlp_ratio : float = 4. ,
143+ qkv_bias : bool = False ,
144+ proj_drop : float = 0. ,
145+ attn_drop : float = 0. ,
146+ drop_path : float = 0. ,
147+ act_layer : Type [nn .Module ] = nn .GELU ,
148+ norm_layer : Type [nn .Module ] = nn .LayerNorm ,
149+ device = None ,
150+ dtype = None ,
136151 ):
152+ dd = {'device' : device , 'dtype' : dtype }
137153 super ().__init__ ()
138- self .norm1 = norm_layer (dim )
154+ self .norm1 = norm_layer (dim , ** dd )
139155 self .attn = CrossAttention (
140156 dim ,
141157 num_heads = num_heads ,
142158 qkv_bias = qkv_bias ,
143159 attn_drop = attn_drop ,
144160 proj_drop = proj_drop ,
161+ ** dd ,
145162 )
146163 # NOTE: drop path for stochastic depth, we shall see if this is better than dropout here
147164 self .drop_path = DropPath (drop_path ) if drop_path > 0. else nn .Identity ()
@@ -155,20 +172,22 @@ class MultiScaleBlock(nn.Module):
155172
156173 def __init__ (
157174 self ,
158- dim ,
159- patches ,
160- depth ,
161- num_heads ,
162- mlp_ratio ,
163- qkv_bias = False ,
164- proj_drop = 0. ,
165- attn_drop = 0. ,
166- drop_path = 0. ,
167- act_layer = nn .GELU ,
168- norm_layer = nn .LayerNorm ,
175+ dim : Tuple [int , ...],
176+ patches : Tuple [int , ...],
177+ depth : Tuple [int , ...],
178+ num_heads : Tuple [int , ...],
179+ mlp_ratio : Tuple [float , ...],
180+ qkv_bias : bool = False ,
181+ proj_drop : float = 0. ,
182+ attn_drop : float = 0. ,
183+ drop_path : Union [List [float ], float ] = 0. ,
184+ act_layer : Type [nn .Module ] = nn .GELU ,
185+ norm_layer : Type [nn .Module ] = nn .LayerNorm ,
186+ device = None ,
187+ dtype = None ,
169188 ):
189+ dd = {'device' : device , 'dtype' : dtype }
170190 super ().__init__ ()
171-
172191 num_branches = len (dim )
173192 self .num_branches = num_branches
174193 # different branch could have different embedding size, the first one is the base
@@ -185,6 +204,7 @@ def __init__(
185204 attn_drop = attn_drop ,
186205 drop_path = drop_path [i ],
187206 norm_layer = norm_layer ,
207+ ** dd ,
188208 ))
189209 if len (tmp ) != 0 :
190210 self .blocks .append (nn .Sequential (* tmp ))
@@ -197,7 +217,7 @@ def __init__(
197217 if dim [d ] == dim [(d + 1 ) % num_branches ] and False :
198218 tmp = [nn .Identity ()]
199219 else :
200- tmp = [norm_layer (dim [d ]), act_layer (), nn .Linear (dim [d ], dim [(d + 1 ) % num_branches ])]
220+ tmp = [norm_layer (dim [d ], ** dd ), act_layer (), nn .Linear (dim [d ], dim [(d + 1 ) % num_branches ], ** dd )]
201221 self .projs .append (nn .Sequential (* tmp ))
202222
203223 self .fusion = nn .ModuleList ()
@@ -215,6 +235,7 @@ def __init__(
215235 attn_drop = attn_drop ,
216236 drop_path = drop_path [- 1 ],
217237 norm_layer = norm_layer ,
238+ ** dd ,
218239 ))
219240 else :
220241 tmp = []
@@ -228,6 +249,7 @@ def __init__(
228249 attn_drop = attn_drop ,
229250 drop_path = drop_path [- 1 ],
230251 norm_layer = norm_layer ,
252+ ** dd ,
231253 ))
232254 self .fusion .append (nn .Sequential (* tmp ))
233255
@@ -236,8 +258,8 @@ def __init__(
236258 if dim [(d + 1 ) % num_branches ] == dim [d ] and False :
237259 tmp = [nn .Identity ()]
238260 else :
239- tmp = [norm_layer (dim [(d + 1 ) % num_branches ]), act_layer (),
240- nn .Linear (dim [(d + 1 ) % num_branches ], dim [d ])]
261+ tmp = [norm_layer (dim [(d + 1 ) % num_branches ], ** dd ), act_layer (),
262+ nn .Linear (dim [(d + 1 ) % num_branches ], dim [d ], ** dd )]
241263 self .revert_projs .append (nn .Sequential (* tmp ))
242264
243265 def forward (self , x : List [torch .Tensor ]) -> List [torch .Tensor ]:
@@ -293,27 +315,30 @@ class CrossVit(nn.Module):
293315
294316 def __init__ (
295317 self ,
296- img_size = 224 ,
297- img_scale = (1.0 , 1.0 ),
298- patch_size = (8 , 16 ),
299- in_chans = 3 ,
300- num_classes = 1000 ,
301- embed_dim = (192 , 384 ),
302- depth = ((1 , 3 , 1 ), (1 , 3 , 1 ), (1 , 3 , 1 )),
303- num_heads = (6 , 12 ),
304- mlp_ratio = (2. , 2. , 4. ),
305- multi_conv = False ,
306- crop_scale = False ,
307- qkv_bias = True ,
308- drop_rate = 0. ,
309- pos_drop_rate = 0. ,
310- proj_drop_rate = 0. ,
311- attn_drop_rate = 0. ,
312- drop_path_rate = 0. ,
313- norm_layer = partial (nn .LayerNorm , eps = 1e-6 ),
314- global_pool = 'token' ,
318+ img_size : int = 224 ,
319+ img_scale : Tuple [float , ...] = (1.0 , 1.0 ),
320+ patch_size : Tuple [int , ...] = (8 , 16 ),
321+ in_chans : int = 3 ,
322+ num_classes : int = 1000 ,
323+ embed_dim : Tuple [int , ...] = (192 , 384 ),
324+ depth : Tuple [Tuple [int , ...], ...] = ((1 , 3 , 1 ), (1 , 3 , 1 ), (1 , 3 , 1 )),
325+ num_heads : Tuple [int , ...] = (6 , 12 ),
326+ mlp_ratio : Tuple [float , ...] = (2. , 2. , 4. ),
327+ multi_conv : bool = False ,
328+ crop_scale : bool = False ,
329+ qkv_bias : bool = True ,
330+ drop_rate : float = 0. ,
331+ pos_drop_rate : float = 0. ,
332+ proj_drop_rate : float = 0. ,
333+ attn_drop_rate : float = 0. ,
334+ drop_path_rate : float = 0. ,
335+ norm_layer : Type [nn .Module ] = partial (nn .LayerNorm , eps = 1e-6 ),
336+ global_pool : str = 'token' ,
337+ device = None ,
338+ dtype = None ,
315339 ):
316340 super ().__init__ ()
341+ dd = {'device' : device , 'dtype' : dtype }
317342 assert global_pool in ('token' , 'avg' )
318343
319344 self .num_classes = num_classes
@@ -330,8 +355,8 @@ def __init__(
330355
331356 # hard-coded for torch jit script
332357 for i in range (self .num_branches ):
333- setattr (self , f'pos_embed_{ i } ' , nn .Parameter (torch .zeros (1 , 1 + num_patches [i ], embed_dim [i ])))
334- setattr (self , f'cls_token_{ i } ' , nn .Parameter (torch .zeros (1 , 1 , embed_dim [i ])))
358+ setattr (self , f'pos_embed_{ i } ' , nn .Parameter (torch .zeros (1 , 1 + num_patches [i ], embed_dim [i ], ** dd )))
359+ setattr (self , f'cls_token_{ i } ' , nn .Parameter (torch .zeros (1 , 1 , embed_dim [i ], ** dd )))
335360
336361 for im_s , p , d in zip (self .img_size_scaled , patch_size , embed_dim ):
337362 self .patch_embed .append (
@@ -341,6 +366,7 @@ def __init__(
341366 in_chans = in_chans ,
342367 embed_dim = d ,
343368 multi_conv = multi_conv ,
369+ ** dd ,
344370 ))
345371
346372 self .pos_drop = nn .Dropout (p = pos_drop_rate )
@@ -363,14 +389,15 @@ def __init__(
363389 attn_drop = attn_drop_rate ,
364390 drop_path = dpr_ ,
365391 norm_layer = norm_layer ,
392+ ** dd ,
366393 )
367394 dpr_ptr += curr_depth
368395 self .blocks .append (blk )
369396
370- self .norm = nn .ModuleList ([norm_layer (embed_dim [i ]) for i in range (self .num_branches )])
397+ self .norm = nn .ModuleList ([norm_layer (embed_dim [i ], ** dd ) for i in range (self .num_branches )])
371398 self .head_drop = nn .Dropout (drop_rate )
372399 self .head = nn .ModuleList ([
373- nn .Linear (embed_dim [i ], num_classes ) if num_classes > 0 else nn .Identity ()
400+ nn .Linear (embed_dim [i ], num_classes , ** dd ) if num_classes > 0 else nn .Identity ()
374401 for i in range (self .num_branches )])
375402
376403 for i in range (self .num_branches ):
@@ -418,8 +445,11 @@ def reset_classifier(self, num_classes: int, global_pool: Optional[str] = None):
418445 if global_pool is not None :
419446 assert global_pool in ('token' , 'avg' )
420447 self .global_pool = global_pool
448+ device = self .head [0 ].weight .device if hasattr (self .head [0 ], 'weight' ) else None
449+ dtype = self .head [0 ].weight .dtype if hasattr (self .head [0 ], 'weight' ) else None
450+ dd = {'device' : device , 'dtype' : dtype }
421451 self .head = nn .ModuleList ([
422- nn .Linear (self .embed_dim [i ], num_classes ) if num_classes > 0 else nn .Identity ()
452+ nn .Linear (self .embed_dim [i ], num_classes , ** dd ) if num_classes > 0 else nn .Identity ()
423453 for i in range (self .num_branches )
424454 ])
425455
0 commit comments