Skip to content

Commit fa40da4

Browse files
author
Will Feng
authored
Merge branch 'master' into jlin27-cpp-frontend-remove-conv5
2 parents 322b101 + a300b1d commit fa40da4

15 files changed

+158
-230
lines changed

advanced_source/cpp_export.rst

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
3. Loading a TorchScript Model in C++
1+
Loading a TorchScript Model in C++
22
=====================================
33

44
**This tutorial was updated to work with PyTorch 1.2**

advanced_source/cpp_extension.rst

Lines changed: 45 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -946,7 +946,8 @@ without having to convert to a single pointer:
946946
Accessor objects have a relatively high level interface, with ``.size()`` and
947947
``.stride()`` methods and multi-dimensional indexing. The ``.accessor<>``
948948
interface is designed to access data efficiently on cpu tensor. The equivalent
949-
for cuda tensors is the ``packed_accessor<>``, which produces a Packed Accessor.
949+
for cuda tensors are ``packed_accessor64<>`` and ``packed_accessor32<>``, which
950+
produce Packed Accessors with either 64-bit or 32-bit integer indexing.
950951

951952
The fundamental difference with Accessor is that a Packed Accessor copies size
952953
and stride data inside of its structure instead of pointing to it. It allows us
@@ -957,34 +958,34 @@ We can design a function that takes Packed Accessors instead of pointers.
957958
.. code-block:: cpp
958959
959960
__global__ void lltm_cuda_forward_kernel(
960-
const torch::PackedTensorAccessor<scalar_t,3,torch::RestrictPtrTraits,size_t> gates,
961-
const torch::PackedTensorAccessor<scalar_t,2,torch::RestrictPtrTraits,size_t> old_cell,
962-
torch::PackedTensorAccessor<scalar_t,2,torch::RestrictPtrTraits,size_t> new_h,
963-
torch::PackedTensorAccessor<scalar_t,2,torch::RestrictPtrTraits,size_t> new_cell,
964-
torch::PackedTensorAccessor<scalar_t,2,torch::RestrictPtrTraits,size_t> input_gate,
965-
torch::PackedTensorAccessor<scalar_t,2,torch::RestrictPtrTraits,size_t> output_gate,
966-
torch::PackedTensorAccessor<scalar_t,2,torch::RestrictPtrTraits,size_t> candidate_cell)
961+
const torch::PackedTensorAccessor32<scalar_t,3,torch::RestrictPtrTraits> gates,
962+
const torch::PackedTensorAccessor32<scalar_t,2,torch::RestrictPtrTraits> old_cell,
963+
torch::PackedTensorAccessor32<scalar_t,2,torch::RestrictPtrTraits> new_h,
964+
torch::PackedTensorAccessor32<scalar_t,2,torch::RestrictPtrTraits> new_cell,
965+
torch::PackedTensorAccessor32<scalar_t,2,torch::RestrictPtrTraits> input_gate,
966+
torch::PackedTensorAccessor32<scalar_t,2,torch::RestrictPtrTraits> output_gate,
967+
torch::PackedTensorAccessor32<scalar_t,2,torch::RestrictPtrTraits> candidate_cell)
967968
968969
Let's decompose the template used here. the first two arguments ``scalar_t`` and
969970
``2`` are the same as regular Accessor. The argument
970971
``torch::RestrictPtrTraits`` indicates that the ``__restrict__`` keyword must be
971-
used. Finally, the argument ``size_t`` indicates that sizes and strides must be
972-
stored in a ``size_t`` integer. This is important as by default ``int64_t`` is
973-
used and can make the kernel slower.
972+
used. Note also that we've used the ``PackedAccessor32`` variant which store the
973+
sizes and strides in an ``int32_t``. This is important as using the 64-bit
974+
variant (``PackedAccessor64``) can make the kernel slower.
974975

975976
The function declaration becomes
976977

