Skip to content

Commit b9bc71c

Browse files
authored
Merge branch 'master' into DQN_revise_training
2 parents df822d4 + 4be5bf0 commit b9bc71c

File tree

5 files changed

+21
-18
lines changed

5 files changed

+21
-18
lines changed

beginner_source/basics/optimization_tutorial.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
===========================
1414
1515
Now that we have a model and data it's time to train, validate and test our model by optimizing its parameters on
16-
our data. Training a model is an iterative process; in each iteration (called an *epoch*) the model makes a guess about the output, calculates
16+
our data. Training a model is an iterative process; in each iteration the model makes a guess about the output, calculates
1717
the error in its guess (*loss*), collects the derivatives of the error with respect to its parameters (as we saw in
1818
the `previous section <autograd_tutorial.html>`_), and **optimizes** these parameters using gradient descent. For a more
1919
detailed walkthrough of this process, check out this video on `backpropagation from 3Blue1Brown <https://www.youtube.com/watch?v=tIeHLnjs5U8>`__.

distributed/home.rst

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -51,8 +51,8 @@ Learn DDP
5151
:link: https://pytorch.org/tutorials/advanced/generic_join.html?utm_source=distr_landing&utm_medium=generic_join
5252
:link-type: url
5353

54-
This tutorial provides a short and gentle intro to the PyTorch
55-
DistributedData Parallel.
54+
This tutorial describes the Join context manager and
55+
demonstrates it's use with DistributedData Parallel.
5656
+++
5757
:octicon:`code;1em` Code
5858

index.rst

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -233,7 +233,7 @@ What's new in PyTorch tutorials?
233233

234234
.. customcarditem::
235235
:header: NLP from Scratch: Generating Names with a Character-level RNN
236-
:card_description: After using character-level RNN to classify names, leanr how to generate names from languages. Second in a series of three tutorials.
236+
:card_description: After using character-level RNN to classify names, learn how to generate names from languages. Second in a series of three tutorials.
237237
:image: _static/img/thumbnails/cropped/NLP-From-Scratch-Generating-Names-with-a-Character-Level-RNN.png
238238
:link: intermediate/char_rnn_generation_tutorial.html
239239
:tags: Text

intermediate_source/FSDP_tutorial.rst

Lines changed: 16 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -62,11 +62,15 @@ We add the following code snippets to a python script “FSDP_mnist.py”.
6262

6363
1.2 Import necessary packages
6464

65+
.. note::
66+
This tutorial is intended for PyTorch versions 1.12 and later. If you are using an earlier version, replace all instances of `size_based_auto_wrap_policy` with `default_auto_wrap_policy`.
67+
6568
.. code-block:: python
6669
6770
# Based on: https://github.com/pytorch/examples/blob/master/mnist/main.py
6871
import os
6972
import argparse
73+
import functools
7074
import torch
7175
import torch.nn as nn
7276
import torch.nn.functional as F
@@ -82,14 +86,13 @@ We add the following code snippets to a python script “FSDP_mnist.py”.
8286
from torch.utils.data.distributed import DistributedSampler
8387
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
8488
from torch.distributed.fsdp.fully_sharded_data_parallel import (
85-
FullyShardedDataParallel as FSDP,
86-
CPUOffload,
87-
BackwardPrefetch,
89+
CPUOffload,
90+
BackwardPrefetch,
8891
)
8992
from torch.distributed.fsdp.wrap import (
90-
default_auto_wrap_policy,
91-
enable_wrap,
92-
wrap,
93+
size_based_auto_wrap_policy,
94+
enable_wrap,
95+
wrap,
9396
)
9497
9598
1.3 Distributed training setup. As we mentioned FSDP is a type of data parallelism which requires a distributed training environment, so here we use two helper functions to initialize the processes for distributed training and clean up.
@@ -196,7 +199,7 @@ We add the following code snippets to a python script “FSDP_mnist.py”.
196199
transform=transforms.Compose([
197200
transforms.ToTensor(),
198201
transforms.Normalize((0.1307,), (0.3081,))
199-
])
202+
])
200203
201204
dataset1 = datasets.MNIST('../data', train=True, download=True,
202205
transform=transform)
@@ -217,8 +220,8 @@ We add the following code snippets to a python script “FSDP_mnist.py”.
217220
train_loader = torch.utils.data.DataLoader(dataset1,**train_kwargs)
218221
test_loader = torch.utils.data.DataLoader(dataset2, **test_kwargs)
219222
my_auto_wrap_policy = functools.partial(
220-
default_auto_wrap_policy, min_num_params=100
221-
)
223+
size_based_auto_wrap_policy, min_num_params=100
224+
)
222225
torch.cuda.set_device(rank)
223226
224227
@@ -248,9 +251,9 @@ We add the following code snippets to a python script “FSDP_mnist.py”.
248251
# use a barrier to make sure training is done on all ranks
249252
dist_barrier()
250253
# state_dict for FSDP model is only available on Nightlies for now
251-
States = model.state_dict()
252-
if rank == 0:
253-
torch.save(states, "mnist_cnn.pt")
254+
states = model.state_dict()
255+
if rank == 0:
256+
torch.save(states, "mnist_cnn.pt")
254257
255258
cleanup()
256259
@@ -343,7 +346,7 @@ Finding an optimal auto wrap policy is challenging, PyTorch will add auto tuning
343346
.. code-block:: python
344347
345348
my_auto_wrap_policy = functools.partial(
346-
default_auto_wrap_policy, min_num_params=20000
349+
size_based_auto_wrap_policy, min_num_params=20000
347350
)
348351
torch.cuda.set_device(rank)
349352
model = Net().to(rank)

recipes_source/recipes_index.rst

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -120,7 +120,7 @@ Recipes are bite-sized, actionable examples of how to use specific PyTorch featu
120120
:header: PyTorch Profiler with Instrumentation and Tracing Technology API (ITT API) support
121121
:card_description: Learn how to use PyTorch's profiler with Instrumentation and Tracing Technology API (ITT API) to visualize operators labeling in Intel® VTune™ Profiler GUI
122122
:image: ../_static/img/thumbnails/cropped/profiler.png
123-
:link: ../recipes/recipes/profile_with_itt.html
123+
:link: ../recipes/profile_with_itt.html
124124
:tags: Basics
125125

126126
.. Interpretability

0 commit comments

Comments
 (0)