|
5 | 5 | **Author**: `Hao Chen <https://github.com/Hhhhhhao>`_
|
6 | 6 |
|
7 | 7 | Unified Semi-supervised learning Benchmark (USB) is a semi-supervised
|
8 |
| -learning framework built upon PyTorch. |
| 8 | +learning (SSL) framework built upon PyTorch. |
9 | 9 | Based on Datasets and Modules provided by PyTorch, USB becomes a flexible,
|
10 | 10 | modular, and easy-to-use framework for semi-supervised learning.
|
11 | 11 | It supports a variety of semi-supervised learning algorithms, including
|
|
17 | 17 | This tutorial will walk you through the basics of using the USB lighting
|
18 | 18 | package.
|
19 | 19 | Let's get started by training a ``FreeMatch``/``SoftMatch`` model on
|
20 |
| -CIFAR-10 using pretrained ViT! |
| 20 | +CIFAR-10 using pretrained Vision Transformers (ViT)! |
21 | 21 | And we will show it is easy to change the semi-supervised algorithm and train
|
22 | 22 | on imbalanced datasets.
|
23 | 23 |
|
|
64 | 64 | # Now, let's use USB to train ``FreeMatch`` and ``SoftMatch`` on CIFAR-10.
|
65 | 65 | # First, we need to install USB package ``semilearn`` and import necessary API
|
66 | 66 | # functions from USB.
|
| 67 | +# If you are running this in Google Colab, install ``semilearn`` by running: |
| 68 | +# ``!pip install semilearn``. |
| 69 | +# |
67 | 70 | # Below is a list of functions we will use from ``semilearn``:
|
68 | 71 | #
|
69 | 72 | # - ``get_dataset`` to load dataset, here we use CIFAR-10
|
|
77 | 80 | # - ``Trainer``: a Trainer class for training and evaluating the
|
78 | 81 | # algorithm on dataset
|
79 | 82 | #
|
| 83 | +# Note that a CUDA-enabled backend is required for training with the ``semilearn`` package. |
| 84 | +# See `Enabling CUDA in Google Colab <https://pytorch.org/tutorials/beginner/colab#using-cuda>`__ for instructions |
| 85 | +# on enabling CUDA in Google Colab. |
| 86 | +# |
80 | 87 | import semilearn
|
81 | 88 | from semilearn import get_dataset, get_data_loader, get_net_builder, get_algorithm, get_config, Trainer
|
82 | 89 |
|
|
92 | 99 |
|
93 | 100 | # optimization configs
|
94 | 101 | 'epoch': 1,
|
95 |
| - 'num_train_iter': 4000, |
| 102 | + 'num_train_iter': 500, |
96 | 103 | 'num_eval_iter': 500,
|
97 | 104 | 'num_log_iter': 50,
|
98 | 105 | 'optim': 'AdamW',
|
|
141 | 148 |
|
142 | 149 | ######################################################################
|
143 | 150 | # We can start training the algorithms on CIFAR-10 with 40 labels now.
|
144 |
| -# We train for 4000 iterations and evaluate every 500 iterations. |
| 151 | +# We train for 500 iterations and evaluate every 500 iterations. |
145 | 152 | #
|
146 | 153 | trainer = Trainer(config, algorithm)
|
147 | 154 | trainer.fit(train_lb_loader, train_ulb_loader, eval_loader)
|
148 | 155 |
|
149 | 156 |
|
150 | 157 | ######################################################################
|
151 | 158 | # Finally, let's evaluate the trained model on the validation set.
|
152 |
| -# After training 4000 iterations with ``FreeMatch`` on only 40 labels of |
153 |
| -# CIFAR-10, we obtain a classifier that achieves above 93 accuracy on the validation set. |
| 159 | +# After training 500 iterations with ``FreeMatch`` on only 40 labels of |
| 160 | +# CIFAR-10, we obtain a classifier that achieves around 87% accuracy on the validation set. |
154 | 161 | trainer.evaluate(eval_loader)
|
155 | 162 |
|
156 | 163 |
|
|
174 | 181 |
|
175 | 182 | # optimization configs
|
176 | 183 | 'epoch': 1,
|
177 |
| - 'num_train_iter': 4000, |
| 184 | + 'num_train_iter': 500, |
178 | 185 | 'num_eval_iter': 500,
|
179 | 186 | 'num_log_iter': 50,
|
180 | 187 | 'optim': 'AdamW',
|
|
225 | 232 |
|
226 | 233 | ######################################################################
|
227 | 234 | # We can start Train the algorithms on CIFAR-10 with 40 labels now.
|
228 |
| -# We train for 4000 iterations and evaluate every 500 iterations. |
| 235 | +# We train for 500 iterations and evaluate every 500 iterations. |
229 | 236 | #
|
230 | 237 | trainer = Trainer(config, algorithm)
|
231 | 238 | trainer.fit(train_lb_loader, train_ulb_loader, eval_loader)
|
|
239 | 246 |
|
240 | 247 |
|
241 | 248 | ######################################################################
|
242 |
| -# References |
243 |
| -# [1] USB: https://github.com/microsoft/Semi-supervised-learning |
244 |
| -# [2] Kihyuk Sohn et al. FixMatch: Simplifying Semi-Supervised Learning with Consistency and Confidence |
245 |
| -# [3] Yidong Wang et al. FreeMatch: Self-adaptive Thresholding for Semi-supervised Learning |
246 |
| -# [4] Hao Chen et al. SoftMatch: Addressing the Quantity-Quality Trade-off in Semi-supervised Learning |
| 249 | +# References: |
| 250 | +# - [1] USB: https://github.com/microsoft/Semi-supervised-learning |
| 251 | +# - [2] Kihyuk Sohn et al. FixMatch: Simplifying Semi-Supervised Learning with Consistency and Confidence |
| 252 | +# - [3] Yidong Wang et al. FreeMatch: Self-adaptive Thresholding for Semi-supervised Learning |
| 253 | +# - [4] Hao Chen et al. SoftMatch: Addressing the Quantity-Quality Trade-off in Semi-supervised Learning |
0 commit comments