diff --git a/gat/README.md b/gat/README.md index 7bb71bc17b..c15076a597 100644 --- a/gat/README.md +++ b/gat/README.md @@ -87,8 +87,7 @@ options: --concat-heads wether to concatinate attention heads, or average over them (default: False) --val-every VAL_EVERY epochs to wait for print training and validation evaluation (default: 20) - --no-cuda disables CUDA training - --no-mps disables macOS GPU training + --no-accel disables accelerator --dry-run quickly check a single pass --seed S random seed (default: 13) ``` diff --git a/gat/main.py b/gat/main.py index 9c143af8ec..948833f3ae 100644 --- a/gat/main.py +++ b/gat/main.py @@ -303,29 +303,25 @@ def test(model, criterion, input, target, mask): help='dimension of the hidden representation (default: 64)') parser.add_argument('--num-heads', type=int, default=8, help='number of the attention heads (default: 4)') - parser.add_argument('--concat-heads', action='store_true', default=False, + parser.add_argument('--concat-heads', action='store_true', help='wether to concatinate attention heads, or average over them (default: False)') parser.add_argument('--val-every', type=int, default=20, help='epochs to wait for print training and validation evaluation (default: 20)') - parser.add_argument('--no-cuda', action='store_true', default=False, + parser.add_argument('--no-accel', action='store_true', help='disables CUDA training') - parser.add_argument('--no-mps', action='store_true', default=False, - help='disables macOS GPU training') - parser.add_argument('--dry-run', action='store_true', default=False, + parser.add_argument('--dry-run', action='store_true', help='quickly check a single pass') parser.add_argument('--seed', type=int, default=13, metavar='S', help='random seed (default: 13)') args = parser.parse_args() torch.manual_seed(args.seed) - use_cuda = not args.no_cuda and torch.cuda.is_available() - use_mps = not args.no_mps and torch.backends.mps.is_available() + + use_accel = not args.no_accel and torch.accelerator.is_available() # Set the device to run on - if use_cuda: - device = torch.device('cuda') - elif use_mps: - device = torch.device('mps') + if use_accel: + device = torch.accelerator.current_accelerator() else: device = torch.device('cpu') print(f'Using {device} device') diff --git a/gat/requirements.txt b/gat/requirements.txt index a9d45a75ec..a4cf014769 100644 --- a/gat/requirements.txt +++ b/gat/requirements.txt @@ -1,3 +1,3 @@ torch requests -numpy<2 +numpy