1717import pandas as pd
1818import torch
1919
20- from timm .data import create_dataset , create_loader , resolve_data_config
20+ from timm .data import create_dataset , create_loader , resolve_data_config , ImageNetInfo , infer_imagenet_subset
2121from timm .layers import apply_test_time_pool
2222from timm .models import create_model
2323from timm .utils import AverageMeter , setup_default_logging , set_jit_fuser , ParseKwargs
4646
4747_FMT_EXT = {
4848 'json' : '.json' ,
49+ 'json-record' : '.json' ,
4950 'json-split' : '.json' ,
5051 'parquet' : '.parquet' ,
5152 'csv' : '.csv' ,
122123scripting_group .add_argument ('--aot-autograd' , default = False , action = 'store_true' ,
123124 help = "Enable AOT Autograd support." )
124125
125- parser .add_argument ('--results-dir' ,type = str , default = None ,
126+ parser .add_argument ('--results-dir' , type = str , default = None ,
126127 help = 'folder for output results' )
127128parser .add_argument ('--results-file' , type = str , default = None ,
128129 help = 'results filename (relative to results-dir)' )
134135 metavar = 'N' , help = 'Top-k to output to CSV' )
135136parser .add_argument ('--fullname' , action = 'store_true' , default = False ,
136137 help = 'use full sample name in output (not just basename).' )
137- parser .add_argument ('--filename-col' , default = 'filename' ,
138+ parser .add_argument ('--filename-col' , type = str , default = 'filename' ,
138139 help = 'name for filename / sample name column' )
139- parser .add_argument ('--index-col' , default = 'index' ,
140+ parser .add_argument ('--index-col' , type = str , default = 'index' ,
140141 help = 'name for output indices column(s)' )
141- parser .add_argument ('--output-col' , default = None ,
142+ parser .add_argument ('--label-col' , type = str , default = 'label' ,
143+ help = 'name for output indices column(s)' )
144+ parser .add_argument ('--output-col' , type = str , default = None ,
142145 help = 'name for logit/probs output column(s)' )
143- parser .add_argument ('--output-type' , default = 'prob' ,
146+ parser .add_argument ('--output-type' , type = str , default = 'prob' ,
144147 help = 'output type colum ("prob" for probabilities, "logit" for raw logits)' )
148+ parser .add_argument ('--label-type' , type = str , default = 'description' ,
149+ help = 'type of label to output, one of "none", "name", "description", "detailed"' )
150+ parser .add_argument ('--include-index' , action = 'store_true' , default = False ,
151+ help = 'include the class index in results' )
145152parser .add_argument ('--exclude-output' , action = 'store_true' , default = False ,
146153 help = 'exclude logits/probs from results, just indices. topk must be set !=0.' )
147154
@@ -237,10 +244,26 @@ def main():
237244 ** data_config ,
238245 )
239246
247+ to_label = None
248+ if args .label_type in ('name' , 'description' , 'detail' ):
249+ imagenet_subset = infer_imagenet_subset (model )
250+ if imagenet_subset is not None :
251+ dataset_info = ImageNetInfo (imagenet_subset )
252+ if args .label_type == 'name' :
253+ to_label = lambda x : dataset_info .index_to_label_name (x )
254+ elif args .label_type == 'detail' :
255+ to_label = lambda x : dataset_info .index_to_description (x , detailed = True )
256+ else :
257+ to_label = lambda x : dataset_info .index_to_description (x )
258+ to_label = np .vectorize (to_label )
259+ else :
260+ _logger .error ("Cannot deduce ImageNet subset from model, no labelling will be performed." )
261+
240262 top_k = min (args .topk , args .num_classes )
241263 batch_time = AverageMeter ()
242264 end = time .time ()
243265 all_indices = []
266+ all_labels = []
244267 all_outputs = []
245268 use_probs = args .output_type == 'prob'
246269 with torch .no_grad ():
@@ -254,7 +277,12 @@ def main():
254277
255278 if top_k :
256279 output , indices = output .topk (top_k )
257- all_indices .append (indices .cpu ().numpy ())
280+ np_indices = indices .cpu ().numpy ()
281+ if args .include_index :
282+ all_indices .append (np_indices )
283+ if to_label is not None :
284+ np_labels = to_label (np_indices )
285+ all_labels .append (np_labels )
258286
259287 all_outputs .append (output .cpu ().numpy ())
260288
@@ -267,6 +295,7 @@ def main():
267295 batch_idx , len (loader ), batch_time = batch_time ))
268296
269297 all_indices = np .concatenate (all_indices , axis = 0 ) if all_indices else None
298+ all_labels = np .concatenate (all_labels , axis = 0 ) if all_labels else None
270299 all_outputs = np .concatenate (all_outputs , axis = 0 ).astype (np .float32 )
271300 filenames = loader .dataset .filenames (basename = not args .fullname )
272301
@@ -276,13 +305,20 @@ def main():
276305 if all_indices is not None :
277306 for i in range (all_indices .shape [- 1 ]):
278307 data_dict [f'{ args .index_col } _{ i } ' ] = all_indices [:, i ]
308+ if all_labels is not None :
309+ for i in range (all_labels .shape [- 1 ]):
310+ data_dict [f'{ args .label_col } _{ i } ' ] = all_labels [:, i ]
279311 for i in range (all_outputs .shape [- 1 ]):
280312 data_dict [f'{ output_col } _{ i } ' ] = all_outputs [:, i ]
281313 else :
282314 if all_indices is not None :
283315 if all_indices .shape [- 1 ] == 1 :
284316 all_indices = all_indices .squeeze (- 1 )
285317 data_dict [args .index_col ] = list (all_indices )
318+ if all_labels is not None :
319+ if all_labels .shape [- 1 ] == 1 :
320+ all_labels = all_labels .squeeze (- 1 )
321+ data_dict [args .label_col ] = list (all_labels )
286322 if all_outputs .shape [- 1 ] == 1 :
287323 all_outputs = all_outputs .squeeze (- 1 )
288324 data_dict [output_col ] = list (all_outputs )
@@ -291,7 +327,7 @@ def main():
291327
292328 results_filename = args .results_file
293329 if results_filename :
294- filename_no_ext , ext = os .path .splitext (results_filename )[ - 1 ]
330+ filename_no_ext , ext = os .path .splitext (results_filename )
295331 if ext and ext in _FMT_EXT .values ():
296332 # if filename provided with one of expected ext,
297333 # remove it as it will be added back
@@ -308,14 +344,16 @@ def main():
308344 save_results (df , results_filename , fmt )
309345
310346 print (f'--result' )
311- print (json . dumps ( dict ( filename = results_filename ) ))
347+ print (df . set_index ( args . filename_col ). to_json ( orient = 'index' , indent = 4 ))
312348
313349
314350def save_results (df , results_filename , results_format = 'csv' , filename_col = 'filename' ):
315351 results_filename += _FMT_EXT [results_format ]
316352 if results_format == 'parquet' :
317353 df .set_index (filename_col ).to_parquet (results_filename )
318354 elif results_format == 'json' :
355+ df .set_index (filename_col ).to_json (results_filename , indent = 4 , orient = 'index' )
356+ elif results_format == 'json-records' :
319357 df .to_json (results_filename , lines = True , orient = 'records' )
320358 elif results_format == 'json-split' :
321359 df .to_json (results_filename , indent = 4 , orient = 'split' , index = False )
0 commit comments