6
6
7
7
8
8
This tutorial shows how to use PyTorch to train a Deep Q Learning (DQN) agent
9
- on the CartPole-v0 task from the `OpenAI Gym <https://www.gymlibrary.dev/>`__.
9
+ on the CartPole-v1 task from the `OpenAI Gym <https://www.gymlibrary.dev/>`__.
10
10
11
11
**Task**
12
12
30
30
31
31
The CartPole task is designed so that the inputs to the agent are 4 real
32
32
values representing the environment state (position, velocity, etc.).
33
- However, neural networks can solve the task purely by looking at the
34
- scene, so we'll use a patch of the screen centered on the cart as an
35
- input. Because of this, our results aren't directly comparable to the
36
- ones from the official leaderboard - our task is much harder.
37
- Unfortunately this does slow down the training, because we have to
38
- render all the frames.
33
+ We take these 4 inputs without any scaling and pass them through a
34
+ small fully-connected network with 2 outputs, one for each action.
35
+ The network is trained to predict the expected value for each action,
36
+ given the input state. The action with the highest expected value is
37
+ then chosen.
39
38
40
- Strictly speaking, we will present the state as the difference between
41
- the current screen patch and the previous one. This will allow the agent
42
- to take the velocity of the pole into account from one image.
43
39
44
40
**Packages**
45
41
46
42
47
43
First, let's import needed packages. Firstly, we need
48
44
`gym <https://github.com/openai/gym>`__ for the environment
45
+ Install by using `pip`. If you are running this in Google colab, run:
49
46
50
47
.. code-block:: bash
51
48
57
54
- neural networks (``torch.nn``)
58
55
- optimization (``torch.optim``)
59
56
- automatic differentiation (``torch.autograd``)
60
- - utilities for vision tasks (``torchvision`` - `a separate
61
- package <https://github.com/pytorch/vision>`__).
62
57
63
58
"""
64
59
70
65
import matplotlib .pyplot as plt
71
66
from collections import namedtuple , deque
72
67
from itertools import count
73
- from PIL import Image
74
68
75
69
import torch
76
70
import torch .nn as nn
77
71
import torch .optim as optim
78
72
import torch .nn .functional as F
79
- import torchvision .transforms as T
80
73
81
-
82
- if gym .__version__ < '0.26' :
83
- env = gym .make ('CartPole-v0' , new_step_api = True , render_mode = 'single_rgb_array' ).unwrapped
74
+ if gym .__version__ [:4 ] == '0.26' :
75
+ env = gym .make ('CartPole-v1' )
76
+ elif gym .__version__ [:4 ] == '0.25' :
77
+ env = gym .make ('CartPole-v1' , new_step_api = True )
84
78
else :
85
- env = gym . make ( 'CartPole-v0' , render_mode = 'rgb_array' ). unwrapped
79
+ raise ImportError ( f"Requires gym v25 or v26, actual version: { gym . __version__ } " )
86
80
87
81
# set up matplotlib
88
82
is_ipython = 'inline' in matplotlib .get_backend ()
@@ -152,9 +146,11 @@ def __len__(self):
152
146
# :math:`R_{t_0} = \sum_{t=t_0}^{\infty} \gamma^{t - t_0} r_t`, where
153
147
# :math:`R_{t_0}` is also known as the *return*. The discount,
154
148
# :math:`\gamma`, should be a constant between :math:`0` and :math:`1`
155
- # that ensures the sum converges. It makes rewards from the uncertain far
156
- # future less important for our agent than the ones in the near future
157
- # that it can be fairly confident about.
149
+ # that ensures the sum converges. A lower :math:`\gamma` makes
150
+ # rewards from the uncertain far future less important for our agent
151
+ # than the ones in the near future that it can be fairly confident
152
+ # about. It also encourages agents to collect reward closer in time
153
+ # than equivalent rewards temporally future away.
158
154
#
159
155
# The main idea behind Q-learning is that if we had a function
160
156
# :math:`Q^*: State \times Action \rightarrow \mathbb{R}`, that could tell
@@ -177,7 +173,7 @@ def __len__(self):
177
173
# The difference between the two sides of the equality is known as the
178
174
# temporal difference error, :math:`\delta`:
179
175
#
180
- # .. math:: \delta = Q(s, a) - (r + \gamma \max_a Q(s', a))
176
+ # .. math:: \delta = Q(s, a) - (r + \gamma \max_a' Q(s', a))
181
177
#
182
178
# To minimise this error, we will use the `Huber
183
179
# loss <https://en.wikipedia.org/wiki/Huber_loss>`__. The Huber loss acts
@@ -211,86 +207,18 @@ def __len__(self):
211
207
212
208
class DQN (nn .Module ):
213
209
214
- def __init__ (self , h , w , outputs ):
210
+ def __init__ (self , n_observations , n_actions ):
215
211
super (DQN , self ).__init__ ()
216
- self .conv1 = nn .Conv2d (3 , 16 , kernel_size = 5 , stride = 2 )
217
- self .bn1 = nn .BatchNorm2d (16 )
218
- self .conv2 = nn .Conv2d (16 , 32 , kernel_size = 5 , stride = 2 )
219
- self .bn2 = nn .BatchNorm2d (32 )
220
- self .conv3 = nn .Conv2d (32 , 32 , kernel_size = 5 , stride = 2 )
221
- self .bn3 = nn .BatchNorm2d (32 )
222
-
223
- # Number of Linear input connections depends on output of conv2d layers
224
- # and therefore the input image size, so compute it.
225
- def conv2d_size_out (size , kernel_size = 5 , stride = 2 ):
226
- return (size - (kernel_size - 1 ) - 1 ) // stride + 1
227
- convw = conv2d_size_out (conv2d_size_out (conv2d_size_out (w )))
228
- convh = conv2d_size_out (conv2d_size_out (conv2d_size_out (h )))
229
- linear_input_size = convw * convh * 32
230
- self .head = nn .Linear (linear_input_size , outputs )
212
+ self .layer1 = nn .Linear (n_observations , 128 )
213
+ self .layer2 = nn .Linear (128 , 128 )
214
+ self .layer3 = nn .Linear (128 , n_actions )
231
215
232
216
# Called with either one element to determine next action, or a batch
233
217
# during optimization. Returns tensor([[left0exp,right0exp]...]).
234
218
def forward (self , x ):
235
- x = x .to (device )
236
- x = F .relu (self .bn1 (self .conv1 (x )))
237
- x = F .relu (self .bn2 (self .conv2 (x )))
238
- x = F .relu (self .bn3 (self .conv3 (x )))
239
- return self .head (x .view (x .size (0 ), - 1 ))
240
-
241
-
242
- ######################################################################
243
- # Input extraction
244
- # ^^^^^^^^^^^^^^^^
245
- #
246
- # The code below are utilities for extracting and processing rendered
247
- # images from the environment. It uses the ``torchvision`` package, which
248
- # makes it easy to compose image transforms. Once you run the cell it will
249
- # display an example patch that it extracted.
250
- #
251
-
252
- resize = T .Compose ([T .ToPILImage (),
253
- T .Resize (40 , interpolation = Image .CUBIC ),
254
- T .ToTensor ()])
255
-
256
-
257
- def get_cart_location (screen_width ):
258
- world_width = env .x_threshold * 2
259
- scale = screen_width / world_width
260
- return int (env .state [0 ] * scale + screen_width / 2.0 ) # MIDDLE OF CART
261
-
262
- def get_screen ():
263
- # Returned screen requested by gym is 400x600x3, but is sometimes larger
264
- # such as 800x1200x3. Transpose it into torch order (CHW).
265
- screen = env .render ().transpose ((2 , 0 , 1 ))
266
- # Cart is in the lower half, so strip off the top and bottom of the screen
267
- _ , screen_height , screen_width = screen .shape
268
- screen = screen [:, int (screen_height * 0.4 ):int (screen_height * 0.8 )]
269
- view_width = int (screen_width * 0.6 )
270
- cart_location = get_cart_location (screen_width )
271
- if cart_location < view_width // 2 :
272
- slice_range = slice (view_width )
273
- elif cart_location > (screen_width - view_width // 2 ):
274
- slice_range = slice (- view_width , None )
275
- else :
276
- slice_range = slice (cart_location - view_width // 2 ,
277
- cart_location + view_width // 2 )
278
- # Strip off the edges, so that we have a square image centered on a cart
279
- screen = screen [:, :, slice_range ]
280
- # Convert to float, rescale, convert to torch tensor
281
- # (this doesn't require a copy)
282
- screen = np .ascontiguousarray (screen , dtype = np .float32 ) / 255
283
- screen = torch .from_numpy (screen )
284
- # Resize, and add a batch dimension (BCHW)
285
- return resize (screen ).unsqueeze (0 )
286
-
287
-
288
- env .reset ()
289
- plt .figure ()
290
- plt .imshow (get_screen ().cpu ().squeeze (0 ).permute (1 , 2 , 0 ).numpy (),
291
- interpolation = 'none' )
292
- plt .title ('Example extracted screen' )
293
- plt .show ()
219
+ x = F .relu (self .layer1 (x ))
220
+ x = F .relu (self .layer2 (x ))
221
+ return self .layer3 (x )
294
222
295
223
296
224
######################################################################
@@ -315,28 +243,35 @@ def get_screen():
315
243
# episode.
316
244
#
317
245
246
+ # BATCH_SIZE is the number of transitions sampled from the replay buffer
247
+ # GAMMA is the discount factor as mentioned in the previous section
248
+ # EPS_START is the starting value of epsilon
249
+ # EPS_END is the final value of epsilon
250
+ # EPS_DECAY controls the rate of exponential decay of epsilon, higher means a slower decay
251
+ # TAU is the update rate of the target network
252
+ # LR is the learning rate of the AdamW optimizer
318
253
BATCH_SIZE = 128
319
- GAMMA = 0.999
254
+ GAMMA = 0.99
320
255
EPS_START = 0.9
321
256
EPS_END = 0.05
322
- EPS_DECAY = 200
323
- TARGET_UPDATE = 10
324
-
325
- # Get screen size so that we can initialize layers correctly based on shape
326
- # returned from AI gym. Typical dimensions at this point are close to 3x40x90
327
- # which is the result of a clamped and down-scaled render buffer in get_screen()
328
- init_screen = get_screen ()
329
- _ , _ , screen_height , screen_width = init_screen .shape
257
+ EPS_DECAY = 1000
258
+ TAU = 0.005
259
+ LR = 1e-4
330
260
331
261
# Get number of actions from gym action space
332
262
n_actions = env .action_space .n
333
-
334
- policy_net = DQN (screen_height , screen_width , n_actions ).to (device )
335
- target_net = DQN (screen_height , screen_width , n_actions ).to (device )
263
+ # Get the number of state observations
264
+ if gym .__version__ [:4 ] == '0.26' :
265
+ state , _ = env .reset ()
266
+ elif gym .__version__ [:4 ] == '0.25' :
267
+ state , _ = env .reset (return_info = True )
268
+ n_observations = len (state )
269
+
270
+ policy_net = DQN (n_observations , n_actions ).to (device )
271
+ target_net = DQN (n_observations , n_actions ).to (device )
336
272
target_net .load_state_dict (policy_net .state_dict ())
337
- target_net .eval ()
338
273
339
- optimizer = optim .RMSprop (policy_net .parameters ())
274
+ optimizer = optim .AdamW (policy_net .parameters (), lr = LR , amsgrad = True )
340
275
memory = ReplayMemory (10000 )
341
276
342
277
@@ -356,14 +291,14 @@ def select_action(state):
356
291
# found, so we pick action with the larger expected reward.
357
292
return policy_net (state ).max (1 )[1 ].view (1 , 1 )
358
293
else :
359
- return torch .tensor ([[random . randrange ( n_actions )]], device = device , dtype = torch .long )
294
+ return torch .tensor ([[env . action_space . sample ( )]], device = device , dtype = torch .long )
360
295
361
296
362
297
episode_durations = []
363
298
364
299
365
300
def plot_durations ():
366
- plt .figure (2 )
301
+ plt .figure (1 )
367
302
plt .clf ()
368
303
durations_t = torch .tensor (episode_durations , dtype = torch .float )
369
304
plt .title ('Training...' )
@@ -394,10 +329,9 @@ def plot_durations():
394
329
# :math:`V(s_{t+1}) = \max_a Q(s_{t+1}, a)`, and combines them into our
395
330
# loss. By definition we set :math:`V(s) = 0` if :math:`s` is a terminal
396
331
# state. We also use a target network to compute :math:`V(s_{t+1})` for
397
- # added stability. The target network has its weights kept frozen most of
398
- # the time, but is updated with the policy network's weights every so often.
399
- # This is usually a set number of steps but we shall use episodes for
400
- # simplicity.
332
+ # added stability. The target network is updated at every step with a
333
+ # `soft update <https://arxiv.org/pdf/1509.02971.pdf>`__ controlled by
334
+ # the hyperparameter ``TAU``, which was previously defined.
401
335
#
402
336
403
337
def optimize_model ():
@@ -430,7 +364,8 @@ def optimize_model():
430
364
# This is merged based on the mask, such that we'll have either the expected
431
365
# state value or 0 in case the state was final.
432
366
next_state_values = torch .zeros (BATCH_SIZE , device = device )
433
- next_state_values [non_final_mask ] = target_net (non_final_next_states ).max (1 )[0 ].detach ()
367
+ with torch .no_grad ():
368
+ next_state_values [non_final_mask ] = target_net (non_final_next_states ).max (1 )[0 ]
434
369
# Compute the expected Q values
435
370
expected_state_action_values = (next_state_values * GAMMA ) + reward_batch
436
371
@@ -441,44 +376,49 @@ def optimize_model():
441
376
# Optimize the model
442
377
optimizer .zero_grad ()
443
378
loss .backward ()
444
- for param in policy_net . parameters ():
445
- param . grad . data . clamp_ ( - 1 , 1 )
379
+ # In-place gradient clipping
380
+ torch . nn . utils . clip_grad_value_ ( policy_net . parameters (), 100 )
446
381
optimizer .step ()
447
382
448
383
449
384
######################################################################
450
385
#
451
386
# Below, you can find the main training loop. At the beginning we reset
452
- # the environment and initialize the ``state`` Tensor. Then, we sample
453
- # an action, execute it, observe the next screen and the reward (always
387
+ # the environment and obtain the initial ``state`` Tensor. Then, we sample
388
+ # an action, execute it, observe the next state and the reward (always
454
389
# 1), and optimize our model once. When the episode ends (our model
455
390
# fails), we restart the loop.
456
391
#
457
- # Below, `num_episodes` is set small. You should download
458
- # the notebook and run lot more epsiodes, such as 300+ for meaningful
459
- # duration improvements.
392
+ # Below, `num_episodes` is set to 600 if a GPU is available, otherwise 50
393
+ # episodes are scheduled so training does not take too long. However, 50
394
+ # episodes is insufficient for to observe good performance on cartpole.
395
+ # You should see the model constantly achieve 500 steps within 600 training
396
+ # episodes. Training RL agents can be a noisy process, so restarting training
397
+ # can produce better results if convergence is not observed.
460
398
#
461
399
462
- num_episodes = 50
400
+ if torch .cuda .is_available ():
401
+ num_episodes = 600
402
+ else :
403
+ num_episodes = 50
404
+
463
405
for i_episode in range (num_episodes ):
464
- # Initialize the environment and state
465
- env .reset ()
466
- last_screen = get_screen ()
467
- current_screen = get_screen ()
468
- state = current_screen - last_screen
406
+ # Initialize the environment and get it's state
407
+ if gym .__version__ [:4 ] == '0.26' :
408
+ state , _ = env .reset ()
409
+ elif gym .__version__ [:4 ] == '0.25' :
410
+ state , _ = env .reset (return_info = True )
411
+ state = torch .tensor (state , dtype = torch .float32 , device = device ).unsqueeze (0 )
469
412
for t in count ():
470
- # Select and perform an action
471
413
action = select_action (state )
472
- _ , reward , done , _ , _ = env .step (action .item ())
414
+ observation , reward , terminated , truncated , _ = env .step (action .item ())
473
415
reward = torch .tensor ([reward ], device = device )
416
+ done = terminated or truncated
474
417
475
- # Observe new state
476
- last_screen = current_screen
477
- current_screen = get_screen ()
478
- if not done :
479
- next_state = current_screen - last_screen
480
- else :
418
+ if terminated :
481
419
next_state = None
420
+ else :
421
+ next_state = torch .tensor (observation , dtype = torch .float32 , device = device ).unsqueeze (0 )
482
422
483
423
# Store the transition in memory
484
424
memory .push (state , action , next_state , reward )
@@ -488,18 +428,21 @@ def optimize_model():
488
428
489
429
# Perform one step of the optimization (on the policy network)
490
430
optimize_model ()
431
+
432
+ # Soft update of the target network's weights
433
+ # θ′ ← τ θ + (1 −τ )θ′
434
+ target_net_state_dict = target_net .state_dict ()
435
+ policy_net_state_dict = policy_net .state_dict ()
436
+ for key in policy_net_state_dict :
437
+ target_net_state_dict [key ] = policy_net_state_dict [key ]* TAU + target_net_state_dict [key ]* (1 - TAU )
438
+ target_net .load_state_dict (target_net_state_dict )
439
+
491
440
if done :
492
441
episode_durations .append (t + 1 )
493
442
plot_durations ()
494
443
break
495
444
496
- # Update the target network, copying all weights and biases in DQN
497
- if t % TARGET_UPDATE == 0 :
498
- target_net .load_state_dict (policy_net .state_dict ())
499
-
500
445
print ('Complete' )
501
- env .render ()
502
- env .close ()
503
446
plt .ioff ()
504
447
plt .show ()
505
448
@@ -512,6 +455,6 @@ def optimize_model():
512
455
# step sample from the gym environment. We record the results in the
513
456
# replay memory and also run optimization step on every iteration.
514
457
# Optimization picks a random batch from the replay memory to do training of the
515
- # new policy. "Older " target_net is also used in optimization to compute the
516
- # expected Q values; it is updated occasionally to keep it current .
458
+ # new policy. The "older " target_net is also used in optimization to compute the
459
+ # expected Q values. A soft update of its weights are performed at every step .
517
460
#
0 commit comments