Skip to content

Commit 0fa8074

Browse files
authored
Merge pull request #389 from mike7ant/master
Fixed reinforcement learning to run with any screen size; added diagram
2 parents 4779db8 + b7569a3 commit 0fa8074

File tree

2 files changed

+65
-22
lines changed

2 files changed

+65
-22
lines changed
23.2 KB
Loading

intermediate_source/reinforcement_q_learning.py

Lines changed: 65 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,10 @@
2323
As the agent observes the current state of the environment and chooses
2424
an action, the environment *transitions* to a new state, and also
2525
returns a reward that indicates the consequences of the action. In this
26-
task, the environment terminates if the pole falls over too far.
26+
task, rewards are +1 for every incremental timestep and the environment
27+
terminates if the pole falls over too far or the cart moves more then 2.4
28+
units away from center. This means better performing scenarios will run
29+
for longer duration, accumulating larger return.
2730
2831
The CartPole task is designed so that the inputs to the agent are 4 real
2932
values representing the environment state (position, velocity, etc.).
@@ -97,7 +100,9 @@
97100
# For this, we're going to need two classses:
98101
#
99102
# - ``Transition`` - a named tuple representing a single transition in
100-
# our environment
103+
# our environment. It maps essentially maps (state, action) pairs
104+
# to their (next_state, reward) result, with the state being the
105+
# screen difference image as described later on.
101106
# - ``ReplayMemory`` - a cyclic buffer of bounded size that holds the
102107
# transitions observed recently. It also implements a ``.sample()``
103108
# method for selecting a random batch of transitions for training.
@@ -197,22 +202,32 @@ def __len__(self):
197202
# difference between the current and previous screen patches. It has two
198203
# outputs, representing :math:`Q(s, \mathrm{left})` and
199204
# :math:`Q(s, \mathrm{right})` (where :math:`s` is the input to the
200-
# network). In effect, the network is trying to predict the *quality* of
205+
# network). In effect, the network is trying to predict the *expected return* of
201206
# taking each action given the current input.
202207
#
203208

204209
class DQN(nn.Module):
205210

206-
def __init__(self):
211+
def __init__(self, h, w):
207212
super(DQN, self).__init__()
208213
self.conv1 = nn.Conv2d(3, 16, kernel_size=5, stride=2)
209214
self.bn1 = nn.BatchNorm2d(16)
210215
self.conv2 = nn.Conv2d(16, 32, kernel_size=5, stride=2)
211216
self.bn2 = nn.BatchNorm2d(32)
212217
self.conv3 = nn.Conv2d(32, 32, kernel_size=5, stride=2)
213218
self.bn3 = nn.BatchNorm2d(32)
214-
self.head = nn.Linear(448, 2)
215219

220+
# Number of Linear input connections depends on output of conv2d layers
221+
# and therefore the input image size, so compute it.
222+
def conv2d_size_out(size, kernel_size = 5, stride = 2):
223+
return (size - (kernel_size - 1) - 1) // stride + 1
224+
convw = conv2d_size_out(conv2d_size_out(conv2d_size_out(w)))
225+
convh = conv2d_size_out(conv2d_size_out(conv2d_size_out(h)))
226+
linear_input_size = convw * convh * 32
227+
self.head = nn.Linear(linear_input_size, 2) # 448 or 512
228+
229+
# Called with either one element to determine next action, or a batch
230+
# during optimization. Returns tensor([[left0exp,right0exp]...]).
216231
def forward(self, x):
217232
x = F.relu(self.bn1(self.conv1(x)))
218233
x = F.relu(self.bn2(self.conv2(x)))
@@ -234,23 +249,21 @@ def forward(self, x):
234249
T.Resize(40, interpolation=Image.CUBIC),
235250
T.ToTensor()])
236251

237-
# This is based on the code from gym.
238-
screen_width = 600
239252

240-
241-
def get_cart_location():
253+
def get_cart_location(screen_width):
242254
world_width = env.x_threshold * 2
243255
scale = screen_width / world_width
244256
return int(env.state[0] * scale + screen_width / 2.0) # MIDDLE OF CART
245257

