|
1 |
| -from typing import Optional |
2 |
| - |
3 |
| -import libcst as cst |
4 |
| -import libcst.matchers as m |
5 |
| -from libcst.codemod.visitors import ImportItem |
6 |
| - |
7 |
| -from ...common import LintViolation, TorchVisitor |
8 |
| - |
9 |
| - |
10 |
| -class TorchVisionDeprecatedPretrainedVisitor(TorchVisitor): |
11 |
| - """ |
12 |
| - Find and fix deprecated `pretrained` parameters in TorchVision models. |
13 |
| -
|
14 |
| - Both `pretrained` and `pretrained_backbone` parameters are supported. |
15 |
| - The parameters are updated to the new `weights` and `weights_backbone` parameters |
16 |
| - only if the old parameter has explicit literal `True` or `False` value, |
17 |
| - otherwise only lint violation is emitted. |
18 |
| - """ |
19 |
| - |
20 |
| - ERROR_CODE = "TOR201" |
21 |
| - |
22 |
| - # flake8: noqa: E105 |
23 |
| - # fmt: off |
24 |
| - MODEL_WEIGHTS = { |
25 |
| - ("mobilenet_v2", "pretrained"): "MobileNet_V2_Weights.IMAGENET1K_V1", |
26 |
| - ("mobilenet_v3_large", "pretrained"): "MobileNet_V3_Large_Weights.IMAGENET1K_V1", |
27 |
| - ("mobilenet_v3_small", "pretrained"): "MobileNet_V3_Small_Weights.IMAGENET1K_V1", |
28 |
| - ("densenet121", "pretrained"): "DenseNet121_Weights.IMAGENET1K_V1", |
29 |
| - ("densenet161", "pretrained"): "DenseNet161_Weights.IMAGENET1K_V1", |
30 |
| - ("densenet169", "pretrained"): "DenseNet169_Weights.IMAGENET1K_V1", |
31 |
| - ("densenet201", "pretrained"): "DenseNet201_Weights.IMAGENET1K_V1", |
32 |
| - ("detection.maskrcnn_resnet50_fpn", "pretrained"): "MaskRCNN_ResNet50_FPN_Weights.COCO_V1", |
33 |
| - ("detection.maskrcnn_resnet50_fpn", "pretrained_backbone"): "ResNet50_Weights.IMAGENET1K_V1", |
34 |
| - ("detection.maskrcnn_resnet50_fpn_v2", "pretrained"): "MaskRCNN_ResNet50_FPN_V2_Weights.COCO_V1", |
35 |
| - ("detection.maskrcnn_resnet50_fpn_v2", "pretrained_backbone"): "ResNet50_Weights.IMAGENET1K_V1", |
36 |
| - ("detection.retinanet_resnet50_fpn", "pretrained"): "RetinaNet_ResNet50_FPN_Weights.COCO_V1", |
37 |
| - ("detection.retinanet_resnet50_fpn", "pretrained_backbone"): "ResNet50_Weights.IMAGENET1K_V1", |
38 |
| - ("detection.retinanet_resnet50_fpn_v2", "pretrained"): "RetinaNet_ResNet50_FPN_V2_Weights.COCO_V1", |
39 |
| - ("detection.retinanet_resnet50_fpn_v2", "pretrained_backbone"): "ResNet50_Weights.IMAGENET1K_V1", |
40 |
| - ("optical_flow.raft_large", "pretrained"): "Raft_Large_Weights.C_T_SKHT_V2", |
41 |
| - ("optical_flow.raft_small", "pretrained"): "Raft_Small_Weights.C_T_V2", |
42 |
| - ("alexnet", "pretrained"): "AlexNet_Weights.IMAGENET1K_V1", |
43 |
| - ("convnext_tiny", "pretrained"): "ConvNeXt_Tiny_Weights.IMAGENET1K_V1", |
44 |
| - ("convnext_small", "pretrained"): "ConvNeXt_Small_Weights.IMAGENET1K_V1", |
45 |
| - ("convnext_base", "pretrained"): "ConvNeXt_Base_Weights.IMAGENET1K_V1", |
46 |
| - ("convnext_large", "pretrained"): "ConvNeXt_Large_Weights.IMAGENET1K_V1", |
47 |
| - ("inception_v3", "pretrained"): "Inception_V3_Weights.IMAGENET1K_V1", |
48 |
| - ("maxvit_t", "pretrained"): "MaxVit_T_Weights.IMAGENET1K_V1", |
49 |
| - ("mnasnet0_5", "pretrained"): "MNASNet0_5_Weights.IMAGENET1K_V1", |
50 |
| - ("mnasnet0_75", "pretrained"): "MNASNet0_75_Weights.IMAGENET1K_V1", |
51 |
| - ("mnasnet1_0", "pretrained"): "MNASNet1_0_Weights.IMAGENET1K_V1", |
52 |
| - ("mnasnet1_3", "pretrained"): "MNASNet1_3_Weights.IMAGENET1K_V1", |
53 |
| - ("detection.fasterrcnn_resnet50_fpn", "pretrained"): "FasterRCNN_ResNet50_FPN_Weights.COCO_V1", |
54 |
| - ("detection.fasterrcnn_resnet50_fpn", "pretrained_backbone"): "ResNet50_Weights.IMAGENET1K_V1", |
55 |
| - ("detection.fasterrcnn_resnet50_fpn_v2", "pretrained"): "FasterRCNN_ResNet50_FPN_V2_Weights.COCO_V1", |
56 |
| - ("detection.fasterrcnn_resnet50_fpn_v2", "pretrained_backbone"): "ResNet50_Weights.IMAGENET1K_V1", |
57 |
| - ("detection.fasterrcnn_mobilenet_v3_large_320_fpn", "pretrained"): "FasterRCNN_MobileNet_V3_Large_320_FPN_Weights.COCO_V1", |
58 |
| - ("detection.fasterrcnn_mobilenet_v3_large_320_fpn", "pretrained_backbone"): "MobileNet_V3_Large_Weights.IMAGENET1K_V1", |
59 |
| - ("detection.fasterrcnn_mobilenet_v3_large_fpn", "pretrained"): "FasterRCNN_MobileNet_V3_Large_FPN_Weights.COCO_V1", |
60 |
| - ("detection.fasterrcnn_mobilenet_v3_large_fpn", "pretrained_backbone"): "MobileNet_V3_Large_Weights.IMAGENET1K_V1", |
61 |
| - ("detection.fcos_resnet50_fpn", "pretrained"): "FCOS_ResNet50_FPN_Weights.COCO_V1", |
62 |
| - ("detection.fcos_resnet50_fpn", "pretrained_backbone"): "ResNet50_Weights.IMAGENET1K_V1", |
63 |
| - ("segmentation.lraspp_mobilenet_v3_large", "pretrained"): "LRASPP_MobileNet_V3_Large_Weights.COCO_WITH_VOC_LABELS_V1", |
64 |
| - ("segmentation.lraspp_mobilenet_v3_large", "pretrained_backbone"): "MobileNet_V3_Large_Weights.IMAGENET1K_V1", |
65 |
| - ("shufflenet_v2_x0_5", "pretrained"): "ShuffleNet_V2_X0_5_Weights.IMAGENET1K_V1", |
66 |
| - ("shufflenet_v2_x1_0", "pretrained"): "ShuffleNet_V2_X1_0_Weights.IMAGENET1K_V1", |
67 |
| - ("shufflenet_v2_x1_5", "pretrained"): "ShuffleNet_V2_X1_5_Weights.IMAGENET1K_V1", |
68 |
| - ("shufflenet_v2_x2_0", "pretrained"): "ShuffleNet_V2_X2_0_Weights.IMAGENET1K_V1", |
69 |
| - ("squeezenet1_0", "pretrained"): "SqueezeNet1_0_Weights.IMAGENET1K_V1", |
70 |
| - ("squeezenet1_1", "pretrained"): "SqueezeNet1_1_Weights.IMAGENET1K_V1", |
71 |
| - ("swin_t", "pretrained"): "Swin_T_Weights.IMAGENET1K_V1", |
72 |
| - ("swin_s", "pretrained"): "Swin_S_Weights.IMAGENET1K_V1", |
73 |
| - ("swin_b", "pretrained"): "Swin_B_Weights.IMAGENET1K_V1", |
74 |
| - ("swin_v2_t", "pretrained"): "Swin_V2_T_Weights.IMAGENET1K_V1", |
75 |
| - ("swin_v2_s", "pretrained"): "Swin_V2_S_Weights.IMAGENET1K_V1", |
76 |
| - ("swin_v2_b", "pretrained"): "Swin_V2_B_Weights.IMAGENET1K_V1", |
77 |
| - ("video.s3d", "pretrained"): "S3D_Weights.KINETICS400_V1", |
78 |
| - ("video.swin3d_t", "pretrained"): "Swin3D_T_Weights.KINETICS400_V1", |
79 |
| - ("video.swin3d_s", "pretrained"): "Swin3D_S_Weights.KINETICS400_V1", |
80 |
| - ("video.swin3d_b", "pretrained"): "Swin3D_B_Weights.KINETICS400_V1", |
81 |
| - ("vit_b_16", "pretrained"): "ViT_B_16_Weights.IMAGENET1K_V1", |
82 |
| - ("vit_b_32", "pretrained"): "ViT_B_32_Weights.IMAGENET1K_V1", |
83 |
| - ("vit_l_16", "pretrained"): "ViT_L_16_Weights.IMAGENET1K_V1", |
84 |
| - ("vit_l_32", "pretrained"): "ViT_L_32_Weights.IMAGENET1K_V1", |
85 |
| - ("vit_h_14", "pretrained"): "None", |
86 |
| - ("vgg11", "pretrained"): "VGG11_Weights.IMAGENET1K_V1", |
87 |
| - ("vgg11_bn", "pretrained"): "VGG11_BN_Weights.IMAGENET1K_V1", |
88 |
| - ("vgg13", "pretrained"): "VGG13_Weights.IMAGENET1K_V1", |
89 |
| - ("vgg13_bn", "pretrained"): "VGG13_BN_Weights.IMAGENET1K_V1", |
90 |
| - ("vgg16", "pretrained"): "VGG16_Weights.IMAGENET1K_V1", |
91 |
| - ("vgg16_bn", "pretrained"): "VGG16_BN_Weights.IMAGENET1K_V1", |
92 |
| - ("vgg19", "pretrained"): "VGG19_Weights.IMAGENET1K_V1", |
93 |
| - ("vgg19_bn", "pretrained"): "VGG19_BN_Weights.IMAGENET1K_V1", |
94 |
| - ("video.mvit_v1_b", "pretrained"): "MViT_V1_B_Weights.KINETICS400_V1", |
95 |
| - ("video.mvit_v2_s", "pretrained"): "MViT_V2_S_Weights.KINETICS400_V1", |
96 |
| - ("video.r3d_18", "pretrained"): "R3D_18_Weights.KINETICS400_V1", |
97 |
| - ("video.mc3_18", "pretrained"): "MC3_18_Weights.KINETICS400_V1", |
98 |
| - ("video.r2plus1d_18", "pretrained"): "R2Plus1D_18_Weights.KINETICS400_V1", |
99 |
| - ("regnet_y_400mf", "pretrained"): "RegNet_Y_400MF_Weights.IMAGENET1K_V1", |
100 |
| - ("regnet_y_800mf", "pretrained"): "RegNet_Y_800MF_Weights.IMAGENET1K_V1", |
101 |
| - ("regnet_y_1_6gf", "pretrained"): "RegNet_Y_1_6GF_Weights.IMAGENET1K_V1", |
102 |
| - ("regnet_y_3_2gf", "pretrained"): "RegNet_Y_3_2GF_Weights.IMAGENET1K_V1", |
103 |
| - ("regnet_y_8gf", "pretrained"): "RegNet_Y_8GF_Weights.IMAGENET1K_V1", |
104 |
| - ("regnet_y_16gf", "pretrained"): "RegNet_Y_16GF_Weights.IMAGENET1K_V1", |
105 |
| - ("regnet_y_32gf", "pretrained"): "RegNet_Y_32GF_Weights.IMAGENET1K_V1", |
106 |
| - ("regnet_y_128gf", "pretrained"): "None", |
107 |
| - ("regnet_x_400mf", "pretrained"): "RegNet_X_400MF_Weights.IMAGENET1K_V1", |
108 |
| - ("regnet_x_800mf", "pretrained"): "RegNet_X_800MF_Weights.IMAGENET1K_V1", |
109 |
| - ("regnet_x_1_6gf", "pretrained"): "RegNet_X_1_6GF_Weights.IMAGENET1K_V1", |
110 |
| - ("regnet_x_3_2gf", "pretrained"): "RegNet_X_3_2GF_Weights.IMAGENET1K_V1", |
111 |
| - ("regnet_x_8gf", "pretrained"): "RegNet_X_8GF_Weights.IMAGENET1K_V1", |
112 |
| - ("regnet_x_16gf", "pretrained"): "RegNet_X_16GF_Weights.IMAGENET1K_V1", |
113 |
| - ("regnet_x_32gf", "pretrained"): "RegNet_X_32GF_Weights.IMAGENET1K_V1", |
114 |
| - ("resnet18", "pretrained"): "ResNet18_Weights.IMAGENET1K_V1", |
115 |
| - ("resnet34", "pretrained"): "ResNet34_Weights.IMAGENET1K_V1", |
116 |
| - ("resnet50", "pretrained"): "ResNet50_Weights.IMAGENET1K_V1", |
117 |
| - ("resnet101", "pretrained"): "ResNet101_Weights.IMAGENET1K_V1", |
118 |
| - ("resnet152", "pretrained"): "ResNet152_Weights.IMAGENET1K_V1", |
119 |
| - ("resnext50_32x4d", "pretrained"): "ResNeXt50_32X4D_Weights.IMAGENET1K_V1", |
120 |
| - ("resnext101_32x8d", "pretrained"): "ResNeXt101_32X8D_Weights.IMAGENET1K_V1", |
121 |
| - ("resnext101_64x4d", "pretrained"): "ResNeXt101_64X4D_Weights.IMAGENET1K_V1", |
122 |
| - ("wide_resnet50_2", "pretrained"): "Wide_ResNet50_2_Weights.IMAGENET1K_V1", |
123 |
| - ("wide_resnet101_2", "pretrained"): "Wide_ResNet101_2_Weights.IMAGENET1K_V1", |
124 |
| - ("efficientnet_b0", "pretrained"): "EfficientNet_B0_Weights.IMAGENET1K_V1", |
125 |
| - ("efficientnet_b1", "pretrained"): "EfficientNet_B1_Weights.IMAGENET1K_V1", |
126 |
| - ("efficientnet_b2", "pretrained"): "EfficientNet_B2_Weights.IMAGENET1K_V1", |
127 |
| - ("efficientnet_b3", "pretrained"): "EfficientNet_B3_Weights.IMAGENET1K_V1", |
128 |
| - ("efficientnet_b4", "pretrained"): "EfficientNet_B4_Weights.IMAGENET1K_V1", |
129 |
| - ("efficientnet_b5", "pretrained"): "EfficientNet_B5_Weights.IMAGENET1K_V1", |
130 |
| - ("efficientnet_b6", "pretrained"): "EfficientNet_B6_Weights.IMAGENET1K_V1", |
131 |
| - ("efficientnet_b7", "pretrained"): "EfficientNet_B7_Weights.IMAGENET1K_V1", |
132 |
| - ("efficientnet_v2_s", "pretrained"): "EfficientNet_V2_S_Weights.IMAGENET1K_V1", |
133 |
| - ("efficientnet_v2_m", "pretrained"): "EfficientNet_V2_M_Weights.IMAGENET1K_V1", |
134 |
| - ("efficientnet_v2_l", "pretrained"): "EfficientNet_V2_L_Weights.IMAGENET1K_V1", |
135 |
| - ("googlenet", "pretrained"): "GoogLeNet_Weights.IMAGENET1K_V1", |
136 |
| - ("segmentation.deeplabv3_resnet50", "pretrained"): "DeepLabV3_ResNet50_Weights.COCO_WITH_VOC_LABELS_V1", |
137 |
| - ("segmentation.deeplabv3_resnet50", "pretrained_backbone"): "ResNet50_Weights.IMAGENET1K_V1", |
138 |
| - ("segmentation.deeplabv3_resnet101", "pretrained"): "DeepLabV3_ResNet101_Weights.COCO_WITH_VOC_LABELS_V1", |
139 |
| - ("segmentation.deeplabv3_resnet101", "pretrained_backbone"): "ResNet101_Weights.IMAGENET1K_V1", |
140 |
| - ("segmentation.deeplabv3_mobilenet_v3_large", "pretrained"): "DeepLabV3_MobileNet_V3_Large_Weights.COCO_WITH_VOC_LABELS_V1", |
141 |
| - ("segmentation.deeplabv3_mobilenet_v3_large", "pretrained_backbone"): "MobileNet_V3_Large_Weights.IMAGENET1K_V1", |
142 |
| - ("segmentation.fcn_resnet50", "pretrained"): "FCN_ResNet50_Weights.COCO_WITH_VOC_LABELS_V1", |
143 |
| - ("segmentation.fcn_resnet50", "pretrained_backbone"): "ResNet50_Weights.IMAGENET1K_V1", |
144 |
| - ("segmentation.fcn_resnet101", "pretrained"): "FCN_ResNet101_Weights.COCO_WITH_VOC_LABELS_V1", |
145 |
| - ("segmentation.fcn_resnet101", "pretrained_backbone"): "ResNet101_Weights.IMAGENET1K_V1", |
146 |
| - ("detection.ssd300_vgg16", "pretrained"): "SSD300_VGG16_Weights.COCO_V1", |
147 |
| - ("detection.ssd300_vgg16", "pretrained_backbone"): "VGG16_Weights.IMAGENET1K_FEATURES", |
148 |
| - ("detection.ssdlite320_mobilenet_v3_large", "pretrained"): "SSDLite320_MobileNet_V3_Large_Weights.COCO_V1", |
149 |
| - ("detection.ssdlite320_mobilenet_v3_large", "pretrained_backbone"): "MobileNet_V3_Large_Weights.IMAGENET1K_V1", |
150 |
| - } |
151 |
| - # fmt: on |
152 |
| - |
153 |
| - # The same model can be imported from torchvision.models directly, |
154 |
| - # or from a submodule like torchvision.models.resnet. |
155 |
| - MODEL_SUBMODULES = ( |
156 |
| - "alexnet", |
157 |
| - "convnext", |
158 |
| - "densenet", |
159 |
| - "efficientnet", |
160 |
| - "googlenet", |
161 |
| - "inception", |
162 |
| - "mnasnet", |
163 |
| - "mobilenet", |
164 |
| - "regnet", |
165 |
| - "resnet", |
166 |
| - "shufflenetv2", |
167 |
| - "squeezenet", |
168 |
| - "vgg", |
169 |
| - "vision_transformer", |
170 |
| - "swin_transformer", |
171 |
| - "maxvit", |
172 |
| - ) |
173 |
| - |
174 |
| - def visit_Call(self, node): |
175 |
| - def _new_arg_and_import( |
176 |
| - old_arg: cst.Arg, is_backbone: bool |
177 |
| - ) -> Optional[cst.Arg]: |
178 |
| - old_arg_name = "pretrained_backbone" if is_backbone else "pretrained" |
179 |
| - if old_arg is None or (model_name, old_arg_name) not in self.MODEL_WEIGHTS: |
180 |
| - return None |
181 |
| - new_arg_name = "weights_backbone" if is_backbone else "weights" |
182 |
| - weights_arg = None |
183 |
| - if cst.ensure_type(old_arg.value, cst.Name).value == "True": |
184 |
| - weights_str = self.MODEL_WEIGHTS[(model_name, old_arg_name)] |
185 |
| - if is_backbone is False and len(model_name.split(".")) > 1: |
186 |
| - # Prepend things like 'detection.' to the weights string |
187 |
| - weights_str = model_name.split(".")[0] + "." + weights_str |
188 |
| - weights_str = "models." + weights_str |
189 |
| - weights_arg = cst.ensure_type( |
190 |
| - cst.parse_expression(f"f({new_arg_name}={weights_str})"), cst.Call |
191 |
| - ).args[0] |
192 |
| - self.needed_imports.add( |
193 |
| - ImportItem( |
194 |
| - module_name="torchvision", |
195 |
| - obj_name="models", |
196 |
| - ) |
197 |
| - ) |
198 |
| - elif cst.ensure_type(old_arg.value, cst.Name).value == "False": |
199 |
| - weights_arg = cst.ensure_type( |
200 |
| - cst.parse_expression(f"f({new_arg_name}=None)"), cst.Call |
201 |
| - ).args[0] |
202 |
| - return weights_arg |
203 |
| - |
204 |
| - qualified_name = self.get_qualified_name_for_call(node) |
205 |
| - if qualified_name is None: |
206 |
| - return |
207 |
| - if qualified_name.startswith("torchvision.models"): |
208 |
| - model_name = qualified_name[len("torchvision.models") + 1 :] |
209 |
| - for submodule in self.MODEL_SUBMODULES: |
210 |
| - if model_name.startswith(submodule + "."): |
211 |
| - model_name = model_name[len(submodule) + 1 :] |
212 |
| - |
213 |
| - if (model_name, "pretrained") not in self.MODEL_WEIGHTS: |
214 |
| - return |
215 |
| - |
216 |
| - message = None |
217 |
| - pretrained_arg = self.get_specific_arg(node, "pretrained", 0) |
218 |
| - if pretrained_arg is not None: |
219 |
| - message = "Parameter `pretrained` is deprecated, please use `weights` instead." |
220 |
| - |
221 |
| - pretrained_backbone_arg = self.get_specific_arg( |
222 |
| - node, "pretrained_backbone", 1 |
223 |
| - ) |
224 |
| - if pretrained_backbone_arg is not None: |
225 |
| - message = "Parameter `pretrained_backbone` is deprecated, please use `weights_backbone` instead." |
226 |
| - |
227 |
| - replacement_args = list(node.args) |
228 |
| - |
229 |
| - new_pretrained_arg = _new_arg_and_import(pretrained_arg, is_backbone=False) |
230 |
| - has_replacement = False |
231 |
| - if new_pretrained_arg is not None: |
232 |
| - for pos, arg in enumerate(node.args): |
233 |
| - if arg is pretrained_arg: |
234 |
| - break |
235 |
| - replacement_args[pos] = new_pretrained_arg |
236 |
| - has_replacement = True |
237 |
| - |
238 |
| - new_pretrained_backbone_arg = _new_arg_and_import( |
239 |
| - pretrained_backbone_arg, is_backbone=True |
240 |
| - ) |
241 |
| - if new_pretrained_backbone_arg is not None: |
242 |
| - for pos, arg in enumerate(node.args): |
243 |
| - if arg is pretrained_backbone_arg: |
244 |
| - break |
245 |
| - replacement_args[pos] = new_pretrained_backbone_arg |
246 |
| - has_replacement = True |
247 |
| - |
248 |
| - replacement = ( |
249 |
| - node.with_changes(args=replacement_args) if has_replacement else None |
250 |
| - ) |
251 |
| - if message is not None: |
252 |
| - position_metadata = self.get_metadata( |
253 |
| - cst.metadata.WhitespaceInclusivePositionProvider, node |
254 |
| - ) |
255 |
| - self.violations.append( |
256 |
| - LintViolation( |
257 |
| - error_code=self.ERROR_CODE, |
258 |
| - message=message, |
259 |
| - line=position_metadata.start.line, |
260 |
| - column=position_metadata.start.column, |
261 |
| - node=node, |
262 |
| - replacement=replacement, |
263 |
| - ) |
264 |
| - ) |
| 1 | +from .pretrained import TorchVisionDeprecatedPretrainedVisitor # noqa: F401 |
| 2 | +from .to_tensor import TorchVisionDeprecatedToTensorVisitor # noqa: F401 |
0 commit comments