diff --git a/intermediate_source/reinforcement_q_learning.py b/intermediate_source/reinforcement_q_learning.py index 78dc7e2fc6e..42bea7c3e9e 100644 --- a/intermediate_source/reinforcement_q_learning.py +++ b/intermediate_source/reinforcement_q_learning.py @@ -283,7 +283,7 @@ def select_action(state): # t.max(1) will return the largest column value of each row. # second column on max result is index of where max element was # found, so we pick action with the larger expected reward. - return policy_net(state).max(1)[1].view(1, 1) + return policy_net(state).max(1).indices.view(1, 1) else: return torch.tensor([[env.action_space.sample()]], device=device, dtype=torch.long) @@ -360,12 +360,12 @@ def optimize_model(): # Compute V(s_{t+1}) for all next states. # Expected values of actions for non_final_next_states are computed based - # on the "older" target_net; selecting their best reward with max(1)[0]. + # on the "older" target_net; selecting their best reward with max(1).values # This is merged based on the mask, such that we'll have either the expected # state value or 0 in case the state was final. next_state_values = torch.zeros(BATCH_SIZE, device=device) with torch.no_grad(): - next_state_values[non_final_mask] = target_net(non_final_next_states).max(1)[0] + next_state_values[non_final_mask] = target_net(non_final_next_states).max(1).values # Compute the expected Q values expected_state_action_values = (next_state_values * GAMMA) + reward_batch