977978
.. code-block:: cpp
978979
979980
template <typename scalar_t>
980981
__global__ void lltm_cuda_forward_kernel(
981-
const torch::PackedTensorAccessor<scalar_t,3,torch::RestrictPtrTraits,size_t> gates,
982-
const torch::PackedTensorAccessor<scalar_t,2,torch::RestrictPtrTraits,size_t> old_cell,
983-
torch::PackedTensorAccessor<scalar_t,2,torch::RestrictPtrTraits,size_t> new_h,
984-
torch::PackedTensorAccessor<scalar_t,2,torch::RestrictPtrTraits,size_t> new_cell,
985-
torch::PackedTensorAccessor<scalar_t,2,torch::RestrictPtrTraits,size_t> input_gate,
986-
torch::PackedTensorAccessor<scalar_t,2,torch::RestrictPtrTraits,size_t> output_gate,
987-
torch::PackedTensorAccessor<scalar_t,2,torch::RestrictPtrTraits,size_t> candidate_cell) {
982+
const torch::PackedTensorAccessor32<scalar_t,3,torch::RestrictPtrTraits> gates,
983+
const torch::PackedTensorAccessor32<scalar_t,2,torch::RestrictPtrTraits> old_cell,
984+
torch::PackedTensorAccessor32<scalar_t,2,torch::RestrictPtrTraits> new_h,
985+
torch::PackedTensorAccessor32<scalar_t,2,torch::RestrictPtrTraits> new_cell,
986+
torch::PackedTensorAccessor32<scalar_t,2,torch::RestrictPtrTraits> input_gate,
987+
torch::PackedTensorAccessor32<scalar_t,2,torch::RestrictPtrTraits> output_gate,
988+
torch::PackedTensorAccessor32<scalar_t,2,torch::RestrictPtrTraits> candidate_cell) {
988989
//batch index
989990
const int n = blockIdx.y;
990991
// column index
@@ -1000,7 +1001,7 @@ The function declaration becomes
10001001
}
10011002
10021003
The implementation is much more readable! This function is then called by
1003-
creating Packed Accessors with the ``.packed_accessor<>`` method within the
1004+
creating Packed Accessors with the ``.packed_accessor32<>`` method within the
10041005
host function.
10051006

10061007
.. code-block:: cpp
@@ -1029,13 +1030,13 @@ host function.
10291030
10301031
AT_DISPATCH_FLOATING_TYPES(gates.type(), "lltm_forward_cuda", ([&] {
10311032
lltm_cuda_forward_kernel<scalar_t><<<blocks, threads>>>(
1032-
gates.packed_accessor<scalar_t,3,torch::RestrictPtrTraits,size_t>(),
1033-
old_cell.packed_accessor<scalar_t,2,torch::RestrictPtrTraits,size_t>(),
1034-
new_h.packed_accessor<scalar_t,2,torch::RestrictPtrTraits,size_t>(),
1035-
new_cell.packed_accessor<scalar_t,2,torch::RestrictPtrTraits,size_t>(),
1036-
input_gate.packed_accessor<scalar_t,2,torch::RestrictPtrTraits,size_t>(),
1037-
output_gate.packed_accessor<scalar_t,2,torch::RestrictPtrTraits,size_t>(),
1038-
candidate_cell.packed_accessor<scalar_t,2,torch::RestrictPtrTraits,size_t>());
1033+
gates.packed_accessor32<scalar_t,3,torch::RestrictPtrTraits>(),
1034+
old_cell.packed_accessor32<scalar_t,2,torch::RestrictPtrTraits>(),
1035+
new_h.packed_accessor32<scalar_t,2,torch::RestrictPtrTraits>(),
1036+
new_cell.packed_accessor32<scalar_t,2,torch::RestrictPtrTraits>(),
1037+
input_gate.packed_accessor32<scalar_t,2,torch::RestrictPtrTraits>(),
1038+
output_gate.packed_accessor32<scalar_t,2,torch::RestrictPtrTraits>(),
1039+
candidate_cell.packed_accessor32<scalar_t,2,torch::RestrictPtrTraits>());
10391040
}));
10401041
10411042
return {new_h, new_cell, input_gate, output_gate, candidate_cell, X, gates};
@@ -1048,15 +1049,15 @@ on it:
10481049
10491050
template <typename scalar_t>
10501051
__global__ void lltm_cuda_backward_kernel(
1051-
torch::PackedTensorAccessor<scalar_t,2,torch::RestrictPtrTraits,size_t> d_old_cell,
1052-
torch::PackedTensorAccessor<scalar_t,3,torch::RestrictPtrTraits,size_t> d_gates,
1053-
const torch::PackedTensorAccessor<scalar_t,2,torch::RestrictPtrTraits,size_t> grad_h,
1054-
const torch::PackedTensorAccessor<scalar_t,2,torch::RestrictPtrTraits,size_t> grad_cell,
1055-
const torch::PackedTensorAccessor<scalar_t,2,torch::RestrictPtrTraits,size_t> new_cell,
1056-
const torch::PackedTensorAccessor<scalar_t,2,torch::RestrictPtrTraits,size_t> input_gate,
1057-
const torch::PackedTensorAccessor<scalar_t,2,torch::RestrictPtrTraits,size_t> output_gate,
1058-
const torch::PackedTensorAccessor<scalar_t,2,torch::RestrictPtrTraits,size_t> candidate_cell,
1059-
const torch::PackedTensorAccessor<scalar_t,3,torch::RestrictPtrTraits,size_t> gate_weights) {
1052+
torch::PackedTensorAccessor32<scalar_t,2,torch::RestrictPtrTraits> d_old_cell,
1053+
torch::PackedTensorAccessor32<scalar_t,3,torch::RestrictPtrTraits> d_gates,
1054+
const torch::PackedTensorAccessor32<scalar_t,2,torch::RestrictPtrTraits> grad_h,
1055+
const torch::PackedTensorAccessor32<scalar_t,2,torch::RestrictPtrTraits> grad_cell,
1056+
const torch::PackedTensorAccessor32<scalar_t,2,torch::RestrictPtrTraits> new_cell,
1057+
const torch::PackedTensorAccessor32<scalar_t,2,torch::RestrictPtrTraits> input_gate,
1058+
const torch::PackedTensorAccessor32<scalar_t,2,torch::RestrictPtrTraits> output_gate,
1059+
const torch::PackedTensorAccessor32<scalar_t,2,torch::RestrictPtrTraits> candidate_cell,
1060+
const torch::PackedTensorAccessor32<scalar_t,3,torch::RestrictPtrTraits> gate_weights) {
10601061
//batch index
10611062
const int n = blockIdx.y;
10621063
// column index
@@ -1102,15 +1103,15 @@ on it:
11021103
11031104
AT_DISPATCH_FLOATING_TYPES(X.type(), "lltm_forward_cuda", ([&] {
11041105
lltm_cuda_backward_kernel<scalar_t><<<blocks, threads>>>(
1105-
d_old_cell.packed_accessor<scalar_t,2,torch::RestrictPtrTraits,size_t>(),
1106-
d_gates.packed_accessor<scalar_t,3,torch::RestrictPtrTraits,size_t>(),
1107-
grad_h.packed_accessor<scalar_t,2,torch::RestrictPtrTraits,size_t>(),
1108-
grad_cell.packed_accessor<scalar_t,2,torch::RestrictPtrTraits,size_t>(),
1109-
new_cell.packed_accessor<scalar_t,2,torch::RestrictPtrTraits,size_t>(),
1110-
input_gate.packed_accessor<scalar_t,2,torch::RestrictPtrTraits,size_t>(),
1111-
output_gate.packed_accessor<scalar_t,2,torch::RestrictPtrTraits,size_t>(),
1112-
candidate_cell.packed_accessor<scalar_t,2,torch::RestrictPtrTraits,size_t>(),
1113-
gates.packed_accessor<scalar_t,3,torch::RestrictPtrTraits,size_t>());
1106+
d_old_cell.packed_accessor32<scalar_t,2,torch::RestrictPtrTraits>(),
1107+
d_gates.packed_accessor32<scalar_t,3,torch::RestrictPtrTraits>(),
1108+
grad_h.packed_accessor32<scalar_t,2,torch::RestrictPtrTraits>(),
1109+
grad_cell.packed_accessor32<scalar_t,2,torch::RestrictPtrTraits>(),
1110+
new_cell.packed_accessor32<scalar_t,2,torch::RestrictPtrTraits>(),
1111+
input_gate.packed_accessor32<scalar_t,2,torch::RestrictPtrTraits>(),
1112+
output_gate.packed_accessor32<scalar_t,2,torch::RestrictPtrTraits>(),
1113+
candidate_cell.packed_accessor32<scalar_t,2,torch::RestrictPtrTraits>(),
1114+
gates.packed_accessor32<scalar_t,3,torch::RestrictPtrTraits>());
11141115
}));
11151116
11161117
auto d_gate_weights = d_gates.reshape({batch_size, 3*state_size});

