File tree 1 file changed +10
-5
lines changed 1 file changed +10
-5
lines changed Original file line number Diff line number Diff line change 72
72
import torch .nn .functional as F
73
73
import torchvision .transforms as T
74
74
75
- env = gym .make ('CartPole-v0' )
75
+ env = gym .make ('CartPole-v0' ). unwrapped
76
76
77
77
is_ipython = 'inline' in matplotlib .get_backend ()
78
78
if is_ipython :
79
79
from IPython import display
80
80
81
-
81
+ plt . ion ()
82
82
######################################################################
83
83
# Replay Memory
84
84
# -------------
@@ -263,8 +263,10 @@ def get_screen():
263
263
return resize (screen ).unsqueeze (0 )
264
264
265
265
env .reset ()
266
+ plt .figure ()
266
267
plt .imshow (get_screen ().squeeze (0 ).permute (
267
268
1 , 2 , 0 ).numpy (), interpolation = 'none' )
269
+ plt .title ('Example extracted screen' )
268
270
plt .show ()
269
271
270
272
@@ -335,9 +337,10 @@ def select_action(state):
335
337
336
338
337
339
def plot_durations ():
338
- plt .figure (1 )
340
+ plt .figure (2 )
339
341
plt .clf ()
340
342
durations_t = torch .Tensor (episode_durations )
343
+ plt .title ('Training...' )
341
344
plt .xlabel ('Episode' )
342
345
plt .ylabel ('Duration' )
343
346
plt .plot (durations_t .numpy ())
@@ -367,7 +370,6 @@ def plot_durations():
367
370
368
371
last_sync = 0
369
372
370
-
371
373
def optimize_model ():
372
374
global last_sync
373
375
if len (memory ) < BATCH_SIZE :
@@ -456,8 +458,11 @@ def optimize_model():
456
458
457
459
# Perform one step of the optimization (on the target network)
458
460
optimize_model ()
459
-
460
461
if done :
461
462
episode_durations .append (t + 1 )
462
463
plot_durations ()
463
464
break
465
+
466
+ env .close ()
467
+ plt .ioff ()
468
+ plt .show ()
You can’t perform that action at this time.
0 commit comments