From c8e5d84fbd8e3da4e82789bf4092b74a5dfec567 Mon Sep 17 00:00:00 2001 From: StringTheory Date: Wed, 11 Jan 2023 20:00:23 +0000 Subject: [PATCH] initial commit --- .../reinforcement_q_learning.py | 43 ++++++++----------- requirements.txt | 1 + 2 files changed, 18 insertions(+), 26 deletions(-) diff --git a/intermediate_source/reinforcement_q_learning.py b/intermediate_source/reinforcement_q_learning.py index 61e374a8b21..dd4a84d5908 100644 --- a/intermediate_source/reinforcement_q_learning.py +++ b/intermediate_source/reinforcement_q_learning.py @@ -3,17 +3,18 @@ Reinforcement Learning (DQN) Tutorial ===================================== **Author**: `Adam Paszke `_ + `Mark Towers `_ This tutorial shows how to use PyTorch to train a Deep Q Learning (DQN) agent -on the CartPole-v1 task from the `OpenAI Gym `__. +on the CartPole-v1 task from `Gymnasium `__. **Task** The agent has to decide between two actions - moving the cart left or -right - so that the pole attached to it stays upright. You can find an -official leaderboard with various algorithms and visualizations at the -`Gym website `__. +right - so that the pole attached to it stays upright. You can find more +information about the environment and other more challenging environments at +`Gymnasium's website `__. .. figure:: /_static/img/cartpole.gif :alt: cartpole @@ -24,7 +25,7 @@ an action, the environment *transitions* to a new state, and also returns a reward that indicates the consequences of the action. In this task, rewards are +1 for every incremental timestep and the environment -terminates if the pole falls over too far or the cart moves more then 2.4 +terminates if the pole falls over too far or the cart moves more than 2.4 units away from center. This means better performing scenarios will run for longer duration, accumulating larger return. @@ -41,13 +42,15 @@ First, let's import needed packages. Firstly, we need -`gym `__ for the environment -Install by using `pip`. If you are running this in Google colab, run: +`gymnasium `__ for the environment, +installed by using `pip`. This is a fork of the original OpenAI +Gym project and maintained by the same team since Gym v0.19. +If you are running this in Google colab, run: .. code-block:: bash %%bash - pip3 install gym[classic_control] + pip3 install gymnasium[classic_control] We'll also use the following from PyTorch: @@ -57,10 +60,9 @@ """ -import gym +import gymnasium as gym import math import random -import numpy as np import matplotlib import matplotlib.pyplot as plt from collections import namedtuple, deque @@ -71,12 +73,7 @@ import torch.optim as optim import torch.nn.functional as F -if gym.__version__[:4] == '0.26': - env = gym.make('CartPole-v1') -elif gym.__version__[:4] == '0.25': - env = gym.make('CartPole-v1', new_step_api=True) -else: - raise ImportError(f"Requires gym v25 or v26, actual version: {gym.__version__}") +env = gym.make("CartPole-v1") # set up matplotlib is_ipython = 'inline' in matplotlib.get_backend() @@ -117,7 +114,7 @@ class ReplayMemory(object): def __init__(self, capacity): - self.memory = deque([],maxlen=capacity) + self.memory = deque([], maxlen=capacity) def push(self, *args): """Save a transition""" @@ -261,10 +258,7 @@ def forward(self, x): # Get number of actions from gym action space n_actions = env.action_space.n # Get the number of state observations -if gym.__version__[:4] == '0.26': - state, _ = env.reset() -elif gym.__version__[:4] == '0.25': - state, _ = env.reset(return_info=True) +state, info = env.reset() n_observations = len(state) policy_net = DQN(n_observations, n_actions).to(device) @@ -286,7 +280,7 @@ def select_action(state): steps_done += 1 if sample > eps_threshold: with torch.no_grad(): - # t.max(1) will return largest column value of each row. + # t.max(1) will return the largest column value of each row. # second column on max result is index of where max element was # found, so we pick action with the larger expected reward. return policy_net(state).max(1)[1].view(1, 1) @@ -410,10 +404,7 @@ def optimize_model(): for i_episode in range(num_episodes): # Initialize the environment and get it's state - if gym.__version__[:4] == '0.26': - state, _ = env.reset() - elif gym.__version__[:4] == '0.25': - state, _ = env.reset(return_info=True) + state, info = env.reset() state = torch.tensor(state, dtype=torch.float32, device=device).unsqueeze(0) for t in count(): action = select_action(state) diff --git a/requirements.txt b/requirements.txt index 367a9d576c3..53a62221fea 100644 --- a/requirements.txt +++ b/requirements.txt @@ -51,6 +51,7 @@ pillow==9.3.0 wget gym==0.25.1 gym-super-mario-bros==7.4.0 +gymnasium==0.27.0 timm iopath pygame==2.1.2