Skip to content

Commit 7ad027f

Browse files
dmitriy-serdyukruotianluo
authored andcommitted
Pytorch 4 (#71)
* Update to pytorch4.1 * Refactor * Fix dataloader * Change to item
1 parent 0ff90bd commit 7ad027f

File tree

3 files changed

+6
-6
lines changed

3 files changed

+6
-6
lines changed

README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@ This is an image captioning codebase in PyTorch. If you are familiar with neural
88

99
## Requirements
1010
Python 2.7 (because there is no [coco-caption](https://github.com/tylin/coco-caption) version for python 3)
11-
PyTorch 0.2 (along with torchvision)
11+
PyTorch 0.4.1 (along with torchvision)
1212

1313
You need to download pretrained resnet model for both training and evaluation. The models can be downloaded from [here](https://drive.google.com/open?id=0B7fNdx_jAqhtbVYzOURMdDNHSGM), and should be placed in `data/imagenet_weights`.
1414

dataloaderraw.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -108,9 +108,10 @@ def get_batch(self, split, batch_size=None):
108108
img = np.concatenate((img, img, img), axis=2)
109109

110110
img = img.astype('float32')/255.0
111-
img = torch.from_numpy(img.transpose([2,0,1])).cuda()
112-
img = Variable(preprocess(img), volatile=True)
113-
tmp_fc, tmp_att = self.my_resnet(img)
111+
img = torch.from_numpy(img.transpose([2, 0, 1])).cuda()
112+
with torch.no_grad():
113+
img = Variable(preprocess(img))
114+
tmp_fc, tmp_att = self.my_resnet(img)
114115

115116
fc_batch[i] = tmp_fc.data.cpu().float().numpy()
116117
att_batch[i] = tmp_att.data.cpu().float().numpy()
@@ -136,4 +137,3 @@ def get_vocab_size(self):
136137

137138
def get_vocab(self):
138139
return self.ix_to_word
139-

train.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -114,7 +114,7 @@ def train(opt):
114114
loss.backward()
115115
utils.clip_gradient(optimizer, opt.grad_clip)
116116
optimizer.step()
117-
train_loss = loss.data[0]
117+
train_loss = loss.item()
118118
torch.cuda.synchronize()
119119
end = time.time()
120120
print("iter {} (epoch {}), train_loss = {:.3f}, time/batch = {:.3f}" \

0 commit comments

Comments
 (0)