Skip to content

Commit c7955eb

Browse files
committed
Add dd factory kwargs to all swin transformers and volo
1 parent 6a3342c commit c7955eb

File tree

4 files changed

+344
-172
lines changed

4 files changed

+344
-172
lines changed

timm/models/swin_transformer.py

Lines changed: 65 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
# --------------------------------------------------------
1818
import logging
1919
import math
20-
from typing import Any, Dict, Callable, List, Optional, Set, Tuple, Union
20+
from typing import Any, Dict, Callable, List, Optional, Set, Tuple, Union, Type
2121

2222
import torch
2323
import torch.nn as nn
@@ -77,7 +77,7 @@ def window_reverse(windows: torch.Tensor, window_size: Tuple[int, int], H: int,
7777
return x
7878

7979

80-
def get_relative_position_index(win_h: int, win_w: int) -> torch.Tensor:
80+
def get_relative_position_index(win_h: int, win_w: int, device=None) -> torch.Tensor:
8181
"""Get pair-wise relative position index for each token inside the window.
8282
8383
Args:
@@ -88,7 +88,10 @@ def get_relative_position_index(win_h: int, win_w: int) -> torch.Tensor:
8888
Relative position index tensor.
8989
"""
9090
# get pair-wise relative position index for each token inside the window
91-
coords = torch.stack(ndgrid(torch.arange(win_h), torch.arange(win_w))) # 2, Wh, Ww
91+
coords = torch.stack(ndgrid(
92+
torch.arange(win_h, device=device, dtype=torch.long),
93+
torch.arange(win_w, device=device, dtype=torch.long),
94+
)) # 2, Wh, Ww
9295
coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww
9396
relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww
9497
relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2
@@ -114,6 +117,8 @@ def __init__(
114117
qkv_bias: bool = True,
115118
attn_drop: float = 0.,
116119
proj_drop: float = 0.,
120+
device=None,
121+
dtype=None,
117122
):
118123
"""
119124
Args:
@@ -125,6 +130,7 @@ def __init__(
125130
attn_drop: Dropout ratio of attention weight.
126131
proj_drop: Dropout ratio of output.
127132
"""
133+
dd = {'device': device, 'dtype': dtype}
128134
super().__init__()
129135
self.dim = dim
130136
self.window_size = to_2tuple(window_size) # Wh, Ww
@@ -137,14 +143,19 @@ def __init__(
137143
self.fused_attn = use_fused_attn(experimental=True) # NOTE not tested for prime-time yet
138144

139145
# define a parameter table of relative position bias, shape: 2*Wh-1 * 2*Ww-1, nH
140-
self.relative_position_bias_table = nn.Parameter(torch.zeros((2 * win_h - 1) * (2 * win_w - 1), num_heads))
146+
self.relative_position_bias_table = nn.Parameter(
147+
torch.zeros((2 * win_h - 1) * (2 * win_w - 1), num_heads, **dd))
141148

142149
# get pair-wise relative position index for each token inside the window
143-
self.register_buffer("relative_position_index", get_relative_position_index(win_h, win_w), persistent=False)
150+
self.register_buffer(
151+
"relative_position_index",
152+
get_relative_position_index(win_h, win_w, device=device),
153+
persistent=False,
154+
)
144155

145-
self.qkv = nn.Linear(dim, attn_dim * 3, bias=qkv_bias)
156+
self.qkv = nn.Linear(dim, attn_dim * 3, bias=qkv_bias, **dd)
146157
self.attn_drop = nn.Dropout(attn_drop)
147-
self.proj = nn.Linear(attn_dim, dim)
158+
self.proj = nn.Linear(attn_dim, dim, **dd)
148159
self.proj_drop = nn.Dropout(proj_drop)
149160

150161
trunc_normal_(self.relative_position_bias_table, std=.02)
@@ -169,7 +180,11 @@ def set_window_size(self, window_size: Tuple[int, int]) -> None:
169180
new_window_size=self.window_size,
170181
new_bias_shape=new_bias_shape,
171182
))
172-
self.register_buffer("relative_position_index", get_relative_position_index(win_h, win_w), persistent=False)
183+
self.register_buffer(
184+
"relative_position_index",
185+
get_relative_position_index(win_h, win_w, device=self.relative_position_bias_table.device),
186+
persistent=False,
187+
)
173188

