Skip to content

Commit 9770577

Browse files
committed
better to_cuda ops
1 parent 7b148a6 commit 9770577

File tree

1 file changed

+28
-23
lines changed

1 file changed

+28
-23
lines changed

intermediate_source/reinforcement_q_learning.py

Lines changed: 28 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -68,17 +68,28 @@
6868
import torch
6969
import torch.nn as nn
7070
import torch.optim as optim
71-
import torch.autograd as autograd
7271
import torch.nn.functional as F
72+
from torch.autograd import Variable
7373
import torchvision.transforms as T
7474

75+
7576
env = gym.make('CartPole-v0').unwrapped
7677

78+
# set up matplotlib
7779
is_ipython = 'inline' in matplotlib.get_backend()
7880
if is_ipython:
7981
from IPython import display
8082

8183
plt.ion()
84+
85+
# if gpu is to be used
86+
use_cuda = torch.cuda.is_available()
87+
FloatTensor = torch.cuda.FloatTensor if use_cuda else torch.FloatTensor
88+
LongTensor = torch.cuda.LongTensor if use_cuda else torch.LongTensor
89+
ByteTensor = torch.cuda.ByteTensor if use_cuda else torch.ByteTensor
90+
Tensor = FloatTensor
91+
92+
8293
######################################################################
8394
# Replay Memory
8495
# -------------
@@ -258,14 +269,14 @@ def get_screen():
258269
# Convert to float, rescare, convert to torch tensor
259270
# (this doesn't require a copy)
260271
screen = np.ascontiguousarray(screen, dtype=np.float32) / 255
261-
screen = torch.from_numpy(screen)
272+
screen = torch.from_numpy(screen).type(Tensor)
262273
# Resize, and add a batch dimension (BCHW)
263274
return resize(screen).unsqueeze(0)
264275

265276
env.reset()
266277
plt.figure()
267-
plt.imshow(get_screen().squeeze(0).permute(
268-
1, 2, 0).numpy(), interpolation='none')
278+
plt.imshow(get_screen().cpu().squeeze(0).permute(1, 2, 0).numpy(),
279+
interpolation='none')
269280
plt.title('Example extracted screen')
270281
plt.show()
271282

@@ -300,22 +311,14 @@ def get_screen():
300311
EPS_START = 0.9
301312
EPS_END = 0.05
302313
EPS_DECAY = 200
303-
USE_CUDA = torch.cuda.is_available()
304314

305315
model = DQN()
306-
memory = ReplayMemory(10000)
307-
optimizer = optim.RMSprop(model.parameters())
308316

309-
if USE_CUDA:
317+
if use_cuda:
310318
model.cuda()
311319

312-
313-
class Variable(autograd.Variable):
314-
315-
def __init__(self, data, *args, **kwargs):
316-
if USE_CUDA:
317-
data = data.cuda()
318-
super(Variable, self).__init__(data, *args, **kwargs)
320+
optimizer = optim.RMSprop(model.parameters())
321+
memory = ReplayMemory(10000)
319322

320323

321324
steps_done = 0
@@ -328,9 +331,10 @@ def select_action(state):
328331
math.exp(-1. * steps_done / EPS_DECAY)
329332
steps_done += 1
330333
if sample > eps_threshold:
331-
return model(Variable(state, volatile=True)).data.max(1)[1].cpu()
334+
return model(
335+
Variable(state, volatile=True).type(FloatTensor)).data.max(1)[1]
332336
else:
333-
return torch.LongTensor([[random.randrange(2)]])
337+
return LongTensor([[random.randrange(2)]])
334338

335339

336340
episode_durations = []
@@ -339,7 +343,7 @@ def select_action(state):
339343
def plot_durations():
340344
plt.figure(2)
341345
plt.clf()
342-
durations_t = torch.Tensor(episode_durations)
346+
durations_t = torch.FloatTensor(episode_durations)
343347
plt.title('Training...')
344348
plt.xlabel('Episode')
345349
plt.ylabel('Duration')
@@ -370,6 +374,7 @@ def plot_durations():
370374

371375
last_sync = 0
372376

377+
373378
def optimize_model():
374379
global last_sync
375380
if len(memory) < BATCH_SIZE:
@@ -380,10 +385,9 @@ def optimize_model():
380385
batch = Transition(*zip(*transitions))
381386

382387
# Compute a mask of non-final states and concatenate the batch elements
383-
non_final_mask = torch.ByteTensor(
384-
tuple(map(lambda s: s is not None, batch.next_state)))
385-
if USE_CUDA:
386-
non_final_mask = non_final_mask.cuda()
388+
non_final_mask = ByteTensor(tuple(map(lambda s: s is not None,
389+
batch.next_state)))
390+
387391
# We don't want to backprop through the expected action values and volatile
388392
# will save us on temporarily changing the model parameters'
389393
# requires_grad to False!
@@ -440,7 +444,7 @@ def optimize_model():
440444
# Select and perform an action
441445
action = select_action(state)
442446
_, reward, done, _ = env.step(action[0, 0])
443-
reward = torch.Tensor([reward])
447+
reward = Tensor([reward])
444448

445449
# Observe new state
446450
last_screen = current_screen
@@ -463,6 +467,7 @@ def optimize_model():
463467
plot_durations()
464468
break
465469

470+
print('Complete')
466471
env.close()
467472
plt.ioff()
468473
plt.show()

0 commit comments

Comments
 (0)