Skip to content

Updating Nightly Build Branch #1159

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 21 commits into from
Sep 24, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
21 commits
Select commit Hold shift + click to select a range
fa7cff7
Fix typo (#1118)
malfet Aug 12, 2020
c61e49c
Recover the attributes of torch in memory_format_tutorial (#1112)
guyang3532 Aug 19, 2020
fdbe99c
fix bugs for data_loading_tutorial and dcgan_faces_tutorial (#1092)
guyang3532 Aug 19, 2020
de2571f
Update autocast in dispatcher tutorial (#1128)
mcarilli Aug 20, 2020
6ab3a37
Corrected model.resnet50() spelling (#1139)
ucalyptus Aug 26, 2020
b458ced
Fix typo & Minor changes (#1138)
codingbowoo Aug 27, 2020
770abb2
Run win_test_worker manually (#1142)
malfet Aug 27, 2020
4bfe338
Disable `pytorch_windows_builder_worker` config (#1143)
malfet Aug 28, 2020
1707a90
Update index.rst (#1140)
brianjo Aug 28, 2020
e60c6af
Update index.rst
brianjo Aug 28, 2020
3191c0c
LSTM's -> LSTMs in equence_models_tutorial.py docs (#1136)
adelevie Aug 28, 2020
ae17b15
Added Ray Tune Hyperparameter Tuning Tutorial (#1066)
krfricke Aug 31, 2020
4a70101
Fix typo in "Introduction to Pytorch" tutorial (in NLP tutorial) (#1145)
viswavi Sep 1, 2020
ee5e448
Install torch not torch vision (#1153)
ranman Sep 9, 2020
fe33b54
Python recipe for automatic mixed precision (#1137)
mcarilli Sep 15, 2020
ba6070e
Fix model to be properly exported to ONNX (#1144)
Sep 15, 2020
cba6b85
Dist rpc merge (#1158)
brianjo Sep 17, 2020
f1e682e
Fix typo "asynchronizely" -> "asynchronously" (#1154)
PWhiddy Sep 21, 2020
468124c
Update dist_overview with additional information. (#1155)
pritamdamania87 Sep 21, 2020
96d4201
Add Performance Tuning guide recipe (#1161)
szmigacz Sep 22, 2020
258f422
A fix for one line comment when removing runnable code. (#1165)
Sep 24, 2020
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 7 additions & 6 deletions .circleci/config.yml
Original file line number Diff line number Diff line change
Expand Up @@ -562,10 +562,11 @@ workflows:
branches:
only:
- master
- pytorch_windows_build_worker:
name: win_test_worker
filters:
branches:
only:
- master
# - pytorch_windows_build_worker:
# name: win_test_worker
# type: approval
# filters:
# branches:
# only:
# - master

8 changes: 8 additions & 0 deletions .jenkins/remove_runnable_code.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,17 @@
if line.startswith('#'):
ret_lines.append(line)
state = STATE_NORMAL
elif ((line.startswith('"""') or line.startswith('r"""')) and
line.endswith('"""')):
ret_lines.append(line)
state = STATE_NORMAL
elif line.startswith('"""') or line.startswith('r"""'):
ret_lines.append(line)
state = STATE_IN_MULTILINE_COMMENT_BLOCK_DOUBLE_QUOTE
elif ((line.startswith("'''") or line.startswith("r'''")) and
line.endswith("'''")):
ret_lines.append(line)
state = STATE_NORMAL
elif line.startswith("'''") or line.startswith("r'''"):
ret_lines.append(line)
state = STATE_IN_MULTILINE_COMMENT_BLOCK_SINGLE_QUOTE
Expand Down
Binary file added _static/img/ray-tune.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added _static/img/thumbnails/cropped/amp.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added _static/img/thumbnails/cropped/profile.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
109 changes: 85 additions & 24 deletions advanced_source/dispatcher.rst
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,8 @@ speaking, the structure of your registrations will look like this:
that provides implementations for all basic operators on the XLA dispatch
key.

.. _autograd-support:

Adding autograd support
-----------------------

Expand Down Expand Up @@ -229,38 +231,97 @@ Autocast
^^^^^^^^

The Autocast dispatch key implements support for
`automatic mixed precision <https://developer.nvidia.com/automatic-mixed-precision>`_
(AMP). An autocast kernel typically modifies the operation of an operator by casting the
input arguments to some precision before carrying out the operation. For some
operations, it is numerically safe to cast to lower precision, which is how AMP
can achieve speed ups and reduced memory usage without sacrificing much
accuracy. A nontrivial autocast kernel looks something like this:
`automatic mixed precision (AMP) <https://pytorch.org/docs/stable/amp.html>`_.
An autocast wrapper kernel typically casts incoming ``float16`` or ``float32`` CUDA tensors
to some preferred precision before running the op.
For example, matmuls and convolutions on floating-point CUDA tensors usually run faster
and use less memory in ``float16`` without impairing convergence.
Autocast wrappers only have an effect in
`autocast-enabled contexts <https://pytorch.org/docs/stable/amp.html#torch.cuda.amp.autocast>`_.

Here's an autocast wrapper for a hypothetical custom matmul, along with its registration:

.. code-block:: cpp

// Autocast-specific helper functions
#include <ATen/autocast_mode.h>

Tensor mymatmul_autocast(const Tensor& self, const Tensor& other) {
c10::impl::ExcludeDispatchKeyGuard no_autocast(c10::DispatchKey::Autocast);
return mymatmul(autocast::_cast(at::kHalf, self), autocast::_cast(at::kHalf, other));
return mymatmul(at::autocast::cached_cast(at::kHalf, self),
at::autocast::cached_cast(at::kHalf, other));
}

TORCH_LIBRARY_IMPL(myops, Autocast, m) {
m.impl("mymatmul", mymatmul_autocast);
}

``cached_cast(kHalf, tensor)`` casts ``tensor`` to ``float16`` if ``tensor`` is CUDA and ``float32``,
otherwise, it leaves ``tensor`` unchanged (c.f. the
`eligibility policy <https://pytorch.org/docs/stable/amp.html#op-eligibility>`_ for natively autocasted ops).
This ensures if the network calls ``mymatmul`` on any mixture of ``float16`` and ``float32`` CUDA tensors,
``mymatmul`` runs in ``float16``. Meanwhile, calls to ``mymatmul`` with non-CUDA, integer-type, or ``float64``
inputs are unaffected. Using ``cached_cast`` to follow the native eligibility policy in your own autocast wrapper
is recommended, but not required. For example, if you wanted to force ``float16`` execution for all input types,
you could ``return mymatmul(self.half(), other.half());`` instead of using ``cached_cast``.

Notice that, like our autograd kernels, we exclude the ``Autocast`` key from
dispatch before redispatching. By default, if no autocast kernel is provided,
we simply fallthrough directly to the regular operator implementation (no
autocasting occurs.) (We didn't use ``myadd`` for this example, since pointwise
addition doesn't do autocasting and should just fall through).

When should an autocast kernel be registered? Unfortunately, there aren't
cut-and-dry rules for when you should cast to a lower precision. You can
get a sense for what operators have autocasting behavior by looking at
the `AMP documentation
<https://pytorch.org/docs/master/amp.html#op-specific-behavior>`_. Some other
general rules:

* Operations that do reductions should be carried out in float32,
* Any operation with multiple float tensor inputs has to standardize them
to a common precision, and
* Any operation that does a convolution or gemm under the hood should
probably be float16
dispatch before redispatching.

By default, if no autocast wrapper is provided,
we fallthrough directly to the regular operator implementation (no
autocasting occurs). (We didn't use ``myadd`` for this example, since pointwise
addition doesn't need autocasting and should just fall through.)

When should an autocast wrapper be registered? Unfortunately, there aren't
cut-and-dried rules for an op's preferred precision. You can
get a sense for some native ops' preferred precisions by looking at the
`cast lists <https://pytorch.org/docs/master/amp.html#op-specific-behavior>`_.
General guidance:

* Ops that do reductions should probably execute in ``float32``,
* Any op that does a convolution or gemm under the hood should
probably execute in ``float16``, and
* Other ops with multiple floating-point tensor inputs should standardize
them to a common precision (unless the implementation supports inputs with different precisions).

If your custom op falls into the third category, the ``promote_type`` template
helps figure out the widest floating-point type present among input tensors, which is
the safest choice for the execution type:

.. code-block:: cpp

#include <ATen/autocast_mode.h>

Tensor my_multiple_input_op_autocast(const Tensor& t0, const Tensor& t1) {
c10::impl::ExcludeDispatchKeyGuard no_autocast(c10::DispatchKey::Autocast);
// The required at::kHalf argument is an optimistic initial guess.
auto exec_type = at::autocast::promote_type(at::kHalf, t0, t1);
return my_multiple_input_op(at::autocast::cached_cast(exec_type, t0),
at::autocast::cached_cast(exec_type, t1));
}

If your custom op is :ref:`autograd-enabled<autograd-support>`, you only need to write and register
an autocast wrapper for the same name onto which the autograd wrapper is registered.
For example, if you wanted an autocast wrapper for the ``myadd`` function shown
in the autograd section, all you'd need is

.. code-block:: cpp

Tensor myadd_autocast(const Tensor& self, const Tensor& other) {
c10::impl::ExcludeDispatchKeyGuard no_autocast(c10::DispatchKey::Autocast);
return myadd(at::autocast::cached_cast(<desired dtype>, self),
at::autocast::cached_cast(<desired dtype>, other));
}

TORCH_LIBRARY_IMPL(myops, Autocast, m) {
m.impl("myadd", myadd_autocast);
}

There are no separate gymnastics to make the backward method autocast compatible.
However, the backward method defined in your custom autograd function will run in the same
dtype as autocast sets for the forward method, so you should choose a ``<desired dtype>``
suitable for both your forward and backward methods.

Batched
^^^^^^^
Expand Down
2 changes: 1 addition & 1 deletion beginner_source/data_loading_tutorial.py
Original file line number Diff line number Diff line change
Expand Up @@ -374,7 +374,7 @@ def __call__(self, sample):
#

dataloader = DataLoader(transformed_dataset, batch_size=4,
shuffle=True, num_workers=4)
shuffle=True, num_workers=0)


# Helper function to show a batch
Expand Down
2 changes: 1 addition & 1 deletion beginner_source/dcgan_faces_tutorial.py
Original file line number Diff line number Diff line change
Expand Up @@ -591,7 +591,7 @@ def forward(self, input):
# Format batch
real_cpu = data[0].to(device)
b_size = real_cpu.size(0)
label = torch.full((b_size,), real_label, device=device)
label = torch.full((b_size,), real_label, dtype=torch.float, device=device)
# Forward pass real batch through D
output = netD(real_cpu).view(-1)
# Calculate loss on all-real batch
Expand Down
10 changes: 10 additions & 0 deletions beginner_source/dist_overview.rst
Original file line number Diff line number Diff line change
Expand Up @@ -195,3 +195,13 @@ RPC Tutorials are listed below:
`@rpc.functions.async_execution <https://pytorch.org/docs/master/rpc.html#torch.distributed.rpc.functions.async_execution>`__
decorator, which can help speed up inference and training. It uses similar
RL and PS examples employed in the above tutorials 1 and 2.
5. The `Combining Distributed DataParallel with Distributed RPC Framework <../advanced/rpc_ddp_tutorial.html>`__
tutorial demonstrates how to combine DDP with RPC to train a model using
distributed data parallelism combined with distributed model parallelism.


PyTorch Distributed Developers
------------------------------

If you'd like to contribute to PyTorch Distributed, please refer to our
`Developer Guide <https://github.com/pytorch/pytorch/blob/master/torch/distributed/CONTRIBUTING.md>`_.
Loading