From 9662c7c0c6ba6f7174694352514d47a5e107d049 Mon Sep 17 00:00:00 2001 From: Yichi-Lionel-Cheung Date: Tue, 6 Dec 2022 14:24:42 +0800 Subject: [PATCH 1/7] Fix a display bug which cannot update figure correctly. --- intermediate_source/reinforcement_q_learning.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/intermediate_source/reinforcement_q_learning.py b/intermediate_source/reinforcement_q_learning.py index 7353bba05a4..4cfdfb37564 100644 --- a/intermediate_source/reinforcement_q_learning.py +++ b/intermediate_source/reinforcement_q_learning.py @@ -378,8 +378,8 @@ def plot_durations(): plt.pause(0.001) # pause a bit so that plots are updated if is_ipython: - display.clear_output(wait=True) display.display(plt.gcf()) + display.clear_output(wait=True) ###################################################################### From 69dc83c76826116e523486e34babf114a314c0df Mon Sep 17 00:00:00 2001 From: Yichi-Lionel-Cheung Date: Thu, 8 Dec 2022 19:54:22 +0800 Subject: [PATCH 2/7] fix display bug --- intermediate_source/reinforcement_q_learning.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/intermediate_source/reinforcement_q_learning.py b/intermediate_source/reinforcement_q_learning.py index 4cfdfb37564..5ac4bcecae3 100644 --- a/intermediate_source/reinforcement_q_learning.py +++ b/intermediate_source/reinforcement_q_learning.py @@ -378,8 +378,8 @@ def plot_durations(): plt.pause(0.001) # pause a bit so that plots are updated if is_ipython: - display.display(plt.gcf()) display.clear_output(wait=True) + display.display(plt.gcf()) ###################################################################### @@ -500,7 +500,12 @@ def optimize_model(): print('Complete') env.render() env.close() +durations_t = torch.tensor(episode_durations, dtype=torch.float) +plt.title('Result') +plt.xlabel('Episode') +plt.ylabel('Duration') plt.ioff() +plt.plot(durations_t.numpy()) plt.show() ###################################################################### From da66936fe8f512f52184068100749bd5a00c2ea9 Mon Sep 17 00:00:00 2001 From: Yichi-Lionel-Cheung Date: Thu, 8 Dec 2022 20:37:32 +0800 Subject: [PATCH 3/7] fix display bug --- intermediate_source/reinforcement_q_learning.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/intermediate_source/reinforcement_q_learning.py b/intermediate_source/reinforcement_q_learning.py index 5ac4bcecae3..c288f6fd34e 100644 --- a/intermediate_source/reinforcement_q_learning.py +++ b/intermediate_source/reinforcement_q_learning.py @@ -378,8 +378,8 @@ def plot_durations(): plt.pause(0.001) # pause a bit so that plots are updated if is_ipython: - display.clear_output(wait=True) display.display(plt.gcf()) + display.clear_output(wait=True) ###################################################################### From 3b2fd56deadcefca5aec01881b7be183bd0f9191 Mon Sep 17 00:00:00 2001 From: Yichi-Lionel-Cheung Date: Mon, 12 Dec 2022 17:35:16 +0800 Subject: [PATCH 4/7] Resolve conflict --- .../reinforcement_q_learning.py | 23 ++----------------- 1 file changed, 2 insertions(+), 21 deletions(-) diff --git a/intermediate_source/reinforcement_q_learning.py b/intermediate_source/reinforcement_q_learning.py index 37e6764127a..816c807a57b 100644 --- a/intermediate_source/reinforcement_q_learning.py +++ b/intermediate_source/reinforcement_q_learning.py @@ -3,23 +3,16 @@ Reinforcement Learning (DQN) Tutorial ===================================== **Author**: `Adam Paszke `_ - - This tutorial shows how to use PyTorch to train a Deep Q Learning (DQN) agent on the CartPole-v1 task from the `OpenAI Gym `__. - **Task** - The agent has to decide between two actions - moving the cart left or right - so that the pole attached to it stays upright. You can find an official leaderboard with various algorithms and visualizations at the `Gym website `__. - .. figure:: /_static/img/cartpole.gif :alt: cartpole - cartpole - As the agent observes the current state of the environment and chooses an action, the environment *transitions* to a new state, and also returns a reward that indicates the consequences of the action. In this @@ -27,7 +20,6 @@ terminates if the pole falls over too far or the cart moves more then 2.4 units away from center. This means better performing scenarios will run for longer duration, accumulating larger return. - The CartPole task is designed so that the inputs to the agent are 4 real values representing the environment state (position, velocity, etc.). We take these 4 inputs without any scaling and pass them through a @@ -35,26 +27,17 @@ The network is trained to predict the expected value for each action, given the input state. The action with the highest expected value is then chosen. - - **Packages** - - First, let's import needed packages. Firstly, we need `gym `__ for the environment Install by using `pip`. If you are running this in Google colab, run: - .. code-block:: bash - %%bash pip3 install gym[classic_control] - We'll also use the following from PyTorch: - - neural networks (``torch.nn``) - optimization (``torch.optim``) - automatic differentiation (``torch.autograd``) - """ import gym @@ -443,14 +426,12 @@ def optimize_model(): break print('Complete') -env.render() -env.close() durations_t = torch.tensor(episode_durations, dtype=torch.float) plt.title('Result') plt.xlabel('Episode') plt.ylabel('Duration') -plt.ioff() plt.plot(durations_t.numpy()) +plt.ioff() plt.show() ###################################################################### @@ -464,4 +445,4 @@ def optimize_model(): # Optimization picks a random batch from the replay memory to do training of the # new policy. The "older" target_net is also used in optimization to compute the # expected Q values. A soft update of its weights are performed at every step. -# +# \ No newline at end of file From 684421ce42430474f32b641caab11bd1da23ef80 Mon Sep 17 00:00:00 2001 From: Yichi-Lionel-Cheung Date: Mon, 12 Dec 2022 17:43:47 +0800 Subject: [PATCH 5/7] Resolve conflict --- .../reinforcement_q_learning.py | 19 ++++++++++++++++++- 1 file changed, 18 insertions(+), 1 deletion(-) diff --git a/intermediate_source/reinforcement_q_learning.py b/intermediate_source/reinforcement_q_learning.py index 816c807a57b..9e3b3738406 100644 --- a/intermediate_source/reinforcement_q_learning.py +++ b/intermediate_source/reinforcement_q_learning.py @@ -3,16 +3,23 @@ Reinforcement Learning (DQN) Tutorial ===================================== **Author**: `Adam Paszke `_ + + This tutorial shows how to use PyTorch to train a Deep Q Learning (DQN) agent on the CartPole-v1 task from the `OpenAI Gym `__. + **Task** + The agent has to decide between two actions - moving the cart left or right - so that the pole attached to it stays upright. You can find an official leaderboard with various algorithms and visualizations at the `Gym website `__. + .. figure:: /_static/img/cartpole.gif :alt: cartpole + cartpole + As the agent observes the current state of the environment and chooses an action, the environment *transitions* to a new state, and also returns a reward that indicates the consequences of the action. In this @@ -20,6 +27,7 @@ terminates if the pole falls over too far or the cart moves more then 2.4 units away from center. This means better performing scenarios will run for longer duration, accumulating larger return. + The CartPole task is designed so that the inputs to the agent are 4 real values representing the environment state (position, velocity, etc.). We take these 4 inputs without any scaling and pass them through a @@ -27,17 +35,26 @@ The network is trained to predict the expected value for each action, given the input state. The action with the highest expected value is then chosen. + + **Packages** + + First, let's import needed packages. Firstly, we need `gym `__ for the environment Install by using `pip`. If you are running this in Google colab, run: + .. code-block:: bash + %%bash pip3 install gym[classic_control] + We'll also use the following from PyTorch: + - neural networks (``torch.nn``) - optimization (``torch.optim``) - automatic differentiation (``torch.autograd``) + """ import gym @@ -445,4 +462,4 @@ def optimize_model(): # Optimization picks a random batch from the replay memory to do training of the # new policy. The "older" target_net is also used in optimization to compute the # expected Q values. A soft update of its weights are performed at every step. -# \ No newline at end of file +# From 0c252e6eba1d2201b86c57efdcb50cbb8562478b Mon Sep 17 00:00:00 2001 From: Yichi-Lionel-Cheung Date: Mon, 12 Dec 2022 18:18:40 +0800 Subject: [PATCH 6/7] Modify plot_durations --- .../reinforcement_q_learning.py | 19 ++++++++++--------- 1 file changed, 10 insertions(+), 9 deletions(-) diff --git a/intermediate_source/reinforcement_q_learning.py b/intermediate_source/reinforcement_q_learning.py index 9e3b3738406..be7df7cb615 100644 --- a/intermediate_source/reinforcement_q_learning.py +++ b/intermediate_source/reinforcement_q_learning.py @@ -297,11 +297,14 @@ def select_action(state): episode_durations = [] -def plot_durations(): +def plot_durations(show_result=False): plt.figure(1) - plt.clf() durations_t = torch.tensor(episode_durations, dtype=torch.float) - plt.title('Training...') + if show_result: + plt.title('Result') + else: + plt.clf() + plt.title('Training...') plt.xlabel('Episode') plt.ylabel('Duration') plt.plot(durations_t.numpy()) @@ -312,9 +315,11 @@ def plot_durations(): plt.plot(means.numpy()) plt.pause(0.001) # pause a bit so that plots are updated - if is_ipython: + if is_ipython and not show_result: display.display(plt.gcf()) display.clear_output(wait=True) + else: + display.display(plt.gcf()) ###################################################################### @@ -443,11 +448,7 @@ def optimize_model(): break print('Complete') -durations_t = torch.tensor(episode_durations, dtype=torch.float) -plt.title('Result') -plt.xlabel('Episode') -plt.ylabel('Duration') -plt.plot(durations_t.numpy()) +plot_durations(show_result=True) plt.ioff() plt.show() From b3d451866b64ada5ba664f98b1bc745c80240722 Mon Sep 17 00:00:00 2001 From: Yichi-Lionel-Cheung Date: Mon, 12 Dec 2022 19:00:13 +0800 Subject: [PATCH 7/7] fix a bug when not using ipython --- intermediate_source/reinforcement_q_learning.py | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/intermediate_source/reinforcement_q_learning.py b/intermediate_source/reinforcement_q_learning.py index be7df7cb615..1522db24bc1 100644 --- a/intermediate_source/reinforcement_q_learning.py +++ b/intermediate_source/reinforcement_q_learning.py @@ -315,11 +315,12 @@ def plot_durations(show_result=False): plt.plot(means.numpy()) plt.pause(0.001) # pause a bit so that plots are updated - if is_ipython and not show_result: - display.display(plt.gcf()) - display.clear_output(wait=True) - else: - display.display(plt.gcf()) + if is_ipython: + if not show_result: + display.display(plt.gcf()) + display.clear_output(wait=True) + else: + display.display(plt.gcf()) ######################################################################