Skip to content

Commit a0dc5b3

Browse files
authored
Update deprecated TorchVision pretrained API (#2573)
1 parent fb43d92 commit a0dc5b3

File tree

2 files changed

+5
-5
lines changed

2 files changed

+5
-5
lines changed

prototype_source/numeric_suite_tutorial.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@
2828
import torch
2929
import torch.nn as nn
3030
import torchvision
31-
from torchvision import datasets
31+
from torchvision import models, datasets
3232
import torchvision.transforms as transforms
3333
import os
3434
import torch.quantization
@@ -43,7 +43,7 @@
4343
# Then we load the pretrained float ResNet18 model, and quantize it into qmodel. We cannot compare two arbitrary models, only a float model and the quantized model derived from it can be compared.
4444

4545

46-
float_model = torchvision.models.quantization.resnet18(pretrained=True, quantize=False)
46+
float_model = torchvision.models.quantization.resnet18(weights=models.ResNet18_Weights.IMAGENET1K_V1, quantize=False)
4747
float_model.to('cpu')
4848
float_model.eval()
4949
float_model.fuse_model()
@@ -199,7 +199,7 @@ def forward(self, x):
199199
#
200200
# Notice before each call of ``compare_model_outputs()`` and ``compare_model_stub()`` we need to have clean float and quantized model. This is because ``compare_model_outputs()`` and ``compare_model_stub()`` modify float and quantized model inplace, and it will cause unexpected results if call one right after another.
201201

202-
float_model = torchvision.models.quantization.resnet18(pretrained=True, quantize=False)
202+
float_model = torchvision.models.quantization.resnet18(weights=models.ResNet18_Weights.IMAGENET1K_V1, quantize=False)
203203
float_model.to('cpu')
204204
float_model.eval()
205205
float_model.fuse_model()

recipes_source/recipes/Captum_Recipe.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -42,12 +42,12 @@
4242
#
4343

4444
import torchvision
45-
from torchvision import transforms
45+
from torchvision import models, transforms
4646
from PIL import Image
4747
import requests
4848
from io import BytesIO
4949

50-
model = torchvision.models.resnet18(pretrained=True).eval()
50+
model = torchvision.models.resnet18(weights=models.ResNet18_Weights.IMAGENET1K_V1).eval()
5151

5252
response = requests.get("https://image.freepik.com/free-photo/two-beautiful-puppies-cat-dog_58409-6024.jpg")
5353
img = Image.open(BytesIO(response.content))

0 commit comments

Comments
 (0)