Skip to content

Commit 12dc18e

Browse files
authored
Use torch.accelerator API in Siamese Network example (#1337)
1 parent b7aebb5 commit 12dc18e

File tree

3 files changed

+45
-14
lines changed

3 files changed

+45
-14
lines changed

siamese_network/README.md

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,37 @@
11
# Siamese Network Example
2+
Siamese network for image similarity estimation.
3+
The network is composed of two identical networks, one for each input.
4+
The output of each network is concatenated and passed to a linear layer.
5+
The output of the linear layer passed through a sigmoid function.
6+
[FaceNet](https://arxiv.org/pdf/1503.03832.pdf) is a variant of the Siamese network.
7+
This implementation varies from FaceNet as we use the `ResNet-18` model from
8+
[Deep Residual Learning for Image Recognition](https://arxiv.org/pdf/1512.03385.pdf) as our feature extractor.
9+
In addition, we aren't using `TripletLoss` as the MNIST dataset is simple, so `BCELoss` can do the trick.
210

311
```bash
412
pip install -r requirements.txt
513
python main.py
614
# CUDA_VISIBLE_DEVICES=2 python main.py # to specify GPU id to ex. 2
715
```
16+
Optionally, you can add the following arguments to customize your execution.
17+
18+
```bash
19+
--batch-size input batch size for training (default: 64)
20+
--test-batch-size input batch size for testing (default: 1000)
21+
--epochs number of epochs to train (default: 14)
22+
--lr learning rate (default: 1.0)
23+
--gamma learning rate step gamma (default: 0.7)
24+
--accel use accelerator
25+
--dry-run quickly check a single pass
26+
--seed random seed (default: 1)
27+
--log-interval how many batches to wait before logging training status
28+
--save-model Saving the current Model
29+
```
30+
31+
To execute in an GPU, add the --accel argument to the command. For example:
32+
33+
```bash
34+
python main.py --accel
35+
```
36+
37+
This command will execute the example on the detected GPU.

siamese_network/main.py

Lines changed: 14 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -105,7 +105,7 @@ def group_examples(self):
105105
"""
106106

107107
# get the targets from MNIST dataset
108-
np_arr = np.array(self.dataset.targets.clone())
108+
np_arr = np.array(self.dataset.targets.clone(), dtype=None, copy=None)
109109

110110
# group examples based on class
111111
self.grouped_examples = {}
@@ -247,10 +247,8 @@ def main():
247247
help='learning rate (default: 1.0)')
248248
parser.add_argument('--gamma', type=float, default=0.7, metavar='M',
249249
help='Learning rate step gamma (default: 0.7)')
250-
parser.add_argument('--no-cuda', action='store_true', default=False,
251-
help='disables CUDA training')
252-
parser.add_argument('--no-mps', action='store_true', default=False,
253-
help='disables macOS GPU training')
250+
parser.add_argument('--accel', action='store_true',
251+
help='use accelerator')
254252
parser.add_argument('--dry-run', action='store_true', default=False,
255253
help='quickly check a single pass')
256254
parser.add_argument('--seed', type=int, default=1, metavar='S',
@@ -260,22 +258,25 @@ def main():
260258
parser.add_argument('--save-model', action='store_true', default=False,
261259
help='For Saving the current Model')
262260
args = parser.parse_args()
263-
264-
use_cuda = not args.no_cuda and torch.cuda.is_available()
265-
use_mps = not args.no_mps and torch.backends.mps.is_available()
266261

267262
torch.manual_seed(args.seed)
268263

269-
if use_cuda:
270-
device = torch.device("cuda")
271-
elif use_mps:
272-
device = torch.device("mps")
264+
if args.accel and not torch.accelerator.is_available():
265+
print("ERROR: accelerator is not available, try running on CPU")
266+
sys.exit(1)
267+
if not args.accel and torch.accelerator.is_available():
268+
print("WARNING: accelerator is available, run with --accel to enable it")
269+
270+
if args.accel:
271+
device = torch.accelerator.current_accelerator()
273272
else:
274273
device = torch.device("cpu")
274+
275+
print(f"Using device: {device}")
275276

276277
train_kwargs = {'batch_size': args.batch_size}
277278
test_kwargs = {'batch_size': args.test_batch_size}
278-
if use_cuda:
279+
if device=="cuda":
279280
cuda_kwargs = {'num_workers': 1,
280281
'pin_memory': True,
281282
'shuffle': True}

siamese_network/requirements.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,2 @@
11
torch
2-
torchvision==0.20.0
2+
torchvision

0 commit comments

Comments
 (0)