6
6
import torch .optim as optim
7
7
from torchvision import datasets , transforms
8
8
9
+
9
10
class Net (nn .Module ):
10
11
def __init__ (self ):
11
12
super (Net , self ).__init__ ()
12
- self .conv1 = nn .Conv2d (1 , 10 , kernel_size = 5 )
13
- self .conv2 = nn .Conv2d (10 , 20 , kernel_size = 5 )
14
- self .conv2_drop = nn .Dropout2d ()
15
- self .fc1 = nn .Linear (320 , 50 )
16
- self .fc2 = nn .Linear (50 , 10 )
13
+ self .conv1 = nn .Conv2d (1 , 20 , 5 , 1 )
14
+ self .conv2 = nn .Conv2d (20 , 50 , 5 , 1 )
15
+ self .fc1 = nn .Linear (4 * 4 * 50 , 500 )
16
+ self .fc2 = nn .Linear (500 , 10 )
17
17
18
18
def forward (self , x ):
19
- x = F .relu (F .max_pool2d (self .conv1 (x ), 2 ))
20
- x = F .relu (F .max_pool2d (self .conv2_drop (self .conv2 (x )), 2 ))
21
- x = x .view (- 1 , 320 )
19
+ x = F .relu (self .conv1 (x ))
20
+ x = F .max_pool2d (x , 2 , 2 )
21
+ x = F .relu (self .conv2 (x ))
22
+ x = F .max_pool2d (x , 2 , 2 )
23
+ x = x .view (- 1 , 4 * 4 * 50 )
22
24
x = F .relu (self .fc1 (x ))
23
- x = F .dropout (x , training = self .training )
24
25
x = self .fc2 (x )
25
26
return F .log_softmax (x , dim = 1 )
26
-
27
+
27
28
def train (args , model , device , train_loader , optimizer , epoch ):
28
29
model .train ()
29
30
for batch_idx , (data , target ) in enumerate (train_loader ):
@@ -51,6 +52,7 @@ def test(args, model, device, test_loader):
51
52
correct += pred .eq (target .view_as (pred )).sum ().item ()
52
53
53
54
test_loss /= len (test_loader .dataset )
55
+
54
56
print ('\n Test set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n ' .format (
55
57
test_loss , correct , len (test_loader .dataset ),
56
58
100. * correct / len (test_loader .dataset )))
@@ -74,6 +76,9 @@ def main():
74
76
help = 'random seed (default: 1)' )
75
77
parser .add_argument ('--log-interval' , type = int , default = 10 , metavar = 'N' ,
76
78
help = 'how many batches to wait before logging training status' )
79
+
80
+ parser .add_argument ('--save-model' , action = 'store_true' , default = False ,
81
+ help = 'For Saving the current Model' )
77
82
args = parser .parse_args ()
78
83
use_cuda = not args .no_cuda and torch .cuda .is_available ()
79
84
@@ -104,6 +109,8 @@ def main():
104
109
train (args , model , device , train_loader , optimizer , epoch )
105
110
test (args , model , device , test_loader )
106
111
107
-
112
+ if (args .save_model ):
113
+ torch .save (model .state_dict (),"mnist_cnn.pt" )
114
+
108
115
if __name__ == '__main__' :
109
- main ()
116
+ main ()
0 commit comments