You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Copy file name to clipboardExpand all lines: intermediate_source/FSDP_tutorial.rst
+12-12Lines changed: 12 additions & 12 deletions
Original file line number
Diff line number
Diff line change
@@ -67,6 +67,7 @@ We add the following code snippets to a python script “FSDP_mnist.py”.
67
67
# Based on: https://github.com/pytorch/examples/blob/master/mnist/main.py
68
68
import os
69
69
import argparse
70
+
import functools
70
71
import torch
71
72
import torch.nn as nn
72
73
import torch.nn.functional as F
@@ -82,14 +83,13 @@ We add the following code snippets to a python script “FSDP_mnist.py”.
82
83
from torch.utils.data.distributed import DistributedSampler
83
84
from torch.distributed.fsdp import FullyShardedDataParallel asFSDP
84
85
from torch.distributed.fsdp.fully_sharded_data_parallel import (
85
-
FullyShardedDataParallel asFSDP,
86
-
CPUOffload,
87
-
BackwardPrefetch,
86
+
CPUOffload,
87
+
BackwardPrefetch,
88
88
)
89
89
from torch.distributed.fsdp.wrap import (
90
-
default_auto_wrap_policy,
91
-
enable_wrap,
92
-
wrap,
90
+
size_based_auto_wrap_policy,
91
+
enable_wrap,
92
+
wrap,
93
93
)
94
94
95
95
1.3 Distributed training setup. As we mentioned FSDP is a type of data parallelism which requires a distributed training environment, so here we use two helper functions to initialize the processes for distributed training and clean up.
@@ -196,7 +196,7 @@ We add the following code snippets to a python script “FSDP_mnist.py”.
0 commit comments