68
68
import torch
69
69
import torch .nn as nn
70
70
import torch .optim as optim
71
- import torch .autograd as autograd
72
71
import torch .nn .functional as F
72
+ from torch .autograd import Variable
73
73
import torchvision .transforms as T
74
74
75
+
75
76
env = gym .make ('CartPole-v0' ).unwrapped
76
77
78
+ # set up matplotlib
77
79
is_ipython = 'inline' in matplotlib .get_backend ()
78
80
if is_ipython :
79
81
from IPython import display
80
82
81
83
plt .ion ()
84
+
85
+ # if gpu is to be used
86
+ use_cuda = torch .cuda .is_available ()
87
+ FloatTensor = torch .cuda .FloatTensor if use_cuda else torch .FloatTensor
88
+ LongTensor = torch .cuda .LongTensor if use_cuda else torch .LongTensor
89
+ ByteTensor = torch .cuda .ByteTensor if use_cuda else torch .ByteTensor
90
+ Tensor = FloatTensor
91
+
92
+
82
93
######################################################################
83
94
# Replay Memory
84
95
# -------------
@@ -260,12 +271,12 @@ def get_screen():
260
271
screen = np .ascontiguousarray (screen , dtype = np .float32 ) / 255
261
272
screen = torch .from_numpy (screen )
262
273
# Resize, and add a batch dimension (BCHW)
263
- return resize (screen ).unsqueeze (0 )
274
+ return resize (screen ).unsqueeze (0 ). type ( Tensor )
264
275
265
276
env .reset ()
266
277
plt .figure ()
267
- plt .imshow (get_screen ().squeeze (0 ).permute (
268
- 1 , 2 , 0 ). numpy (), interpolation = 'none' )
278
+ plt .imshow (get_screen ().cpu (). squeeze (0 ).permute (1 , 2 , 0 ). numpy (),
279
+ interpolation = 'none' )
269
280
plt .title ('Example extracted screen' )
270
281
plt .show ()
271
282
@@ -300,22 +311,14 @@ def get_screen():
300
311
EPS_START = 0.9
301
312
EPS_END = 0.05
302
313
EPS_DECAY = 200
303
- USE_CUDA = torch .cuda .is_available ()
304
314
305
315
model = DQN ()
306
- memory = ReplayMemory (10000 )
307
- optimizer = optim .RMSprop (model .parameters ())
308
316
309
- if USE_CUDA :
317
+ if use_cuda :
310
318
model .cuda ()
311
319
312
-
313
- class Variable (autograd .Variable ):
314
-
315
- def __init__ (self , data , * args , ** kwargs ):
316
- if USE_CUDA :
317
- data = data .cuda ()
318
- super (Variable , self ).__init__ (data , * args , ** kwargs )
320
+ optimizer = optim .RMSprop (model .parameters ())
321
+ memory = ReplayMemory (10000 )
319
322
320
323
321
324
steps_done = 0
@@ -328,9 +331,10 @@ def select_action(state):
328
331
math .exp (- 1. * steps_done / EPS_DECAY )
329
332
steps_done += 1
330
333
if sample > eps_threshold :
331
- return model (Variable (state , volatile = True )).data .max (1 )[1 ].cpu ()
334
+ return model (
335
+ Variable (state , volatile = True ).type (FloatTensor )).data .max (1 )[1 ]
332
336
else :
333
- return torch . LongTensor ([[random .randrange (2 )]])
337
+ return LongTensor ([[random .randrange (2 )]])
334
338
335
339
336
340
episode_durations = []
@@ -339,7 +343,7 @@ def select_action(state):
339
343
def plot_durations ():
340
344
plt .figure (2 )
341
345
plt .clf ()
342
- durations_t = torch .Tensor (episode_durations )
346
+ durations_t = torch .FloatTensor (episode_durations )
343
347
plt .title ('Training...' )
344
348
plt .xlabel ('Episode' )
345
349
plt .ylabel ('Duration' )
@@ -349,6 +353,8 @@ def plot_durations():
349
353
means = durations_t .unfold (0 , 100 , 1 ).mean (1 ).view (- 1 )
350
354
means = torch .cat ((torch .zeros (99 ), means ))
351
355
plt .plot (means .numpy ())
356
+
357
+ plt .pause (0.001 ) # pause a bit so that plots are updated
352
358
if is_ipython :
353
359
display .clear_output (wait = True )
354
360
display .display (plt .gcf ())
@@ -370,6 +376,7 @@ def plot_durations():
370
376
371
377
last_sync = 0
372
378
379
+
373
380
def optimize_model ():
374
381
global last_sync
375
382
if len (memory ) < BATCH_SIZE :
@@ -380,10 +387,9 @@ def optimize_model():
380
387
batch = Transition (* zip (* transitions ))
381
388
382
389
# Compute a mask of non-final states and concatenate the batch elements
383
- non_final_mask = torch .ByteTensor (
384
- tuple (map (lambda s : s is not None , batch .next_state )))
385
- if USE_CUDA :
386
- non_final_mask = non_final_mask .cuda ()
390
+ non_final_mask = ByteTensor (tuple (map (lambda s : s is not None ,
391
+ batch .next_state )))
392
+
387
393
# We don't want to backprop through the expected action values and volatile
388
394
# will save us on temporarily changing the model parameters'
389
395
# requires_grad to False!
@@ -399,7 +405,7 @@ def optimize_model():
399
405
state_action_values = model (state_batch ).gather (1 , action_batch )
400
406
401
407
# Compute V(s_{t+1}) for all next states.
402
- next_state_values = Variable (torch .zeros (BATCH_SIZE ))
408
+ next_state_values = Variable (torch .zeros (BATCH_SIZE ). type ( Tensor ) )
403
409
next_state_values [non_final_mask ] = model (non_final_next_states ).max (1 )[0 ]
404
410
# Now, we don't want to mess up the loss with a volatile flag, so let's
405
411
# clear it. After this, we'll just end up with a Variable that has
@@ -440,7 +446,7 @@ def optimize_model():
440
446
# Select and perform an action
441
447
action = select_action (state )
442
448
_ , reward , done , _ = env .step (action [0 , 0 ])
443
- reward = torch . Tensor ([reward ])
449
+ reward = Tensor ([reward ])
444
450
445
451
# Observe new state
446
452
last_screen = current_screen
@@ -463,6 +469,8 @@ def optimize_model():
463
469
plot_durations ()
464
470
break
465
471
472
+ print ('Complete' )
473
+ env .render (close = True )
466
474
env .close ()
467
475
plt .ioff ()
468
476
plt .show ()
0 commit comments