Skip to content

Use torch.accelerator API in VAE example #1338

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
May 2, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 6 additions & 8 deletions vae/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,14 +8,12 @@ pip install -r requirements.txt
python main.py
```

The main.py script accepts the following arguments:
The main.py script accepts the following optional arguments:

```bash
optional arguments:
--batch-size input batch size for training (default: 128)
--epochs number of epochs to train (default: 10)
--no-cuda enables CUDA training
--mps enables GPU on macOS
--seed random seed (default: 1)
--log-interval how many batches to wait before logging training status
--batch-size input batch size for training (default: 128)
--epochs number of epochs to train (default: 10)
--accel use accelerator
--seed random seed (default: 1)
--log-interval how many batches to wait before logging training status
```
25 changes: 14 additions & 11 deletions vae/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,28 +13,31 @@
help='input batch size for training (default: 128)')
parser.add_argument('--epochs', type=int, default=10, metavar='N',
help='number of epochs to train (default: 10)')
parser.add_argument('--no-cuda', action='store_true', default=False,
help='disables CUDA training')
parser.add_argument('--no-mps', action='store_true', default=False,
help='disables macOS GPU training')
parser.add_argument('--accel', action='store_true',
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@eromomon : I did not see this PR before it got merged. But you've changed the default in this example. Previously acceleration was always used (if available) and there was a flag to disable it. Now default is to not use acceleration and there is a flag to enable it. I personally like new behavior better, but we need to change the CI script to reflect that and that was not done:

uv run main.py --epochs 1 || error "vae failed"

help='use accelerator')
parser.add_argument('--seed', type=int, default=1, metavar='S',
help='random seed (default: 1)')
parser.add_argument('--log-interval', type=int, default=10, metavar='N',
help='how many batches to wait before logging training status')
args = parser.parse_args()
args.cuda = not args.no_cuda and torch.cuda.is_available()
use_mps = not args.no_mps and torch.backends.mps.is_available()


torch.manual_seed(args.seed)

if args.cuda:
device = torch.device("cuda")
elif use_mps:
device = torch.device("mps")
if args.accel and not torch.accelerator.is_available():
print("ERROR: accelerator is not available, try running on CPU")
sys.exit(1)
if not args.accel and torch.accelerator.is_available():
print("WARNING: accelerator is available, run with --accel to enable it")

if args.accel:
device = torch.accelerator.current_accelerator()
else:
device = torch.device("cpu")

kwargs = {'num_workers': 1, 'pin_memory': True} if args.cuda else {}
print(f"Using device: {device}")

kwargs = {'num_workers': 1, 'pin_memory': True} if device=="cuda" else {}
train_loader = torch.utils.data.DataLoader(
datasets.MNIST('../data', train=True, download=True,
transform=transforms.ToTensor()),
Expand Down
2 changes: 1 addition & 1 deletion vae/requirements.txt
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
torch
torchvision==0.20.0
torchvision
tqdm
six