|
4 | 4 | =========================================
|
5 | 5 | **Author**: `Adam Paszke <https://github.com/apaszke>`_
|
6 | 6 |
|
| 7 | +**Updated by**: `Adam Dziedzic` [https://github.com/adam-dziedzic](https://github.com/adam-dziedzic) |
| 8 | +
|
7 | 9 | In this tutorial, we shall go through two tasks:
|
8 | 10 |
|
9 | 11 | 1. Create a neural network layer with no parameters.
|
10 | 12 |
|
11 |
| - - This calls into **numpy** as part of it’s implementation |
| 13 | + - This calls into **numpy** as part of its implementation |
12 | 14 |
|
13 | 15 | 2. Create a neural network layer that has learnable weights
|
14 | 16 |
|
15 |
| - - This calls into **SciPy** as part of it’s implementation |
| 17 | + - This calls into **SciPy** as part of its implementation |
16 | 18 | """
|
17 | 19 |
|
18 | 20 | import torch
|
@@ -64,60 +66,74 @@ def incorrect_fft(input):
|
64 | 66 | # Parametrized example
|
65 | 67 | # --------------------
|
66 | 68 | #
|
67 |
| -# This implements a layer with learnable weights. |
68 |
| -# |
69 |
| -# It implements the Cross-correlation with a learnable kernel. |
70 |
| -# |
71 |
| -# In deep learning literature, it’s confusingly referred to as |
72 |
| -# Convolution. |
73 |
| -# |
74 |
| -# The backward computes the gradients wrt the input and gradients wrt the |
75 |
| -# filter. |
| 69 | +# In deep learning literature, this layer is confusingly referred |
| 70 | +# to as convolution while the actual operation is cross-correlation |
| 71 | +# (the only difference is that filter is flipped for convolution, |
| 72 | +# which is not the case for cross-correlation). |
76 | 73 | #
|
77 |
| -# **Implementation:** |
| 74 | +# Implementation of a layer with learnable weights, where cross-correlation |
| 75 | +# has a filter (kernel) that represents weights. |
78 | 76 | #
|
79 |
| -# *Please Note that the implementation serves as an illustration, and we |
80 |
| -# did not verify it’s correctness* |
| 77 | +# The backward pass computes the gradient wrt the input and the gradient wrt the filter. |
81 | 78 |
|
| 79 | +from numpy import flip |
| 80 | +import numpy as np |
82 | 81 | from scipy.signal import convolve2d, correlate2d
|
83 | 82 | from torch.nn.modules.module import Module
|
84 | 83 | from torch.nn.parameter import Parameter
|
85 | 84 |
|
86 | 85 |
|
87 | 86 | class ScipyConv2dFunction(Function):
|
88 | 87 | @staticmethod
|
89 |
| - def forward(ctx, input, filter): |
90 |
| - input, filter = input.detach(), filter.detach() # detach so we can cast to NumPy |
91 |
| - result = correlate2d(input.numpy(), filter.detach().numpy(), mode='valid') |
92 |
| - ctx.save_for_backward(input, filter) |
93 |
| - return input.new(result) |
| 88 | + def forward(ctx, input, filter, bias): |
| 89 | + # detach so we can cast to NumPy |
| 90 | + input, filter, bias = input.detach(), filter.detach(), bias.detach() |
| 91 | + result = correlate2d(input.numpy(), filter.numpy(), mode='valid') |
| 92 | + result += bias.numpy() |
| 93 | + ctx.save_for_backward(input, filter, bias) |
| 94 | + return torch.from_numpy(result) |
94 | 95 |
|
95 | 96 | @staticmethod
|
96 | 97 | def backward(ctx, grad_output):
|
97 | 98 | grad_output = grad_output.detach()
|
98 |
| - input, filter = ctx.saved_tensors |
99 |
| - grad_input = convolve2d(grad_output.numpy(), filter.t().numpy(), mode='full') |
100 |
| - grad_filter = convolve2d(input.numpy(), grad_output.numpy(), mode='valid') |
101 |
| - |
102 |
| - return grad_output.new_tensor(grad_input), grad_output.new_tensor(grad_filter) |
| 99 | + input, filter, bias = ctx.saved_tensors |
| 100 | + grad_output = grad_output.numpy() |
| 101 | + grad_bias = np.sum(grad_output, keepdims=True) |
| 102 | + grad_input = convolve2d(grad_output, filter.numpy(), mode='full') |
| 103 | + # the previous line can be expressed equivalently as: |
| 104 | + # grad_input = correlate2d(grad_output, flip(flip(filter.numpy(), axis=0), axis=1), mode='full') |
| 105 | + grad_filter = correlate2d(input.numpy(), grad_output, mode='valid') |
| 106 | + return torch.from_numpy(grad_input), torch.from_numpy(grad_filter), torch.from_numpy(grad_bias) |
103 | 107 |
|
104 | 108 |
|
105 | 109 | class ScipyConv2d(Module):
|
106 |
| - |
107 |
| - def __init__(self, kh, kw): |
| 110 | + def __init__(self, filter_width, filter_height): |
108 | 111 | super(ScipyConv2d, self).__init__()
|
109 |
| - self.filter = Parameter(torch.randn(kh, kw)) |
| 112 | + self.filter = Parameter(torch.randn(filter_width, filter_height)) |
| 113 | + self.bias = Parameter(torch.randn(1, 1)) |
110 | 114 |
|
111 | 115 | def forward(self, input):
|
112 |
| - return ScipyConv2dFunction.apply(input, self.filter) |
| 116 | + return ScipyConv2dFunction.apply(input, self.filter, self.bias) |
| 117 | + |
113 | 118 |
|
114 | 119 | ###############################################################
|
115 | 120 | # **Example usage:**
|
116 | 121 |
|
117 | 122 | module = ScipyConv2d(3, 3)
|
118 |
| -print(list(module.parameters())) |
| 123 | +print("Filter and bias: ", list(module.parameters())) |
119 | 124 | input = torch.randn(10, 10, requires_grad=True)
|
120 | 125 | output = module(input)
|
121 |
| -print(output) |
| 126 | +print("Output from the convolution: ", output) |
122 | 127 | output.backward(torch.randn(8, 8))
|
123 |
| -print(input.grad) |
| 128 | +print("Gradient for the input map: ", input.grad) |
| 129 | + |
| 130 | +############################################################### |
| 131 | +# **Check the gradients:** |
| 132 | + |
| 133 | +from torch.autograd.gradcheck import gradcheck |
| 134 | + |
| 135 | +moduleConv = ScipyConv2d(3, 3) |
| 136 | + |
| 137 | +input = [torch.randn(20, 20, dtype=torch.double, requires_grad=True)] |
| 138 | +test = gradcheck(moduleConv, input, eps=1e-6, atol=1e-4) |
| 139 | +print("Are the gradients correct: ", test) |
0 commit comments