3
3
Reinforcement Learning (DQN) Tutorial
4
4
=====================================
5
5
**Author**: `Adam Paszke <https://github.com/apaszke>`_
6
-
7
-
8
6
This tutorial shows how to use PyTorch to train a Deep Q Learning (DQN) agent
9
7
on the CartPole-v1 task from the `OpenAI Gym <https://www.gymlibrary.dev/>`__.
10
-
11
8
**Task**
12
-
13
9
The agent has to decide between two actions - moving the cart left or
14
10
right - so that the pole attached to it stays upright. You can find an
15
11
official leaderboard with various algorithms and visualizations at the
16
12
`Gym website <https://www.gymlibrary.dev/environments/classic_control/cart_pole>`__.
17
-
18
13
.. figure:: /_static/img/cartpole.gif
19
14
:alt: cartpole
20
-
21
15
cartpole
22
-
23
16
As the agent observes the current state of the environment and chooses
24
17
an action, the environment *transitions* to a new state, and also
25
18
returns a reward that indicates the consequences of the action. In this
26
19
task, rewards are +1 for every incremental timestep and the environment
27
20
terminates if the pole falls over too far or the cart moves more then 2.4
28
21
units away from center. This means better performing scenarios will run
29
22
for longer duration, accumulating larger return.
30
-
31
23
The CartPole task is designed so that the inputs to the agent are 4 real
32
24
values representing the environment state (position, velocity, etc.).
33
25
We take these 4 inputs without any scaling and pass them through a
34
26
small fully-connected network with 2 outputs, one for each action.
35
27
The network is trained to predict the expected value for each action,
36
28
given the input state. The action with the highest expected value is
37
29
then chosen.
38
-
39
-
40
30
**Packages**
41
-
42
-
43
31
First, let's import needed packages. Firstly, we need
44
32
`gym <https://github.com/openai/gym>`__ for the environment
45
33
Install by using `pip`. If you are running this in Google colab, run:
46
-
47
34
.. code-block:: bash
48
-
49
35
%%bash
50
36
pip3 install gym[classic_control]
51
-
52
37
We'll also use the following from PyTorch:
53
-
54
38
- neural networks (``torch.nn``)
55
39
- optimization (``torch.optim``)
56
40
- automatic differentiation (``torch.autograd``)
57
-
58
41
"""
59
42
60
43
import gym
@@ -443,14 +426,12 @@ def optimize_model():
443
426
break
444
427
445
428
print ('Complete' )
446
- env .render ()
447
- env .close ()
448
429
durations_t = torch .tensor (episode_durations , dtype = torch .float )
449
430
plt .title ('Result' )
450
431
plt .xlabel ('Episode' )
451
432
plt .ylabel ('Duration' )
452
- plt .ioff ()
453
433
plt .plot (durations_t .numpy ())
434
+ plt .ioff ()
454
435
plt .show ()
455
436
456
437
######################################################################
@@ -464,4 +445,4 @@ def optimize_model():
464
445
# Optimization picks a random batch from the replay memory to do training of the
465
446
# new policy. The "older" target_net is also used in optimization to compute the
466
447
# expected Q values. A soft update of its weights are performed at every step.
467
- #
448
+ #
0 commit comments