From afbdd24b9707201a24a543cba0479c096e84c1b9 Mon Sep 17 00:00:00 2001 From: Dmitry Rogozhkin Date: Thu, 17 Apr 2025 14:27:06 -0700 Subject: [PATCH] Use torch.accelerator API in Fast Neural Style example `torch.accelerator` API allows to abstract some of the accelerator specifics in the user scripts. This commit modifies Fast Neural Style example with the mentioned API. Things to note: * Commit modifies example command line by replacing accelerator specific flags (`--cuda`, `--mps`, `--xpu`) with the --accel * `torch.accelerator` API appeared in pytorch 2.6, so commit binds example to this torch version * Commit introduces `USE_ACCEL` flag for the test `*.sh` script. At the moment this flag is just an alias for `USE_CUDA`. In the future it should replace `USE_CUDA` once all examples will be converted to use `torch.accelerator` API. Signed-off-by: Dmitry Rogozhkin --- fast_neural_style/README.md | 14 +++--- .../neural_style/neural_style.py | 45 +++++++------------ fast_neural_style/requirements.txt | 2 +- run_python_examples.sh | 15 ++++++- 4 files changed, 37 insertions(+), 39 deletions(-) diff --git a/fast_neural_style/README.md b/fast_neural_style/README.md index c7fbe80320..9b5834ed7d 100644 --- a/fast_neural_style/README.md +++ b/fast_neural_style/README.md @@ -19,21 +19,19 @@ The program is written in Python, and uses [pytorch](http://pytorch.org/), [scip Stylize image ``` -python neural_style/neural_style.py eval --content-image --model --output-image --cuda 0 +python neural_style/neural_style.py eval --content-image --model --output-image --accel ``` - `--content-image`: path to content image you want to stylize. - `--model`: saved model to be used for stylizing the image (eg: `mosaic.pth`) - `--output-image`: path for saving the output image. - `--content-scale`: factor for scaling down the content image if memory is an issue (eg: value of 2 will halve the height and width of content-image) -- `--cuda 0|1`: set it to 1 for running on GPU, 0 for CPU. -- `--mps`: use MPS device backend. -- `--xpu`: use XPU device backend. +- `--accel`: use accelerator Train model ```bash -python neural_style/neural_style.py train --dataset --style-image --save-model-dir --epochs 2 --cuda 1 +python neural_style/neural_style.py train --dataset --style-image --save-model-dir --epochs 2 --accel ``` There are several command line arguments, the important ones are listed below @@ -41,9 +39,9 @@ There are several command line arguments, the important ones are listed below - `--dataset`: path to training dataset, the path should point to a folder containing another folder with all the training images. I used COCO 2014 Training images dataset [80K/13GB] [(download)](https://cocodataset.org/#download). - `--style-image`: path to style-image. - `--save-model-dir`: path to folder where trained model will be saved. -- `--cuda 0|1`: set it to 1 for running on GPU, 0 for CPU. -- `--mps`: use MPS device backend. -- `--xpu`: use XPU device backend. +- `--accel`: use accelerator. + +If `--accel` argument is given, pytorch will search for available hardware acceleration device and attempt to use it. This example is known to work on CUDA, MPS and XPU devices. Refer to `neural_style/neural_style.py` for other command line arguments. For training new models you might have to tune the values of `--content-weight` and `--style-weight`. The mosaic style model shown above was trained with `--content-weight 1e5` and `--style-weight 1e10`. The remaining 3 models were also trained with similar order of weight parameters with slight variation in the `--style-weight` (`5e10` or `1e11`). diff --git a/fast_neural_style/neural_style/neural_style.py b/fast_neural_style/neural_style/neural_style.py index e51007c157..fa98525692 100644 --- a/fast_neural_style/neural_style/neural_style.py +++ b/fast_neural_style/neural_style/neural_style.py @@ -29,16 +29,12 @@ def check_paths(args): def train(args): - if args.cuda: - device = torch.device("cuda") - elif args.mps: - device = torch.device("mps") - elif args.xpu: - device = torch.device("xpu") + if args.accel: + device = torch.accelerator.current_accelerator() else: device = torch.device("cpu") - print("Device to use: ", device) + print(f"Using device: {device}") np.random.seed(args.seed) torch.manual_seed(args.seed) @@ -129,10 +125,12 @@ def train(args): def stylize(args): - device = torch.device("cuda" if args.cuda else "cpu") - device = torch.device("xpu" if args.xpu else "cpu") + if args.accel: + device = torch.accelerator.current_accelerator() + else: + device = torch.device("cpu") - print("Device to use: ", device) + print(f"Using device: {device}") content_image = utils.load_image(args.content_image, scale=args.content_scale) content_transform = transforms.Compose([ @@ -212,8 +210,8 @@ def main(): help="size of training images, default is 256 X 256") train_arg_parser.add_argument("--style-size", type=int, default=None, help="size of style-image, default is the original size of style image") - train_arg_parser.add_argument("--cuda", type=int, required=True, - help="set it to 1 for running on GPU, 0 for CPU") + train_arg_parser.add_argument('--accel', action='store_true', + help='use accelerator') train_arg_parser.add_argument("--seed", type=int, default=42, help="random seed for training") train_arg_parser.add_argument("--content-weight", type=float, default=1e5, @@ -226,10 +224,6 @@ def main(): help="number of images after which the training loss is logged, default is 500") train_arg_parser.add_argument("--checkpoint-interval", type=int, default=2000, help="number of batches after which a checkpoint of the trained model will be created") - train_arg_parser.add_argument('--mps', action='store_true', - help='enable macOS GPU training') - train_arg_parser.add_argument('--xpu', action='store_true', - help='enable Intel XPU training') eval_arg_parser = subparsers.add_parser("eval", help="parser for evaluation/stylizing arguments") eval_arg_parser.add_argument("--content-image", type=str, required=True, @@ -240,28 +234,21 @@ def main(): help="path for saving the output image") eval_arg_parser.add_argument("--model", type=str, required=True, help="saved model to be used for stylizing the image. If file ends in .pth - PyTorch path is used, if in .onnx - Caffe2 path") - eval_arg_parser.add_argument("--cuda", type=int, default=False, - help="set it to 1 for running on cuda, 0 for CPU") eval_arg_parser.add_argument("--export_onnx", type=str, help="export ONNX model to a given file") - eval_arg_parser.add_argument('--mps', action='store_true', - help='enable macOS GPU evaluation') - eval_arg_parser.add_argument('--xpu', action='store_true', - help='enable Intel XPU evaluation') - + eval_arg_parser.add_argument('--accel', action='store_true', + help='use accelerator') args = main_arg_parser.parse_args() if args.subcommand is None: print("ERROR: specify either train or eval") sys.exit(1) - if args.cuda and not torch.cuda.is_available(): - print("ERROR: cuda is not available, try running on CPU") + if args.accel and not torch.accelerator.is_available(): + print("ERROR: accelerator is not available, try running on CPU") sys.exit(1) - if not args.mps and torch.backends.mps.is_available(): - print("WARNING: mps is available, run with --mps to enable macOS GPU") - if not args.xpu and torch.xpu.is_available(): - print("WARNING: XPU is available, run with --xpu to enable Intel XPU") + if not args.accel and torch.accelerator.is_available(): + print("WARNING: accelerator is available, run with --accel to enable it") if args.subcommand == "train": check_paths(args) diff --git a/fast_neural_style/requirements.txt b/fast_neural_style/requirements.txt index cef06d7884..54d4c008f0 100644 --- a/fast_neural_style/requirements.txt +++ b/fast_neural_style/requirements.txt @@ -1,3 +1,3 @@ numpy -torch +torch>=2.6 torchvision diff --git a/run_python_examples.sh b/run_python_examples.sh index d12fb60021..b15f9397d7 100755 --- a/run_python_examples.sh +++ b/run_python_examples.sh @@ -13,6 +13,12 @@ # To test examples on CUDA accelerator, run as: # USE_CUDA=True ./run_python_examples.sh # +# To test examples on hardware accelerator (CUDA, MPS, XPU, etc.), run as: +# USE_ACCEL=True ./run_python_examples.sh +# NOTE: USE_ACCEL relies on torch.accelerator API and not all examples are converted +# to use it at the moment. Thus, expect failures using this flag on non-CUDA accelerators +# and consider to run examples one by one. +# # Script requires uv to be installed. When executed, script will install prerequisites from # `requirements.txt` for each example. If ran within activated virtual environment (uv venv, # python -m venv, conda) this might reinstall some of the packages. To change pip installation @@ -27,17 +33,24 @@ BASE_DIR="$(pwd)/$(dirname $0)" source $BASE_DIR/utils.sh +# TODO: Leave only USE_ACCEL and drop USE_CUDA once all examples will be converted +# to torch.accelerator API. For now, just add USE_ACCEL as an alias for USE_CUDA. +if [ -n "$USE_ACCEL" ]; then + USE_CUDA=$USE_ACCEL +fi USE_CUDA=${USE_CUDA:-False} case $USE_CUDA in "True") echo "using cuda" CUDA=1 CUDA_FLAG="--cuda" + ACCEL_FLAG="--accel" ;; "False") echo "not using cuda" CUDA=0 CUDA_FLAG="" + ACCEL_FLAG="" ;; "") exit 1; @@ -56,7 +69,7 @@ function fast_neural_style() { test -d "saved_models" || { error "saved models not found"; return; } echo "running fast neural style model" - uv run neural_style/neural_style.py eval --content-image images/content-images/amber.jpg --model saved_models/candy.pth --output-image images/output-images/amber-candy.jpg --cuda $CUDA --mps || error "neural_style.py failed" + uv run neural_style/neural_style.py eval --content-image images/content-images/amber.jpg --model saved_models/candy.pth --output-image images/output-images/amber-candy.jpg $ACCEL_FLAG || error "neural_style.py failed" } function imagenet() {