-
Notifications
You must be signed in to change notification settings - Fork 582
Open
Description
def forward(self, x: torch.Tensor) -> torch.Tensor:
x = self.patch_embed(x) # [2, 3, 1024, 1024] -> [2, 64, 64, 768]
if self.pos_embed is not None:
x = x + self.pos_embed
global_features = []
for i, blk in enumerate(self.blocks):
x = blk(x) # [2, 64, 64, 768] -> [2, 64, 64, 768]
if self.sam_hd and blk.window_size == 0: # "global_attn_indexes": [2, 5, 8, 11]
global_features.append(x)
x = self.neck(x.permute(0, 3, 1, 2)) # [2, 64, 64, 768]->[2, 256, 64, 64]
x_dtype = x.dtype
x = F.interpolate(
x.float(), size=(96, 96), mode="bilinear", align_corners=False
).to(x_dtype) # [2, 64, 64, 768]->[2, 256, 96, 96]
x = self.downsamples(x) # [2, 256, 96, 96]->[2, 1024, 24, 24]
if self.sam_hd:
first_global_feature = self.neck_hd(global_features[0].permute(0, 3, 1, 2)) # [2, 64, 64, 768]->[2, 256, 64, 64]
x_dtype = first_global_feature.dtype
first_global_feature = F.interpolate(
first_global_feature.float(),
size=(96, 96),
mode="bilinear",
align_corners=False,
) # [2, 256, 64, 64] -> [2, 256, 96, 96]
first_global_feature = self.downsamples(first_global_feature.to(x_dtype)) # [2, 256, 96, 96] -> [2, 1024, 24, 24]
x = x + first_global_feature * self.hd_alpha_downsamples #这里self.hd_alpha_downsamples = 0,也就是first_global_feature的计算没有意义
上面的代码是sam.py文件中的一段,看上面的代码,这里设置self.hd_alpha_downsamples一直为0,计算first_global_feature的作用是什么,这里其实不是一直都是0,在外部会对这个变量进行修改。
Metadata
Metadata
Assignees
Labels
No labels