Skip to content

Commit 51f640e

Browse files
committed
Update FSDP tutorial
* rename default_auto_wrap_policy -> size_based_auto_wrap_policy * import functools * indentation
1 parent 83d6fec commit 51f640e

File tree

1 file changed

+12
-12
lines changed

1 file changed

+12
-12
lines changed

intermediate_source/FSDP_tutorial.rst

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -67,6 +67,7 @@ We add the following code snippets to a python script “FSDP_mnist.py”.
6767
# Based on: https://github.com/pytorch/examples/blob/master/mnist/main.py
6868
import os
6969
import argparse
70+
import functools
7071
import torch
7172
import torch.nn as nn
7273
import torch.nn.functional as F
@@ -82,14 +83,13 @@ We add the following code snippets to a python script “FSDP_mnist.py”.
8283
from torch.utils.data.distributed import DistributedSampler
8384
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
8485
from torch.distributed.fsdp.fully_sharded_data_parallel import (
85-
FullyShardedDataParallel as FSDP,
86-
CPUOffload,
87-
BackwardPrefetch,
86+
CPUOffload,
87+
BackwardPrefetch,
8888
)
8989
from torch.distributed.fsdp.wrap import (
90-
default_auto_wrap_policy,
91-
enable_wrap,
92-
wrap,
90+
size_based_auto_wrap_policy,
91+
enable_wrap,
92+
wrap,
9393
)
9494
9595
1.3 Distributed training setup. As we mentioned FSDP is a type of data parallelism which requires a distributed training environment, so here we use two helper functions to initialize the processes for distributed training and clean up.
@@ -196,7 +196,7 @@ We add the following code snippets to a python script “FSDP_mnist.py”.
196196
transform=transforms.Compose([
197197
transforms.ToTensor(),
198198
transforms.Normalize((0.1307,), (0.3081,))
199-
])
199+
])
200200
201201
dataset1 = datasets.MNIST('../data', train=True, download=True,
202202
transform=transform)
@@ -217,7 +217,7 @@ We add the following code snippets to a python script “FSDP_mnist.py”.
217217
train_loader = torch.utils.data.DataLoader(dataset1,**train_kwargs)
218218
test_loader = torch.utils.data.DataLoader(dataset2, **test_kwargs)
219219
my_auto_wrap_policy = functools.partial(
220-
default_auto_wrap_policy, min_num_params=100
220+
size_based_auto_wrap_policy, min_num_params=100
221221
)
222222
torch.cuda.set_device(rank)
223223
@@ -248,9 +248,9 @@ We add the following code snippets to a python script “FSDP_mnist.py”.
248248
# use a barrier to make sure training is done on all ranks
249249
dist_barrier()
250250
# state_dict for FSDP model is only available on Nightlies for now
251-
States = model.state_dict()
252-
if rank == 0:
253-
torch.save(states, "mnist_cnn.pt")
251+
states = model.state_dict()
252+
if rank == 0:
253+
torch.save(states, "mnist_cnn.pt")
254254
255255
cleanup()
256256
@@ -343,7 +343,7 @@ Finding an optimal auto wrap policy is challenging, PyTorch will add auto tuning
343343
.. code-block:: python
344344
345345
my_auto_wrap_policy = functools.partial(
346-
default_auto_wrap_policy, min_num_params=20000
346+
size_based_auto_wrap_policy, min_num_params=20000
347347
)
348348
torch.cuda.set_device(rank)
349349
model = Net().to(rank)

0 commit comments

Comments
 (0)