Skip to content

[WIP] Improve training of DQN tutorial #2030

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 28 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
28 commits
Select commit Hold shift + click to select a range
d8d7f40
increased model capacity and input resolution
SiftingSands Sep 7, 2022
9f43173
changed out optimizer and greatly increased replay buffer
SiftingSands Sep 7, 2022
ee25c65
added reward shaping based solely on duration. intention was to reduc…
SiftingSands Sep 7, 2022
dffcd1f
forgot to included my increase to EPS_DECAY. Might as well change the…
SiftingSands Sep 8, 2022
d899d8d
revisions to batch norm behavior as suggested by vmoens
SiftingSands Sep 8, 2022
2a59217
switched to state vector input and modified hyps so training suceeds …
SiftingSands Sep 24, 2022
db90afd
numerous doc changes. removed unused imports. used gym's action space…
SiftingSands Sep 24, 2022
8e507a2
minor doc change. removed hard coding of network input size
SiftingSands Sep 24, 2022
0f0c07b
removed unneeded timelimit wrapper
SiftingSands Sep 25, 2022
d3bff00
Fixed termination vs truncation behavior. Remove the timelimit import…
SiftingSands Sep 27, 2022
ab20d96
Merge branch 'master' into DQN_revise_training
Sep 28, 2022
70bffde
Added missing # to see if the webpage will rende
SiftingSands Sep 28, 2022
04c3b32
Remove comment w/ special chars for webpage render
SiftingSands Sep 28, 2022
26f774f
Naive removal of block beginning with comments to see if it fixes the…
SiftingSands Sep 28, 2022
2e629a6
undid more docstring changes for testing
SiftingSands Sep 28, 2022
068260d
undid the last few plausible docstring changes and added the input ex…
SiftingSands Sep 28, 2022
80d133a
minor docstring changes
SiftingSands Oct 1, 2022
e52cf66
updated diagram for soft updates
SiftingSands Oct 1, 2022
ce0ae8f
Merge branch 'master' into DQN_revise_training
Oct 5, 2022
a3e5484
Merge branch 'master' into DQN_revise_training
Oct 5, 2022
271ead7
Merge branch 'master' into DQN_revise_training
malfet Oct 15, 2022
b3d65d0
gym version handling authored by https://github.com/pseudo-rnd-though…
SiftingSands Nov 2, 2022
1a878bf
fix merge
SiftingSands Nov 2, 2022
68e29a3
Merge branch 'master' into DQN_revise_training
SiftingSands Nov 2, 2022
1015af6
more version handling for v.25 and v.26
SiftingSands Nov 5, 2022
a3dfc6e
Merge branch 'DQN_revise_training' of github.com:SiftingSands/tutoria…
SiftingSands Nov 5, 2022
df822d4
Merge branch 'master' into DQN_revise_training
SiftingSands Nov 5, 2022
b9bc71c
Merge branch 'master' into DQN_revise_training
SiftingSands Nov 30, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions _static/img/reinforcement_learning_diagram.drawio
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
<mxfile host="app.diagrams.net" modified="2022-10-01T16:00:40.980Z" agent="5.0 (X11)" etag="_qbqVrrm3wUvm_i0-Q9T" version="20.4.0" type="device"><diagram id="aSXDm0BvLjt-Za0vl2Tv" name="Page-1">5Vpbc+MmFP41nmkfmpGEpMiPjTftzrTZZtbbbbYvHSxhiRQJFeHb/vqChG4gx95ElqfTeCaGwwEO37lwDskMLNL9zwzmyQONEJk5VrSfgXczx7GtuSW+JOVQUXzXrggxw5FiaglL/BXVMxV1gyNU9Bg5pYTjvE8MaZahkPdokDG667OtKenvmsMYGYRlCIlJ/QNHPKmogWe19PcIx0m9s22pkRTWzIpQJDCiuw4J3M/AglHKq1a6XyAiwatxqeb9dGS0EYyhjJ8z4flQfCient3PwfKXVfTn40P6/vMPapUtJBt14I8oJ/AgaA8opeygZOeHGhBGN1mE5JrWDNztEszRMoehHN0JExC0hKdE9GzRXGNCFpRQVs4Fa09+BD1mMMJC7s7YqvyIsYIz+jfqjPjljxhRsiLG0f4oCHYDrbBJRFPE5SEsNcGt1aPMMVDdXatbu1ZY0tGrr2hQmVPcrNwiLhoK9G9QgGMo4Lec41T6gWN9535v4C/WFLaPTmM/AlrebR8t2z0TruBScAEDLgaziKaCtoI8TAy4wg3bltYqIUFZ9KOMCaIbElgUOOxD1rftai0UGQFCg0/sRzcsRKfdjEMWI37KGkx1dOD2BtCuaQwRyPG2L+6QCtQOjxSLgxz1DRdoWqyOqWZ1I42+kNdfCOjmUOFgLCR0U0aemi2XDMVxgQ3ztK0X5fJtjR/0+EWjkqA1z0YHr7dYz7DYR0pwKM/5AfFRw2sEUbAOh8PrLYI+sgbDaxig1foy4dWxrh1fAzOACodeqi5lPKExzSC5b6laGGh5fqU0V8g/I84PKluBG06HQu8okcN/W+Q4OyS8CWTfsPFPlbz/Cxu/eg5hm0nEmPcg2mP+JNs3nup96Yy823c7h/HvTu8/cXfqV9H8lXen5xxJuUa+O91A2yd4+e709LvTP8Hvvsh/mbvWNtPDUR0hE+I9NVNF50vrFrLb+kLZq52hcaCO+9hTuY9zpvt413Qf3Vqc29e6z1xbyJkm9TSKb51fS4mdKVJP2zXc4fc8grwsLb3rlpaODog3cItaU96ijUdNmirWscHuXq03jjdVeLDnZ8aHI+qcJsGspewFdp8Iee8ivJU7Ehxn5YD/z0a+qN0RtOZtT7Ri9Q1Tac3ZqsjLvvWJQZzhLBbNUmtqXSFnuXQzb5zd7Bvxa5FQWkgvbB4vJDxCgXVlCEOOaXZhURwpylJQiRQFZdsL7wfkfh9RSFkkGql6XrQ2KiRddG9X7t2+rF10L6/ElUpu5VZ/ZWUt1D/piuk76/K8pWyq5S+lHiVi23oGaA9E7PlAxG4Yxw/ZZr4X1q5Vu9AE6V8wP5UAyt4jYlgcG7HrlUhVGL1WkgeO5EDf/r5oDdcuo9dIeqUPXk7ygK/xn3iPNACxJkgKHTMpHNVJBmod6+Z2snzmqmWMrlCgVx/nWjjQLc+7jIUDvYw5ZeFA43emsFjzCf0iYd2ava6q7z2LTVbX18XdyaDvX9UjNIMBevl2tkdo71VATyrG8ghd4LcV6qLb/oW/Ym//TwLc/ws=</diagram></mxfile>
Binary file modified _static/img/reinforcement_learning_diagram.jpg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
239 changes: 91 additions & 148 deletions intermediate_source/reinforcement_q_learning.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@


