diff --git a/mnist/main.py b/mnist/main.py index e3b7fc0beb..2197bfcf41 100644 --- a/mnist/main.py +++ b/mnist/main.py @@ -1,11 +1,27 @@ from __future__ import print_function import argparse +import errno +import os import torch import torch.nn as nn import torch.nn.functional as F import torch.optim as optim from torchvision import datasets, transforms + +def makedir_exist_ok(dirpath): + """ + Python2 support for os.makedirs(.., exist_ok=True) + """ + try: + os.makedirs(dirpath) + except OSError as e: + if e.errno == errno.EEXIST: + pass + else: + raise + + class Net(nn.Module): def __init__(self): super(Net, self).__init__() @@ -55,6 +71,13 @@ def test(args, model, device, test_loader): test_loss, correct, len(test_loader.dataset), 100. * correct / len(test_loader.dataset))) + +def save_model(args, model): + makedir_exist_ok(args.result_dir) + path = os.path.join(args.result_dir, 'models.pt') + torch.save(model.state_dict(), path) + + def main(): # Training settings parser = argparse.ArgumentParser(description='PyTorch MNIST Example') @@ -74,6 +97,10 @@ def main(): help='random seed (default: 1)') parser.add_argument('--log-interval', type=int, default=10, metavar='N', help='how many batches to wait before logging training status') + parser.add_argument('--data-dir', type=str, default='../data', + help='location holding the training and test data.') + parser.add_argument('--result-dir', type=str, default='../data', + help='location to save training results.') args = parser.parse_args() use_cuda = not args.no_cuda and torch.cuda.is_available() @@ -83,14 +110,14 @@ def main(): kwargs = {'num_workers': 1, 'pin_memory': True} if use_cuda else {} train_loader = torch.utils.data.DataLoader( - datasets.MNIST('../data', train=True, download=True, + datasets.MNIST(args.data_dir, train=True, download=True, transform=transforms.Compose([ transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,)) ])), batch_size=args.batch_size, shuffle=True, **kwargs) test_loader = torch.utils.data.DataLoader( - datasets.MNIST('../data', train=False, transform=transforms.Compose([ + datasets.MNIST(args.data_dir, train=False, transform=transforms.Compose([ transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,)) ])), @@ -104,6 +131,7 @@ def main(): train(args, model, device, train_loader, optimizer, epoch) test(args, model, device, test_loader) + save_model(args, model) if __name__ == '__main__': - main() \ No newline at end of file + main()