Skip to content

Commit 5c632a0

Browse files
authored
Update reinforcement_q_learning.py - use named tuples rather than indices (#2689)
1 parent 51a3f60 commit 5c632a0

File tree

1 file changed

+3
-3
lines changed

1 file changed

+3
-3
lines changed

intermediate_source/reinforcement_q_learning.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -283,7 +283,7 @@ def select_action(state):
283283
# t.max(1) will return the largest column value of each row.
284284
# second column on max result is index of where max element was
285285
# found, so we pick action with the larger expected reward.
286-
return policy_net(state).max(1)[1].view(1, 1)
286+
return policy_net(state).max(1).indices.view(1, 1)
287287
else:
288288
return torch.tensor([[env.action_space.sample()]], device=device, dtype=torch.long)
289289

@@ -360,12 +360,12 @@ def optimize_model():
360360

361361
# Compute V(s_{t+1}) for all next states.
362362
# Expected values of actions for non_final_next_states are computed based
363-
# on the "older" target_net; selecting their best reward with max(1)[0].
363+
# on the "older" target_net; selecting their best reward with max(1).values
364364
# This is merged based on the mask, such that we'll have either the expected
365365
# state value or 0 in case the state was final.
366366
next_state_values = torch.zeros(BATCH_SIZE, device=device)
367367
with torch.no_grad():
368-
next_state_values[non_final_mask] = target_net(non_final_next_states).max(1)[0]
368+
next_state_values[non_final_mask] = target_net(non_final_next_states).max(1).values
369369
# Compute the expected Q values
370370
expected_state_action_values = (next_state_values * GAMMA) + reward_batch
371371

0 commit comments

Comments
 (0)