Skip to content

Commit 12bb590

Browse files
authored
Update model.py
1 parent 8dfcef6 commit 12bb590

File tree

1 file changed

+45
-1
lines changed

1 file changed

+45
-1
lines changed

model.py

Lines changed: 45 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,43 @@ def weights_init_classifier(m):
2525
init.normal_(m.weight.data, std=0.001)
2626
init.constant_(m.bias.data, 0.0)
2727

28+
class USAM(nn.Module):
29+
#Joint Representation Learning and Keypoint Detection for Cross-view Geo-localization. TIP2022
30+
def __init__(self, kernel_size=3, padding=1, polish=True):
31+
super(USAM, self).__init__()
32+
33+
kernel = torch.ones((kernel_size, kernel_size))
34+
kernel = kernel.unsqueeze(0).unsqueeze(0)
35+
self.weight = nn.Parameter(data=kernel, requires_grad=False)
36+
37+
38+
kernel2 = torch.ones((1, 1)) * (kernel_size * kernel_size)
39+
kernel2 = kernel2.unsqueeze(0).unsqueeze(0)
40+
self.weight2 = nn.Parameter(data=kernel2, requires_grad=False)
41+
42+
self.polish = polish
43+
self.pad = padding
44+
self.relu = nn.ReLU()
45+
self.bn = nn.BatchNorm2d(1)
46+
47+
def __call__(self, x):
48+
fmap = x.sum(1, keepdim=True)
49+
x1 = F.conv2d(fmap, self.weight, padding=self.pad)
50+
x2 = F.conv2d(fmap, self.weight2, padding=0)
51+
52+
att = x2 - x1
53+
att = self.bn(att)
54+
att = self.relu(att)
55+
56+
if self.polish:
57+
att[:, :, :, 0] = 0
58+
att[:, :, :, -1] = 0
59+
att[:, :, 0, :] = 0
60+
att[:, :, -1, :] = 0
61+
62+
output = x + att * x
63+
64+
return output
2865

2966
def activate_drop(m):
3067
classname = m.__class__.__name__
@@ -73,7 +110,7 @@ def forward(self, x):
73110
# Define the ResNet50-based Model
74111
class ft_net(nn.Module):
75112

76-
def __init__(self, class_num=751, droprate=0.5, stride=2, circle=False, ibn=False, linear_num=512):
113+
def __init__(self, class_num=751, droprate=0.5, stride=2, circle=False, ibn=False, linear_num=512, usam=False):
77114
super(ft_net, self).__init__()
78115
model_ft = models.resnet50(pretrained=True)
79116
if ibn==True:
@@ -86,13 +123,20 @@ def __init__(self, class_num=751, droprate=0.5, stride=2, circle=False, ibn=Fals
86123
self.model = model_ft
87124
self.circle = circle
88125
self.classifier = ClassBlock(2048, class_num, droprate, linear=linear_num, return_f = circle)
126+
if usam:
127+
self.usam_1 = USAM()
128+
self.usam_2 = USAM()
89129

90130
def forward(self, x):
91131
x = self.model.conv1(x)
92132
x = self.model.bn1(x)
93133
x = self.model.relu(x)
134+
if self.usam:
135+
x = self.usam_1(x)
94136
x = self.model.maxpool(x)
95137
x = self.model.layer1(x)
138+
if self.usam:
139+
x = self.usam_2(x)
96140
x = self.model.layer2(x)
97141
x = self.model.layer3(x)
98142
x = self.model.layer4(x)

0 commit comments

Comments
 (0)