diff --git a/intermediate_source/reinforcement_q_learning.py b/intermediate_source/reinforcement_q_learning.py index a5ba6ada8a1..ad38aed66fc 100644 --- a/intermediate_source/reinforcement_q_learning.py +++ b/intermediate_source/reinforcement_q_learning.py @@ -63,7 +63,7 @@ import numpy as np import matplotlib import matplotlib.pyplot as plt -from collections import namedtuple +from collections import namedtuple, deque from itertools import count from PIL import Image @@ -115,16 +115,11 @@ class ReplayMemory(object): def __init__(self, capacity): - self.capacity = capacity - self.memory = [] - self.position = 0 + self.memory = deque([],maxlen=capacity) def push(self, *args): - """Saves a transition.""" - if len(self.memory) < self.capacity: - self.memory.append(None) - self.memory[self.position] = Transition(*args) - self.position = (self.position + 1) % self.capacity + """Save a transition""" + self.memory.append(Transition(*args)) def sample(self, batch_size): return random.sample(self.memory, batch_size)