Skip to content

Commit a0e5e66

Browse files
Qznansvekars
andauthored
Fix the Calculation of the KL divergence loss. (#2785)
Fix the Calculation of the KL divergence loss. Refer to torch.nn.KLDivLoss in pytorch Co-authored-by: Svetlana Karslioglu <svekars@meta.com>
1 parent 7e83c23 commit a0e5e66

File tree

1 file changed

+1
-1
lines changed

1 file changed

+1
-1
lines changed

beginner_source/knowledge_distillation_tutorial.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -324,7 +324,7 @@ def train_knowledge_distillation(teacher, student, train_loader, epochs, learnin
324324
soft_prob = nn.functional.log_softmax(student_logits / T, dim=-1)
325325

326326
# Calculate the soft targets loss. Scaled by T**2 as suggested by the authors of the paper "Distilling the knowledge in a neural network"
327-
soft_targets_loss = -torch.sum(soft_targets * soft_prob) / soft_prob.size()[0] * (T**2)
327+
soft_targets_loss = torch.sum(soft_targets * (soft_targets.log() - soft_prob)) / soft_prob.size()[0] * (T**2)
328328

329329
# Calculate the true label loss
330330
label_loss = ce_loss(student_logits, labels)

0 commit comments

Comments
 (0)