2
2
from scipy .sparse import csr_matrix
3
3
from scipy .sparse .csgraph import dijkstra
4
4
5
+
5
6
class gridworld :
6
7
"""A class for making gridworlds"""
7
8
def __init__ (self , image , targetx , targety ):
@@ -25,6 +26,8 @@ def __init__(self, image, targetx, targety):
25
26
26
27
27
28
def set_vals (self ):
29
+ # Setup function to initialize all necessary
30
+ # data
28
31
row_obs , col_obs = np .where (self .image == 0 )
29
32
row_free , col_free = np .where (self .image != 0 )
30
33
self .obstacles = [row_obs , col_obs ]
@@ -125,24 +128,28 @@ def set_vals(self):
125
128
126
129
127
130
def get_graph (self ):
131
+ # Returns graph
128
132
G = self .G
129
133
W = self .W [self .W != 0 ]
130
134
return G , W
131
135
132
136
133
137
def get_graph_inv (self ):
138
+ # Returns transpose of graph
134
139
G = self .G .T
135
140
W = self .W .T
136
141
return G , W
137
142
138
143
139
144
def val_2_image (self , val ):
145
+ # Zeros for obstacles, val for free space
140
146
im = np .zeros ((self .n_row , self .n_col ))
141
147
im [self .freespace [0 ], self .freespace [1 ]] = val
142
148
return im
143
149
144
150
145
151
def get_value_prior (self ):
152
+ # Returns value prior for gridworld
146
153
s_map_col , s_map_row = np .meshgrid (np .arange (0 ,self .n_col ),
147
154
np .arange (0 , self .n_row ))
148
155
im = np .sqrt (np .square (s_map_col - self .targety )
@@ -151,30 +158,37 @@ def get_value_prior(self):
151
158
152
159
153
160
def get_reward_prior (self ):
161
+ # Returns reward prior for gridworld
154
162
im = - 1 * np .ones ((self .n_row , self .n_col ))
155
163
im [self .targetx , self .targety ] = 10
156
164
return im
157
165
158
166
159
167
def t_get_reward_prior (self ):
168
+ # Returns reward prior as needed for
169
+ # dataset generation
160
170
im = np .zeros ((self .n_row , self .n_col ))
161
171
im [self .targetx , self .targety ] = 10
162
172
return im
163
173
164
174
165
175
def get_state_image (self , row , col ):
176
+ # Zeros everywhere except [row,col]
166
177
im = np .zeros ((self .n_row , self .n_col ))
167
178
im [row , col ] = 1
168
179
return im
169
180
170
181
171
182
def map_ind_to_state (self , row , col ):
183
+ # Takes [row, col] and maps to a state
172
184
rw = np .where (self .state_map_row == row )
173
185
cl = np .where (self .state_map_col == col )
174
186
return np .intersect1d (rw , cl )[0 ]
175
187
176
188
177
189
def get_coords (self , states ):
190
+ # Given a state or states, returns
191
+ # [row,col] pairs for the state(s)
178
192
non_obstacles = np .ravel_multi_index (
179
193
[self .freespace [0 ], self .freespace [1 ]],
180
194
(self .n_row ,self .n_col ), order = 'F' )
@@ -186,6 +200,7 @@ def get_coords(self, states):
186
200
187
201
188
202
def rand_choose (self , in_vec ):
203
+ # Samples
189
204
if len (in_vec .shape ) > 1 :
190
205
if in_vec .shape [1 ] == 1 :
191
206
in_vec = in_vec .T
@@ -197,6 +212,8 @@ def rand_choose(self, in_vec):
197
212
198
213
199
214
def next_state_prob (self , s , a ):
215
+ # Gets next state probability for
216
+ # a given action (a)
200
217
if hasattr (a , "__iter__" ):
201
218
p = np .squeeze (self .P [s , :, a ])
202
219
else :
@@ -205,16 +222,22 @@ def next_state_prob(self, s, a):
205
222
206
223
207
224
def sample_next_state (self , s , a ):
225
+ # Gets the next state given the
226
+ # current state (s) and an
227
+ # action (a)
208
228
vec = self .next_state_prob (s , a )
209
229
result = self .rand_choose (vec )
210
230
return result
211
231
212
232
213
233
def get_size (self ):
234
+ # Returns domain size
214
235
return self .n_row , self .n_col
215
236
216
237
217
238
def north (self , row , col ):
239
+ # Returns new [row,col]
240
+ # if we take the action
218
241
new_row = np .max ([row - 1 , 0 ])
219
242
new_col = col
220
243
if self .image [new_row , new_col ] == 0 :
@@ -224,6 +247,8 @@ def north(self, row, col):
224
247
225
248
226
249
def northeast (self , row , col ):
250
+ # Returns new [row,col]
251
+ # if we take the action
227
252
new_row = np .max ([row - 1 , 0 ])
228
253
new_col = np .min ([col + 1 , self .n_col - 1 ])
229
254
if self .image [new_row , new_col ] == 0 :
@@ -233,6 +258,8 @@ def northeast(self, row, col):
233
258
234
259
235
260
def northwest (self , row , col ):
261
+ # Returns new [row,col]
262
+ # if we take the action
236
263
new_row = np .max ([row - 1 , 0 ])
237
264
new_col = np .max ([col - 1 , 0 ])
238
265
if self .image [new_row , new_col ] == 0 :
@@ -242,6 +269,8 @@ def northwest(self, row, col):
242
269
243
270
244
271
def south (self , row , col ):
272
+ # Returns new [row,col]
273
+ # if we take the action
245
274
new_row = np .min ([row + 1 , self .n_row - 1 ])
246
275
new_col = col
247
276
if self .image [new_row , new_col ] == 0 :
@@ -251,6 +280,8 @@ def south(self, row, col):
251
280
252
281
253
282
def southeast (self , row , col ):
283
+ # Returns new [row,col]
284
+ # if we take the action
254
285
new_row = np .min ([row + 1 , self .n_row - 1 ])
255
286
new_col = np .min ([col + 1 , self .n_col - 1 ])
256
287
if self .image [new_row , new_col ] == 0 :
@@ -260,6 +291,8 @@ def southeast(self, row, col):
260
291
261
292
262
293
def southwest (self , row , col ):
294
+ # Returns new [row,col]
295
+ # if we take the action
263
296
new_row = np .min ([row + 1 , self .n_row - 1 ])
264
297
new_col = np .max ([col - 1 , 0 ])
265
298
if self .image [new_row , new_col ] == 0 :
@@ -269,6 +302,8 @@ def southwest(self, row, col):
269
302
270
303
271
304
def east (self , row , col ):
305
+ # Returns new [row,col]
306
+ # if we take the action
272
307
new_row = row
273
308
new_col = np .min ([col + 1 , self .n_col - 1 ])
274
309
if self .image [new_row , new_col ] == 0 :
@@ -278,6 +313,8 @@ def east(self, row, col):
278
313
279
314
280
315
def west (self , row , col ):
316
+ # Returns new [row,col]
317
+ # if we take the action
281
318
new_row = row
282
319
new_col = np .max ([col - 1 , 0 ])
283
320
if self .image [new_row , new_col ] == 0 :
@@ -307,6 +344,9 @@ def neighbors(self, row, col):
307
344
308
345
309
346
def trace_path (pred , source , target ):
347
+ # traces back shortest path from
348
+ # source to target given pred
349
+ # (a predicessor list)
310
350
max_len = 1000
311
351
path = np .zeros ((max_len , 1 ))
312
352
i = max_len - 1
@@ -325,6 +365,8 @@ def trace_path(pred, source, target):
325
365
326
366
327
367
def sample_trajectory (M , n_states ):
368
+ # Samples trajectories from random nodes
369
+ # in our domain (M)
328
370
G , W = M .get_graph_inv ()
329
371
N = G .shape [0 ]
330
372
if N >= n_states :
0 commit comments