|
13 | 13 | help='input batch size for training (default: 128)')
|
14 | 14 | parser.add_argument('--epochs', type=int, default=10, metavar='N',
|
15 | 15 | help='number of epochs to train (default: 10)')
|
16 |
| -parser.add_argument('--no-cuda', action='store_true', default=False, |
17 |
| - help='disables CUDA training') |
18 |
| -parser.add_argument('--no-mps', action='store_true', default=False, |
19 |
| - help='disables macOS GPU training') |
| 16 | +parser.add_argument('--accel', action='store_true', |
| 17 | + help='use accelerator') |
20 | 18 | parser.add_argument('--seed', type=int, default=1, metavar='S',
|
21 | 19 | help='random seed (default: 1)')
|
22 | 20 | parser.add_argument('--log-interval', type=int, default=10, metavar='N',
|
23 | 21 | help='how many batches to wait before logging training status')
|
24 | 22 | args = parser.parse_args()
|
25 |
| -args.cuda = not args.no_cuda and torch.cuda.is_available() |
26 |
| -use_mps = not args.no_mps and torch.backends.mps.is_available() |
| 23 | + |
27 | 24 |
|
28 | 25 | torch.manual_seed(args.seed)
|
29 | 26 |
|
30 |
| -if args.cuda: |
31 |
| - device = torch.device("cuda") |
32 |
| -elif use_mps: |
33 |
| - device = torch.device("mps") |
| 27 | +if args.accel and not torch.accelerator.is_available(): |
| 28 | + print("ERROR: accelerator is not available, try running on CPU") |
| 29 | + sys.exit(1) |
| 30 | +if not args.accel and torch.accelerator.is_available(): |
| 31 | + print("WARNING: accelerator is available, run with --accel to enable it") |
| 32 | + |
| 33 | +if args.accel: |
| 34 | + device = torch.accelerator.current_accelerator() |
34 | 35 | else:
|
35 | 36 | device = torch.device("cpu")
|
36 | 37 |
|
37 |
| -kwargs = {'num_workers': 1, 'pin_memory': True} if args.cuda else {} |
| 38 | +print(f"Using device: {device}") |
| 39 | + |
| 40 | +kwargs = {'num_workers': 1, 'pin_memory': True} if device=="cuda" else {} |
38 | 41 | train_loader = torch.utils.data.DataLoader(
|
39 | 42 | datasets.MNIST('../data', train=True, download=True,
|
40 | 43 | transform=transforms.ToTensor()),
|
|
0 commit comments