Skip to content

Commit 1e172a0

Browse files
committed
dd kwargs for naflexvit, needs revisit for nn.Parameters
1 parent 8cbbf39 commit 1e172a0

File tree

1 file changed

+28
-12
lines changed

1 file changed

+28
-12
lines changed

timm/models/naflexvit.py

Lines changed: 28 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -203,8 +203,8 @@ def __init__(
203203
unique_sizes: List[Tuple[int, int]],
204204
batch_size: int,
205205
seq_len: int,
206-
dtype: torch.dtype,
207206
device: torch.device,
207+
dtype: torch.dtype,
208208
):
209209
self.rope = rope_module
210210
self.size_to_indices = size_to_indices
@@ -362,6 +362,8 @@ def __init__(
362362
norm_layer: Optional[Type[nn.Module]] = None,
363363
pos_drop_rate: float = 0.,
364364
enable_patch_interpolator: bool = False,
365+
device=None,
366+
dtype=None,
365367
) -> None:
366368
"""Initialize NaFlexEmbeds module.
367369
@@ -385,6 +387,7 @@ def __init__(
385387
pos_drop_rate: Dropout rate for position embeddings.
386388
enable_patch_interpolator: Enable dynamic patch size support.
387389
"""
390+
dd = {'device': device, 'dtype': dtype}
388391
super().__init__()
389392
self.has_class_token = class_token
390393
self.num_reg_tokens = reg_tokens
@@ -402,8 +405,8 @@ def __init__(
402405
self.num_prefix_tokens += reg_tokens
403406

404407
# Create class and register tokens
405-
self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim)) if class_token else None
406-
self.reg_token = nn.Parameter(torch.zeros(1, reg_tokens, embed_dim)) if reg_tokens else None
408+
self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim, **dd)) if class_token else None
409+
self.reg_token = nn.Parameter(torch.zeros(1, reg_tokens, embed_dim, **dd)) if reg_tokens else None
407410

