diff --git a/intermediate_source/reinforcement_q_learning.py b/intermediate_source/reinforcement_q_learning.py index 611cfb32448..1522db24bc1 100644 --- a/intermediate_source/reinforcement_q_learning.py +++ b/intermediate_source/reinforcement_q_learning.py @@ -297,11 +297,14 @@ def select_action(state): episode_durations = [] -def plot_durations(): +def plot_durations(show_result=False): plt.figure(1) - plt.clf() durations_t = torch.tensor(episode_durations, dtype=torch.float) - plt.title('Training...') + if show_result: + plt.title('Result') + else: + plt.clf() + plt.title('Training...') plt.xlabel('Episode') plt.ylabel('Duration') plt.plot(durations_t.numpy()) @@ -313,8 +316,11 @@ def plot_durations(): plt.pause(0.001) # pause a bit so that plots are updated if is_ipython: - display.clear_output(wait=True) - display.display(plt.gcf()) + if not show_result: + display.display(plt.gcf()) + display.clear_output(wait=True) + else: + display.display(plt.gcf()) ###################################################################### @@ -443,6 +449,7 @@ def optimize_model(): break print('Complete') +plot_durations(show_result=True) plt.ioff() plt.show()