This tutorial shows how to use PyTorch to train a Deep Q Learning (DQN) agent
on the CartPole-v0 task from the `OpenAI Gym <https://www.gymlibrary.dev/>`__.
on the CartPole-v1 task from the `OpenAI Gym <https://www.gymlibrary.dev/>`__.

**Task**

Expand All @@ -30,35 +30,30 @@

The CartPole task is designed so that the inputs to the agent are 4 real
values representing the environment state (position, velocity, etc.).
However, neural networks can solve the task purely by looking at the
scene, so we'll use a patch of the screen centered on the cart as an
input. Because of this, our results aren't directly comparable to the
ones from the official leaderboard - our task is much harder.
Unfortunately this does slow down the training, because we have to
render all the frames.
We take these 4 inputs without any scaling and pass them through a
small fully-connected network with 2 outputs, one for each action.
The network is trained to predict the expected value for each action,
given the input state. The action with the highest expected value is
then chosen.

Strictly speaking, we will present the state as the difference between
the current screen patch and the previous one. This will allow the agent
to take the velocity of the pole into account from one image.

**Packages**


First, let's import needed packages. Firstly, we need
`gym <https://github.com/openai/gym>`__ for the environment
Install by using `pip`. If you are running this in Google colab, run:

.. code-block:: bash

%%bash
pip3 install gym[classic_control]

We'll also use the following from PyTorch:

- neural networks (``torch.nn``)
- optimization (``torch.optim``)
- automatic differentiation (``torch.autograd``)
- utilities for vision tasks (``torchvision`` - `a separate
package <https://github.com/pytorch/vision>`__).

