From 522c793b88b67b92f0c8ed905a055c62359051d1 Mon Sep 17 00:00:00 2001 From: ver2king <31938494+ver2king@users.noreply.github.com> Date: Sun, 12 Nov 2023 19:19:29 -0600 Subject: [PATCH] Changes to pytorch tutorial for Spatial Transformer Networks 1. Add a sharable link for the distorted MNIST dataset size 60*60 2. Add two additional Spatial Transformer Net (STNs) with different padding modes in "torch.grid_sample" layer 3. Compare the performance of all STNs (for both the original and the distorted MNIST datasets) --- .idea/.gitignore | 3 + .../inspectionProfiles/profiles_settings.xml | 6 + .idea/misc.xml | 4 + .idea/modules.xml | 8 + .idea/tutorials.iml | 12 + .idea/vcs.xml | 6 + .../spatial_transformer_tutorial.py | 207 ++++++++++++++++-- 7 files changed, 230 insertions(+), 16 deletions(-) create mode 100644 .idea/.gitignore create mode 100644 .idea/inspectionProfiles/profiles_settings.xml create mode 100644 .idea/misc.xml create mode 100644 .idea/modules.xml create mode 100644 .idea/tutorials.iml create mode 100644 .idea/vcs.xml diff --git a/.idea/.gitignore b/.idea/.gitignore new file mode 100644 index 00000000000..26d33521af1 --- /dev/null +++ b/.idea/.gitignore @@ -0,0 +1,3 @@ +# Default ignored files +/shelf/ +/workspace.xml diff --git a/.idea/inspectionProfiles/profiles_settings.xml b/.idea/inspectionProfiles/profiles_settings.xml new file mode 100644 index 00000000000..105ce2da2d6 --- /dev/null +++ b/.idea/inspectionProfiles/profiles_settings.xml @@ -0,0 +1,6 @@ + + + + \ No newline at end of file diff --git a/.idea/misc.xml b/.idea/misc.xml new file mode 100644 index 00000000000..d1e22ecb896 --- /dev/null +++ b/.idea/misc.xml @@ -0,0 +1,4 @@ + + + + \ No newline at end of file diff --git a/.idea/modules.xml b/.idea/modules.xml new file mode 100644 index 00000000000..5c1c8885418 --- /dev/null +++ b/.idea/modules.xml @@ -0,0 +1,8 @@ + + + + + + + + \ No newline at end of file diff --git a/.idea/tutorials.iml b/.idea/tutorials.iml new file mode 100644 index 00000000000..f14d577ec9c --- /dev/null +++ b/.idea/tutorials.iml @@ -0,0 +1,12 @@ + + + + + + + + + + \ No newline at end of file diff --git a/.idea/vcs.xml b/.idea/vcs.xml new file mode 100644 index 00000000000..94a25f7f4cb --- /dev/null +++ b/.idea/vcs.xml @@ -0,0 +1,6 @@ + + + + + + \ No newline at end of file diff --git a/intermediate_source/spatial_transformer_tutorial.py b/intermediate_source/spatial_transformer_tutorial.py index 49b6b0f0a2b..ebd82205ca1 100644 --- a/intermediate_source/spatial_transformer_tutorial.py +++ b/intermediate_source/spatial_transformer_tutorial.py @@ -23,6 +23,13 @@ One of the best things about STN is the ability to simply plug it into any existing CNN with very little modification. + +Update for this tutorial: +- Add a distorted MNIST dataset 60*60 to interpret the original approach +using torch.grid_sample with padding_mode = "zeros" +- Add a new Spatial Transformer Network compatible with the distorted MNIST dataset +- Quantify the difference between padding_mode in torch.grid_sample (i.e., "zeros" vs +"boundary") """ # License: BSD # Author: Ghassen Hamrouni @@ -31,12 +38,13 @@ import torch.nn as nn import torch.nn.functional as F import torch.optim as optim +from torch.utils.data import Dataset, DataLoader import torchvision from torchvision import datasets, transforms import matplotlib.pyplot as plt import numpy as np -plt.ion() # interactive mode +plt.ion() # interactive mode ###################################################################### # Loading the data @@ -46,7 +54,10 @@ # standard convolutional network augmented with a spatial transformer # network. +import google_drive_downloader +from google_drive_downloader import GoogleDriveDownloader as GDD from six.moves import urllib + opener = urllib.request.build_opener() opener.addheaders = [('User-agent', 'Mozilla/5.0')] urllib.request.install_opener(opener) @@ -54,18 +65,93 @@ device = torch.device("cuda" if torch.cuda.is_available() else "cpu") # Training dataset -train_loader = torch.utils.data.DataLoader( +normal_train_loader = torch.utils.data.DataLoader( datasets.MNIST(root='.', train=True, download=True, transform=transforms.Compose([ transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,)) - ])), batch_size=64, shuffle=True, num_workers=4) + ])), batch_size=64, shuffle=True, num_workers=0) # Test dataset -test_loader = torch.utils.data.DataLoader( +normal_test_loader = torch.utils.data.DataLoader( datasets.MNIST(root='.', train=False, transform=transforms.Compose([ transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,)) - ])), batch_size=64, shuffle=True, num_workers=4) + ])), batch_size=64, shuffle=True, num_workers=0) + + +###################################################################### +# Loading the data +# ---------------- +# +# In this post we experiment with the classic MNIST dataset. Using a +# standard convolutional network augmented with a spatial transformer +# network. +# +# Update: to interpret the Spatial Transformer Network better as the +# updated aims indicate, we also experiment with a distorted MNIST dataset. +# In the distorted MNIST dataset, for an image: +# - The original digits are placed randomly into a black canvas of 60*60. +# - Noises (i.e., patches sampled from other images not identical to the +# specific digit in the image) are placed randomly in the new canvas 60*60 above. +# +# The distorted MNIST dataset with image size 60*60 is loaded from: +# https://github.com/theRealSuperMario/pytorch_stn/blob/master/data/mnist_cluttered_60.npz +# +# A preview of the distorted MNIST dataset with image size 60*60 is loaded from: +# https://drive.google.com/file/d/1txYwNjgY5FxYIUuScE7AKgmeXA4MJB5R/view?usp=drive_linkmo.png +# Credit for this distorted MNISt dataset is given to +# **Author**: `Sandro Braun `_ + +# Helper class to load the distorted dataset +class DistortedDataSet(Dataset): + # TODO: ? transforms may not be required here + """ + Generate dataset composed of: + - The original inputs & outputs (using torch DataLoader) + - Transforms (using torchvision transforms) + """ + + def __init__(self, inputs, outputs, transform): + super(DistortedDataSet, self).__init__() + self.inputs = inputs + self.outputs = outputs + self.transform = transform + + def __len__(self): + return len(self.outputs) + + def __getitem__(self, idx): + input_ = self.inputs[idx] + input_ = input_[None, :, :] + output_ = int(self.outputs[idx]) + if self.transform: + input_ = self.transform(input_) + return input_, output_ + + +# Load the distorted MNIST dataset first +distorted_file_id = '1txYwNjgY5FxYIUuScE7AKgmeXA4MJB5R' +GDD.download_file_from_google_drive(file_id=distorted_file_id, dest_path='./distorted_mnist_60.npz', unzip=True) +distorted_data = np.load('distorted_mnist_60.npz') + +# Training dataset (distorted) +train_images = torch.tensor(distorted_data['X_train'], dtype=torch.float32) +train_digits = torch.tensor(distorted_data['y_train'], dtype=torch.float32) +train_set = DistortedDataSet(inputs=train_images, outputs=train_digits, + transform=transforms.Compose([ + transforms.Normalize((0.1307,), (0.3081,))])) +distorted_train_loader = DataLoader( + dataset=train_set, batch_size=64, shuffle=True, num_workers=0) + +# Test dataset (distorted) +test_images = torch.tensor(distorted_data['X_test'], dtype=torch.float32) +test_digits = torch.tensor(distorted_data['y_test'], dtype=torch.float32) +test_set = DistortedDataSet(inputs=test_images, outputs=test_digits, + transform=transforms.Compose([ + transforms.Normalize((0.1307,), (0.3081,))])) +distorted_test_loader = DataLoader( + dataset=test_set, batch_size=64, shuffle=False, num_workers=0) + ###################################################################### # Depicting spatial transformer networks @@ -88,6 +174,10 @@ # We need the latest version of PyTorch that contains # affine_grid and grid_sample modules. # +# Update: to interpret the Spatial Transformer Network better as the +# updated aims indicate: +# - A Spatial Transformer Network that digests the image size 60*60, named Net_60, is added. +# - This Net_60 enables either "zeros" or "boundary" padding_mode in torch.grid_sample class Net(nn.Module): @@ -146,7 +236,53 @@ def forward(self, x): return F.log_softmax(x, dim=1) -model = Net().to(device) +class Net_60(nn.Module): + def __init__(self, padding_mode): + super(Net_60, self).__init__() + self.localization = nn.Sequential(nn.Conv2d(1, 8, kernel_size=7), + nn.MaxPool2d(2, stride=2), + nn.ReLU(True), + nn.Conv2d(8, 10, kernel_size=5), + nn.MaxPool2d(2, stride=2), + nn.ReLU(True)) + self.conv1 = nn.Conv2d(1, 10, kernel_size=5) + self.conv2 = nn.Conv2d(10, 20, kernel_size=5) + self.conv2_drop = nn.Dropout2d() + self.fc1 = nn.Linear(2880, 50) + self.fc2 = nn.Linear(50, 10) + self.fc_loc = nn.Sequential( + nn.Linear(10 * 11 * 11, 32), nn.ReLU(True), + nn.Linear(32, 3 * 2) + ) + self.fc_loc[2].weight.data.zero_() + self.fc_loc[2].bias.data.copy_( + torch.tensor([1, 0, 0, 0, 1, 0], dtype=torch.float) + ) + self.padding_mode = padding_mode + + def stn(self, x): + xs = self.localization(x) + xs = xs.view(-1, 10 * 11 * 11) + theta = self.fc_loc(xs) + theta = theta.view(-1, 2, 3) + grid = F.affine_grid(theta, x.size()) + x = F.grid_sample(x, grid, padding_mode=self.padding_mode) + return x + + def forward(self, x): + x = self.stn(x) + x = F.relu(F.max_pool2d(self.conv1(x), 2)) + x = F.relu(F.max_pool2d(self.conv2_drop(self.conv2(x)), 2)) + x = x.view(-1, 2880) + x = F.relu(self.fc1(x)) + x = F.dropout(x, training=self.training) + x = self.fc2(x) + return F.log_softmax(x, dim=1) + + +model_28 = Net().to(device) +model_60_padding_zeros = Net_60(padding_mode="zeros").to(device) +model_60_padding_boundary = Net_60(padding_mode="boundary").to(device) ###################################################################### # Training the model @@ -157,10 +293,12 @@ def forward(self, x): # the model is learning STN automatically in an end-to-end fashion. -optimizer = optim.SGD(model.parameters(), lr=0.01) +optimizer_28 = optim.SGD(model_28.parameters(), lr=0.01) +optimizer_60_padding_zeros = optim.SGD(model_60_padding_zeros.parameters(), lr=0.01) +optimizer_60_padding_boundary = optim.SGD(model_60_padding_boundary.parameters(), lr=0.01) -def train(epoch): +def train(model, train_loader, optimizer, epoch): model.train() for batch_idx, (data, target) in enumerate(train_loader): data, target = data.to(device), target.to(device) @@ -173,13 +311,15 @@ def train(epoch): if batch_idx % 500 == 0: print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format( epoch, batch_idx * len(data), len(train_loader.dataset), - 100. * batch_idx / len(train_loader), loss.item())) + 100. * batch_idx / len(train_loader), loss.item())) + + # # A simple test procedure to measure the STN performances on MNIST. # -def test(): +def test(model, test_loader): with torch.no_grad(): model.eval() test_loss = 0 @@ -199,6 +339,7 @@ def test(): .format(test_loss, correct, len(test_loader.dataset), 100. * correct / len(test_loader.dataset))) + ###################################################################### # Visualizing the STN results # --------------------------- @@ -219,12 +360,15 @@ def convert_image_np(inp): inp = np.clip(inp, 0, 1) return inp + # We want to visualize the output of the spatial transformers layer # after the training, we visualize a batch of input images and # the corresponding transformed batch using STN. +# +# Update: to interpret the Spatial Transformer Network better as the +# updated aims indicate, this function is modified to take any torch.Dataloader - -def visualize_stn(): +def visualize_stn(model, test_loader): with torch.no_grad(): # Get a batch of training data data = next(iter(test_loader))[0].to(device) @@ -246,12 +390,43 @@ def visualize_stn(): axarr[1].imshow(out_grid) axarr[1].set_title('Transformed Images') + +# Update: to interpret the Spatial Transformer Network better as the +# updated aims indicate, now we perform the following: +# 1. Use model to train, test and visualize for th original image (size 28*28) +# 2. Use model_60_padding_zeros to train, test and visualize for the distorted image (size 60*60) +# 3. Use model_60_padding_boundary to train, test and visualize for the distorted image (size 60*60) + +# The model for original image size 28*28 +for epoch in range(1, 20 + 1): + train(model_28, normal_train_loader, optimizer_28, epoch) + test(model_28, normal_test_loader) + +# the model for distorted image size 60*60, with padding zeros for torch.grid_sample for epoch in range(1, 20 + 1): - train(epoch) - test() + train(model_60_padding_zeros, distorted_train_loader, optimizer_60_padding_zeros, epoch) + test(model_60_padding_zeros, distorted_test_loader) -# Visualize the STN transformation on some input batch -visualize_stn() +# the model for distorted image size 60*60, with padding boundary for torch.grid_sample +for epoch in range(1, 20 + 1): + train(model_60_padding_boundary, distorted_train_loader, optimizer_60_padding_boundary, epoch) + test(model_60_padding_boundary, distorted_test_loader) + +# Visualize the STN transformation on some input batche for model_28 +# model_60_padding_zeros, and model_60_padding_boundary, respectively +visualize_stn(model_28, normal_test_loader) +visualize_stn(model_60_padding_zeros, distorted_test_loader) +visualize_stn(model_60_padding_boundary, distorted_test_loader) plt.ioff() plt.show() + +###################################################################### +# Interpreting the STN results +# --------------------------- +# +# With the visualization from the 3 Spatial Transformer Networks above: +# +# - +# - +# -