Skip to content

Commit e69ccc9

Browse files
author
Kent Sommer
committed
Code comments and cleanup
1 parent 583f192 commit e69ccc9

File tree

12 files changed

+103
-68
lines changed

12 files changed

+103
-68
lines changed

README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -71,7 +71,7 @@ python test.py --weights trained/vin_28x28.pth --imsize 28 --k 36
7171
## Results
7272
Gridworld | Sample One | Sample Two
7373
-- | --- | ---
74-
8x8 | <img src="results/8x8_2.png" width="450"> | <img src="results/8x8_3.png" width="450">
74+
8x8 | <img src="results/8x8_1.png" width="450"> | <img src="results/8x8_2.png" width="450">
7575
16x16 | <img src="results/16x16_1.png" width="450"> | <img src="results/16x16_2.png" width="450">
7676

7777
## Datasets

dataset/dataset.py

Lines changed: 12 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,40 +1,42 @@
1+
import numpy as np
2+
13
import torch
24
import torch.utils.data as data
3-
import numpy as np
45

56

67
class GridworldData(data.Dataset):
7-
def __init__(self, file, imsize, train=True, transform=None, target_transform=None):
8+
def __init__(self, file, imsize, train=True,
9+
transform=None, target_transform=None):
810
assert file.endswith('.npz') # Must be .npz format
911
self.file = file
1012
self.imsize = imsize
1113
self.transform = transform
1214
self.target_transform = target_transform
1315
self.train = train # training set or test set
1416

15-
self.images, self.S1, self.S2, self.labels = self._process(file, self.train)
17+
self.images, self.S1, self.S2, self.labels = \
18+
self._process(file, self.train)
1619

1720
def __getitem__(self, index):
1821
img = self.images[index]
1922
s1 = self.S1[index]
2023
s2 = self.S2[index]
2124
label = self.labels[index]
22-
25+
# Apply transform if we have one
2326
if self.transform is not None:
2427
img = self.transform(img)
2528
else: # Internal default transform: Just to Tensor
2629
img = torch.from_numpy(img)
27-
30+
# Apply target transform if we have one
2831
if self.target_transform is not None:
2932
label = self.target_transform(label)
30-
31-
# Ensure labels in naive float type
32-
# DataLoader has bug with np.int/float type in default_collate()
3333
return img, int(s1), int(s2), int(label)
3434

35+
3536
def __len__(self):
3637
return self.images.shape[0]
3738