advanced_source/super_resolution_with_onnxruntime.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
"""
2-
4. (optional) Exporting a Model from PyTorch to ONNX and Running it using ONNX Runtime
2+
(optional) Exporting a Model from PyTorch to ONNX and Running it using ONNX Runtime
33
========================================================================
44
55
In this tutorial, we describe how to convert a model defined

beginner_source/Intro_to_TorchScript_tutorial.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
"""
2-
2. Introduction to TorchScript
2+
Introduction to TorchScript
33
===========================
44
55
*James Reed (jamesreed@fb.com), Michael Suo (suo@fb.com)*, rev2
@@ -24,7 +24,7 @@
2424
- How to compose both approaches
2525
- Saving and loading TorchScript modules
2626
27-
We hope that after you complete this tutorial, you proceed to go through
27+
We hope that after you complete this tutorial, you will proceed to go through
2828
`the follow-on tutorial <https://pytorch.org/tutorials/advanced/cpp_export.html>`_
2929
which will walk you through an example of actually calling a TorchScript
3030
model from C++.

beginner_source/aws_distributed_training_tutorial.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
"""
2-
4. (advanced) PyTorch 1.0 Distributed Trainer with Amazon AWS
2+
(advanced) PyTorch 1.0 Distributed Trainer with Amazon AWS
33
=============================================================
44
55
**Author**: `Nathan Inkawhich <https://github.com/inkawhich>`_

beginner_source/blitz/cifar10_tutorial.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -185,6 +185,15 @@ def forward(self, x):
185185
print('Finished Training')
186186

187187
########################################################################
188+
# Let's quickly save our trained model:
189+
190+
PATH = './cifar_net.pth'
191+
torch.save(net.state_dict(), PATH)
192+
193+
########################################################################
194+
# See `here <https://pytorch.org/docs/stable/notes/serialization.html>`_
195+
# for more details on saving PyTorch models.
196+
#
188197
# 5. Test the network on the test data
189198
# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
190199
#
@@ -204,6 +213,13 @@ def forward(self, x):
204213
imshow(torchvision.utils.make_grid(images))
205214
print('GroundTruth: ', ' '.join('%5s' % classes[labels[j]] for j in range(4)))
206215

216+
########################################################################
217+
# Next, let's load back in our saved model (note: saving and re-loading the model
218+
# wasn't necessary here, we only did it to illustrate how to do so):
219+
220+
net = Net()
221+
net.load_state_dict(torch.load(PATH))
222+
207223
########################################################################
208224
# Okay, now let us see what the neural network thinks these examples above are:
209225

beginner_source/chatbot_tutorial.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -537,7 +537,7 @@ def outputVar(l, voc):
537537
max_target_len = max([len(indexes) for indexes in indexes_batch])
538538
padList = zeroPadding(indexes_batch)
539539
mask = binaryMatrix(padList)
540-
mask = torch.ByteTensor(mask)
540+
mask = torch.BoolTensor(mask)
541541
padVar = torch.LongTensor(padList)
542542
return padVar, mask, max_target_len
543543

beginner_source/data_loading_tutorial.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
# -*- coding: utf-8 -*-
22
"""
3-
Data Loading and Processing Tutorial
4-
====================================
3+
Writing Custom Datasets, DataLoaders and Transforms
4+
===================================================
55
**Author**: `Sasank Chilamkurthy <https://chsasank.github.io>`_
66
77
A lot of effort in solving any machine learning problem goes in to

