23
23
As the agent observes the current state of the environment and chooses
24
24
an action, the environment *transitions* to a new state, and also
25
25
returns a reward that indicates the consequences of the action. In this
26
- task, the environment terminates if the pole falls over too far.
26
+ task, rewards are +1 for every incremental timestep and the environment
27
+ terminates if the pole falls over too far or the cart moves more then 2.4
28
+ units away from center. This means better performing scenarios will run
29
+ for longer duration, accumulating larger return.
27
30
28
31
The CartPole task is designed so that the inputs to the agent are 4 real
29
32
values representing the environment state (position, velocity, etc.).
97
100
# For this, we're going to need two classses:
98
101
#
99
102
# - ``Transition`` - a named tuple representing a single transition in
100
- # our environment
103
+ # our environment. It maps essentially maps (state, action) pairs
104
+ # to their (next_state, reward) result, with the state being the
105
+ # screen difference image as described later on.
101
106
# - ``ReplayMemory`` - a cyclic buffer of bounded size that holds the
102
107
# transitions observed recently. It also implements a ``.sample()``
103
108
# method for selecting a random batch of transitions for training.
@@ -197,22 +202,32 @@ def __len__(self):
197
202
# difference between the current and previous screen patches. It has two
198
203
# outputs, representing :math:`Q(s, \mathrm{left})` and
199
204
# :math:`Q(s, \mathrm{right})` (where :math:`s` is the input to the
200
- # network). In effect, the network is trying to predict the *quality * of
205
+ # network). In effect, the network is trying to predict the *expected return * of
201
206
# taking each action given the current input.
202
207
#
203
208
204
209
class DQN (nn .Module ):
205
210
206
- def __init__ (self ):
211
+ def __init__ (self , h , w ):
207
212
super (DQN , self ).__init__ ()
208
213
self .conv1 = nn .Conv2d (3 , 16 , kernel_size = 5 , stride = 2 )
209
214
self .bn1 = nn .BatchNorm2d (16 )
210
215
self .conv2 = nn .Conv2d (16 , 32 , kernel_size = 5 , stride = 2 )
211
216
self .bn2 = nn .BatchNorm2d (32 )
212
217
self .conv3 = nn .Conv2d (32 , 32 , kernel_size = 5 , stride = 2 )
213
218
self .bn3 = nn .BatchNorm2d (32 )
214
- self .head = nn .Linear (448 , 2 )
215
219
220
+ # Number of Linear input connections depends on output of conv2d layers
221
+ # and therefore the input image size, so compute it.
222
+ def conv2d_size_out (size , kernel_size = 5 , stride = 2 ):
223
+ return (size - (kernel_size - 1 ) - 1 ) // stride + 1
224
+ convw = conv2d_size_out (conv2d_size_out (conv2d_size_out (w )))
225
+ convh = conv2d_size_out (conv2d_size_out (conv2d_size_out (h )))
226
+ linear_input_size = convw * convh * 32
227
+ self .head = nn .Linear (linear_input_size , 2 ) # 448 or 512
228
+
229
+ # Called with either one element to determine next action, or a batch
230
+ # during optimization. Returns tensor([[left0exp,right0exp]...]).
216
231
def forward (self , x ):
217
232
x = F .relu (self .bn1 (self .conv1 (x )))
218
233
x = F .relu (self .bn2 (self .conv2 (x )))
@@ -234,23 +249,21 @@ def forward(self, x):
234
249
T .Resize (40 , interpolation = Image .CUBIC ),
235
250
T .ToTensor ()])
236
251
237
- # This is based on the code from gym.
238
- screen_width = 600
239
252
240
-
241
- def get_cart_location ():
253
+ def get_cart_location (screen_width ):
242
254
world_width = env .x_threshold * 2
243
255
scale = screen_width / world_width
244
256
return int (env .state [0 ] * scale + screen_width / 2.0 ) # MIDDLE OF CART
245
257
246
-
247
258
def get_screen ():
248
- screen = env .render (mode = 'rgb_array' ).transpose (
249
- (2 , 0 , 1 )) # transpose into torch order (CHW)
250
- # Strip off the top and bottom of the screen
251
- screen = screen [:, 160 :320 ]
252
- view_width = 320
253
- cart_location = get_cart_location ()
259
+ # Returned screen requested by gym is 400x600x3, but is sometimes larger
260
+ # such as 800x1200x3. Transpose it into torch order (CHW).
261
+ screen = env .render (mode = 'rgb_array' ).transpose ((2 , 0 , 1 ))
262
+ # Cart is in the lower half, so strip off the top and bottom of the screen
263
+ _ , screen_height , screen_width = screen .shape
264
+ screen = screen [:, int (screen_height * 0.4 ):int (screen_height * 0.8 )]
265
+ view_width = int (screen_width * 0.6 )
266
+ cart_location = get_cart_location (screen_width )
254
267
if cart_location < view_width // 2 :
255
268
slice_range = slice (view_width )
256
269
elif cart_location > (screen_width - view_width // 2 ):
@@ -305,8 +318,14 @@ def get_screen():
305
318
EPS_DECAY = 200
306
319
TARGET_UPDATE = 10
307
320
308
- policy_net = DQN ().to (device )
309
- target_net = DQN ().to (device )
321
+ # Get screen size so that we can initialize layers correctly based on shape
322
+ # returned from AI gym. Typical dimensions at this point are close to 3x40x90
323
+ # which is the result of a clamped and down-scaled render buffer in get_screen()
324
+ init_screen = get_screen ()
325
+ _ , _ , screen_height , screen_width = init_screen .shape
326
+
327
+ policy_net = DQN (screen_height , screen_width ).to (device )
328
+ target_net = DQN (screen_height , screen_width ).to (device )
310
329
target_net .load_state_dict (policy_net .state_dict ())
311
330
target_net .eval ()
312
331
@@ -325,6 +344,9 @@ def select_action(state):
325
344
steps_done += 1
326
345
if sample > eps_threshold :
327
346
with torch .no_grad ():
347
+ # t.max(1) will return largest value for column of each row.
348
+ # second column on max result is index of where max element was
349
+ # found, so we pick action with the larger expected reward.
328
350
return policy_net (state ).max (1 )[1 ].view (1 , 1 )
329
351
else :
330
352
return torch .tensor ([[random .randrange (2 )]], device = device , dtype = torch .long )
@@ -376,10 +398,12 @@ def optimize_model():
376
398
return
377
399
transitions = memory .sample (BATCH_SIZE )
378
400
# Transpose the batch (see http://stackoverflow.com/a/19343/3343043 for
379
- # detailed explanation).
401
+ # detailed explanation). This converts batch-array of Transitions
402
+ # to Transition of batch-arrays.
380
403
batch = Transition (* zip (* transitions ))
381
404
382
405
# Compute a mask of non-final states and concatenate the batch elements
406
+ # (a final state would've been the one after which simulation ended)
383
407
non_final_mask = torch .tensor (tuple (map (lambda s : s is not None ,
384
408
batch .next_state )), device = device , dtype = torch .uint8 )
385
409
non_final_next_states = torch .cat ([s for s in batch .next_state
@@ -389,10 +413,15 @@ def optimize_model():
389
413
reward_batch = torch .cat (batch .reward )
390
414
391
415
# Compute Q(s_t, a) - the model computes Q(s_t), then we select the
392
- # columns of actions taken
416
+ # columns of actions taken. These are the actions which would've been taken
417
+ # for each batch state according to policy_net
393
418
state_action_values = policy_net (state_batch ).gather (1 , action_batch )
394
419
395
420
# Compute V(s_{t+1}) for all next states.
421
+ # Expected values of actions for non_final_next_states are computed based
422
+ # on the "older" target_net; selecting their best reward with max(1)[0].
423
+ # This is merged based on the mask, such that we'll have either the expected
424
+ # state value or 0 in case the state was final.
396
425
next_state_values = torch .zeros (BATCH_SIZE , device = device )
397
426
next_state_values [non_final_mask ] = target_net (non_final_next_states ).max (1 )[0 ].detach ()
398
427
# Compute the expected Q values
@@ -418,7 +447,8 @@ def optimize_model():
418
447
# fails), we restart the loop.
419
448
#
420
449
# Below, `num_episodes` is set small. You should download
421
- # the notebook and run lot more epsiodes.
450
+ # the notebook and run lot more epsiodes, such as 300+ for meaningful
451
+ # duration improvements.
422
452
#
423
453
424
454
num_episodes = 50
@@ -454,7 +484,7 @@ def optimize_model():
454
484
episode_durations .append (t + 1 )
455
485
plot_durations ()
456
486
break
457
- # Update the target network
487
+ # Update the target network, copying all weights and biases in DQN
458
488
if i_episode % TARGET_UPDATE == 0 :
459
489
target_net .load_state_dict (policy_net .state_dict ())
460
490
@@ -463,3 +493,16 @@ def optimize_model():
463
493
env .close ()
464
494
plt .ioff ()
465
495
plt .show ()
496
+
497
+ ######################################################################
498
+ # Here is the diagram that illustrates the overall resulting data flow.
499
+ #
500
+ # .. figure:: /_static/img/reinforcement_learning_diagram.jpg
501
+ #
502
+ # Actions are chosen either randomly or based on a policy, getting the next
503
+ # step sample from the gym environment. We record the results in the
504
+ # replay memory and also run optimization step on every iteration.
505
+ # Optimization picks a random batch from the replay memory to do training of the
506
+ # new policy. "Older" target_net is also used in optimization to compute the
507
+ # expected Q values; it is updated occasionally to keep it current.
508
+ #
0 commit comments