Skip to content

Commit f677078

Browse files
committed
Use torch.acceleratort API in VAE example
1 parent 6967ff5 commit f677078

File tree

3 files changed

+16
-14
lines changed

3 files changed

+16
-14
lines changed

vae/README.md

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,8 +14,7 @@ The main.py script accepts the following arguments:
1414
optional arguments:
1515
--batch-size input batch size for training (default: 128)
1616
--epochs number of epochs to train (default: 10)
17-
--no-cuda enables CUDA training
18-
--mps enables GPU on macOS
17+
--accel use accelerator
1918
--seed random seed (default: 1)
2019
--log-interval how many batches to wait before logging training status
2120
```

vae/main.py

Lines changed: 14 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -13,28 +13,31 @@
1313
help='input batch size for training (default: 128)')
1414
parser.add_argument('--epochs', type=int, default=10, metavar='N',
1515
help='number of epochs to train (default: 10)')
16-
parser.add_argument('--no-cuda', action='store_true', default=False,
17-
help='disables CUDA training')
18-
parser.add_argument('--no-mps', action='store_true', default=False,
19-
help='disables macOS GPU training')
16+
parser.add_argument('--accel', action='store_true',
17+
help='use accelerator')
2018
parser.add_argument('--seed', type=int, default=1, metavar='S',
2119
help='random seed (default: 1)')
2220
parser.add_argument('--log-interval', type=int, default=10, metavar='N',
2321
help='how many batches to wait before logging training status')
2422
args = parser.parse_args()
25-
args.cuda = not args.no_cuda and torch.cuda.is_available()
26-
use_mps = not args.no_mps and torch.backends.mps.is_available()
23+
2724

2825
torch.manual_seed(args.seed)
2926

30-
if args.cuda:
31-
device = torch.device("cuda")
32-
elif use_mps:
33-
device = torch.device("mps")
27+
if args.accel and not torch.accelerator.is_available():
28+
print("ERROR: accelerator is not available, try running on CPU")
29+
sys.exit(1)
30+
if not args.accel and torch.accelerator.is_available():
31+
print("WARNING: accelerator is available, run with --accel to enable it")
32+
33+
if args.accel:
34+
device = torch.accelerator.current_accelerator()
3435
else:
3536
device = torch.device("cpu")
3637

37-
kwargs = {'num_workers': 1, 'pin_memory': True} if args.cuda else {}
38+
print(f"Using device: {device}")
39+
40+
kwargs = {'num_workers': 1, 'pin_memory': True} if device=="cuda" else {}
3841
train_loader = torch.utils.data.DataLoader(
3942
datasets.MNIST('../data', train=True, download=True,
4043
transform=transforms.ToTensor()),

vae/requirements.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
11
torch
2-
torchvision==0.20.0
2+
torchvision
33
tqdm
44
six

0 commit comments

Comments
 (0)