39+
3840
def _process(self, file, train):
3941
"""Data format: A list, [train data, test data]
4042
Each data sample: label, S1, S2, Images, in this order.
@@ -50,15 +52,14 @@ def _process(self, file, train):
5052
S1 = f['arr_5']
5153
S2 = f['arr_6']
5254
labels = f['arr_7']
53-
55+
# Set proper datatypes
5456
images = images.astype(np.float32)
5557
S1 = S1.astype(int) # (S1, S2) location are integers
5658
S2 = S2.astype(int)
5759
labels = labels.astype(int) # labels are integers
58-
60+
# Print number of samples
5961
if train:
6062
print("Number of Train Samples: {0}".format(images.shape[0]))
6163
else:
6264
print("Number of Test Samples: {0}".format(images.shape[0]))
63-
6465
return images, S1, S2, labels

dataset/make_training_data.py

Lines changed: 20 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -9,19 +9,24 @@
99
sys.path.remove('.')
1010

1111
def extract_action(traj):
12+
# Given a trajectory, outputs a 1D vector of
13+
# actions corresponding to the trajectory.
1214
n_actions = 8
13-
action_vecs = np.asarray([[-1., 0.],[1.,0.],[0.,1.],[0.,-1.],[-1.,1.],[-1.,-1.],[1.,1.],[1.,-1.]])
15+
action_vecs = np.asarray([[-1., 0.],[1.,0.],[0.,1.],[0.,-1.],[-1.,1.],
16+
[-1.,-1.],[1.,1.],[1.,-1.]])
1417
action_vecs[4:] = 1/np.sqrt(2) * action_vecs[4:]
1518
action_vecs = action_vecs.T
1619
state_diff = np.diff(traj, axis=0)
17-
norm_state_diff = state_diff * np.tile(1/np.sqrt(np.sum(np.square(state_diff), axis=1)), (2, 1)).T
20+
norm_state_diff = state_diff * np.tile(1/np.sqrt(np.sum(np.square(
21+
state_diff), axis=1)), (2, 1)).T
1822
prj_state_diff = np.dot(norm_state_diff, action_vecs)
1923
actions_one_hot = np.abs(prj_state_diff -1)<0.00001
2024
actions = np.dot(actions_one_hot, np.arange(n_actions).T)
2125
return actions
2226

2327

24-
def make_data(dom_size, n_domains, max_obs, max_obs_size, n_traj, state_batch_size):
28+
def make_data(dom_size, n_domains, max_obs,
29+
max_obs_size, n_traj, state_batch_size):
2530

2631
X_l = []
2732
S1_l = []
@@ -43,7 +48,6 @@ def make_data(dom_size, n_domains, max_obs, max_obs_size, n_traj, state_batch_si
4348
continue
4449
# Get final map
4550
im = obs.get_final()
46-
4751
# Generate gridworld from obstacle map
4852
G = gridworld(im, goal[0], goal[1])
4953
# Get value prior
@@ -59,7 +63,8 @@ def make_data(dom_size, n_domains, max_obs, max_obs_size, n_traj, state_batch_si
5963
image = 1 - im
6064
# Resize domain and goal images and concate
6165
image_data = np.resize(image, (1,1,dom_size[0],dom_size[1]))
62-
value_data = np.resize(value_prior, (1,1,dom_size[0],dom_size[1]))
66+
value_data = np.resize(value_prior, (1,1,dom_size[0],
67+
dom_size[1]))
6368
iv_mixed = np.concatenate((image_data, value_data), axis=1)
6469
X_current = np.tile(iv_mixed, (ns, 1, 1, 1))
6570
# Resize states
@@ -86,18 +91,19 @@ def make_data(dom_size, n_domains, max_obs, max_obs_size, n_traj, state_batch_si
8691

8792
def main(dom_size=[8,8], n_domains=15000, max_obs=30, max_obs_size=None,
8893
n_traj=7, state_batch_size=1):
89-
94+
# Get path to save dataset
9095
save_path = "dataset/gridworld_{0}x{1}".format(dom_size[0], dom_size[1])
91-
96+
# Get training data
9297
print("Now making training data...")
93-
X_out_tr, S1_out_tr, S2_out_tr, Labels_out_tr = make_data(dom_size, n_domains, max_obs,
94-
max_obs_size, n_traj, state_batch_size)
98+
X_out_tr, S1_out_tr, S2_out_tr, Labels_out_tr = make_data(
99+
dom_size, n_domains, max_obs, max_obs_size, n_traj, state_batch_size)
100+
# Get testing data
95101
print("\nNow making testing data...")
96-
X_out_ts, S1_out_ts, S2_out_ts, Labels_out_ts = make_data(dom_size, n_domains/6,
97-
max_obs, max_obs_size, n_traj, state_batch_size)
98-
99-
np.savez_compressed(save_path, X_out_tr, S1_out_tr, S2_out_tr, Labels_out_tr,
100-
X_out_ts, S1_out_ts, S2_out_ts, Labels_out_ts)
102+
X_out_ts, S1_out_ts, S2_out_ts, Labels_out_ts = make_data(
103+
dom_size, n_domains/6, max_obs, max_obs_size, n_traj, state_batch_size)
104+
# Save dataset
105+
np.savez_compressed(save_path, X_out_tr, S1_out_tr, S2_out_tr,
106+
Labels_out_tr, X_out_ts, S1_out_ts, S2_out_ts, Labels_out_ts)
101107

102108

103109
if __name__ == '__main__':

domains/gridworld.py

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
from scipy.sparse import csr_matrix
33
from scipy.sparse.csgraph import dijkstra
44

5+
56
class gridworld:
67
"""A class for making gridworlds"""
78
def __init__(self, image, targetx, targety):
@@ -25,6 +26,8 @@ def __init__(self, image, targetx, targety):
2526

2627

2728
def set_vals(self):
29+
# Setup function to initialize all necessary
30+
# data
2831
row_obs, col_obs = np.where(self.image == 0)
2932
row_free, col_free = np.where(self.image != 0)
3033
self.obstacles = [row_obs, col_obs]
@@ -125,24 +128,28 @@ def set_vals(self):
125128

126129

127130
def get_graph(self):
131+
# Returns graph
128132
G = self.G
129133
W = self.W[self.W != 0]
130134
return G, W
131135

132136

133137
def get_graph_inv(self):
138+
# Returns transpose of graph
134139
G = self.G.T
135140
W = self.W.T
136141
return G, W
137142

138143

139144
def val_2_image(self, val):
145+
# Zeros for obstacles, val for free space
140146
im = np.zeros((self.n_row, self.n_col))
141147
im[self.freespace[0], self.freespace[1]] = val
142148
return im
143149

144150

145151
def get_value_prior(self):
152+
# Returns value prior for gridworld
146153
s_map_col, s_map_row = np.meshgrid(np.arange(0,self.n_col),
147154
np.arange(0, self.n_row))
148155
im = np.sqrt(np.square(s_map_col - self.targety)
@@ -151,30 +158,37 @@ def get_value_prior(self):
151158

152159

153160
def get_reward_prior(self):
161+
# Returns reward prior for gridworld
154162
im = -1 * np.ones((self.n_row, self.n_col))
155163
im[self.targetx, self.targety] = 10
156164
return im
157165

158166

159167
def t_get_reward_prior(self):
168+
# Returns reward prior as needed for
169+
# dataset generation
160170
im = np.zeros((self.n_row, self.n_col))
161171
im[self.targetx, self.targety] = 10
162172
return im
163173

164174

165175
def get_state_image(self, row, col):
176+
# Zeros everywhere except [row,col]
166177
im = np.zeros((self.n_row, self.n_col))
167178
im[row, col] = 1
168179
return im
169180

170181

171182
def map_ind_to_state(self, row, col):
183+
# Takes [row, col] and maps to a state
172184
rw = np.where(self.state_map_row == row)
173185
cl = np.where(self.state_map_col == col)
174186
return np.intersect1d(rw, cl)[0]
175187

176188

177189
def get_coords(self, states):
190+
# Given a state or states, returns
191+
# [row,col] pairs for the state(s)
178192
non_obstacles = np.ravel_multi_index(
179193
[self.freespace[0], self.freespace[1]],
180194
(self.n_row,self.n_col), order='F')
@@ -186,6 +200,7 @@ def get_coords(self, states):
186200

187201

188202
def rand_choose(self, in_vec):
203+
# Samples
189204
if len(in_vec.shape) > 1:
190205
if in_vec.shape[1] == 1:
191206
in_vec = in_vec.T
@@ -197,6 +212,8 @@ def rand_choose(self, in_vec):
197212

198213

199214
def next_state_prob(self, s, a):
215+
# Gets next state probability for
216+
# a given action (a)
200217
if hasattr(a, "__iter__"):
201218
p = np.squeeze(self.P[s, :, a])
202219
else:
@@ -205,16 +222,22 @@ def next_state_prob(self, s, a):
205222

206223

207224
def sample_next_state(self, s, a):
225+
# Gets the next state given the
226+
# current state (s) and an
227+
# action (a)
208228
vec = self.next_state_prob(s, a)
209229
result = self.rand_choose(vec)
210230
return result
211231

212232

213233
def get_size(self):
234+
# Returns domain size
214235
return self.n_row, self.n_col
215236

216237

217238
def north(self, row, col):
239+
# Returns new [row,col]
240+
# if we take the action
218241
new_row = np.max([row-1, 0])
219242
new_col = col
220243
if self.image[new_row, new_col] == 0:
@@ -224,6 +247,8 @@ def north(self, row, col):
224247

225248

226249
def northeast(self, row, col):
250+
# Returns new [row,col]
251+
# if we take the action
227252
new_row = np.max([row - 1, 0])
228253
new_col = np.min([col + 1, self.n_col - 1])
229254
if self.image[new_row, new_col] == 0:
@@ -233,6 +258,8 @@ def northeast(self, row, col):
233258

234259

235260
def northwest(self, row, col):
261+
# Returns new [row,col]
262+
# if we take the action
236263
new_row = np.max([row - 1, 0])
237264
new_col = np.max([col - 1, 0])
238265
if self.image[new_row, new_col] == 0:
@@ -242,6 +269,8 @@ def northwest(self, row, col):
242269

243270

244271
def south(self, row, col):
272+
# Returns new [row,col]
273+
# if we take the action
245274
new_row = np.min([row + 1, self.n_row - 1])
246275
new_col = col
247276
if self.image[new_row, new_col] == 0:
@@ -251,6 +280,8 @@ def south(self, row, col):
251280

252281

253282
def southeast(self, row, col):
283+
# Returns new [row,col]
284+
# if we take the action
254285
new_row = np.min([row + 1, self.n_row - 1])
255286
new_col = np.min([col + 1, self.n_col - 1])
256287
if self.image[new_row, new_col] == 0:
@@ -260,6 +291,8 @@ def southeast(self, row, col):
260291

261292

262293
def southwest(self, row, col):
294+
# Returns new [row,col]
295+
# if we take the action
263296
new_row = np.min([row + 1, self.n_row - 1])
264297
new_col = np.max([col - 1, 0])
265298
if self.image[new_row, new_col] == 0:
@@ -269,6 +302,8 @@ def southwest(self, row, col):
269302

270303

271304
def east(self, row, col):
305+
# Returns new [row,col]
306+
# if we take the action
272307
new_row = row
273308
new_col = np.min([col + 1, self.n_col - 1])
274309
if self.image[new_row, new_col] == 0:
@@ -278,6 +313,8 @@ def east(self, row, col):
278313

279314

280315
def west(self, row, col):
316+
# Returns new [row,col]
317+
# if we take the action
281318
new_row = row
282319
new_col = np.max([col - 1, 0])
283320
if self.image[new_row, new_col] == 0:
@@ -307,6 +344,9 @@ def neighbors(self, row, col):
307344

308345

309346
def trace_path(pred, source, target):
347+
# traces back shortest path from
348+
# source to target given pred
349+
# (a predicessor list)
310350
max_len = 1000
311351
path = np.zeros((max_len, 1))
312352
i = max_len - 1
@@ -325,6 +365,8 @@ def trace_path(pred, source, target):
325365

326366

327367
def sample_trajectory(M, n_states):
368+
# Samples trajectories from random nodes
369+
# in our domain (M)
328370
G, W = M.get_graph_inv()
329371
N = G.shape[0]
330372
if N >= n_states:

0 commit comments

Comments
 (0)