Skip to content

Commit cafd91e

Browse files
authored
Apply suggestions from code review
Editorial clean up
1 parent c79d298 commit cafd91e

File tree

1 file changed

+10
-9
lines changed

1 file changed

+10
-9
lines changed

beginner_source/knowledge_distillation_tutorial.py

Lines changed: 10 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@
2929
#
3030
# * 1 GPU, 4GB of memory
3131
# * PyTorch v2.0 or later
32-
# * CIFAR-10 dataset (downloaded by the script and saved it in a directory called ``/data``)
32+
# * CIFAR-10 dataset (downloaded by the script and saved in a directory called ``/data``)
3333

3434
import torch
3535
import torch.nn as nn
@@ -156,7 +156,7 @@ def forward(self, x):
156156
# One function is called ``train`` and takes the following arguments:
157157
#
158158
# - ``model``: A model instance to train (update its weights) via this function.
159-
# - ``train_loader``: we defined our ``train_loader`` above, and its job is to feed the data into the model.
159+
# - ``train_loader``: We defined our ``train_loader`` above, and its job is to feed the data into the model.
160160
# - ``epochs``: How many times we loop over the dataset.
161161
# - ``learning_rate``: The learning rate determines how large our steps towards convergence should be. Too large or too small steps can be detrimental.
162162
# - ``device``: Determines the device to run the workload on. Can be either CPU or GPU depending on availability.
@@ -166,7 +166,7 @@ def forward(self, x):
166166
# .. figure:: /../_static/img/knowledge_distillation/ce_only.png
167167
# :align: center
168168
#
169-
# Train both networks with Cross-Entropy. The student will be used as a baseline
169+
# Train both networks with Cross-Entropy. The student will be used as a baseline:
170170
#
171171

172172
def train(model, train_loader, epochs, learning_rate, device):
@@ -250,14 +250,14 @@ def test(model, test_loader, device):
250250
print("Norm of 1st layer of new_nn_light:", torch.norm(new_nn_light.features[0].weight).item())
251251

252252
######################################################################
253-
# Print the total number of parameters in each model.
253+
# Print the total number of parameters in each model:
254254
total_params_deep = "{:,}".format(sum(p.numel() for p in nn_deep.parameters()))
255255
print(f"DeepNN parameters: {total_params_deep}")
256256
total_params_light = "{:,}".format(sum(p.numel() for p in nn_light.parameters()))
257257
print(f"LightNN parameters: {total_params_light}")
258258

259259
######################################################################
260-
# Train and test the lightweight network with cross entropy loss
260+
# Train and test the lightweight network with cross entropy loss:
261261
train(nn_light, train_loader, epochs=10, learning_rate=0.001, device=device)
262262
test_accuracy_light_ce = test(nn_light, test_loader, device)
263263

@@ -295,7 +295,7 @@ def test(model, test_loader, device):
295295
# .. figure:: /../_static/img/knowledge_distillation/distillation_output_loss.png
296296
# :align: center
297297
#
298-
# Distillation loss is calculated from the logits of the networks. It only returns gradients to the student
298+
# Distillation loss is calculated from the logits of the networks. It only returns gradients to the student:
299299
#
300300

301301
def train_knowledge_distillation(teacher, student, train_loader, epochs, learning_rate, T, soft_target_loss_weight, ce_loss_weight, device):
@@ -479,7 +479,7 @@ def forward(self, x):
479479
# .. figure:: /../_static/img/knowledge_distillation/cosine_loss_distillation.png
480480
# :align: center
481481
#
482-
# In Cosine Loss minimization we want to maximize the cosine similarity of the two representations by returning gradients to the student
482+
# In Cosine Loss minimization, we want to maximize the cosine similarity of the two representations by returning gradients to the student:
483483
#
484484

485485
def train_cosine_loss(teacher, student, train_loader, epochs, learning_rate, hidden_rep_loss_weight, ce_loss_weight, device):
@@ -589,7 +589,7 @@ def test_multiple_outputs(model, test_loader, device):
589589
# .. figure:: /../_static/img/knowledge_distillation/fitnets_knowledge_distill.png
590590
# :align: center
591591
#
592-
# The trainable layer matches the shapes of the intermediate tensors and Mean Squared Error ``(MSE)`` is properly defined
592+
# The trainable layer matches the shapes of the intermediate tensors and Mean Squared Error (MSE) is properly defined:
593593
#
594594

595595
class ModifiedDeepNNRegressor(nn.Module):
@@ -721,7 +721,7 @@ def train_mse_loss(teacher, student, train_loader, epochs, learning_rate, featur
721721
print(f"Student accuracy with CE + RegressorMSE: {test_accuracy_light_ce_and_mse_loss:.2f}%")
722722

723723
######################################################################
724-
# Conclusions
724+
# Conclusion
725725
# --------------------------------------------
726726
# None of the methods above increases the number of parameters for the network or inference time,
727727
# so the performance increase comes at the little cost of calculating gradients during training.
@@ -732,6 +732,7 @@ def train_mse_loss(teacher, student, train_loader, epochs, learning_rate, featur
732732
# but keep in mind, if you change the number of neurons / filters chances are a shape mismatch might occur.
733733
#
734734
# For more information, see:
735+
#
735736
# * `Hinton, G., Vinyals, O., Dean, J.: Distilling the knowledge in a neural network. In: Neural Information Processing System Deep Learning Workshop (2015) <https://arxiv.org/abs/1503.02531>`_
736737
#
737738
# * `Romero, A., Ballas, N., Kahou, S.E., Chassang, A., Gatta, C., Bengio, Y.: Fitnets: Hints for thin deep nets. In: Proceedings of the International Conference on Learning Representations (2015) <https://arxiv.org/abs/1412.6550>`_

0 commit comments

Comments
 (0)