Skip to content
This repository was archived by the owner on Aug 7, 2019. It is now read-only.

Commit f158325

Browse files
adam-dziedzicsoumith
authored andcommitted
corrected the example for cross-correlation with scipy (pytorch#268)
* corrected the example with cross-correlation with scipy - the example is fully correct with appropriate gradients * improve the code and comments; corrected the example with cross-correlation with scipy - the example is fully correct with appropriate gradients * keep the include in the previous place * remove spurious new lines * small changes * flipped filter or cross-correlation
1 parent cd4953d commit f158325

File tree

1 file changed

+47
-31
lines changed

1 file changed

+47
-31
lines changed

advanced_source/numpy_extensions_tutorial.py

Lines changed: 47 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -4,15 +4,17 @@
44
=========================================
55
**Author**: `Adam Paszke <https://github.com/apaszke>`_
66
7+
**Updated by**: `Adam Dziedzic` [https://github.com/adam-dziedzic](https://github.com/adam-dziedzic)
8+
79
In this tutorial, we shall go through two tasks:
810
911
1. Create a neural network layer with no parameters.
1012
11-
- This calls into **numpy** as part of it’s implementation
13+
- This calls into **numpy** as part of its implementation
1214
1315
2. Create a neural network layer that has learnable weights
1416
15-
- This calls into **SciPy** as part of it’s implementation
17+
- This calls into **SciPy** as part of its implementation
1618
"""
1719

1820
import torch
@@ -64,60 +66,74 @@ def incorrect_fft(input):
6466
# Parametrized example
6567
# --------------------
6668
#
67-
# This implements a layer with learnable weights.
68-
#
69-
# It implements the Cross-correlation with a learnable kernel.
70-
#
71-
# In deep learning literature, it’s confusingly referred to as
72-
# Convolution.
73-
#
74-
# The backward computes the gradients wrt the input and gradients wrt the
75-
# filter.
69+
# In deep learning literature, this layer is confusingly referred
70+
# to as convolution while the actual operation is cross-correlation
71+
# (the only difference is that filter is flipped for convolution,
72+
# which is not the case for cross-correlation).
7673
#
77-
# **Implementation:**
74+
# Implementation of a layer with learnable weights, where cross-correlation
75+
# has a filter (kernel) that represents weights.
7876
#
79-
# *Please Note that the implementation serves as an illustration, and we
80-
# did not verify it’s correctness*
77+
# The backward pass computes the gradient wrt the input and the gradient wrt the filter.
8178

79+
from numpy import flip
80+
import numpy as np
8281
from scipy.signal import convolve2d, correlate2d
8382
from torch.nn.modules.module import Module
8483
from torch.nn.parameter import Parameter
8584

8685

8786
class ScipyConv2dFunction(Function):
8887
@staticmethod
89-
def forward(ctx, input, filter):
90-
input, filter = input.detach(), filter.detach() # detach so we can cast to NumPy
91-
result = correlate2d(input.numpy(), filter.detach().numpy(), mode='valid')
92-
ctx.save_for_backward(input, filter)
93-
return input.new(result)
88+
def forward(ctx, input, filter, bias):
89+
# detach so we can cast to NumPy
90+
input, filter, bias = input.detach(), filter.detach(), bias.detach()
91+
result = correlate2d(input.numpy(), filter.numpy(), mode='valid')
92+
result += bias.numpy()
93+
ctx.save_for_backward(input, filter, bias)
94+
return torch.from_numpy(result)
9495

9596
@staticmethod
9697
def backward(ctx, grad_output):
9798
grad_output = grad_output.detach()
98-
input, filter = ctx.saved_tensors
99-
grad_input = convolve2d(grad_output.numpy(), filter.t().numpy(), mode='full')
100-
grad_filter = convolve2d(input.numpy(), grad_output.numpy(), mode='valid')
101-
102-
return grad_output.new_tensor(grad_input), grad_output.new_tensor(grad_filter)
99+
input, filter, bias = ctx.saved_tensors
100+
grad_output = grad_output.numpy()
101+
grad_bias = np.sum(grad_output, keepdims=True)
102+
grad_input = convolve2d(grad_output, filter.numpy(), mode='full')
103+
# the previous line can be expressed equivalently as:
104+
# grad_input = correlate2d(grad_output, flip(flip(filter.numpy(), axis=0), axis=1), mode='full')
105+
grad_filter = correlate2d(input.numpy(), grad_output, mode='valid')
106+
return torch.from_numpy(grad_input), torch.from_numpy(grad_filter), torch.from_numpy(grad_bias)
103107

104108

105109
class ScipyConv2d(Module):
106-
107-
def __init__(self, kh, kw):
110+
def __init__(self, filter_width, filter_height):
108111
super(ScipyConv2d, self).__init__()
109-
self.filter = Parameter(torch.randn(kh, kw))
112+
self.filter = Parameter(torch.randn(filter_width, filter_height))
113+
self.bias = Parameter(torch.randn(1, 1))
110114

111115
def forward(self, input):
112-
return ScipyConv2dFunction.apply(input, self.filter)
116+
return ScipyConv2dFunction.apply(input, self.filter, self.bias)
117+
113118

114119
###############################################################
115120
# **Example usage:**
116121

117122
module = ScipyConv2d(3, 3)
118-
print(list(module.parameters()))
123+
print("Filter and bias: ", list(module.parameters()))
119124
input = torch.randn(10, 10, requires_grad=True)
120125
output = module(input)
121-
print(output)
126+
print("Output from the convolution: ", output)
122127
output.backward(torch.randn(8, 8))
123-
print(input.grad)
128+
print("Gradient for the input map: ", input.grad)
129+
130+
###############################################################
131+
# **Check the gradients:**
132+
133+
from torch.autograd.gradcheck import gradcheck
134+
135+
moduleConv = ScipyConv2d(3, 3)
136+
137+
input = [torch.randn(20, 20, dtype=torch.double, requires_grad=True)]
138+
test = gradcheck(moduleConv, input, eps=1e-6, atol=1e-4)
139+
print("Are the gradients correct: ", test)

0 commit comments

Comments
 (0)