Skip to content

Add target network to DQN #209

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Mar 29, 2018
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
54 changes: 29 additions & 25 deletions intermediate_source/reinforcement_q_learning.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,6 @@
import matplotlib.pyplot as plt
from collections import namedtuple
from itertools import count
from copy import deepcopy
from PIL import Image

import torch
Expand Down Expand Up @@ -187,7 +186,7 @@ def __len__(self):
#
# .. math::
#
# \mathcal{L} = \frac{1}{|B|}\sum_{(s, a, s', r) \ \in \ B} \mathcal{L}(\delta)
# \mathcal{L} = \frac{1}{|B|}\sum_{(s, a, s', r) \ \in \ B} \mathcal{L}(\delta)
#
# .. math::
#
Expand Down Expand Up @@ -237,7 +236,7 @@ def forward(self, x):
#

resize = T.Compose([T.ToPILImage(),
T.Scale(40, interpolation=Image.CUBIC),
T.Resize(40, interpolation=Image.CUBIC),
T.ToTensor()])

# This is based on the code from gym.
Expand Down Expand Up @@ -273,6 +272,7 @@ def get_screen():
# Resize, and add a batch dimension (BCHW)
return resize(screen).unsqueeze(0).type(Tensor)


env.reset()
plt.figure()
plt.imshow(get_screen().cpu().squeeze(0).permute(1, 2, 0).numpy(),
Expand Down Expand Up @@ -311,13 +311,18 @@ def get_screen():
EPS_START = 0.9
EPS_END = 0.05
EPS_DECAY = 200
TARGET_UPDATE = 10

model = DQN()
policy_net = DQN()
target_net = DQN()
target_net.load_state_dict(policy_net.state_dict())
target_net.eval()

if use_cuda:
model.cuda()
policy_net.cuda()
target_net.cuda()

optimizer = optim.RMSprop(model.parameters())
optimizer = optim.RMSprop(policy_net.parameters())
memory = ReplayMemory(10000)


Expand All @@ -331,7 +336,7 @@ def select_action(state):
math.exp(-1. * steps_done / EPS_DECAY)
steps_done += 1
if sample > eps_threshold:
return model(
return policy_net(
Variable(state, volatile=True).type(FloatTensor)).data.max(1)[1].view(1, 1)
else:
return LongTensor([[random.randrange(2)]])
Expand Down Expand Up @@ -371,14 +376,14 @@ def plot_durations():
# all the tensors into a single one, computes :math:`Q(s_t, a_t)` and
# :math:`V(s_{t+1}) = \max_a Q(s_{t+1}, a)`, and combines them into our
# loss. By defition we set :math:`V(s) = 0` if :math:`s` is a terminal
# state.


last_sync = 0

# state. We also use a target network to compute :math:`V(s_{t+1})` for
# added stability. The target network has its weights kept frozen most of
# the time, but is updated with the policy network's weights every so often.
# This is usually a set number of steps but we shall use episodes for
# simplicity.
#

def optimize_model():
global last_sync
if len(memory) < BATCH_SIZE:
return
transitions = memory.sample(BATCH_SIZE)
Expand All @@ -389,10 +394,6 @@ def optimize_model():
# Compute a mask of non-final states and concatenate the batch elements
non_final_mask = ByteTensor(tuple(map(lambda s: s is not None,
batch.next_state)))

# We don't want to backprop through the expected action values and volatile
# will save us on temporarily changing the model parameters'
# requires_grad to False!
non_final_next_states = Variable(torch.cat([s for s in batch.next_state
if s is not None]),
volatile=True)
Expand All @@ -402,28 +403,27 @@ def optimize_model():

# Compute Q(s_t, a) - the model computes Q(s_t), then we select the
# columns of actions taken
state_action_values = model(state_batch).gather(1, action_batch)
state_action_values = policy_net(state_batch).gather(1, action_batch)

# Compute V(s_{t+1}) for all next states.
next_state_values = Variable(torch.zeros(BATCH_SIZE).type(Tensor))
next_state_values[non_final_mask] = model(non_final_next_states).max(1)[0]
# Now, we don't want to mess up the loss with a volatile flag, so let's
# clear it. After this, we'll just end up with a Variable that has
# requires_grad=False
next_state_values.volatile = False
next_state_values[non_final_mask] = target_net(non_final_next_states).max(1)[0]
# Compute the expected Q values
expected_state_action_values = (next_state_values * GAMMA) + reward_batch
# Undo volatility (which was used to prevent unnecessary gradients)
expected_state_action_values = Variable(expected_state_action_values.data)

# Compute Huber loss
loss = F.smooth_l1_loss(state_action_values, expected_state_action_values)

# Optimize the model
optimizer.zero_grad()
loss.backward()
for param in model.parameters():
for param in policy_net.parameters():
param.grad.data.clamp_(-1, 1)
optimizer.step()


######################################################################
#
# Below, you can find the main training loop. At the beginning we reset
Expand All @@ -434,8 +434,9 @@ def optimize_model():
#
# Below, `num_episodes` is set small. You should download
# the notebook and run lot more epsiodes.
#

num_episodes = 10
num_episodes = 50
for i_episode in range(num_episodes):
# Initialize the environment and state
env.reset()
Expand Down Expand Up @@ -468,6 +469,9 @@ def optimize_model():
episode_durations.append(t + 1)
plot_durations()
break
# Update the target network
if i_episode % TARGET_UPDATE == 0:
target_net.load_state_dict(policy_net.state_dict())

print('Complete')
env.render(close=True)
Expand Down