beginner_source/pytorch_with_examples.rst

Lines changed: 0 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -123,43 +123,6 @@ network:
123123

124124
.. includenodoc:: /beginner/examples_autograd/two_layer_net_custom_function.py
125125

126-
TensorFlow: Static Graphs
127-
-------------------------
128-
129-
PyTorch autograd looks a lot like TensorFlow: in both frameworks we
130-
define a computational graph, and use automatic differentiation to
131-
compute gradients. The biggest difference between the two is that
132-
TensorFlow's computational graphs are **static** and PyTorch uses
133-
**dynamic** computational graphs.
134-
135-
In TensorFlow, we define the computational graph once and then execute
136-
the same graph over and over again, possibly feeding different input
137-
data to the graph. In PyTorch, each forward pass defines a new
138-
computational graph.
139-
140-
Static graphs are nice because you can optimize the graph up front; for
141-
example a framework might decide to fuse some graph operations for
142-
efficiency, or to come up with a strategy for distributing the graph
143-
across many GPUs or many machines. If you are reusing the same graph
144-
over and over, then this potentially costly up-front optimization can be
145-
amortized as the same graph is rerun over and over.
146-
147-
One aspect where static and dynamic graphs differ is control flow. For
148-
some models we may wish to perform different computation for each data
149-
point; for example a recurrent network might be unrolled for different
150-
numbers of time steps for each data point; this unrolling can be
151-
implemented as a loop. With a static graph the loop construct needs to
152-
be a part of the graph; for this reason TensorFlow provides operators
153-
such as ``tf.scan`` for embedding loops into the graph. With dynamic
154-
graphs the situation is simpler: since we build graphs on-the-fly for
155-
each example, we can use normal imperative flow control to perform
156-
computation that differs for each input.
157-
158-
To contrast with the PyTorch autograd example above, here we use
159-
TensorFlow to fit a simple two-layer net:
160-
161-
.. includenodoc:: /beginner/examples_autograd/tf_two_layer_net.py
162-
163126
`nn` module
164127
===========
165128

beginner_source/transfer_learning_tutorial.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,12 @@
11
# -*- coding: utf-8 -*-
22
"""
3-
Transfer Learning Tutorial
4-
==========================
3+
Transfer Learning for Computer Vision Tutorial
4+
==============================================
55
**Author**: `Sasank Chilamkurthy <https://chsasank.github.io>`_
66
7-
In this tutorial, you will learn how to train your network using
8-
transfer learning. You can read more about the transfer learning at `cs231n
9-
notes <https://cs231n.github.io/transfer-learning/>`__
7+
In this tutorial, you will learn how to train a convolutional neural network for
8+
image classification using transfer learning. You can read more about the transfer
9+
learning at `cs231n notes <https://cs231n.github.io/transfer-learning/>`__
1010
1111
Quoting these notes,
1212

0 commit comments

Comments
 (0)