Skip to content

Commit 8dfdda4

Browse files
Juphexholly1238
andauthored
fixing loss (#1420)
Co-authored-by: holly1238 <77758406+holly1238@users.noreply.github.com>
1 parent f1fd16f commit 8dfdda4

File tree

1 file changed

+2
-1
lines changed

1 file changed

+2
-1
lines changed

intermediate_source/reinforcement_q_learning.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -426,7 +426,8 @@ def optimize_model():
426426
expected_state_action_values = (next_state_values * GAMMA) + reward_batch
427427

428428
# Compute Huber loss
429-
loss = F.smooth_l1_loss(state_action_values, expected_state_action_values.unsqueeze(1))
429+
criterion = nn.SmoothL1Loss()
430+
loss = criterion(state_action_values, expected_state_action_values.unsqueeze(1))
430431

431432
# Optimize the model
432433
optimizer.zero_grad()

0 commit comments

Comments
 (0)