From 787b1fe746ab1792f2482aff47bad4791be59adb Mon Sep 17 00:00:00 2001 From: Adam Dziedzic Date: Fri, 29 Jun 2018 18:08:00 -0500 Subject: [PATCH 1/6] corrected the example with cross-correlation with scipy - the example is fully correct with appropriate gradients --- advanced_source/numpy_extensions_tutorial.py | 53 +++++++++++++------- 1 file changed, 36 insertions(+), 17 deletions(-) diff --git a/advanced_source/numpy_extensions_tutorial.py b/advanced_source/numpy_extensions_tutorial.py index e6bca8b0dd4..87aeefcf601 100644 --- a/advanced_source/numpy_extensions_tutorial.py +++ b/advanced_source/numpy_extensions_tutorial.py @@ -4,6 +4,8 @@ ========================================= **Author**: `Adam Paszke `_ +**Updated by**: `Adam Dziedzic` [https://github.com/adam-dziedzic](https://github.com/adam-dziedzic) + In this tutorial, we shall go through two tasks: 1. Create a neural network layer with no parameters. @@ -79,45 +81,62 @@ def incorrect_fft(input): # *Please Note that the implementation serves as an illustration, and we # did not verify it’s correctness* -from scipy.signal import convolve2d, correlate2d +from numpy import flip +import numpy as np +from scipy.signal import correlate2d from torch.nn.modules.module import Module from torch.nn.parameter import Parameter - class ScipyConv2dFunction(Function): @staticmethod - def forward(ctx, input, filter): - input, filter = input.detach(), filter.detach() # detach so we can cast to NumPy - result = correlate2d(input.numpy(), filter.detach().numpy(), mode='valid') - ctx.save_for_backward(input, filter) - return input.new(result) + def forward(ctx, input, filter, bias): + # detach so we can cast to NumPy + input, filter, bias = input.detach(), filter.detach(), bias.detach() + result = correlate2d(input.numpy(), filter.numpy(), mode='valid') + result += bias.numpy() + ctx.save_for_backward(input, filter, bias) + return torch.from_numpy(result) @staticmethod def backward(ctx, grad_output): grad_output = grad_output.detach() - input, filter = ctx.saved_tensors - grad_input = convolve2d(grad_output.numpy(), filter.t().numpy(), mode='full') - grad_filter = convolve2d(input.numpy(), grad_output.numpy(), mode='valid') - - return grad_output.new_tensor(grad_input), grad_output.new_tensor(grad_filter) + input, filter, bias = ctx.saved_tensors + grad_output = grad_output.numpy() + grad_bias = np.sum(grad_output, keepdims=True) + grad_input = correlate2d(grad_output, flip(flip(filter.numpy(), axis=0), axis=1), mode='full') + grad_filter = correlate2d(input.numpy(), grad_output, mode='valid') + return torch.from_numpy(grad_input), torch.from_numpy(grad_filter), torch.from_numpy(grad_bias) class ScipyConv2d(Module): - def __init__(self, kh, kw): super(ScipyConv2d, self).__init__() self.filter = Parameter(torch.randn(kh, kw)) + self.bias = Parameter(torch.randn(1, 1)) def forward(self, input): - return ScipyConv2dFunction.apply(input, self.filter) + return ScipyConv2dFunction.apply(input, self.filter, self.bias) + ############################################################### # **Example usage:** module = ScipyConv2d(3, 3) -print(list(module.parameters())) +print("Filter and bias: ", list(module.parameters())) input = torch.randn(10, 10, requires_grad=True) output = module(input) -print(output) +print("Output from the convolution: ", output) output.backward(torch.randn(8, 8)) -print(input.grad) +print("Gradient for the input map: ", input.grad) + +############################################################### +# **Check the gradients:** + +from torch.autograd import gradcheck + +moduleConv = ScipyConv2d(3, 3) + +input = [torch.randn(20, 20, dtype=torch.double, requires_grad=True)] +# print("input: ", input) +test = gradcheck(moduleConv, input, eps=1e-6, atol=1e-4) +print("Are the gradients correct: ", test) \ No newline at end of file From 43fe93fba9e3209c2287616d700dca8fb22ce1ea Mon Sep 17 00:00:00 2001 From: Adam Dziedzic Date: Sat, 30 Jun 2018 05:48:07 -0500 Subject: [PATCH 2/6] improve the code and comments; corrected the example with cross-correlation with scipy - the example is fully correct with appropriate gradients --- advanced_source/numpy_extensions_tutorial.py | 30 +++++++++----------- 1 file changed, 13 insertions(+), 17 deletions(-) diff --git a/advanced_source/numpy_extensions_tutorial.py b/advanced_source/numpy_extensions_tutorial.py index 87aeefcf601..5508587907e 100644 --- a/advanced_source/numpy_extensions_tutorial.py +++ b/advanced_source/numpy_extensions_tutorial.py @@ -10,16 +10,18 @@ 1. Create a neural network layer with no parameters. - - This calls into **numpy** as part of it’s implementation + - This calls into **numpy** as part of its implementation 2. Create a neural network layer that has learnable weights - - This calls into **SciPy** as part of it’s implementation + - This calls into **SciPy** as part of its implementation """ import torch +from numpy.fft import rfft2, irfft2 from torch.autograd import Function + ############################################################### # Parameter-less example # ---------------------- @@ -31,8 +33,6 @@ # # **Layer Implementation** -from numpy.fft import rfft2, irfft2 - class BadFFTFunction(Function): @@ -46,6 +46,7 @@ def backward(self, grad_output): result = irfft2(numpy_go) return grad_output.new(result) + # since this layer does not have any parameters, we can # simply declare this as a function, rather than as an nn.Module class @@ -53,6 +54,7 @@ def backward(self, grad_output): def incorrect_fft(input): return BadFFTFunction()(input) + ############################################################### # **Example usage of the created layer:** @@ -66,20 +68,13 @@ def incorrect_fft(input): # Parametrized example # -------------------- # -# This implements a layer with learnable weights. -# -# It implements the Cross-correlation with a learnable kernel. -# -# In deep learning literature, it’s confusingly referred to as -# Convolution. +# In deep learning literature, this layer is confusingly referred to as convolution while the actual operation is +# cross-correlation (the only difference is that filter is flipped for convolution, +# which is not the case for cross-correlation). # -# The backward computes the gradients wrt the input and gradients wrt the -# filter. +# Implementation of a layer with learnable weights, where cross-correlation has a kernel that represents weights. # -# **Implementation:** -# -# *Please Note that the implementation serves as an illustration, and we -# did not verify it’s correctness* +# The backward pass computes the gradient wrt the input and the gradient wrt the filter. from numpy import flip import numpy as np @@ -87,6 +82,7 @@ def incorrect_fft(input): from torch.nn.modules.module import Module from torch.nn.parameter import Parameter + class ScipyConv2dFunction(Function): @staticmethod def forward(ctx, input, filter, bias): @@ -139,4 +135,4 @@ def forward(self, input): input = [torch.randn(20, 20, dtype=torch.double, requires_grad=True)] # print("input: ", input) test = gradcheck(moduleConv, input, eps=1e-6, atol=1e-4) -print("Are the gradients correct: ", test) \ No newline at end of file +print("Are the gradients correct: ", test) From 28d90f647cc98a2daf82ed0757c0f854245348b9 Mon Sep 17 00:00:00 2001 From: Adam Dziedzic Date: Sat, 30 Jun 2018 05:51:29 -0500 Subject: [PATCH 3/6] keep the include in the previous place --- advanced_source/numpy_extensions_tutorial.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/advanced_source/numpy_extensions_tutorial.py b/advanced_source/numpy_extensions_tutorial.py index 5508587907e..be019ad7682 100644 --- a/advanced_source/numpy_extensions_tutorial.py +++ b/advanced_source/numpy_extensions_tutorial.py @@ -18,7 +18,6 @@ """ import torch -from numpy.fft import rfft2, irfft2 from torch.autograd import Function @@ -33,6 +32,8 @@ # # **Layer Implementation** +from numpy.fft import rfft2, irfft2 + class BadFFTFunction(Function): From df2ccae637146a3c857c607cbc45e6650573fe86 Mon Sep 17 00:00:00 2001 From: adam-dziedzic Date: Sat, 30 Jun 2018 05:55:53 -0500 Subject: [PATCH 4/6] remove spurious new lines --- advanced_source/numpy_extensions_tutorial.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/advanced_source/numpy_extensions_tutorial.py b/advanced_source/numpy_extensions_tutorial.py index be019ad7682..31eb0e15241 100644 --- a/advanced_source/numpy_extensions_tutorial.py +++ b/advanced_source/numpy_extensions_tutorial.py @@ -20,7 +20,6 @@ import torch from torch.autograd import Function - ############################################################### # Parameter-less example # ---------------------- @@ -47,7 +46,6 @@ def backward(self, grad_output): result = irfft2(numpy_go) return grad_output.new(result) - # since this layer does not have any parameters, we can # simply declare this as a function, rather than as an nn.Module class @@ -55,7 +53,6 @@ def backward(self, grad_output): def incorrect_fft(input): return BadFFTFunction()(input) - ############################################################### # **Example usage of the created layer:** From c7ab8a30fc9a740a940897aa5665e6d8abebeb37 Mon Sep 17 00:00:00 2001 From: Adam Dziedzic Date: Sat, 30 Jun 2018 06:03:43 -0500 Subject: [PATCH 5/6] small changes --- advanced_source/numpy_extensions_tutorial.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/advanced_source/numpy_extensions_tutorial.py b/advanced_source/numpy_extensions_tutorial.py index be019ad7682..caad5de33d5 100644 --- a/advanced_source/numpy_extensions_tutorial.py +++ b/advanced_source/numpy_extensions_tutorial.py @@ -69,11 +69,13 @@ def incorrect_fft(input): # Parametrized example # -------------------- # -# In deep learning literature, this layer is confusingly referred to as convolution while the actual operation is -# cross-correlation (the only difference is that filter is flipped for convolution, +# In deep learning literature, this layer is confusingly referred +# to as convolution while the actual operation is cross-correlation +# (the only difference is that filter is flipped for convolution, # which is not the case for cross-correlation). # -# Implementation of a layer with learnable weights, where cross-correlation has a kernel that represents weights. +# Implementation of a layer with learnable weights, where cross-correlation +# has a filter (kernel) that represents weights. # # The backward pass computes the gradient wrt the input and the gradient wrt the filter. From 24fd037a4e634b53d95bd08922b3d0a4d8f10f8c Mon Sep 17 00:00:00 2001 From: Adam Dziedzic Date: Sat, 30 Jun 2018 06:58:29 -0500 Subject: [PATCH 6/6] flipped filter or cross-correlation --- advanced_source/numpy_extensions_tutorial.py | 13 +++++++------ 1 file changed, 7 insertions(+), 6 deletions(-) diff --git a/advanced_source/numpy_extensions_tutorial.py b/advanced_source/numpy_extensions_tutorial.py index 957371fb4e3..e39e94a0319 100644 --- a/advanced_source/numpy_extensions_tutorial.py +++ b/advanced_source/numpy_extensions_tutorial.py @@ -78,7 +78,7 @@ def incorrect_fft(input): from numpy import flip import numpy as np -from scipy.signal import correlate2d +from scipy.signal import convolve2d, correlate2d from torch.nn.modules.module import Module from torch.nn.parameter import Parameter @@ -99,15 +99,17 @@ def backward(ctx, grad_output): input, filter, bias = ctx.saved_tensors grad_output = grad_output.numpy() grad_bias = np.sum(grad_output, keepdims=True) - grad_input = correlate2d(grad_output, flip(flip(filter.numpy(), axis=0), axis=1), mode='full') + grad_input = convolve2d(grad_output, filter.numpy(), mode='full') + # the previous line can be expressed equivalently as: + # grad_input = correlate2d(grad_output, flip(flip(filter.numpy(), axis=0), axis=1), mode='full') grad_filter = correlate2d(input.numpy(), grad_output, mode='valid') return torch.from_numpy(grad_input), torch.from_numpy(grad_filter), torch.from_numpy(grad_bias) class ScipyConv2d(Module): - def __init__(self, kh, kw): + def __init__(self, filter_width, filter_height): super(ScipyConv2d, self).__init__() - self.filter = Parameter(torch.randn(kh, kw)) + self.filter = Parameter(torch.randn(filter_width, filter_height)) self.bias = Parameter(torch.randn(1, 1)) def forward(self, input): @@ -128,11 +130,10 @@ def forward(self, input): ############################################################### # **Check the gradients:** -from torch.autograd import gradcheck +from torch.autograd.gradcheck import gradcheck moduleConv = ScipyConv2d(3, 3) input = [torch.randn(20, 20, dtype=torch.double, requires_grad=True)] -# print("input: ", input) test = gradcheck(moduleConv, input, eps=1e-6, atol=1e-4) print("Are the gradients correct: ", test)