diff --git a/run_python_examples.sh b/run_python_examples.sh index 2d769c0ae1..c017bc78a8 100755 --- a/run_python_examples.sh +++ b/run_python_examples.sh @@ -137,6 +137,7 @@ function fx() { function super_resolution() { uv run main.py --upscale_factor 3 --batchSize 4 --testBatchSize 100 --nEpochs 1 --lr 0.001 --mps || error "super resolution failed" + uv run super_resolve.py --input_image dataset/BSDS300/images/test/16077.jpg --model model_epoch_1.pth --output_filename out.png || error "super resolution upscaling failed" } function time_sequence_prediction() { diff --git a/super_resolution/README.md b/super_resolution/README.md index 6b5fe831d9..b21a8c4af4 100644 --- a/super_resolution/README.md +++ b/super_resolution/README.md @@ -22,14 +22,18 @@ optional arguments: --seed random seed to use. Default=123 ``` -This example trains a super-resolution network on the [BSD300 dataset](https://www2.eecs.berkeley.edu/Research/Projects/CS/vision/bsds/), using crops from the 200 training images, and evaluating on crops of the 100 test images. A snapshot of the model after every epoch with filename model*epoch*.pth +This example trains a super-resolution network on the [BSD300 dataset](https://www2.eecs.berkeley.edu/Research/Projects/CS/vision/bsds/), using crops from the 200 training images, and evaluating on crops of the 100 test images. A snapshot of the model after every epoch with filename `model_epoch_.pth`. ## Example Usage: ### Train -`python main.py --upscale_factor 3 --batchSize 4 --testBatchSize 100 --nEpochs 30 --lr 0.001` +```bash +python main.py --upscale_factor 3 --batchSize 4 --testBatchSize 100 --nEpochs 30 --lr 0.001 +``` ### Super Resolve -`python super_resolve.py --input_image dataset/BSDS300/images/test/16077.jpg --model model_epoch_500.pth --output_filename out.png` +```bash +python super_resolve.py --input_image dataset/BSDS300/images/test/16077.jpg --model model_epoch_30.pth --output_filename out.png +``` diff --git a/super_resolution/super_resolve.py b/super_resolution/super_resolve.py index 750d635312..72ca24f022 100644 --- a/super_resolution/super_resolve.py +++ b/super_resolution/super_resolve.py @@ -3,6 +3,7 @@ import torch from PIL import Image from torchvision.transforms import ToTensor +from model import Net import numpy as np @@ -18,7 +19,16 @@ img = Image.open(opt.input_image).convert('YCbCr') y, cb, cr = img.split() -model = torch.load(opt.model) +with open(opt.model, 'rb') as f: + safe_globals = [ + Net, + torch.nn.modules.activation.ReLU, + torch.nn.modules.conv.Conv2d, + torch.nn.modules.pixelshuffle.PixelShuffle, + ] + with torch.serialization.safe_globals(safe_globals): + model = torch.load(f) + img_to_tensor = ToTensor() input = img_to_tensor(y).view(1, -1, y.size[1], y.size[0])