3636import pathlib
3737from argparse import ArgumentParser , ArgumentDefaultsHelpFormatter
3838
39- import torch
4039import numpy as np
41- from scipy . misc import imread
40+ import torch
4241from scipy import linalg
43- from torch . autograd import Variable
42+ from scipy . misc import imread
4443from torch .nn .functional import adaptive_avg_pool2d
45- from tqdm import tqdm
44+
45+ try :
46+ from tqdm import tqdm
47+ except ImportError :
48+ # If not tqdm is not available, provide a mock version of it
49+ def tqdm (x ): return x
4650
4751from inception import InceptionV3
4852
@@ -65,8 +69,7 @@ def get_activations(files, model, batch_size=64, dims=2048,
6569 """Calculates the activations of the pool_3 layer for all images.
6670
6771 Params:
68- -- images : Numpy array of dimension (n_images, 3, hi, wi). The values
69- must lie between 0 and 1.
72+ -- files : List of image files paths
7073 -- model : Instance of inception model
7174 -- batch_size : the images numpy array is split into batches with
7275 batch size batch_size. A reasonable batch size depends
@@ -81,34 +84,32 @@ def get_activations(files, model, batch_size=64, dims=2048,
8184 query tensor.
8285 """
8386 model .eval ()
84-
85- #calculate number of total files
86- d0 = len (files )
87- if batch_size > d0 :
87+
88+ if batch_size > len (files ):
8889 print (('Warning: batch size is bigger than the data size. '
8990 'Setting batch size to data size' ))
90- batch_size = d0
91+ batch_size = len ( files )
9192
92- n_batches = d0 // batch_size
93+ n_batches = len ( files ) // batch_size
9394 n_used_imgs = n_batches * batch_size
9495
9596 pred_arr = np .empty ((n_used_imgs , dims ))
96-
97- #Add processbar to know process
97+
9898 for i in tqdm (range (n_batches )):
9999 if verbose :
100100 print ('\r Propagating batch %d/%d' % (i + 1 , n_batches ),
101101 end = '' , flush = True )
102102 start = i * batch_size
103103 end = start + batch_size
104-
105- # real batch of images here
106- images = np .array ([imread (str (fn )).astype (np .float32 ) for fn in files [start :end ]])
104+
105+ images = np .array ([imread (str (f )).astype (np .float32 )
106+ for f in files [start :end ]])
107+
108+ # Reshape to (n_images, 3, height, width)
107109 images = images .transpose ((0 , 3 , 1 , 2 ))
108110 images /= 255
109-
111+
110112 batch = torch .from_numpy (images ).type (torch .FloatTensor )
111- batch = Variable (batch , volatile = True )
112113 if cuda :
113114 batch = batch .cuda ()
114115
@@ -139,10 +140,10 @@ def calculate_frechet_distance(mu1, sigma1, mu2, sigma2, eps=1e-6):
139140 -- mu1 : Numpy array containing the activations of a layer of the
140141 inception net (like returned by the function 'get_predictions')
141142 for generated samples.
142- -- mu2 : The sample mean over activations, precalculated on an
143+ -- mu2 : The sample mean over activations, precalculated on an
143144 representative data set.
144145 -- sigma1: The covariance matrix over activations for generated samples.
145- -- sigma2: The covariance matrix over activations, precalculated on an
146+ -- sigma2: The covariance matrix over activations, precalculated on an
146147 representative data set.
147148
148149 Returns:
@@ -188,8 +189,7 @@ def calculate_activation_statistics(files, model, batch_size=64,
188189 dims = 2048 , cuda = False , verbose = False ):
189190 """Calculation of the statistics used by the FID.
190191 Params:
191- -- images : Numpy array of dimension (n_images, 3, hi, wi). The values
192- must lie between 0 and 1.
192+ -- files : List of image files paths
193193 -- model : Instance of inception model
194194 -- batch_size : The images numpy array is split into batches with
195195 batch size batch_size. A reasonable batch size
@@ -204,31 +204,20 @@ def calculate_activation_statistics(files, model, batch_size=64,
204204 -- sigma : The covariance matrix of the activations of the pool_3 layer of
205205 the inception model.
206206 """
207- # Instead of load all the images, we pass the file name list
208207 act = get_activations (files , model , batch_size , dims , cuda , verbose )
209208 mu = np .mean (act , axis = 0 )
210209 sigma = np .cov (act , rowvar = False )
211210 return mu , sigma
212211
213212
214- def _compute_statistics_of_path (path , model , batch_size , dims , cuda , flag ):
213+ def _compute_statistics_of_path (path , model , batch_size , dims , cuda ):
215214 if path .endswith ('.npz' ):
216215 f = np .load (path )
217216 m , s = f ['mu' ][:], f ['sigma' ][:]
218217 f .close ()
219218 else :
220219 path = pathlib .Path (path )
221220 files = list (path .glob ('*.jpg' )) + list (path .glob ('*.png' ))
222-
223- # imgs = np.array([imread(str(fn)).astype(np.float32) for fn in files])
224-
225- # Bring images to shape (B, 3, H, W)
226- # imgs = imgs.transpose((0, 3, 1, 2))
227-
228- # Rescale images to be between 0 and 1
229- # imgs /= 255
230-
231- # Instead of load all the images, we pass the file name list
232221 m , s = calculate_activation_statistics (files , model , batch_size ,
233222 dims , cuda )
234223
@@ -246,11 +235,11 @@ def calculate_fid_given_paths(paths, batch_size, cuda, dims):
246235 model = InceptionV3 ([block_idx ])
247236 if cuda :
248237 model .cuda ()
249-
238+
250239 m1 , s1 = _compute_statistics_of_path (paths [0 ], model , batch_size ,
251- dims , cuda , 1 )
240+ dims , cuda )
252241 m2 , s2 = _compute_statistics_of_path (paths [1 ], model , batch_size ,
253- dims , cuda , 0 )
242+ dims , cuda )
254243 fid_value = calculate_frechet_distance (m1 , s1 , m2 , s2 )
255244
256245 return fid_value
0 commit comments