diff --git a/README.md b/README.md index f4032334..8b94c261 100644 --- a/README.md +++ b/README.md @@ -1,6 +1,6 @@ -# Neuraltalk2-pytorch +# ImageCaptioning.pytorch -Changes compared to neuraltalk2. +This is an image captioning codebase in PyTorch. If you are familiar with neuraltalk2, here are the differences compared to neuraltalk2. - Instead of using random split, we use [karpathy's train-val-test split](http://cs.stanford.edu/people/karpathy/deepimagesent/caption_datasets.zip). - Instead of including the convnet in the model, we use preprocessed features. (finetuneable cnn version is in the branch **with_finetune**) - Use resnet instead of vgg; the feature extraction method is the same as in self-critical: run cnn on original image and adaptively average pool the last conv layer feature to fixed size . @@ -8,7 +8,7 @@ Changes compared to neuraltalk2. ## Requirements Python 2.7 (because there is no [coco-caption](https://github.com/tylin/coco-caption) version for python 3) -PyTorch 0.2 (along with torchvision) +PyTorch 0.4.1 (along with torchvision) You need to download pretrained resnet model for both training and evaluation. The models can be downloaded from [here](https://drive.google.com/open?id=0B7fNdx_jAqhtbVYzOURMdDNHSGM), and should be placed in `data/imagenet_weights`. @@ -31,6 +31,7 @@ Once we have these, we can now invoke the `prepro_*.py` script, which will read ```bash $ python scripts/prepro_labels.py --input_json data/dataset_coco.json --output_json data/cocotalk.json --output_h5 data/cocotalk $ python scripts/prepro_feats.py --input_json data/dataset_coco.json --output_dir data/cocotalk --images_root $IMAGE_ROOT + ``` `prepro_labels.py` will map all words that occur <= 5 times to a special `UNK` token, and create a vocabulary for all the remaining words. The image information and vocabulary are dumped into `data/cocotalk.json` and discretized caption data are dumped into `data/cocotalk_label.h5`. @@ -39,6 +40,12 @@ $ python scripts/prepro_feats.py --input_json data/dataset_coco.json --output_di (Check the prepro scripts for more options, like other resnet models or other attention sizes.) +**Legacy:** previously we extract features into separate npy/npz files for each image, but it would be slower to load on some NFS and also to copy them around. We now save all the features in h5 file. If you want to convert from previous npy/npz files to h5 file, you can use run + +```bash +$ python scripts/convert_old.py --input_json data/dataset_coco.json --fc_input_dir data/cocotalk_fc/ --att_input_dir data/cocotalk_att/ --fc_output_dir data/cocotalk_fc --att_output_dir data/cocotalk_att/ +``` + **Warning**: the prepro script will fail with the default MSCOCO data because one of their images is corrupted. See [this issue](https://github.com/karpathy/neuraltalk2/issues/4) for the fix, it involves manually replacing one image in the dataset. ### Start training @@ -97,6 +104,22 @@ The defualt split to evaluate is test. The default inference method is greedy de **Live demo**. Not supported now. Welcome pull request. +## Reference +If you find this implementation helpful, please consider citing this repo: + +``` +@misc{Luo2017, +author = {Ruotian Luo}, +title = {An Image Captioning codebase in PyTorch}, +year = {2017}, +publisher = {GitHub}, +journal = {GitHub repository}, +howpublished = {\url{https://github.com/ruotianluo/ImageCaptioning.pytorch}}, +} +``` + +Of course, please cite the original paper of models you are using (You can find references in the model files). + ## Acknowledgements -Thanks the original [neuraltalk2](https://github.com/karpathy/neuraltalk2) and awesome PyTorch team. \ No newline at end of file +Thanks the original [neuraltalk2](https://github.com/karpathy/neuraltalk2) and awesome PyTorch team. diff --git a/dataloader.py b/dataloader.py index f1175356..e5d54fc1 100644 --- a/dataloader.py +++ b/dataloader.py @@ -8,22 +8,17 @@ import numpy as np import random -import torch import torch.utils.data as data import multiprocessing -def get_npy_data(ix, fc_file, att_file, use_att): - if use_att == True: - return (np.load(fc_file), np.load(att_file)['feat'], ix) - else: - return (np.load(fc_file), np.zeros((1,1,1)), ix) class DataLoader(data.Dataset): def reset_iterator(self, split): del self._prefetch_process[split] - self._prefetch_process[split] = BlobFetcher(split, self, split=='train') + self._prefetch_process[split] = BlobFetcher(split, + self, split == 'train') self.iterators[split] = 0 def get_vocab_size(self): @@ -35,22 +30,40 @@ def get_vocab(self): def get_seq_length(self): return self.seq_length + def read_files(self): + self.feats_fc = h5py.File(os.path.join( + self.opt.input_fc_dir, 'feats_fc.h5'), 'r') + self.feats_att = h5py.File(os.path.join( + self.opt.input_att_dir, 'feats_att.h5'), 'r') + + def get_data(self, ix): + self.read_files() + index = str(self.info['images'][ix]['id']) + if self.use_att: + return (np.array(self.feats_fc[index]).astype('float32'), + np.array(self.feats_att[index]).astype('float32'), ix) + else: + return (np.array(self.feats_fc[index]).astype('float32'), + np.zeros((1, 1, 1)).astype('float32'), ix) + def __init__(self, opt): self.opt = opt self.batch_size = self.opt.batch_size self.seq_per_img = opt.seq_per_img self.use_att = getattr(opt, 'use_att', True) - # load the json file which contains additional information about the dataset + # load json file which contains additional information about dataset print('DataLoader loading json file: ', opt.input_json) self.info = json.load(open(self.opt.input_json)) self.ix_to_word = self.info['ix_to_word'] self.vocab_size = len(self.ix_to_word) print('vocab size is ', self.vocab_size) - + # open the hdf5 file - print('DataLoader loading h5 file: ', opt.input_fc_dir, opt.input_att_dir, opt.input_label_h5) - self.h5_label_file = h5py.File(self.opt.input_label_h5, 'r', driver='core') + print('DataLoader loading h5 file: ', opt.input_fc_dir, + opt.input_att_dir, opt.input_label_h5) + self.h5_label_file = h5py.File(self.opt.input_label_h5, 'r', + driver='core') self.input_fc_dir = self.opt.input_fc_dir self.input_att_dir = self.opt.input_att_dir @@ -64,7 +77,7 @@ def __init__(self, opt): self.label_end_ix = self.h5_label_file['label_end_ix'][:] self.num_images = self.label_start_ix.shape[0] - print('read %d image features' %(self.num_images)) + print('read %d image features' % (self.num_images)) # separate out indexes for each of the provided splits self.split_ix = {'train': [], 'val': [], 'test': []} @@ -76,23 +89,27 @@ def __init__(self, opt): self.split_ix['val'].append(ix) elif img['split'] == 'test': self.split_ix['test'].append(ix) - elif opt.train_only == 0: # restval + elif opt.train_only == 0: # restval self.split_ix['train'].append(ix) - print('assigned %d images to split train' %len(self.split_ix['train'])) - print('assigned %d images to split val' %len(self.split_ix['val'])) - print('assigned %d images to split test' %len(self.split_ix['test'])) + print('assigned %d images to split train' % len(self.split_ix['train'])) + print('assigned %d images to split val' % len(self.split_ix['val'])) + print('assigned %d images to split test' % len(self.split_ix['test'])) self.iterators = {'train': 0, 'val': 0, 'test': 0} - - self._prefetch_process = {} # The three prefetch process + + self._prefetch_process = {} # The three prefetch process for split in self.iterators.keys(): - self._prefetch_process[split] = BlobFetcher(split, self, split=='train') + self._prefetch_process[split] = BlobFetcher(split, + self, + split == 'train') # Terminate the child process when the parent exists + def cleanup(): print('Terminating BlobFetcher') for split in self.iterators.keys(): del self._prefetch_process[split] + import atexit atexit.register(cleanup) @@ -100,10 +117,12 @@ def get_batch(self, split, batch_size=None, seq_per_img=None): batch_size = batch_size or self.batch_size seq_per_img = seq_per_img or self.seq_per_img - fc_batch = [] # np.ndarray((batch_size * seq_per_img, self.opt.fc_feat_size), dtype = 'float32') - att_batch = [] # np.ndarray((batch_size * seq_per_img, 14, 14, self.opt.att_feat_size), dtype = 'float32') - label_batch = np.zeros([batch_size * seq_per_img, self.seq_length + 2], dtype = 'int') - mask_batch = np.zeros([batch_size * seq_per_img, self.seq_length + 2], dtype = 'float32') + fc_batch = [] + att_batch = [] + label_batch = np.zeros( + [batch_size * seq_per_img, self.seq_length + 2], dtype='int') + mask_batch = np.zeros( + [batch_size * seq_per_img, self.seq_length + 2], dtype='float32') wrapped = False @@ -111,8 +130,6 @@ def get_batch(self, split, batch_size=None, seq_per_img=None): gts = [] for i in range(batch_size): - import time - t_start = time.time() # fetch image tmp_fc, tmp_att,\ ix, tmp_wrapped = self._prefetch_process[split].get() @@ -120,76 +137,84 @@ def get_batch(self, split, batch_size=None, seq_per_img=None): att_batch += [tmp_att] * seq_per_img # fetch the sequence labels - ix1 = self.label_start_ix[ix] - 1 #label_start_ix starts from 1 + ix1 = self.label_start_ix[ix] - 1 # label_start_ix starts from 1 ix2 = self.label_end_ix[ix] - 1 - ncap = ix2 - ix1 + 1 # number of captions available for this image - assert ncap > 0, 'an image does not have any label. this can be handled but right now isn\'t' + ncap = ix2 - ix1 + 1 # number of captions available for this image + assert ncap > 0, 'an image does not have any label.' if ncap < seq_per_img: # we need to subsample (with replacement) - seq = np.zeros([seq_per_img, self.seq_length], dtype = 'int') + seq = np.zeros([seq_per_img, self.seq_length], dtype='int') for q in range(seq_per_img): - ixl = random.randint(ix1,ix2) - seq[q, :] = self.h5_label_file['labels'][ixl, :self.seq_length] + ixl = random.randint(ix1, ix2) + seq[q, :] = self.h5_label_file['labels'][ixl, + :self.seq_length] else: ixl = random.randint(ix1, ix2 - seq_per_img + 1) - seq = self.h5_label_file['labels'][ixl: ixl + seq_per_img, :self.seq_length] - - label_batch[i * seq_per_img : (i + 1) * seq_per_img, 1 : self.seq_length + 1] = seq + seq = self.h5_label_file['labels'][ixl: ixl + seq_per_img, + :self.seq_length] + + label_batch[i * seq_per_img: (i + 1) * seq_per_img, + 1: self.seq_length + 1] = seq if tmp_wrapped: wrapped = True # Used for reward evaluation - gts.append(self.h5_label_file['labels'][self.label_start_ix[ix] - 1: self.label_end_ix[ix]]) - + gts.append( + self.h5_label_file['labels'][self.label_start_ix[ix] - 1: + self.label_end_ix[ix]]) + # record associated info as well info_dict = {} info_dict['ix'] = ix info_dict['id'] = self.info['images'][ix]['id'] info_dict['file_path'] = self.info['images'][ix]['file_path'] infos.append(info_dict) - #print(i, time.time() - t_start) # generate mask - t_start = time.time() - nonzeros = np.array(list(map(lambda x: (x != 0).sum()+2, label_batch))) + nonzeros = np.array(list(map(lambda x: (x != 0).sum() + 2, label_batch))) for ix, row in enumerate(mask_batch): row[:nonzeros[ix]] = 1 - #print('mask', time.time() - t_start) data = {} data['fc_feats'] = np.stack(fc_batch) data['att_feats'] = np.stack(att_batch) data['labels'] = label_batch data['gts'] = gts - data['masks'] = mask_batch - data['bounds'] = {'it_pos_now': self.iterators[split], 'it_max': len(self.split_ix[split]), 'wrapped': wrapped} + data['masks'] = mask_batch + data['bounds'] = {'it_pos_now': self.iterators[split], + 'it_max': len(self.split_ix[split]), + 'wrapped': wrapped} data['infos'] = infos return data - # It's not coherent to make DataLoader a subclass of Dataset, but essentially, we only need to implement the following to functions, - # so that the torch.utils.data.DataLoader can load the data according the index. - # However, it's minimum change to switch to pytorch data loading. + # It's not coherent to make DataLoader a subclass of Dataset, + # but essentially, we only need to implement the following to functions, + # so that the torch.utils.data.DataLoader can load the data according + # the index. However, it's minimum change to switch to pytorch data loading def __getitem__(self, index): """This function returns a tuple that is further passed to collate_fn """ - ix = index #self.split_ix[index] - return get_npy_data(ix, \ - os.path.join(self.input_fc_dir, str(self.info['images'][ix]['id']) + '.npy'), - os.path.join(self.input_att_dir, str(self.info['images'][ix]['id']) + '.npz'), - self.use_att - ) + ix = index # self.split_ix[index] + return self.get_data(ix) def __len__(self): return len(self.info['images']) + +class ArraySampler(data.sampler.SubsetRandomSampler): + def __iter__(self): + return iter(self.indices) + + class BlobFetcher(): """Experimental class for prefetching blobs in a separate process.""" def __init__(self, split, dataloader, if_shuffle=False): """ - db is a list of tuples containing: imcrop_name, caption, bbox_feat of gt box, imname + db is a list of tuples containing: imcrop_name, + caption, bbox_feat of gt box, imname """ self.split = split self.dataloader = dataloader @@ -199,17 +224,22 @@ def __init__(self, split, dataloader, if_shuffle=False): def reset(self): """ Two cases: - 1. not hasattr(self, 'split_loader'): Resume from previous training. Create the dataset given the saved split_ix and iterator - 2. wrapped: a new epoch, the split_ix and iterator have been updated in the get_minibatch_inds already. + 1. not hasattr(self, 'split_loader'): Resume from previous training. + Create the dataset given the saved split_ix and iterator + 2. wrapped: a new epoch, the split_ix and iterator have been updated in + the get_minibatch_inds already. """ # batch_size is 0, the merge is done in DataLoader class - self.split_loader = iter(data.DataLoader(dataset=self.dataloader, - batch_size=1, - sampler=self.dataloader.split_ix[self.split][self.dataloader.iterators[self.split]:], - shuffle=False, - pin_memory=True, - num_workers=multiprocessing.cpu_count(), - collate_fn=lambda x: x[0])) + sampler = ArraySampler( + self.dataloader.split_ix[self.split][self.dataloader.iterators[self.split]:]) + self.split_loader = iter( + data.DataLoader(dataset=self.dataloader, + batch_size=1, + sampler=sampler, + shuffle=False, + pin_memory=True, + num_workers=multiprocessing.cpu_count(), + collate_fn=lambda x: x[0])) def _get_next_minibatch_inds(self): max_index = len(self.dataloader.split_ix[self.split]) @@ -227,7 +257,7 @@ def _get_next_minibatch_inds(self): self.dataloader.iterators[self.split] = ri_next return ix, wrapped - + def get(self): if not hasattr(self, 'split_loader'): self.reset() @@ -236,7 +266,6 @@ def get(self): tmp = self.split_loader.next() if wrapped: self.reset() - assert tmp[2] == ix, "ix not equal" - return tmp + [wrapped] \ No newline at end of file + return tmp + [wrapped] diff --git a/dataloaderraw.py b/dataloaderraw.py index d2180770..01a142b4 100644 --- a/dataloaderraw.py +++ b/dataloaderraw.py @@ -108,9 +108,10 @@ def get_batch(self, split, batch_size=None): img = np.concatenate((img, img, img), axis=2) img = img.astype('float32')/255.0 - img = torch.from_numpy(img.transpose([2,0,1])).cuda() - img = Variable(preprocess(img), volatile=True) - tmp_fc, tmp_att = self.my_resnet(img) + img = torch.from_numpy(img.transpose([2, 0, 1])).cuda() + with torch.no_grad(): + img = Variable(preprocess(img)) + tmp_fc, tmp_att = self.my_resnet(img) fc_batch[i] = tmp_fc.data.cpu().float().numpy() att_batch[i] = tmp_att.data.cpu().float().numpy() @@ -136,4 +137,3 @@ def get_vocab_size(self): def get_vocab(self): return self.ix_to_word - \ No newline at end of file diff --git a/eval.py b/eval.py index 9d26932b..1f859dd1 100644 --- a/eval.py +++ b/eval.py @@ -1,3 +1,4 @@ +#!/usr/bin/env python from __future__ import absolute_import from __future__ import division from __future__ import print_function @@ -21,59 +22,59 @@ # Input arguments and options parser = argparse.ArgumentParser() # Input paths -parser.add_argument('--model', type=str, default='', - help='path to model to evaluate') +parser.add_argument('--model', type=str, required=True, + help='path to model to evaluate') parser.add_argument('--cnn_model', type=str, default='resnet101', - help='resnet101, resnet152') -parser.add_argument('--infos_path', type=str, default='', - help='path to infos to evaluate') + help='resnet101, resnet152') +parser.add_argument('--infos_path', type=str, required=True, + help='path to infos to evaluate') # Basic options parser.add_argument('--batch_size', type=int, default=0, - help='if > 0 then overrule, otherwise load from checkpoint.') + help='if > 0 then overrule, otherwise load from checkpoint.') parser.add_argument('--num_images', type=int, default=-1, - help='how many images to use when periodically evaluating the loss? (-1 = all)') + help='how many images to use when periodically evaluating the loss? (-1 = all)') parser.add_argument('--language_eval', type=int, default=0, - help='Evaluate language as well (1 = yes, 0 = no)? BLEU/CIDEr/METEOR/ROUGE_L? requires coco-caption code from Github.') + help='Evaluate language as well (1 = yes, 0 = no)? BLEU/CIDEr/METEOR/ROUGE_L? requires coco-caption code from Github.') parser.add_argument('--dump_images', type=int, default=1, - help='Dump images into vis/imgs folder for vis? (1=yes,0=no)') + help='Dump images into vis/imgs folder for vis? (1=yes,0=no)') parser.add_argument('--dump_json', type=int, default=1, - help='Dump json with predictions into vis folder? (1=yes,0=no)') + help='Dump json with predictions into vis folder? (1=yes,0=no)') parser.add_argument('--dump_path', type=int, default=0, - help='Write image paths along with predictions into vis json? (1=yes,0=no)') + help='Write image paths along with predictions into vis json? (1=yes,0=no)') # Sampling options parser.add_argument('--sample_max', type=int, default=1, - help='1 = sample argmax words. 0 = sample from distributions.') + help='1 = sample argmax words. 0 = sample from distributions.') parser.add_argument('--beam_size', type=int, default=2, - help='used when sample_max = 1, indicates number of beams in beam search. Usually 2 or 3 works well. More is not better. Set this to 1 for faster runtime but a bit worse performance.') + help='used when sample_max = 1, indicates number of beams in beam search. Usually 2 or 3 works well. More is not better. Set this to 1 for faster runtime but a bit worse performance.') parser.add_argument('--temperature', type=float, default=1.0, - help='temperature when sampling from distributions (i.e. when sample_max = 0). Lower = "safer" predictions.') + help='temperature when sampling from distributions (i.e. when sample_max = 0). Lower = "safer" predictions.') # For evaluation on a folder of images: -parser.add_argument('--image_folder', type=str, default='', - help='If this is nonempty then will predict on the images in this folder path') -parser.add_argument('--image_root', type=str, default='', - help='In case the image paths have to be preprended with a root path to an image folder') +parser.add_argument('--image_folder', type=str, default='', + help='If this is nonempty then will predict on the images in this folder path') +parser.add_argument('--image_root', type=str, default='', + help='In case the image paths have to be preprended with a root path to an image folder') # For evaluation on MSCOCO images from some split: parser.add_argument('--input_fc_dir', type=str, default='', - help='path to the h5file containing the preprocessed dataset') + help='path to the h5file containing the preprocessed dataset') parser.add_argument('--input_att_dir', type=str, default='', - help='path to the h5file containing the preprocessed dataset') + help='path to the h5file containing the preprocessed dataset') parser.add_argument('--input_label_h5', type=str, default='', - help='path to the h5file containing the preprocessed dataset') -parser.add_argument('--input_json', type=str, default='', - help='path to the json file containing additional info and vocab. empty = fetch from model checkpoint.') -parser.add_argument('--split', type=str, default='test', - help='if running on MSCOCO images, which split to use: val|test|train') -parser.add_argument('--coco_json', type=str, default='', - help='if nonempty then use this file in DataLoaderRaw (see docs there). Used only in MSCOCO test evaluation, where we have a specific json file of only test set images.') + help='path to the h5file containing the preprocessed dataset') +parser.add_argument('--input_json', type=str, default='', + help='path to the json file containing additional info and vocab. empty = fetch from model checkpoint.') +parser.add_argument('--split', type=str, default='test', + help='if running on MSCOCO images, which split to use: val|test|train') +parser.add_argument('--coco_json', type=str, default='', + help='if nonempty then use this file in DataLoaderRaw (see docs there). Used only in MSCOCO test evaluation, where we have a specific json file of only test set images.') # misc -parser.add_argument('--id', type=str, default='', - help='an id identifying this run/job. used only if language_eval = 1 for appending to intermediate files') +parser.add_argument('--id', type=str, default='', + help='an id identifying this run/job. used only if language_eval = 1 for appending to intermediate files') opt = parser.parse_args() # Load infos -with open(opt.infos_path) as f: +with open(opt.infos_path, 'rb') as f: infos = cPickle.load(f) # override and collect parameters @@ -106,9 +107,9 @@ # Create the Data Loader instance if len(opt.image_folder) == 0: - loader = DataLoader(opt) + loader = DataLoader(opt) else: - loader = DataLoaderRaw({'folder_path': opt.image_folder, + loader = DataLoaderRaw({'folder_path': opt.image_folder, 'coco_json': opt.coco_json, 'batch_size': opt.batch_size, 'cnn_model': opt.cnn_model}) @@ -118,12 +119,13 @@ # Set sample options -loss, split_predictions, lang_stats = eval_utils.eval_split(model, crit, loader, +loss, split_predictions, lang_stats = eval_utils.eval_split( + model, crit, loader, vars(opt)) print('loss: ', loss) if lang_stats: - print(lang_stats) + print(lang_stats) if opt.dump_json == 1: # dump the json diff --git a/eval_utils.py b/eval_utils.py index ab0abd06..71911b9a 100644 --- a/eval_utils.py +++ b/eval_utils.py @@ -85,10 +85,11 @@ def eval_split(model, crit, loader, eval_kwargs={}): if data.get('labels', None) is not None: # forward the model to get loss tmp = [data['fc_feats'], data['att_feats'], data['labels'], data['masks']] - tmp = [Variable(torch.from_numpy(_), volatile=True).cuda() for _ in tmp] - fc_feats, att_feats, labels, masks = tmp + with torch.no_grad(): + tmp = [Variable(torch.from_numpy(_)).cuda() for _ in tmp] + fc_feats, att_feats, labels, masks = tmp - loss = crit(model(fc_feats, att_feats, labels), labels[:,1:], masks[:,1:]).data[0] + loss = crit(model(fc_feats, att_feats, labels), labels[:,1:], masks[:,1:]).item() loss_sum = loss_sum + loss loss_evals = loss_evals + 1 @@ -96,10 +97,12 @@ def eval_split(model, crit, loader, eval_kwargs={}): # Only leave one feature for each image, in case duplicate sample tmp = [data['fc_feats'][np.arange(loader.batch_size) * loader.seq_per_img], data['att_feats'][np.arange(loader.batch_size) * loader.seq_per_img]] - tmp = [Variable(torch.from_numpy(_), volatile=True).cuda() for _ in tmp] - fc_feats, att_feats = tmp - # forward the model to also get generated samples for each image - seq, _ = model.sample(fc_feats, att_feats, eval_kwargs) + with torch.no_grad(): + tmp = [Variable(torch.from_numpy(_)).cuda() for _ in tmp] + fc_feats, att_feats = tmp + # forward the model to also get generated samples for each image + seq, _ = model.sample(fc_feats, att_feats, eval_kwargs) + seq = seq.cpu().numpy() #set_trace() sents = utils.decode_sequence(loader.get_vocab(), seq) diff --git a/misc/utils.py b/misc/utils.py index 6ccd0f94..88a08af2 100644 --- a/misc/utils.py +++ b/misc/utils.py @@ -15,7 +15,7 @@ def if_use_att(caption_model): # Input: seq, N*D numpy array, with element 0 .. vocab_size. 0 is END token. def decode_sequence(ix_to_word, seq): - N, D = seq.size() + N, D = seq.shape out = [] for i in range(N): txt = '' @@ -42,9 +42,9 @@ def __init__(self): def forward(self, input, target, mask): # truncate to the same size - target = target[:, :input.size(1)] - mask = mask[:, :input.size(1)] - input = to_contiguous(input).view(-1, input.size(2)) + target = target[:, :input.shape[1]] + mask = mask[:, :input.shape[1]] + input = to_contiguous(input).view(-1, input.shape[2]) target = to_contiguous(target).view(-1, 1) mask = to_contiguous(mask).view(-1, 1) output = - input.gather(1, target) * mask @@ -59,4 +59,4 @@ def set_lr(optimizer, lr): def clip_gradient(optimizer, grad_clip): for group in optimizer.param_groups: for param in group['params']: - param.grad.data.clamp_(-grad_clip, grad_clip) \ No newline at end of file + param.grad.data.clamp_(-grad_clip, grad_clip) diff --git a/models/CaptionModel.py b/models/CaptionModel.py index 4f04fcdc..14062bee 100644 --- a/models/CaptionModel.py +++ b/models/CaptionModel.py @@ -37,7 +37,7 @@ def beam_step(logprobsf, beam_size, t, beam_seq, beam_seq_logprobs, beam_logprob #beam_seq_logprobs : log-probability of each decision made, same size as beam_seq #beam_logprobs_sum : joint log-probability of each beam - ys,ix = torch.sort(logprobsf,1,True) + ys, ix = torch.sort(logprobsf, 1, True) candidates = [] cols = min(beam_size, ys.size(1)) rows = beam_size @@ -45,10 +45,12 @@ def beam_step(logprobsf, beam_size, t, beam_seq, beam_seq_logprobs, beam_logprob rows = 1 for c in range(cols): # for each column (word, essentially) for q in range(rows): # for each beam expansion - #compute logprob of expanding beam q with word in (sorted) position c - local_logprob = ys[q,c] - candidate_logprob = beam_logprobs_sum[q] + local_logprob - candidates.append({'c':ix[q,c], 'q':q, 'p':candidate_logprob, 'r':local_logprob}) + # compute logprob of expanding beam q with word in (sorted) position c + local_logprob = ys[q, c] + candidate_logprob = beam_logprobs_sum[q] + local_logprob.cpu() + candidates.append(dict(c=ix[q, c], q=q, + p=candidate_logprob, + r=local_logprob)) candidates = sorted(candidates, key=lambda x: -x['p']) new_state = [_.clone() for _ in state] @@ -80,7 +82,8 @@ def beam_step(logprobsf, beam_size, t, beam_seq, beam_seq_logprobs, beam_logprob beam_seq = torch.LongTensor(self.seq_length, beam_size).zero_() beam_seq_logprobs = torch.FloatTensor(self.seq_length, beam_size).zero_() - beam_logprobs_sum = torch.zeros(beam_size) # running sum of logprobs for each beam + # running sum of logprobs for each beam + beam_logprobs_sum = torch.zeros(beam_size) done_beams = [] for t in range(self.seq_length): @@ -110,7 +113,7 @@ def beam_step(logprobsf, beam_size, t, beam_seq, beam_seq_logprobs, beam_logprob final_beam = { 'seq': beam_seq[:, vix].clone(), 'logps': beam_seq_logprobs[:, vix].clone(), - 'p': beam_logprobs_sum[vix] + 'p': float(beam_logprobs_sum[vix].cpu().numpy()) } done_beams.append(final_beam) # don't continue beams from finished sequences diff --git a/scripts/convert_old.py b/scripts/convert_old.py new file mode 100644 index 00000000..5744f1c1 --- /dev/null +++ b/scripts/convert_old.py @@ -0,0 +1,55 @@ +import argparse +import h5py +import os +import numpy as np +import json + + +def main(params): + if not os.path.isdir(params['fc_output_dir']): + os.mkdir(params['fc_output_dir']) + if not os.path.isdir(params['att_output_dir']): + os.mkdir(params['att_output_dir']) + + imgs = json.load(open(params['input_json'], 'r')) + imgs = imgs['images'] + N = len(imgs) + + with h5py.File(os.path.join(params['fc_output_dir'], 'feats_fc.h5')) as file_fc,\ + h5py.File(os.path.join(params['att_output_dir'], 'feats_att.h5')) as file_att: + for i, img in enumerate(imgs): + npy_fc_path = os.path.join( + params['fc_input_dir'], + str(img['cocoid']) + '.npy') + npy_att_path = os.path.join( + params['att_input_dir'], + str(img['cocoid']) + '.npz') + + d_set_fc = file_fc.create_dataset( + str(img['cocoid']), data=np.load(npy_fc_path)) + d_set_att = file_att.create_dataset( + str(img['cocoid']), + data=np.load(npy_att_path)['feat']) + + if i % 1000 == 0: + print('processing %d/%d (%.2f%% done)' % (i, N, i * 100.0 / N)) + file_fc.close() + file_att.close() + + +if __name__ == "__main__": + + parser = argparse.ArgumentParser() + + parser.add_argument('--input_json', required=True, help='input json file to process into hdf5') + parser.add_argument('--fc_output_dir', default='data', help='output directory for fc') + parser.add_argument('--att_output_dir', default='data', help='output directory for att') + parser.add_argument('--fc_input_dir', default='data', help='input directory for numpy fc files') + parser.add_argument('--att_input_dir', default='data', help='input directory for numpy att files') + + args = parser.parse_args() + params = vars(args) # convert to ordinary dict + print('parsed input parameters:') + print(json.dumps(params, indent=2)) + + main(params) diff --git a/scripts/prepro_feats.py b/scripts/prepro_feats.py index 6489e49f..6492a827 100644 --- a/scripts/prepro_feats.py +++ b/scripts/prepro_feats.py @@ -30,14 +30,11 @@ import os import json import argparse -from random import shuffle, seed -import string -# non-standard dependencies: import h5py -from six.moves import cPickle +from random import shuffle, seed + import numpy as np import torch -import torchvision.models as models from torch.autograd import Variable import skimage.io @@ -50,6 +47,7 @@ from misc.resnet_utils import myResnet import misc.resnet as resnet + def main(params): net = getattr(resnet, params['model'])() net.load_state_dict(torch.load(os.path.join(params['model_root'],params['model']+'.pth'))) @@ -70,25 +68,37 @@ def main(params): if not os.path.isdir(dir_att): os.mkdir(dir_att) - for i,img in enumerate(imgs): - # load the image - I = skimage.io.imread(os.path.join(params['images_root'], img['filepath'], img['filename'])) - # handle grayscale input images - if len(I.shape) == 2: - I = I[:,:,np.newaxis] - I = np.concatenate((I,I,I), axis=2) - - I = I.astype('float32')/255.0 - I = torch.from_numpy(I.transpose([2,0,1])).cuda() - I = Variable(preprocess(I), volatile=True) - tmp_fc, tmp_att = my_resnet(I, params['att_size']) - # write to pkl - np.save(os.path.join(dir_fc, str(img['cocoid'])), tmp_fc.data.cpu().float().numpy()) - np.savez_compressed(os.path.join(dir_att, str(img['cocoid'])), feat=tmp_att.data.cpu().float().numpy()) - - if i % 1000 == 0: - print('processing %d/%d (%.2f%% done)' % (i, N, i*100.0/N)) - print('wrote ', params['output_dir']) + with h5py.File(os.path.join(dir_fc, 'feats_fc.h5')) as file_fc,\ + h5py.File(os.path.join(dir_att, 'feats_att.h5')) as file_att: + for i, img in enumerate(imgs): + # load the image + I = skimage.io.imread(os.path.join(params['images_root'], img['filepath'], img['filename'])) + # handle grayscale input images + if len(I.shape) == 2: + I = I[:,:,np.newaxis] + I = np.concatenate((I,I,I), axis=2) + + I = I.astype('float32')/255.0 + I = torch.from_numpy(I.transpose([2,0,1])).cuda() + with torch.no_grad(): + I = Variable(preprocess(I)) + tmp_fc, tmp_att = my_resnet(I, params['att_size']) + + # write to hdf5 + d_set_fc = file_fc.create_dataset( + str(img['cocoid']), + (2048,), dtype="float") + d_set_att = file_att.create_dataset( + str(img['cocoid']), + (params['att_size'], params['att_size'], 2048), dtype="float") + + d_set_fc[...] = tmp_fc.cpu().float().numpy() + d_set_att[...] = tmp_att.cpu().float().numpy() + if i % 1000 == 0: + print('processing %d/%d (%.2f%% done)' % (i, N, i*100.0 / N)) + file_fc.close() + file_att.close() + if __name__ == "__main__": @@ -96,7 +106,7 @@ def main(params): # input json parser.add_argument('--input_json', required=True, help='input json file to process into hdf5') - parser.add_argument('--output_dir', default='data', help='output h5 file') + parser.add_argument('--output_dir', default='data', help='output directory') # options parser.add_argument('--images_root', default='', help='root location in which images are stored, to be prepended to file_path in input json') diff --git a/train.py b/train.py index 7a626977..83850094 100644 --- a/train.py +++ b/train.py @@ -1,3 +1,4 @@ +#!/usr/bin/env python from __future__ import absolute_import from __future__ import division from __future__ import print_function @@ -114,7 +115,7 @@ def train(opt): loss.backward() utils.clip_gradient(optimizer, opt.grad_clip) optimizer.step() - train_loss = loss.data[0] + train_loss = loss.item() torch.cuda.synchronize() end = time.time() print("iter {} (epoch {}), train_loss = {:.3f}, time/batch = {:.3f}" \