Skip to content

Commit 0c252e6

Browse files
Modify plot_durations
1 parent 684421c commit 0c252e6

File tree

1 file changed

+10
-9
lines changed

1 file changed

+10
-9
lines changed

intermediate_source/reinforcement_q_learning.py

Lines changed: 10 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -297,11 +297,14 @@ def select_action(state):
297297
episode_durations = []
298298

299299

300-
def plot_durations():
300+
def plot_durations(show_result=False):
301301
plt.figure(1)
302-
plt.clf()
303302
durations_t = torch.tensor(episode_durations, dtype=torch.float)
304-
plt.title('Training...')
303+
if show_result:
304+
plt.title('Result')
305+
else:
306+
plt.clf()
307+
plt.title('Training...')
305308
plt.xlabel('Episode')
306309
plt.ylabel('Duration')
307310
plt.plot(durations_t.numpy())
@@ -312,9 +315,11 @@ def plot_durations():
312315
plt.plot(means.numpy())
313316

314317
plt.pause(0.001) # pause a bit so that plots are updated
315-
if is_ipython:
318+
if is_ipython and not show_result:
316319
display.display(plt.gcf())
317320
display.clear_output(wait=True)
321+
else:
322+
display.display(plt.gcf())
318323

319324

320325
######################################################################
@@ -443,11 +448,7 @@ def optimize_model():
443448
break
444449

445450
print('Complete')
446-
durations_t = torch.tensor(episode_durations, dtype=torch.float)
447-
plt.title('Result')
448-
plt.xlabel('Episode')
449-
plt.ylabel('Duration')
450-
plt.plot(durations_t.numpy())
451+
plot_durations(show_result=True)
451452
plt.ioff()
452453
plt.show()
453454

0 commit comments

Comments
 (0)