|
| 1 | +from typing import Optional, Type |
| 2 | +from warnings import warn |
| 3 | + |
| 4 | +import torch |
| 5 | +from torch import nn |
| 6 | + |
| 7 | +from captum.optim.models._common import RedirectedReluLayer, SkipLayer |
| 8 | + |
| 9 | +GS_SAVED_WEIGHTS_URL = ( |
| 10 | + "https://pytorch.s3.amazonaws.com/models/captum/clip_resnet50x4_image.pt" |
| 11 | +) |
| 12 | + |
| 13 | + |
| 14 | +def clip_resnet50x4_image( |
| 15 | + pretrained: bool = False, |
| 16 | + progress: bool = True, |
| 17 | + model_path: Optional[str] = None, |
| 18 | + **kwargs |
| 19 | +) -> "CLIP_ResNet50x4Image": |
| 20 | + """ |
| 21 | + The visual portion of OpenAI's ResNet 50x4 CLIP model from 'Learning Transferable |
| 22 | + Visual Models From Natural Language Supervision': https://arxiv.org/abs/2103.00020 |
| 23 | +
|
| 24 | + This model can be combined with the CLIP ResNet 50x4 Text model to create the full |
| 25 | + CLIP ResNet 50x4 model. |
| 26 | +
|
| 27 | + AvgPool2d layers were replaced with AdaptiveAvgPool2d to allow for any input height |
| 28 | + and width size, though the best results are obtained by using the model's intended |
| 29 | + input height and width of 288x288. |
| 30 | +
|
| 31 | + See here for more details: |
| 32 | + https://github.com/openai/CLIP |
| 33 | + https://github.com/mlfoundations/open_clip |
| 34 | +
|
| 35 | + Args: |
| 36 | +
|
| 37 | + pretrained (bool, optional): If True, returns a pre-trained model. |
| 38 | + Default: False |
| 39 | + progress (bool, optional): If True, displays a progress bar of the download to |
| 40 | + stderr |
| 41 | + Default: True |
| 42 | + model_path (str, optional): Optional path for the model file. |
| 43 | + Default: None |
| 44 | + replace_relus_with_redirectedrelu (bool, optional): If True, return pretrained |
| 45 | + model with Redirected ReLU in place of ReLU layers. |
| 46 | + Default: *True* when pretrained is True otherwise *False* |
| 47 | + use_linear_modules_only (bool, optional): If True, return model |
| 48 | + with all nonlinear layers replaced with linear equivalents. |
| 49 | + Default: False |
| 50 | + transform_input (bool, optional): If True, preprocesses the input according to |
| 51 | + the method with which it was trained. |
| 52 | + Default: *True* when pretrained is True otherwise *False* |
| 53 | +
|
| 54 | + Returns: |
| 55 | + **CLIP_ResNet50x4Image** (CLIP_ResNet50x4Image): A CLIP ResNet 50x4 model's |
| 56 | + image portion. |
| 57 | + """ |
| 58 | + if pretrained: |
| 59 | + if "transform_input" not in kwargs: |
| 60 | + kwargs["transform_input"] = True |
| 61 | + if "replace_relus_with_redirectedrelu" not in kwargs: |
| 62 | + kwargs["replace_relus_with_redirectedrelu"] = True |
| 63 | + if "use_linear_modules_only" not in kwargs: |
| 64 | + kwargs["use_linear_modules_only"] = False |
| 65 | + |
| 66 | + model = CLIP_ResNet50x4Image(**kwargs) |
| 67 | + |
| 68 | + if model_path is None: |
| 69 | + state_dict = torch.hub.load_state_dict_from_url( |
| 70 | + GS_SAVED_WEIGHTS_URL, progress=progress, check_hash=False |
| 71 | + ) |
| 72 | + else: |
| 73 | + state_dict = torch.load(model_path, map_location="cpu") |
| 74 | + model.load_state_dict(state_dict) |
| 75 | + return model |
| 76 | + |
| 77 | + return CLIP_ResNet50x4Image(**kwargs) |
| 78 | + |
| 79 | + |
| 80 | +class CLIP_ResNet50x4Image(nn.Module): |
| 81 | + """ |
| 82 | + The visual portion of OpenAI's ResNet 50x4 CLIP model from 'Learning Transferable |
| 83 | + Visual Models From Natural Language Supervision': https://arxiv.org/abs/2103.00020 |
| 84 | + """ |
| 85 | + __constants__ = ["transform_input"] |
| 86 | + |
| 87 | + def __init__( |
| 88 | + self, |
| 89 | + transform_input: bool = False, |
| 90 | + replace_relus_with_redirectedrelu: bool = False, |
| 91 | + use_linear_modules_only: bool = False, |
| 92 | + ) -> None: |
| 93 | + """ |
| 94 | + Args: |
| 95 | +
|
| 96 | + replace_relus_with_redirectedrelu (bool, optional): If True, return |
| 97 | + model with Redirected ReLU in place of ReLU layers. |
| 98 | + Default: False |
| 99 | + use_linear_modules_only (bool, optional): If True, return model with |
| 100 | + all nonlinear layers replaced with linear equivalents. |
| 101 | + Default: False |
| 102 | + transform_input (bool, optional): If True, preprocesses the input according |
| 103 | + to the method with which it was trained on. |
| 104 | + Default: False |
| 105 | + """ |
| 106 | + super().__init__() |
| 107 | + if use_linear_modules_only: |
| 108 | + activ = SkipLayer |
| 109 | + else: |
| 110 | + if replace_relus_with_redirectedrelu: |
| 111 | + activ = RedirectedReluLayer |
| 112 | + else: |
| 113 | + activ = nn.ReLU |
| 114 | + |
| 115 | + self.transform_input = transform_input |
| 116 | + |
| 117 | + # Stem layers |
| 118 | + self.conv1 = nn.Conv2d(3, 40, kernel_size=3, stride=2, padding=1, bias=False) |
| 119 | + self.bn1 = nn.BatchNorm2d(40) |
| 120 | + self.relu1 = activ() |
| 121 | + self.conv2 = nn.Conv2d(40, 40, kernel_size=3, padding=1, bias=False) |
| 122 | + self.bn2 = nn.BatchNorm2d(40) |
| 123 | + self.relu2 = activ() |
| 124 | + self.conv3 = nn.Conv2d(40, 80, kernel_size=3, padding=1, bias=False) |
| 125 | + self.bn3 = nn.BatchNorm2d(80) |
| 126 | + self.relu3 = activ() |
| 127 | + self.avgpool = nn.AdaptiveAvgPool2d(72) |
| 128 | + |
| 129 | + # Residual layers |
| 130 | + self.layer1 = self._build_layer(80, 80, 4, stride=1, pooling=72, activ=activ) |
| 131 | + self.layer2 = self._build_layer(320, 160, 6, stride=2, pooling=36, activ=activ) |
| 132 | + self.layer3 = self._build_layer(640, 320, 10, stride=2, pooling=18, activ=activ) |
| 133 | + self.layer4 = self._build_layer(1280, 640, 6, stride=2, pooling=9, activ=activ) |
| 134 | + |
| 135 | + # Attention Pooling |
| 136 | + self.attnpool = AttentionPool2d(9, 2560, out_features=640, num_heads=40) |
| 137 | + |
| 138 | + def _build_layer( |
| 139 | + self, |
| 140 | + inplanes: int = 80, |
| 141 | + planes: int = 80, |
| 142 | + blocks: int = 4, |
| 143 | + stride: int = 1, |
| 144 | + pooling: int = 72, |
| 145 | + activ: Type[nn.Module] = nn.ReLU, |
| 146 | + ) -> nn.Module: |
| 147 | + """ |
| 148 | + Residual layer creation helper function. |
| 149 | +
|
| 150 | + Args: |
| 151 | +
|
| 152 | + inplanes (int, optional): The number of input channels / features to use |
| 153 | + for the first layer. |
| 154 | + Default: 80 |
| 155 | + planes (int, optional): The number of output channels / features to use |
| 156 | + for the first layer. This variable is then multiplied by 4 to get the |
| 157 | + number of input channels / features to use for the subsequent layers. |
| 158 | + Default: 80 |
| 159 | + blocks (int, optional): The number of Bottleneck layers to create. |
| 160 | + Default: 4 |
| 161 | + stride (int, optional): The stride value to use for the Bottleneck layers. |
| 162 | + Default: 1 |
| 163 | + pooling (int, optional): The output size used for nn.AdaptiveAvgPool2d. |
| 164 | + Default: 72 |
| 165 | + activ (type of nn.Module, optional): The nn.Module class type to use for |
| 166 | + activation layers. |
| 167 | + Default: nn.ReLU |
| 168 | +
|
| 169 | + Returns: |
| 170 | + residual_layer (nn.Sequential): A full residual layer. |
| 171 | + """ |
| 172 | + layers = [Bottleneck(inplanes, planes, stride, pooling=pooling, activ=activ)] |
| 173 | + for _ in range(blocks - 1): |
| 174 | + layers += [Bottleneck(planes * 4, planes, pooling=pooling, activ=activ)] |
| 175 | + return nn.Sequential(*layers) |
| 176 | + |
| 177 | + def _transform_input(self, x: torch.Tensor) -> torch.Tensor: |
| 178 | + """ |
| 179 | + Args: |
| 180 | +
|
| 181 | + x (torch.Tensor): An input tensor to normalize the values of. |
| 182 | +
|
| 183 | + Returns: |
| 184 | + x (torch.Tensor): A normalized tensor. |
| 185 | + """ |
| 186 | + assert x.dim() == 3 or x.dim() == 4 |
| 187 | + if self.transform_input: |
| 188 | + if x.min() < 0.0 or x.max() > 1.0: |
| 189 | + warn("Model input has values outside of the range [0, 1].") |
| 190 | + x = x.unsqueeze(0) if x.dim() == 3 else x |
| 191 | + x = x - torch.tensor( |
| 192 | + [0.48145466, 0.4578275, 0.40821073], device=x.device |
| 193 | + ).view(3, 1, 1) |
| 194 | + x = x / torch.tensor( |
| 195 | + [0.26862954, 0.26130258, 0.27577711], device=x.device |
| 196 | + ).view(3, 1, 1) |
| 197 | + return x |
| 198 | + |
| 199 | + def forward(self, x: torch.Tensor) -> torch.Tensor: |
| 200 | + """ |
| 201 | + Args: |
| 202 | +
|
| 203 | + x (torch.Tensor): An input tensor to run through the model. |
| 204 | +
|
| 205 | + Returns: |
| 206 | + x (torch.Tensor): The model output. |
| 207 | + """ |
| 208 | + x = self._transform_input(x) |
| 209 | + |
| 210 | + # Stem layers |
| 211 | + x = self.relu1(self.bn1(self.conv1(x))) |
| 212 | + x = self.relu2(self.bn2(self.conv2(x))) |
| 213 | + x = self.relu3(self.bn3(self.conv3(x))) |
| 214 | + x = self.avgpool(x) |
| 215 | + |
| 216 | + # Residual layers |
| 217 | + x = self.layer1(x) |
| 218 | + x = self.layer2(x) |
| 219 | + x = self.layer3(x) |
| 220 | + x = self.layer4(x) |
| 221 | + |
| 222 | + # Attention Pooling |
| 223 | + x = self.attnpool(x) |
| 224 | + return x |
| 225 | + |
| 226 | + |
| 227 | +class Bottleneck(nn.Module): |
| 228 | + def __init__( |
| 229 | + self, |
| 230 | + inplanes: int = 80, |
| 231 | + planes: int = 80, |
| 232 | + stride: int = 1, |
| 233 | + pooling: int = 72, |
| 234 | + activ: Type[nn.Module] = nn.ReLU, |
| 235 | + ) -> None: |
| 236 | + """ |
| 237 | + Args: |
| 238 | +
|
| 239 | + inplanes (int, optional): The number of input channels / features to use |
| 240 | + for the first layer. |
| 241 | + Default: 80 |
| 242 | + planes (int, optional): The number of output channels / features to use |
| 243 | + for the subsequent layers. |
| 244 | + Default: 80 |
| 245 | + stride (int, optional): The stride value to use for the Bottleneck layers. |
| 246 | + Default: 1 |
| 247 | + pooling (int, optional): The output size used for nn.AdaptiveAvgPool2d. |
| 248 | + Default: 72 |
| 249 | + activ (type of nn.Module, optional): The nn.Module class type to use for |
| 250 | + activation layers. |
| 251 | + Default: nn.ReLU |
| 252 | + """ |
| 253 | + super().__init__() |
| 254 | + self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False) |
| 255 | + self.bn1 = nn.BatchNorm2d(planes) |
| 256 | + self.relu1 = activ() |
| 257 | + |
| 258 | + self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, padding=1, bias=False) |
| 259 | + self.bn2 = nn.BatchNorm2d(planes) |
| 260 | + self.relu2 = activ() |
| 261 | + |
| 262 | + self.avgpool = nn.AdaptiveAvgPool2d(pooling) |
| 263 | + |
| 264 | + self.conv3 = nn.Conv2d(planes, planes * 4, kernel_size=1, bias=False) |
| 265 | + self.bn3 = nn.BatchNorm2d(planes * 4) |
| 266 | + self.relu3 = activ() |
| 267 | + |
| 268 | + if stride > 1 or inplanes != planes * 4: |
| 269 | + self.downsample = nn.Sequential( |
| 270 | + nn.AdaptiveAvgPool2d(pooling), |
| 271 | + nn.Conv2d(inplanes, planes * 4, kernel_size=1, stride=1, bias=False), |
| 272 | + nn.BatchNorm2d(planes * 4), |
| 273 | + ) |
| 274 | + else: |
| 275 | + self.downsample = None |
| 276 | + |
| 277 | + def forward(self, x: torch.Tensor) -> torch.Tensor: |
| 278 | + """ |
| 279 | + Args: |
| 280 | +
|
| 281 | + x (torch.Tensor): An input tensor to run through the module. |
| 282 | +
|
| 283 | + Returns: |
| 284 | + x (torch.Tensor): The module output. |
| 285 | + """ |
| 286 | + assert x.dim() == 4 |
| 287 | + if self.downsample is not None: |
| 288 | + identity = self.downsample(x) |
| 289 | + else: |
| 290 | + identity = x.clone() |
| 291 | + |
| 292 | + x = self.relu1(self.bn1(self.conv1(x))) |
| 293 | + x = self.relu2(self.bn2(self.conv2(x))) |
| 294 | + x = self.avgpool(x) |
| 295 | + |
| 296 | + x = self.bn3(self.conv3(x)) + identity |
| 297 | + x = self.relu3(x) |
| 298 | + return x |
| 299 | + |
| 300 | + |
| 301 | +class AttentionPool2d(nn.Module): |
| 302 | + def __init__( |
| 303 | + self, |
| 304 | + spacial_size: int = 9, |
| 305 | + in_features: int = 2560, |
| 306 | + out_features: int = 640, |
| 307 | + num_heads: int = 40, |
| 308 | + ) -> None: |
| 309 | + """ |
| 310 | + Args: |
| 311 | +
|
| 312 | + spacial_size (int, optional): The desired size to user for the positional |
| 313 | + embedding. |
| 314 | + Default: 9 |
| 315 | + in_features (int, optional): The desired input size for the nn.Linear |
| 316 | + layers. |
| 317 | + Default: 2560 |
| 318 | + out_features (int, optional): The desired output size for the nn.Linear |
| 319 | + layers. |
| 320 | + num_heads (int, optional): The number of heads to use. |
| 321 | + Default: 40 |
| 322 | + """ |
| 323 | + super().__init__() |
| 324 | + self.positional_embedding = nn.Parameter( |
| 325 | + torch.randn(spacial_size**2 + 1, in_features) / in_features**0.5 |
| 326 | + ) |
| 327 | + self.k_proj = nn.Linear(in_features, in_features) |
| 328 | + self.q_proj = nn.Linear(in_features, in_features) |
| 329 | + self.v_proj = nn.Linear(in_features, in_features) |
| 330 | + self.c_proj = nn.Linear(in_features, out_features) |
| 331 | + self.num_heads = num_heads |
| 332 | + |
| 333 | + @torch.jit.ignore |
| 334 | + def forward(self, x: torch.Tensor) -> torch.Tensor: |
| 335 | + """ |
| 336 | + Args: |
| 337 | +
|
| 338 | + x (torch.Tensor): An input tensor to run through the module. |
| 339 | +
|
| 340 | + Returns: |
| 341 | + x (torch.Tensor): The module output. |
| 342 | + """ |
| 343 | + assert x.dim() == 4 |
| 344 | + x = x.reshape(*x.shape[:2], -1).permute(2, 0, 1) |
| 345 | + x = torch.cat([x.mean(dim=0, keepdim=True), x], dim=0) |
| 346 | + x = x + self.positional_embedding[:, None, :] |
| 347 | + return torch.nn.functional.multi_head_attention_forward( |
| 348 | + query=x, |
| 349 | + key=x, |
| 350 | + value=x, |
| 351 | + embed_dim_to_check=x.shape[-1], |
| 352 | + num_heads=self.num_heads, |
| 353 | + q_proj_weight=self.q_proj.weight, |
| 354 | + k_proj_weight=self.k_proj.weight, |
| 355 | + v_proj_weight=self.v_proj.weight, |
| 356 | + in_proj_weight=None, |
| 357 | + in_proj_bias=torch.cat( |
| 358 | + [self.q_proj.bias, self.k_proj.bias, self.v_proj.bias] |
| 359 | + ), |
| 360 | + bias_k=None, |
| 361 | + bias_v=None, |
| 362 | + add_zero_attn=False, |
| 363 | + dropout_p=0.0, |
| 364 | + out_proj_weight=self.c_proj.weight, |
| 365 | + out_proj_bias=self.c_proj.bias, |
| 366 | + use_separate_proj_weight=True, |
| 367 | + training=self.training, |
| 368 | + need_weights=False, |
| 369 | + )[0][0] |
0 commit comments