44"""
55import re
66from collections import OrderedDict
7- from typing import Any , Dict , Optional , Tuple , Union
7+ from typing import Any , Dict , Optional , Tuple , Type , Union
88
99import torch
1010import torch .nn as nn
@@ -31,9 +31,11 @@ def __init__(
3131 num_input_features : int ,
3232 growth_rate : int ,
3333 bn_size : int ,
34- norm_layer : type = BatchNormAct2d ,
34+ norm_layer : Type [ nn . Module ] = BatchNormAct2d ,
3535 drop_rate : float = 0. ,
3636 grad_checkpointing : bool = False ,
37+ device = None ,
38+ dtype = None ,
3739 ) -> None :
3840 """Initialize DenseLayer.
3941
@@ -45,13 +47,14 @@ def __init__(
4547 drop_rate: Dropout rate.
4648 grad_checkpointing: Use gradient checkpointing.
4749 """
50+ dd = {'device' : device , 'dtype' : dtype }
4851 super (DenseLayer , self ).__init__ ()
49- self .add_module ('norm1' , norm_layer (num_input_features )),
52+ self .add_module ('norm1' , norm_layer (num_input_features , ** dd )),
5053 self .add_module ('conv1' , nn .Conv2d (
51- num_input_features , bn_size * growth_rate , kernel_size = 1 , stride = 1 , bias = False )),
52- self .add_module ('norm2' , norm_layer (bn_size * growth_rate )),
54+ num_input_features , bn_size * growth_rate , kernel_size = 1 , stride = 1 , bias = False , ** dd )),
55+ self .add_module ('norm2' , norm_layer (bn_size * growth_rate , ** dd )),
5356 self .add_module ('conv2' , nn .Conv2d (
54- bn_size * growth_rate , growth_rate , kernel_size = 3 , stride = 1 , padding = 1 , bias = False )),
57+ bn_size * growth_rate , growth_rate , kernel_size = 3 , stride = 1 , padding = 1 , bias = False , ** dd )),
5558 self .drop_rate = float (drop_rate )
5659 self .grad_checkpointing = grad_checkpointing
5760
@@ -129,9 +132,11 @@ def __init__(
129132 num_input_features : int ,
130133 bn_size : int ,
131134 growth_rate : int ,
132- norm_layer : type = BatchNormAct2d ,
135+ norm_layer : Type [ nn . Module ] = BatchNormAct2d ,
133136 drop_rate : float = 0. ,
134137 grad_checkpointing : bool = False ,
138+ device = None ,
139+ dtype = None ,
135140 ) -> None :
136141 """Initialize DenseBlock.
137142
@@ -144,6 +149,7 @@ def __init__(
144149 drop_rate: Dropout rate.
145150 grad_checkpointing: Use gradient checkpointing.
146151 """
152+ dd = {'device' : device , 'dtype' : dtype }
147153 super (DenseBlock , self ).__init__ ()
148154 for i in range (num_layers ):
149155 layer = DenseLayer (
@@ -153,6 +159,7 @@ def __init__(
153159 norm_layer = norm_layer ,
154160 drop_rate = drop_rate ,
155161 grad_checkpointing = grad_checkpointing ,
162+ ** dd ,
156163 )
157164 self .add_module ('denselayer%d' % (i + 1 ), layer )
158165
@@ -182,8 +189,10 @@ def __init__(
182189 self ,
183190 num_input_features : int ,
184191 num_output_features : int ,
185- norm_layer : type = BatchNormAct2d ,
186- aa_layer : Optional [type ] = None ,
192+ norm_layer : Type [nn .Module ] = BatchNormAct2d ,
193+ aa_layer : Optional [Type [nn .Module ]] = None ,
194+ device = None ,
195+ dtype = None ,
187196 ) -> None :
188197 """Initialize DenseTransition.
189198
@@ -193,12 +202,13 @@ def __init__(
193202 norm_layer: Normalization layer class.
194203 aa_layer: Anti-aliasing layer class.
195204 """
205+ dd = {'device' : device , 'dtype' : dtype }
196206 super (DenseTransition , self ).__init__ ()
197- self .add_module ('norm' , norm_layer (num_input_features ))
207+ self .add_module ('norm' , norm_layer (num_input_features , ** dd ))
198208 self .add_module ('conv' , nn .Conv2d (
199- num_input_features , num_output_features , kernel_size = 1 , stride = 1 , bias = False ))
209+ num_input_features , num_output_features , kernel_size = 1 , stride = 1 , bias = False , ** dd ))
200210 if aa_layer is not None :
201- self .add_module ('pool' , aa_layer (num_output_features , stride = 2 ))
211+ self .add_module ('pool' , aa_layer (num_output_features , stride = 2 , ** dd ))
202212 else :
203213 self .add_module ('pool' , nn .AvgPool2d (kernel_size = 2 , stride = 2 ))
204214
@@ -231,11 +241,13 @@ def __init__(
231241 stem_type : str = '' ,
232242 act_layer : str = 'relu' ,
233243 norm_layer : str = 'batchnorm2d' ,
234- aa_layer : Optional [type ] = None ,
244+ aa_layer : Optional [Type [ nn . Module ] ] = None ,
235245 drop_rate : float = 0. ,
236246 proj_drop_rate : float = 0. ,
237247 memory_efficient : bool = False ,
238248 aa_stem_only : bool = True ,
249+ device = None ,
250+ dtype = None ,
239251 ) -> None :
240252 """Initialize DenseNet.
241253
@@ -255,6 +267,7 @@ def __init__(
255267 memory_efficient: If True, uses checkpointing for memory efficiency.
256268 aa_stem_only: Apply anti-aliasing only to stem.
257269 """
270+ dd = {'device' : device , 'dtype' : dtype }
258271 self .num_classes = num_classes
259272 super (DenseNet , self ).__init__ ()
260273 norm_layer = get_norm_act_layer (norm_layer , act_layer = act_layer )
@@ -267,25 +280,25 @@ def __init__(
267280 else :
268281 stem_pool = nn .Sequential (* [
269282 nn .MaxPool2d (kernel_size = 3 , stride = 1 , padding = 1 ),
270- aa_layer (channels = num_init_features , stride = 2 )])
283+ aa_layer (channels = num_init_features , stride = 2 , ** dd )])
271284 if deep_stem :
272285 stem_chs_1 = stem_chs_2 = growth_rate
273286 if 'tiered' in stem_type :
274287 stem_chs_1 = 3 * (growth_rate // 4 )
275288 stem_chs_2 = num_init_features if 'narrow' in stem_type else 6 * (growth_rate // 4 )
276289 self .features = nn .Sequential (OrderedDict ([
277- ('conv0' , nn .Conv2d (in_chans , stem_chs_1 , 3 , stride = 2 , padding = 1 , bias = False )),
278- ('norm0' , norm_layer (stem_chs_1 )),
279- ('conv1' , nn .Conv2d (stem_chs_1 , stem_chs_2 , 3 , stride = 1 , padding = 1 , bias = False )),
280- ('norm1' , norm_layer (stem_chs_2 )),
281- ('conv2' , nn .Conv2d (stem_chs_2 , num_init_features , 3 , stride = 1 , padding = 1 , bias = False )),
282- ('norm2' , norm_layer (num_init_features )),
290+ ('conv0' , nn .Conv2d (in_chans , stem_chs_1 , 3 , stride = 2 , padding = 1 , bias = False , ** dd )),
291+ ('norm0' , norm_layer (stem_chs_1 , ** dd )),
292+ ('conv1' , nn .Conv2d (stem_chs_1 , stem_chs_2 , 3 , stride = 1 , padding = 1 , bias = False , ** dd )),
293+ ('norm1' , norm_layer (stem_chs_2 , ** dd )),
294+ ('conv2' , nn .Conv2d (stem_chs_2 , num_init_features , 3 , stride = 1 , padding = 1 , bias = False , ** dd )),
295+ ('norm2' , norm_layer (num_init_features , ** dd )),
283296 ('pool0' , stem_pool ),
284297 ]))
285298 else :
286299 self .features = nn .Sequential (OrderedDict ([
287- ('conv0' , nn .Conv2d (in_chans , num_init_features , kernel_size = 7 , stride = 2 , padding = 3 , bias = False )),
288- ('norm0' , norm_layer (num_init_features )),
300+ ('conv0' , nn .Conv2d (in_chans , num_init_features , kernel_size = 7 , stride = 2 , padding = 3 , bias = False , ** dd )),
301+ ('norm0' , norm_layer (num_init_features , ** dd )),
289302 ('pool0' , stem_pool ),
290303 ]))
291304 self .feature_info = [
@@ -303,6 +316,7 @@ def __init__(
303316 norm_layer = norm_layer ,
304317 drop_rate = proj_drop_rate ,
305318 grad_checkpointing = memory_efficient ,
319+ ** dd ,
306320 )
307321 module_name = f'denseblock{ (i + 1 )} '
308322 self .features .add_module (module_name , block )
@@ -317,12 +331,13 @@ def __init__(
317331 num_output_features = num_features // 2 ,
318332 norm_layer = norm_layer ,
319333 aa_layer = transition_aa_layer ,
334+ ** dd ,
320335 )
321336 self .features .add_module (f'transition{ i + 1 } ' , trans )
322337 num_features = num_features // 2
323338
324339 # Final batch norm
325- self .features .add_module ('norm5' , norm_layer (num_features ))
340+ self .features .add_module ('norm5' , norm_layer (num_features , ** dd ))
326341
327342 self .feature_info += [dict (num_chs = num_features , reduction = current_stride , module = 'features.norm5' )]
328343 self .num_features = self .head_hidden_size = num_features
@@ -332,6 +347,7 @@ def __init__(
332347 self .num_features ,
333348 self .num_classes ,
334349 pool_type = global_pool ,
350+ ** dd ,
335351 )
336352 self .global_pool = global_pool
337353 self .head_drop = nn .Dropout (drop_rate )
0 commit comments