|
24 | 24 | an action, the environment *transitions* to a new state, and also
|
25 | 25 | returns a reward that indicates the consequences of the action. In this
|
26 | 26 | task, rewards are +1 for every incremental timestep and the environment
|
27 |
| -terminates if the pole falls over too far or the crat mover more then 2.4 |
| 27 | +terminates if the pole falls over too far or the cart moves more then 2.4 |
28 | 28 | units away from center. This means better performing scenarios will run
|
29 | 29 | for longer duration, accumulating larger return.
|
30 | 30 |
|
@@ -249,14 +249,15 @@ def forward(self, x):
|
249 | 249 | T.Resize(40, interpolation=Image.CUBIC),
|
250 | 250 | T.ToTensor()])
|
251 | 251 |
|
| 252 | + |
252 | 253 | def get_cart_location(screen_width):
|
253 | 254 | world_width = env.x_threshold * 2
|
254 | 255 | scale = screen_width / world_width
|
255 | 256 | return int(env.state[0] * scale + screen_width / 2.0) # MIDDLE OF CART
|
256 | 257 |
|
257 | 258 | def get_screen():
|
258 |
| - # Returned requested by gym is 400x600x3, but is sometimes larger such as |
259 |
| - # as 800x1200x3. Transpose into torch order (CHW). |
| 259 | + # Returned screen requested by gym is 400x600x3, but is sometimes larger |
| 260 | + # such as 800x1200x3. Transpose it into torch order (CHW). |
260 | 261 | screen = env.render(mode='rgb_array').transpose((2, 0, 1))
|
261 | 262 | # Cart is in the lower half, so strip off the top and bottom of the screen
|
262 | 263 | _, screen_height, screen_width = screen.shape
|
@@ -310,20 +311,18 @@ def get_screen():
|
310 | 311 | # episode.
|
311 | 312 | #
|
312 | 313 |
|
313 |
| -BATCH_SIZE = 196 #128 |
| 314 | +BATCH_SIZE = 128 |
314 | 315 | GAMMA = 0.999
|
315 | 316 | EPS_START = 0.9
|
316 |
| -EPS_END = 0.07 |
317 |
| -EPS_DECAY = 300 |
| 317 | +EPS_END = 0.05 |
| 318 | +EPS_DECAY = 200 |
318 | 319 | TARGET_UPDATE = 10
|
319 | 320 |
|
320 | 321 | # Get screen size so that we can initialize layers correctly based on shape
|
321 |
| -# returned from AI gym. Typical dimentions at this pont are close to 3x40x90 |
322 |
| -# which is the result of a clamped and down-scaled buffer in get_screen() |
| 322 | +# returned from AI gym. Typical dimensions at this point are close to 3x40x90 |
| 323 | +# which is the result of a clamped and down-scaled render buffer in get_screen() |
323 | 324 | init_screen = get_screen()
|
324 | 325 | _, _, screen_height, screen_width = init_screen.shape
|
325 |
| -#screen_height = init_screen.shape[2] |
326 |
| -#print("Screen size w,h:", screen_width, " ", screen_height) |
327 | 326 |
|
328 | 327 | policy_net = DQN(screen_height, screen_width).to(device)
|
329 | 328 | target_net = DQN(screen_height, screen_width).to(device)
|
@@ -452,7 +451,7 @@ def optimize_model():
|
452 | 451 | # duration improvements.
|
453 | 452 | #
|
454 | 453 |
|
455 |
| -num_episodes = 500 |
| 454 | +num_episodes = 50 |
456 | 455 | for i_episode in range(num_episodes):
|
457 | 456 | # Initialize the environment and state
|
458 | 457 | env.reset()
|
@@ -496,14 +495,14 @@ def optimize_model():
|
496 | 495 | plt.show()
|
497 | 496 |
|
498 | 497 | ######################################################################
|
499 |
| -# Here is the diagram that illustrates the overall resulting flow. |
| 498 | +# Here is the diagram that illustrates the overall resulting data flow. |
500 | 499 | #
|
501 | 500 | # .. figure:: /_static/img/reinforcement_learning_diagram.jpg
|
502 | 501 | #
|
503 | 502 | # Actions are chosen either randomly or based on a policy, getting the next
|
504 |
| -# step sample for the gym environment. We record the results in the |
505 |
| -# replay memory and also perform optimization step on every iteration. |
| 503 | +# step sample from the gym environment. We record the results in the |
| 504 | +# replay memory and also run optimization step on every iteration. |
506 | 505 | # Optimization picks a random batch from the replay memory to do training of the
|
507 |
| -# new policy. "Older" target_net, used in optimization to computed expected |
508 |
| -# Q values is updated occasionally to keep it current. |
| 506 | +# new policy. "Older" target_net is also used in optimization to compute the |
| 507 | +# expected Q values; it is updated occasionally to keep it current. |
509 | 508 | #
|
0 commit comments