From 7434662a30de1aa6df16657749544577908182ed Mon Sep 17 00:00:00 2001 From: Vasilis Vryniotis Date: Tue, 25 Jan 2022 09:32:23 +0000 Subject: [PATCH] Updating code on the article. --- ...orchvision-new-multi-weight-support-api.md | 38 +++++++++---------- 1 file changed, 19 insertions(+), 19 deletions(-) diff --git a/_posts/2021-12-22-introducing-torchvision-new-multi-weight-support-api.md b/_posts/2021-12-22-introducing-torchvision-new-multi-weight-support-api.md index 6ca7cb4767c8..d77d518f4765 100644 --- a/_posts/2021-12-22-introducing-torchvision-new-multi-weight-support-api.md +++ b/_posts/2021-12-22-introducing-torchvision-new-multi-weight-support-api.md @@ -74,7 +74,7 @@ from torchvision.prototype import models as PM img = Image.open("test/assets/encode_jpeg/grace_hopper_517x606.jpg") # Step 1: Initialize model -weights = PM.ResNet50_Weights.ImageNet1K_V1 +weights = PM.ResNet50_Weights.IMAGENET1K_V1 model = PM.resnet50(weights=weights) model.eval() @@ -96,7 +96,7 @@ As we can see the new API eliminates the aforementioned limitations. Let’s exp ### Multi-weight support -At the heart of the new API, we have the ability to define multiple different weights for the same model variant. Each model building method (eg `resnet50`) has an associated Enum class (eg `ResNet50_Weights`) which has as many entries as the number of pre-trained weights available. Additionally, each Enum class has a `default` alias which points to the best available weights for the specific model. This allows the users who want to always use the best available weights to do so without modifying their code. +At the heart of the new API, we have the ability to define multiple different weights for the same model variant. Each model building method (eg `resnet50`) has an associated Enum class (eg `ResNet50_Weights`) which has as many entries as the number of pre-trained weights available. Additionally, each Enum class has a `DEFAULT` alias which points to the best available weights for the specific model. This allows the users who want to always use the best available weights to do so without modifying their code. Here is an example of initializing models with different weights: @@ -104,13 +104,13 @@ Here is an example of initializing models with different weights: from torchvision.prototype.models import resnet50, ResNet50_Weights # Legacy weights with accuracy 76.130% -model = resnet50(weights=ResNet50_Weights.ImageNet1K_V1) +model = resnet50(weights=ResNet50_Weights.IMAGENET1K_V1) -# New weights with accuracy 80.674% -model = resnet50(weights=ResNet50_Weights.ImageNet1K_V2) +# New weights with accuracy 80.858% +model = resnet50(weights=ResNet50_Weights.IMAGENET1K_V2) -# Best available weights (currently alias for ImageNet1K_V2) -model = resnet50(weights=ResNet50_Weights.default) +# Best available weights (currently alias for IMAGENET1K_V2) +model = resnet50(weights=ResNet50_Weights.DEFAULT) # No weights - random initialization model = resnet50(weights=None) @@ -124,10 +124,10 @@ The weights of each model are associated with meta-data. The type of information from torchvision.prototype.models import ResNet50_Weights # Accessing a single record -size = ResNet50_Weights.ImageNet1K_V2.meta["size"] +size = ResNet50_Weights.IMAGENET1K_V2.meta["size"] # Iterating the items of the meta-data dictionary -for k, v in ResNet50_Weights.ImageNet1K_V2.meta.items(): +for k, v in ResNet50_Weights.IMAGENET1K_V2.meta.items(): print(k, v) ``` @@ -137,10 +137,10 @@ Additionally, each weights entry is associated with the necessary preprocessing from torchvision.prototype.models import ResNet50_Weights # Initializing preprocessing at standard 224x224 resolution -preprocess = ResNet50_Weights.ImageNet1K.transforms() +preprocess = ResNet50_Weights.IMAGENET1K_V2.transforms() # Initializing preprocessing at 400x400 resolution -preprocess = ResNet50_Weights.ImageNet1K.transforms(crop_size=400, resize_size=400) +preprocess = ResNet50_Weights.IMAGENET1K_V2.transforms(crop_size=400, resize_size=400) # Once initialized the callable can accept the image data: # img_preprocessed = preprocess(img) @@ -156,11 +156,11 @@ The ability to link directly the weights with their properties (meta data, prepr from torchvision.prototype.models import get_weight # Weights can be retrieved by name: -assert get_weight("ResNet50_Weights.ImageNet1K_V1") == ResNet50_Weights.ImageNet1K_V1 -assert get_weight("ResNet50_Weights.ImageNet1K_V2") == ResNet50_Weights.ImageNet1K_V2 +assert get_weight("ResNet50_Weights.IMAGENET1K_V1") == ResNet50_Weights.IMAGENET1K_V1 +assert get_weight("ResNet50_Weights.IMAGENET1K_V2") == ResNet50_Weights.IMAGENET1K_V2 -# Including using the default alias: -assert get_weight("ResNet50_Weights.default") == ResNet50_Weights.ImageNet1K_V2 +# Including using the DEFAULT alias: +assert get_weight("ResNet50_Weights.DEFAULT") == ResNet50_Weights.IMAGENET1K_V2 ``` ## Deprecations @@ -172,8 +172,8 @@ In the new API the boolean `pretrained` and `pretrained_backbone` parameters, wh UserWarning: The parameter 'pretrained' is deprecated, please use 'weights' instead. UserWarning: Arguments other than a weight enum or `None` for 'weights' are deprecated. -The current behavior is equivalent to passing `weights=ResNet50_Weights.ImageNet1K_V1`. -You can also use `weights=ResNet50_Weights.default` to get the most up-to-date weights. +The current behavior is equivalent to passing `weights=ResNet50_Weights.IMAGENET1K_V1`. +You can also use `weights=ResNet50_Weights.DEFAULT` to get the most up-to-date weights. ``` Additionally the builder methods require using keyword parameters. The use of positional parameter is deprecated and using them emits the following warning: @@ -191,7 +191,7 @@ Migrating to the new API is very straightforward. The following method calls bet ``` # Using pretrained weights: -torchvision.prototype.models.resnet50(weights=ResNet50_Weights.ImageNet1K_V1) +torchvision.prototype.models.resnet50(weights=ResNet50_Weights.IMAGENET1K_V1) torchvision.models.resnet50(pretrained=True) torchvision.models.resnet50(True) @@ -237,7 +237,7 @@ If you are still unconvinced about giving a try to the new API, here is one more |RegNet Y 8gf |80.032 |82.828 | |RegNet Y 16gf |80.424 |82.89 | |RegNet Y 32gf |80.878 |83.366 | -|ResNet50 |76.13 |80.674 | +|ResNet50 |76.13 |80.858 | |ResNet101 |77.374 |81.886 | |ResNet152 |78.312 |82.284 | |ResNeXt50 32x4d |77.618 |81.198 |