Skip to content

Commit d5abb12

Browse files
authored
[Refactoring] Improve training of DQN tutorial (2030 copy) (#2145)
1 parent 0050004 commit d5abb12

File tree

3 files changed

+91
-147
lines changed

3 files changed

+91
-147
lines changed
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
<mxfile host="app.diagrams.net" modified="2022-10-01T16:00:40.980Z" agent="5.0 (X11)" etag="_qbqVrrm3wUvm_i0-Q9T" version="20.4.0" type="device"><diagram id="aSXDm0BvLjt-Za0vl2Tv" name="Page-1">5Vpbc+MmFP41nmkfmpGEpMiPjTftzrTZZtbbbbYvHSxhiRQJFeHb/vqChG4gx95ElqfTeCaGwwEO37lwDskMLNL9zwzmyQONEJk5VrSfgXczx7GtuSW+JOVQUXzXrggxw5FiaglL/BXVMxV1gyNU9Bg5pYTjvE8MaZahkPdokDG667OtKenvmsMYGYRlCIlJ/QNHPKmogWe19PcIx0m9s22pkRTWzIpQJDCiuw4J3M/AglHKq1a6XyAiwatxqeb9dGS0EYyhjJ8z4flQfCient3PwfKXVfTn40P6/vMPapUtJBt14I8oJ/AgaA8opeygZOeHGhBGN1mE5JrWDNztEszRMoehHN0JExC0hKdE9GzRXGNCFpRQVs4Fa09+BD1mMMJC7s7YqvyIsYIz+jfqjPjljxhRsiLG0f4oCHYDrbBJRFPE5SEsNcGt1aPMMVDdXatbu1ZY0tGrr2hQmVPcrNwiLhoK9G9QgGMo4Lec41T6gWN9535v4C/WFLaPTmM/AlrebR8t2z0TruBScAEDLgaziKaCtoI8TAy4wg3bltYqIUFZ9KOMCaIbElgUOOxD1rftai0UGQFCg0/sRzcsRKfdjEMWI37KGkx1dOD2BtCuaQwRyPG2L+6QCtQOjxSLgxz1DRdoWqyOqWZ1I42+kNdfCOjmUOFgLCR0U0aemi2XDMVxgQ3ztK0X5fJtjR/0+EWjkqA1z0YHr7dYz7DYR0pwKM/5AfFRw2sEUbAOh8PrLYI+sgbDaxig1foy4dWxrh1fAzOACodeqi5lPKExzSC5b6laGGh5fqU0V8g/I84PKluBG06HQu8okcN/W+Q4OyS8CWTfsPFPlbz/Cxu/eg5hm0nEmPcg2mP+JNs3nup96Yy823c7h/HvTu8/cXfqV9H8lXen5xxJuUa+O91A2yd4+e709LvTP8Hvvsh/mbvWNtPDUR0hE+I9NVNF50vrFrLb+kLZq52hcaCO+9hTuY9zpvt413Qf3Vqc29e6z1xbyJkm9TSKb51fS4mdKVJP2zXc4fc8grwsLb3rlpaODog3cItaU96ijUdNmirWscHuXq03jjdVeLDnZ8aHI+qcJsGspewFdp8Iee8ivJU7Ehxn5YD/z0a+qN0RtOZtT7Ri9Q1Tac3ZqsjLvvWJQZzhLBbNUmtqXSFnuXQzb5zd7Bvxa5FQWkgvbB4vJDxCgXVlCEOOaXZhURwpylJQiRQFZdsL7wfkfh9RSFkkGql6XrQ2KiRddG9X7t2+rF10L6/ElUpu5VZ/ZWUt1D/piuk76/K8pWyq5S+lHiVi23oGaA9E7PlAxG4Yxw/ZZr4X1q5Vu9AE6V8wP5UAyt4jYlgcG7HrlUhVGL1WkgeO5EDf/r5oDdcuo9dIeqUPXk7ygK/xn3iPNACxJkgKHTMpHNVJBmod6+Z2snzmqmWMrlCgVx/nWjjQLc+7jIUDvYw5ZeFA43emsFjzCf0iYd2ava6q7z2LTVbX18XdyaDvX9UjNIMBevl2tkdo71VATyrG8ghd4LcV6qLb/oW/Ym//TwLc/ws=</diagram></mxfile>
10.8 KB
Loading

intermediate_source/reinforcement_q_learning.py

Lines changed: 90 additions & 147 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
77
88
This tutorial shows how to use PyTorch to train a Deep Q Learning (DQN) agent
9-
on the CartPole-v0 task from the `OpenAI Gym <https://www.gymlibrary.dev/>`__.
9+
on the CartPole-v1 task from the `OpenAI Gym <https://www.gymlibrary.dev/>`__.
1010
1111
**Task**
1212
@@ -30,22 +30,19 @@
3030
3131
The CartPole task is designed so that the inputs to the agent are 4 real
3232
values representing the environment state (position, velocity, etc.).
33-
However, neural networks can solve the task purely by looking at the
34-
scene, so we'll use a patch of the screen centered on the cart as an
35-
input. Because of this, our results aren't directly comparable to the
36-
ones from the official leaderboard - our task is much harder.
37-
Unfortunately this does slow down the training, because we have to
38-
render all the frames.
33+
We take these 4 inputs without any scaling and pass them through a
34+
small fully-connected network with 2 outputs, one for each action.
35+
The network is trained to predict the expected value for each action,
36+
given the input state. The action with the highest expected value is
37+
then chosen.
3938
40-
Strictly speaking, we will present the state as the difference between
41-
the current screen patch and the previous one. This will allow the agent
42-
to take the velocity of the pole into account from one image.
4339
4440
**Packages**
4541
4642
4743
First, let's import needed packages. Firstly, we need
4844
`gym <https://github.com/openai/gym>`__ for the environment
45+
Install by using `pip`. If you are running this in Google colab, run:
4946
5047
.. code-block:: bash
5148
@@ -57,8 +54,6 @@
5754
- neural networks (``torch.nn``)
5855
- optimization (``torch.optim``)
5956
- automatic differentiation (``torch.autograd``)
60-
- utilities for vision tasks (``torchvision`` - `a separate
61-
package <https://github.com/pytorch/vision>`__).
6257
6358
"""
6459

@@ -70,19 +65,18 @@
7065
import matplotlib.pyplot as plt
7166
from collections import namedtuple, deque
7267
from itertools import count
73-
from PIL import Image
7468

7569
import torch
7670
import torch.nn as nn
7771
import torch.optim as optim
7872
import torch.nn.functional as F
79-
import torchvision.transforms as T
8073

81-
82-
if gym.__version__ < '0.26':
83-
env = gym.make('CartPole-v0', new_step_api=True, render_mode='single_rgb_array').unwrapped
74+
if gym.__version__[:4] == '0.26':
75+
env = gym.make('CartPole-v1')
76+
elif gym.__version__[:4] == '0.25':
77+
env = gym.make('CartPole-v1', new_step_api=True)
8478
else:
85-
env = gym.make('CartPole-v0', render_mode='rgb_array').unwrapped
79+
raise ImportError(f"Requires gym v25 or v26, actual version: {gym.__version__}")
8680

8781
# set up matplotlib
8882
is_ipython = 'inline' in matplotlib.get_backend()
@@ -152,9 +146,11 @@ def __len__(self):
152146
# :math:`R_{t_0} = \sum_{t=t_0}^{\infty} \gamma^{t - t_0} r_t`, where
153147
# :math:`R_{t_0}` is also known as the *return*. The discount,
154148
# :math:`\gamma`, should be a constant between :math:`0` and :math:`1`
155-
# that ensures the sum converges. It makes rewards from the uncertain far
156-
# future less important for our agent than the ones in the near future
157-
# that it can be fairly confident about.
149+
# that ensures the sum converges. A lower :math:`\gamma` makes
150+
# rewards from the uncertain far future less important for our agent
151+
# than the ones in the near future that it can be fairly confident
152+
# about. It also encourages agents to collect reward closer in time
153+
# than equivalent rewards temporally future away.
158154
#
159155
# The main idea behind Q-learning is that if we had a function
160156
# :math:`Q^*: State \times Action \rightarrow \mathbb{R}`, that could tell
@@ -177,7 +173,7 @@ def __len__(self):
177173
# The difference between the two sides of the equality is known as the
178174
# temporal difference error, :math:`\delta`:
179175
#
180-
# .. math:: \delta = Q(s, a) - (r + \gamma \max_a Q(s', a))
176+
# .. math:: \delta = Q(s, a) - (r + \gamma \max_a' Q(s', a))
181177
#
182178
# To minimise this error, we will use the `Huber
183179
# loss <https://en.wikipedia.org/wiki/Huber_loss>`__. The Huber loss acts
@@ -211,86 +207,18 @@ def __len__(self):
211207

212208
class DQN(nn.Module):
213209

214-
def __init__(self, h, w, outputs):
210+
def __init__(self, n_observations, n_actions):
215211
super(DQN, self).__init__()
216-
self.conv1 = nn.Conv2d(3, 16, kernel_size=5, stride=2)
217-
self.bn1 = nn.BatchNorm2d(16)
218-
self.conv2 = nn.Conv2d(16, 32, kernel_size=5, stride=2)
219-
self.bn2 = nn.BatchNorm2d(32)
220-
self.conv3 = nn.Conv2d(32, 32, kernel_size=5, stride=2)
221-
self.bn3 = nn.BatchNorm2d(32)
222-
223-
# Number of Linear input connections depends on output of conv2d layers
224-
# and therefore the input image size, so compute it.
225-
def conv2d_size_out(size, kernel_size = 5, stride = 2):
226-
return (size - (kernel_size - 1) - 1) // stride + 1
227-
convw = conv2d_size_out(conv2d_size_out(conv2d_size_out(w)))
228-
convh = conv2d_size_out(conv2d_size_out(conv2d_size_out(h)))
229-
linear_input_size = convw * convh * 32
230-
self.head = nn.Linear(linear_input_size, outputs)
212+
self.layer1 = nn.Linear(n_observations, 128)
213+
self.layer2 = nn.Linear(128, 128)
214+
self.layer3 = nn.Linear(128, n_actions)
231215

232216
# Called with either one element to determine next action, or a batch
233217
# during optimization. Returns tensor([[left0exp,right0exp]...]).
234218
def forward(self, x):
235-
x = x.to(device)
236-
x = F.relu(self.bn1(self.conv1(x)))
237-
x = F.relu(self.bn2(self.conv2(x)))
238-
x = F.relu(self.bn3(self.conv3(x)))
239-
return self.head(x.view(x.size(0), -1))
240-
241-
242-
######################################################################
243-
# Input extraction
244-
# ^^^^^^^^^^^^^^^^
245-
#
246-
# The code below are utilities for extracting and processing rendered
247-
# images from the environment. It uses the ``torchvision`` package, which
248-
# makes it easy to compose image transforms. Once you run the cell it will
249-
# display an example patch that it extracted.
250-
#
251-
252-
resize = T.Compose([T.ToPILImage(),
253-
T.Resize(40, interpolation=Image.CUBIC),
254-
T.ToTensor()])
255-
256-
257-
def get_cart_location(screen_width):
258-
world_width = env.x_threshold * 2
259-
scale = screen_width / world_width
260-
return int(env.state[0] * scale + screen_width / 2.0) # MIDDLE OF CART
261-
262-
def get_screen():
263-
# Returned screen requested by gym is 400x600x3, but is sometimes larger
264-
# such as 800x1200x3. Transpose it into torch order (CHW).
265-
screen = env.render().transpose((2, 0, 1))
266-
# Cart is in the lower half, so strip off the top and bottom of the screen
267-
_, screen_height, screen_width = screen.shape
268-
screen = screen[:, int(screen_height*0.4):int(screen_height * 0.8)]
269-
view_width = int(screen_width * 0.6)
270-
cart_location = get_cart_location(screen_width)
271-
if cart_location < view_width // 2:
272-
slice_range = slice(view_width)
273-
elif cart_location > (screen_width - view_width // 2):
274-
slice_range = slice(-view_width, None)
275-
else:
276-
slice_range = slice(cart_location - view_width // 2,
277-
cart_location + view_width // 2)
278-
# Strip off the edges, so that we have a square image centered on a cart
279-
screen = screen[:, :, slice_range]
280-
# Convert to float, rescale, convert to torch tensor
281-
# (this doesn't require a copy)
282-
screen = np.ascontiguousarray(screen, dtype=np.float32) / 255
283-
screen = torch.from_numpy(screen)
284-
# Resize, and add a batch dimension (BCHW)
285-
return resize(screen).unsqueeze(0)
286-
287-
288-
env.reset()
289-
plt.figure()
290-
plt.imshow(get_screen().cpu().squeeze(0).permute(1, 2, 0).numpy(),
291-
interpolation='none')
292-
plt.title('Example extracted screen')
293-
plt.show()
219+
x = F.relu(self.layer1(x))
220+
x = F.relu(self.layer2(x))
221+
return self.layer3(x)
294222

295223

296224
######################################################################
@@ -315,28 +243,35 @@ def get_screen():
315243
# episode.
316244
#
317245

246+
# BATCH_SIZE is the number of transitions sampled from the replay buffer
247+
# GAMMA is the discount factor as mentioned in the previous section
248+
# EPS_START is the starting value of epsilon
249+
# EPS_END is the final value of epsilon
250+
# EPS_DECAY controls the rate of exponential decay of epsilon, higher means a slower decay
251+
# TAU is the update rate of the target network
252+
# LR is the learning rate of the AdamW optimizer
318253
BATCH_SIZE = 128
319-
GAMMA = 0.999
254+
GAMMA = 0.99
320255
EPS_START = 0.9
321256
EPS_END = 0.05
322-
EPS_DECAY = 200
323-
TARGET_UPDATE = 10
324-
325-
# Get screen size so that we can initialize layers correctly based on shape
326-
# returned from AI gym. Typical dimensions at this point are close to 3x40x90
327-
# which is the result of a clamped and down-scaled render buffer in get_screen()
328-
init_screen = get_screen()
329-
_, _, screen_height, screen_width = init_screen.shape
257+
EPS_DECAY = 1000
258+
TAU = 0.005
259+
LR = 1e-4
330260

331261
# Get number of actions from gym action space
332262
n_actions = env.action_space.n
333-
334-
policy_net = DQN(screen_height, screen_width, n_actions).to(device)
335-
target_net = DQN(screen_height, screen_width, n_actions).to(device)
263+
# Get the number of state observations
264+
if gym.__version__[:4] == '0.26':
265+
state, _ = env.reset()
266+
elif gym.__version__[:4] == '0.25':
267+
state, _ = env.reset(return_info=True)
268+
n_observations = len(state)
269+
270+
policy_net = DQN(n_observations, n_actions).to(device)
271+
target_net = DQN(n_observations, n_actions).to(device)
336272
target_net.load_state_dict(policy_net.state_dict())
337-
target_net.eval()
338273

339-
optimizer = optim.RMSprop(policy_net.parameters())
274+
optimizer = optim.AdamW(policy_net.parameters(), lr=LR, amsgrad=True)
340275
memory = ReplayMemory(10000)
341276

342277

@@ -356,14 +291,14 @@ def select_action(state):
356291
# found, so we pick action with the larger expected reward.
357292
return policy_net(state).max(1)[1].view(1, 1)
358293
else:
359-
return torch.tensor([[random.randrange(n_actions)]], device=device, dtype=torch.long)
294+
return torch.tensor([[env.action_space.sample()]], device=device, dtype=torch.long)
360295

361296

362297
episode_durations = []
363298

364299

365300
def plot_durations():
366-
plt.figure(2)
301+
plt.figure(1)
367302
plt.clf()
368303
durations_t = torch.tensor(episode_durations, dtype=torch.float)
369304
plt.title('Training...')
@@ -394,10 +329,9 @@ def plot_durations():
394329
# :math:`V(s_{t+1}) = \max_a Q(s_{t+1}, a)`, and combines them into our
395330
# loss. By definition we set :math:`V(s) = 0` if :math:`s` is a terminal
396331
# state. We also use a target network to compute :math:`V(s_{t+1})` for
397-
# added stability. The target network has its weights kept frozen most of
398-
# the time, but is updated with the policy network's weights every so often.
399-
# This is usually a set number of steps but we shall use episodes for
400-
# simplicity.
332+
# added stability. The target network is updated at every step with a
333+
# `soft update <https://arxiv.org/pdf/1509.02971.pdf>`__ controlled by
334+
# the hyperparameter ``TAU``, which was previously defined.
401335
#
402336

403337
def optimize_model():
@@ -430,7 +364,8 @@ def optimize_model():
430364
# This is merged based on the mask, such that we'll have either the expected
431365
# state value or 0 in case the state was final.
432366
next_state_values = torch.zeros(BATCH_SIZE, device=device)
433-
next_state_values[non_final_mask] = target_net(non_final_next_states).max(1)[0].detach()
367+
with torch.no_grad():
368+
next_state_values[non_final_mask] = target_net(non_final_next_states).max(1)[0]
434369
# Compute the expected Q values
435370
expected_state_action_values = (next_state_values * GAMMA) + reward_batch
436371

@@ -441,44 +376,49 @@ def optimize_model():
441376
# Optimize the model
442377
optimizer.zero_grad()
443378
loss.backward()
444-
for param in policy_net.parameters():
445-
param.grad.data.clamp_(-1, 1)
379+
# In-place gradient clipping
380+
torch.nn.utils.clip_grad_value_(policy_net.parameters(), 100)
446381
optimizer.step()
447382

448383

449384
######################################################################
450385
#
451386
# Below, you can find the main training loop. At the beginning we reset
452-
# the environment and initialize the ``state`` Tensor. Then, we sample
453-
# an action, execute it, observe the next screen and the reward (always
387+
# the environment and obtain the initial ``state`` Tensor. Then, we sample
388+
# an action, execute it, observe the next state and the reward (always
454389
# 1), and optimize our model once. When the episode ends (our model
455390
# fails), we restart the loop.
456391
#
457-
# Below, `num_episodes` is set small. You should download
458-
# the notebook and run lot more epsiodes, such as 300+ for meaningful
459-
# duration improvements.
392+
# Below, `num_episodes` is set to 600 if a GPU is available, otherwise 50
393+
# episodes are scheduled so training does not take too long. However, 50
394+
# episodes is insufficient for to observe good performance on cartpole.
395+
# You should see the model constantly achieve 500 steps within 600 training
396+
# episodes. Training RL agents can be a noisy process, so restarting training
397+
# can produce better results if convergence is not observed.
460398
#
461399

462-
num_episodes = 50
400+
if torch.cuda.is_available():
401+
num_episodes = 600
402+
else:
403+
num_episodes = 50
404+
463405
for i_episode in range(num_episodes):
464-
# Initialize the environment and state
465-
env.reset()
466-
last_screen = get_screen()
467-
current_screen = get_screen()
468-
state = current_screen - last_screen
406+
# Initialize the environment and get it's state
407+
if gym.__version__[:4] == '0.26':
408+
state, _ = env.reset()
409+
elif gym.__version__[:4] == '0.25':
410+
state, _ = env.reset(return_info=True)
411+
state = torch.tensor(state, dtype=torch.float32, device=device).unsqueeze(0)
469412
for t in count():
470-
# Select and perform an action
471413
action = select_action(state)
472-
_, reward, done, _, _ = env.step(action.item())
414+
observation, reward, terminated, truncated, _ = env.step(action.item())
473415
reward = torch.tensor([reward], device=device)
416+
done = terminated or truncated
474417

475-
# Observe new state
476-
last_screen = current_screen
477-
current_screen = get_screen()
478-
if not done:
479-
next_state = current_screen - last_screen
480-
else:
418+
if terminated:
481419
next_state = None
420+
else:
421+
next_state = torch.tensor(observation, dtype=torch.float32, device=device).unsqueeze(0)
482422

483423
# Store the transition in memory
484424
memory.push(state, action, next_state, reward)
@@ -488,18 +428,21 @@ def optimize_model():
488428

489429
# Perform one step of the optimization (on the policy network)
490430
optimize_model()
431+
432+
# Soft update of the target network's weights
433+
# θ′ ← τ θ + (1 −τ )θ′
434+
target_net_state_dict = target_net.state_dict()
435+
policy_net_state_dict = policy_net.state_dict()
436+
for key in policy_net_state_dict:
437+
target_net_state_dict[key] = policy_net_state_dict[key]*TAU + target_net_state_dict[key]*(1-TAU)
438+
target_net.load_state_dict(target_net_state_dict)
439+
491440
if done:
492441
episode_durations.append(t + 1)
493442
plot_durations()
494443
break
495444

496-
# Update the target network, copying all weights and biases in DQN
497-
if t % TARGET_UPDATE == 0:
498-
target_net.load_state_dict(policy_net.state_dict())
499-
500445
print('Complete')
501-
env.render()
502-
env.close()
503446
plt.ioff()
504447
plt.show()
505448

@@ -512,6 +455,6 @@ def optimize_model():
512455
# step sample from the gym environment. We record the results in the
513456
# replay memory and also run optimization step on every iteration.
514457
# Optimization picks a random batch from the replay memory to do training of the
515-
# new policy. "Older" target_net is also used in optimization to compute the
516-
# expected Q values; it is updated occasionally to keep it current.
458+
# new policy. The "older" target_net is also used in optimization to compute the
459+
# expected Q values. A soft update of its weights are performed at every step.
517460
#

0 commit comments

Comments
 (0)