-
Notifications
You must be signed in to change notification settings - Fork 9.7k
Mnist update #469
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Mnist update #469
Conversation
mnist/main.py
Outdated
x = self.fc2(x) | ||
return F.log_softmax(x, dim=1) | ||
|
||
|
||
def name(self): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this function should be removed
mnist/main.py
Outdated
if (args.save_model): | ||
PATH="mnist_cnn.pt" | ||
torch.save(model.state_dict(), PATH) | ||
print("model saved as mnist_cnn.pt in the current working directory\n") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
print is unnecessary
mnist/main.py
Outdated
train(args, model, device, train_loader, optimizer, epoch) | ||
test(args, model, device, test_loader) | ||
|
||
if (args.save_model): | ||
PATH="mnist_cnn.pt" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
additional PATH is unncessary, do torch.save(model.state_dict(), "mnist_cnn.pt")
mnist/main.py
Outdated
@@ -100,10 +107,14 @@ def main(): | |||
model = Net().to(device) | |||
optimizer = optim.SGD(model.parameters(), lr=args.lr, momentum=args.momentum) | |||
|
|||
for epoch in range(1, args.epochs + 1): | |||
for epoch in range(1, args.epochs+1): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
args.epochs + 1
is cleaner. Revert this change
@soumith i have made the changes as asked by you. |
thank you @surgan12 ! |
@soumith my pleasure always ! |
* mnist added dcgan * mnist added * mnist improved * Update main.py
Changing the architecture of the Mnist , simplifying it and increasing the accuracy.
Also added the functionalty of saving the model as an argument.