diff --git a/_static/img/reinforcement_learning_diagram.drawio b/_static/img/reinforcement_learning_diagram.drawio new file mode 100644 index 00000000000..2ff4e6f0270 --- /dev/null +++ b/_static/img/reinforcement_learning_diagram.drawio @@ -0,0 +1 @@ +5Vpbc+MmFP41nmkfmpGEpMiPjTftzrTZZtbbbbYvHSxhiRQJFeHb/vqChG4gx95ElqfTeCaGwwEO37lwDskMLNL9zwzmyQONEJk5VrSfgXczx7GtuSW+JOVQUXzXrggxw5FiaglL/BXVMxV1gyNU9Bg5pYTjvE8MaZahkPdokDG667OtKenvmsMYGYRlCIlJ/QNHPKmogWe19PcIx0m9s22pkRTWzIpQJDCiuw4J3M/AglHKq1a6XyAiwatxqeb9dGS0EYyhjJ8z4flQfCient3PwfKXVfTn40P6/vMPapUtJBt14I8oJ/AgaA8opeygZOeHGhBGN1mE5JrWDNztEszRMoehHN0JExC0hKdE9GzRXGNCFpRQVs4Fa09+BD1mMMJC7s7YqvyIsYIz+jfqjPjljxhRsiLG0f4oCHYDrbBJRFPE5SEsNcGt1aPMMVDdXatbu1ZY0tGrr2hQmVPcrNwiLhoK9G9QgGMo4Lec41T6gWN9535v4C/WFLaPTmM/AlrebR8t2z0TruBScAEDLgaziKaCtoI8TAy4wg3bltYqIUFZ9KOMCaIbElgUOOxD1rftai0UGQFCg0/sRzcsRKfdjEMWI37KGkx1dOD2BtCuaQwRyPG2L+6QCtQOjxSLgxz1DRdoWqyOqWZ1I42+kNdfCOjmUOFgLCR0U0aemi2XDMVxgQ3ztK0X5fJtjR/0+EWjkqA1z0YHr7dYz7DYR0pwKM/5AfFRw2sEUbAOh8PrLYI+sgbDaxig1foy4dWxrh1fAzOACodeqi5lPKExzSC5b6laGGh5fqU0V8g/I84PKluBG06HQu8okcN/W+Q4OyS8CWTfsPFPlbz/Cxu/eg5hm0nEmPcg2mP+JNs3nup96Yy823c7h/HvTu8/cXfqV9H8lXen5xxJuUa+O91A2yd4+e709LvTP8Hvvsh/mbvWNtPDUR0hE+I9NVNF50vrFrLb+kLZq52hcaCO+9hTuY9zpvt413Qf3Vqc29e6z1xbyJkm9TSKb51fS4mdKVJP2zXc4fc8grwsLb3rlpaODog3cItaU96ijUdNmirWscHuXq03jjdVeLDnZ8aHI+qcJsGspewFdp8Iee8ivJU7Ehxn5YD/z0a+qN0RtOZtT7Ri9Q1Tac3ZqsjLvvWJQZzhLBbNUmtqXSFnuXQzb5zd7Bvxa5FQWkgvbB4vJDxCgXVlCEOOaXZhURwpylJQiRQFZdsL7wfkfh9RSFkkGql6XrQ2KiRddG9X7t2+rF10L6/ElUpu5VZ/ZWUt1D/piuk76/K8pWyq5S+lHiVi23oGaA9E7PlAxG4Yxw/ZZr4X1q5Vu9AE6V8wP5UAyt4jYlgcG7HrlUhVGL1WkgeO5EDf/r5oDdcuo9dIeqUPXk7ygK/xn3iPNACxJkgKHTMpHNVJBmod6+Z2snzmqmWMrlCgVx/nWjjQLc+7jIUDvYw5ZeFA43emsFjzCf0iYd2ava6q7z2LTVbX18XdyaDvX9UjNIMBevl2tkdo71VATyrG8ghd4LcV6qLb/oW/Ym//TwLc/ws= \ No newline at end of file diff --git a/_static/img/reinforcement_learning_diagram.jpg b/_static/img/reinforcement_learning_diagram.jpg index bdcbc322502..7e04efc2534 100644 Binary files a/_static/img/reinforcement_learning_diagram.jpg and b/_static/img/reinforcement_learning_diagram.jpg differ diff --git a/intermediate_source/reinforcement_q_learning.py b/intermediate_source/reinforcement_q_learning.py index 7353bba05a4..611cfb32448 100644 --- a/intermediate_source/reinforcement_q_learning.py +++ b/intermediate_source/reinforcement_q_learning.py @@ -6,7 +6,7 @@ This tutorial shows how to use PyTorch to train a Deep Q Learning (DQN) agent -on the CartPole-v0 task from the `OpenAI Gym `__. +on the CartPole-v1 task from the `OpenAI Gym `__. **Task** @@ -30,22 +30,19 @@ The CartPole task is designed so that the inputs to the agent are 4 real values representing the environment state (position, velocity, etc.). -However, neural networks can solve the task purely by looking at the -scene, so we'll use a patch of the screen centered on the cart as an -input. Because of this, our results aren't directly comparable to the -ones from the official leaderboard - our task is much harder. -Unfortunately this does slow down the training, because we have to -render all the frames. +We take these 4 inputs without any scaling and pass them through a +small fully-connected network with 2 outputs, one for each action. +The network is trained to predict the expected value for each action, +given the input state. The action with the highest expected value is +then chosen. -Strictly speaking, we will present the state as the difference between -the current screen patch and the previous one. This will allow the agent -to take the velocity of the pole into account from one image. **Packages** First, let's import needed packages. Firstly, we need `gym `__ for the environment +Install by using `pip`. If you are running this in Google colab, run: .. code-block:: bash @@ -57,8 +54,6 @@ - neural networks (``torch.nn``) - optimization (``torch.optim``) - automatic differentiation (``torch.autograd``) -- utilities for vision tasks (``torchvision`` - `a separate - package `__). """ @@ -70,19 +65,18 @@ import matplotlib.pyplot as plt from collections import namedtuple, deque from itertools import count -from PIL import Image import torch import torch.nn as nn import torch.optim as optim import torch.nn.functional as F -import torchvision.transforms as T - -if gym.__version__ < '0.26': - env = gym.make('CartPole-v0', new_step_api=True, render_mode='single_rgb_array').unwrapped +if gym.__version__[:4] == '0.26': + env = gym.make('CartPole-v1') +elif gym.__version__[:4] == '0.25': + env = gym.make('CartPole-v1', new_step_api=True) else: - env = gym.make('CartPole-v0', render_mode='rgb_array').unwrapped + raise ImportError(f"Requires gym v25 or v26, actual version: {gym.__version__}") # set up matplotlib is_ipython = 'inline' in matplotlib.get_backend() @@ -152,9 +146,11 @@ def __len__(self): # :math:`R_{t_0} = \sum_{t=t_0}^{\infty} \gamma^{t - t_0} r_t`, where # :math:`R_{t_0}` is also known as the *return*. The discount, # :math:`\gamma`, should be a constant between :math:`0` and :math:`1` -# that ensures the sum converges. It makes rewards from the uncertain far -# future less important for our agent than the ones in the near future -# that it can be fairly confident about. +# that ensures the sum converges. A lower :math:`\gamma` makes +# rewards from the uncertain far future less important for our agent +# than the ones in the near future that it can be fairly confident +# about. It also encourages agents to collect reward closer in time +# than equivalent rewards temporally future away. # # The main idea behind Q-learning is that if we had a function # :math:`Q^*: State \times Action \rightarrow \mathbb{R}`, that could tell @@ -177,7 +173,7 @@ def __len__(self): # The difference between the two sides of the equality is known as the # temporal difference error, :math:`\delta`: # -# .. math:: \delta = Q(s, a) - (r + \gamma \max_a Q(s', a)) +# .. math:: \delta = Q(s, a) - (r + \gamma \max_a' Q(s', a)) # # To minimise this error, we will use the `Huber # loss `__. The Huber loss acts @@ -211,86 +207,18 @@ def __len__(self): class DQN(nn.Module): - def __init__(self, h, w, outputs): + def __init__(self, n_observations, n_actions): super(DQN, self).__init__() - self.conv1 = nn.Conv2d(3, 16, kernel_size=5, stride=2) - self.bn1 = nn.BatchNorm2d(16) - self.conv2 = nn.Conv2d(16, 32, kernel_size=5, stride=2) - self.bn2 = nn.BatchNorm2d(32) - self.conv3 = nn.Conv2d(32, 32, kernel_size=5, stride=2) - self.bn3 = nn.BatchNorm2d(32) - - # Number of Linear input connections depends on output of conv2d layers - # and therefore the input image size, so compute it. - def conv2d_size_out(size, kernel_size = 5, stride = 2): - return (size - (kernel_size - 1) - 1) // stride + 1 - convw = conv2d_size_out(conv2d_size_out(conv2d_size_out(w))) - convh = conv2d_size_out(conv2d_size_out(conv2d_size_out(h))) - linear_input_size = convw * convh * 32 - self.head = nn.Linear(linear_input_size, outputs) + self.layer1 = nn.Linear(n_observations, 128) + self.layer2 = nn.Linear(128, 128) + self.layer3 = nn.Linear(128, n_actions) # Called with either one element to determine next action, or a batch # during optimization. Returns tensor([[left0exp,right0exp]...]). def forward(self, x): - x = x.to(device) - x = F.relu(self.bn1(self.conv1(x))) - x = F.relu(self.bn2(self.conv2(x))) - x = F.relu(self.bn3(self.conv3(x))) - return self.head(x.view(x.size(0), -1)) - - -###################################################################### -# Input extraction -# ^^^^^^^^^^^^^^^^ -# -# The code below are utilities for extracting and processing rendered -# images from the environment. It uses the ``torchvision`` package, which -# makes it easy to compose image transforms. Once you run the cell it will -# display an example patch that it extracted. -# - -resize = T.Compose([T.ToPILImage(), - T.Resize(40, interpolation=Image.CUBIC), - T.ToTensor()]) - - -def get_cart_location(screen_width): - world_width = env.x_threshold * 2 - scale = screen_width / world_width - return int(env.state[0] * scale + screen_width / 2.0) # MIDDLE OF CART - -def get_screen(): - # Returned screen requested by gym is 400x600x3, but is sometimes larger - # such as 800x1200x3. Transpose it into torch order (CHW). - screen = env.render().transpose((2, 0, 1)) - # Cart is in the lower half, so strip off the top and bottom of the screen - _, screen_height, screen_width = screen.shape - screen = screen[:, int(screen_height*0.4):int(screen_height * 0.8)] - view_width = int(screen_width * 0.6) - cart_location = get_cart_location(screen_width) - if cart_location < view_width // 2: - slice_range = slice(view_width) - elif cart_location > (screen_width - view_width // 2): - slice_range = slice(-view_width, None) - else: - slice_range = slice(cart_location - view_width // 2, - cart_location + view_width // 2) - # Strip off the edges, so that we have a square image centered on a cart - screen = screen[:, :, slice_range] - # Convert to float, rescale, convert to torch tensor - # (this doesn't require a copy) - screen = np.ascontiguousarray(screen, dtype=np.float32) / 255 - screen = torch.from_numpy(screen) - # Resize, and add a batch dimension (BCHW) - return resize(screen).unsqueeze(0) - - -env.reset() -plt.figure() -plt.imshow(get_screen().cpu().squeeze(0).permute(1, 2, 0).numpy(), - interpolation='none') -plt.title('Example extracted screen') -plt.show() + x = F.relu(self.layer1(x)) + x = F.relu(self.layer2(x)) + return self.layer3(x) ###################################################################### @@ -315,28 +243,35 @@ def get_screen(): # episode. # +# BATCH_SIZE is the number of transitions sampled from the replay buffer +# GAMMA is the discount factor as mentioned in the previous section +# EPS_START is the starting value of epsilon +# EPS_END is the final value of epsilon +# EPS_DECAY controls the rate of exponential decay of epsilon, higher means a slower decay +# TAU is the update rate of the target network +# LR is the learning rate of the AdamW optimizer BATCH_SIZE = 128 -GAMMA = 0.999 +GAMMA = 0.99 EPS_START = 0.9 EPS_END = 0.05 -EPS_DECAY = 200 -TARGET_UPDATE = 10 - -# Get screen size so that we can initialize layers correctly based on shape -# returned from AI gym. Typical dimensions at this point are close to 3x40x90 -# which is the result of a clamped and down-scaled render buffer in get_screen() -init_screen = get_screen() -_, _, screen_height, screen_width = init_screen.shape +EPS_DECAY = 1000 +TAU = 0.005 +LR = 1e-4 # Get number of actions from gym action space n_actions = env.action_space.n - -policy_net = DQN(screen_height, screen_width, n_actions).to(device) -target_net = DQN(screen_height, screen_width, n_actions).to(device) +# Get the number of state observations +if gym.__version__[:4] == '0.26': + state, _ = env.reset() +elif gym.__version__[:4] == '0.25': + state, _ = env.reset(return_info=True) +n_observations = len(state) + +policy_net = DQN(n_observations, n_actions).to(device) +target_net = DQN(n_observations, n_actions).to(device) target_net.load_state_dict(policy_net.state_dict()) -target_net.eval() -optimizer = optim.RMSprop(policy_net.parameters()) +optimizer = optim.AdamW(policy_net.parameters(), lr=LR, amsgrad=True) memory = ReplayMemory(10000) @@ -356,14 +291,14 @@ def select_action(state): # found, so we pick action with the larger expected reward. return policy_net(state).max(1)[1].view(1, 1) else: - return torch.tensor([[random.randrange(n_actions)]], device=device, dtype=torch.long) + return torch.tensor([[env.action_space.sample()]], device=device, dtype=torch.long) episode_durations = [] def plot_durations(): - plt.figure(2) + plt.figure(1) plt.clf() durations_t = torch.tensor(episode_durations, dtype=torch.float) plt.title('Training...') @@ -394,10 +329,9 @@ def plot_durations(): # :math:`V(s_{t+1}) = \max_a Q(s_{t+1}, a)`, and combines them into our # loss. By definition we set :math:`V(s) = 0` if :math:`s` is a terminal # state. We also use a target network to compute :math:`V(s_{t+1})` for -# added stability. The target network has its weights kept frozen most of -# the time, but is updated with the policy network's weights every so often. -# This is usually a set number of steps but we shall use episodes for -# simplicity. +# added stability. The target network is updated at every step with a +# `soft update `__ controlled by +# the hyperparameter ``TAU``, which was previously defined. # def optimize_model(): @@ -430,7 +364,8 @@ def optimize_model(): # This is merged based on the mask, such that we'll have either the expected # state value or 0 in case the state was final. next_state_values = torch.zeros(BATCH_SIZE, device=device) - next_state_values[non_final_mask] = target_net(non_final_next_states).max(1)[0].detach() + with torch.no_grad(): + next_state_values[non_final_mask] = target_net(non_final_next_states).max(1)[0] # Compute the expected Q values expected_state_action_values = (next_state_values * GAMMA) + reward_batch @@ -441,44 +376,49 @@ def optimize_model(): # Optimize the model optimizer.zero_grad() loss.backward() - for param in policy_net.parameters(): - param.grad.data.clamp_(-1, 1) + # In-place gradient clipping + torch.nn.utils.clip_grad_value_(policy_net.parameters(), 100) optimizer.step() ###################################################################### # # Below, you can find the main training loop. At the beginning we reset -# the environment and initialize the ``state`` Tensor. Then, we sample -# an action, execute it, observe the next screen and the reward (always +# the environment and obtain the initial ``state`` Tensor. Then, we sample +# an action, execute it, observe the next state and the reward (always # 1), and optimize our model once. When the episode ends (our model # fails), we restart the loop. # -# Below, `num_episodes` is set small. You should download -# the notebook and run lot more epsiodes, such as 300+ for meaningful -# duration improvements. +# Below, `num_episodes` is set to 600 if a GPU is available, otherwise 50 +# episodes are scheduled so training does not take too long. However, 50 +# episodes is insufficient for to observe good performance on cartpole. +# You should see the model constantly achieve 500 steps within 600 training +# episodes. Training RL agents can be a noisy process, so restarting training +# can produce better results if convergence is not observed. # -num_episodes = 50 +if torch.cuda.is_available(): + num_episodes = 600 +else: + num_episodes = 50 + for i_episode in range(num_episodes): - # Initialize the environment and state - env.reset() - last_screen = get_screen() - current_screen = get_screen() - state = current_screen - last_screen + # Initialize the environment and get it's state + if gym.__version__[:4] == '0.26': + state, _ = env.reset() + elif gym.__version__[:4] == '0.25': + state, _ = env.reset(return_info=True) + state = torch.tensor(state, dtype=torch.float32, device=device).unsqueeze(0) for t in count(): - # Select and perform an action action = select_action(state) - _, reward, done, _, _ = env.step(action.item()) + observation, reward, terminated, truncated, _ = env.step(action.item()) reward = torch.tensor([reward], device=device) + done = terminated or truncated - # Observe new state - last_screen = current_screen - current_screen = get_screen() - if not done: - next_state = current_screen - last_screen - else: + if terminated: next_state = None + else: + next_state = torch.tensor(observation, dtype=torch.float32, device=device).unsqueeze(0) # Store the transition in memory memory.push(state, action, next_state, reward) @@ -488,18 +428,21 @@ def optimize_model(): # Perform one step of the optimization (on the policy network) optimize_model() + + # Soft update of the target network's weights + # θ′ ← τ θ + (1 −τ )θ′ + target_net_state_dict = target_net.state_dict() + policy_net_state_dict = policy_net.state_dict() + for key in policy_net_state_dict: + target_net_state_dict[key] = policy_net_state_dict[key]*TAU + target_net_state_dict[key]*(1-TAU) + target_net.load_state_dict(target_net_state_dict) + if done: episode_durations.append(t + 1) plot_durations() break - # Update the target network, copying all weights and biases in DQN - if t % TARGET_UPDATE == 0: - target_net.load_state_dict(policy_net.state_dict()) - print('Complete') -env.render() -env.close() plt.ioff() plt.show() @@ -512,6 +455,6 @@ def optimize_model(): # step sample from the gym environment. We record the results in the # replay memory and also run optimization step on every iteration. # Optimization picks a random batch from the replay memory to do training of the -# new policy. "Older" target_net is also used in optimization to compute the -# expected Q values; it is updated occasionally to keep it current. +# new policy. The "older" target_net is also used in optimization to compute the +# expected Q values. A soft update of its weights are performed at every step. #