Skip to content

Commit 1c8e29e

Browse files
committed
[pl] train_pl: fix wrong indention in evaluation.
1 parent 6c6da82 commit 1c8e29e

File tree

1 file changed

+42
-42
lines changed

1 file changed

+42
-42
lines changed

tools/train_pl.py

Lines changed: 42 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -158,48 +158,48 @@ def validation_step(self, data, batch_idx):
158158
loss = crit(model(fc_feats, att_feats,
159159
labels[..., :-1], att_masks), labels[..., 1:], masks[..., 1:])
160160

161-
# forward the model to also get generated samples for each image
162-
# Only leave one feature for each image, in case duplicate sample
163-
tmp_eval_kwargs = eval_kwargs.copy()
164-
tmp_eval_kwargs.update({'sample_n': 1})
165-
seq, seq_logprobs = model(
166-
fc_feats, att_feats, att_masks, opt=tmp_eval_kwargs, mode='sample')
167-
seq = seq.data
168-
entropy = - (F.softmax(seq_logprobs, dim=2) *
169-
seq_logprobs).sum(2).sum(1) / ((seq > 0).to(seq_logprobs).sum(1)+1)
170-
perplexity = - \
171-
seq_logprobs.gather(2, seq.unsqueeze(2)).squeeze(
172-
2).sum(1) / ((seq > 0).to(seq_logprobs).sum(1)+1)
173-
174-
# Print beam search
175-
if beam_size > 1 and verbose_beam:
176-
for i in range(fc_feats.shape[0]):
177-
print('\n'.join([utils.decode_sequence(model.vocab, _[
178-
'seq'].unsqueeze(0))[0] for _ in model.done_beams[i]]))
179-
print('--' * 10)
180-
sents = utils.decode_sequence(model.vocab, seq)
181-
182-
for k, sent in enumerate(sents):
183-
entry = {'image_id': data['infos'][k]['id'], 'caption': sent,
184-
'perplexity': perplexity[k].item(), 'entropy': entropy[k].item()}
185-
if eval_kwargs.get('dump_path', 0) == 1:
186-
entry['file_name'] = data['infos'][k]['file_path']
187-
predictions.append(entry)
188-
if eval_kwargs.get('dump_images', 0) == 1:
189-
# dump the raw image to vis/ folder
190-
cmd = 'cp "' + os.path.join(eval_kwargs['image_root'], data['infos'][k]['file_path']) + \
191-
'" vis/imgs/img' + \
192-
str(len(predictions)) + '.jpg' # bit gross
193-
print(cmd)
194-
os.system(cmd)
195-
196-
if verbose:
197-
print('image %s: %s' %
198-
(entry['image_id'], entry['caption']))
199-
200-
if sample_n > 1:
201-
eval_utils.eval_split_n(model, n_predictions, [
202-
fc_feats, att_feats, att_masks, data], eval_kwargs)
161+
# forward the model to also get generated samples for each image
162+
# Only leave one feature for each image, in case duplicate sample
163+
tmp_eval_kwargs = eval_kwargs.copy()
164+
tmp_eval_kwargs.update({'sample_n': 1})
165+
seq, seq_logprobs = model(
166+
fc_feats, att_feats, att_masks, opt=tmp_eval_kwargs, mode='sample')
167+
seq = seq.data
168+
entropy = - (F.softmax(seq_logprobs, dim=2) *
169+
seq_logprobs).sum(2).sum(1) / ((seq > 0).to(seq_logprobs).sum(1)+1)
170+
perplexity = - \
171+
seq_logprobs.gather(2, seq.unsqueeze(2)).squeeze(
172+
2).sum(1) / ((seq > 0).to(seq_logprobs).sum(1)+1)
173+
174+
# Print beam search
175+
if beam_size > 1 and verbose_beam:
176+
for i in range(fc_feats.shape[0]):
177+
print('\n'.join([utils.decode_sequence(model.vocab, _[
178+
'seq'].unsqueeze(0))[0] for _ in model.done_beams[i]]))
179+
print('--' * 10)
180+
sents = utils.decode_sequence(model.vocab, seq)
181+
182+
for k, sent in enumerate(sents):
183+
entry = {'image_id': data['infos'][k]['id'], 'caption': sent,
184+
'perplexity': perplexity[k].item(), 'entropy': entropy[k].item()}
185+
if eval_kwargs.get('dump_path', 0) == 1:
186+
entry['file_name'] = data['infos'][k]['file_path']
187+
predictions.append(entry)
188+
if eval_kwargs.get('dump_images', 0) == 1:
189+
# dump the raw image to vis/ folder
190+
cmd = 'cp "' + os.path.join(eval_kwargs['image_root'], data['infos'][k]['file_path']) + \
191+
'" vis/imgs/img' + \
192+
str(len(predictions)) + '.jpg' # bit gross
193+
print(cmd)
194+
os.system(cmd)
195+
196+
if verbose:
197+
print('image %s: %s' %
198+
(entry['image_id'], entry['caption']))
199+
200+
if sample_n > 1:
201+
eval_utils.eval_split_n(model, n_predictions, [
202+
fc_feats, att_feats, att_masks, data], eval_kwargs)
203203

204204
output = {
205205
'val_loss': loss,

0 commit comments

Comments
 (0)