"""

Expand All @@ -70,19 +65,18 @@
import matplotlib.pyplot as plt
from collections import namedtuple, deque
from itertools import count
from PIL import Image

import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import torchvision.transforms as T


if gym.__version__ < '0.26':
env = gym.make('CartPole-v0', new_step_api=True, render_mode='single_rgb_array').unwrapped
if gym.__version__[:4] == '0.26':
env = gym.make('CartPole-v1')
elif gym.__version__[:4] == '0.25':
env = gym.make('CartPole-v1', new_step_api=True)
else:
env = gym.make('CartPole-v0', render_mode='rgb_array').unwrapped
raise ImportError(f"Requires gym v25 or v26, actual version: {gym.__version__}")

# set up matplotlib
is_ipython = 'inline' in matplotlib.get_backend()
Expand Down Expand Up @@ -152,9 +146,11 @@ def __len__(self):
# :math:`R_{t_0} = \sum_{t=t_0}^{\infty} \gamma^{t - t_0} r_t`, where
# :math:`R_{t_0}` is also known as the *return*. The discount,
# :math:`\gamma`, should be a constant between :math:`0` and :math:`1`
# that ensures the sum converges. It makes rewards from the uncertain far
# future less important for our agent than the ones in the near future
# that it can be fairly confident about.
# that ensures the sum converges. A lower :math:`\gamma` makes
# rewards from the uncertain far future less important for our agent
# than the ones in the near future that it can be fairly confident
# about. It also encourages agents to collect reward closer in time
# than equivalent rewards temporally future away.
#
# The main idea behind Q-learning is that if we had a function
# :math:`Q^*: State \times Action \rightarrow \mathbb{R}`, that could tell
Expand All @@ -177,7 +173,7 @@ def __len__(self):
# The difference between the two sides of the equality is known as the
# temporal difference error, :math:`\delta`:
#
# .. math:: \delta = Q(s, a) - (r + \gamma \max_a Q(s', a))
# .. math:: \delta = Q(s, a) - (r + \gamma \max_a' Q(s', a))
#
# To minimise this error, we will use the `Huber
# loss <https://en.wikipedia.org/wiki/Huber_loss>`__. The Huber loss acts
Expand Down Expand Up @@ -211,86 +207,18 @@ def __len__(self):

class DQN(nn.Module):

def __init__(self, h, w, outputs):
def __init__(self, n_observations, n_actions):
super(DQN, self).__init__()
self.conv1 = nn.Conv2d(3, 16, kernel_size=5, stride=2)
self.bn1 = nn.BatchNorm2d(16)
self.conv2 = nn.Conv2d(16, 32, kernel_size=5, stride=2)
self.bn2 = nn.BatchNorm2d(32)
self.conv3 = nn.Conv2d(32, 32, kernel_size=5, stride=2)
self.bn3 = nn.BatchNorm2d(32)

# Number of Linear input connections depends on output of conv2d layers
# and therefore the input image size, so compute it.
def conv2d_size_out(size, kernel_size = 5, stride = 2):
return (size - (kernel_size - 1) - 1) // stride + 1
convw = conv2d_size_out(conv2d_size_out(conv2d_size_out(w)))
convh = conv2d_size_out(conv2d_size_out(conv2d_size_out(h)))
linear_input_size = convw * convh * 32
self.head = nn.Linear(linear_input_size, outputs)
self.layer1 = nn.Linear(n_observations, 128)
self.layer2 = nn.Linear(128, 128)
self.layer3 = nn.Linear(128, n_actions)

# Called with either one element to determine next action, or a batch
# during optimization. Returns tensor([[left0exp,right0exp]...]).
def forward(self, x):
x = x.to(device)
x = F.relu(self.bn1(self.conv1(x)))
x = F.relu(self.bn2(self.conv2(x)))
x = F.relu(self.bn3(self.conv3(x)))
return self.head(x.view(x.size(0), -1))


######################################################################
# Input extraction
# ^^^^^^^^^^^^^^^^
#
# The code below are utilities for extracting and processing rendered
# images from the environment. It uses the ``torchvision`` package, which
# makes it easy to compose image transforms. Once you run the cell it will
# display an example patch that it extracted.
#

resize = T.Compose([T.ToPILImage(),
T.Resize(40, interpolation=Image.CUBIC),
T.ToTensor()])


def get_cart_location(screen_width):
world_width = env.x_threshold * 2
scale = screen_width / world_width
return int(env.state[0] * scale + screen_width / 2.0) # MIDDLE OF CART

def get_screen():
# Returned screen requested by gym is 400x600x3, but is sometimes larger
# such as 800x1200x3. Transpose it into torch order (CHW).
screen = env.render().transpose((2, 0, 1))
# Cart is in the lower half, so strip off the top and bottom of the screen
_, screen_height, screen_width = screen.shape
screen = screen[:, int(screen_height*0.4):int(screen_height * 0.8)]
view_width = int(screen_width * 0.6)
cart_location = get_cart_location(screen_width)
if cart_location < view_width // 2:
slice_range = slice(view_width)
elif cart_location > (screen_width - view_width // 2):
slice_range = slice(-view_width, None)
else:
slice_range = slice(cart_location - view_width // 2,
cart_location + view_width // 2)
# Strip off the edges, so that we have a square image centered on a cart
screen = screen[:, :, slice_range]
# Convert to float, rescale, convert to torch tensor
# (this doesn't require a copy)
screen = np.ascontiguousarray(screen, dtype=np.float32) / 255
screen = torch.from_numpy(screen)
# Resize, and add a batch dimension (BCHW)
return resize(screen).unsqueeze(0)


env.reset()
plt.figure()
plt.imshow(get_screen().cpu().squeeze(0).permute(1, 2, 0).numpy(),
interpolation='none')
plt.title('Example extracted screen')
plt.show()
x = F.relu(self.layer1(x))
x = F.relu(self.layer2(x))
return self.layer3(x)


######################################################################
Expand All @@ -315,28 +243,35 @@ def get_screen():
# episode.
#

# BATCH_SIZE is the number of transitions sampled from the replay buffer
# GAMMA is the discount factor as mentioned in the previous section
# EPS_START is the starting value of epsilon
# EPS_END is the final value of epsilon
# EPS_DECAY controls the rate of exponential decay of epsilon, higher means a slower decay
# TAU is the update rate of the target network
# LR is the learning rate of the AdamW optimizer
BATCH_SIZE = 128
GAMMA = 0.999
GAMMA = 0.99
EPS_START = 0.9
EPS_END = 0.05
EPS_DECAY = 200
TARGET_UPDATE = 10

# Get screen size so that we can initialize layers correctly based on shape
# returned from AI gym. Typical dimensions at this point are close to 3x40x90
# which is the result of a clamped and down-scaled render buffer in get_screen()
init_screen = get_screen()
_, _, screen_height, screen_width = init_screen.shape
EPS_DECAY = 1000
TAU = 0.005
LR = 1e-4

# Get number of actions from gym action space
n_actions = env.action_space.n

policy_net = DQN(screen_height, screen_width, n_actions).to(device)
target_net = DQN(screen_height, screen_width, n_actions).to(device)
# Get the number of state observations
if gym.__version__[:4] == '0.26':
state, _ = env.reset()
elif gym.__version__[:4] == '0.25':
state, _ = env.reset(return_info=True)
n_observations = len(state)

policy_net = DQN(n_observations, n_actions).to(device)
target_net = DQN(n_observations, n_actions).to(device)
target_net.load_state_dict(policy_net.state_dict())
target_net.eval()

optimizer = optim.RMSprop(policy_net.parameters())
optimizer = optim.AdamW(policy_net.parameters(), lr=LR, amsgrad=True)
memory = ReplayMemory(10000)


Expand All @@ -356,14 +291,14 @@ def select_action(state):
# found, so we pick action with the larger expected reward.
return policy_net(state).max(1)[1].view(1, 1)
else:
return torch.tensor([[random.randrange(n_actions)]], device=device, dtype=torch.long)
return torch.tensor([[env.action_space.sample()]], device=device, dtype=torch.long)


episode_durations = []


def plot_durations():
plt.figure(2)
plt.figure(1)
plt.clf()
durations_t = torch.tensor(episode_durations, dtype=torch.float)
plt.title('Training...')
Expand Down Expand Up @@ -394,10 +329,9 @@ def plot_durations():
# :math:`V(s_{t+1}) = \max_a Q(s_{t+1}, a)`, and combines them into our
# loss. By definition we set :math:`V(s) = 0` if :math:`s` is a terminal
# state. We also use a target network to compute :math:`V(s_{t+1})` for
# added stability. The target network has its weights kept frozen most of
# the time, but is updated with the policy network's weights every so often.
# This is usually a set number of steps but we shall use episodes for
# simplicity.
# added stability. The target network is updated at every step with a
# `soft update <https://arxiv.org/pdf/1509.02971.pdf>`__ controlled by
# the hyperparameter ``TAU``, which was previously defined.
#

def optimize_model():
Expand Down Expand Up @@ -430,7 +364,8 @@ def optimize_model():
# This is merged based on the mask, such that we'll have either the expected
# state value or 0 in case the state was final.
next_state_values = torch.zeros(BATCH_SIZE, device=device)
next_state_values[non_final_mask] = target_net(non_final_next_states).max(1)[0].detach()
with torch.no_grad():
next_state_values[non_final_mask] = target_net(non_final_next_states).max(1)[0]
# Compute the expected Q values
expected_state_action_values = (next_state_values * GAMMA) + reward_batch

Expand All @@ -441,44 +376,49 @@ def optimize_model():
# Optimize the model
optimizer.zero_grad()
loss.backward()
for param in policy_net.parameters():
param.grad.data.clamp_(-1, 1)
# In-place gradient clipping
torch.nn.utils.clip_grad_value_(policy_net.parameters(), 100)
optimizer.step()


######################################################################
#
# Below, you can find the main training loop. At the beginning we reset
# the environment and initialize the ``state`` Tensor. Then, we sample
# an action, execute it, observe the next screen and the reward (always
# the environment and obtain the initial ``state`` Tensor. Then, we sample
# an action, execute it, observe the next state and the reward (always
# 1), and optimize our model once. When the episode ends (our model
# fails), we restart the loop.
#
# Below, `num_episodes` is set small. You should download
# the notebook and run lot more epsiodes, such as 300+ for meaningful
# duration improvements.
# Below, `num_episodes` is set to 600 if a GPU is available, otherwise 50
# episodes are scheduled so training does not take too long. However, 50
# episodes is insufficient for to observe good performance on cartpole.
# You should see the model constantly achieve 500 steps within 600 training
# episodes. Training RL agents can be a noisy process, so restarting training
# can produce better results if convergence is not observed.
#

num_episodes = 50
if torch.cuda.is_available():
num_episodes = 600
else:
num_episodes = 50

for i_episode in range(num_episodes):
# Initialize the environment and state
env.reset()
last_screen = get_screen()
current_screen = get_screen()
state = current_screen - last_screen
# Initialize the environment and get it's state
if gym.__version__[:4] == '0.26':
state, _ = env.reset()
elif gym.__version__[:4] == '0.25':
state, _ = env.reset(return_info=True)
state = torch.tensor(state, dtype=torch.float32, device=device).unsqueeze(0)
for t in count():
# Select and perform an action
action = select_action(state)
_, reward, done, _, _ = env.step(action.item())
observation, reward, terminated, truncated, _ = env.step(action.item())
reward = torch.tensor([reward], device=device)
done = terminated or truncated

# Observe new state
last_screen = current_screen
current_screen = get_screen()
if not done:
next_state = current_screen - last_screen
else:
if terminated:
next_state = None
else:
next_state = torch.tensor(observation, dtype=torch.float32, device=device).unsqueeze(0)

# Store the transition in memory
memory.push(state, action, next_state, reward)
Expand All @@ -488,18 +428,21 @@ def optimize_model():

# Perform one step of the optimization (on the policy network)
optimize_model()

# Soft update of the target network's weights
# θ′ ← τ θ + (1 −τ )θ′
target_net_state_dict = target_net.state_dict()
policy_net_state_dict = policy_net.state_dict()
for key in policy_net_state_dict:
target_net_state_dict[key] = policy_net_state_dict[key]*TAU + target_net_state_dict[key]*(1-TAU)
target_net.load_state_dict(target_net_state_dict)

if done:
episode_durations.append(t + 1)
plot_durations()
break

# Update the target network, copying all weights and biases in DQN
if t % TARGET_UPDATE == 0:
target_net.load_state_dict(policy_net.state_dict())

print('Complete')
env.render()
env.close()
plt.ioff()
plt.show()

Expand All @@ -512,6 +455,6 @@ def optimize_model():
# step sample from the gym environment. We record the results in the
# replay memory and also run optimization step on every iteration.
# Optimization picks a random batch from the replay memory to do training of the
# new policy. "Older" target_net is also used in optimization to compute the
# expected Q values; it is updated occasionally to keep it current.
# new policy. The "older" target_net is also used in optimization to compute the
# expected Q values. A soft update of its weights are performed at every step.
#