diff --git a/intermediate_source/reinforcement_q_learning.py b/intermediate_source/reinforcement_q_learning.py index d52c4f5fef1..cb9abc229c9 100644 --- a/intermediate_source/reinforcement_q_learning.py +++ b/intermediate_source/reinforcement_q_learning.py @@ -6,14 +6,14 @@ This tutorial shows how to use PyTorch to train a Deep Q Learning (DQN) agent -on the CartPole-v0 task from the `OpenAI Gym `__. +on the CartPole-v0 task from the `OpenAI Gym `__. **Task** The agent has to decide between two actions - moving the cart left or right - so that the pole attached to it stays upright. You can find an official leaderboard with various algorithms and visualizations at the -`Gym website `__. +`Gym website `__. .. figure:: /_static/img/cartpole.gif :alt: cartpole @@ -74,7 +74,7 @@ import torchvision.transforms as T -env = gym.make('CartPole-v0').unwrapped +env = gym.make('CartPole-v0', new_step_api=True, render_mode='single_rgb_array').unwrapped # set up matplotlib is_ipython = 'inline' in matplotlib.get_backend() @@ -254,7 +254,7 @@ def get_cart_location(screen_width): def get_screen(): # Returned screen requested by gym is 400x600x3, but is sometimes larger # such as 800x1200x3. Transpose it into torch order (CHW). - screen = env.render(mode='rgb_array').transpose((2, 0, 1)) + screen = env.render().transpose((2, 0, 1)) # Cart is in the lower half, so strip off the top and bottom of the screen _, screen_height, screen_width = screen.shape screen = screen[:, int(screen_height*0.4):int(screen_height * 0.8)] @@ -461,7 +461,7 @@ def optimize_model(): for t in count(): # Select and perform an action action = select_action(state) - _, reward, done, _ = env.step(action.item()) + _, reward, done, _, _ = env.step(action.item()) reward = torch.tensor([reward], device=device) # Observe new state