Skip to content

Commit deb0e34

Browse files
authored
Update RL Mario Tutorial (#2075)
By switching to new API if gym-v0.25 is used and thus making it more compatible with gym-v0.26 Also, slightly cleanup tensors creating logic on GPU and CPU
1 parent 7ab03a0 commit deb0e34

File tree

1 file changed

+31
-34
lines changed

1 file changed

+31
-34
lines changed

intermediate_source/mario_rl_tutorial.py

Lines changed: 31 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -31,8 +31,10 @@
3131
######################################################################
3232
#
3333
#
34-
35-
# !pip install gym-super-mario-bros==7.3.0
34+
# .. code-block:: bash
35+
#
36+
# %%bash
37+
# pip install gym-super-mario-bros==7.4.0
3638

3739
import torch
3840
from torch import nn
@@ -95,16 +97,19 @@
9597
# (next) state, reward and other info.
9698
#
9799

98-
# Initialize Super Mario environment
99-
env = gym_super_mario_bros.make("SuperMarioBros-1-1-v0")
100+
# Initialize Super Mario environment (in v0.26 change render mode to 'human' to see results on the screen)
101+
if gym.__version__ < '0.26':
102+
env = gym_super_mario_bros.make("SuperMarioBros-1-1-v0", new_step_api=True)
103+
else:
104+
env = gym_super_mario_bros.make("SuperMarioBros-1-1-v0", render_mode='rgb', apply_api_compatibility=True)
100105

101106
# Limit the action-space to
102107
# 0. walk right
103108
# 1. jump right
104109
env = JoypadSpace(env, [["right"], ["right", "A"]])
105110

106111
env.reset()
107-
next_state, reward, done, info = env.step(action=0)
112+
next_state, reward, done, trunc, info = env.step(action=0)
108113
print(f"{next_state.shape},\n {reward},\n {done},\n {info}")
109114

110115

@@ -151,14 +156,13 @@ def __init__(self, env, skip):
151156
def step(self, action):
152157
"""Repeat action, and sum reward"""
153158
total_reward = 0.0
154-
done = False
155159
for i in range(self._skip):
156160
# Accumulate reward and repeat the same action
157-
obs, reward, done, info = self.env.step(action)
161+
obs, reward, done, trunk, info = self.env.step(action)
158162
total_reward += reward
159163
if done:
160164
break
161-
return obs, total_reward, done, info
165+
return obs, total_reward, done, trunk, info
162166

163167

164168
class GrayScaleObservation(gym.ObservationWrapper):
@@ -203,7 +207,10 @@ def observation(self, observation):
203207
env = SkipFrame(env, skip=4)
204208
env = GrayScaleObservation(env)
205209
env = ResizeObservation(env, shape=84)
206-
env = FrameStack(env, num_stack=4)
210+
if gym.__version__ < '0.26':
211+
env = FrameStack(env, num_stack=4, new_step_api=True)
212+
else:
213+
env = FrameStack(env, num_stack=4)
207214

208215

209216
######################################################################
@@ -283,12 +290,11 @@ def __init__(self, state_dim, action_dim, save_dir):
283290
self.action_dim = action_dim
284291
self.save_dir = save_dir
285292

286-
self.use_cuda = torch.cuda.is_available()
293+
self.device = "cuda" if torch.cuda.is_available() else "cpu"
287294

288295
# Mario's DNN to predict the most optimal action - we implement this in the Learn section
289296
self.net = MarioNet(self.state_dim, self.action_dim).float()
290-
if self.use_cuda:
291-
self.net = self.net.to(device="cuda")
297+
self.net = self.net.to(device=self.device)
292298

293299
self.exploration_rate = 1
294300
self.exploration_rate_decay = 0.99999975
@@ -312,12 +318,8 @@ def act(self, state):
312318

313319
# EXPLOIT
314320
else:
315-
state = state.__array__()
316-
if self.use_cuda:
317-
state = torch.tensor(state).cuda()
318-
else:
319-
state = torch.tensor(state)
320-
state = state.unsqueeze(0)
321+
state = state[0].__array__() if isinstance(state, tuple) else state.__array__()
322+
state = torch.tensor(state, device=self.device).unsqueeze(0)
321323
action_values = self.net(state, model="online")
322324
action_idx = torch.argmax(action_values, axis=1).item()
323325

@@ -363,21 +365,16 @@ def cache(self, state, next_state, action, reward, done):
363365
reward (float),
364366
done(bool))
365367
"""
366-
state = state.__array__()
367-
next_state = next_state.__array__()
368-
369-
if self.use_cuda:
370-
state = torch.tensor(state).cuda()
371-
next_state = torch.tensor(next_state).cuda()
372-
action = torch.tensor([action]).cuda()
373-
reward = torch.tensor([reward]).cuda()
374-
done = torch.tensor([done]).cuda()
375-
else:
376-
state = torch.tensor(state)
377-
next_state = torch.tensor(next_state)
378-
action = torch.tensor([action])
379-
reward = torch.tensor([reward])
380-
done = torch.tensor([done])
368+
def first_if_tuple(x):
369+
return x[0] if isinstance(x, tuple) else x
370+
state = first_if_tuple(state).__array__()
371+
next_state = first_if_tuple(next_state).__array__()
372+
373+
state = torch.tensor(state, device=self.device)
374+
next_state = torch.tensor(next_state, device=self.device)
375+
action = torch.tensor([action], device=self.device)
376+
reward = torch.tensor([reward], device=self.device)
377+
done = torch.tensor([done], device=self.device)
381378

382379
self.memory.append((state, next_state, action, reward, done,))
383380

@@ -753,7 +750,7 @@ def record(self, episode, epsilon, step):
753750
action = mario.act(state)
754751

755752
# Agent performs action
756-
next_state, reward, done, info = env.step(action)
753+
next_state, reward, done, trunc, info = env.step(action)
757754

758755
# Remember
759756
mario.cache(state, next_state, action, reward, done)

0 commit comments

Comments
 (0)