From d472a149e187e7dc46345fe07819bb62c2b417fc Mon Sep 17 00:00:00 2001 From: Kai Arulkumaran Date: Wed, 7 Feb 2018 23:35:37 -0500 Subject: [PATCH 1/2] Add target network to DQN Closes #181 Closes #194 --- .../reinforcement_q_learning.py | 53 ++++++++++--------- 1 file changed, 28 insertions(+), 25 deletions(-) diff --git a/intermediate_source/reinforcement_q_learning.py b/intermediate_source/reinforcement_q_learning.py index 763f21492ea..dd7dc52b1eb 100644 --- a/intermediate_source/reinforcement_q_learning.py +++ b/intermediate_source/reinforcement_q_learning.py @@ -62,7 +62,6 @@ import matplotlib.pyplot as plt from collections import namedtuple from itertools import count -from copy import deepcopy from PIL import Image import torch @@ -187,7 +186,7 @@ def __len__(self): # # .. math:: # -# \mathcal{L} = \frac{1}{|B|}\sum_{(s, a, s', r) \ \in \ B} \mathcal{L}(\delta) +# \mathcal{L} = \frac{1}{|B|}\sum_{(s, a, s', r) \ \in \ B} \mathcal{L}(\delta) # # .. math:: # @@ -273,6 +272,7 @@ def get_screen(): # Resize, and add a batch dimension (BCHW) return resize(screen).unsqueeze(0).type(Tensor) + env.reset() plt.figure() plt.imshow(get_screen().cpu().squeeze(0).permute(1, 2, 0).numpy(), @@ -311,13 +311,20 @@ def get_screen(): EPS_START = 0.9 EPS_END = 0.05 EPS_DECAY = 200 +TARGET_UPDATE = 50 -model = DQN() +policy_net = DQN() +target_net = DQN() +target_net.eval() +target_net.load_state_dict(policy_net.state_dict()) +for param in target_net.parameters(): + param.requires_grad = False if use_cuda: - model.cuda() + policy_net.cuda() + target_net.cuda() -optimizer = optim.RMSprop(model.parameters()) +optimizer = optim.RMSprop(policy_net.parameters()) memory = ReplayMemory(10000) @@ -331,7 +338,7 @@ def select_action(state): math.exp(-1. * steps_done / EPS_DECAY) steps_done += 1 if sample > eps_threshold: - return model( + return policy_net( Variable(state, volatile=True).type(FloatTensor)).data.max(1)[1].view(1, 1) else: return LongTensor([[random.randrange(2)]]) @@ -371,14 +378,14 @@ def plot_durations(): # all the tensors into a single one, computes :math:`Q(s_t, a_t)` and # :math:`V(s_{t+1}) = \max_a Q(s_{t+1}, a)`, and combines them into our # loss. By defition we set :math:`V(s) = 0` if :math:`s` is a terminal -# state. - - -last_sync = 0 - +# state. We also use a target network to compute :math:`V(s_{t+1})` for +# added stability. The target network has its weights kept frozen most of +# the time, but is updated with the policy network's weights every so often. +# This is usually a set number of steps but we shall use episodes for +# simplicity. +# def optimize_model(): - global last_sync if len(memory) < BATCH_SIZE: return transitions = memory.sample(BATCH_SIZE) @@ -389,28 +396,19 @@ def optimize_model(): # Compute a mask of non-final states and concatenate the batch elements non_final_mask = ByteTensor(tuple(map(lambda s: s is not None, batch.next_state))) - - # We don't want to backprop through the expected action values and volatile - # will save us on temporarily changing the model parameters' - # requires_grad to False! non_final_next_states = Variable(torch.cat([s for s in batch.next_state - if s is not None]), - volatile=True) + if s is not None])) state_batch = Variable(torch.cat(batch.state)) action_batch = Variable(torch.cat(batch.action)) reward_batch = Variable(torch.cat(batch.reward)) # Compute Q(s_t, a) - the model computes Q(s_t), then we select the # columns of actions taken - state_action_values = model(state_batch).gather(1, action_batch) + state_action_values = policy_net(state_batch).gather(1, action_batch) # Compute V(s_{t+1}) for all next states. next_state_values = Variable(torch.zeros(BATCH_SIZE).type(Tensor)) - next_state_values[non_final_mask] = model(non_final_next_states).max(1)[0] - # Now, we don't want to mess up the loss with a volatile flag, so let's - # clear it. After this, we'll just end up with a Variable that has - # requires_grad=False - next_state_values.volatile = False + next_state_values[non_final_mask] = target_net(non_final_next_states).max(1)[0] # Compute the expected Q values expected_state_action_values = (next_state_values * GAMMA) + reward_batch @@ -420,10 +418,11 @@ def optimize_model(): # Optimize the model optimizer.zero_grad() loss.backward() - for param in model.parameters(): + for param in policy_net.parameters(): param.grad.data.clamp_(-1, 1) optimizer.step() + ###################################################################### # # Below, you can find the main training loop. At the beginning we reset @@ -434,6 +433,7 @@ def optimize_model(): # # Below, `num_episodes` is set small. You should download # the notebook and run lot more epsiodes. +# num_episodes = 10 for i_episode in range(num_episodes): @@ -468,6 +468,9 @@ def optimize_model(): episode_durations.append(t + 1) plot_durations() break + # Update the target network + if i_episode % 5 == 0: + target_net.load_state_dict(policy_net.state_dict()) print('Complete') env.render(close=True) From a855aea47fb1e45c5b47e9cfa7debaa849d8f67b Mon Sep 17 00:00:00 2001 From: Kai Arulkumaran Date: Wed, 7 Feb 2018 23:48:22 -0500 Subject: [PATCH 2/2] Revert back to volatility for now --- intermediate_source/reinforcement_q_learning.py | 17 +++++++++-------- 1 file changed, 9 insertions(+), 8 deletions(-) diff --git a/intermediate_source/reinforcement_q_learning.py b/intermediate_source/reinforcement_q_learning.py index dd7dc52b1eb..1046faaa6f2 100644 --- a/intermediate_source/reinforcement_q_learning.py +++ b/intermediate_source/reinforcement_q_learning.py @@ -236,7 +236,7 @@ def forward(self, x): # resize = T.Compose([T.ToPILImage(), - T.Scale(40, interpolation=Image.CUBIC), + T.Resize(40, interpolation=Image.CUBIC), T.ToTensor()]) # This is based on the code from gym. @@ -311,14 +311,12 @@ def get_screen(): EPS_START = 0.9 EPS_END = 0.05 EPS_DECAY = 200 -TARGET_UPDATE = 50 +TARGET_UPDATE = 10 policy_net = DQN() target_net = DQN() -target_net.eval() target_net.load_state_dict(policy_net.state_dict()) -for param in target_net.parameters(): - param.requires_grad = False +target_net.eval() if use_cuda: policy_net.cuda() @@ -397,7 +395,8 @@ def optimize_model(): non_final_mask = ByteTensor(tuple(map(lambda s: s is not None, batch.next_state))) non_final_next_states = Variable(torch.cat([s for s in batch.next_state - if s is not None])) + if s is not None]), + volatile=True) state_batch = Variable(torch.cat(batch.state)) action_batch = Variable(torch.cat(batch.action)) reward_batch = Variable(torch.cat(batch.reward)) @@ -411,6 +410,8 @@ def optimize_model(): next_state_values[non_final_mask] = target_net(non_final_next_states).max(1)[0] # Compute the expected Q values expected_state_action_values = (next_state_values * GAMMA) + reward_batch + # Undo volatility (which was used to prevent unnecessary gradients) + expected_state_action_values = Variable(expected_state_action_values.data) # Compute Huber loss loss = F.smooth_l1_loss(state_action_values, expected_state_action_values) @@ -435,7 +436,7 @@ def optimize_model(): # the notebook and run lot more epsiodes. # -num_episodes = 10 +num_episodes = 50 for i_episode in range(num_episodes): # Initialize the environment and state env.reset() @@ -469,7 +470,7 @@ def optimize_model(): plot_durations() break # Update the target network - if i_episode % 5 == 0: + if i_episode % TARGET_UPDATE == 0: target_net.load_state_dict(policy_net.state_dict()) print('Complete')