174189
def _get_rel_pos_bias(self) -> torch.Tensor:
175190
relative_position_bias = self.relative_position_bias_table[
@@ -241,8 +256,10 @@ def __init__(
241256
proj_drop: float = 0.,
242257
attn_drop: float = 0.,
243258
drop_path: float = 0.,
244-
act_layer: Callable = nn.GELU,
245-
norm_layer: Callable = nn.LayerNorm,
259+
act_layer: Type[nn.Module] = nn.GELU,
260+
norm_layer: Type[nn.Module] = nn.LayerNorm,
261+
device=None,
262+
dtype=None,
246263
):
247264
"""
248265
Args:
@@ -261,6 +278,7 @@ def __init__(
261278
act_layer: Activation layer.
262279
norm_layer: Normalization layer.
263280
"""
281+
dd = {'device': device, 'dtype': dtype}
264282
super().__init__()
265283
self.dim = dim
266284
self.input_resolution = input_resolution
@@ -271,7 +289,7 @@ def __init__(
271289
self.window_area = self.window_size[0] * self.window_size[1]
272290
self.mlp_ratio = mlp_ratio
273291

274-
self.norm1 = norm_layer(dim)
292+
self.norm1 = norm_layer(dim, **dd)
275293
self.attn = WindowAttention(
276294
dim,
277295
num_heads=num_heads,
@@ -280,25 +298,32 @@ def __init__(
280298
qkv_bias=qkv_bias,
281299
attn_drop=attn_drop,
282300
proj_drop=proj_drop,
301+
**dd,
283302
)
284303
self.drop_path1 = DropPath(drop_path) if drop_path > 0. else nn.Identity()
285304

286-
self.norm2 = norm_layer(dim)
305+
self.norm2 = norm_layer(dim, **dd)
287306
self.mlp = Mlp(
288307
in_features=dim,
289308
hidden_features=int(dim * mlp_ratio),
290309
act_layer=act_layer,
291310
drop=proj_drop,
311+
**dd,
292312
)
293313
self.drop_path2 = DropPath(drop_path) if drop_path > 0. else nn.Identity()
294314

295315
self.register_buffer(
296316
"attn_mask",
297-
None if self.dynamic_mask else self.get_attn_mask(),
317+
None if self.dynamic_mask else self.get_attn_mask(**dd),
298318
persistent=False,
299319
)
300320

301-
def get_attn_mask(self, x: Optional[torch.Tensor] = None) -> Optional[torch.Tensor]:
321+
def get_attn_mask(
322+
self,
323+
x: Optional[torch.Tensor] = None,
324+
device: Optional[torch.device] = None,
325+
dtype: Optional[torch.dtype] = None,
326+
) -> Optional[torch.Tensor]:
302327
if any(self.shift_size):
303328
# calculate attention mask for SW-MSA
304329
if x is not None:
@@ -307,8 +332,8 @@ def get_attn_mask(self, x: Optional[torch.Tensor] = None) -> Optional[torch.Tens
307332
dtype = x.dtype
308333
else:
309334
H, W = self.input_resolution
310-
device = None
311-
dtype = None
335+
device = device
336+
dtype = dtype
312337
H = math.ceil(H / self.window_size[0]) * self.window_size[0]
313338
W = math.ceil(W / self.window_size[1]) * self.window_size[1]
314339
img_mask = torch.zeros((1, H, W, 1), dtype=dtype, device=device) # 1 H W 1
@@ -372,9 +397,11 @@ def set_input_size(
372397
self.window_size, self.shift_size = self._calc_window_shift(window_size)
373398
self.window_area = self.window_size[0] * self.window_size[1]
374399
self.attn.set_window_size(self.window_size)
400+
device = self.attn_mask.device if self.attn_mask is not None else None
401+
dtype = self.attn_mask.dtype if self.attn_mask is not None else None
375402
self.register_buffer(
376403
"attn_mask",
377-
None if self.dynamic_mask else self.get_attn_mask(),
404+
None if self.dynamic_mask else self.get_attn_mask(device=device, dtype=dtype),
378405
persistent=False,
379406
)
380407

@@ -444,19 +471,22 @@ def __init__(
444471
self,
445472
dim: int,
446473
out_dim: Optional[int] = None,
447-
norm_layer: Callable = nn.LayerNorm,
474+
norm_layer: Type[nn.Module] = nn.LayerNorm,
475+
device=None,
476+
dtype=None,
448477
):
449478
"""
450479
Args:
451480
dim: Number of input channels.
452481
out_dim: Number of output channels (or 2 * dim if None)
453482
norm_layer: Normalization layer.
454483
"""
484+
dd = {'device': device, 'dtype': dtype}
455485
super().__init__()
456486
self.dim = dim
457487
self.out_dim = out_dim or 2 * dim
458-
self.norm = norm_layer(4 * dim)
459-
self.reduction = nn.Linear(4 * dim, self.out_dim, bias=False)
488+
self.norm = norm_layer(4 * dim, **dd)
489+
self.reduction = nn.Linear(4 * dim, self.out_dim, bias=False, **dd)
460490

461491
def forward(self, x: torch.Tensor) -> torch.Tensor:
462492
"""Forward pass.
@@ -502,7 +532,9 @@ def __init__(
502532
proj_drop: float = 0.,
503533
attn_drop: float = 0.,
504534
drop_path: Union[List[float], float] = 0.,
505-
norm_layer: Callable = nn.LayerNorm,
535+
norm_layer: Type[nn.Module] = nn.LayerNorm,
536+
device=None,
537+
dtype=None,
506538
):
507539
"""
508540
Args:
@@ -521,6 +553,7 @@ def __init__(
521553
drop_path: Stochastic depth rate.
522554
norm_layer: Normalization layer.
523555
"""
556+
dd = {'device': device, 'dtype': dtype}
524557
super().__init__()
525558
self.dim = dim
526559
self.input_resolution = input_resolution
@@ -536,6 +569,7 @@ def __init__(
536569
dim=dim,
537570
out_dim=out_dim,
538571
norm_layer=norm_layer,
572+
**dd,
539573
)
540574
else:
541575
assert dim == out_dim
@@ -558,6 +592,7 @@ def __init__(
558592
attn_drop=attn_drop,
559593
drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path,
560594
norm_layer=norm_layer,
595+
**dd,
561596
)
562597
for i in range(depth)])
563598

@@ -631,9 +666,11 @@ def __init__(
631666
proj_drop_rate: float = 0.,
632667
attn_drop_rate: float = 0.,
633668
drop_path_rate: float = 0.1,
634-
embed_layer: Callable = PatchEmbed,
635-
norm_layer: Union[str, Callable] = nn.LayerNorm,
669+
embed_layer: Type[nn.Module] = PatchEmbed,
670+
norm_layer: Union[str, Type[nn.Module]] = nn.LayerNorm,
636671
weight_init: str = '',
672+
device=None,
673+
dtype=None,
637674
**kwargs,
638675
):
639676
"""
@@ -656,6 +693,7 @@ def __init__(
656693
norm_layer (nn.Module): Normalization layer.
657694
"""
658695
super().__init__()
696+
dd = {'device': device, 'dtype': dtype}
659697
assert global_pool in ('', 'avg')
660698
self.num_classes = num_classes
661699
self.global_pool = global_pool
@@ -678,6 +716,7 @@ def __init__(
678716
norm_layer=norm_layer,
679717
strict_img_size=strict_img_size,
680718
output_fmt='NHWC',
719+
**dd,
681720
)
682721
patch_grid = self.patch_embed.grid_size
683722

@@ -715,20 +754,22 @@ def __init__(
715754
attn_drop=attn_drop_rate,
716755
drop_path=dpr[i],
717756
norm_layer=norm_layer,
757+
**dd,
718758
)]
719759
in_dim = out_dim
720760
if i > 0:
721761
scale *= 2
722762
self.feature_info += [dict(num_chs=out_dim, reduction=patch_size * scale, module=f'layers.{i}')]
723763
self.layers = nn.Sequential(*layers)
724764

725-
self.norm = norm_layer(self.num_features)
765+
self.norm = norm_layer(self.num_features, **dd)
726766
self.head = ClassifierHead(
727767
self.num_features,
728768
num_classes,
729769
pool_type=global_pool,
730770
drop_rate=drop_rate,
731771
input_fmt=self.output_fmt,
772+
**dd,
732773
)
733774
if weight_init != 'skip':
734775
self.init_weights(weight_init)

0 commit comments

Comments
 (0)