Skip to content

Commit fc85412

Browse files
committed
Use torch.accelerator API in Fast Neural Style example
Signed-off-by: Dmitry Rogozhkin <dmitry.v.rogozhkin@intel.com>
1 parent 8393ceb commit fc85412

File tree

3 files changed

+23
-38
lines changed

3 files changed

+23
-38
lines changed

fast_neural_style/README.md

Lines changed: 4 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -19,31 +19,27 @@ The program is written in Python, and uses [pytorch](http://pytorch.org/), [scip
1919
Stylize image
2020

2121
```
22-
python neural_style/neural_style.py eval --content-image </path/to/content/image> --model </path/to/saved/model> --output-image </path/to/output/image> --cuda 0
22+
python neural_style/neural_style.py eval --content-image </path/to/content/image> --model </path/to/saved/model> --output-image </path/to/output/image> --accel
2323
```
2424

2525
- `--content-image`: path to content image you want to stylize.
2626
- `--model`: saved model to be used for stylizing the image (eg: `mosaic.pth`)
2727
- `--output-image`: path for saving the output image.
2828
- `--content-scale`: factor for scaling down the content image if memory is an issue (eg: value of 2 will halve the height and width of content-image)
29-
- `--cuda 0|1`: set it to 1 for running on GPU, 0 for CPU.
30-
- `--mps`: use MPS device backend.
31-
- `--xpu`: use XPU device backend.
29+
- `--accel`: use accelerator
3230

3331
Train model
3432

3533
```bash
36-
python neural_style/neural_style.py train --dataset </path/to/train-dataset> --style-image </path/to/style/image> --save-model-dir </path/to/save-model/folder> --epochs 2 --cuda 1
34+
python neural_style/neural_style.py train --dataset </path/to/train-dataset> --style-image </path/to/style/image> --save-model-dir </path/to/save-model/folder> --epochs 2 --accel
3735
```
3836

3937
There are several command line arguments, the important ones are listed below
4038

4139
- `--dataset`: path to training dataset, the path should point to a folder containing another folder with all the training images. I used COCO 2014 Training images dataset [80K/13GB] [(download)](https://cocodataset.org/#download).
4240
- `--style-image`: path to style-image.
4341
- `--save-model-dir`: path to folder where trained model will be saved.
44-
- `--cuda 0|1`: set it to 1 for running on GPU, 0 for CPU.
45-
- `--mps`: use MPS device backend.
46-
- `--xpu`: use XPU device backend.
42+
- `--accel`: use accelerator.
4743

4844
Refer to `neural_style/neural_style.py` for other command line arguments. For training new models you might have to tune the values of `--content-weight` and `--style-weight`. The mosaic style model shown above was trained with `--content-weight 1e5` and `--style-weight 1e10`. The remaining 3 models were also trained with similar order of weight parameters with slight variation in the `--style-weight` (`5e10` or `1e11`).
4945

fast_neural_style/neural_style/neural_style.py

Lines changed: 16 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -29,16 +29,12 @@ def check_paths(args):
2929

3030

3131
def train(args):
32-
if args.cuda:
33-
device = torch.device("cuda")
34-
elif args.mps:
35-
device = torch.device("mps")
36-
elif args.xpu:
37-
device = torch.device("xpu")
32+
if args.accel:
33+
device = torch.accelerator.current_accelerator()
3834
else:
3935
device = torch.device("cpu")
4036

41-
print("Device to use: ", device)
37+
print(f"Using device: {device}")
4238

4339
np.random.seed(args.seed)
4440
torch.manual_seed(args.seed)
@@ -129,10 +125,12 @@ def train(args):
129125

130126

131127
def stylize(args):
132-
device = torch.device("cuda" if args.cuda else "cpu")
133-
device = torch.device("xpu" if args.xpu else "cpu")
128+
if args.accel:
129+
device = torch.accelerator.current_accelerator()
130+
else:
131+
device = torch.device("cpu")
134132

135-
print("Device to use: ", device)
133+
print(f"Using device: {device}")
136134

137135
content_image = utils.load_image(args.content_image, scale=args.content_scale)
138136
content_transform = transforms.Compose([
@@ -212,8 +210,8 @@ def main():
212210
help="size of training images, default is 256 X 256")
213211
train_arg_parser.add_argument("--style-size", type=int, default=None,
214212
help="size of style-image, default is the original size of style image")
215-
train_arg_parser.add_argument("--cuda", type=int, required=True,
216-
help="set it to 1 for running on GPU, 0 for CPU")
213+
train_arg_parser.add_argument('--accel', action='store_true',
214+
help='use accelerator')
217215
train_arg_parser.add_argument("--seed", type=int, default=42,
218216
help="random seed for training")
219217
train_arg_parser.add_argument("--content-weight", type=float, default=1e5,
@@ -226,10 +224,6 @@ def main():
226224
help="number of images after which the training loss is logged, default is 500")
227225
train_arg_parser.add_argument("--checkpoint-interval", type=int, default=2000,
228226
help="number of batches after which a checkpoint of the trained model will be created")
229-
train_arg_parser.add_argument('--mps', action='store_true',
230-
help='enable macOS GPU training')
231-
train_arg_parser.add_argument('--xpu', action='store_true',
232-
help='enable Intel XPU training')
233227

234228
eval_arg_parser = subparsers.add_parser("eval", help="parser for evaluation/stylizing arguments")
235229
eval_arg_parser.add_argument("--content-image", type=str, required=True,
@@ -240,28 +234,21 @@ def main():
240234
help="path for saving the output image")
241235
eval_arg_parser.add_argument("--model", type=str, required=True,
242236
help="saved model to be used for stylizing the image. If file ends in .pth - PyTorch path is used, if in .onnx - Caffe2 path")
243-
eval_arg_parser.add_argument("--cuda", type=int, default=False,
244-
help="set it to 1 for running on cuda, 0 for CPU")
245237
eval_arg_parser.add_argument("--export_onnx", type=str,
246238
help="export ONNX model to a given file")
247-
eval_arg_parser.add_argument('--mps', action='store_true',
248-
help='enable macOS GPU evaluation')
249-
eval_arg_parser.add_argument('--xpu', action='store_true',
250-
help='enable Intel XPU evaluation')
251-
239+
eval_arg_parser.add_argument('--accel', action='store_true',
240+
help='use accelerator')
252241

253242
args = main_arg_parser.parse_args()
254243

255244
if args.subcommand is None:
256245
print("ERROR: specify either train or eval")
257246
sys.exit(1)
258-
if args.cuda and not torch.cuda.is_available():
259-
print("ERROR: cuda is not available, try running on CPU")
247+
if args.accel and not torch.accelerator.is_available():
248+
print("ERROR: accelerator is not available, try running on CPU")
260249
sys.exit(1)
261-
if not args.mps and torch.backends.mps.is_available():
262-
print("WARNING: mps is available, run with --mps to enable macOS GPU")
263-
if not args.xpu and torch.xpu.is_available():
264-
print("WARNING: XPU is available, run with --xpu to enable Intel XPU")
250+
if not args.accel and torch.accelerator.is_available():
251+
print("WARNING: accelerator is available, run with --accel to enable it")
265252

266253
if args.subcommand == "train":
267254
check_paths(args)

run_python_examples.sh

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,11 +19,13 @@ case $USE_CUDA in
1919
echo "using cuda"
2020
CUDA=1
2121
CUDA_FLAG="--cuda"
22+
ACCEL_FLAG="--accel"
2223
;;
2324
"False")
2425
echo "not using cuda"
2526
CUDA=0
2627
CUDA_FLAG=""
28+
ACCEL_FLAG=""
2729
;;
2830
"")
2931
exit 1;
@@ -44,7 +46,7 @@ function fast_neural_style() {
4446
test -d "saved_models" || { error "saved models not found"; return; }
4547

4648
echo "running fast neural style model"
47-
python neural_style/neural_style.py eval --content-image images/content-images/amber.jpg --model saved_models/candy.pth --output-image images/output-images/amber-candy.jpg --cuda $CUDA --mps || error "neural_style.py failed"
49+
python neural_style/neural_style.py eval --content-image images/content-images/amber.jpg --model saved_models/candy.pth --output-image images/output-images/amber-candy.jpg $ACCEL_FLAG || error "neural_style.py failed"
4850
}
4951

5052
function imagenet() {

0 commit comments

Comments
 (0)