Skip to content

Commit 34bedcf

Browse files
authored
Add the CLIP ResNet 50x4 model
1 parent aae34d0 commit 34bedcf

File tree

5 files changed

+775
-0
lines changed

5 files changed

+775
-0
lines changed

captum/optim/models/__init__.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,9 +6,14 @@
66
replace_layers,
77
skip_layers,
88
)
9+
from ._image.clip_resnet50x4_image import CLIP_ResNet50x4Image # noqa: F401
10+
from ._image.clip_resnet50x4_image import clip_resnet50x4_image # noqa: F401
11+
from ._image.clip_resnet50x4_text import CLIP_ResNet50x4Text # noqa: F401
12+
from ._image.clip_resnet50x4_text import clip_resnet50x4_text # noqa: F401
913
from ._image.inception5h_classes import INCEPTION5H_CLASSES # noqa: F401
1014
from ._image.inception_v1 import InceptionV1, googlenet # noqa: F401
1115

16+
1217
__all__ = [
1318
"RedirectedReluLayer",
1419
"SkipLayer",
@@ -19,4 +24,8 @@
1924
"InceptionV1",
2025
"googlenet",
2126
"INCEPTION5H_CLASSES",
27+
"CLIP_ResNet50x4Image",
28+
"clip_resnet50x4_image",
29+
"CLIP_ResNet50x4Text",
30+
"clip_resnet50x4_text",
2231
]
Lines changed: 369 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,369 @@
1+
from typing import Optional, Type
2+
from warnings import warn
3+
4+
import torch
5+
from torch import nn
6+
7+
from captum.optim.models._common import RedirectedReluLayer, SkipLayer
8+
9+
GS_SAVED_WEIGHTS_URL = (
10+
"https://pytorch.s3.amazonaws.com/models/captum/clip_resnet50x4_image.pt"
11+
)
12+
13+
14+
def clip_resnet50x4_image(
15+
pretrained: bool = False,
16+
progress: bool = True,
17+
model_path: Optional[str] = None,
18+
**kwargs
19+
) -> "CLIP_ResNet50x4Image":
20+
"""
21+
The visual portion of OpenAI's ResNet 50x4 CLIP model from 'Learning Transferable
22+
Visual Models From Natural Language Supervision': https://arxiv.org/abs/2103.00020
23+
24+
This model can be combined with the CLIP ResNet 50x4 Text model to create the full
25+
CLIP ResNet 50x4 model.
26+
27+
AvgPool2d layers were replaced with AdaptiveAvgPool2d to allow for any input height
28+
and width size, though the best results are obtained by using the model's intended
29+
input height and width of 288x288.
30+
31+
See here for more details:
32+
https://github.com/openai/CLIP
33+
https://github.com/mlfoundations/open_clip
34+
35+
Args:
36+
37+
pretrained (bool, optional): If True, returns a pre-trained model.
38+
Default: False
39+
progress (bool, optional): If True, displays a progress bar of the download to
40+
stderr
41+
Default: True
42+
model_path (str, optional): Optional path for the model file.
43+
Default: None
44+
replace_relus_with_redirectedrelu (bool, optional): If True, return pretrained
45+
model with Redirected ReLU in place of ReLU layers.
46+
Default: *True* when pretrained is True otherwise *False*
47+
use_linear_modules_only (bool, optional): If True, return model
48+
with all nonlinear layers replaced with linear equivalents.
49+
Default: False
50+
transform_input (bool, optional): If True, preprocesses the input according to
51+
the method with which it was trained.
52+
Default: *True* when pretrained is True otherwise *False*
53+
54+
Returns:
55+
**CLIP_ResNet50x4Image** (CLIP_ResNet50x4Image): A CLIP ResNet 50x4 model's
56+
image portion.
57+
"""
58+
if pretrained:
59+
if "transform_input" not in kwargs:
60+
kwargs["transform_input"] = True
61+
if "replace_relus_with_redirectedrelu" not in kwargs:
62+
kwargs["replace_relus_with_redirectedrelu"] = True
63+
if "use_linear_modules_only" not in kwargs:
64+
kwargs["use_linear_modules_only"] = False
65+
66+
model = CLIP_ResNet50x4Image(**kwargs)
67+
68+
if model_path is None:
69+
state_dict = torch.hub.load_state_dict_from_url(
70+
GS_SAVED_WEIGHTS_URL, progress=progress, check_hash=False
71+
)
72+
else:
73+
state_dict = torch.load(model_path, map_location="cpu")
74+
model.load_state_dict(state_dict)
75+
return model
76+
77+
return CLIP_ResNet50x4Image(**kwargs)
78+
79+
80+
class CLIP_ResNet50x4Image(nn.Module):
81+
"""
82+
The visual portion of OpenAI's ResNet 50x4 CLIP model from 'Learning Transferable
83+
Visual Models From Natural Language Supervision': https://arxiv.org/abs/2103.00020
84+
"""
85+
__constants__ = ["transform_input"]
86+
87+
def __init__(
88+
self,
89+
transform_input: bool = False,
90+
replace_relus_with_redirectedrelu: bool = False,
91+
use_linear_modules_only: bool = False,
92+
) -> None:
93+
"""
94+
Args:
95+
96+
replace_relus_with_redirectedrelu (bool, optional): If True, return
97+
model with Redirected ReLU in place of ReLU layers.
98+
Default: False
99+
use_linear_modules_only (bool, optional): If True, return model with
100+
all nonlinear layers replaced with linear equivalents.
101+
Default: False
102+
transform_input (bool, optional): If True, preprocesses the input according
103+
to the method with which it was trained on.
104+
Default: False
105+
"""
106+
super().__init__()
107+
if use_linear_modules_only:
108+
activ = SkipLayer
109+
else:
110+
if replace_relus_with_redirectedrelu:
111+
activ = RedirectedReluLayer
112+
else:
113+
activ = nn.ReLU
114+
115+
self.transform_input = transform_input
116+
117+
# Stem layers
118+
self.conv1 = nn.Conv2d(3, 40, kernel_size=3, stride=2, padding=1, bias=False)
119+
self.bn1 = nn.BatchNorm2d(40)
120+
self.relu1 = activ()
121+
self.conv2 = nn.Conv2d(40, 40, kernel_size=3, padding=1, bias=False)
122+
self.bn2 = nn.BatchNorm2d(40)
123+
self.relu2 = activ()
124+
self.conv3 = nn.Conv2d(40, 80, kernel_size=3, padding=1, bias=False)
125+
self.bn3 = nn.BatchNorm2d(80)
126+
self.relu3 = activ()
127+
self.avgpool = nn.AdaptiveAvgPool2d(72)
128+
129+
# Residual layers
130+
self.layer1 = self._build_layer(80, 80, 4, stride=1, pooling=72, activ=activ)
131+
self.layer2 = self._build_layer(320, 160, 6, stride=2, pooling=36, activ=activ)
132+
self.layer3 = self._build_layer(640, 320, 10, stride=2, pooling=18, activ=activ)
133+
self.layer4 = self._build_layer(1280, 640, 6, stride=2, pooling=9, activ=activ)
134+
135+
# Attention Pooling
136+
self.attnpool = AttentionPool2d(9, 2560, out_features=640, num_heads=40)
137+
138+
def _build_layer(
139+
self,
140+
inplanes: int = 80,
141+
planes: int = 80,
142+
blocks: int = 4,
143+
stride: int = 1,
144+
pooling: int = 72,
145+
activ: Type[nn.Module] = nn.ReLU,
146+
) -> nn.Module:
147+
"""
148+
Residual layer creation helper function.
149+
150+
Args:
151+
152+
inplanes (int, optional): The number of input channels / features to use
153+
for the first layer.
154+
Default: 80
155+
planes (int, optional): The number of output channels / features to use
156+
for the first layer. This variable is then multiplied by 4 to get the
157+
number of input channels / features to use for the subsequent layers.
158+
Default: 80
159+
blocks (int, optional): The number of Bottleneck layers to create.
160+
Default: 4
161+
stride (int, optional): The stride value to use for the Bottleneck layers.
162+
Default: 1
163+
pooling (int, optional): The output size used for nn.AdaptiveAvgPool2d.
164+
Default: 72
165+
activ (type of nn.Module, optional): The nn.Module class type to use for
166+
activation layers.
167+
Default: nn.ReLU
168+
169+
Returns:
170+
residual_layer (nn.Sequential): A full residual layer.
171+
"""
172+
layers = [Bottleneck(inplanes, planes, stride, pooling=pooling, activ=activ)]
173+
for _ in range(blocks - 1):
174+
layers += [Bottleneck(planes * 4, planes, pooling=pooling, activ=activ)]
175+
return nn.Sequential(*layers)
176+
177+
def _transform_input(self, x: torch.Tensor) -> torch.Tensor:
178+
"""
179+
Args:
180+
181+
x (torch.Tensor): An input tensor to normalize the values of.
182+
183+
Returns:
184+
x (torch.Tensor): A normalized tensor.
185+
"""
186+
assert x.dim() == 3 or x.dim() == 4
187+
if self.transform_input:
188+
if x.min() < 0.0 or x.max() > 1.0:
189+
warn("Model input has values outside of the range [0, 1].")
190+
x = x.unsqueeze(0) if x.dim() == 3 else x
191+
x = x - torch.tensor(
192+
[0.48145466, 0.4578275, 0.40821073], device=x.device
193+
).view(3, 1, 1)
194+
x = x / torch.tensor(
195+
[0.26862954, 0.26130258, 0.27577711], device=x.device
196+
).view(3, 1, 1)
197+
return x
198+
199+
def forward(self, x: torch.Tensor) -> torch.Tensor:
200+
"""
201+
Args:
202+
203+
x (torch.Tensor): An input tensor to run through the model.
204+
205+
Returns:
206+
x (torch.Tensor): The model output.
207+
"""
208+
x = self._transform_input(x)
209+
210+
# Stem layers
211+
x = self.relu1(self.bn1(self.conv1(x)))
212+
x = self.relu2(self.bn2(self.conv2(x)))
213+
x = self.relu3(self.bn3(self.conv3(x)))
214+
x = self.avgpool(x)
215+
216+
# Residual layers
217+
x = self.layer1(x)
218+
x = self.layer2(x)
219+
x = self.layer3(x)
220+
x = self.layer4(x)
221+
222+
# Attention Pooling
223+
x = self.attnpool(x)
224+
return x
225+
226+
227+
class Bottleneck(nn.Module):
228+
def __init__(
229+
self,
230+
inplanes: int = 80,
231+
planes: int = 80,
232+
stride: int = 1,
233+
pooling: int = 72,
234+
activ: Type[nn.Module] = nn.ReLU,
235+
) -> None:
236+
"""
237+
Args:
238+
239+
inplanes (int, optional): The number of input channels / features to use
240+
for the first layer.
241+
Default: 80
242+
planes (int, optional): The number of output channels / features to use
243+
for the subsequent layers.
244+
Default: 80
245+
stride (int, optional): The stride value to use for the Bottleneck layers.
246+
Default: 1
247+
pooling (int, optional): The output size used for nn.AdaptiveAvgPool2d.
248+
Default: 72
249+
activ (type of nn.Module, optional): The nn.Module class type to use for
250+
activation layers.
251+
Default: nn.ReLU
252+
"""
253+
super().__init__()
254+
self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False)
255+
self.bn1 = nn.BatchNorm2d(planes)
256+
self.relu1 = activ()
257+
258+
self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, padding=1, bias=False)
259+
self.bn2 = nn.BatchNorm2d(planes)
260+
self.relu2 = activ()
261+
262+
self.avgpool = nn.AdaptiveAvgPool2d(pooling)
263+
264+
self.conv3 = nn.Conv2d(planes, planes * 4, kernel_size=1, bias=False)
265+
self.bn3 = nn.BatchNorm2d(planes * 4)
266+
self.relu3 = activ()
267+
268+
if stride > 1 or inplanes != planes * 4:
269+
self.downsample = nn.Sequential(
270+
nn.AdaptiveAvgPool2d(pooling),
271+
nn.Conv2d(inplanes, planes * 4, kernel_size=1, stride=1, bias=False),
272+
nn.BatchNorm2d(planes * 4),
273+
)
274+
else:
275+
self.downsample = None
276+
277+
def forward(self, x: torch.Tensor) -> torch.Tensor:
278+
"""
279+
Args:
280+
281+
x (torch.Tensor): An input tensor to run through the module.
282+
283+
Returns:
284+
x (torch.Tensor): The module output.
285+
"""
286+
assert x.dim() == 4
287+
if self.downsample is not None:
288+
identity = self.downsample(x)
289+
else:
290+
identity = x.clone()
291+
292+
x = self.relu1(self.bn1(self.conv1(x)))
293+
x = self.relu2(self.bn2(self.conv2(x)))
294+
x = self.avgpool(x)
295+
296+
x = self.bn3(self.conv3(x)) + identity
297+
x = self.relu3(x)
298+
return x
299+
300+
301+
class AttentionPool2d(nn.Module):
302+
def __init__(
303+
self,
304+
spacial_size: int = 9,
305+
in_features: int = 2560,
306+
out_features: int = 640,
307+
num_heads: int = 40,
308+
) -> None:
309+
"""
310+
Args:
311+
312+
spacial_size (int, optional): The desired size to user for the positional
313+
embedding.
314+
Default: 9
315+
in_features (int, optional): The desired input size for the nn.Linear
316+
layers.
317+
Default: 2560
318+
out_features (int, optional): The desired output size for the nn.Linear
319+
layers.
320+
num_heads (int, optional): The number of heads to use.
321+
Default: 40
322+
"""
323+
super().__init__()
324+
self.positional_embedding = nn.Parameter(
325+
torch.randn(spacial_size**2 + 1, in_features) / in_features**0.5
326+
)
327+
self.k_proj = nn.Linear(in_features, in_features)
328+
self.q_proj = nn.Linear(in_features, in_features)
329+
self.v_proj = nn.Linear(in_features, in_features)
330+
self.c_proj = nn.Linear(in_features, out_features)
331+
self.num_heads = num_heads
332+
333+
@torch.jit.ignore
334+
def forward(self, x: torch.Tensor) -> torch.Tensor:
335+
"""
336+
Args:
337+
338+
x (torch.Tensor): An input tensor to run through the module.
339+
340+
Returns:
341+
x (torch.Tensor): The module output.
342+
"""
343+
assert x.dim() == 4
344+
x = x.reshape(*x.shape[:2], -1).permute(2, 0, 1)
345+
x = torch.cat([x.mean(dim=0, keepdim=True), x], dim=0)
346+
x = x + self.positional_embedding[:, None, :]
347+
return torch.nn.functional.multi_head_attention_forward(
348+
query=x,
349+
key=x,
350+
value=x,
351+
embed_dim_to_check=x.shape[-1],
352+
num_heads=self.num_heads,
353+
q_proj_weight=self.q_proj.weight,
354+
k_proj_weight=self.k_proj.weight,
355+
v_proj_weight=self.v_proj.weight,
356+
in_proj_weight=None,
357+
in_proj_bias=torch.cat(
358+
[self.q_proj.bias, self.k_proj.bias, self.v_proj.bias]
359+
),
360+
bias_k=None,
361+
bias_v=None,
362+
add_zero_attn=False,
363+
dropout_p=0.0,
364+
out_proj_weight=self.c_proj.weight,
365+
out_proj_bias=self.c_proj.bias,
366+
use_separate_proj_weight=True,
367+
training=self.training,
368+
need_weights=False,
369+
)[0][0]

0 commit comments

Comments
 (0)