You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Copy file name to clipboardExpand all lines: intermediate_source/FSDP_tutorial.rst
+16-13Lines changed: 16 additions & 13 deletions
Original file line number
Diff line number
Diff line change
@@ -62,11 +62,15 @@ We add the following code snippets to a python script “FSDP_mnist.py”.
62
62
63
63
1.2 Import necessary packages
64
64
65
+
.. note::
66
+
This tutorial is intended for PyTorch versions 1.12 and later. If you are using an earlier version, replace all instances of `size_based_auto_wrap_policy` with `default_auto_wrap_policy`.
67
+
65
68
.. code-block:: python
66
69
67
70
# Based on: https://github.com/pytorch/examples/blob/master/mnist/main.py
68
71
import os
69
72
import argparse
73
+
import functools
70
74
import torch
71
75
import torch.nn as nn
72
76
import torch.nn.functional as F
@@ -82,14 +86,13 @@ We add the following code snippets to a python script “FSDP_mnist.py”.
82
86
from torch.utils.data.distributed import DistributedSampler
83
87
from torch.distributed.fsdp import FullyShardedDataParallel asFSDP
84
88
from torch.distributed.fsdp.fully_sharded_data_parallel import (
85
-
FullyShardedDataParallel asFSDP,
86
-
CPUOffload,
87
-
BackwardPrefetch,
89
+
CPUOffload,
90
+
BackwardPrefetch,
88
91
)
89
92
from torch.distributed.fsdp.wrap import (
90
-
default_auto_wrap_policy,
91
-
enable_wrap,
92
-
wrap,
93
+
size_based_auto_wrap_policy,
94
+
enable_wrap,
95
+
wrap,
93
96
)
94
97
95
98
1.3 Distributed training setup. As we mentioned FSDP is a type of data parallelism which requires a distributed training environment, so here we use two helper functions to initialize the processes for distributed training and clean up.
@@ -196,7 +199,7 @@ We add the following code snippets to a python script “FSDP_mnist.py”.
0 commit comments