Skip to content

Commit 77bdb58

Browse files
authored
[TorchFix] Make pretrained rule work when a model imported from a submodule (#4613)
Update the rule for `pretrained` torchvision models parameter to work when a model imported from a submodule, like `models.resnet.resnet101` instead of `models.resnet101` (which is the same model).
1 parent 76f4ebb commit 77bdb58

File tree

4 files changed

+60
-1
lines changed

4 files changed

+60
-1
lines changed

tests/fixtures/vision/checker/pretrained.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,3 +17,7 @@
1717
from torchvision.models import ResNet50_Weights
1818
torchvision.models.resnet50(weights=ResNet50_Weights.IMAGENET1K_V1)
1919
torchvision.models.resnet50(weights=None)
20+
21+
# Make sure no false positives on non-model functions
22+
from torchvision.models.vgg import make_layers, cfgs
23+
make_layers(cfgs['D'], False)
Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
# Check that the codemod works when a model
2+
# imported from a submodule
3+
# i.e. from torchvision.models.resnet instead of
4+
# directly from torchvision.models.
5+
6+
from torchvision import models
7+
backbone = models.resnet.resnet101(pretrained=True, replace_stride_with_dilation=[False, True, True])
8+
9+
from torchvision.models import resnet
10+
backbone = resnet.resnet101(pretrained=True, replace_stride_with_dilation=[False, True, True])
11+
12+
from torchvision.models.resnet import resnet152
13+
resnet152(pretrained=True)
14+
resnet152(pretrained=False)
Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
# Check that the codemod works when a model
2+
# imported from a submodule
3+
# i.e. from torchvision.models.resnet instead of
4+
# directly from torchvision.models.
5+
6+
from torchvision import models
7+
backbone = models.resnet.resnet101(weights=models.ResNet101_Weights.IMAGENET1K_V1, replace_stride_with_dilation=[False, True, True])
8+
9+
from torchvision.models import resnet
10+
backbone = resnet.resnet101(weights=models.ResNet101_Weights.IMAGENET1K_V1, replace_stride_with_dilation=[False, True, True])
11+
12+
from torchvision.models.resnet import resnet152
13+
resnet152(weights=models.ResNet152_Weights.IMAGENET1K_V1)
14+
resnet152(weights=None)

torchfix/visitors/vision/__init__.py

Lines changed: 28 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -150,6 +150,27 @@ class TorchVisionDeprecatedPretrainedVisitor(TorchVisitor):
150150
}
151151
# fmt: on
152152

153+
# The same model can be imported from torchvision.models directly,
154+
# or from a submodule like torchvision.models.resnet.
155+
MODEL_SUBMODULES = (
156+
"alexnet",
157+
"convnext",
158+
"densenet",
159+
"efficientnet",
160+
"googlenet",
161+
"inception",
162+
"mnasnet",
163+
"mobilenet",
164+
"regnet",
165+
"resnet",
166+
"shufflenetv2",
167+
"squeezenet",
168+
"vgg",
169+
"vision_transformer",
170+
"swin_transformer",
171+
"maxvit",
172+
)
173+
153174
def visit_Call(self, node):
154175
def _new_arg_and_import(
155176
old_arg: cst.Arg, is_backbone: bool
@@ -185,8 +206,14 @@ def _new_arg_and_import(
185206
return
186207
if qualified_name.startswith("torchvision.models"):
187208
model_name = qualified_name[len("torchvision.models") + 1 :]
188-
message = None
209+
for submodule in self.MODEL_SUBMODULES:
210+
if model_name.startswith(submodule + "."):
211+
model_name = model_name[len(submodule) + 1 :]
189212

213+
if (model_name, "pretrained") not in self.MODEL_WEIGHTS:
214+
return
215+
216+
message = None
190217
pretrained_arg = self.get_specific_arg(node, "pretrained", 0)
191218
if pretrained_arg is not None:
192219
message = "Parameter `pretrained` is deprecated, please use `weights` instead."

0 commit comments

Comments
 (0)