Skip to content

Commit 99d1f15

Browse files
Add files via upload
1 parent 84cb947 commit 99d1f15

File tree

1 file changed

+306
-0
lines changed

1 file changed

+306
-0
lines changed

MNIST Classification Model.ipynb

Lines changed: 306 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,306 @@
1+
{
2+
"cells": [
3+
{
4+
"cell_type": "code",
5+
"execution_count": 2,
6+
"id": "ecdbd317",
7+
"metadata": {},
8+
"outputs": [],
9+
"source": [
10+
"import torch as tch\n",
11+
"import torchvision.datasets as dt\n",
12+
"import torchvision.transforms as trans\n",
13+
"import torch.nn as nn\n",
14+
"import matplotlib.pyplot as plt\n",
15+
"from time import time"
16+
]
17+
},
18+
{
19+
"cell_type": "code",
20+
"execution_count": 3,
21+
"id": "6c333d1a",
22+
"metadata": {},
23+
"outputs": [
24+
{
25+
"name": "stdout",
26+
"output_type": "stream",
27+
"text": [
28+
"Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz\n",
29+
"Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz to ./datasets\\MNIST\\raw\\train-images-idx3-ubyte.gz\n"
30+
]
31+
},
32+
{
33+
"data": {
34+
"application/vnd.jupyter.widget-view+json": {
35+
"model_id": "01e4f3f628994928bfa4a950fe0a3e33",
36+
"version_major": 2,
37+
"version_minor": 0
38+
},
39+
"text/plain": [
40+
" 0%| | 0/9912422 [00:00<?, ?it/s]"
41+
]
42+
},
43+
"metadata": {},
44+
"output_type": "display_data"
45+
},
46+
{
47+
"name": "stdout",
48+
"output_type": "stream",
49+
"text": [
50+
"Extracting ./datasets\\MNIST\\raw\\train-images-idx3-ubyte.gz to ./datasets\\MNIST\\raw\n",
51+
"\n",
52+
"Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz\n",
53+
"Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz to ./datasets\\MNIST\\raw\\train-labels-idx1-ubyte.gz\n"
54+
]
55+
},
56+
{
57+
"data": {
58+
"application/vnd.jupyter.widget-view+json": {
59+
"model_id": "1de9e88ce1c94b34b1cc1c8b5b9c70a8",
60+
"version_major": 2,
61+
"version_minor": 0
62+
},
63+
"text/plain": [
64+
" 0%| | 0/28881 [00:00<?, ?it/s]"
65+
]
66+
},
67+
"metadata": {},
68+
"output_type": "display_data"
69+
},
70+
{
71+
"name": "stdout",
72+
"output_type": "stream",
73+
"text": [
74+
"Extracting ./datasets\\MNIST\\raw\\train-labels-idx1-ubyte.gz to ./datasets\\MNIST\\raw\n",
75+
"\n",
76+
"Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz\n",
77+
"Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz to ./datasets\\MNIST\\raw\\t10k-images-idx3-ubyte.gz\n"
78+
]
79+
},
80+
{
81+
"data": {
82+
"application/vnd.jupyter.widget-view+json": {
83+
"model_id": "ebf88e8850e7413fad33668a8d0bf344",
84+
"version_major": 2,
85+
"version_minor": 0
86+
},
87+
"text/plain": [
88+
" 0%| | 0/1648877 [00:00<?, ?it/s]"
89+
]
90+
},
91+
"metadata": {},
92+
"output_type": "display_data"
93+
},
94+
{
95+
"name": "stdout",
96+
"output_type": "stream",
97+
"text": [
98+
"Extracting ./datasets\\MNIST\\raw\\t10k-images-idx3-ubyte.gz to ./datasets\\MNIST\\raw\n",
99+
"\n",
100+
"Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz\n",
101+
"Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz to ./datasets\\MNIST\\raw\\t10k-labels-idx1-ubyte.gz\n"
102+
]
103+
},
104+
{
105+
"data": {
106+
"application/vnd.jupyter.widget-view+json": {
107+
"model_id": "2693df027cae4fd3981044ca79135e89",
108+
"version_major": 2,
109+
"version_minor": 0
110+
},
111+
"text/plain": [
112+
" 0%| | 0/4542 [00:00<?, ?it/s]"
113+
]
114+
},
115+
"metadata": {},
116+
"output_type": "display_data"
117+
},
118+
{
119+
"name": "stdout",
120+
"output_type": "stream",
121+
"text": [
122+
"Extracting ./datasets\\MNIST\\raw\\t10k-labels-idx1-ubyte.gz to ./datasets\\MNIST\\raw\n",
123+
"\n",
124+
"No. of Training examples: 60000\n",
125+
"No. of Test examples: 10000\n"
126+
]
127+
}
128+
],
129+
"source": [
130+
"train = dt.MNIST(root=\"./datasets\", train=True, transform=trans.ToTensor(), download=True)\n",
131+
"test = dt.MNIST(root=\"./datasets\", train=False, transform=trans.ToTensor(), download=True)\n",
132+
"print(\"No. of Training examples: \",len(train))\n",
133+
"print(\"No. of Test examples: \",len(test))"
134+
]
135+
},
136+
{
137+
"cell_type": "code",
138+
"execution_count": 4,
139+
"id": "afa4ff12",
140+
"metadata": {},
141+
"outputs": [],
142+
"source": [
143+
"train_batch = tch.utils.data.DataLoader(train, batch_size=30, shuffle=True)"
144+
]
145+
},
146+
{
147+
"cell_type": "code",
148+
"execution_count": 5,
149+
"id": "022d262f",
150+
"metadata": {},
151+
"outputs": [],
152+
"source": [
153+
"input = 784\n",
154+
"hidden = 490\n",
155+
"output = 10"
156+
]
157+
},
158+
{
159+
"cell_type": "code",
160+
"execution_count": 6,
161+
"id": "ea590922",
162+
"metadata": {},
163+
"outputs": [],
164+
"source": [
165+
"model = nn.Sequential(nn.Linear(input, hidden),\n",
166+
" nn.LeakyReLU(),\n",
167+
" nn.Linear(hidden, output),\n",
168+
" nn.LogSoftmax(dim=1))"
169+
]
170+
},
171+
{
172+
"cell_type": "code",
173+
"execution_count": 7,
174+
"id": "fa9a5b1f",
175+
"metadata": {},
176+
"outputs": [],
177+
"source": [
178+
"lossfn = nn.NLLLoss()\n",
179+
"images, labels = next(iter(train_batch))\n",
180+
"images = images.view(images.shape[0], -1)\n",
181+
"\n",
182+
"logps = model(images)\n",
183+
"loss = lossfn(logps, labels)\n",
184+
"loss.backward()"
185+
]
186+
},
187+
{
188+
"cell_type": "code",
189+
"execution_count": 8,
190+
"id": "e1fb5d19",
191+
"metadata": {},
192+
"outputs": [
193+
{
194+
"name": "stdout",
195+
"output_type": "stream",
196+
"text": [
197+
"Epoch Number : 0 = Loss : 0.5149400651156902\n",
198+
"Epoch Number : 1 = Loss : 0.261456840605475\n",
199+
"Epoch Number : 2 = Loss : 0.20588867816049605\n",
200+
"Epoch Number : 3 = Loss : 0.16964873825758695\n",
201+
"Epoch Number : 4 = Loss : 0.1434834775705822\n",
202+
"Epoch Number : 5 = Loss : 0.12429279719106853\n",
203+
"Epoch Number : 6 = Loss : 0.10908355080941692\n",
204+
"Epoch Number : 7 = Loss : 0.09697999537643046\n",
205+
"Epoch Number : 8 = Loss : 0.08723836344201118\n",
206+
"Epoch Number : 9 = Loss : 0.07917423069826328\n",
207+
"Epoch Number : 10 = Loss : 0.07214489371958188\n",
208+
"Epoch Number : 11 = Loss : 0.06623679360805546\n",
209+
"Epoch Number : 12 = Loss : 0.060786034525139254\n",
210+
"Epoch Number : 13 = Loss : 0.05600704051565845\n",
211+
"Epoch Number : 14 = Loss : 0.05210975646332372\n",
212+
"Epoch Number : 15 = Loss : 0.04836869774857769\n",
213+
"Epoch Number : 16 = Loss : 0.045035426611895676\n",
214+
"Epoch Number : 17 = Loss : 0.04181636443955358\n",
215+
"\n",
216+
"Training Time (in minutes) : 3.022224660714467\n"
217+
]
218+
}
219+
],
220+
"source": [
221+
"optimize = tch.optim.SGD(model.parameters(), lr=0.003, momentum=0.9)\n",
222+
"time_start = time()\n",
223+
"epochs = 18\n",
224+
"for num in range(epochs):\n",
225+
" run=0\n",
226+
" for images, labels in train_batch:\n",
227+
" images = images.view(images.shape[0], -1)\n",
228+
" optimize.zero_grad()\n",
229+
" output = model(images)\n",
230+
" loss = lossfn(output, labels)\n",
231+
" loss.backward()\n",
232+
" optimize.step()\n",
233+
" run += loss.item()\n",
234+
" else:\n",
235+
" print(\"Epoch Number : {} = Loss : {}\".format(num, run/len(train_batch)))\n",
236+
"Elapsed=(time()-time_start)/60\n",
237+
"print(\"\\nTraining Time (in minutes) : \",Elapsed)"
238+
]
239+
},
240+
{
241+
"cell_type": "code",
242+
"execution_count": 9,
243+
"id": "03c79aae",
244+
"metadata": {},
245+
"outputs": [
246+
{
247+
"name": "stdout",
248+
"output_type": "stream",
249+
"text": [
250+
"Number Of Images Tested : 10000\n",
251+
"Model Accuracy : 0.9777\n"
252+
]
253+
}
254+
],
255+
"source": [
256+
"correct=0\n",
257+
"all = 0\n",
258+
"for images,labels in test:\n",
259+
" img = images.view(1, 784)\n",
260+
" with tch.no_grad():\n",
261+
" logps = model(img) \n",
262+
" ps = tch.exp(logps)\n",
263+
" probab = list(ps.numpy()[0])\n",
264+
" prediction = probab.index(max(probab))\n",
265+
" truth = labels\n",
266+
" if(truth == prediction):\n",
267+
" correct += 1\n",
268+
" all += 1\n",
269+
"\n",
270+
"print(\"Number Of Images Tested : \", all)\n",
271+
"print(\"Model Accuracy : \", (correct/all))"
272+
]
273+
},
274+
{
275+
"cell_type": "code",
276+
"execution_count": 10,
277+
"id": "3569476c",
278+
"metadata": {},
279+
"outputs": [],
280+
"source": [
281+
"tch.save(model, './mnist_model.pt')"
282+
]
283+
}
284+
],
285+
"metadata": {
286+
"kernelspec": {
287+
"display_name": "Python 3 (ipykernel)",
288+
"language": "python",
289+
"name": "python3"
290+
},
291+
"language_info": {
292+
"codemirror_mode": {
293+
"name": "ipython",
294+
"version": 3
295+
},
296+
"file_extension": ".py",
297+
"mimetype": "text/x-python",
298+
"name": "python",
299+
"nbconvert_exporter": "python",
300+
"pygments_lexer": "ipython3",
301+
"version": "3.9.13"
302+
}
303+
},
304+
"nbformat": 4,
305+
"nbformat_minor": 5
306+
}

0 commit comments

Comments
 (0)