246-
247258
def get_screen():
248-
screen = env.render(mode='rgb_array').transpose(
249-
(2, 0, 1)) # transpose into torch order (CHW)
250-
# Strip off the top and bottom of the screen
251-
screen = screen[:, 160:320]
252-
view_width = 320
253-
cart_location = get_cart_location()
259+
# Returned screen requested by gym is 400x600x3, but is sometimes larger
260+
# such as 800x1200x3. Transpose it into torch order (CHW).
261+
screen = env.render(mode='rgb_array').transpose((2, 0, 1))
262+
# Cart is in the lower half, so strip off the top and bottom of the screen
263+
_, screen_height, screen_width = screen.shape
264+
screen = screen[:, int(screen_height*0.4):int(screen_height * 0.8)]
265+
view_width = int(screen_width * 0.6)
266+
cart_location = get_cart_location(screen_width)
254267
if cart_location < view_width // 2:
255268
slice_range = slice(view_width)
256269
elif cart_location > (screen_width - view_width // 2):
@@ -305,8 +318,14 @@ def get_screen():
305318
EPS_DECAY = 200
306319
TARGET_UPDATE = 10
307320

308-
policy_net = DQN().to(device)
309-
target_net = DQN().to(device)
321+
# Get screen size so that we can initialize layers correctly based on shape
322+
# returned from AI gym. Typical dimensions at this point are close to 3x40x90
323+
# which is the result of a clamped and down-scaled render buffer in get_screen()
324+
init_screen = get_screen()
325+
_, _, screen_height, screen_width = init_screen.shape
326+
327+
policy_net = DQN(screen_height, screen_width).to(device)
328+
target_net = DQN(screen_height, screen_width).to(device)
310329
target_net.load_state_dict(policy_net.state_dict())
311330
target_net.eval()
312331

@@ -325,6 +344,9 @@ def select_action(state):
325344
steps_done += 1
326345
if sample > eps_threshold:
327346
with torch.no_grad():
347+
# t.max(1) will return largest value for column of each row.
348+
# second column on max result is index of where max element was
349+
# found, so we pick action with the larger expected reward.
328350
return policy_net(state).max(1)[1].view(1, 1)
329351
else:
330352
return torch.tensor([[random.randrange(2)]], device=device, dtype=torch.long)
@@ -376,10 +398,12 @@ def optimize_model():
376398
return
377399
transitions = memory.sample(BATCH_SIZE)
378400
# Transpose the batch (see http://stackoverflow.com/a/19343/3343043 for
379-
# detailed explanation).
401+
# detailed explanation). This converts batch-array of Transitions
402+
# to Transition of batch-arrays.
380403
batch = Transition(*zip(*transitions))
381404

382405
# Compute a mask of non-final states and concatenate the batch elements
406+
# (a final state would've been the one after which simulation ended)
383407
non_final_mask = torch.tensor(tuple(map(lambda s: s is not None,
384408
batch.next_state)), device=device, dtype=torch.uint8)
385409
non_final_next_states = torch.cat([s for s in batch.next_state
@@ -389,10 +413,15 @@ def optimize_model():
389413
reward_batch = torch.cat(batch.reward)
390414

391415
# Compute Q(s_t, a) - the model computes Q(s_t), then we select the
392-
# columns of actions taken
416+
# columns of actions taken. These are the actions which would've been taken
417+
# for each batch state according to policy_net
393418
state_action_values = policy_net(state_batch).gather(1, action_batch)
394419

395420
# Compute V(s_{t+1}) for all next states.
421+
# Expected values of actions for non_final_next_states are computed based
422+
# on the "older" target_net; selecting their best reward with max(1)[0].
423+
# This is merged based on the mask, such that we'll have either the expected
424+
# state value or 0 in case the state was final.
396425
next_state_values = torch.zeros(BATCH_SIZE, device=device)
397426
next_state_values[non_final_mask] = target_net(non_final_next_states).max(1)[0].detach()
398427
# Compute the expected Q values
@@ -418,7 +447,8 @@ def optimize_model():
418447
# fails), we restart the loop.
419448
#
420449
# Below, `num_episodes` is set small. You should download
421-
# the notebook and run lot more epsiodes.
450+
# the notebook and run lot more epsiodes, such as 300+ for meaningful
451+
# duration improvements.
422452
#
423453

424454
num_episodes = 50
@@ -454,7 +484,7 @@ def optimize_model():
454484
episode_durations.append(t + 1)
455485
plot_durations()
456486
break
457-
# Update the target network
487+
# Update the target network, copying all weights and biases in DQN
458488
if i_episode % TARGET_UPDATE == 0:
459489
target_net.load_state_dict(policy_net.state_dict())
460490

@@ -463,3 +493,16 @@ def optimize_model():
463493
env.close()
464494
plt.ioff()
465495
plt.show()
496+
497+
######################################################################
498+
# Here is the diagram that illustrates the overall resulting data flow.
499+
#
500+
# .. figure:: /_static/img/reinforcement_learning_diagram.jpg
501+
#
502+
# Actions are chosen either randomly or based on a policy, getting the next
503+
# step sample from the gym environment. We record the results in the
504+
# replay memory and also run optimization step on every iteration.
505+
# Optimization picks a random batch from the replay memory to do training of the
506+
# new policy. "Older" target_net is also used in optimization to compute the
507+
# expected Q values; it is updated occasionally to keep it current.
508+
#

0 commit comments

Comments
 (0)