Skip to content

Commit 05ffab1

Browse files
authored
Merge pull request #88 from chsasank/rl_cuda
Use cuda on RL tutorial
2 parents 78ff247 + 7ae5a2c commit 05ffab1

File tree

1 file changed

+32
-24
lines changed

1 file changed

+32
-24
lines changed

intermediate_source/reinforcement_q_learning.py

Lines changed: 32 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -68,17 +68,28 @@
6868
import torch
6969
import torch.nn as nn
7070
import torch.optim as optim
71-
import torch.autograd as autograd
7271
import torch.nn.functional as F
72+
from torch.autograd import Variable
7373
import torchvision.transforms as T
7474

75+
7576
env = gym.make('CartPole-v0').unwrapped
7677

78+
# set up matplotlib
7779
is_ipython = 'inline' in matplotlib.get_backend()
7880
if is_ipython:
7981
from IPython import display
8082

8183
plt.ion()
84+
85+
# if gpu is to be used
86+
use_cuda = torch.cuda.is_available()
87+
FloatTensor = torch.cuda.FloatTensor if use_cuda else torch.FloatTensor
88+
LongTensor = torch.cuda.LongTensor if use_cuda else torch.LongTensor
89+
ByteTensor = torch.cuda.ByteTensor if use_cuda else torch.ByteTensor
90+
Tensor = FloatTensor
91+
92+
8293
######################################################################
8394
# Replay Memory
8495
# -------------
@@ -260,12 +271,12 @@ def get_screen():
260271
screen = np.ascontiguousarray(screen, dtype=np.float32) / 255
261272
screen = torch.from_numpy(screen)
262273
# Resize, and add a batch dimension (BCHW)
263-
return resize(screen).unsqueeze(0)
274+
return resize(screen).unsqueeze(0).type(Tensor)
264275

265276
env.reset()
266277
plt.figure()
267-
plt.imshow(get_screen().squeeze(0).permute(
268-
1, 2, 0).numpy(), interpolation='none')
278+
plt.imshow(get_screen().cpu().squeeze(0).permute(1, 2, 0).numpy(),
279+
interpolation='none')
269280
plt.title('Example extracted screen')
270281
plt.show()
271282

@@ -300,22 +311,14 @@ def get_screen():
300311
EPS_START = 0.9
301312
EPS_END = 0.05
302313
EPS_DECAY = 200
303-
USE_CUDA = torch.cuda.is_available()
304314

305315
model = DQN()
306-
memory = ReplayMemory(10000)
307-
optimizer = optim.RMSprop(model.parameters())
308316

309-
if USE_CUDA:
317+
if use_cuda:
310318
model.cuda()
311319

312-
313-
class Variable(autograd.Variable):
314-
315-
def __init__(self, data, *args, **kwargs):
316-
if USE_CUDA:
317-
data = data.cuda()
318-
super(Variable, self).__init__(data, *args, **kwargs)
320+
optimizer = optim.RMSprop(model.parameters())
321+
memory = ReplayMemory(10000)
319322

320323

321324
steps_done = 0
@@ -328,9 +331,10 @@ def select_action(state):
328331
math.exp(-1. * steps_done / EPS_DECAY)
329332
steps_done += 1
330333
if sample > eps_threshold:
331-
return model(Variable(state, volatile=True)).data.max(1)[1].cpu()
334+
return model(
335+
Variable(state, volatile=True).type(FloatTensor)).data.max(1)[1]
332336
else:
333-
return torch.LongTensor([[random.randrange(2)]])
337+
return LongTensor([[random.randrange(2)]])
334338

335339

336340
episode_durations = []
@@ -339,7 +343,7 @@ def select_action(state):
339343
def plot_durations():
340344
plt.figure(2)
341345
plt.clf()
342-
durations_t = torch.Tensor(episode_durations)
346+
durations_t = torch.FloatTensor(episode_durations)
343347
plt.title('Training...')
344348
plt.xlabel('Episode')
345349
plt.ylabel('Duration')
@@ -349,6 +353,8 @@ def plot_durations():
349353
means = durations_t.unfold(0, 100, 1).mean(1).view(-1)
350354
means = torch.cat((torch.zeros(99), means))
351355
plt.plot(means.numpy())
356+
357+
plt.pause(0.001) # pause a bit so that plots are updated
352358
if is_ipython:
353359
display.clear_output(wait=True)
354360
display.display(plt.gcf())
@@ -370,6 +376,7 @@ def plot_durations():
370376

371377
last_sync = 0
372378

379+
373380
def optimize_model():
374381
global last_sync
375382
if len(memory) < BATCH_SIZE:
@@ -380,10 +387,9 @@ def optimize_model():
380387
batch = Transition(*zip(*transitions))
381388

382389
# Compute a mask of non-final states and concatenate the batch elements
383-
non_final_mask = torch.ByteTensor(
384-
tuple(map(lambda s: s is not None, batch.next_state)))
385-
if USE_CUDA:
386-
non_final_mask = non_final_mask.cuda()
390+
non_final_mask = ByteTensor(tuple(map(lambda s: s is not None,
391+
batch.next_state)))
392+
387393
# We don't want to backprop through the expected action values and volatile
388394
# will save us on temporarily changing the model parameters'
389395
# requires_grad to False!
@@ -399,7 +405,7 @@ def optimize_model():
399405
state_action_values = model(state_batch).gather(1, action_batch)
400406

401407
# Compute V(s_{t+1}) for all next states.
402-
next_state_values = Variable(torch.zeros(BATCH_SIZE))
408+
next_state_values = Variable(torch.zeros(BATCH_SIZE).type(Tensor))
403409
next_state_values[non_final_mask] = model(non_final_next_states).max(1)[0]
404410
# Now, we don't want to mess up the loss with a volatile flag, so let's
405411
# clear it. After this, we'll just end up with a Variable that has
@@ -440,7 +446,7 @@ def optimize_model():
440446
# Select and perform an action
441447
action = select_action(state)
442448
_, reward, done, _ = env.step(action[0, 0])
443-
reward = torch.Tensor([reward])
449+
reward = Tensor([reward])
444450

445451
# Observe new state
446452
last_screen = current_screen
@@ -463,6 +469,8 @@ def optimize_model():
463469
plot_durations()
464470
break
465471

472+
print('Complete')
473+
env.render(close=True)
466474
env.close()
467475
plt.ioff()
468476
plt.show()

0 commit comments

Comments
 (0)