Skip to content

Commit 6959ea6

Browse files
authored
[TorchFix] add visitor for deprecated TorchVision transform (#4615)
This adds a violation when the [deprecated](https://github.com/pytorch/vision/blob/ddfee23d56700ba84fd28805d5cbdeac2f28f2a7/torchvision/transforms/v2/_deprecated.py#L12-L13) `torchvision.transforms.v2.ToTensor` transform is imported or used. For the future, there are two more things to do: 1. The corresponding functional is also [deprecated](https://github.com/pytorch/vision/blob/ddfee23d56700ba84fd28805d5cbdeac2f28f2a7/torchvision/transforms/v2/functional/_deprecated.py#L10-L11) and needs the same treatment as the function. If we keep the long violation message (which is the same message that one would see using this at runtime), we probably should have a different error code for the functional to avoid making the message even longer. Otherwise we may also have both under the same error code. 2. Instead of just complaining about the use of deprecated functionality, we could also provide a fix. However, this is not straight forward, since we need to replace `v2.ToTensor()` with `v2.Compose([v2.ToImage(), v2.ToDtype(torch.float32, scale=True)])`. While that should work everywhere, usually `v2.ToTensor` is not used as standalone, but rather already as part of a `Compose`, e.g. ```python transform = transforms.Compose([ transforms.Resize(...), transforms.CenterCrop(...), transforms.ToTensor(), transforms.Normalize(...), ]) ``` or ```python pipeline = [ transforms.Resize(...), transforms.CenterCrop(...), ] if foo: pipeline.append(...) else: pipeline.extend(...) pipeline.extend([ transforms.ToTensor(), transforms.Normalize(...), ]) transform = transforms.Compose(pipeline) ``` In both cases, we want don't want to have an extra `Compose` in there, but just replace `ToTensor` with `v2.ToImage(), v2.ToDtype(torch.float32, scale=True)`. I don't have much experience with `libcst`, but from what I did in this PR I think 1. is fairly straight forward, while 2. can become a nightmare.
1 parent 77bdb58 commit 6959ea6

File tree

6 files changed

+349
-265
lines changed

6 files changed

+349
-265
lines changed
Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,26 @@
1+
from torchvision.transforms.v2 import ToTensor
2+
3+
from torchvision.transforms.v2 import ToTensor as ToTensorAlias
4+
5+
from torchvision.transforms.v2 import (
6+
ToImage,
7+
ToTensor,
8+
ToDtype,
9+
)
10+
11+
from torchvision.transforms import v2
12+
13+
v2.ToTensor
14+
15+
from torchvision.transforms import v2 as transforms
16+
17+
transforms.ToTensor
18+
19+
20+
import torchvision.transforms.v2 as transforms
21+
22+
transforms.ToTensor
23+
24+
import torchvision.transforms.v2
25+
26+
torchvision.transforms.v2.ToTensor
Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
1:39 TOR202 The transform `v2.ToTensor()` is deprecated and will be removed in a future release. Instead, please use `v2.Compose([v2.ToImage(), v2.ToDtype(torch.float32, scale=True)])`.
2+
3:39 TOR202 The transform `v2.ToTensor()` is deprecated and will be removed in a future release. Instead, please use `v2.Compose([v2.ToImage(), v2.ToDtype(torch.float32, scale=True)])`.
3+
7:5 TOR202 The transform `v2.ToTensor()` is deprecated and will be removed in a future release. Instead, please use `v2.Compose([v2.ToImage(), v2.ToDtype(torch.float32, scale=True)])`.
4+
13:1 TOR202 The transform `v2.ToTensor()` is deprecated and will be removed in a future release. Instead, please use `v2.Compose([v2.ToImage(), v2.ToDtype(torch.float32, scale=True)])`.
5+
17:1 TOR202 The transform `v2.ToTensor()` is deprecated and will be removed in a future release. Instead, please use `v2.Compose([v2.ToImage(), v2.ToDtype(torch.float32, scale=True)])`.
6+
22:1 TOR202 The transform `v2.ToTensor()` is deprecated and will be removed in a future release. Instead, please use `v2.Compose([v2.ToImage(), v2.ToDtype(torch.float32, scale=True)])`.
7+
26:1 TOR202 The transform `v2.ToTensor()` is deprecated and will be removed in a future release. Instead, please use `v2.Compose([v2.ToImage(), v2.ToDtype(torch.float32, scale=True)])`.

torchfix/torchfix.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,9 @@
1212

1313
from .visitors.performance import TorchSynchronizedDataLoaderVisitor
1414
from .visitors.misc import TorchRequireGradVisitor
15-
from .visitors.vision import TorchVisionDeprecatedPretrainedVisitor
15+
from .visitors.vision import (
16+
TorchVisionDeprecatedPretrainedVisitor, TorchVisionDeprecatedToTensorVisitor
17+
)
1618

1719
__version__ = "0.0.3"
1820

@@ -27,6 +29,7 @@ def GET_ALL_VISITORS():
2729
TorchRequireGradVisitor(),
2830
TorchSynchronizedDataLoaderVisitor(),
2931
TorchVisionDeprecatedPretrainedVisitor(),
32+
TorchVisionDeprecatedToTensorVisitor(),
3033
]
3134

3235

torchfix/visitors/vision/__init__.py

Lines changed: 2 additions & 264 deletions
Original file line numberDiff line numberDiff line change
@@ -1,264 +1,2 @@
1-
from typing import Optional
2-
3-
import libcst as cst
4-
import libcst.matchers as m
5-
from libcst.codemod.visitors import ImportItem
6-
7-
from ...common import LintViolation, TorchVisitor
8-
9-
10-
class TorchVisionDeprecatedPretrainedVisitor(TorchVisitor):
11-
"""
12-
Find and fix deprecated `pretrained` parameters in TorchVision models.
13-
14-
Both `pretrained` and `pretrained_backbone` parameters are supported.
15-
The parameters are updated to the new `weights` and `weights_backbone` parameters
16-
only if the old parameter has explicit literal `True` or `False` value,
17-
otherwise only lint violation is emitted.
18-
"""
19-
20-
ERROR_CODE = "TOR201"
21-
22-
# flake8: noqa: E105
23-
# fmt: off
24-
MODEL_WEIGHTS = {
25-
("mobilenet_v2", "pretrained"): "MobileNet_V2_Weights.IMAGENET1K_V1",
26-
("mobilenet_v3_large", "pretrained"): "MobileNet_V3_Large_Weights.IMAGENET1K_V1",
27-
("mobilenet_v3_small", "pretrained"): "MobileNet_V3_Small_Weights.IMAGENET1K_V1",
28-
("densenet121", "pretrained"): "DenseNet121_Weights.IMAGENET1K_V1",
29-
("densenet161", "pretrained"): "DenseNet161_Weights.IMAGENET1K_V1",
30-
("densenet169", "pretrained"): "DenseNet169_Weights.IMAGENET1K_V1",
31-
("densenet201", "pretrained"): "DenseNet201_Weights.IMAGENET1K_V1",
32-
("detection.maskrcnn_resnet50_fpn", "pretrained"): "MaskRCNN_ResNet50_FPN_Weights.COCO_V1",
33-
("detection.maskrcnn_resnet50_fpn", "pretrained_backbone"): "ResNet50_Weights.IMAGENET1K_V1",
34-
("detection.maskrcnn_resnet50_fpn_v2", "pretrained"): "MaskRCNN_ResNet50_FPN_V2_Weights.COCO_V1",
35-
("detection.maskrcnn_resnet50_fpn_v2", "pretrained_backbone"): "ResNet50_Weights.IMAGENET1K_V1",
36-
("detection.retinanet_resnet50_fpn", "pretrained"): "RetinaNet_ResNet50_FPN_Weights.COCO_V1",
37-
("detection.retinanet_resnet50_fpn", "pretrained_backbone"): "ResNet50_Weights.IMAGENET1K_V1",
38-
("detection.retinanet_resnet50_fpn_v2", "pretrained"): "RetinaNet_ResNet50_FPN_V2_Weights.COCO_V1",
39-
("detection.retinanet_resnet50_fpn_v2", "pretrained_backbone"): "ResNet50_Weights.IMAGENET1K_V1",
40-
("optical_flow.raft_large", "pretrained"): "Raft_Large_Weights.C_T_SKHT_V2",
41-
("optical_flow.raft_small", "pretrained"): "Raft_Small_Weights.C_T_V2",
42-
("alexnet", "pretrained"): "AlexNet_Weights.IMAGENET1K_V1",
43-
("convnext_tiny", "pretrained"): "ConvNeXt_Tiny_Weights.IMAGENET1K_V1",
44-
("convnext_small", "pretrained"): "ConvNeXt_Small_Weights.IMAGENET1K_V1",
45-
("convnext_base", "pretrained"): "ConvNeXt_Base_Weights.IMAGENET1K_V1",
46-
("convnext_large", "pretrained"): "ConvNeXt_Large_Weights.IMAGENET1K_V1",
47-
("inception_v3", "pretrained"): "Inception_V3_Weights.IMAGENET1K_V1",
48-
("maxvit_t", "pretrained"): "MaxVit_T_Weights.IMAGENET1K_V1",
49-
("mnasnet0_5", "pretrained"): "MNASNet0_5_Weights.IMAGENET1K_V1",
50-
("mnasnet0_75", "pretrained"): "MNASNet0_75_Weights.IMAGENET1K_V1",
51-
("mnasnet1_0", "pretrained"): "MNASNet1_0_Weights.IMAGENET1K_V1",
52-
("mnasnet1_3", "pretrained"): "MNASNet1_3_Weights.IMAGENET1K_V1",
53-
("detection.fasterrcnn_resnet50_fpn", "pretrained"): "FasterRCNN_ResNet50_FPN_Weights.COCO_V1",
54-
("detection.fasterrcnn_resnet50_fpn", "pretrained_backbone"): "ResNet50_Weights.IMAGENET1K_V1",
55-
("detection.fasterrcnn_resnet50_fpn_v2", "pretrained"): "FasterRCNN_ResNet50_FPN_V2_Weights.COCO_V1",
56-
("detection.fasterrcnn_resnet50_fpn_v2", "pretrained_backbone"): "ResNet50_Weights.IMAGENET1K_V1",
57-
("detection.fasterrcnn_mobilenet_v3_large_320_fpn", "pretrained"): "FasterRCNN_MobileNet_V3_Large_320_FPN_Weights.COCO_V1",
58-
("detection.fasterrcnn_mobilenet_v3_large_320_fpn", "pretrained_backbone"): "MobileNet_V3_Large_Weights.IMAGENET1K_V1",
59-
("detection.fasterrcnn_mobilenet_v3_large_fpn", "pretrained"): "FasterRCNN_MobileNet_V3_Large_FPN_Weights.COCO_V1",
60-
("detection.fasterrcnn_mobilenet_v3_large_fpn", "pretrained_backbone"): "MobileNet_V3_Large_Weights.IMAGENET1K_V1",
61-
("detection.fcos_resnet50_fpn", "pretrained"): "FCOS_ResNet50_FPN_Weights.COCO_V1",
62-
("detection.fcos_resnet50_fpn", "pretrained_backbone"): "ResNet50_Weights.IMAGENET1K_V1",
63-
("segmentation.lraspp_mobilenet_v3_large", "pretrained"): "LRASPP_MobileNet_V3_Large_Weights.COCO_WITH_VOC_LABELS_V1",
64-
("segmentation.lraspp_mobilenet_v3_large", "pretrained_backbone"): "MobileNet_V3_Large_Weights.IMAGENET1K_V1",
65-
("shufflenet_v2_x0_5", "pretrained"): "ShuffleNet_V2_X0_5_Weights.IMAGENET1K_V1",
66-
("shufflenet_v2_x1_0", "pretrained"): "ShuffleNet_V2_X1_0_Weights.IMAGENET1K_V1",
67-
("shufflenet_v2_x1_5", "pretrained"): "ShuffleNet_V2_X1_5_Weights.IMAGENET1K_V1",
68-
("shufflenet_v2_x2_0", "pretrained"): "ShuffleNet_V2_X2_0_Weights.IMAGENET1K_V1",
69-
("squeezenet1_0", "pretrained"): "SqueezeNet1_0_Weights.IMAGENET1K_V1",
70-
("squeezenet1_1", "pretrained"): "SqueezeNet1_1_Weights.IMAGENET1K_V1",
71-
("swin_t", "pretrained"): "Swin_T_Weights.IMAGENET1K_V1",
72-
("swin_s", "pretrained"): "Swin_S_Weights.IMAGENET1K_V1",
73-
("swin_b", "pretrained"): "Swin_B_Weights.IMAGENET1K_V1",
74-
("swin_v2_t", "pretrained"): "Swin_V2_T_Weights.IMAGENET1K_V1",
75-
("swin_v2_s", "pretrained"): "Swin_V2_S_Weights.IMAGENET1K_V1",
76-
("swin_v2_b", "pretrained"): "Swin_V2_B_Weights.IMAGENET1K_V1",
77-
("video.s3d", "pretrained"): "S3D_Weights.KINETICS400_V1",
78-
("video.swin3d_t", "pretrained"): "Swin3D_T_Weights.KINETICS400_V1",
79-
("video.swin3d_s", "pretrained"): "Swin3D_S_Weights.KINETICS400_V1",
80-
("video.swin3d_b", "pretrained"): "Swin3D_B_Weights.KINETICS400_V1",
81-
("vit_b_16", "pretrained"): "ViT_B_16_Weights.IMAGENET1K_V1",
82-
("vit_b_32", "pretrained"): "ViT_B_32_Weights.IMAGENET1K_V1",
83-
("vit_l_16", "pretrained"): "ViT_L_16_Weights.IMAGENET1K_V1",
84-
("vit_l_32", "pretrained"): "ViT_L_32_Weights.IMAGENET1K_V1",
85-
("vit_h_14", "pretrained"): "None",
86-
("vgg11", "pretrained"): "VGG11_Weights.IMAGENET1K_V1",
87-
("vgg11_bn", "pretrained"): "VGG11_BN_Weights.IMAGENET1K_V1",
88-
("vgg13", "pretrained"): "VGG13_Weights.IMAGENET1K_V1",
89-
("vgg13_bn", "pretrained"): "VGG13_BN_Weights.IMAGENET1K_V1",
90-
("vgg16", "pretrained"): "VGG16_Weights.IMAGENET1K_V1",
91-
("vgg16_bn", "pretrained"): "VGG16_BN_Weights.IMAGENET1K_V1",
92-
("vgg19", "pretrained"): "VGG19_Weights.IMAGENET1K_V1",
93-
("vgg19_bn", "pretrained"): "VGG19_BN_Weights.IMAGENET1K_V1",
94-
("video.mvit_v1_b", "pretrained"): "MViT_V1_B_Weights.KINETICS400_V1",
95-
("video.mvit_v2_s", "pretrained"): "MViT_V2_S_Weights.KINETICS400_V1",
96-
("video.r3d_18", "pretrained"): "R3D_18_Weights.KINETICS400_V1",
97-
("video.mc3_18", "pretrained"): "MC3_18_Weights.KINETICS400_V1",
98-
("video.r2plus1d_18", "pretrained"): "R2Plus1D_18_Weights.KINETICS400_V1",
99-
("regnet_y_400mf", "pretrained"): "RegNet_Y_400MF_Weights.IMAGENET1K_V1",
100-
("regnet_y_800mf", "pretrained"): "RegNet_Y_800MF_Weights.IMAGENET1K_V1",
101-
("regnet_y_1_6gf", "pretrained"): "RegNet_Y_1_6GF_Weights.IMAGENET1K_V1",
102-
("regnet_y_3_2gf", "pretrained"): "RegNet_Y_3_2GF_Weights.IMAGENET1K_V1",
103-
("regnet_y_8gf", "pretrained"): "RegNet_Y_8GF_Weights.IMAGENET1K_V1",
104-
("regnet_y_16gf", "pretrained"): "RegNet_Y_16GF_Weights.IMAGENET1K_V1",
105-
("regnet_y_32gf", "pretrained"): "RegNet_Y_32GF_Weights.IMAGENET1K_V1",
106-
("regnet_y_128gf", "pretrained"): "None",
107-
("regnet_x_400mf", "pretrained"): "RegNet_X_400MF_Weights.IMAGENET1K_V1",
108-
("regnet_x_800mf", "pretrained"): "RegNet_X_800MF_Weights.IMAGENET1K_V1",
109-
("regnet_x_1_6gf", "pretrained"): "RegNet_X_1_6GF_Weights.IMAGENET1K_V1",
110-
("regnet_x_3_2gf", "pretrained"): "RegNet_X_3_2GF_Weights.IMAGENET1K_V1",
111-
("regnet_x_8gf", "pretrained"): "RegNet_X_8GF_Weights.IMAGENET1K_V1",
112-
("regnet_x_16gf", "pretrained"): "RegNet_X_16GF_Weights.IMAGENET1K_V1",
113-
("regnet_x_32gf", "pretrained"): "RegNet_X_32GF_Weights.IMAGENET1K_V1",
114-
("resnet18", "pretrained"): "ResNet18_Weights.IMAGENET1K_V1",
115-
("resnet34", "pretrained"): "ResNet34_Weights.IMAGENET1K_V1",
116-
("resnet50", "pretrained"): "ResNet50_Weights.IMAGENET1K_V1",
117-
("resnet101", "pretrained"): "ResNet101_Weights.IMAGENET1K_V1",
118-
("resnet152", "pretrained"): "ResNet152_Weights.IMAGENET1K_V1",
119-
("resnext50_32x4d", "pretrained"): "ResNeXt50_32X4D_Weights.IMAGENET1K_V1",
120-
("resnext101_32x8d", "pretrained"): "ResNeXt101_32X8D_Weights.IMAGENET1K_V1",
121-
("resnext101_64x4d", "pretrained"): "ResNeXt101_64X4D_Weights.IMAGENET1K_V1",
122-
("wide_resnet50_2", "pretrained"): "Wide_ResNet50_2_Weights.IMAGENET1K_V1",
123-
("wide_resnet101_2", "pretrained"): "Wide_ResNet101_2_Weights.IMAGENET1K_V1",
124-
("efficientnet_b0", "pretrained"): "EfficientNet_B0_Weights.IMAGENET1K_V1",
125-
("efficientnet_b1", "pretrained"): "EfficientNet_B1_Weights.IMAGENET1K_V1",
126-
("efficientnet_b2", "pretrained"): "EfficientNet_B2_Weights.IMAGENET1K_V1",
127-
("efficientnet_b3", "pretrained"): "EfficientNet_B3_Weights.IMAGENET1K_V1",
128-
("efficientnet_b4", "pretrained"): "EfficientNet_B4_Weights.IMAGENET1K_V1",
129-
("efficientnet_b5", "pretrained"): "EfficientNet_B5_Weights.IMAGENET1K_V1",
130-
("efficientnet_b6", "pretrained"): "EfficientNet_B6_Weights.IMAGENET1K_V1",
131-
("efficientnet_b7", "pretrained"): "EfficientNet_B7_Weights.IMAGENET1K_V1",
132-
("efficientnet_v2_s", "pretrained"): "EfficientNet_V2_S_Weights.IMAGENET1K_V1",
133-
("efficientnet_v2_m", "pretrained"): "EfficientNet_V2_M_Weights.IMAGENET1K_V1",
134-
("efficientnet_v2_l", "pretrained"): "EfficientNet_V2_L_Weights.IMAGENET1K_V1",
135-
("googlenet", "pretrained"): "GoogLeNet_Weights.IMAGENET1K_V1",
136-
("segmentation.deeplabv3_resnet50", "pretrained"): "DeepLabV3_ResNet50_Weights.COCO_WITH_VOC_LABELS_V1",
137-
("segmentation.deeplabv3_resnet50", "pretrained_backbone"): "ResNet50_Weights.IMAGENET1K_V1",
138-
("segmentation.deeplabv3_resnet101", "pretrained"): "DeepLabV3_ResNet101_Weights.COCO_WITH_VOC_LABELS_V1",
139-
("segmentation.deeplabv3_resnet101", "pretrained_backbone"): "ResNet101_Weights.IMAGENET1K_V1",
140-
("segmentation.deeplabv3_mobilenet_v3_large", "pretrained"): "DeepLabV3_MobileNet_V3_Large_Weights.COCO_WITH_VOC_LABELS_V1",
141-
("segmentation.deeplabv3_mobilenet_v3_large", "pretrained_backbone"): "MobileNet_V3_Large_Weights.IMAGENET1K_V1",
142-
("segmentation.fcn_resnet50", "pretrained"): "FCN_ResNet50_Weights.COCO_WITH_VOC_LABELS_V1",
143-
("segmentation.fcn_resnet50", "pretrained_backbone"): "ResNet50_Weights.IMAGENET1K_V1",
144-
("segmentation.fcn_resnet101", "pretrained"): "FCN_ResNet101_Weights.COCO_WITH_VOC_LABELS_V1",
145-
("segmentation.fcn_resnet101", "pretrained_backbone"): "ResNet101_Weights.IMAGENET1K_V1",
146-
("detection.ssd300_vgg16", "pretrained"): "SSD300_VGG16_Weights.COCO_V1",
147-
("detection.ssd300_vgg16", "pretrained_backbone"): "VGG16_Weights.IMAGENET1K_FEATURES",
148-
("detection.ssdlite320_mobilenet_v3_large", "pretrained"): "SSDLite320_MobileNet_V3_Large_Weights.COCO_V1",
149-
("detection.ssdlite320_mobilenet_v3_large", "pretrained_backbone"): "MobileNet_V3_Large_Weights.IMAGENET1K_V1",
150-
}
151-
# fmt: on
152-
153-
# The same model can be imported from torchvision.models directly,
154-
# or from a submodule like torchvision.models.resnet.
155-
MODEL_SUBMODULES = (
156-
"alexnet",
157-
"convnext",
158-
"densenet",
159-
"efficientnet",
160-
"googlenet",
161-
"inception",
162-
"mnasnet",
163-
"mobilenet",
164-
"regnet",
165-
"resnet",
166-
"shufflenetv2",
167-
"squeezenet",
168-
"vgg",
169-
"vision_transformer",
170-
"swin_transformer",
171-
"maxvit",
172-
)
173-
174-
def visit_Call(self, node):
175-
def _new_arg_and_import(
176-
old_arg: cst.Arg, is_backbone: bool
177-
) -> Optional[cst.Arg]:
178-
old_arg_name = "pretrained_backbone" if is_backbone else "pretrained"
179-
if old_arg is None or (model_name, old_arg_name) not in self.MODEL_WEIGHTS:
180-
return None
181-
new_arg_name = "weights_backbone" if is_backbone else "weights"
182-
weights_arg = None
183-
if cst.ensure_type(old_arg.value, cst.Name).value == "True":
184-
weights_str = self.MODEL_WEIGHTS[(model_name, old_arg_name)]
185-
if is_backbone is False and len(model_name.split(".")) > 1:
186-
# Prepend things like 'detection.' to the weights string
187-
weights_str = model_name.split(".")[0] + "." + weights_str
188-
weights_str = "models." + weights_str
189-
weights_arg = cst.ensure_type(
190-
cst.parse_expression(f"f({new_arg_name}={weights_str})"), cst.Call
191-
).args[0]
192-
self.needed_imports.add(
193-
ImportItem(
194-
module_name="torchvision",
195-
obj_name="models",
196-
)
197-
)
198-
elif cst.ensure_type(old_arg.value, cst.Name).value == "False":
199-
weights_arg = cst.ensure_type(
200-
cst.parse_expression(f"f({new_arg_name}=None)"), cst.Call
201-
).args[0]
202-
return weights_arg
203-
204-
qualified_name = self.get_qualified_name_for_call(node)
205-
if qualified_name is None:
206-
return
207-
if qualified_name.startswith("torchvision.models"):
208-
model_name = qualified_name[len("torchvision.models") + 1 :]
209-
for submodule in self.MODEL_SUBMODULES:
210-
if model_name.startswith(submodule + "."):
211-
model_name = model_name[len(submodule) + 1 :]
212-
213-
if (model_name, "pretrained") not in self.MODEL_WEIGHTS:
214-
return
215-
216-
message = None
217-
pretrained_arg = self.get_specific_arg(node, "pretrained", 0)
218-
if pretrained_arg is not None:
219-
message = "Parameter `pretrained` is deprecated, please use `weights` instead."
220-
221-
pretrained_backbone_arg = self.get_specific_arg(
222-
node, "pretrained_backbone", 1
223-
)
224-
if pretrained_backbone_arg is not None:
225-
message = "Parameter `pretrained_backbone` is deprecated, please use `weights_backbone` instead."
226-
227-
replacement_args = list(node.args)
228-
229-
new_pretrained_arg = _new_arg_and_import(pretrained_arg, is_backbone=False)
230-
has_replacement = False
231-
if new_pretrained_arg is not None:
232-
for pos, arg in enumerate(node.args):
233-
if arg is pretrained_arg:
234-
break
235-
replacement_args[pos] = new_pretrained_arg
236-
has_replacement = True
237-
238-
new_pretrained_backbone_arg = _new_arg_and_import(
239-
pretrained_backbone_arg, is_backbone=True
240-
)
241-
if new_pretrained_backbone_arg is not None:
242-
for pos, arg in enumerate(node.args):
243-
if arg is pretrained_backbone_arg:
244-
break
245-
replacement_args[pos] = new_pretrained_backbone_arg
246-
has_replacement = True
247-
248-
replacement = (
249-
node.with_changes(args=replacement_args) if has_replacement else None
250-
)
251-
if message is not None:
252-
position_metadata = self.get_metadata(
253-
cst.metadata.WhitespaceInclusivePositionProvider, node
254-
)
255-
self.violations.append(
256-
LintViolation(
257-
error_code=self.ERROR_CODE,
258-
message=message,
259-
line=position_metadata.start.line,
260-
column=position_metadata.start.column,
261-
node=node,
262-
replacement=replacement,
263-
)
264-
)
1+
from .pretrained import TorchVisionDeprecatedPretrainedVisitor # noqa: F401
2+
from .to_tensor import TorchVisionDeprecatedToTensorVisitor # noqa: F401

0 commit comments

Comments
 (0)