Skip to content

corrected the example for cross-correlation with scipy #268

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 7 commits into from
Jul 1, 2018
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
78 changes: 47 additions & 31 deletions advanced_source/numpy_extensions_tutorial.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,15 +4,17 @@
=========================================
**Author**: `Adam Paszke <https://github.com/apaszke>`_

**Updated by**: `Adam Dziedzic` [https://github.com/adam-dziedzic](https://github.com/adam-dziedzic)

In this tutorial, we shall go through two tasks:

1. Create a neural network layer with no parameters.

- This calls into **numpy** as part of it’s implementation
- This calls into **numpy** as part of its implementation

2. Create a neural network layer that has learnable weights

- This calls into **SciPy** as part of it’s implementation
- This calls into **SciPy** as part of its implementation
"""

import torch
Expand Down Expand Up @@ -64,60 +66,74 @@ def incorrect_fft(input):
# Parametrized example
# --------------------
#
# This implements a layer with learnable weights.
#
# It implements the Cross-correlation with a learnable kernel.
#
# In deep learning literature, it’s confusingly referred to as
# Convolution.
#
# The backward computes the gradients wrt the input and gradients wrt the
# filter.
# In deep learning literature, this layer is confusingly referred
# to as convolution while the actual operation is cross-correlation
# (the only difference is that filter is flipped for convolution,
# which is not the case for cross-correlation).
#
# **Implementation:**
# Implementation of a layer with learnable weights, where cross-correlation
# has a filter (kernel) that represents weights.
#
# *Please Note that the implementation serves as an illustration, and we
# did not verify it’s correctness*
# The backward pass computes the gradient wrt the input and the gradient wrt the filter.

from numpy import flip
import numpy as np
from scipy.signal import convolve2d, correlate2d
from torch.nn.modules.module import Module
from torch.nn.parameter import Parameter


class ScipyConv2dFunction(Function):
@staticmethod
def forward(ctx, input, filter):
input, filter = input.detach(), filter.detach() # detach so we can cast to NumPy
result = correlate2d(input.numpy(), filter.detach().numpy(), mode='valid')
ctx.save_for_backward(input, filter)
return input.new(result)
def forward(ctx, input, filter, bias):
# detach so we can cast to NumPy
input, filter, bias = input.detach(), filter.detach(), bias.detach()
result = correlate2d(input.numpy(), filter.numpy(), mode='valid')
result += bias.numpy()
ctx.save_for_backward(input, filter, bias)
return torch.from_numpy(result)

@staticmethod
def backward(ctx, grad_output):
grad_output = grad_output.detach()
input, filter = ctx.saved_tensors
grad_input = convolve2d(grad_output.numpy(), filter.t().numpy(), mode='full')
grad_filter = convolve2d(input.numpy(), grad_output.numpy(), mode='valid')

return grad_output.new_tensor(grad_input), grad_output.new_tensor(grad_filter)
input, filter, bias = ctx.saved_tensors
grad_output = grad_output.numpy()
grad_bias = np.sum(grad_output, keepdims=True)
grad_input = convolve2d(grad_output, filter.numpy(), mode='full')
# the previous line can be expressed equivalently as:
# grad_input = correlate2d(grad_output, flip(flip(filter.numpy(), axis=0), axis=1), mode='full')
grad_filter = correlate2d(input.numpy(), grad_output, mode='valid')
return torch.from_numpy(grad_input), torch.from_numpy(grad_filter), torch.from_numpy(grad_bias)


class ScipyConv2d(Module):

def __init__(self, kh, kw):
def __init__(self, filter_width, filter_height):
super(ScipyConv2d, self).__init__()
self.filter = Parameter(torch.randn(kh, kw))
self.filter = Parameter(torch.randn(filter_width, filter_height))
self.bias = Parameter(torch.randn(1, 1))

def forward(self, input):
return ScipyConv2dFunction.apply(input, self.filter)
return ScipyConv2dFunction.apply(input, self.filter, self.bias)


###############################################################
# **Example usage:**

module = ScipyConv2d(3, 3)
print(list(module.parameters()))
print("Filter and bias: ", list(module.parameters()))
input = torch.randn(10, 10, requires_grad=True)
output = module(input)
print(output)
print("Output from the convolution: ", output)
output.backward(torch.randn(8, 8))
print(input.grad)
print("Gradient for the input map: ", input.grad)

###############################################################
# **Check the gradients:**

from torch.autograd.gradcheck import gradcheck

moduleConv = ScipyConv2d(3, 3)

input = [torch.randn(20, 20, dtype=torch.double, requires_grad=True)]
test = gradcheck(moduleConv, input, eps=1e-6, atol=1e-4)
print("Are the gradients correct: ", test)