Skip to content

Commit ce58d59

Browse files
authored
Replace "Learning PyTorch with Examples" with fitting sine function with a third order polynomial (#1265)
* Replace tutorial with fitting sine function with third order polynomial * more * Save * save * save * fix * fix * fix * fix * no tensor.data * fix * P3 * save * save * save * save * fix * fix * fix
1 parent 133e5b6 commit ce58d59

13 files changed

+416
-465
lines changed
Lines changed: 72 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,72 @@
1+
# -*- coding: utf-8 -*-
2+
"""
3+
PyTorch: Tensors and autograd
4+
-------------------------------
5+
6+
A third order polynomial, trained to predict :math:`y=\sin(x)` from :math:`-\pi`
7+
to :math:`pi` by minimizing squared Euclidean distance.
8+
9+
This implementation computes the forward pass using operations on PyTorch
10+
Tensors, and uses PyTorch autograd to compute gradients.
11+
12+
13+
A PyTorch Tensor represents a node in a computational graph. If ``x`` is a
14+
Tensor that has ``x.requires_grad=True`` then ``x.grad`` is another Tensor
15+
holding the gradient of ``x`` with respect to some scalar value.
16+
"""
17+
import torch
18+
import math
19+
20+
dtype = torch.float
21+
device = torch.device("cpu")
22+
# device = torch.device("cuda:0") # Uncomment this to run on GPU
23+
24+
# Create Tensors to hold input and outputs.
25+
# By default, requires_grad=False, which indicates that we do not need to
26+
# compute gradients with respect to these Tensors during the backward pass.
27+
x = torch.linspace(-math.pi, math.pi, 2000, device=device, dtype=dtype)
28+
y = torch.sin(x)
29+
30+
# Create random Tensors for weights. For a third order polynomial, we need
31+
# 4 weights: y = a + b x + c x^2 + d x^3
32+
# Setting requires_grad=True indicates that we want to compute gradients with
33+
# respect to these Tensors during the backward pass.
34+
a = torch.randn((), device=device, dtype=dtype, requires_grad=True)
35+
b = torch.randn((), device=device, dtype=dtype, requires_grad=True)
36+
c = torch.randn((), device=device, dtype=dtype, requires_grad=True)
37+
d = torch.randn((), device=device, dtype=dtype, requires_grad=True)
38+
39+
learning_rate = 1e-6
40+
for t in range(2000):
41+
# Forward pass: compute predicted y using operations on Tensors.
42+
y_pred = a + b * x + c * x ** 2 + d * x ** 3
43+
44+
# Compute and print loss using operations on Tensors.
45+
# Now loss is a Tensor of shape (1,)
46+
# loss.item() gets the scalar value held in the loss.
47+
loss = (y_pred - y).pow(2).sum()
48+
if t % 100 == 99:
49+
print(t, loss.item())
50+
51+
# Use autograd to compute the backward pass. This call will compute the
52+
# gradient of loss with respect to all Tensors with requires_grad=True.
53+
# After this call a.grad, b.grad. c.grad and d.grad will be Tensors holding
54+
# the gradient of the loss with respect to a, b, c, d respectively.
55+
loss.backward()
56+
57+
# Manually update weights using gradient descent. Wrap in torch.no_grad()
58+
# because weights have requires_grad=True, but we don't need to track this
59+
# in autograd.
60+
with torch.no_grad():
61+
a -= learning_rate * a.grad
62+
b -= learning_rate * b.grad
63+
c -= learning_rate * c.grad
64+
d -= learning_rate * d.grad
65+
66+
# Manually zero the gradients after updating weights
67+
a.grad = None
68+
b.grad = None
69+
c.grad = None
70+
d.grad = None
71+
72+
print(f'Result: y = {a.item()} + {b.item()} x + {c.item()} x^2 + {d.item()} x^3')
Lines changed: 104 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,104 @@
1+
# -*- coding: utf-8 -*-
2+
"""
3+
PyTorch: Defining New autograd Functions
4+
----------------------------------------
5+
6+
A third order polynomial, trained to predict :math:`y=\sin(x)` from :math:`-\pi`
7+
to :math:`pi` by minimizing squared Euclidean distance. Instead of writing the
8+
polynomial as :math:`y=a+bx+cx^2+dx^3`, we write the polynomial as
9+
:math:`y=a+b P_3(c+dx)` where :math:`P_3(x)=\frac{1}{2}\left(5x^3-3x\right)` is
10+
the `Legendre polynomial`_ of degree three.
11+
12+
.. _Legendre polynomial:
13+
https://en.wikipedia.org/wiki/Legendre_polynomials
14+
15+
This implementation computes the forward pass using operations on PyTorch
16+
Tensors, and uses PyTorch autograd to compute gradients.
17+
18+
In this implementation we implement our own custom autograd function to perform
19+
:math:`P_3'(x)`. By mathematics, :math:`P_3'(x)=\frac{3}{2}\left(5x^2-1\right)`
20+
"""
21+
import torch
22+
import math
23+
24+
25+
class LegendrePolynomial3(torch.autograd.Function):
26+
"""
27+
We can implement our own custom autograd Functions by subclassing
28+
torch.autograd.Function and implementing the forward and backward passes
29+
which operate on Tensors.
30+
"""
31+
32+
@staticmethod
33+
def forward(ctx, input):
34+
"""
35+
In the forward pass we receive a Tensor containing the input and return
36+
a Tensor containing the output. ctx is a context object that can be used
37+
to stash information for backward computation. You can cache arbitrary
38+
objects for use in the backward pass using the ctx.save_for_backward method.
39+
"""
40+
ctx.save_for_backward(input)
41+
return 0.5 * (5 * input ** 3 - 3 * input)
42+
43+
@staticmethod
44+
def backward(ctx, grad_output):
45+
"""
46+
In the backward pass we receive a Tensor containing the gradient of the loss
47+
with respect to the output, and we need to compute the gradient of the loss
48+
with respect to the input.
49+
"""
50+
input, = ctx.saved_tensors
51+
return grad_output * 1.5 * (5 * input ** 2 - 1)
52+
53+
54+
dtype = torch.float
55+
device = torch.device("cpu")
56+
# device = torch.device("cuda:0") # Uncomment this to run on GPU
57+
58+
# Create Tensors to hold input and outputs.
59+
# By default, requires_grad=False, which indicates that we do not need to
60+
# compute gradients with respect to these Tensors during the backward pass.
61+
x = torch.linspace(-math.pi, math.pi, 2000, device=device, dtype=dtype)
62+
y = torch.sin(x)
63+
64+
# Create random Tensors for weights. For this example, we need
65+
# 4 weights: y = a + b * P3(c + d * x), these weights need to be initialized
66+
# not too far from the correct result to ensure convergence.
67+
# Setting requires_grad=True indicates that we want to compute gradients with
68+
# respect to these Tensors during the backward pass.
69+
a = torch.full((), 0.0, device=device, dtype=dtype, requires_grad=True)
70+
b = torch.full((), -1.0, device=device, dtype=dtype, requires_grad=True)
71+
c = torch.full((), 0.0, device=device, dtype=dtype, requires_grad=True)
72+
d = torch.full((), 0.3, device=device, dtype=dtype, requires_grad=True)
73+
74+
learning_rate = 5e-6
75+
for t in range(2000):
76+
# To apply our Function, we use Function.apply method. We alias this as 'P3'.
77+
P3 = LegendrePolynomial3.apply
78+
79+
# Forward pass: compute predicted y using operations; we compute
80+
# P3 using our custom autograd operation.
81+
y_pred = a + b * P3(c + d * x)
82+
83+
# Compute and print loss
84+
loss = (y_pred - y).pow(2).sum()
85+
if t % 100 == 99:
86+
print(t, loss.item())
87+
88+
# Use autograd to compute the backward pass.
89+
loss.backward()
90+
91+
# Update weights using gradient descent
92+
with torch.no_grad():
93+
a -= learning_rate * a.grad
94+
b -= learning_rate * b.grad
95+
c -= learning_rate * c.grad
96+
d -= learning_rate * d.grad
97+
98+
# Manually zero the gradients after updating weights
99+
a.grad = None
100+
b.grad = None
101+
c.grad = None
102+
d.grad = None
103+
104+
print(f'Result: y = {a.item()} + {b.item()} * P3({c.item()} + {d.item()} x)')

beginner_source/examples_autograd/tf_two_layer_net.py

Lines changed: 0 additions & 79 deletions
This file was deleted.

beginner_source/examples_autograd/two_layer_net_autograd.py

Lines changed: 0 additions & 81 deletions
This file was deleted.

0 commit comments

Comments
 (0)