Skip to content

Update FSDP tutorial #2116

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 5 commits into from
Nov 15, 2022
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 16 additions & 13 deletions intermediate_source/FSDP_tutorial.rst
Original file line number Diff line number Diff line change
Expand Up @@ -62,11 +62,15 @@ We add the following code snippets to a python script “FSDP_mnist.py”.

1.2 Import necessary packages

.. note::
This tutorial is intended for PyTorch versions 1.12 and later. If you are using an earlier version, replace all instances of `size_based_auto_wrap_policy` with `default_auto_wrap_policy`.

.. code-block:: python

# Based on: https://github.com/pytorch/examples/blob/master/mnist/main.py
import os
import argparse
import functools
import torch
import torch.nn as nn
import torch.nn.functional as F
Expand All @@ -82,14 +86,13 @@ We add the following code snippets to a python script “FSDP_mnist.py”.
from torch.utils.data.distributed import DistributedSampler
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
from torch.distributed.fsdp.fully_sharded_data_parallel import (
FullyShardedDataParallel as FSDP,
CPUOffload,
BackwardPrefetch,
CPUOffload,
BackwardPrefetch,
)
from torch.distributed.fsdp.wrap import (
default_auto_wrap_policy,
enable_wrap,
wrap,
size_based_auto_wrap_policy,
enable_wrap,
wrap,
)

1.3 Distributed training setup. As we mentioned FSDP is a type of data parallelism which requires a distributed training environment, so here we use two helper functions to initialize the processes for distributed training and clean up.
Expand Down Expand Up @@ -196,7 +199,7 @@ We add the following code snippets to a python script “FSDP_mnist.py”.
transform=transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.1307,), (0.3081,))
])
])

dataset1 = datasets.MNIST('../data', train=True, download=True,
transform=transform)
Expand All @@ -217,8 +220,8 @@ We add the following code snippets to a python script “FSDP_mnist.py”.
train_loader = torch.utils.data.DataLoader(dataset1,**train_kwargs)
test_loader = torch.utils.data.DataLoader(dataset2, **test_kwargs)
my_auto_wrap_policy = functools.partial(
default_auto_wrap_policy, min_num_params=100
)
size_based_auto_wrap_policy, min_num_params=100
)
torch.cuda.set_device(rank)


Expand Down Expand Up @@ -248,9 +251,9 @@ We add the following code snippets to a python script “FSDP_mnist.py”.
# use a barrier to make sure training is done on all ranks
dist_barrier()
# state_dict for FSDP model is only available on Nightlies for now
States = model.state_dict()
if rank == 0:
torch.save(states, "mnist_cnn.pt")
states = model.state_dict()
if rank == 0:
torch.save(states, "mnist_cnn.pt")

cleanup()

Expand Down Expand Up @@ -343,7 +346,7 @@ Finding an optimal auto wrap policy is challenging, PyTorch will add auto tuning
.. code-block:: python

my_auto_wrap_policy = functools.partial(
default_auto_wrap_policy, min_num_params=20000
size_based_auto_wrap_policy, min_num_params=20000
)
torch.cuda.set_device(rank)
model = Net().to(rank)
Expand Down