Skip to content

Commit 6e0fd0a

Browse files
neuralninja27vmoensSvetlana Karslioglu
authored
Update mario_rl_tutorial.py (#2381)
* Update mario_rl_tutorial.py Fixes #1620 --------- Co-authored-by: Vincent Moens <vincentmoens@gmail.com> Co-authored-by: Svetlana Karslioglu <svekars@fb.com>
1 parent 2284ab2 commit 6e0fd0a

File tree

1 file changed

+7
-4
lines changed

1 file changed

+7
-4
lines changed

intermediate_source/mario_rl_tutorial.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,8 @@
5353
# Super Mario environment for OpenAI Gym
5454
import gym_super_mario_bros
5555

56+
from tensordict import TensorDict
57+
from torchrl.data import TensorDictReplayBuffer, LazyMemmapStorage
5658

5759
######################################################################
5860
# RL Definitions
@@ -348,7 +350,7 @@ def act(self, state):
348350
class Mario(Mario): # subclassing for continuity
349351
def __init__(self, state_dim, action_dim, save_dir):
350352
super().__init__(state_dim, action_dim, save_dir)
351-
self.memory = deque(maxlen=100000)
353+
self.memory = TensorDictReplayBuffer(storage=LazyMemmapStorage(100000))
352354
self.batch_size = 32
353355

354356
def cache(self, state, next_state, action, reward, done):
@@ -373,14 +375,15 @@ def first_if_tuple(x):
373375
reward = torch.tensor([reward], device=self.device)
374376
done = torch.tensor([done], device=self.device)
375377

376-
self.memory.append((state, next_state, action, reward, done,))
378+
# self.memory.append((state, next_state, action, reward, done,))
379+
self.memory.add(TensorDict({"state": state, "next_state": next_state, "action": action, "reward": reward, "done": done}, batch_size=[]))
377380

378381
def recall(self):
379382
"""
380383
Retrieve a batch of experiences from memory
381384
"""
382-
batch = random.sample(self.memory, self.batch_size)
383-
state, next_state, action, reward, done = map(torch.stack, zip(*batch))
385+
batch = self.memory.sample(self.batch_size)
386+
state, next_state, action, reward, done = (batch.get(key) for key in ("state", "next_state", "action", "reward", "done"))
384387
return state, next_state, action.squeeze(), reward.squeeze(), done.squeeze()
385388

386389

0 commit comments

Comments
 (0)