3
3
Reinforcement Learning (DQN) Tutorial
4
4
=====================================
5
5
**Author**: `Adam Paszke <https://github.com/apaszke>`_
6
+
7
+
6
8
This tutorial shows how to use PyTorch to train a Deep Q Learning (DQN) agent
7
9
on the CartPole-v1 task from the `OpenAI Gym <https://www.gymlibrary.dev/>`__.
10
+
8
11
**Task**
12
+
9
13
The agent has to decide between two actions - moving the cart left or
10
14
right - so that the pole attached to it stays upright. You can find an
11
15
official leaderboard with various algorithms and visualizations at the
12
16
`Gym website <https://www.gymlibrary.dev/environments/classic_control/cart_pole>`__.
17
+
13
18
.. figure:: /_static/img/cartpole.gif
14
19
:alt: cartpole
20
+
15
21
cartpole
22
+
16
23
As the agent observes the current state of the environment and chooses
17
24
an action, the environment *transitions* to a new state, and also
18
25
returns a reward that indicates the consequences of the action. In this
19
26
task, rewards are +1 for every incremental timestep and the environment
20
27
terminates if the pole falls over too far or the cart moves more then 2.4
21
28
units away from center. This means better performing scenarios will run
22
29
for longer duration, accumulating larger return.
30
+
23
31
The CartPole task is designed so that the inputs to the agent are 4 real
24
32
values representing the environment state (position, velocity, etc.).
25
33
We take these 4 inputs without any scaling and pass them through a
26
34
small fully-connected network with 2 outputs, one for each action.
27
35
The network is trained to predict the expected value for each action,
28
36
given the input state. The action with the highest expected value is
29
37
then chosen.
38
+
39
+
30
40
**Packages**
41
+
42
+
31
43
First, let's import needed packages. Firstly, we need
32
44
`gym <https://github.com/openai/gym>`__ for the environment
33
45
Install by using `pip`. If you are running this in Google colab, run:
46
+
34
47
.. code-block:: bash
48
+
35
49
%%bash
36
50
pip3 install gym[classic_control]
51
+
37
52
We'll also use the following from PyTorch:
53
+
38
54
- neural networks (``torch.nn``)
39
55
- optimization (``torch.optim``)
40
56
- automatic differentiation (``torch.autograd``)
57
+
41
58
"""
42
59
43
60
import gym
@@ -445,4 +462,4 @@ def optimize_model():
445
462
# Optimization picks a random batch from the replay memory to do training of the
446
463
# new policy. The "older" target_net is also used in optimization to compute the
447
464
# expected Q values. A soft update of its weights are performed at every step.
448
- #
465
+ #
0 commit comments