1717# --------------------------------------------------------
1818import logging
1919import math
20- from typing import Any , Dict , Callable , List , Optional , Set , Tuple , Union
20+ from typing import Any , Dict , Callable , List , Optional , Set , Tuple , Union , Type
2121
2222import torch
2323import torch .nn as nn
@@ -77,7 +77,7 @@ def window_reverse(windows: torch.Tensor, window_size: Tuple[int, int], H: int,
7777 return x
7878
7979
80- def get_relative_position_index (win_h : int , win_w : int ) -> torch .Tensor :
80+ def get_relative_position_index (win_h : int , win_w : int , device = None ) -> torch .Tensor :
8181 """Get pair-wise relative position index for each token inside the window.
8282
8383 Args:
@@ -88,7 +88,10 @@ def get_relative_position_index(win_h: int, win_w: int) -> torch.Tensor:
8888 Relative position index tensor.
8989 """
9090 # get pair-wise relative position index for each token inside the window
91- coords = torch .stack (ndgrid (torch .arange (win_h ), torch .arange (win_w ))) # 2, Wh, Ww
91+ coords = torch .stack (ndgrid (
92+ torch .arange (win_h , device = device , dtype = torch .long ),
93+ torch .arange (win_w , device = device , dtype = torch .long ),
94+ )) # 2, Wh, Ww
9295 coords_flatten = torch .flatten (coords , 1 ) # 2, Wh*Ww
9396 relative_coords = coords_flatten [:, :, None ] - coords_flatten [:, None , :] # 2, Wh*Ww, Wh*Ww
9497 relative_coords = relative_coords .permute (1 , 2 , 0 ).contiguous () # Wh*Ww, Wh*Ww, 2
@@ -114,6 +117,8 @@ def __init__(
114117 qkv_bias : bool = True ,
115118 attn_drop : float = 0. ,
116119 proj_drop : float = 0. ,
120+ device = None ,
121+ dtype = None ,
117122 ):
118123 """
119124 Args:
@@ -125,6 +130,7 @@ def __init__(
125130 attn_drop: Dropout ratio of attention weight.
126131 proj_drop: Dropout ratio of output.
127132 """
133+ dd = {'device' : device , 'dtype' : dtype }
128134 super ().__init__ ()
129135 self .dim = dim
130136 self .window_size = to_2tuple (window_size ) # Wh, Ww
@@ -137,14 +143,19 @@ def __init__(
137143 self .fused_attn = use_fused_attn (experimental = True ) # NOTE not tested for prime-time yet
138144
139145 # define a parameter table of relative position bias, shape: 2*Wh-1 * 2*Ww-1, nH
140- self .relative_position_bias_table = nn .Parameter (torch .zeros ((2 * win_h - 1 ) * (2 * win_w - 1 ), num_heads ))
146+ self .relative_position_bias_table = nn .Parameter (
147+ torch .zeros ((2 * win_h - 1 ) * (2 * win_w - 1 ), num_heads , ** dd ))
141148
142149 # get pair-wise relative position index for each token inside the window
143- self .register_buffer ("relative_position_index" , get_relative_position_index (win_h , win_w ), persistent = False )
150+ self .register_buffer (
151+ "relative_position_index" ,
152+ get_relative_position_index (win_h , win_w , device = device ),
153+ persistent = False ,
154+ )
144155
145- self .qkv = nn .Linear (dim , attn_dim * 3 , bias = qkv_bias )
156+ self .qkv = nn .Linear (dim , attn_dim * 3 , bias = qkv_bias , ** dd )
146157 self .attn_drop = nn .Dropout (attn_drop )
147- self .proj = nn .Linear (attn_dim , dim )
158+ self .proj = nn .Linear (attn_dim , dim , ** dd )
148159 self .proj_drop = nn .Dropout (proj_drop )
149160
150161 trunc_normal_ (self .relative_position_bias_table , std = .02 )
@@ -169,7 +180,11 @@ def set_window_size(self, window_size: Tuple[int, int]) -> None:
169180 new_window_size = self .window_size ,
170181 new_bias_shape = new_bias_shape ,
171182 ))
172- self .register_buffer ("relative_position_index" , get_relative_position_index (win_h , win_w ), persistent = False )
183+ self .register_buffer (
184+ "relative_position_index" ,
185+ get_relative_position_index (win_h , win_w , device = self .relative_position_bias_table .device ),
186+ persistent = False ,
187+ )
173188
174189 def _get_rel_pos_bias (self ) -> torch .Tensor :
175190 relative_position_bias = self .relative_position_bias_table [
@@ -241,8 +256,10 @@ def __init__(
241256 proj_drop : float = 0. ,
242257 attn_drop : float = 0. ,
243258 drop_path : float = 0. ,
244- act_layer : Callable = nn .GELU ,
245- norm_layer : Callable = nn .LayerNorm ,
259+ act_layer : Type [nn .Module ] = nn .GELU ,
260+ norm_layer : Type [nn .Module ] = nn .LayerNorm ,
261+ device = None ,
262+ dtype = None ,
246263 ):
247264 """
248265 Args:
@@ -261,6 +278,7 @@ def __init__(
261278 act_layer: Activation layer.
262279 norm_layer: Normalization layer.
263280 """
281+ dd = {'device' : device , 'dtype' : dtype }
264282 super ().__init__ ()
265283 self .dim = dim
266284 self .input_resolution = input_resolution
@@ -271,7 +289,7 @@ def __init__(
271289 self .window_area = self .window_size [0 ] * self .window_size [1 ]
272290 self .mlp_ratio = mlp_ratio
273291
274- self .norm1 = norm_layer (dim )
292+ self .norm1 = norm_layer (dim , ** dd )
275293 self .attn = WindowAttention (
276294 dim ,
277295 num_heads = num_heads ,
@@ -280,25 +298,32 @@ def __init__(
280298 qkv_bias = qkv_bias ,
281299 attn_drop = attn_drop ,
282300 proj_drop = proj_drop ,
301+ ** dd ,
283302 )
284303 self .drop_path1 = DropPath (drop_path ) if drop_path > 0. else nn .Identity ()
285304
286- self .norm2 = norm_layer (dim )
305+ self .norm2 = norm_layer (dim , ** dd )
287306 self .mlp = Mlp (
288307 in_features = dim ,
289308 hidden_features = int (dim * mlp_ratio ),
290309 act_layer = act_layer ,
291310 drop = proj_drop ,
311+ ** dd ,
292312 )
293313 self .drop_path2 = DropPath (drop_path ) if drop_path > 0. else nn .Identity ()
294314
295315 self .register_buffer (
296316 "attn_mask" ,
297- None if self .dynamic_mask else self .get_attn_mask (),
317+ None if self .dynamic_mask else self .get_attn_mask (** dd ),
298318 persistent = False ,
299319 )
300320
301- def get_attn_mask (self , x : Optional [torch .Tensor ] = None ) -> Optional [torch .Tensor ]:
321+ def get_attn_mask (
322+ self ,
323+ x : Optional [torch .Tensor ] = None ,
324+ device : Optional [torch .device ] = None ,
325+ dtype : Optional [torch .dtype ] = None ,
326+ ) -> Optional [torch .Tensor ]:
302327 if any (self .shift_size ):
303328 # calculate attention mask for SW-MSA
304329 if x is not None :
@@ -307,8 +332,8 @@ def get_attn_mask(self, x: Optional[torch.Tensor] = None) -> Optional[torch.Tens
307332 dtype = x .dtype
308333 else :
309334 H , W = self .input_resolution
310- device = None
311- dtype = None
335+ device = device
336+ dtype = dtype
312337 H = math .ceil (H / self .window_size [0 ]) * self .window_size [0 ]
313338 W = math .ceil (W / self .window_size [1 ]) * self .window_size [1 ]
314339 img_mask = torch .zeros ((1 , H , W , 1 ), dtype = dtype , device = device ) # 1 H W 1
@@ -372,9 +397,11 @@ def set_input_size(
372397 self .window_size , self .shift_size = self ._calc_window_shift (window_size )
373398 self .window_area = self .window_size [0 ] * self .window_size [1 ]
374399 self .attn .set_window_size (self .window_size )
400+ device = self .attn_mask .device if self .attn_mask is not None else None
401+ dtype = self .attn_mask .dtype if self .attn_mask is not None else None
375402 self .register_buffer (
376403 "attn_mask" ,
377- None if self .dynamic_mask else self .get_attn_mask (),
404+ None if self .dynamic_mask else self .get_attn_mask (device = device , dtype = dtype ),
378405 persistent = False ,
379406 )
380407
@@ -444,19 +471,22 @@ def __init__(
444471 self ,
445472 dim : int ,
446473 out_dim : Optional [int ] = None ,
447- norm_layer : Callable = nn .LayerNorm ,
474+ norm_layer : Type [nn .Module ] = nn .LayerNorm ,
475+ device = None ,
476+ dtype = None ,
448477 ):
449478 """
450479 Args:
451480 dim: Number of input channels.
452481 out_dim: Number of output channels (or 2 * dim if None)
453482 norm_layer: Normalization layer.
454483 """
484+ dd = {'device' : device , 'dtype' : dtype }
455485 super ().__init__ ()
456486 self .dim = dim
457487 self .out_dim = out_dim or 2 * dim
458- self .norm = norm_layer (4 * dim )
459- self .reduction = nn .Linear (4 * dim , self .out_dim , bias = False )
488+ self .norm = norm_layer (4 * dim , ** dd )
489+ self .reduction = nn .Linear (4 * dim , self .out_dim , bias = False , ** dd )
460490
461491 def forward (self , x : torch .Tensor ) -> torch .Tensor :
462492 """Forward pass.
@@ -502,7 +532,9 @@ def __init__(
502532 proj_drop : float = 0. ,
503533 attn_drop : float = 0. ,
504534 drop_path : Union [List [float ], float ] = 0. ,
505- norm_layer : Callable = nn .LayerNorm ,
535+ norm_layer : Type [nn .Module ] = nn .LayerNorm ,
536+ device = None ,
537+ dtype = None ,
506538 ):
507539 """
508540 Args:
@@ -521,6 +553,7 @@ def __init__(
521553 drop_path: Stochastic depth rate.
522554 norm_layer: Normalization layer.
523555 """
556+ dd = {'device' : device , 'dtype' : dtype }
524557 super ().__init__ ()
525558 self .dim = dim
526559 self .input_resolution = input_resolution
@@ -536,6 +569,7 @@ def __init__(
536569 dim = dim ,
537570 out_dim = out_dim ,
538571 norm_layer = norm_layer ,
572+ ** dd ,
539573 )
540574 else :
541575 assert dim == out_dim
@@ -558,6 +592,7 @@ def __init__(
558592 attn_drop = attn_drop ,
559593 drop_path = drop_path [i ] if isinstance (drop_path , list ) else drop_path ,
560594 norm_layer = norm_layer ,
595+ ** dd ,
561596 )
562597 for i in range (depth )])
563598
@@ -631,9 +666,11 @@ def __init__(
631666 proj_drop_rate : float = 0. ,
632667 attn_drop_rate : float = 0. ,
633668 drop_path_rate : float = 0.1 ,
634- embed_layer : Callable = PatchEmbed ,
635- norm_layer : Union [str , Callable ] = nn .LayerNorm ,
669+ embed_layer : Type [ nn . Module ] = PatchEmbed ,
670+ norm_layer : Union [str , Type [ nn . Module ] ] = nn .LayerNorm ,
636671 weight_init : str = '' ,
672+ device = None ,
673+ dtype = None ,
637674 ** kwargs ,
638675 ):
639676 """
@@ -656,6 +693,7 @@ def __init__(
656693 norm_layer (nn.Module): Normalization layer.
657694 """
658695 super ().__init__ ()
696+ dd = {'device' : device , 'dtype' : dtype }
659697 assert global_pool in ('' , 'avg' )
660698 self .num_classes = num_classes
661699 self .global_pool = global_pool
@@ -678,6 +716,7 @@ def __init__(
678716 norm_layer = norm_layer ,
679717 strict_img_size = strict_img_size ,
680718 output_fmt = 'NHWC' ,
719+ ** dd ,
681720 )
682721 patch_grid = self .patch_embed .grid_size
683722
@@ -715,20 +754,22 @@ def __init__(
715754 attn_drop = attn_drop_rate ,
716755 drop_path = dpr [i ],
717756 norm_layer = norm_layer ,
757+ ** dd ,
718758 )]
719759 in_dim = out_dim
720760 if i > 0 :
721761 scale *= 2
722762 self .feature_info += [dict (num_chs = out_dim , reduction = patch_size * scale , module = f'layers.{ i } ' )]
723763 self .layers = nn .Sequential (* layers )
724764
725- self .norm = norm_layer (self .num_features )
765+ self .norm = norm_layer (self .num_features , ** dd )
726766 self .head = ClassifierHead (
727767 self .num_features ,
728768 num_classes ,
729769 pool_type = global_pool ,
730770 drop_rate = drop_rate ,
731771 input_fmt = self .output_fmt ,
772+ ** dd ,
732773 )
733774 if weight_init != 'skip' :
734775 self .init_weights (weight_init )
0 commit comments