Skip to content

Commit f71137c

Browse files
authored
Merge pull request #67 from chsasank/patch-rl
Patch RL tutorial
2 parents e5d1b46 + cefcbbb commit f71137c

File tree

1 file changed

+10
-5
lines changed

1 file changed

+10
-5
lines changed

intermediate_source/reinforcement_q_learning.py

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -72,13 +72,13 @@
7272
import torch.nn.functional as F
7373
import torchvision.transforms as T
7474

75-
env = gym.make('CartPole-v0')
75+
env = gym.make('CartPole-v0').unwrapped
7676

7777
is_ipython = 'inline' in matplotlib.get_backend()
7878
if is_ipython:
7979
from IPython import display
8080

81-
81+
plt.ion()
8282
######################################################################
8383
# Replay Memory
8484
# -------------
@@ -263,8 +263,10 @@ def get_screen():
263263
return resize(screen).unsqueeze(0)
264264

265265
env.reset()
266+
plt.figure()
266267
plt.imshow(get_screen().squeeze(0).permute(
267268
1, 2, 0).numpy(), interpolation='none')
269+
plt.title('Example extracted screen')
268270
plt.show()
269271

270272

@@ -335,9 +337,10 @@ def select_action(state):
335337

336338

337339
def plot_durations():
338-
plt.figure(1)
340+
plt.figure(2)
339341
plt.clf()
340342
durations_t = torch.Tensor(episode_durations)
343+
plt.title('Training...')
341344
plt.xlabel('Episode')
342345
plt.ylabel('Duration')
343346
plt.plot(durations_t.numpy())
@@ -367,7 +370,6 @@ def plot_durations():
367370

368371
last_sync = 0
369372

370-
371373
def optimize_model():
372374
global last_sync
373375
if len(memory) < BATCH_SIZE:
@@ -456,8 +458,11 @@ def optimize_model():
456458

457459
# Perform one step of the optimization (on the target network)
458460
optimize_model()
459-
460461
if done:
461462
episode_durations.append(t + 1)
462463
plot_durations()
463464
break
465+
466+
env.close()
467+
plt.ioff()
468+
plt.show()

0 commit comments

Comments
 (0)