Skip to content

Commit 88a5b84

Browse files
authored
Merge pull request #1662 from rwightman/dataset_info
ImageNet metadata (info) and labelling update
2 parents 89b0452 + 7a0bd09 commit 88a5b84

23 files changed

+65547
-18
lines changed

MANIFEST.in

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,3 @@
1-
include timm/models/pruned/*.txt
2-
1+
include timm/models/_pruned/*.txt
2+
include timm/data/_info/*.txt
3+
include timm/data/_info/*.json

inference.py

Lines changed: 47 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
import pandas as pd
1818
import torch
1919

20-
from timm.data import create_dataset, create_loader, resolve_data_config
20+
from timm.data import create_dataset, create_loader, resolve_data_config, ImageNetInfo, infer_imagenet_subset
2121
from timm.layers import apply_test_time_pool
2222
from timm.models import create_model
2323
from timm.utils import AverageMeter, setup_default_logging, set_jit_fuser, ParseKwargs
@@ -46,6 +46,7 @@
4646

4747
_FMT_EXT = {
4848
'json': '.json',
49+
'json-record': '.json',
4950
'json-split': '.json',
5051
'parquet': '.parquet',
5152
'csv': '.csv',
@@ -122,7 +123,7 @@
122123
scripting_group.add_argument('--aot-autograd', default=False, action='store_true',
123124
help="Enable AOT Autograd support.")
124125

125-
parser.add_argument('--results-dir',type=str, default=None,
126+
parser.add_argument('--results-dir', type=str, default=None,
126127
help='folder for output results')
127128
parser.add_argument('--results-file', type=str, default=None,
128129
help='results filename (relative to results-dir)')
@@ -134,14 +135,20 @@
134135
metavar='N', help='Top-k to output to CSV')
135136
parser.add_argument('--fullname', action='store_true', default=False,
136137
help='use full sample name in output (not just basename).')
137-
parser.add_argument('--filename-col', default='filename',
138+
parser.add_argument('--filename-col', type=str, default='filename',
138139
help='name for filename / sample name column')
139-
parser.add_argument('--index-col', default='index',
140+
parser.add_argument('--index-col', type=str, default='index',
140141
help='name for output indices column(s)')
141-
parser.add_argument('--output-col', default=None,
142+
parser.add_argument('--label-col', type=str, default='label',
143+
help='name for output indices column(s)')
144+
parser.add_argument('--output-col', type=str, default=None,
142145
help='name for logit/probs output column(s)')
143-
parser.add_argument('--output-type', default='prob',
146+
parser.add_argument('--output-type', type=str, default='prob',
144147
help='output type colum ("prob" for probabilities, "logit" for raw logits)')
148+
parser.add_argument('--label-type', type=str, default='description',
149+
help='type of label to output, one of "none", "name", "description", "detailed"')
150+
parser.add_argument('--include-index', action='store_true', default=False,
151+
help='include the class index in results')
145152
parser.add_argument('--exclude-output', action='store_true', default=False,
146153
help='exclude logits/probs from results, just indices. topk must be set !=0.')
147154

@@ -237,10 +244,26 @@ def main():
237244
**data_config,
238245
)
239246

247+
to_label = None
248+
if args.label_type in ('name', 'description', 'detail'):
249+
imagenet_subset = infer_imagenet_subset(model)
250+
if imagenet_subset is not None:
251+
dataset_info = ImageNetInfo(imagenet_subset)
252+
if args.label_type == 'name':
253+
to_label = lambda x: dataset_info.index_to_label_name(x)
254+
elif args.label_type == 'detail':
255+
to_label = lambda x: dataset_info.index_to_description(x, detailed=True)
256+
else:
257+
to_label = lambda x: dataset_info.index_to_description(x)
258+
to_label = np.vectorize(to_label)
259+
else:
260+
_logger.error("Cannot deduce ImageNet subset from model, no labelling will be performed.")
261+
240262
top_k = min(args.topk, args.num_classes)
241263
batch_time = AverageMeter()
242264
end = time.time()
243265
all_indices = []
266+
all_labels = []
244267
all_outputs = []
245268
use_probs = args.output_type == 'prob'
246269
with torch.no_grad():
@@ -254,7 +277,12 @@ def main():
254277

255278
if top_k:
256279
output, indices = output.topk(top_k)
257-
all_indices.append(indices.cpu().numpy())
280+
np_indices = indices.cpu().numpy()
281+
if args.include_index:
282+
all_indices.append(np_indices)
283+
if to_label is not None:
284+
np_labels = to_label(np_indices)
285+
all_labels.append(np_labels)
258286

259287
all_outputs.append(output.cpu().numpy())
260288

@@ -267,6 +295,7 @@ def main():
267295
batch_idx, len(loader), batch_time=batch_time))
268296

269297
all_indices = np.concatenate(all_indices, axis=0) if all_indices else None
298+
all_labels = np.concatenate(all_labels, axis=0) if all_labels else None
270299
all_outputs = np.concatenate(all_outputs, axis=0).astype(np.float32)
271300
filenames = loader.dataset.filenames(basename=not args.fullname)
272301

@@ -276,13 +305,20 @@ def main():
276305
if all_indices is not None:
277306
for i in range(all_indices.shape[-1]):
278307
data_dict[f'{args.index_col}_{i}'] = all_indices[:, i]
308+
if all_labels is not None:
309+
for i in range(all_labels.shape[-1]):
310+
data_dict[f'{args.label_col}_{i}'] = all_labels[:, i]
279311
for i in range(all_outputs.shape[-1]):
280312
data_dict[f'{output_col}_{i}'] = all_outputs[:, i]
281313
else:
282314
if all_indices is not None:
283315
if all_indices.shape[-1] == 1:
284316
all_indices = all_indices.squeeze(-1)
285317
data_dict[args.index_col] = list(all_indices)
318+
if all_labels is not None:
319+
if all_labels.shape[-1] == 1:
320+
all_labels = all_labels.squeeze(-1)
321+
data_dict[args.label_col] = list(all_labels)
286322
if all_outputs.shape[-1] == 1:
287323
all_outputs = all_outputs.squeeze(-1)
288324
data_dict[output_col] = list(all_outputs)
@@ -291,7 +327,7 @@ def main():
291327

292328
results_filename = args.results_file
293329
if results_filename:
294-
filename_no_ext, ext = os.path.splitext(results_filename)[-1]
330+
filename_no_ext, ext = os.path.splitext(results_filename)
295331
if ext and ext in _FMT_EXT.values():
296332
# if filename provided with one of expected ext,
297333
# remove it as it will be added back
@@ -308,14 +344,16 @@ def main():
308344
save_results(df, results_filename, fmt)
309345

310346
print(f'--result')
311-
print(json.dumps(dict(filename=results_filename)))
347+
print(df.set_index(args.filename_col).to_json(orient='index', indent=4))
312348

313349

314350
def save_results(df, results_filename, results_format='csv', filename_col='filename'):
315351
results_filename += _FMT_EXT[results_format]
316352
if results_format == 'parquet':
317353
df.set_index(filename_col).to_parquet(results_filename)
318354
elif results_format == 'json':
355+
df.set_index(filename_col).to_json(results_filename, indent=4, orient='index')
356+
elif results_format == 'json-records':
319357
df.to_json(results_filename, lines=True, orient='records')
320358
elif results_format == 'json-split':
321359
df.to_json(results_filename, indent=4, orient='split', index=False)

timm/data/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,8 @@
44
from .constants import *
55
from .dataset import ImageDataset, IterableImageDataset, AugMixDataset
66
from .dataset_factory import create_dataset
7+
from .dataset_info import DatasetInfo
8+
from .imagenet_info import ImageNetInfo, infer_imagenet_subset
79
from .loader import create_loader
810
from .mixup import Mixup, FastCollateMixup
911
from .readers import create_reader
File renamed without changes.

0 commit comments

Comments
 (0)