Skip to content

Commit c79d298

Browse files
final review fixes
1 parent beec23c commit c79d298

File tree

1 file changed

+8
-8
lines changed

1 file changed

+8
-8
lines changed

beginner_source/knowledge_distillation_tutorial.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -446,8 +446,8 @@ def forward(self, x):
446446

447447
# Initialize a modified lightweight network with the same seed as our other lightweight instances. This will be trained from scratch to examine the effectiveness of cosine loss minimization.
448448
torch.manual_seed(42)
449-
modified_light_nn = ModifiedLightNNCosine(num_classes=10).to(device)
450-
print("Norm of 1st layer:", torch.norm(modified_light_nn.features[0].weight).item())
449+
modified_nn_light = ModifiedLightNNCosine(num_classes=10).to(device)
450+
print("Norm of 1st layer:", torch.norm(modified_nn_light.features[0].weight).item())
451451

452452
######################################################################
453453
# Naturally, we need to change the train loop because now the model returns a tuple ``(logits, hidden_representation)``. Using a sample input tensor
@@ -457,7 +457,7 @@ def forward(self, x):
457457
sample_input = torch.randn(128, 3, 32, 32).to(device) # Batch size: 128, Filters: 3, Image size: 32x32
458458

459459
# Pass the input through the student
460-
logits, hidden_representation = modified_light_nn(sample_input)
460+
logits, hidden_representation = modified_nn_light(sample_input)
461461

462462
# Print the shapes of the tensors
463463
print("Student logits shape:", logits.shape) # batch_size x total_classes
@@ -551,8 +551,8 @@ def test_multiple_outputs(model, test_loader, device):
551551
# For now, we can run a simple train-test session.
552552

553553
# Train and test the lightweight network with cross entropy loss
554-
train_cosine_loss(teacher=modified_nn_deep, student=modified_light_nn, train_loader=train_loader, epochs=10, learning_rate=0.001, hidden_rep_loss_weight=0.25, ce_loss_weight=0.75, device=device)
555-
test_accuracy_light_ce_and_cosine_loss = test_multiple_outputs(modified_light_nn, test_loader, device)
554+
train_cosine_loss(teacher=modified_nn_deep, student=modified_nn_light, train_loader=train_loader, epochs=10, learning_rate=0.001, hidden_rep_loss_weight=0.25, ce_loss_weight=0.75, device=device)
555+
test_accuracy_light_ce_and_cosine_loss = test_multiple_outputs(modified_nn_light, test_loader, device)
556556

557557
######################################################################
558558
# Intermediate regressor run
@@ -699,15 +699,15 @@ def train_mse_loss(teacher, student, train_loader, epochs, learning_rate, featur
699699

700700
# Initialize a ModifiedLightNNRegressor
701701
torch.manual_seed(42)
702-
modified_light_nn_mse = ModifiedLightNNRegressor(num_classes=10).to(device)
702+
modified_nn_light_reg = ModifiedLightNNRegressor(num_classes=10).to(device)
703703

704704
# We do not have to train the modified deep network from scratch of course, we just load its weights from the trained instance
705705
modified_nn_deep_reg = ModifiedDeepNNRegressor(num_classes=10).to(device)
706706
modified_nn_deep_reg.load_state_dict(nn_deep.state_dict())
707707

708708
# Train and test once again
709-
train_mse_loss(teacher=modified_nn_deep_reg, student=modified_light_nn_mse, train_loader=train_loader, epochs=10, learning_rate=0.001, feature_map_weight=0.25, ce_loss_weight=0.75, device=device)
710-
test_accuracy_light_ce_and_mse_loss = test_multiple_outputs(modified_light_nn_mse, test_loader, device)
709+
train_mse_loss(teacher=modified_nn_deep_reg, student=modified_nn_light_reg, train_loader=train_loader, epochs=10, learning_rate=0.001, feature_map_weight=0.25, ce_loss_weight=0.75, device=device)
710+
test_accuracy_light_ce_and_mse_loss = test_multiple_outputs(modified_nn_light_reg, test_loader, device)
711711

712712
######################################################################
713713
# It is expected that the final method will work better than ``CosineLoss`` because now we have allowed a trainable layer between the teacher and the student,

0 commit comments

Comments
 (0)