408411
# Calculate grid size and number of patches
409412
self.default_img_size: Optional[Tuple[int, int]] = None
@@ -425,15 +428,20 @@ def __init__(
425428
"`norm_layer` must be given when input_norm_layer=True"
426429
input_norm_layer = norm_layer if input_norm_layer is True else (input_norm_layer or None)
427430
self.norm_input = input_norm_layer(patch_dim) if input_norm_layer else None
428-
self.proj = nn.Linear(patch_dim, embed_dim, bias=proj_bias)
431+
self.proj = nn.Linear(patch_dim, embed_dim, bias=proj_bias, **dd)
429432
self.flatten = False
430433
self.is_linear = True
431434
else:
432435
# Default to convolutional patch embedding for image inputs
433436
assert not input_norm_layer
434437
self.norm_input = None
435438
self.proj = nn.Conv2d(
436-
in_chans, embed_dim, kernel_size=patch_size, stride=patch_size, bias=proj_bias
439+
in_chans,
440+
embed_dim,
441+
kernel_size=patch_size,
442+
stride=patch_size,
443+
bias=proj_bias,
444+
**dd,
437445
)
438446
self.flatten = True
439447
self.is_linear = False
@@ -470,12 +478,12 @@ def __init__(
470478
assert self.pos_embed_grid_size is not None
471479
h, w = self.pos_embed_grid_size
472480
self.pos_embed_type = 'factorized'
473-
self.pos_embed_y = nn.Parameter(torch.randn(1, h, embed_dim) * .02)
474-
self.pos_embed_x = nn.Parameter(torch.randn(1, w, embed_dim) * .02)
481+
self.pos_embed_y = nn.Parameter(torch.randn(1, h, embed_dim, **dd) * .02)
482+
self.pos_embed_x = nn.Parameter(torch.randn(1, w, embed_dim, **dd) * .02)
475483
else:
476484
assert self.pos_embed_grid_size is not None
477485
h, w = self.pos_embed_grid_size
478-
self.pos_embed = nn.Parameter(torch.randn(1, h, w, embed_dim) * .02)
486+
self.pos_embed = nn.Parameter(torch.randn(1, h, w, embed_dim, **dd) * .02)
479487
self.pos_embed_type = 'learned'
480488

481489
# Dropout layer
@@ -1080,6 +1088,8 @@ def __init__(
10801088
in_chans: int = 3,
10811089
num_classes: int = 1000,
10821090
img_size: Optional[Union[int, Tuple[int, int]]] = None,
1091+
device=None,
1092+
dtype=None,
10831093
**kwargs,
10841094
) -> None:
10851095
"""Initialize NaFlexVit model.
@@ -1092,6 +1102,7 @@ def __init__(
10921102
**kwargs: Additional config parameters to override cfg values.
10931103
"""
10941104
super().__init__()
1105+
dd = {'device': device, 'dtype': dtype}
10951106

10961107
# Initialize config
10971108
cfg = cfg or NaFlexVitCfg()
@@ -1141,8 +1152,9 @@ def __init__(
11411152
proj_norm_layer=embed_norm_layer,
11421153
pos_drop_rate=cfg.pos_drop_rate,
11431154
enable_patch_interpolator=getattr(cfg, 'enable_patch_interpolator', False),
1155+
**dd,
11441156
)
1145-
self.norm_pre = norm_layer(cfg.embed_dim) if cfg.pre_norm else nn.Identity()
1157+
self.norm_pre = norm_layer(cfg.embed_dim, **dd) if cfg.pre_norm else nn.Identity()
11461158

11471159
# ROPE position embeddings at model level
11481160
self.rope: Optional[nn.Module] = None
@@ -1157,6 +1169,7 @@ def __init__(
11571169
temperature=cfg.rope_temperature,
11581170
feat_shape=None, # Dynamic shapes for NaFlex
11591171
grid_indexing=cfg.rope_grid_indexing,
1172+
**dd,
11601173
)
11611174
self.rope_is_mixed = True
11621175
elif cfg.rope_type == 'axial':
@@ -1168,6 +1181,7 @@ def __init__(
11681181
ref_feat_shape=cfg.rope_ref_feat_shape,
11691182
grid_offset=cfg.rope_grid_offset,
11701183
grid_indexing=cfg.rope_grid_indexing,
1184+
**dd,
11711185
)
11721186
self.rope_is_mixed = False
11731187
else:
@@ -1200,6 +1214,7 @@ def __init__(
12001214
norm_layer=norm_layer,
12011215
act_layer=act_layer,
12021216
mlp_layer=mlp_layer,
1217+
**dd,
12031218
)
12041219
for i in range(cfg.depth)
12051220
])
@@ -1211,7 +1226,7 @@ def __init__(
12111226
for i in range(cfg.depth)
12121227
]
12131228

1214-
self.norm = norm_layer(cfg.embed_dim) if cfg.final_norm and not cfg.fc_norm else nn.Identity()
1229+
self.norm = norm_layer(cfg.embed_dim, **dd) if cfg.final_norm and not cfg.fc_norm else nn.Identity()
12151230

12161231
# Classifier Head
12171232
if cfg.global_pool == 'map':
@@ -1221,6 +1236,7 @@ def __init__(
12211236
mlp_ratio=cfg.attn_pool_mlp_ratio or cfg.mlp_ratio,
12221237
norm_layer=norm_layer,
12231238
act_layer=act_layer,
1239+
**dd,
12241240
)
12251241
else:
12261242
self.attn_pool = None
@@ -1229,9 +1245,9 @@ def __init__(
12291245
fc_norm = cfg.fc_norm
12301246
if fc_norm is None:
12311247
fc_norm = cfg.global_pool == 'avg'
1232-
self.fc_norm = norm_layer(cfg.embed_dim) if cfg.final_norm and fc_norm else nn.Identity()
1248+
self.fc_norm = norm_layer(cfg.embed_dim, **dd) if cfg.final_norm and fc_norm else nn.Identity()
12331249
self.head_drop = nn.Dropout(cfg.drop_rate)
1234-
self.head = nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity()
1250+
self.head = nn.Linear(self.embed_dim, num_classes, **dd) if num_classes > 0 else nn.Identity()
12351251

12361252
if cfg.weight_init != 'skip':
12371253
self.init_weights(cfg.weight_init)

0 commit comments

Comments
 (0)