@@ -446,8 +446,8 @@ def forward(self, x):
446
446
447
447
# Initialize a modified lightweight network with the same seed as our other lightweight instances. This will be trained from scratch to examine the effectiveness of cosine loss minimization.
448
448
torch .manual_seed (42 )
449
- modified_light_nn = ModifiedLightNNCosine (num_classes = 10 ).to (device )
450
- print ("Norm of 1st layer:" , torch .norm (modified_light_nn .features [0 ].weight ).item ())
449
+ modified_nn_light = ModifiedLightNNCosine (num_classes = 10 ).to (device )
450
+ print ("Norm of 1st layer:" , torch .norm (modified_nn_light .features [0 ].weight ).item ())
451
451
452
452
######################################################################
453
453
# Naturally, we need to change the train loop because now the model returns a tuple ``(logits, hidden_representation)``. Using a sample input tensor
@@ -457,7 +457,7 @@ def forward(self, x):
457
457
sample_input = torch .randn (128 , 3 , 32 , 32 ).to (device ) # Batch size: 128, Filters: 3, Image size: 32x32
458
458
459
459
# Pass the input through the student
460
- logits , hidden_representation = modified_light_nn (sample_input )
460
+ logits , hidden_representation = modified_nn_light (sample_input )
461
461
462
462
# Print the shapes of the tensors
463
463
print ("Student logits shape:" , logits .shape ) # batch_size x total_classes
@@ -551,8 +551,8 @@ def test_multiple_outputs(model, test_loader, device):
551
551
# For now, we can run a simple train-test session.
552
552
553
553
# Train and test the lightweight network with cross entropy loss
554
- train_cosine_loss (teacher = modified_nn_deep , student = modified_light_nn , train_loader = train_loader , epochs = 10 , learning_rate = 0.001 , hidden_rep_loss_weight = 0.25 , ce_loss_weight = 0.75 , device = device )
555
- test_accuracy_light_ce_and_cosine_loss = test_multiple_outputs (modified_light_nn , test_loader , device )
554
+ train_cosine_loss (teacher = modified_nn_deep , student = modified_nn_light , train_loader = train_loader , epochs = 10 , learning_rate = 0.001 , hidden_rep_loss_weight = 0.25 , ce_loss_weight = 0.75 , device = device )
555
+ test_accuracy_light_ce_and_cosine_loss = test_multiple_outputs (modified_nn_light , test_loader , device )
556
556
557
557
######################################################################
558
558
# Intermediate regressor run
@@ -699,15 +699,15 @@ def train_mse_loss(teacher, student, train_loader, epochs, learning_rate, featur
699
699
700
700
# Initialize a ModifiedLightNNRegressor
701
701
torch .manual_seed (42 )
702
- modified_light_nn_mse = ModifiedLightNNRegressor (num_classes = 10 ).to (device )
702
+ modified_nn_light_reg = ModifiedLightNNRegressor (num_classes = 10 ).to (device )
703
703
704
704
# We do not have to train the modified deep network from scratch of course, we just load its weights from the trained instance
705
705
modified_nn_deep_reg = ModifiedDeepNNRegressor (num_classes = 10 ).to (device )
706
706
modified_nn_deep_reg .load_state_dict (nn_deep .state_dict ())
707
707
708
708
# Train and test once again
709
- train_mse_loss (teacher = modified_nn_deep_reg , student = modified_light_nn_mse , train_loader = train_loader , epochs = 10 , learning_rate = 0.001 , feature_map_weight = 0.25 , ce_loss_weight = 0.75 , device = device )
710
- test_accuracy_light_ce_and_mse_loss = test_multiple_outputs (modified_light_nn_mse , test_loader , device )
709
+ train_mse_loss (teacher = modified_nn_deep_reg , student = modified_nn_light_reg , train_loader = train_loader , epochs = 10 , learning_rate = 0.001 , feature_map_weight = 0.25 , ce_loss_weight = 0.75 , device = device )
710
+ test_accuracy_light_ce_and_mse_loss = test_multiple_outputs (modified_nn_light_reg , test_loader , device )
711
711
712
712
######################################################################
713
713
# It is expected that the final method will work better than ``CosineLoss`` because now we have allowed a trainable layer between the teacher and the student,
0 commit comments