Skip to content

Commit 7ef2a5a

Browse files
Kaixhinchsasank
authored andcommitted
Add target network to DQN (#209)
* Add target network to DQN Closes #181 Closes #194 * Revert back to volatility for now
1 parent ab62128 commit 7ef2a5a

File tree

1 file changed

+29
-25
lines changed

1 file changed

+29
-25
lines changed

intermediate_source/reinforcement_q_learning.py

Lines changed: 29 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -62,7 +62,6 @@
6262
import matplotlib.pyplot as plt
6363
from collections import namedtuple
6464
from itertools import count
65-
from copy import deepcopy
6665
from PIL import Image
6766

6867
import torch
@@ -187,7 +186,7 @@ def __len__(self):
187186
#
188187
# .. math::
189188
#
190-
# \mathcal{L} = \frac{1}{|B|}\sum_{(s, a, s', r) \ \in \ B} \mathcal{L}(\delta)
189+
# \mathcal{L} = \frac{1}{|B|}\sum_{(s, a, s', r) \ \in \ B} \mathcal{L}(\delta)
191190
#
192191
# .. math::
193192
#
@@ -237,7 +236,7 @@ def forward(self, x):
237236
#
238237

239238
resize = T.Compose([T.ToPILImage(),
240-
T.Scale(40, interpolation=Image.CUBIC),
239+
T.Resize(40, interpolation=Image.CUBIC),
241240
T.ToTensor()])
242241

243242
# This is based on the code from gym.
@@ -273,6 +272,7 @@ def get_screen():
273272
# Resize, and add a batch dimension (BCHW)
274273
return resize(screen).unsqueeze(0).type(Tensor)
275274

275+
276276
env.reset()
277277
plt.figure()
278278
plt.imshow(get_screen().cpu().squeeze(0).permute(1, 2, 0).numpy(),
@@ -311,13 +311,18 @@ def get_screen():
311311
EPS_START = 0.9
312312
EPS_END = 0.05
313313
EPS_DECAY = 200
314+
TARGET_UPDATE = 10
314315

315-
model = DQN()
316+
policy_net = DQN()
317+
target_net = DQN()
318+
target_net.load_state_dict(policy_net.state_dict())
319+
target_net.eval()
316320

317321
if use_cuda:
318-
model.cuda()
322+
policy_net.cuda()
323+
target_net.cuda()
319324

320-
optimizer = optim.RMSprop(model.parameters())
325+
optimizer = optim.RMSprop(policy_net.parameters())
321326
memory = ReplayMemory(10000)
322327

323328

@@ -331,7 +336,7 @@ def select_action(state):
331336
math.exp(-1. * steps_done / EPS_DECAY)
332337
steps_done += 1
333338
if sample > eps_threshold:
334-
return model(
339+
return policy_net(
335340
Variable(state, volatile=True).type(FloatTensor)).data.max(1)[1].view(1, 1)
336341
else:
337342
return LongTensor([[random.randrange(2)]])
@@ -371,14 +376,14 @@ def plot_durations():
371376
# all the tensors into a single one, computes :math:`Q(s_t, a_t)` and
372377
# :math:`V(s_{t+1}) = \max_a Q(s_{t+1}, a)`, and combines them into our
373378
# loss. By defition we set :math:`V(s) = 0` if :math:`s` is a terminal
374-
# state.
375-
376-
377-
last_sync = 0
378-
379+
# state. We also use a target network to compute :math:`V(s_{t+1})` for
380+
# added stability. The target network has its weights kept frozen most of
381+
# the time, but is updated with the policy network's weights every so often.
382+
# This is usually a set number of steps but we shall use episodes for
383+
# simplicity.
384+
#
379385

380386
def optimize_model():
381-
global last_sync
382387
if len(memory) < BATCH_SIZE:
383388
return
384389
transitions = memory.sample(BATCH_SIZE)
@@ -389,10 +394,6 @@ def optimize_model():
389394
# Compute a mask of non-final states and concatenate the batch elements
390395
non_final_mask = ByteTensor(tuple(map(lambda s: s is not None,
391396
batch.next_state)))
392-
393-
# We don't want to backprop through the expected action values and volatile
394-
# will save us on temporarily changing the model parameters'
395-
# requires_grad to False!
396397
non_final_next_states = Variable(torch.cat([s for s in batch.next_state
397398
if s is not None]),
398399
volatile=True)
@@ -402,28 +403,27 @@ def optimize_model():
402403

403404
# Compute Q(s_t, a) - the model computes Q(s_t), then we select the
404405
# columns of actions taken
405-
state_action_values = model(state_batch).gather(1, action_batch)
406+
state_action_values = policy_net(state_batch).gather(1, action_batch)
406407

407408
# Compute V(s_{t+1}) for all next states.
408409
next_state_values = Variable(torch.zeros(BATCH_SIZE).type(Tensor))
409-
next_state_values[non_final_mask] = model(non_final_next_states).max(1)[0]
410-
# Now, we don't want to mess up the loss with a volatile flag, so let's
411-
# clear it. After this, we'll just end up with a Variable that has
412-
# requires_grad=False
413-
next_state_values.volatile = False
410+
next_state_values[non_final_mask] = target_net(non_final_next_states).max(1)[0]
414411
# Compute the expected Q values
415412
expected_state_action_values = (next_state_values * GAMMA) + reward_batch
413+
# Undo volatility (which was used to prevent unnecessary gradients)
414+
expected_state_action_values = Variable(expected_state_action_values.data)
416415

417416
# Compute Huber loss
418417
loss = F.smooth_l1_loss(state_action_values, expected_state_action_values)
419418

420419
# Optimize the model
421420
optimizer.zero_grad()
422421
loss.backward()
423-
for param in model.parameters():
422+
for param in policy_net.parameters():
424423
param.grad.data.clamp_(-1, 1)
425424
optimizer.step()
426425

426+
427427
######################################################################
428428
#
429429
# Below, you can find the main training loop. At the beginning we reset
@@ -434,8 +434,9 @@ def optimize_model():
434434
#
435435
# Below, `num_episodes` is set small. You should download
436436
# the notebook and run lot more epsiodes.
437+
#
437438

438-
num_episodes = 10
439+
num_episodes = 50
439440
for i_episode in range(num_episodes):
440441
# Initialize the environment and state
441442
env.reset()
@@ -468,6 +469,9 @@ def optimize_model():
468469
episode_durations.append(t + 1)
469470
plot_durations()
470471
break
472+
# Update the target network
473+
if i_episode % TARGET_UPDATE == 0:
474+
target_net.load_state_dict(policy_net.state_dict())
471475

472476
print('Complete')
473477
env.render(close=True)

0 commit comments

Comments
 (0)