Skip to content

Commit fc3bbc8

Browse files
committed
Add grad-cam example
1 parent b8b4298 commit fc3bbc8

File tree

2 files changed

+51
-0
lines changed

2 files changed

+51
-0
lines changed

dezero/functions.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -245,6 +245,15 @@ def broadcast_to(x, shape):
245245
return BroadcastTo(shape)(x)
246246

247247

248+
def average(x, axis=None, keepdims=False):
249+
x = as_variable(x)
250+
y = sum(x, axis, keepdims)
251+
return y * (y.data.size / x.data.size)
252+
253+
254+
mean = average
255+
256+
248257
class MatMul(Function):
249258
def forward(self, x, W):
250259
y = x.dot(W)

examples/grad_cam.py

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,42 @@
1+
"""
2+
Simple implementation of Grad-CAM (https://arxiv.org/pdf/1610.02391.pdf)
3+
"""
4+
import numpy as np
5+
from PIL import Image
6+
import cv2
7+
import dezero
8+
import dezero.functions as F
9+
from dezero.models import VGG16
10+
11+
12+
url = 'https://github.com/oreilly-japan/deep-learning-from-scratch-3/raw/images/zebra.jpg'
13+
img_path = dezero.utils.get_file(url)
14+
img = Image.open(img_path)
15+
img_size = img.size
16+
17+
model = VGG16(pretrained=True)
18+
x = VGG16.preprocess(img)[np.newaxis] # preprocess for VGG
19+
y = model(x)
20+
last_conv_output = model.conv5_3.outputs[0]()
21+
predict_id = np.argmax(y.data)
22+
predict_output = y[0, predict_id]
23+
24+
predict_output.backward(retain_grad=True)
25+
grads = last_conv_output.grad
26+
pooled_grads = F.average(grads, axis=(0, 2, 3))
27+
28+
heatmap = last_conv_output.data[0]
29+
for c in range(heatmap.shape[0]):
30+
heatmap[c] *= pooled_grads[c].data
31+
32+
heatmap = np.mean(heatmap, axis=0)
33+
heatmap = np.maximum(heatmap, 0)
34+
heatmap /= np.max(heatmap)
35+
36+
# visualize the heatmap on image
37+
img = cv2.imread(img_path)
38+
heatmap = cv2.resize(heatmap, (img.shape[1], img.shape[0]))
39+
heatmap = np.uint8(255 * heatmap)
40+
heatmap = cv2.applyColorMap(heatmap, cv2.COLORMAP_JET)
41+
heatmap_on_img = heatmap * 0.4 + img
42+
cv2.imwrite('grad_cam.png', heatmap_on_img)

0 commit comments

Comments
 (0)