1+ import matplotlib .pyplot as plt
2+ import numpy as np
3+ from PIL import Image
4+ import dezero
5+ import dezero .functions as F
6+ from dezero import Variable
7+ from dezero .models import VGG16
8+
9+
10+ use_gpu = dezero .cuda .gpu_enable
11+ lr = 5.0
12+ iterations = 2001
13+ model_input_size = (224 , 224 )
14+ style_weight = 1.0
15+ content_weight = 1e-4
16+ total_varitaion_weight = 1e-6
17+ content_layers = ['conv5_2' ]
18+ style_layers = ['conv1_1' , 'conv2_1' , 'conv3_1' , 'conv4_1' , 'conv5_1' ]
19+ content_url = 'https://github.com/oreilly-japan/deep-learning-from-scratch-3/raw/images/zebra.jpg'
20+ style_url = 'https://raw.githubusercontent.com/jcjohnson/neural-style/master/examples/inputs/starry_night_google.jpg'
21+
22+
23+ class VGG16 (VGG16 ):
24+ def extract (self , x ):
25+ c1_1 = F .relu (self .conv1_1 (x ))
26+ c1_2 = F .relu (self .conv1_2 (c1_1 ))
27+ p1 = F .average_pooling (c1_2 , 2 , 2 )
28+ c2_1 = F .relu (self .conv2_1 (p1 ))
29+ c2_2 = F .relu (self .conv2_2 (c2_1 ))
30+ p2 = F .average_pooling (c2_2 , 2 , 2 )
31+ c3_1 = F .relu (self .conv3_1 (p2 ))
32+ c3_2 = F .relu (self .conv3_2 (c3_1 ))
33+ c3_3 = F .relu (self .conv3_3 (c3_2 ))
34+ p3 = F .average_pooling (c3_3 , 2 , 2 )
35+ c4_1 = F .relu (self .conv4_1 (p3 ))
36+ c4_2 = F .relu (self .conv4_2 (c4_1 ))
37+ c4_3 = F .relu (self .conv4_3 (c4_2 ))
38+ p4 = F .average_pooling (c4_3 , 2 , 2 )
39+ c5_1 = F .relu (self .conv5_1 (p4 ))
40+ c5_2 = F .relu (self .conv5_2 (c5_1 ))
41+ c5_3 = F .relu (self .conv5_3 (c5_2 ))
42+ return {'conv1_1' :c1_1 , 'conv1_2' :c1_2 , 'conv2_1' :c2_1 , 'conv2_2' :c2_2 ,
43+ 'conv3_1' :c3_1 , 'conv3_2' :c3_2 , 'conv3_3' :c3_3 , 'conv4_1' :c4_1 ,
44+ 'conv5_1' :c5_1 , 'conv5_2' :c5_2 , 'conv5_3' :c5_3 }
45+
46+ # Setup for content & style image
47+ content_path = dezero .utils .get_file (content_url )
48+ style_path = dezero .utils .get_file (style_url )
49+ content_img = Image .open (content_path )
50+ content_size = content_img .size
51+ style_img = Image .open (style_path )
52+ content_img = VGG16 .preprocess (content_img , size = model_input_size )[np .newaxis ] # preprocess for VGG
53+ style_img = VGG16 .preprocess (style_img , size = model_input_size )[np .newaxis ]
54+ content_img , style_img = Variable (content_img ), Variable (style_img )
55+
56+ model = VGG16 (pretrained = True )
57+ #gen_data = np.random.uniform(-20, 20, (1, 3, img_resize[0], img_resize[1])).astype(np.float32)
58+ gen_data = content_img .data .copy ()
59+ gen_img = dezero .Parameter (gen_data )
60+ gen_model = dezero .models .Model ()
61+ gen_model .param = gen_img
62+ optimizer = dezero .optimizers .AdaGrad (lr = lr ).setup (gen_model )
63+
64+ if use_gpu :
65+ model .to_gpu ()
66+ gen_img .to_gpu ()
67+ content_img .to_gpu ()
68+ style_img .to_gpu ()
69+
70+
71+ with dezero .no_grad ():
72+ content_features = model .extract (content_img )
73+ style_features = model .extract (style_img )
74+
75+
76+ def deprocess_image (x , size = None ):
77+ if use_gpu :
78+ x = dezero .cuda .as_numpy (x )
79+ if x .ndim == 4 :
80+ x = np .squeeze (x )
81+ x = x .transpose ((1 ,2 ,0 ))
82+ x += np .array ([103.939 , 116.779 , 123.68 ])
83+ x = x [:,:,::- 1 ] # BGR -> RGB
84+ x = np .clip (x , 0 , 255 ).astype ('uint8' )
85+ img = Image .fromarray (x , mode = "RGB" )
86+ if size :
87+ img = img .resize (size )
88+ return img
89+
90+
91+ def gram_mat (x ):
92+ N , C , H , W = x .shape
93+ features = x .reshape (C , - 1 )
94+ gram = F .matmul (features , features .T )
95+ return gram .reshape (1 , C , C )
96+
97+
98+ def style_loss (style , comb ):
99+ S = gram_mat (style )
100+ C = gram_mat (comb )
101+ N , ch , H , W = style .shape
102+ return F .mean_squared_error (S , C ) / (4 * (ch * W * H )** 2 )
103+
104+
105+ def content_loss (base , comb ):
106+ return F .mean_squared_error (base , comb ) / 2
107+
108+
109+ def total_varitaion_loss (x ):
110+ a = (x [:, :, :- 1 , :- 1 ] - x [:, :, 1 :, : - 1 ]) ** 2
111+ b = (x [:, :, :- 1 , :- 1 ] - x [:, :, : - 1 , 1 :]) ** 2
112+ return F .sum (a + b )
113+
114+
115+ def loss_func (gen_features , content_features , style_features , gen_img ):
116+ loss = 0
117+ # content loss
118+ for layer in content_features :
119+ loss += content_weight / len (content_layers ) * \
120+ content_loss (content_features [layer ], gen_features [layer ])
121+ # style loss
122+ for layer in style_features :
123+ loss += style_weight / len (style_layers ) * \
124+ style_loss (style_features [layer ], gen_features [layer ])
125+ # total variation loss
126+ loss += total_varitaion_weight * total_varitaion_loss (gen_img )
127+ return loss
128+
129+
130+ print_interval = 100 if use_gpu else 1
131+ for i in range (iterations ):
132+ model .cleargrads ()
133+ gen_img .cleargrad ()
134+
135+ gen_features = model .extract (gen_img )
136+ loss = loss_func (gen_features , content_features , style_features , gen_img )
137+ loss .backward ()
138+ optimizer .update ()
139+
140+ if i % print_interval == 0 :
141+ print ('{} loss: {:.0f}' .format (i , float (loss .data )))
142+
143+ if i % 100 == 0 :
144+ img = deprocess_image (gen_img .data , content_size )
145+ plt .imshow (np .array (img ))
146+ plt .show ()
147+ #img.save("style_transfer_{}.png".format(str(i)))
0 commit comments