|
1 | 1 | """
|
2 | 2 | Semi-Supervised Learning using USB built upon PyTorch
|
3 |
| -============================= |
4 |
| -
|
| 3 | +===================================================== |
5 | 4 |
|
6 | 5 | **Author**: `Hao Chen <https://github.com/Hhhhhhao>`_
|
7 |
| - |
8 |
| -
|
9 |
| -Introduction |
10 |
| ------------- |
11 | 6 |
|
12 |
| -USB is a semi-supervised learning framework built upon PyTorch. |
13 |
| -Based on Datasets and Modules provided by PyTorch, USB becomes a flexible, modular, and easy-to-use framework for semi-supervised learning. |
14 |
| -It supports a variety of semi-supervised learning algorithms, including FixMatch, FreeMatch, DeFixMatch, SoftMatch, etc. |
| 7 | +Unified Semi-supervised learning Benchmark (USB) is a semi-supervised |
| 8 | +learning framework built upon PyTorch. |
| 9 | +Based on Datasets and Modules provided by PyTorch, USB becomes a flexible, |
| 10 | +modular, and easy-to-use framework for semi-supervised learning. |
| 11 | +It supports a variety of semi-supervised learning algorithms, including |
| 12 | +``FixMatch``, ``FreeMatch``, ``DeFixMatch``, ``SoftMatch``, and so on. |
15 | 13 | It also supports a variety of imbalanced semi-supervised learning algorithms.
|
16 |
| -The benchmark results across different datasets of computer vision, natural language processing, and speech processing are included in USB. |
| 14 | +The benchmark results across different datasets of computer vision, natural |
| 15 | +language processing, and speech processing are included in USB. |
| 16 | +
|
| 17 | +This tutorial will walk you through the basics of using the USB lighting |
| 18 | +package. |
| 19 | +Let's get started by training a ``FreeMatch``/``SoftMatch`` model on |
| 20 | +CIFAR-10 using pretrained ViT! |
| 21 | +And we will show it is easy to change the semi-supervised algorithm and train |
| 22 | +on imbalanced datasets. |
17 | 23 |
|
18 |
| -This tutorial will walk you through the basics of using the usb lighting package. |
19 |
| -Let's get started by training a FreeMatch/SoftMatch model on CIFAR-10 using pre-trained ViT! |
20 |
| -And we will show it is easy to change the semi-supervised algorithm and train on imbalanced datasets. |
21 | 24 |
|
22 |
| - |
23 | 25 | .. figure:: /_static/img/usb_semisup_learn/code.png
|
24 | 26 | :alt: USB framework illustration
|
25 | 27 | """
|
26 | 28 |
|
27 | 29 |
|
28 | 30 | ######################################################################
|
29 |
| -# Introduction to FreeMatch and SoftMatch in Semi-Supervised Learning |
30 |
| -# -------------------- |
31 |
| -# Here we provide a brief introduction to FreeMatch and SoftMatch. |
32 |
| -# First we introduce a famous baseline for semi-supervised learning called FixMatch. |
33 |
| -# FixMatch is a very simple framework for semi-supervised learning, where it utilizes a strong augmentation to generate pseudo labels for unlabeled data. |
34 |
| -# It adopts a confidence thresholding strategy to filter out the low-confidence pseudo labels with a fixed threshold set. |
35 |
| -# FreeMatch and SoftMatch are two algorithms that improve upon FixMatch. |
36 |
| -# FreeMatch proposes adaptive thresholding strategy to replace the fixed thresholding strategy in FixMatch. |
37 |
| -# The adaptive thresholding progressively increases the threshold according to the learning status of the model on each class. |
38 |
| -# SoftMatch absorbs the idea of confidence thresholding as an weighting mechanism. |
39 |
| -# It proposes a Gaussian weighting mechanism to overcome the quantity-quality trade-off in pseudo-labels. |
40 |
| -# In this tutorial, we will use USB to train FreeMatch and SoftMatch. |
| 31 | +# Introduction to ``FreeMatch`` and ``SoftMatch`` in Semi-Supervised Learning |
| 32 | +# --------------------------------------------------------------------------- |
| 33 | +# |
| 34 | +# Here we provide a brief introduction to ``FreeMatch`` and ``SoftMatch``. |
| 35 | +# First, we introduce a famous baseline for semi-supervised learning called ``FixMatch``. |
| 36 | +# ``FixMatch`` is a very simple framework for semi-supervised learning, where it |
| 37 | +# utilizes a strong augmentation to generate pseudo labels for unlabeled data. |
| 38 | +# It adopts a confidence thresholding strategy to filter out the low-confidence |
| 39 | +# pseudo labels with a fixed threshold set. |
| 40 | +# ``FreeMatch`` and ``SoftMatch`` are two algorithms that improve upon ``FixMatch``. |
| 41 | +# ``FreeMatch`` proposes adaptive thresholding strategy to replace the fixed |
| 42 | +# thresholding strategy in ``FixMatch``. The adaptive thresholding progressively |
| 43 | +# increases the threshold according to the learning status of the model on each |
| 44 | +# class. ``SoftMatch`` absorbs the idea of confidence thresholding as an |
| 45 | +# weighting mechanism. It proposes a Gaussian weighting mechanism to overcome |
| 46 | +# the quantity-quality trade-off in pseudo-labels. In this tutorial, we will |
| 47 | +# use USB to train ``FreeMatch`` and ``SoftMatch``. |
41 | 48 |
|
42 | 49 |
|
43 | 50 | ######################################################################
|
44 |
| -# Use USB to Train FreeMatch/SoftMatch on CIFAR-10 with only 40 labels |
45 |
| -# -------------------- |
46 |
| -# USB is a Pytorch-based Python package for Semi-Supervised Learning (SSL). |
47 |
| -# It is easy-to-use/extend, affordable to small groups, and comprehensive for developing and evaluating SSL algorithms. |
48 |
| -# USB provides the implementation of 14 SSL algorithms based on Consistency Regularization, and 15 tasks for evaluation from CV, NLP, and Audio domain. |
49 |
| -# It has a modular design that allows users to easily extend the package by adding new algorithms and tasks. |
50 |
| -# It also supports a python api for easier adaptation to different SSL algorithms on new data. |
51 |
| -# |
52 |
| -# |
53 |
| -# Now, let's use USB to train FreeMatch and SoftMatch on CIFAR-10. |
54 |
| -# First, we need to install USB package ``semilearn`` and import necessary api functions from USB. |
| 51 | +# Use USB to Train ``FreeMatch``/``SoftMatch`` on CIFAR-10 with only 40 labels |
| 52 | +# ---------------------------------------------------------------------------- |
| 53 | +# |
| 54 | +# USB is easy to use and extend, affordable to small groups, and comprehensive |
| 55 | +# for developing and evaluating SSL algorithms. |
| 56 | +# USB provides the implementation of 14 SSL algorithms based on Consistency |
| 57 | +# Regularization, and 15 tasks for evaluation from CV, NLP, and Audio domain. |
| 58 | +# It has a modular design that allows users to easily extend the package by |
| 59 | +# adding new algorithms and tasks. |
| 60 | +# It also supports a Python API for easier adaptation to different SSL |
| 61 | +# algorithms on new data. |
| 62 | +# |
| 63 | +# |
| 64 | +# Now, let's use USB to train ``FreeMatch`` and ``SoftMatch`` on CIFAR-10. |
| 65 | +# First, we need to install USB package ``semilearn`` and import necessary API |
| 66 | +# functions from USB. |
55 | 67 | # Below is a list of functions we will use from ``semilearn``:
|
| 68 | +# |
56 | 69 | # - ``get_dataset`` to load dataset, here we use CIFAR-10
|
57 |
| -# - ``get_data_loader`` to create train (labeled and unlabeled) and test data loaders, the train unlabeled loaders will provide both strong and weak augmentation of unlabeled data |
58 |
| -# - ``get_net_builder`` to create a model, here we use pre-trained ViT |
59 |
| -# - ``get_algorithm`` to create the semi-supervised learning algorithm, here we use FreeMatch and SoftMatch |
| 70 | +# - ``get_data_loader`` to create train (labeled and unlabeled) and test data |
| 71 | +# loaders, the train unlabeled loaders will provide both strong and weak |
| 72 | +# augmentation of unlabeled data |
| 73 | +# - ``get_net_builder`` to create a model, here we use pretrained ViT |
| 74 | +# - ``get_algorithm`` to create the semi-supervised learning algorithm, |
| 75 | +# here we use ``FreeMatch`` and ``SoftMatch`` |
60 | 76 | # - ``get_config``: to get default configuration of the algorithm
|
61 |
| -# - ``Trainer``: a Trainer class for training and evaluating the algorithm on dataset |
| 77 | +# - ``Trainer``: a Trainer class for training and evaluating the |
| 78 | +# algorithm on dataset |
62 | 79 | #
|
63 | 80 | import semilearn
|
64 | 81 | from semilearn import get_dataset, get_data_loader, get_net_builder, get_algorithm, get_config, Trainer
|
65 | 82 |
|
66 | 83 | ######################################################################
|
67 |
| -# After importing necessary functions, we first set the hyper-parameters of the algorithm. |
| 84 | +# After importing necessary functions, we first set the hyper-parameters of the |
| 85 | +# algorithm. |
68 | 86 | #
|
69 | 87 | config = {
|
70 | 88 | 'algorithm': 'freematch',
|
|
122 | 140 |
|
123 | 141 |
|
124 | 142 | ######################################################################
|
125 |
| -# We can start Train the algorithms on CIFAR-10 with 40 labels now. |
| 143 | +# We can start training the algorithms on CIFAR-10 with 40 labels now. |
126 | 144 | # We train for 4000 iterations and evaluate every 500 iterations.
|
127 | 145 | #
|
128 | 146 | trainer = Trainer(config, algorithm)
|
129 | 147 | trainer.fit(train_lb_loader, train_ulb_loader, eval_loader)
|
130 | 148 |
|
131 | 149 |
|
132 | 150 | ######################################################################
|
133 |
| -# Finally, let's evaluate the trained model on validation set. |
134 |
| -# After training 4000 iterations with FreeMatch on only 40 labels of CIFAR-10, we obtain a classifier that achieves above 93 accuracy on validation set. |
| 151 | +# Finally, let's evaluate the trained model on the validation set. |
| 152 | +# After training 4000 iterations with ``FreeMatch`` on only 40 labels of |
| 153 | +# CIFAR-10, we obtain a classifier that achieves above 93 accuracy on the validation set. |
135 | 154 | trainer.evaluate(eval_loader)
|
136 | 155 |
|
137 | 156 |
|
138 | 157 |
|
139 | 158 | ######################################################################
|
140 |
| -# Use USB to Train SoftMatch with specific imbalanced algorithm on imbalanced CIFAR-10 |
141 |
| -# -------------------- |
| 159 | +# Use USB to Train ``SoftMatch`` with specific imbalanced algorithm on imbalanced CIFAR-10 |
| 160 | +# ------------------------------------------------------------------------------------ |
142 | 161 | #
|
143 |
| -# Now let's say we have imbalanced labeled set and unlabeled set of CIFAR-10, and we want to train a SoftMatch model on it. |
144 |
| -# We create an imbalanced labeled set and imbalanced unlabeled set of CIFAR-10, by setting the ``lb_imb_ratio`` and ``ulb_imb_ratio`` to 10. |
145 |
| -# Also we replace the ``algorithm`` with ``softmatch`` and set the ``imbalanced`` to ``True``. |
| 162 | +# Now let's say we have imbalanced labeled set and unlabeled set of CIFAR-10, |
| 163 | +# and we want to train a ``SoftMatch`` model on it. |
| 164 | +# We create an imbalanced labeled set and imbalanced unlabeled set of CIFAR-10, |
| 165 | +# by setting the ``lb_imb_ratio`` and ``ulb_imb_ratio`` to 10. |
| 166 | +# Also, we replace the ``algorithm`` with ``softmatch`` and set the ``imbalanced`` |
| 167 | +# to ``True``. |
146 | 168 | #
|
147 | 169 | config = {
|
148 | 170 | 'algorithm': 'softmatch',
|
|
210 | 232 |
|
211 | 233 |
|
212 | 234 | ######################################################################
|
213 |
| -# Finally, let's evaluate the trained model on validation set. |
| 235 | +# Finally, let's evaluate the trained model on the validation set. |
214 | 236 | #
|
215 | 237 | trainer.evaluate(eval_loader)
|
216 | 238 |
|
|
0 commit comments