Skip to content

Commit cdded76

Browse files
committed
Add style transfer example
1 parent 011aa1e commit cdded76

File tree

2 files changed

+149
-0
lines changed

2 files changed

+149
-0
lines changed

README.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -67,6 +67,8 @@ DeZeroの他の実装例は[examples](/examples)にあります。
6767

6868
[<img src="https://raw.githubusercontent.com/oreilly-japan/deep-learning-from-scratch-3/images/gan.gif" height="175"/>](/examples/gan.py)[<img src="https://raw.githubusercontent.com/oreilly-japan/deep-learning-from-scratch-3/images/vae.png" height="175"/>](/examples/vae.py)[<img src="https://raw.githubusercontent.com/oreilly-japan/deep-learning-from-scratch-3/images/pythonista.png" height="175"/>](https://github.com/oreilly-japan/deep-learning-from-scratch-3/wiki/DeZero%E3%82%92iPhone%E3%81%A7%E5%8B%95%E3%81%8B%E3%81%99)
6969

70+
[<img src="https://raw.githubusercontent.com/oreilly-japan/deep-learning-from-scratch-3/images/style_transfer.png" height="175"/>](/examples/style_transfer.py)
71+
7072
## 正誤表
7173

7274
本書の正誤情報は、[:mag_right: 正誤表ページ](../../wiki/Errata)に掲載しています。

examples/style_transfer.py

Lines changed: 147 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,147 @@
1+
import matplotlib.pyplot as plt
2+
import numpy as np
3+
from PIL import Image
4+
import dezero
5+
import dezero.functions as F
6+
from dezero import Variable
7+
from dezero.models import VGG16
8+
9+
10+
use_gpu = dezero.cuda.gpu_enable
11+
lr = 5.0
12+
iterations = 2001
13+
model_input_size = (224, 224)
14+
style_weight = 1.0
15+
content_weight = 1e-4
16+
total_varitaion_weight = 1e-6
17+
content_layers = ['conv5_2']
18+
style_layers = ['conv1_1', 'conv2_1', 'conv3_1', 'conv4_1', 'conv5_1']
19+
content_url = 'https://github.com/oreilly-japan/deep-learning-from-scratch-3/raw/images/zebra.jpg'
20+
style_url = 'https://raw.githubusercontent.com/jcjohnson/neural-style/master/examples/inputs/starry_night_google.jpg'
21+
22+
23+
class VGG16(VGG16):
24+
def extract(self, x):
25+
c1_1 = F.relu(self.conv1_1(x))
26+
c1_2 = F.relu(self.conv1_2(c1_1))
27+
p1 = F.average_pooling(c1_2, 2, 2)
28+
c2_1 = F.relu(self.conv2_1(p1))
29+
c2_2 = F.relu(self.conv2_2(c2_1))
30+
p2 = F.average_pooling(c2_2, 2, 2)
31+
c3_1 = F.relu(self.conv3_1(p2))
32+
c3_2 = F.relu(self.conv3_2(c3_1))
33+
c3_3 = F.relu(self.conv3_3(c3_2))
34+
p3 = F.average_pooling(c3_3, 2, 2)
35+
c4_1 = F.relu(self.conv4_1(p3))
36+
c4_2 = F.relu(self.conv4_2(c4_1))
37+
c4_3 = F.relu(self.conv4_3(c4_2))
38+
p4 = F.average_pooling(c4_3, 2, 2)
39+
c5_1 = F.relu(self.conv5_1(p4))
40+
c5_2 = F.relu(self.conv5_2(c5_1))
41+
c5_3 = F.relu(self.conv5_3(c5_2))
42+
return {'conv1_1':c1_1, 'conv1_2':c1_2, 'conv2_1':c2_1, 'conv2_2':c2_2,
43+
'conv3_1':c3_1, 'conv3_2':c3_2, 'conv3_3':c3_3, 'conv4_1':c4_1,
44+
'conv5_1':c5_1, 'conv5_2':c5_2, 'conv5_3':c5_3}
45+
46+
# Setup for content & style image
47+
content_path = dezero.utils.get_file(content_url)
48+
style_path = dezero.utils.get_file(style_url)
49+
content_img = Image.open(content_path)
50+
content_size = content_img.size
51+
style_img = Image.open(style_path)
52+
content_img = VGG16.preprocess(content_img, size=model_input_size)[np.newaxis] # preprocess for VGG
53+
style_img = VGG16.preprocess(style_img, size=model_input_size)[np.newaxis]
54+
content_img, style_img = Variable(content_img), Variable(style_img)
55+
56+
model = VGG16(pretrained=True)
57+
#gen_data = np.random.uniform(-20, 20, (1, 3, img_resize[0], img_resize[1])).astype(np.float32)
58+
gen_data = content_img.data.copy()
59+
gen_img = dezero.Parameter(gen_data)
60+
gen_model = dezero.models.Model()
61+
gen_model.param = gen_img
62+
optimizer = dezero.optimizers.AdaGrad(lr=lr).setup(gen_model)
63+
64+
if use_gpu:
65+
model.to_gpu()
66+
gen_img.to_gpu()
67+
content_img.to_gpu()
68+
style_img.to_gpu()
69+
70+
71+
with dezero.no_grad():
72+
content_features = model.extract(content_img)
73+
style_features = model.extract(style_img)
74+
75+
76+
def deprocess_image(x, size=None):
77+
if use_gpu:
78+
x = dezero.cuda.as_numpy(x)
79+
if x.ndim == 4:
80+
x = np.squeeze(x)
81+
x = x.transpose((1,2,0))
82+
x += np.array([103.939, 116.779, 123.68])
83+
x = x[:,:,::-1] # BGR -> RGB
84+
x = np.clip(x, 0, 255).astype('uint8')
85+
img = Image.fromarray(x, mode="RGB")
86+
if size:
87+
img = img.resize(size)
88+
return img
89+
90+
91+
def gram_mat(x):
92+
N, C, H, W = x.shape
93+
features = x.reshape(C, -1)
94+
gram = F.matmul(features, features.T)
95+
return gram.reshape(1, C, C)
96+
97+
98+
def style_loss(style, comb):
99+
S = gram_mat(style)
100+
C = gram_mat(comb)
101+
N, ch, H, W = style.shape
102+
return F.mean_squared_error(S, C) / (4 * (ch * W * H)**2)
103+
104+
105+
def content_loss(base, comb):
106+
return F.mean_squared_error(base, comb) / 2
107+
108+
109+
def total_varitaion_loss(x):
110+
a = (x[:, :, :-1, :-1] - x[:, :, 1:, : -1]) ** 2
111+
b = (x[:, :, :-1, :-1] - x[:, :, : -1, 1:]) ** 2
112+
return F.sum(a + b)
113+
114+
115+
def loss_func(gen_features, content_features, style_features, gen_img):
116+
loss = 0
117+
# content loss
118+
for layer in content_features:
119+
loss += content_weight / len(content_layers) * \
120+
content_loss(content_features[layer], gen_features[layer])
121+
# style loss
122+
for layer in style_features:
123+
loss += style_weight / len(style_layers) * \
124+
style_loss(style_features[layer], gen_features[layer])
125+
# total variation loss
126+
loss += total_varitaion_weight * total_varitaion_loss(gen_img)
127+
return loss
128+
129+
130+
print_interval = 100 if use_gpu else 1
131+
for i in range(iterations):
132+
model.cleargrads()
133+
gen_img.cleargrad()
134+
135+
gen_features = model.extract(gen_img)
136+
loss = loss_func(gen_features, content_features, style_features, gen_img)
137+
loss.backward()
138+
optimizer.update()
139+
140+
if i % print_interval == 0:
141+
print('{} loss: {:.0f}'.format(i, float(loss.data)))
142+
143+
if i % 100 == 0:
144+
img = deprocess_image(gen_img.data, content_size)
145+
plt.imshow(np.array(img))
146+
plt.show()
147+
#img.save("style_transfer_{}.png".format(str(i)))

0 commit comments

Comments
 (0)