Skip to content

Commit 536afb6

Browse files
committed
Fix super_resolution example for torch>=2.6
1 parent ac7e960 commit 536afb6

File tree

3 files changed

+19
-4
lines changed

3 files changed

+19
-4
lines changed

run_python_examples.sh

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -137,6 +137,7 @@ function fx() {
137137

138138
function super_resolution() {
139139
uv run main.py --upscale_factor 3 --batchSize 4 --testBatchSize 100 --nEpochs 1 --lr 0.001 --mps || error "super resolution failed"
140+
uv run super_resolve.py --input_image dataset/BSDS300/images/test/16077.jpg --model model_epoch_1.pth --output_filename out.png || error "super resolution upscaling failed"
140141
}
141142

142143
function time_sequence_prediction() {

super_resolution/README.md

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -22,14 +22,18 @@ optional arguments:
2222
--seed random seed to use. Default=123
2323
```
2424

25-
This example trains a super-resolution network on the [BSD300 dataset](https://www2.eecs.berkeley.edu/Research/Projects/CS/vision/bsds/), using crops from the 200 training images, and evaluating on crops of the 100 test images. A snapshot of the model after every epoch with filename model*epoch*<epoch_number>.pth
25+
This example trains a super-resolution network on the [BSD300 dataset](https://www2.eecs.berkeley.edu/Research/Projects/CS/vision/bsds/), using crops from the 200 training images, and evaluating on crops of the 100 test images. A snapshot of the model after every epoch with filename `model_epoch_<epoch_number>.pth`.
2626

2727
## Example Usage:
2828

2929
### Train
3030

31-
`python main.py --upscale_factor 3 --batchSize 4 --testBatchSize 100 --nEpochs 30 --lr 0.001`
31+
```bash
32+
python main.py --upscale_factor 3 --batchSize 4 --testBatchSize 100 --nEpochs 30 --lr 0.001
33+
```
3234

3335
### Super Resolve
3436

35-
`python super_resolve.py --input_image dataset/BSDS300/images/test/16077.jpg --model model_epoch_500.pth --output_filename out.png`
37+
```bash
38+
python super_resolve.py --input_image dataset/BSDS300/images/test/16077.jpg --model model_epoch_30.pth --output_filename out.png
39+
```

super_resolution/super_resolve.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
import torch
44
from PIL import Image
55
from torchvision.transforms import ToTensor
6+
from model import Net
67

78
import numpy as np
89

@@ -18,7 +19,16 @@
1819
img = Image.open(opt.input_image).convert('YCbCr')
1920
y, cb, cr = img.split()
2021

21-
model = torch.load(opt.model)
22+
with open(opt.model, 'rb') as f:
23+
safe_globals = [
24+
Net,
25+
torch.nn.modules.activation.ReLU,
26+
torch.nn.modules.conv.Conv2d,
27+
torch.nn.modules.pixelshuffle.PixelShuffle,
28+
]
29+
with torch.serialization.safe_globals(safe_globals):
30+
model = torch.load(f)
31+
2232
img_to_tensor = ToTensor()
2333
input = img_to_tensor(y).view(1, -1, y.size[1], y.size[0])
2434

0 commit comments

Comments
 (0)