diff --git a/_static/img/reinforcement_learning_diagram.drawio b/_static/img/reinforcement_learning_diagram.drawio
new file mode 100644
index 00000000000..2ff4e6f0270
--- /dev/null
+++ b/_static/img/reinforcement_learning_diagram.drawio
@@ -0,0 +1 @@
+5Vpbc+MmFP41nmkfmpGEpMiPjTftzrTZZtbbbbYvHSxhiRQJFeHb/vqChG4gx95ElqfTeCaGwwEO37lwDskMLNL9zwzmyQONEJk5VrSfgXczx7GtuSW+JOVQUXzXrggxw5FiaglL/BXVMxV1gyNU9Bg5pYTjvE8MaZahkPdokDG667OtKenvmsMYGYRlCIlJ/QNHPKmogWe19PcIx0m9s22pkRTWzIpQJDCiuw4J3M/AglHKq1a6XyAiwatxqeb9dGS0EYyhjJ8z4flQfCient3PwfKXVfTn40P6/vMPapUtJBt14I8oJ/AgaA8opeygZOeHGhBGN1mE5JrWDNztEszRMoehHN0JExC0hKdE9GzRXGNCFpRQVs4Fa09+BD1mMMJC7s7YqvyIsYIz+jfqjPjljxhRsiLG0f4oCHYDrbBJRFPE5SEsNcGt1aPMMVDdXatbu1ZY0tGrr2hQmVPcrNwiLhoK9G9QgGMo4Lec41T6gWN9535v4C/WFLaPTmM/AlrebR8t2z0TruBScAEDLgaziKaCtoI8TAy4wg3bltYqIUFZ9KOMCaIbElgUOOxD1rftai0UGQFCg0/sRzcsRKfdjEMWI37KGkx1dOD2BtCuaQwRyPG2L+6QCtQOjxSLgxz1DRdoWqyOqWZ1I42+kNdfCOjmUOFgLCR0U0aemi2XDMVxgQ3ztK0X5fJtjR/0+EWjkqA1z0YHr7dYz7DYR0pwKM/5AfFRw2sEUbAOh8PrLYI+sgbDaxig1foy4dWxrh1fAzOACodeqi5lPKExzSC5b6laGGh5fqU0V8g/I84PKluBG06HQu8okcN/W+Q4OyS8CWTfsPFPlbz/Cxu/eg5hm0nEmPcg2mP+JNs3nup96Yy823c7h/HvTu8/cXfqV9H8lXen5xxJuUa+O91A2yd4+e709LvTP8Hvvsh/mbvWNtPDUR0hE+I9NVNF50vrFrLb+kLZq52hcaCO+9hTuY9zpvt413Qf3Vqc29e6z1xbyJkm9TSKb51fS4mdKVJP2zXc4fc8grwsLb3rlpaODog3cItaU96ijUdNmirWscHuXq03jjdVeLDnZ8aHI+qcJsGspewFdp8Iee8ivJU7Ehxn5YD/z0a+qN0RtOZtT7Ri9Q1Tac3ZqsjLvvWJQZzhLBbNUmtqXSFnuXQzb5zd7Bvxa5FQWkgvbB4vJDxCgXVlCEOOaXZhURwpylJQiRQFZdsL7wfkfh9RSFkkGql6XrQ2KiRddG9X7t2+rF10L6/ElUpu5VZ/ZWUt1D/piuk76/K8pWyq5S+lHiVi23oGaA9E7PlAxG4Yxw/ZZr4X1q5Vu9AE6V8wP5UAyt4jYlgcG7HrlUhVGL1WkgeO5EDf/r5oDdcuo9dIeqUPXk7ygK/xn3iPNACxJkgKHTMpHNVJBmod6+Z2snzmqmWMrlCgVx/nWjjQLc+7jIUDvYw5ZeFA43emsFjzCf0iYd2ava6q7z2LTVbX18XdyaDvX9UjNIMBevl2tkdo71VATyrG8ghd4LcV6qLb/oW/Ym//TwLc/ws=
\ No newline at end of file
diff --git a/_static/img/reinforcement_learning_diagram.jpg b/_static/img/reinforcement_learning_diagram.jpg
index bdcbc322502..7e04efc2534 100644
Binary files a/_static/img/reinforcement_learning_diagram.jpg and b/_static/img/reinforcement_learning_diagram.jpg differ
diff --git a/intermediate_source/reinforcement_q_learning.py b/intermediate_source/reinforcement_q_learning.py
index 7353bba05a4..611cfb32448 100644
--- a/intermediate_source/reinforcement_q_learning.py
+++ b/intermediate_source/reinforcement_q_learning.py
@@ -6,7 +6,7 @@
This tutorial shows how to use PyTorch to train a Deep Q Learning (DQN) agent
-on the CartPole-v0 task from the `OpenAI Gym `__.
+on the CartPole-v1 task from the `OpenAI Gym `__.
**Task**
@@ -30,22 +30,19 @@
The CartPole task is designed so that the inputs to the agent are 4 real
values representing the environment state (position, velocity, etc.).
-However, neural networks can solve the task purely by looking at the
-scene, so we'll use a patch of the screen centered on the cart as an
-input. Because of this, our results aren't directly comparable to the
-ones from the official leaderboard - our task is much harder.
-Unfortunately this does slow down the training, because we have to
-render all the frames.
+We take these 4 inputs without any scaling and pass them through a
+small fully-connected network with 2 outputs, one for each action.
+The network is trained to predict the expected value for each action,
+given the input state. The action with the highest expected value is
+then chosen.
-Strictly speaking, we will present the state as the difference between
-the current screen patch and the previous one. This will allow the agent
-to take the velocity of the pole into account from one image.
**Packages**
First, let's import needed packages. Firstly, we need
`gym `__ for the environment
+Install by using `pip`. If you are running this in Google colab, run:
.. code-block:: bash
@@ -57,8 +54,6 @@
- neural networks (``torch.nn``)
- optimization (``torch.optim``)
- automatic differentiation (``torch.autograd``)
-- utilities for vision tasks (``torchvision`` - `a separate
- package `__).
"""
@@ -70,19 +65,18 @@
import matplotlib.pyplot as plt
from collections import namedtuple, deque
from itertools import count
-from PIL import Image
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
-import torchvision.transforms as T
-
-if gym.__version__ < '0.26':
- env = gym.make('CartPole-v0', new_step_api=True, render_mode='single_rgb_array').unwrapped
+if gym.__version__[:4] == '0.26':
+ env = gym.make('CartPole-v1')
+elif gym.__version__[:4] == '0.25':
+ env = gym.make('CartPole-v1', new_step_api=True)
else:
- env = gym.make('CartPole-v0', render_mode='rgb_array').unwrapped
+ raise ImportError(f"Requires gym v25 or v26, actual version: {gym.__version__}")
# set up matplotlib
is_ipython = 'inline' in matplotlib.get_backend()
@@ -152,9 +146,11 @@ def __len__(self):
# :math:`R_{t_0} = \sum_{t=t_0}^{\infty} \gamma^{t - t_0} r_t`, where
# :math:`R_{t_0}` is also known as the *return*. The discount,
# :math:`\gamma`, should be a constant between :math:`0` and :math:`1`
-# that ensures the sum converges. It makes rewards from the uncertain far
-# future less important for our agent than the ones in the near future
-# that it can be fairly confident about.
+# that ensures the sum converges. A lower :math:`\gamma` makes
+# rewards from the uncertain far future less important for our agent
+# than the ones in the near future that it can be fairly confident
+# about. It also encourages agents to collect reward closer in time
+# than equivalent rewards temporally future away.
#
# The main idea behind Q-learning is that if we had a function
# :math:`Q^*: State \times Action \rightarrow \mathbb{R}`, that could tell
@@ -177,7 +173,7 @@ def __len__(self):
# The difference between the two sides of the equality is known as the
# temporal difference error, :math:`\delta`:
#
-# .. math:: \delta = Q(s, a) - (r + \gamma \max_a Q(s', a))
+# .. math:: \delta = Q(s, a) - (r + \gamma \max_a' Q(s', a))
#
# To minimise this error, we will use the `Huber
# loss `__. The Huber loss acts
@@ -211,86 +207,18 @@ def __len__(self):
class DQN(nn.Module):
- def __init__(self, h, w, outputs):
+ def __init__(self, n_observations, n_actions):
super(DQN, self).__init__()
- self.conv1 = nn.Conv2d(3, 16, kernel_size=5, stride=2)
- self.bn1 = nn.BatchNorm2d(16)
- self.conv2 = nn.Conv2d(16, 32, kernel_size=5, stride=2)
- self.bn2 = nn.BatchNorm2d(32)
- self.conv3 = nn.Conv2d(32, 32, kernel_size=5, stride=2)
- self.bn3 = nn.BatchNorm2d(32)
-
- # Number of Linear input connections depends on output of conv2d layers
- # and therefore the input image size, so compute it.
- def conv2d_size_out(size, kernel_size = 5, stride = 2):
- return (size - (kernel_size - 1) - 1) // stride + 1
- convw = conv2d_size_out(conv2d_size_out(conv2d_size_out(w)))
- convh = conv2d_size_out(conv2d_size_out(conv2d_size_out(h)))
- linear_input_size = convw * convh * 32
- self.head = nn.Linear(linear_input_size, outputs)
+ self.layer1 = nn.Linear(n_observations, 128)
+ self.layer2 = nn.Linear(128, 128)
+ self.layer3 = nn.Linear(128, n_actions)
# Called with either one element to determine next action, or a batch
# during optimization. Returns tensor([[left0exp,right0exp]...]).
def forward(self, x):
- x = x.to(device)
- x = F.relu(self.bn1(self.conv1(x)))
- x = F.relu(self.bn2(self.conv2(x)))
- x = F.relu(self.bn3(self.conv3(x)))
- return self.head(x.view(x.size(0), -1))
-
-
-######################################################################
-# Input extraction
-# ^^^^^^^^^^^^^^^^
-#
-# The code below are utilities for extracting and processing rendered
-# images from the environment. It uses the ``torchvision`` package, which
-# makes it easy to compose image transforms. Once you run the cell it will
-# display an example patch that it extracted.
-#
-
-resize = T.Compose([T.ToPILImage(),
- T.Resize(40, interpolation=Image.CUBIC),
- T.ToTensor()])
-
-
-def get_cart_location(screen_width):
- world_width = env.x_threshold * 2
- scale = screen_width / world_width
- return int(env.state[0] * scale + screen_width / 2.0) # MIDDLE OF CART
-
-def get_screen():
- # Returned screen requested by gym is 400x600x3, but is sometimes larger
- # such as 800x1200x3. Transpose it into torch order (CHW).
- screen = env.render().transpose((2, 0, 1))
- # Cart is in the lower half, so strip off the top and bottom of the screen
- _, screen_height, screen_width = screen.shape
- screen = screen[:, int(screen_height*0.4):int(screen_height * 0.8)]
- view_width = int(screen_width * 0.6)
- cart_location = get_cart_location(screen_width)
- if cart_location < view_width // 2:
- slice_range = slice(view_width)
- elif cart_location > (screen_width - view_width // 2):
- slice_range = slice(-view_width, None)
- else:
- slice_range = slice(cart_location - view_width // 2,
- cart_location + view_width // 2)
- # Strip off the edges, so that we have a square image centered on a cart
- screen = screen[:, :, slice_range]
- # Convert to float, rescale, convert to torch tensor
- # (this doesn't require a copy)
- screen = np.ascontiguousarray(screen, dtype=np.float32) / 255
- screen = torch.from_numpy(screen)
- # Resize, and add a batch dimension (BCHW)
- return resize(screen).unsqueeze(0)
-
-
-env.reset()
-plt.figure()
-plt.imshow(get_screen().cpu().squeeze(0).permute(1, 2, 0).numpy(),
- interpolation='none')
-plt.title('Example extracted screen')
-plt.show()
+ x = F.relu(self.layer1(x))
+ x = F.relu(self.layer2(x))
+ return self.layer3(x)
######################################################################
@@ -315,28 +243,35 @@ def get_screen():
# episode.
#
+# BATCH_SIZE is the number of transitions sampled from the replay buffer
+# GAMMA is the discount factor as mentioned in the previous section
+# EPS_START is the starting value of epsilon
+# EPS_END is the final value of epsilon
+# EPS_DECAY controls the rate of exponential decay of epsilon, higher means a slower decay
+# TAU is the update rate of the target network
+# LR is the learning rate of the AdamW optimizer
BATCH_SIZE = 128
-GAMMA = 0.999
+GAMMA = 0.99
EPS_START = 0.9
EPS_END = 0.05
-EPS_DECAY = 200
-TARGET_UPDATE = 10
-
-# Get screen size so that we can initialize layers correctly based on shape
-# returned from AI gym. Typical dimensions at this point are close to 3x40x90
-# which is the result of a clamped and down-scaled render buffer in get_screen()
-init_screen = get_screen()
-_, _, screen_height, screen_width = init_screen.shape
+EPS_DECAY = 1000
+TAU = 0.005
+LR = 1e-4
# Get number of actions from gym action space
n_actions = env.action_space.n
-
-policy_net = DQN(screen_height, screen_width, n_actions).to(device)
-target_net = DQN(screen_height, screen_width, n_actions).to(device)
+# Get the number of state observations
+if gym.__version__[:4] == '0.26':
+ state, _ = env.reset()
+elif gym.__version__[:4] == '0.25':
+ state, _ = env.reset(return_info=True)
+n_observations = len(state)
+
+policy_net = DQN(n_observations, n_actions).to(device)
+target_net = DQN(n_observations, n_actions).to(device)
target_net.load_state_dict(policy_net.state_dict())
-target_net.eval()
-optimizer = optim.RMSprop(policy_net.parameters())
+optimizer = optim.AdamW(policy_net.parameters(), lr=LR, amsgrad=True)
memory = ReplayMemory(10000)
@@ -356,14 +291,14 @@ def select_action(state):
# found, so we pick action with the larger expected reward.
return policy_net(state).max(1)[1].view(1, 1)
else:
- return torch.tensor([[random.randrange(n_actions)]], device=device, dtype=torch.long)
+ return torch.tensor([[env.action_space.sample()]], device=device, dtype=torch.long)
episode_durations = []
def plot_durations():
- plt.figure(2)
+ plt.figure(1)
plt.clf()
durations_t = torch.tensor(episode_durations, dtype=torch.float)
plt.title('Training...')
@@ -394,10 +329,9 @@ def plot_durations():
# :math:`V(s_{t+1}) = \max_a Q(s_{t+1}, a)`, and combines them into our
# loss. By definition we set :math:`V(s) = 0` if :math:`s` is a terminal
# state. We also use a target network to compute :math:`V(s_{t+1})` for
-# added stability. The target network has its weights kept frozen most of
-# the time, but is updated with the policy network's weights every so often.
-# This is usually a set number of steps but we shall use episodes for
-# simplicity.
+# added stability. The target network is updated at every step with a
+# `soft update `__ controlled by
+# the hyperparameter ``TAU``, which was previously defined.
#
def optimize_model():
@@ -430,7 +364,8 @@ def optimize_model():
# This is merged based on the mask, such that we'll have either the expected
# state value or 0 in case the state was final.
next_state_values = torch.zeros(BATCH_SIZE, device=device)
- next_state_values[non_final_mask] = target_net(non_final_next_states).max(1)[0].detach()
+ with torch.no_grad():
+ next_state_values[non_final_mask] = target_net(non_final_next_states).max(1)[0]
# Compute the expected Q values
expected_state_action_values = (next_state_values * GAMMA) + reward_batch
@@ -441,44 +376,49 @@ def optimize_model():
# Optimize the model
optimizer.zero_grad()
loss.backward()
- for param in policy_net.parameters():
- param.grad.data.clamp_(-1, 1)
+ # In-place gradient clipping
+ torch.nn.utils.clip_grad_value_(policy_net.parameters(), 100)
optimizer.step()
######################################################################
#
# Below, you can find the main training loop. At the beginning we reset
-# the environment and initialize the ``state`` Tensor. Then, we sample
-# an action, execute it, observe the next screen and the reward (always
+# the environment and obtain the initial ``state`` Tensor. Then, we sample
+# an action, execute it, observe the next state and the reward (always
# 1), and optimize our model once. When the episode ends (our model
# fails), we restart the loop.
#
-# Below, `num_episodes` is set small. You should download
-# the notebook and run lot more epsiodes, such as 300+ for meaningful
-# duration improvements.
+# Below, `num_episodes` is set to 600 if a GPU is available, otherwise 50
+# episodes are scheduled so training does not take too long. However, 50
+# episodes is insufficient for to observe good performance on cartpole.
+# You should see the model constantly achieve 500 steps within 600 training
+# episodes. Training RL agents can be a noisy process, so restarting training
+# can produce better results if convergence is not observed.
#
-num_episodes = 50
+if torch.cuda.is_available():
+ num_episodes = 600
+else:
+ num_episodes = 50
+
for i_episode in range(num_episodes):
- # Initialize the environment and state
- env.reset()
- last_screen = get_screen()
- current_screen = get_screen()
- state = current_screen - last_screen
+ # Initialize the environment and get it's state
+ if gym.__version__[:4] == '0.26':
+ state, _ = env.reset()
+ elif gym.__version__[:4] == '0.25':
+ state, _ = env.reset(return_info=True)
+ state = torch.tensor(state, dtype=torch.float32, device=device).unsqueeze(0)
for t in count():
- # Select and perform an action
action = select_action(state)
- _, reward, done, _, _ = env.step(action.item())
+ observation, reward, terminated, truncated, _ = env.step(action.item())
reward = torch.tensor([reward], device=device)
+ done = terminated or truncated
- # Observe new state
- last_screen = current_screen
- current_screen = get_screen()
- if not done:
- next_state = current_screen - last_screen
- else:
+ if terminated:
next_state = None
+ else:
+ next_state = torch.tensor(observation, dtype=torch.float32, device=device).unsqueeze(0)
# Store the transition in memory
memory.push(state, action, next_state, reward)
@@ -488,18 +428,21 @@ def optimize_model():
# Perform one step of the optimization (on the policy network)
optimize_model()
+
+ # Soft update of the target network's weights
+ # θ′ ← τ θ + (1 −τ )θ′
+ target_net_state_dict = target_net.state_dict()
+ policy_net_state_dict = policy_net.state_dict()
+ for key in policy_net_state_dict:
+ target_net_state_dict[key] = policy_net_state_dict[key]*TAU + target_net_state_dict[key]*(1-TAU)
+ target_net.load_state_dict(target_net_state_dict)
+
if done:
episode_durations.append(t + 1)
plot_durations()
break
- # Update the target network, copying all weights and biases in DQN
- if t % TARGET_UPDATE == 0:
- target_net.load_state_dict(policy_net.state_dict())
-
print('Complete')
-env.render()
-env.close()
plt.ioff()
plt.show()
@@ -512,6 +455,6 @@ def optimize_model():
# step sample from the gym environment. We record the results in the
# replay memory and also run optimization step on every iteration.
# Optimization picks a random batch from the replay memory to do training of the
-# new policy. "Older" target_net is also used in optimization to compute the
-# expected Q values; it is updated occasionally to keep it current.
+# new policy. The "older" target_net is also used in optimization to compute the
+# expected Q values. A soft update of its weights are performed at every step.
#