File tree Expand file tree Collapse file tree 1 file changed +3
-3
lines changed Expand file tree Collapse file tree 1 file changed +3
-3
lines changed Original file line number Diff line number Diff line change @@ -283,7 +283,7 @@ def select_action(state):
283
283
# t.max(1) will return the largest column value of each row.
284
284
# second column on max result is index of where max element was
285
285
# found, so we pick action with the larger expected reward.
286
- return policy_net (state ).max (1 )[ 1 ] .view (1 , 1 )
286
+ return policy_net (state ).max (1 ). indices .view (1 , 1 )
287
287
else :
288
288
return torch .tensor ([[env .action_space .sample ()]], device = device , dtype = torch .long )
289
289
@@ -360,12 +360,12 @@ def optimize_model():
360
360
361
361
# Compute V(s_{t+1}) for all next states.
362
362
# Expected values of actions for non_final_next_states are computed based
363
- # on the "older" target_net; selecting their best reward with max(1)[0].
363
+ # on the "older" target_net; selecting their best reward with max(1).values
364
364
# This is merged based on the mask, such that we'll have either the expected
365
365
# state value or 0 in case the state was final.
366
366
next_state_values = torch .zeros (BATCH_SIZE , device = device )
367
367
with torch .no_grad ():
368
- next_state_values [non_final_mask ] = target_net (non_final_next_states ).max (1 )[ 0 ]
368
+ next_state_values [non_final_mask ] = target_net (non_final_next_states ).max (1 ). values
369
369
# Compute the expected Q values
370
370
expected_state_action_values = (next_state_values * GAMMA ) + reward_batch
371
371
You can’t perform that action at this time.
0 commit comments