Skip to content

Commit d8dd8ca

Browse files
surgan12soumith
authored andcommitted
Mnist update (#469)
* mnist added dcgan * mnist added * mnist improved * Update main.py
1 parent 35a9d84 commit d8dd8ca

File tree

1 file changed

+19
-12
lines changed

1 file changed

+19
-12
lines changed

mnist/main.py

Lines changed: 19 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -6,24 +6,25 @@
66
import torch.optim as optim
77
from torchvision import datasets, transforms
88

9+
910
class Net(nn.Module):
1011
def __init__(self):
1112
super(Net, self).__init__()
12-
self.conv1 = nn.Conv2d(1, 10, kernel_size=5)
13-
self.conv2 = nn.Conv2d(10, 20, kernel_size=5)
14-
self.conv2_drop = nn.Dropout2d()
15-
self.fc1 = nn.Linear(320, 50)
16-
self.fc2 = nn.Linear(50, 10)
13+
self.conv1 = nn.Conv2d(1, 20, 5, 1)
14+
self.conv2 = nn.Conv2d(20, 50, 5, 1)
15+
self.fc1 = nn.Linear(4*4*50, 500)
16+
self.fc2 = nn.Linear(500, 10)
1717

1818
def forward(self, x):
19-
x = F.relu(F.max_pool2d(self.conv1(x), 2))
20-
x = F.relu(F.max_pool2d(self.conv2_drop(self.conv2(x)), 2))
21-
x = x.view(-1, 320)
19+
x = F.relu(self.conv1(x))
20+
x = F.max_pool2d(x, 2, 2)
21+
x = F.relu(self.conv2(x))
22+
x = F.max_pool2d(x, 2, 2)
23+
x = x.view(-1, 4*4*50)
2224
x = F.relu(self.fc1(x))
23-
x = F.dropout(x, training=self.training)
2425
x = self.fc2(x)
2526
return F.log_softmax(x, dim=1)
26-
27+
2728
def train(args, model, device, train_loader, optimizer, epoch):
2829
model.train()
2930
for batch_idx, (data, target) in enumerate(train_loader):
@@ -51,6 +52,7 @@ def test(args, model, device, test_loader):
5152
correct += pred.eq(target.view_as(pred)).sum().item()
5253

5354
test_loss /= len(test_loader.dataset)
55+
5456
print('\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.format(
5557
test_loss, correct, len(test_loader.dataset),
5658
100. * correct / len(test_loader.dataset)))
@@ -74,6 +76,9 @@ def main():
7476
help='random seed (default: 1)')
7577
parser.add_argument('--log-interval', type=int, default=10, metavar='N',
7678
help='how many batches to wait before logging training status')
79+
80+
parser.add_argument('--save-model', action='store_true', default=False,
81+
help='For Saving the current Model')
7782
args = parser.parse_args()
7883
use_cuda = not args.no_cuda and torch.cuda.is_available()
7984

@@ -104,6 +109,8 @@ def main():
104109
train(args, model, device, train_loader, optimizer, epoch)
105110
test(args, model, device, test_loader)
106111

107-
112+
if (args.save_model):
113+
torch.save(model.state_dict(),"mnist_cnn.pt")
114+
108115
if __name__ == '__main__':
109-
main()
116+
main()

0 commit comments

Comments
 (0)