Skip to content

Commit 0ab50ad

Browse files
davidberard98Svetlana Karslioglu
and
Svetlana Karslioglu
authored
Update FSDP tutorial (#2116)
Update FSDP tutorial * rename default_auto_wrap_policy -> size_based_auto_wrap_policy * import functools * indentation Co-authored-by: Svetlana Karslioglu <svekars@fb.com>
1 parent 8a468e3 commit 0ab50ad

File tree

1 file changed

+16
-13
lines changed

1 file changed

+16
-13
lines changed

intermediate_source/FSDP_tutorial.rst

Lines changed: 16 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -62,11 +62,15 @@ We add the following code snippets to a python script “FSDP_mnist.py”.
6262

6363
1.2 Import necessary packages
6464

65+
.. note::
66+
This tutorial is intended for PyTorch versions 1.12 and later. If you are using an earlier version, replace all instances of `size_based_auto_wrap_policy` with `default_auto_wrap_policy`.
67+
6568
.. code-block:: python
6669
6770
# Based on: https://github.com/pytorch/examples/blob/master/mnist/main.py
6871
import os
6972
import argparse
73+
import functools
7074
import torch
7175
import torch.nn as nn
7276
import torch.nn.functional as F
@@ -82,14 +86,13 @@ We add the following code snippets to a python script “FSDP_mnist.py”.
8286
from torch.utils.data.distributed import DistributedSampler
8387
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
8488
from torch.distributed.fsdp.fully_sharded_data_parallel import (
85-
FullyShardedDataParallel as FSDP,
86-
CPUOffload,
87-
BackwardPrefetch,
89+
CPUOffload,
90+
BackwardPrefetch,
8891
)
8992
from torch.distributed.fsdp.wrap import (
90-
default_auto_wrap_policy,
91-
enable_wrap,
92-
wrap,
93+
size_based_auto_wrap_policy,
94+
enable_wrap,
95+
wrap,
9396
)
9497
9598
1.3 Distributed training setup. As we mentioned FSDP is a type of data parallelism which requires a distributed training environment, so here we use two helper functions to initialize the processes for distributed training and clean up.
@@ -196,7 +199,7 @@ We add the following code snippets to a python script “FSDP_mnist.py”.
196199
transform=transforms.Compose([
197200
transforms.ToTensor(),
198201
transforms.Normalize((0.1307,), (0.3081,))
199-
])
202+
])
200203
201204
dataset1 = datasets.MNIST('../data', train=True, download=True,
202205
transform=transform)
@@ -217,8 +220,8 @@ We add the following code snippets to a python script “FSDP_mnist.py”.
217220
train_loader = torch.utils.data.DataLoader(dataset1,**train_kwargs)
218221
test_loader = torch.utils.data.DataLoader(dataset2, **test_kwargs)
219222
my_auto_wrap_policy = functools.partial(
220-
default_auto_wrap_policy, min_num_params=100
221-
)
223+
size_based_auto_wrap_policy, min_num_params=100
224+
)
222225
torch.cuda.set_device(rank)
223226
224227
@@ -248,9 +251,9 @@ We add the following code snippets to a python script “FSDP_mnist.py”.
248251
# use a barrier to make sure training is done on all ranks
249252
dist_barrier()
250253
# state_dict for FSDP model is only available on Nightlies for now
251-
States = model.state_dict()
252-
if rank == 0:
253-
torch.save(states, "mnist_cnn.pt")
254+
states = model.state_dict()
255+
if rank == 0:
256+
torch.save(states, "mnist_cnn.pt")
254257
255258
cleanup()
256259
@@ -343,7 +346,7 @@ Finding an optimal auto wrap policy is challenging, PyTorch will add auto tuning
343346
.. code-block:: python
344347
345348
my_auto_wrap_policy = functools.partial(
346-
default_auto_wrap_policy, min_num_params=20000
349+
size_based_auto_wrap_policy, min_num_params=20000
347350
)
348351
torch.cuda.set_device(rank)
349352
model = Net().to(rank)

0 commit comments

Comments
 (0)