diff --git a/advanced_source/numpy_extensions_tutorial.py b/advanced_source/numpy_extensions_tutorial.py index e6bca8b0dd4..e39e94a0319 100644 --- a/advanced_source/numpy_extensions_tutorial.py +++ b/advanced_source/numpy_extensions_tutorial.py @@ -4,15 +4,17 @@ ========================================= **Author**: `Adam Paszke `_ +**Updated by**: `Adam Dziedzic` [https://github.com/adam-dziedzic](https://github.com/adam-dziedzic) + In this tutorial, we shall go through two tasks: 1. Create a neural network layer with no parameters. - - This calls into **numpy** as part of it’s implementation + - This calls into **numpy** as part of its implementation 2. Create a neural network layer that has learnable weights - - This calls into **SciPy** as part of it’s implementation + - This calls into **SciPy** as part of its implementation """ import torch @@ -64,21 +66,18 @@ def incorrect_fft(input): # Parametrized example # -------------------- # -# This implements a layer with learnable weights. -# -# It implements the Cross-correlation with a learnable kernel. -# -# In deep learning literature, it’s confusingly referred to as -# Convolution. -# -# The backward computes the gradients wrt the input and gradients wrt the -# filter. +# In deep learning literature, this layer is confusingly referred +# to as convolution while the actual operation is cross-correlation +# (the only difference is that filter is flipped for convolution, +# which is not the case for cross-correlation). # -# **Implementation:** +# Implementation of a layer with learnable weights, where cross-correlation +# has a filter (kernel) that represents weights. # -# *Please Note that the implementation serves as an illustration, and we -# did not verify it’s correctness* +# The backward pass computes the gradient wrt the input and the gradient wrt the filter. +from numpy import flip +import numpy as np from scipy.signal import convolve2d, correlate2d from torch.nn.modules.module import Module from torch.nn.parameter import Parameter @@ -86,38 +85,55 @@ def incorrect_fft(input): class ScipyConv2dFunction(Function): @staticmethod - def forward(ctx, input, filter): - input, filter = input.detach(), filter.detach() # detach so we can cast to NumPy - result = correlate2d(input.numpy(), filter.detach().numpy(), mode='valid') - ctx.save_for_backward(input, filter) - return input.new(result) + def forward(ctx, input, filter, bias): + # detach so we can cast to NumPy + input, filter, bias = input.detach(), filter.detach(), bias.detach() + result = correlate2d(input.numpy(), filter.numpy(), mode='valid') + result += bias.numpy() + ctx.save_for_backward(input, filter, bias) + return torch.from_numpy(result) @staticmethod def backward(ctx, grad_output): grad_output = grad_output.detach() - input, filter = ctx.saved_tensors - grad_input = convolve2d(grad_output.numpy(), filter.t().numpy(), mode='full') - grad_filter = convolve2d(input.numpy(), grad_output.numpy(), mode='valid') - - return grad_output.new_tensor(grad_input), grad_output.new_tensor(grad_filter) + input, filter, bias = ctx.saved_tensors + grad_output = grad_output.numpy() + grad_bias = np.sum(grad_output, keepdims=True) + grad_input = convolve2d(grad_output, filter.numpy(), mode='full') + # the previous line can be expressed equivalently as: + # grad_input = correlate2d(grad_output, flip(flip(filter.numpy(), axis=0), axis=1), mode='full') + grad_filter = correlate2d(input.numpy(), grad_output, mode='valid') + return torch.from_numpy(grad_input), torch.from_numpy(grad_filter), torch.from_numpy(grad_bias) class ScipyConv2d(Module): - - def __init__(self, kh, kw): + def __init__(self, filter_width, filter_height): super(ScipyConv2d, self).__init__() - self.filter = Parameter(torch.randn(kh, kw)) + self.filter = Parameter(torch.randn(filter_width, filter_height)) + self.bias = Parameter(torch.randn(1, 1)) def forward(self, input): - return ScipyConv2dFunction.apply(input, self.filter) + return ScipyConv2dFunction.apply(input, self.filter, self.bias) + ############################################################### # **Example usage:** module = ScipyConv2d(3, 3) -print(list(module.parameters())) +print("Filter and bias: ", list(module.parameters())) input = torch.randn(10, 10, requires_grad=True) output = module(input) -print(output) +print("Output from the convolution: ", output) output.backward(torch.randn(8, 8)) -print(input.grad) +print("Gradient for the input map: ", input.grad) + +############################################################### +# **Check the gradients:** + +from torch.autograd.gradcheck import gradcheck + +moduleConv = ScipyConv2d(3, 3) + +input = [torch.randn(20, 20, dtype=torch.double, requires_grad=True)] +test = gradcheck(moduleConv, input, eps=1e-6, atol=1e-4) +print("Are the gradients correct: ", test)