Skip to content

Commit 65afde6

Browse files
authored
Use torch.accelerator API in Fast Neural Style example (#1327)
1 parent 00ef8a7 commit 65afde6

File tree

4 files changed

+37
-39
lines changed

4 files changed

+37
-39
lines changed

fast_neural_style/README.md

Lines changed: 6 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -19,31 +19,29 @@ The program is written in Python, and uses [pytorch](http://pytorch.org/), [scip
1919
Stylize image
2020

2121
```
22-
python neural_style/neural_style.py eval --content-image </path/to/content/image> --model </path/to/saved/model> --output-image </path/to/output/image> --cuda 0
22+
python neural_style/neural_style.py eval --content-image </path/to/content/image> --model </path/to/saved/model> --output-image </path/to/output/image> --accel
2323
```
2424

2525
- `--content-image`: path to content image you want to stylize.
2626
- `--model`: saved model to be used for stylizing the image (eg: `mosaic.pth`)
2727
- `--output-image`: path for saving the output image.
2828
- `--content-scale`: factor for scaling down the content image if memory is an issue (eg: value of 2 will halve the height and width of content-image)
29-
- `--cuda 0|1`: set it to 1 for running on GPU, 0 for CPU.
30-
- `--mps`: use MPS device backend.
31-
- `--xpu`: use XPU device backend.
29+
- `--accel`: use accelerator
3230

3331
Train model
3432

3533
```bash
36-
python neural_style/neural_style.py train --dataset </path/to/train-dataset> --style-image </path/to/style/image> --save-model-dir </path/to/save-model/folder> --epochs 2 --cuda 1
34+
python neural_style/neural_style.py train --dataset </path/to/train-dataset> --style-image </path/to/style/image> --save-model-dir </path/to/save-model/folder> --epochs 2 --accel
3735
```
3836

3937
There are several command line arguments, the important ones are listed below
4038

4139
- `--dataset`: path to training dataset, the path should point to a folder containing another folder with all the training images. I used COCO 2014 Training images dataset [80K/13GB] [(download)](https://cocodataset.org/#download).
4240
- `--style-image`: path to style-image.
4341
- `--save-model-dir`: path to folder where trained model will be saved.
44-
- `--cuda 0|1`: set it to 1 for running on GPU, 0 for CPU.
45-
- `--mps`: use MPS device backend.
46-
- `--xpu`: use XPU device backend.
42+
- `--accel`: use accelerator.
43+
44+
If `--accel` argument is given, pytorch will search for available hardware acceleration device and attempt to use it. This example is known to work on CUDA, MPS and XPU devices.
4745

4846
Refer to `neural_style/neural_style.py` for other command line arguments. For training new models you might have to tune the values of `--content-weight` and `--style-weight`. The mosaic style model shown above was trained with `--content-weight 1e5` and `--style-weight 1e10`. The remaining 3 models were also trained with similar order of weight parameters with slight variation in the `--style-weight` (`5e10` or `1e11`).
4947

fast_neural_style/neural_style/neural_style.py

Lines changed: 16 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -29,16 +29,12 @@ def check_paths(args):
2929

3030

3131
def train(args):
32-
if args.cuda:
33-
device = torch.device("cuda")
34-
elif args.mps:
35-
device = torch.device("mps")
36-
elif args.xpu:
37-
device = torch.device("xpu")
32+
if args.accel:
33+
device = torch.accelerator.current_accelerator()
3834
else:
3935
device = torch.device("cpu")
4036

41-
print("Device to use: ", device)
37+
print(f"Using device: {device}")
4238

4339
np.random.seed(args.seed)
4440
torch.manual_seed(args.seed)
@@ -129,10 +125,12 @@ def train(args):
129125

130126

131127
def stylize(args):
132-
device = torch.device("cuda" if args.cuda else "cpu")
133-
device = torch.device("xpu" if args.xpu else "cpu")
128+
if args.accel:
129+
device = torch.accelerator.current_accelerator()
130+
else:
131+
device = torch.device("cpu")
134132

135-
print("Device to use: ", device)
133+
print(f"Using device: {device}")
136134

137135
content_image = utils.load_image(args.content_image, scale=args.content_scale)
138136
content_transform = transforms.Compose([
@@ -212,8 +210,8 @@ def main():
212210
help="size of training images, default is 256 X 256")
213211
train_arg_parser.add_argument("--style-size", type=int, default=None,
214212
help="size of style-image, default is the original size of style image")
215-
train_arg_parser.add_argument("--cuda", type=int, required=True,
216-
help="set it to 1 for running on GPU, 0 for CPU")
213+
train_arg_parser.add_argument('--accel', action='store_true',
214+
help='use accelerator')
217215
train_arg_parser.add_argument("--seed", type=int, default=42,
218216
help="random seed for training")
219217
train_arg_parser.add_argument("--content-weight", type=float, default=1e5,
@@ -226,10 +224,6 @@ def main():
226224
help="number of images after which the training loss is logged, default is 500")
227225
train_arg_parser.add_argument("--checkpoint-interval", type=int, default=2000,
228226
help="number of batches after which a checkpoint of the trained model will be created")
229-
train_arg_parser.add_argument('--mps', action='store_true',
230-
help='enable macOS GPU training')
231-
train_arg_parser.add_argument('--xpu', action='store_true',
232-
help='enable Intel XPU training')
233227

234228
eval_arg_parser = subparsers.add_parser("eval", help="parser for evaluation/stylizing arguments")
235229
eval_arg_parser.add_argument("--content-image", type=str, required=True,
@@ -240,28 +234,21 @@ def main():
240234
help="path for saving the output image")
241235
eval_arg_parser.add_argument("--model", type=str, required=True,
242236
help="saved model to be used for stylizing the image. If file ends in .pth - PyTorch path is used, if in .onnx - Caffe2 path")
243-
eval_arg_parser.add_argument("--cuda", type=int, default=False,
244-
help="set it to 1 for running on cuda, 0 for CPU")
245237
eval_arg_parser.add_argument("--export_onnx", type=str,
246238
help="export ONNX model to a given file")
247-
eval_arg_parser.add_argument('--mps', action='store_true',
248-
help='enable macOS GPU evaluation')
249-
eval_arg_parser.add_argument('--xpu', action='store_true',
250-
help='enable Intel XPU evaluation')
251-
239+
eval_arg_parser.add_argument('--accel', action='store_true',
240+
help='use accelerator')
252241

253242
args = main_arg_parser.parse_args()
254243

255244
if args.subcommand is None:
256245
print("ERROR: specify either train or eval")
257246
sys.exit(1)
258-
if args.cuda and not torch.cuda.is_available():
259-
print("ERROR: cuda is not available, try running on CPU")
247+
if args.accel and not torch.accelerator.is_available():
248+
print("ERROR: accelerator is not available, try running on CPU")
260249
sys.exit(1)
261-
if not args.mps and torch.backends.mps.is_available():
262-
print("WARNING: mps is available, run with --mps to enable macOS GPU")
263-
if not args.xpu and torch.xpu.is_available():
264-
print("WARNING: XPU is available, run with --xpu to enable Intel XPU")
250+
if not args.accel and torch.accelerator.is_available():
251+
print("WARNING: accelerator is available, run with --accel to enable it")
265252

266253
if args.subcommand == "train":
267254
check_paths(args)

fast_neural_style/requirements.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,3 @@
11
numpy
2-
torch
2+
torch>=2.6
33
torchvision

run_python_examples.sh

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,12 @@
1313
# To test examples on CUDA accelerator, run as:
1414
# USE_CUDA=True ./run_python_examples.sh
1515
#
16+
# To test examples on hardware accelerator (CUDA, MPS, XPU, etc.), run as:
17+
# USE_ACCEL=True ./run_python_examples.sh
18+
# NOTE: USE_ACCEL relies on torch.accelerator API and not all examples are converted
19+
# to use it at the moment. Thus, expect failures using this flag on non-CUDA accelerators
20+
# and consider to run examples one by one.
21+
#
1622
# Script requires uv to be installed. When executed, script will install prerequisites from
1723
# `requirements.txt` for each example. If ran within activated virtual environment (uv venv,
1824
# python -m venv, conda) this might reinstall some of the packages. To change pip installation
@@ -27,17 +33,24 @@
2733
BASE_DIR="$(pwd)/$(dirname $0)"
2834
source $BASE_DIR/utils.sh
2935

36+
# TODO: Leave only USE_ACCEL and drop USE_CUDA once all examples will be converted
37+
# to torch.accelerator API. For now, just add USE_ACCEL as an alias for USE_CUDA.
38+
if [ -n "$USE_ACCEL" ]; then
39+
USE_CUDA=$USE_ACCEL
40+
fi
3041
USE_CUDA=${USE_CUDA:-False}
3142
case $USE_CUDA in
3243
"True")
3344
echo "using cuda"
3445
CUDA=1
3546
CUDA_FLAG="--cuda"
47+
ACCEL_FLAG="--accel"
3648
;;
3749
"False")
3850
echo "not using cuda"
3951
CUDA=0
4052
CUDA_FLAG=""
53+
ACCEL_FLAG=""
4154
;;
4255
"")
4356
exit 1;
@@ -56,7 +69,7 @@ function fast_neural_style() {
5669
test -d "saved_models" || { error "saved models not found"; return; }
5770

5871
echo "running fast neural style model"
59-
uv run neural_style/neural_style.py eval --content-image images/content-images/amber.jpg --model saved_models/candy.pth --output-image images/output-images/amber-candy.jpg --cuda $CUDA --mps || error "neural_style.py failed"
72+
uv run neural_style/neural_style.py eval --content-image images/content-images/amber.jpg --model saved_models/candy.pth --output-image images/output-images/amber-candy.jpg $ACCEL_FLAG || error "neural_style.py failed"
6073
}
6174

6275
function imagenet() {

0 commit comments

Comments
 (0)