62
62
import matplotlib .pyplot as plt
63
63
from collections import namedtuple
64
64
from itertools import count
65
- from copy import deepcopy
66
65
from PIL import Image
67
66
68
67
import torch
@@ -187,7 +186,7 @@ def __len__(self):
187
186
#
188
187
# .. math::
189
188
#
190
- # \mathcal{L} = \frac{1}{|B|}\sum_{(s, a, s', r) \ \in \ B} \mathcal{L}(\delta)
189
+ # \mathcal{L} = \frac{1}{|B|}\sum_{(s, a, s', r) \ \in \ B} \mathcal{L}(\delta)
191
190
#
192
191
# .. math::
193
192
#
@@ -237,7 +236,7 @@ def forward(self, x):
237
236
#
238
237
239
238
resize = T .Compose ([T .ToPILImage (),
240
- T .Scale (40 , interpolation = Image .CUBIC ),
239
+ T .Resize (40 , interpolation = Image .CUBIC ),
241
240
T .ToTensor ()])
242
241
243
242
# This is based on the code from gym.
@@ -273,6 +272,7 @@ def get_screen():
273
272
# Resize, and add a batch dimension (BCHW)
274
273
return resize (screen ).unsqueeze (0 ).type (Tensor )
275
274
275
+
276
276
env .reset ()
277
277
plt .figure ()
278
278
plt .imshow (get_screen ().cpu ().squeeze (0 ).permute (1 , 2 , 0 ).numpy (),
@@ -311,13 +311,18 @@ def get_screen():
311
311
EPS_START = 0.9
312
312
EPS_END = 0.05
313
313
EPS_DECAY = 200
314
+ TARGET_UPDATE = 10
314
315
315
- model = DQN ()
316
+ policy_net = DQN ()
317
+ target_net = DQN ()
318
+ target_net .load_state_dict (policy_net .state_dict ())
319
+ target_net .eval ()
316
320
317
321
if use_cuda :
318
- model .cuda ()
322
+ policy_net .cuda ()
323
+ target_net .cuda ()
319
324
320
- optimizer = optim .RMSprop (model .parameters ())
325
+ optimizer = optim .RMSprop (policy_net .parameters ())
321
326
memory = ReplayMemory (10000 )
322
327
323
328
@@ -331,7 +336,7 @@ def select_action(state):
331
336
math .exp (- 1. * steps_done / EPS_DECAY )
332
337
steps_done += 1
333
338
if sample > eps_threshold :
334
- return model (
339
+ return policy_net (
335
340
Variable (state , volatile = True ).type (FloatTensor )).data .max (1 )[1 ].view (1 , 1 )
336
341
else :
337
342
return LongTensor ([[random .randrange (2 )]])
@@ -371,14 +376,14 @@ def plot_durations():
371
376
# all the tensors into a single one, computes :math:`Q(s_t, a_t)` and
372
377
# :math:`V(s_{t+1}) = \max_a Q(s_{t+1}, a)`, and combines them into our
373
378
# loss. By defition we set :math:`V(s) = 0` if :math:`s` is a terminal
374
- # state.
375
-
376
-
377
- last_sync = 0
378
-
379
+ # state. We also use a target network to compute :math:`V(s_{t+1})` for
380
+ # added stability. The target network has its weights kept frozen most of
381
+ # the time, but is updated with the policy network's weights every so often.
382
+ # This is usually a set number of steps but we shall use episodes for
383
+ # simplicity.
384
+ #
379
385
380
386
def optimize_model ():
381
- global last_sync
382
387
if len (memory ) < BATCH_SIZE :
383
388
return
384
389
transitions = memory .sample (BATCH_SIZE )
@@ -389,10 +394,6 @@ def optimize_model():
389
394
# Compute a mask of non-final states and concatenate the batch elements
390
395
non_final_mask = ByteTensor (tuple (map (lambda s : s is not None ,
391
396
batch .next_state )))
392
-
393
- # We don't want to backprop through the expected action values and volatile
394
- # will save us on temporarily changing the model parameters'
395
- # requires_grad to False!
396
397
non_final_next_states = Variable (torch .cat ([s for s in batch .next_state
397
398
if s is not None ]),
398
399
volatile = True )
@@ -402,28 +403,27 @@ def optimize_model():
402
403
403
404
# Compute Q(s_t, a) - the model computes Q(s_t), then we select the
404
405
# columns of actions taken
405
- state_action_values = model (state_batch ).gather (1 , action_batch )
406
+ state_action_values = policy_net (state_batch ).gather (1 , action_batch )
406
407
407
408
# Compute V(s_{t+1}) for all next states.
408
409
next_state_values = Variable (torch .zeros (BATCH_SIZE ).type (Tensor ))
409
- next_state_values [non_final_mask ] = model (non_final_next_states ).max (1 )[0 ]
410
- # Now, we don't want to mess up the loss with a volatile flag, so let's
411
- # clear it. After this, we'll just end up with a Variable that has
412
- # requires_grad=False
413
- next_state_values .volatile = False
410
+ next_state_values [non_final_mask ] = target_net (non_final_next_states ).max (1 )[0 ]
414
411
# Compute the expected Q values
415
412
expected_state_action_values = (next_state_values * GAMMA ) + reward_batch
413
+ # Undo volatility (which was used to prevent unnecessary gradients)
414
+ expected_state_action_values = Variable (expected_state_action_values .data )
416
415
417
416
# Compute Huber loss
418
417
loss = F .smooth_l1_loss (state_action_values , expected_state_action_values )
419
418
420
419
# Optimize the model
421
420
optimizer .zero_grad ()
422
421
loss .backward ()
423
- for param in model .parameters ():
422
+ for param in policy_net .parameters ():
424
423
param .grad .data .clamp_ (- 1 , 1 )
425
424
optimizer .step ()
426
425
426
+
427
427
######################################################################
428
428
#
429
429
# Below, you can find the main training loop. At the beginning we reset
@@ -434,8 +434,9 @@ def optimize_model():
434
434
#
435
435
# Below, `num_episodes` is set small. You should download
436
436
# the notebook and run lot more epsiodes.
437
+ #
437
438
438
- num_episodes = 10
439
+ num_episodes = 50
439
440
for i_episode in range (num_episodes ):
440
441
# Initialize the environment and state
441
442
env .reset ()
@@ -468,6 +469,9 @@ def optimize_model():
468
469
episode_durations .append (t + 1 )
469
470
plot_durations ()
470
471
break
472
+ # Update the target network
473
+ if i_episode % TARGET_UPDATE == 0 :
474
+ target_net .load_state_dict (policy_net .state_dict ())
471
475
472
476
print ('Complete' )
473
477
env .render (close = True )
0 commit comments