Skip to content

Commit 15fefd5

Browse files
author
Kent Sommer
committed
Patches for current PyTorch Release, fixes #3, #4, and #5
1 parent 32d2037 commit 15fefd5

File tree

7 files changed

+128
-14
lines changed

7 files changed

+128
-14
lines changed

.gitignore

Lines changed: 110 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,110 @@
1+
# Byte-compiled / optimized / DLL files
2+
__pycache__/
3+
*.py[cod]
4+
*$py.class
5+
6+
# C extensions
7+
*.so
8+
9+
# Distribution / packaging
10+
.Python
11+
build/
12+
develop-eggs/
13+
dist/
14+
downloads/
15+
eggs/
16+
.eggs/
17+
lib/
18+
lib64/
19+
parts/
20+
sdist/
21+
var/
22+
wheels/
23+
*.egg-info/
24+
.installed.cfg
25+
*.egg
26+
MANIFEST
27+
28+
# PyInstaller
29+
# Usually these files are written by a python script from a template
30+
# before PyInstaller builds the exe, so as to inject date/other infos into it.
31+
*.manifest
32+
*.spec
33+
34+
# Installer logs
35+
pip-log.txt
36+
pip-delete-this-directory.txt
37+
38+
# Unit test / coverage reports
39+
htmlcov/
40+
.tox/
41+
.coverage
42+
.coverage.*
43+
.cache
44+
nosetests.xml
45+
coverage.xml
46+
*.cover
47+
.hypothesis/
48+
.pytest_cache/
49+
50+
# Translations
51+
*.mo
52+
*.pot
53+
54+
# Django stuff:
55+
*.log
56+
local_settings.py
57+
db.sqlite3
58+
59+
# Flask stuff:
60+
instance/
61+
.webassets-cache
62+
63+
# Scrapy stuff:
64+
.scrapy
65+
66+
# Sphinx documentation
67+
docs/_build/
68+
69+
# PyBuilder
70+
target/
71+
72+
# Jupyter Notebook
73+
.ipynb_checkpoints
74+
75+
# pyenv
76+
.python-version
77+
78+
# celery beat schedule file
79+
celerybeat-schedule
80+
81+
# SageMath parsed files
82+
*.sage.py
83+
84+
# Environments
85+
.env
86+
.venv
87+
env/
88+
venv/
89+
ENV/
90+
env.bak/
91+
venv.bak/
92+
93+
# Spyder project settings
94+
.spyderproject
95+
.spyproject
96+
97+
# Rope project settings
98+
.ropeproject
99+
100+
# mkdocs documentation
101+
/site
102+
103+
# mypy
104+
.mypy_cache/
105+
106+
# npz
107+
*.npz
108+
109+
# pth
110+
*.pth

dataset/dataset.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@ def _process(self, file, train):
4141
"""Data format: A list, [train data, test data]
4242
Each data sample: label, S1, S2, Images, in this order.
4343
"""
44-
with np.load(file) as f:
44+
with np.load(file, mmap_mode='r') as f:
4545
if train:
4646
images = f['arr_0']
4747
S1 = f['arr_1']

dataset/make_training_data.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -89,7 +89,7 @@ def make_data(dom_size, n_domains, max_obs,
8989
return X_f, S1_f, S2_f, Labels_f
9090

9191

92-
def main(dom_size=[8,8], n_domains=15000, max_obs=30, max_obs_size=None,
92+
def main(dom_size=[28,28], n_domains=5000, max_obs=50, max_obs_size=2,
9393
n_traj=7, state_batch_size=1):
9494
# Get path to save dataset
9595
save_path = "dataset/gridworld_{0}x{1}".format(dom_size[0], dom_size[1])

download_weights_and_datasets.sh

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,9 @@
11
cd trained
2-
wget 'https://github.com/kentsommer/pytorch-value-iteration-networks/releases/download/v1.0/vin_8x8.pth'
3-
wget 'https://github.com/kentsommer/pytorch-value-iteration-networks/releases/download/v1.0/vin_16x16.pth'
4-
wget 'https://github.com/kentsommer/pytorch-value-iteration-networks/releases/download/v1.0/vin_28x28.pth'
2+
wget 'https://github.com/kentsommer/pytorch-value-iteration-networks/releases/download/v1.1/vin_8x8.pth'
3+
wget 'https://github.com/kentsommer/pytorch-value-iteration-networks/releases/download/v1.1/vin_16x16.pth'
4+
wget 'https://github.com/kentsommer/pytorch-value-iteration-networks/releases/download/v1.1/vin_28x28.pth'
55
cd ../dataset
6-
wget 'https://github.com/kentsommer/pytorch-value-iteration-networks/releases/download/v1.0/gridworld_8x8.npz'
7-
wget 'https://github.com/kentsommer/pytorch-value-iteration-networks/releases/download/v1.0/gridworld_16x16.npz'
8-
wget 'https://github.com/kentsommer/pytorch-value-iteration-networks/releases/download/v1.0/gridworld_28x28.npz'
6+
wget 'https://github.com/kentsommer/pytorch-value-iteration-networks/releases/download/v1.1/gridworld_8x8.npz'
7+
wget 'https://github.com/kentsommer/pytorch-value-iteration-networks/releases/download/v1.1/gridworld_16x16.npz'
8+
wget 'https://github.com/kentsommer/pytorch-value-iteration-networks/releases/download/v1.1/gridworld_28x28.npz'
99
cd ..

model.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@ def __init__(self, config):
3232
out_features=8,
3333
bias=False)
3434
self.w = Parameter(torch.zeros(config.l_q,1,3,3), requires_grad=True)
35-
self.sm = nn.Softmax()
35+
self.sm = nn.Softmax(dim=1)
3636

3737

3838
def forward(self, X, S1, S2, config):

test.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -22,9 +22,13 @@ def main(config, n_domains=100, max_obs=30,
2222
correct, total = 0.0, 0.0
2323
# Automatic swith of GPU mode if available
2424
use_GPU = torch.cuda.is_available()
25-
vin = torch.load(config.weights)
25+
# Instantiate a VIN model
26+
vin = VIN(config)
27+
# Load model parameters
28+
vin.load_state_dict(torch.load(config.weights))
29+
# Use GPU if available
2630
if use_GPU:
27-
vin = vin.cuda()
31+
vin = vin.cuda()
2832

2933
for dom in range(n_domains):
3034
# Randomly select goal position

train.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -77,7 +77,7 @@ def test(net, testloader, config):
7777
# Unwrap autograd.Variable to Tensor
7878
predicted = predicted.data
7979
# Compute test accuracy
80-
correct += (predicted == labels).sum()
80+
correct += (torch.eq(torch.squeeze(predicted), labels)).sum()
8181
total += labels.size()[0]
8282
print('Test Accuracy: {:.2f}%'.format(100*(correct/total)))
8383

@@ -147,5 +147,5 @@ def test(net, testloader, config):
147147
train(net, trainloader, config, criterion, optimizer, use_GPU)
148148
# Test accuracy
149149
test(net, testloader, config)
150-
# Save the trained model
151-
torch.save(net, save_path)
150+
# Save the trained model parameters
151+
torch.save(net.state_dict(), save_path)

0 commit comments

Comments
 (0)