Skip to content

Commit b7569a3

Browse files
author
Michael Antonov
committed
Fixed comment typos and training values.
1 parent 5f1dcfe commit b7569a3

File tree

1 file changed

+15
-16
lines changed

1 file changed

+15
-16
lines changed

intermediate_source/reinforcement_q_learning.py

Lines changed: 15 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@
2424
an action, the environment *transitions* to a new state, and also
2525
returns a reward that indicates the consequences of the action. In this
2626
task, rewards are +1 for every incremental timestep and the environment
27-
terminates if the pole falls over too far or the crat mover more then 2.4
27+
terminates if the pole falls over too far or the cart moves more then 2.4
2828
units away from center. This means better performing scenarios will run
2929
for longer duration, accumulating larger return.
3030
@@ -249,14 +249,15 @@ def forward(self, x):
249249
T.Resize(40, interpolation=Image.CUBIC),
250250
T.ToTensor()])
251251

252+
252253
def get_cart_location(screen_width):
253254
world_width = env.x_threshold * 2
254255
scale = screen_width / world_width
255256
return int(env.state[0] * scale + screen_width / 2.0) # MIDDLE OF CART
256257

257258
def get_screen():
258-
# Returned requested by gym is 400x600x3, but is sometimes larger such as
259-
# as 800x1200x3. Transpose into torch order (CHW).
259+
# Returned screen requested by gym is 400x600x3, but is sometimes larger
260+
# such as 800x1200x3. Transpose it into torch order (CHW).
260261
screen = env.render(mode='rgb_array').transpose((2, 0, 1))
261262
# Cart is in the lower half, so strip off the top and bottom of the screen
262263
_, screen_height, screen_width = screen.shape
@@ -310,20 +311,18 @@ def get_screen():
310311
# episode.
311312
#
312313

313-
BATCH_SIZE = 196 #128
314+
BATCH_SIZE = 128
314315
GAMMA = 0.999
315316
EPS_START = 0.9
316-
EPS_END = 0.07
317-
EPS_DECAY = 300
317+
EPS_END = 0.05
318+
EPS_DECAY = 200
318319
TARGET_UPDATE = 10
319320

320321
# Get screen size so that we can initialize layers correctly based on shape
321-
# returned from AI gym. Typical dimentions at this pont are close to 3x40x90
322-
# which is the result of a clamped and down-scaled buffer in get_screen()
322+
# returned from AI gym. Typical dimensions at this point are close to 3x40x90
323+
# which is the result of a clamped and down-scaled render buffer in get_screen()
323324
init_screen = get_screen()
324325
_, _, screen_height, screen_width = init_screen.shape
325-
#screen_height = init_screen.shape[2]
326-
#print("Screen size w,h:", screen_width, " ", screen_height)
327326

328327
policy_net = DQN(screen_height, screen_width).to(device)
329328
target_net = DQN(screen_height, screen_width).to(device)
@@ -452,7 +451,7 @@ def optimize_model():
452451
# duration improvements.
453452
#
454453

455-
num_episodes = 500
454+
num_episodes = 50
456455
for i_episode in range(num_episodes):
457456
# Initialize the environment and state
458457
env.reset()
@@ -496,14 +495,14 @@ def optimize_model():
496495
plt.show()
497496

498497
######################################################################
499-
# Here is the diagram that illustrates the overall resulting flow.
498+
# Here is the diagram that illustrates the overall resulting data flow.
500499
#
501500
# .. figure:: /_static/img/reinforcement_learning_diagram.jpg
502501
#
503502
# Actions are chosen either randomly or based on a policy, getting the next
504-
# step sample for the gym environment. We record the results in the
505-
# replay memory and also perform optimization step on every iteration.
503+
# step sample from the gym environment. We record the results in the
504+
# replay memory and also run optimization step on every iteration.
506505
# Optimization picks a random batch from the replay memory to do training of the
507-
# new policy. "Older" target_net, used in optimization to computed expected
508-
# Q values is updated occasionally to keep it current.
506+
# new policy. "Older" target_net is also used in optimization to compute the
507+
# expected Q values; it is updated occasionally to keep it current.
509508
#

0 commit comments

Comments
 (0)