@@ -25,6 +25,43 @@ def weights_init_classifier(m):
2525 init .normal_ (m .weight .data , std = 0.001 )
2626 init .constant_ (m .bias .data , 0.0 )
2727
28+ class USAM (nn .Module ):
29+ #Joint Representation Learning and Keypoint Detection for Cross-view Geo-localization. TIP2022
30+ def __init__ (self , kernel_size = 3 , padding = 1 , polish = True ):
31+ super (USAM , self ).__init__ ()
32+
33+ kernel = torch .ones ((kernel_size , kernel_size ))
34+ kernel = kernel .unsqueeze (0 ).unsqueeze (0 )
35+ self .weight = nn .Parameter (data = kernel , requires_grad = False )
36+
37+
38+ kernel2 = torch .ones ((1 , 1 )) * (kernel_size * kernel_size )
39+ kernel2 = kernel2 .unsqueeze (0 ).unsqueeze (0 )
40+ self .weight2 = nn .Parameter (data = kernel2 , requires_grad = False )
41+
42+ self .polish = polish
43+ self .pad = padding
44+ self .relu = nn .ReLU ()
45+ self .bn = nn .BatchNorm2d (1 )
46+
47+ def __call__ (self , x ):
48+ fmap = x .sum (1 , keepdim = True )
49+ x1 = F .conv2d (fmap , self .weight , padding = self .pad )
50+ x2 = F .conv2d (fmap , self .weight2 , padding = 0 )
51+
52+ att = x2 - x1
53+ att = self .bn (att )
54+ att = self .relu (att )
55+
56+ if self .polish :
57+ att [:, :, :, 0 ] = 0
58+ att [:, :, :, - 1 ] = 0
59+ att [:, :, 0 , :] = 0
60+ att [:, :, - 1 , :] = 0
61+
62+ output = x + att * x
63+
64+ return output
2865
2966def activate_drop (m ):
3067 classname = m .__class__ .__name__
@@ -73,7 +110,7 @@ def forward(self, x):
73110# Define the ResNet50-based Model
74111class ft_net (nn .Module ):
75112
76- def __init__ (self , class_num = 751 , droprate = 0.5 , stride = 2 , circle = False , ibn = False , linear_num = 512 ):
113+ def __init__ (self , class_num = 751 , droprate = 0.5 , stride = 2 , circle = False , ibn = False , linear_num = 512 , usam = False ):
77114 super (ft_net , self ).__init__ ()
78115 model_ft = models .resnet50 (pretrained = True )
79116 if ibn == True :
@@ -86,13 +123,20 @@ def __init__(self, class_num=751, droprate=0.5, stride=2, circle=False, ibn=Fals
86123 self .model = model_ft
87124 self .circle = circle
88125 self .classifier = ClassBlock (2048 , class_num , droprate , linear = linear_num , return_f = circle )
126+ if usam :
127+ self .usam_1 = USAM ()
128+ self .usam_2 = USAM ()
89129
90130 def forward (self , x ):
91131 x = self .model .conv1 (x )
92132 x = self .model .bn1 (x )
93133 x = self .model .relu (x )
134+ if self .usam :
135+ x = self .usam_1 (x )
94136 x = self .model .maxpool (x )
95137 x = self .model .layer1 (x )
138+ if self .usam :
139+ x = self .usam_2 (x )
96140 x = self .model .layer2 (x )
97141 x = self .model .layer3 (x )
98142 x = self .model .layer4 (x )
0 commit comments