Skip to content

Update the ResNet50 article #922

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Feb 1, 2022
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ article.pytorch-article table tr th, article.pytorch-article table td {line-heig

A few weeks ago, TorchVision v0.11 was released packed with numerous new primitives, models and training recipe improvements which allowed achieving state-of-the-art (SOTA) results. The project was dubbed “[TorchVision with Batteries Included](https://github.com/pytorch/vision/issues/3911)” and aimed to modernize our library. We wanted to enable researchers to reproduce papers and conduct research more easily by using common building blocks. Moreover, we aspired to provide the necessary tools to Applied ML practitioners to train their models on their own data using the same SOTA techniques as in research. Finally, we wanted to refresh our pre-trained weights and offer better off-the-shelf models to our users, hoping that they would build better applications.

Though there is still much work to be done, we wanted to share with you some exciting results from the above work. We will showcase how one can use the new tools included in TorchVision to achieve state-of-the-art results on a highly competitive and well-studied architecture such as ResNet50 [[1]](https://arxiv.org/abs/1512.03385). We will share the exact recipe used to improve our baseline by over 4.5 accuracy points to reach a final top-1 accuracy of 80.7% and share the journey for deriving the new training process. Moreover, we will show that this recipe generalizes well to other model variants and families. We hope that the above will influence future research for developing stronger generalizable training methodologies and will inspire the community to adopt and contribute to our efforts.
Though there is still much work to be done, we wanted to share with you some exciting results from the above work. We will showcase how one can use the new tools included in TorchVision to achieve state-of-the-art results on a highly competitive and well-studied architecture such as ResNet50 [[1]](https://arxiv.org/abs/1512.03385). We will share the exact recipe used to improve our baseline by over 4.7 accuracy points to reach a final top-1 accuracy of 80.9% and share the journey for deriving the new training process. Moreover, we will show that this recipe generalizes well to other model variants and families. We hope that the above will influence future research for developing stronger generalizable training methodologies and will inspire the community to adopt and contribute to our efforts.

## The Results

Expand All @@ -20,15 +20,17 @@ Using our new training recipe found on ResNet50, we’ve refreshed the pre-train

| Model | Accuracy@1 | Accuracy@5|
|----------|:--------:|:----------:|
| ResNet50 | 80.674 | 95.166|
| ResNet50 | 80.858 | 95.434|
|----------|:--------:|:----------:|
| ResNet101 | 81.728 | 95.670|
| ResNet101 | 81.886 | 95.780|
|----------|:--------:|:----------:|
| ResNet152 | 82.042 | 95.926|
| ResNet152 | 82.284 | 96.002|
|----------|:--------:|:----------:|
| ResNeXt50-32x4d | 81.116 | 95.478|
| ResNeXt50-32x4d | 81.198 | 95.340|

Note that the accuracy of all models except RetNet50 can be further improved by adjusting their training parameters slightly, but our focus was to have a single robust recipe which performs well for all.
Note that the accuracy of all models except RetNet50 can be further improved by adjusting their training parameters slightly, but our focus was to have a single robust recipe which performs well for all.

**UPDATE:** We have refreshed the majority of popular classification models of TorchVision, you can find the details on this [blog post](https://pytorch.org/blog/introducing-torchvision-new-multi-weight-support-api/).

There are currently two ways to use the latest weights of the model.

Expand All @@ -42,10 +44,10 @@ We are currently working on a new prototype mechanism which will extend the mode
img = Image.open("test/assets/encode_jpeg/grace_hopper_517x606.jpg")

# Initialize model
weights = P.models.ResNet50_Weights.ImageNet1K_V2
weights = P.models.ResNet50_Weights.IMAGENET1K_V2
model = P.models.resnet50(weights=weights)
model.eval()
 

# Initialize inference transforms
preprocess = weights.transforms()

Expand All @@ -70,7 +72,7 @@ Those who don’t want to use a prototype API have the option of accessing the n
from torchvision.models import resnet

# Overwrite the URL of the previous weights
resnet.model_urls["resnet50"] = "https://download.pytorch.org/models/resnet50-f46c3f97.pth"
resnet.model_urls["resnet50"] = "https://download.pytorch.org/models/resnet50-11ad3fa6.pth"

# Initialize the model using the legacy API
model = resnet.resnet50(pretrained=True)
Expand All @@ -81,7 +83,7 @@ Those who don’t want to use a prototype API have the option of accessing the n

## The Training Recipe

Our goal was to use the newly introduced primitives of TorchVision to derive a new strong training recipe which achieves state-of-the-art results for the vanilla ResNet50 architecture when trained from scratch on ImageNet with no additional external data. Though by using architecture specific tricks [[2]](https://arxiv.org/abs/1812.01187) one could further improve the accuracy, we’ve decided not to include them so that the recipe can be used in other architectures. Our recipe heavily focuses on simplicity and builds upon work by FAIR [[3]](https://arxiv.org/abs/2103.06877), [[4]](https://arxiv.org/abs/2106.14881), [[5]](https://arxiv.org/abs/1906.06423), [[6]](https://arxiv.org/abs/2012.12877), [[7]](https://arxiv.org/abs/2110.00476)]. Our findings align with the parallel study of Wightman et al. [[7]](https://arxiv.org/abs/2110.00476), who also report major accuracy improvements by focusing on the training recipes.
Our goal was to use the newly introduced primitives of TorchVision to derive a new strong training recipe which achieves state-of-the-art results for the vanilla ResNet50 architecture when trained from scratch on ImageNet with no additional external data. Though by using architecture specific tricks [[2]](https://arxiv.org/abs/1812.01187) one could further improve the accuracy, we’ve decided not to include them so that the recipe can be used in other architectures. Our recipe heavily focuses on simplicity and builds upon work by FAIR [[3]](https://arxiv.org/abs/2103.06877), [[4]](https://arxiv.org/abs/2106.14881), [[5]](https://arxiv.org/abs/1906.06423), [[6]](https://arxiv.org/abs/2012.12877), [[7]](https://arxiv.org/abs/2110.00476). Our findings align with the parallel study of Wightman et al. [[7]](https://arxiv.org/abs/2110.00476), who also report major accuracy improvements by focusing on the training recipes.

Without further ado, here are the main parameters of our recipe:

Expand Down Expand Up @@ -110,6 +112,9 @@ Without further ado, here are the main parameters of our recipe:
cutmix_alpha=1.0,
auto_augment='ta_wide',
random_erase=0.1,

ra_sampler=True,
ra_reps=4,


# EMA configuration
Expand All @@ -132,7 +137,7 @@ torchrun --nproc_per_node=8 train.py --model resnet50 --batch-size 128 --lr 0.5
--lr-scheduler cosineannealinglr --lr-warmup-epochs 5 --lr-warmup-method linear \
--auto-augment ta_wide --epochs 600 --random-erase 0.1 --weight-decay 0.00002 \
--norm-weight-decay 0.0 --label-smoothing 0.1 --mixup-alpha 0.2 --cutmix-alpha 1.0 \
--train-crop-size 176 --model-ema --val-resize-size 232
--train-crop-size 176 --model-ema --val-resize-size 232 --ra-sampler --ra-reps 4
```

## Methodology
Expand Down Expand Up @@ -180,9 +185,13 @@ In the table below, we provide a summary of the performance of stacked increment
|+ EMA |80.450|94.908| 0.254|4.320|
|----------|:--------:|:----------:|:---------|:--------:|
| + Inference Resize tuning * |80.674|95.166| 0.224|4.544|
|----------|:--------:|:----------:|:---------|:--------:|
| + Repeated Augmentation ** |80.858|95.434| 0.184|4.728|

*The tuning of the inference size was done on top of the last model. See below for details.

** Community contribution done after the release of the article. See below for details.

## Baseline

Our baseline is the previously released ResNet50 model of TorchVision. It was trained with the following recipe:
Expand Down Expand Up @@ -259,7 +268,7 @@ This further increases our top-1 Accuracy by 1.8 points on top of the previous s

## Random Erasing

Another data augmentation technique known to help the classification accuracy is Random Erasing [[10]](https://arxiv.org/abs/1708.04896), [[11]](https://arxiv.org/abs/1708.04552)]. Often paired with Automatic Augmentation methods, it usually yields additional improvements in accuracy due to its regularization effect. In our experiments we tuned only the probability of applying the method via a grid search and found that it’s beneficial to keep its probability at low levels, typically around 10%. 
Another data augmentation technique known to help the classification accuracy is Random Erasing [[10]](https://arxiv.org/abs/1708.04896), [[11]](https://arxiv.org/abs/1708.04552). Often paired with Automatic Augmentation methods, it usually yields additional improvements in accuracy due to its regularization effect. In our experiments we tuned only the probability of applying the method via a grid search and found that it’s beneficial to keep its probability at low levels, typically around 10%. 

Here is the extra parameter introduced on top of the previous:

Expand All @@ -283,7 +292,7 @@ We use PyTorch’s newly introduced [CrossEntropyLoss](https://pytorch.org/docs

## Mixup and Cutmix

Two data augmentation techniques often used to produce SOTA results are Mixup and Cutmix [[13]](https://arxiv.org/abs/1710.09412), [[14]](https://arxiv.org/abs/1905.04899)]. They both provide strong regularization effects by softening not only the labels but also the images. In our setup we found it beneficial to apply one of them randomly with equal probability. Each is parameterized with a hyperparameter alpha, which controls the shape of the Beta distribution from which the smoothing probability is sampled. We did a very limited grid search, focusing primarily on common values proposed on the papers. 
Two data augmentation techniques often used to produce SOTA results are Mixup and Cutmix [[13]](https://arxiv.org/abs/1710.09412), [[14]](https://arxiv.org/abs/1905.04899). They both provide strong regularization effects by softening not only the labels but also the images. In our setup we found it beneficial to apply one of them randomly with equal probability. Each is parameterized with a hyperparameter alpha, which controls the shape of the Beta distribution from which the smoothing probability is sampled. We did a very limited grid search, focusing primarily on common values proposed on the papers. 

Below you will find the optimal values for the alpha parameters of the two techniques:

Expand All @@ -309,7 +318,7 @@ The above update improves our accuracy by a further 0.526 points, providing add

## FixRes mitigations

An important property identified early in our experiments is the fact that the models performed significantly better if the resolution used during validation was increased from the 224x224 of training. This effect is studied in detail on the FixRes paper [5](https://arxiv.org/abs/1906.06423) and two mitigations are proposed: a) one could try to reduce the training resolution so that the accuracy on the validation resolution is maximized or b) one could fine-tune the model on a two-phase training so that it adjusts on the target resolution. Since we didn’t want to introduce a 2-phase training, we went for option a). This means that we reduced the train crop size from 224 and used grid search to find the one that maximizes the validation on resolution of 224x224.
An important property identified early in our experiments is the fact that the models performed significantly better if the resolution used during validation was increased from the 224x224 of training. This effect is studied in detail on the FixRes paper [[5]](https://arxiv.org/abs/1906.06423) and two mitigations are proposed: a) one could try to reduce the training resolution so that the accuracy on the validation resolution is maximized or b) one could fine-tune the model on a two-phase training so that it adjusts on the target resolution. Since we didn’t want to introduce a 2-phase training, we went for option a). This means that we reduced the train crop size from 224 and used grid search to find the one that maximizes the validation on resolution of 224x224.

Below you can see the optimal value used on our recipe:

Expand Down Expand Up @@ -348,10 +357,10 @@ Unlike all other steps of the process which involved training models with differ
Below you can see the optimal value used on our recipe:

```
--val-resize-size 232
val_resize_size=232,
```

The above is the final optimization which improved our accuracy by 0.224 points. It’s worth noting that the optimal value for ResNet50 works also best for ResNet101, ResNet152 and ResNeXt50, which hints that it generalizes across models:
The above is an optimization which improved our accuracy by 0.224 points. It’s worth noting that the optimal value for ResNet50 works also best for ResNet101, ResNet152 and ResNeXt50, which hints that it generalizes across models:


<div style="display: flex">
Expand All @@ -360,6 +369,19 @@ The above is the final optimization which improved our accuracy by 0.224 points
<img src="/assets/images/sota/ResNet152 Inference Resize.png" alt="Best ResNet50 trained with 224 Resolution" width="30%">
</div>

## [UPDATE] Repeated Augmentation

Repeated Augmentation [[15]](https://arxiv.org/abs/1901.09335), [[16]](https://arxiv.org/abs/1902.05509) is another technique which can improve the overall accuracy and has been used by other strong recipes such as those at [[6]](https://arxiv.org/abs/2012.12877), [[7]](https://arxiv.org/abs/2110.00476). Tal Ben-Nun, a community contributor, has [further improved](https://github.com/pytorch/vision/pull/5201) upon our original recipe by proposing training the model with 4 repetitions. His contribution came after the release of this article.

Below you can see the optimal value used on our recipe:

```
ra_sampler=True,
ra_reps=4,
```

The above is the final optimization which improved our accuracy by 0.184 points. 

## Optimizations that were tested but not adopted

During the early stages of our research, we experimented with additional techniques, configurations and optimizations. Since our target was to keep our recipe as simple as possible, we decided not to include anything that didn’t provide a significant improvement. Here are a few approaches that we took but didn’t make it to our final recipe:
Expand All @@ -372,7 +394,7 @@ During the early stages of our research, we experimented with additional techniq

## Acknowledgements

We would like to thank Piotr Dollar, Mannat Singh and Hugo Touvron for providing their insights and feedback during the development of the recipe and for their previous research work on which our recipe is based on. Their support was invaluable for achieving the above result. Moreover, we would like to thank Prabhat Roy, Kai Zhang, Yiwen Song, Joel Schlosser, Ilqar Ramazanli, Francisco Massa, Mannat Singh, Xiaoliang Dai, Samuel Gabriel and Allen Goodman for their contributions to the Batteries Included project.
We would like to thank Piotr Dollar, Mannat Singh and Hugo Touvron for providing their insights and feedback during the development of the recipe and for their previous research work on which our recipe is based on. Their support was invaluable for achieving the above result. Moreover, we would like to thank Prabhat Roy, Kai Zhang, Yiwen Song, Joel Schlosser, Ilqar Ramazanli, Francisco Massa, Mannat Singh, Xiaoliang Dai, Samuel Gabriel, Allen Goodman and Tal Ben-Nun for their contributions to the Batteries Included project.

## References

Expand All @@ -390,3 +412,5 @@ We would like to thank Piotr Dollar, Mannat Singh and Hugo Touvron for providin
12. Christian Szegedy, Vincent Vanhoucke, Sergey Ioffe, Jon Shlens, Zbigniew Wojna. “Rethinking the Inception Architecture for Computer Vision”
13. Hongyi Zhang, Moustapha Cisse, Yann N. Dauphin, David Lopez-Paz. “mixup: Beyond Empirical Risk Minimization”
14. Sangdoo Yun, Dongyoon Han, Seong Joon Oh, Sanghyuk Chun, Junsuk Choe, Youngjoon Yoo. “CutMix: Regularization Strategy to Train Strong Classifiers with Localizable Features”
15. Elad Hoffer, Tal Ben-Nun, Itay Hubara, Niv Giladi, Torsten Hoefler, Daniel Soudry. “Augment your batch: better training with larger batches”
16. Maxim Berman, Hervé Jégou, Andrea Vedaldi, Iasonas Kokkinos, Matthijs Douze. “Multigrain: a unified image embedding for classes and instances”