Skip to content

Commit 0435cb1

Browse files
authored
Merge pull request #28 from Maratyszcza/master
Auto-detect architectures in ImageNet example
2 parents 66238bc + 9a92d76 commit 0435cb1

File tree

1 file changed

+8
-5
lines changed

1 file changed

+8
-5
lines changed

imagenet/main.py

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -14,12 +14,18 @@
1414
import torchvision.models as models
1515

1616

17+
model_names = sorted(name for name in models.__dict__
18+
if name.islower() and not name.startswith("__"))
19+
20+
1721
parser = argparse.ArgumentParser(description='PyTorch ImageNet Training')
1822
parser.add_argument('data', metavar='DIR',
1923
help='path to dataset')
2024
parser.add_argument('--arch', '-a', metavar='ARCH', default='resnet18',
21-
help='model architecture: resnet18 | resnet34 | ... '
22-
'(default: resnet18)')
25+
choices=model_names,
26+
help='model architecture: ' +
27+
' | '.join(model_names) +
28+
' (default: resnet18)')
2329
parser.add_argument('-j', '--workers', default=4, type=int, metavar='N',
2430
help='number of data loading workers (default: 4)')
2531
parser.add_argument('--epochs', default=90, type=int, metavar='N',
@@ -50,9 +56,6 @@ def main():
5056
global args, best_prec1
5157
args = parser.parse_args()
5258

53-
if args.arch not in models.__dict__:
54-
parser.error('invalid architecture: {}'.format(args.arch))
55-
5659
# create model
5760
if args.pretrained:
5861
print("=> using pre-trained model '{}'".format(args.arch))

0 commit comments

Comments
 (0)