4242from scipy import linalg
4343from torch .autograd import Variable
4444from torch .nn .functional import adaptive_avg_pool2d
45+ from tqdm import tqdm
4546
4647from inception import InceptionV3
4748
48-
4949parser = ArgumentParser (formatter_class = ArgumentDefaultsHelpFormatter )
5050parser .add_argument ('path' , type = str , nargs = 2 ,
5151 help = ('Path to the generated images or '
5252 'to .npz statistic files' ))
53- parser .add_argument ('--batch-size' , type = int , default = 64 ,
53+ parser .add_argument ('--batch-size' , type = int , default = 256 ,
5454 help = 'Batch size to use' )
5555parser .add_argument ('--dims' , type = int , default = 2048 ,
5656 choices = list (InceptionV3 .BLOCK_INDEX_BY_DIM ),
6060 help = 'GPU to use (leave blank for CPU only)' )
6161
6262
63- def get_activations (images , model , batch_size = 64 , dims = 2048 ,
63+ def get_activations (files , model , batch_size = 64 , dims = 2048 ,
6464 cuda = False , verbose = False ):
6565 """Calculates the activations of the pool_3 layer for all images.
6666
@@ -81,8 +81,9 @@ def get_activations(images, model, batch_size=64, dims=2048,
8181 query tensor.
8282 """
8383 model .eval ()
84-
85- d0 = images .shape [0 ]
84+
85+ #calculate number of total files
86+ d0 = len (files )
8687 if batch_size > d0 :
8788 print (('Warning: batch size is bigger than the data size. '
8889 'Setting batch size to data size' ))
@@ -92,14 +93,21 @@ def get_activations(images, model, batch_size=64, dims=2048,
9293 n_used_imgs = n_batches * batch_size
9394
9495 pred_arr = np .empty ((n_used_imgs , dims ))
95- for i in range (n_batches ):
96+
97+ #Add processbar to know process
98+ for i in tqdm (range (n_batches )):
9699 if verbose :
97100 print ('\r Propagating batch %d/%d' % (i + 1 , n_batches ),
98101 end = '' , flush = True )
99102 start = i * batch_size
100103 end = start + batch_size
101-
102- batch = torch .from_numpy (images [start :end ]).type (torch .FloatTensor )
104+
105+ # real batch of images here
106+ images = np .array ([imread (str (fn )).astype (np .float32 ) for fn in files [start :end ]])
107+ images = images .transpose ((0 , 3 , 1 , 2 ))
108+ images /= 255
109+
110+ batch = torch .from_numpy (images ).type (torch .FloatTensor )
103111 batch = Variable (batch , volatile = True )
104112 if cuda :
105113 batch = batch .cuda ()
@@ -176,7 +184,7 @@ def calculate_frechet_distance(mu1, sigma1, mu2, sigma2, eps=1e-6):
176184 np .trace (sigma2 ) - 2 * tr_covmean )
177185
178186
179- def calculate_activation_statistics (images , model , batch_size = 64 ,
187+ def calculate_activation_statistics (files , model , batch_size = 64 ,
180188 dims = 2048 , cuda = False , verbose = False ):
181189 """Calculation of the statistics used by the FID.
182190 Params:
@@ -196,13 +204,14 @@ def calculate_activation_statistics(images, model, batch_size=64,
196204 -- sigma : The covariance matrix of the activations of the pool_3 layer of
197205 the inception model.
198206 """
199- act = get_activations (images , model , batch_size , dims , cuda , verbose )
207+ # Instead of load all the images, we pass the file name list
208+ act = get_activations (files , model , batch_size , dims , cuda , verbose )
200209 mu = np .mean (act , axis = 0 )
201210 sigma = np .cov (act , rowvar = False )
202211 return mu , sigma
203212
204213
205- def _compute_statistics_of_path (path , model , batch_size , dims , cuda ):
214+ def _compute_statistics_of_path (path , model , batch_size , dims , cuda , flag ):
206215 if path .endswith ('.npz' ):
207216 f = np .load (path )
208217 m , s = f ['mu' ][:], f ['sigma' ][:]
@@ -211,15 +220,16 @@ def _compute_statistics_of_path(path, model, batch_size, dims, cuda):
211220 path = pathlib .Path (path )
212221 files = list (path .glob ('*.jpg' )) + list (path .glob ('*.png' ))
213222
214- imgs = np .array ([imread (str (fn )).astype (np .float32 ) for fn in files ])
223+ # imgs = np.array([imread(str(fn)).astype(np.float32) for fn in files])
215224
216225 # Bring images to shape (B, 3, H, W)
217- imgs = imgs .transpose ((0 , 3 , 1 , 2 ))
226+ # imgs = imgs.transpose((0, 3, 1, 2))
218227
219228 # Rescale images to be between 0 and 1
220- imgs /= 255
229+ # imgs /= 255
221230
222- m , s = calculate_activation_statistics (imgs , model , batch_size ,
231+ # Instead of load all the images, we pass the file name list
232+ m , s = calculate_activation_statistics (files , model , batch_size ,
223233 dims , cuda )
224234
225235 return m , s
@@ -236,11 +246,11 @@ def calculate_fid_given_paths(paths, batch_size, cuda, dims):
236246 model = InceptionV3 ([block_idx ])
237247 if cuda :
238248 model .cuda ()
239-
249+
240250 m1 , s1 = _compute_statistics_of_path (paths [0 ], model , batch_size ,
241- dims , cuda )
251+ dims , cuda , 1 )
242252 m2 , s2 = _compute_statistics_of_path (paths [1 ], model , batch_size ,
243- dims , cuda )
253+ dims , cuda , 0 )
244254 fid_value = calculate_frechet_distance (m1 , s1 , m2 , s2 )
245255
246256 return fid_value
0 commit comments