68
68
import torch
69
69
import torch .nn as nn
70
70
import torch .optim as optim
71
- import torch .autograd as autograd
72
71
import torch .nn .functional as F
72
+ from torch .autograd import Variable
73
73
import torchvision .transforms as T
74
74
75
+
75
76
env = gym .make ('CartPole-v0' ).unwrapped
76
77
78
+ # set up matplotlib
77
79
is_ipython = 'inline' in matplotlib .get_backend ()
78
80
if is_ipython :
79
81
from IPython import display
80
82
81
83
plt .ion ()
84
+
85
+ # if gpu is to be used
86
+ use_cuda = torch .cuda .is_available ()
87
+ FloatTensor = torch .cuda .FloatTensor if use_cuda else torch .FloatTensor
88
+ LongTensor = torch .cuda .LongTensor if use_cuda else torch .LongTensor
89
+ ByteTensor = torch .cuda .ByteTensor if use_cuda else torch .ByteTensor
90
+ Tensor = FloatTensor
91
+
92
+
82
93
######################################################################
83
94
# Replay Memory
84
95
# -------------
@@ -258,14 +269,14 @@ def get_screen():
258
269
# Convert to float, rescare, convert to torch tensor
259
270
# (this doesn't require a copy)
260
271
screen = np .ascontiguousarray (screen , dtype = np .float32 ) / 255
261
- screen = torch .from_numpy (screen )
272
+ screen = torch .from_numpy (screen ). type ( Tensor )
262
273
# Resize, and add a batch dimension (BCHW)
263
274
return resize (screen ).unsqueeze (0 )
264
275
265
276
env .reset ()
266
277
plt .figure ()
267
- plt .imshow (get_screen ().squeeze (0 ).permute (
268
- 1 , 2 , 0 ). numpy (), interpolation = 'none' )
278
+ plt .imshow (get_screen ().cpu (). squeeze (0 ).permute (1 , 2 , 0 ). numpy (),
279
+ interpolation = 'none' )
269
280
plt .title ('Example extracted screen' )
270
281
plt .show ()
271
282
@@ -300,22 +311,14 @@ def get_screen():
300
311
EPS_START = 0.9
301
312
EPS_END = 0.05
302
313
EPS_DECAY = 200
303
- USE_CUDA = torch .cuda .is_available ()
304
314
305
315
model = DQN ()
306
- memory = ReplayMemory (10000 )
307
- optimizer = optim .RMSprop (model .parameters ())
308
316
309
- if USE_CUDA :
317
+ if use_cuda :
310
318
model .cuda ()
311
319
312
-
313
- class Variable (autograd .Variable ):
314
-
315
- def __init__ (self , data , * args , ** kwargs ):
316
- if USE_CUDA :
317
- data = data .cuda ()
318
- super (Variable , self ).__init__ (data , * args , ** kwargs )
320
+ optimizer = optim .RMSprop (model .parameters ())
321
+ memory = ReplayMemory (10000 )
319
322
320
323
321
324
steps_done = 0
@@ -328,9 +331,10 @@ def select_action(state):
328
331
math .exp (- 1. * steps_done / EPS_DECAY )
329
332
steps_done += 1
330
333
if sample > eps_threshold :
331
- return model (Variable (state , volatile = True )).data .max (1 )[1 ].cpu ()
334
+ return model (
335
+ Variable (state , volatile = True ).type (FloatTensor )).data .max (1 )[1 ]
332
336
else :
333
- return torch . LongTensor ([[random .randrange (2 )]])
337
+ return LongTensor ([[random .randrange (2 )]])
334
338
335
339
336
340
episode_durations = []
@@ -339,7 +343,7 @@ def select_action(state):
339
343
def plot_durations ():
340
344
plt .figure (2 )
341
345
plt .clf ()
342
- durations_t = torch .Tensor (episode_durations )
346
+ durations_t = torch .FloatTensor (episode_durations )
343
347
plt .title ('Training...' )
344
348
plt .xlabel ('Episode' )
345
349
plt .ylabel ('Duration' )
@@ -370,6 +374,7 @@ def plot_durations():
370
374
371
375
last_sync = 0
372
376
377
+
373
378
def optimize_model ():
374
379
global last_sync
375
380
if len (memory ) < BATCH_SIZE :
@@ -380,10 +385,9 @@ def optimize_model():
380
385
batch = Transition (* zip (* transitions ))
381
386
382
387
# Compute a mask of non-final states and concatenate the batch elements
383
- non_final_mask = torch .ByteTensor (
384
- tuple (map (lambda s : s is not None , batch .next_state )))
385
- if USE_CUDA :
386
- non_final_mask = non_final_mask .cuda ()
388
+ non_final_mask = ByteTensor (tuple (map (lambda s : s is not None ,
389
+ batch .next_state )))
390
+
387
391
# We don't want to backprop through the expected action values and volatile
388
392
# will save us on temporarily changing the model parameters'
389
393
# requires_grad to False!
@@ -440,7 +444,7 @@ def optimize_model():
440
444
# Select and perform an action
441
445
action = select_action (state )
442
446
_ , reward , done , _ = env .step (action [0 , 0 ])
443
- reward = torch . Tensor ([reward ])
447
+ reward = Tensor ([reward ])
444
448
445
449
# Observe new state
446
450
last_screen = current_screen
@@ -463,6 +467,7 @@ def optimize_model():
463
467
plot_durations ()
464
468
break
465
469
470
+ print ('Complete' )
466
471
env .close ()
467
472
plt .ioff ()
468
473
plt .show ()
0 commit comments