Skip to content

Commit 7092296

Browse files
authored
FSDP2 example code for tutorial (#1343)
* FSDP2 example Summary: Test Plan: Reviewers: Subscribers: Tasks: Tags: * update README Summary: Test Plan: Reviewers: Subscribers: Tasks: Tags: * fix typo in README Summary: Test Plan: Reviewers: Subscribers: Tasks: Tags: * fix README Summary: Test Plan: Reviewers: Subscribers: Tasks: Tags:
1 parent 54e132e commit 7092296

File tree

5 files changed

+492
-0
lines changed

5 files changed

+492
-0
lines changed

distributed/FSDP2/README.md

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,26 @@
1+
## FSDP2
2+
To run FSDP2 on transformer model:
3+
```
4+
cd distributed/FSDP2
5+
torchrun --nproc_per_node 2 train.py
6+
```
7+
* For 1st time, it creates a "checkpoints" folder and saves state dicts there
8+
* For 2nd time, it loads from previous checkpoints
9+
10+
To enable explicit prefetching
11+
```
12+
torchrun --nproc_per_node 2 train.py --explicit-prefetch
13+
```
14+
15+
To enable mixed precision
16+
```
17+
torchrun --nproc_per_node 2 train.py --mixed-precision
18+
```
19+
20+
To showcase DCP API
21+
```
22+
torchrun --nproc_per_node 2 train.py --dcp-api
23+
```
24+
25+
## Ensure you are running a recent version of PyTorch:
26+
see https://pytorch.org/get-started/locally/ to install at least 2.5 and ideally a current nightly build.

distributed/FSDP2/checkpoint.py

Lines changed: 209 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,209 @@
1+
import os
2+
import time
3+
4+
import torch
5+
import torch.nn as nn
6+
from torch.distributed.checkpoint.state_dict import (
7+
_init_optim_state,
8+
get_model_state_dict,
9+
get_optimizer_state_dict,
10+
set_model_state_dict,
11+
set_optimizer_state_dict,
12+
StateDictOptions,
13+
)
14+
from torch.distributed.fsdp import FSDPModule
15+
from torch.distributed.tensor import distribute_tensor, DTensor
16+
17+
18+
MODEL_CHECKPOINT = "model_state_dict.pt"
19+
OPTIM_CHECKPOINT = "optim_state_dict.pt"
20+
PARAMS = "params"
21+
22+
23+
def get_latest_checkpoint_folder(path):
24+
max_num = None
25+
if not os.path.exists(path):
26+
return max_num
27+
for name in os.listdir(path):
28+
folder_path = os.path.join(path, name)
29+
if os.path.isdir(folder_path):
30+
try:
31+
num = int(name)
32+
if max_num is None or num > max_num:
33+
max_num = num
34+
except ValueError:
35+
pass # Skip non-numeric folder names
36+
return max_num
37+
38+
39+
class Checkpointer:
40+
def __init__(self, folder: str, dcp_api: bool):
41+
self.folder = folder
42+
self.dcp_api = dcp_api
43+
self.last_training_time = get_latest_checkpoint_folder(
44+
f"{folder}/{'dcp_api' if dcp_api else 'dtensor_api'}"
45+
)
46+
47+
def is_empty(self):
48+
return self.last_training_time is None
49+
50+
def load_model(self, model: FSDPModule):
51+
last_model_checkpoint = (
52+
f"{self.folder}/{'dcp_api' if self.dcp_api else 'dtensor_api'}"
53+
f"/{self.last_training_time}/{MODEL_CHECKPOINT}"
54+
)
55+
full_sd = torch.load(
56+
last_model_checkpoint, mmap=True, weights_only=True, map_location="cpu"
57+
)
58+
if self.dcp_api:
59+
set_model_state_dict(
60+
model=model,
61+
model_state_dict=full_sd,
62+
options=StateDictOptions(
63+
full_state_dict=True,
64+
broadcast_from_rank0=True,
65+
),
66+
)
67+
return
68+
meta_sharded_sd = model.state_dict()
69+
sharded_sd = {}
70+
for param_name, full_tensor in full_sd.items():
71+
sharded_meta_param = meta_sharded_sd.get(param_name)
72+
sharded_tensor = distribute_tensor(
73+
full_tensor,
74+
sharded_meta_param.device_mesh,
75+
sharded_meta_param.placements,
76+
)
77+
sharded_sd[param_name] = nn.Parameter(sharded_tensor)
78+
# choose `assign=True` since we cannot call `copy_` on meta tensor
79+
model.load_state_dict(sharded_sd, strict=False, assign=True)
80+
81+
def load_optim(self, model: FSDPModule, opt: torch.optim.Optimizer):
82+
last_optim_checkpoint = (
83+
f"{self.folder}/{'dcp_api' if self.dcp_api else 'dtensor_api'}"
84+
f"/{self.last_training_time}/{OPTIM_CHECKPOINT}"
85+
)
86+
full_sd = torch.load(
87+
last_optim_checkpoint, mmap=True, weights_only=True, map_location="cpu"
88+
)
89+
if self.dcp_api:
90+
set_optimizer_state_dict(
91+
model=model,
92+
optimizers=opt,
93+
optim_state_dict=full_sd,
94+
options=StateDictOptions(
95+
full_state_dict=True,
96+
broadcast_from_rank0=True,
97+
),
98+
)
99+
return
100+
_init_optim_state(opt)
101+
param_groups = opt.state_dict()["param_groups"]
102+
state = opt.state_dict()["state"]
103+
104+
full_param_groups = full_sd["param_groups"]
105+
full_state = full_sd["state"]
106+
107+
for param_group, full_param_group in zip(param_groups, full_param_groups):
108+
for key, value in full_param_group.items():
109+
if key == PARAMS:
110+
continue
111+
param_group[key] = value
112+
for pid, full_pid in zip(param_group[PARAMS], full_param_group[PARAMS]):
113+
if pid not in state:
114+
continue
115+
param_state = state[pid]
116+
full_param_state = full_state[full_pid]
117+
for attr, full_tensor in full_param_state.items():
118+
sharded_tensor = param_state[attr]
119+
if isinstance(sharded_tensor, DTensor):
120+
# exp_avg is DTensor
121+
param_state[attr] = distribute_tensor(
122+
full_tensor,
123+
sharded_tensor.device_mesh,
124+
sharded_tensor.placements,
125+
)
126+
else:
127+
# step is plain tensor
128+
param_state[attr] = full_tensor
129+
opt.load_state_dict(
130+
{
131+
"param_groups": param_groups,
132+
"state": state,
133+
}
134+
)
135+
136+
def _get_full_model_state_dict(self, model: FSDPModule):
137+
if self.dcp_api:
138+
return get_model_state_dict(
139+
model=model,
140+
options=StateDictOptions(
141+
full_state_dict=True,
142+
cpu_offload=True,
143+
),
144+
)
145+
146+
sharded_sd = model.state_dict()
147+
cpu_state_dict = {}
148+
for param_name, sharded_param in sharded_sd.items():
149+
full_param = sharded_param.full_tensor()
150+
if torch.distributed.get_rank() == 0:
151+
cpu_state_dict[param_name] = full_param.cpu()
152+
else:
153+
del full_param
154+
return cpu_state_dict
155+
156+
def _get_full_optimizer_state_dict(
157+
self,
158+
model: FSDPModule,
159+
opt: torch.optim.Optimizer,
160+
):
161+
if self.dcp_api:
162+
return get_optimizer_state_dict(
163+
model=model,
164+
optimizers=opt,
165+
options=StateDictOptions(
166+
full_state_dict=True,
167+
cpu_offload=True,
168+
),
169+
)
170+
is_rank_zero = torch.distributed.get_rank() == 0
171+
sharded_sd = opt.state_dict()
172+
sharded_state = sharded_sd["state"]
173+
full_state = {}
174+
for group_id, sharded_group in sharded_state.items():
175+
group_state = {}
176+
for attr, sharded_tensor in sharded_group.items():
177+
if isinstance(sharded_tensor, DTensor):
178+
# "exp_avg" in AdamW is `DTensor`
179+
full_tensor = sharded_tensor.full_tensor()
180+
else:
181+
# "step" in AdamW is plain tensor
182+
full_tensor = sharded_tensor
183+
if is_rank_zero:
184+
group_state[attr] = full_tensor.cpu()
185+
else:
186+
del full_tensor
187+
if is_rank_zero:
188+
full_state[group_id] = group_state
189+
else:
190+
del group_state
191+
if is_rank_zero:
192+
return {
193+
"param_groups": sharded_sd["param_groups"],
194+
"state": full_state,
195+
}
196+
else:
197+
return {}
198+
199+
def save(self, model: FSDPModule, optim: torch.optim.Optimizer):
200+
model_state_dict = self._get_full_model_state_dict(model)
201+
optim_state_dict = self._get_full_optimizer_state_dict(model, optim)
202+
if torch.distributed.get_rank() == 0:
203+
new_training_time = int(time.time() * 1000)
204+
new_checkpoint_folder = f"{self.folder}/{'dcp_api' if self.dcp_api else 'dtensor_api'}/{new_training_time}"
205+
new_model_checkpoint = f"{new_checkpoint_folder}/{MODEL_CHECKPOINT}"
206+
new_optim_checkpoint = f"{new_checkpoint_folder}/{OPTIM_CHECKPOINT}"
207+
os.makedirs(new_checkpoint_folder, exist_ok=True)
208+
torch.save(model_state_dict, new_model_checkpoint)
209+
torch.save(optim_state_dict, new_optim_checkpoint)

distributed/FSDP2/model.py

Lines changed: 134 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,134 @@
1+
from dataclasses import dataclass
2+
3+
import torch
4+
import torch.nn as nn
5+
import torch.nn.functional as F
6+
7+
8+
@dataclass
9+
class ModelArgs:
10+
n_layers: int = 2
11+
vocab_size: int = 8
12+
max_seq_len: int = 16
13+
dim: int = 16
14+
n_heads: int = 4
15+
dropout_p: float = 0.1
16+
17+
18+
class Attention(nn.Module):
19+
def __init__(self, args: ModelArgs):
20+
super().__init__()
21+
assert args.dim % args.n_heads == 0
22+
self.head_dim = args.dim // args.n_heads
23+
self.n_heads = args.n_heads
24+
self.dropout_p = args.dropout_p
25+
self.resid_dropout = nn.Dropout(args.dropout_p)
26+
27+
self.wq = nn.Linear(args.dim, args.dim, bias=False)
28+
self.wk = nn.Linear(args.dim, args.dim, bias=False)
29+
self.wv = nn.Linear(args.dim, args.dim, bias=False)
30+
self.wo = nn.Linear(args.dim, args.dim, bias=False)
31+
32+
def forward(self, x):
33+
bsz, seq_len, _ = x.size()
34+
queries, keys, values = self.wq(x), self.wk(x), self.wv(x)
35+
queries = queries.view(bsz, seq_len, self.n_heads, self.head_dim)
36+
keys = keys.view(bsz, seq_len, self.n_heads, self.head_dim)
37+
values = values.view(bsz, seq_len, self.n_heads, self.head_dim)
38+
39+
queries = queries.transpose(1, 2) # (bsz, n_heads, seq_len, head_dim)
40+
keys = keys.transpose(1, 2) # (bsz, n_heads, seq_len, head_dim)
41+
values = values.transpose(1, 2) # (bsz, n_heads, seq_len, head_dim)
42+
43+
output = F.scaled_dot_product_attention(
44+
queries,
45+
keys,
46+
values,
47+
None,
48+
self.dropout_p if self.training else 0,
49+
)
50+
output = output.transpose(1, 2).contiguous().view(bsz, seq_len, -1)
51+
return self.resid_dropout(self.wo(output))
52+
53+
def reset_parameters(self):
54+
self.wq.reset_parameters()
55+
self.wk.reset_parameters()
56+
self.wv.reset_parameters()
57+
self.wo.reset_parameters()
58+
59+
60+
class FeedForward(nn.Module):
61+
def __init__(self, dim, hidden_dim, dropout_p):
62+
super().__init__()
63+
self.w1 = nn.Linear(dim, hidden_dim)
64+
self.gelu = nn.GELU()
65+
self.w2 = nn.Linear(hidden_dim, dim)
66+
self.resid_dropout = nn.Dropout(dropout_p)
67+
68+
def forward(self, x):
69+
return self.resid_dropout(self.w2(self.gelu(self.w1(x))))
70+
71+
def reset_parameters(self):
72+
self.w1.reset_parameters()
73+
self.w2.reset_parameters()
74+
75+
76+
class TransformerBlock(nn.Module):
77+
def __init__(self, args: ModelArgs):
78+
super().__init__()
79+
self.attention_norm = nn.LayerNorm(args.dim)
80+
self.attention = Attention(args)
81+
self.ffn_norm = nn.LayerNorm(args.dim)
82+
self.feed_forward = FeedForward(
83+
args.dim, hidden_dim=4 * args.dim, dropout_p=args.dropout_p
84+
)
85+
86+
def forward(self, x):
87+
h = x + self.attention(self.attention_norm(x))
88+
out = h + self.feed_forward(self.ffn_norm(h))
89+
return out
90+
91+
def reset_parameters(self):
92+
self.attention_norm.reset_parameters()
93+
self.attention.reset_parameters()
94+
self.ffn_norm.reset_parameters()
95+
self.feed_forward.reset_parameters()
96+
97+
98+
# A toy transformer model, partly inspired by the nanoGPT model:
99+
# https://github.com/karpathy/nanoGPT.
100+
class Transformer(nn.Module):
101+
def __init__(self, args: ModelArgs):
102+
super().__init__()
103+
assert args.vocab_size is not None
104+
assert args.max_seq_len is not None
105+
self.model_args = args
106+
self.max_seq_len = args.max_seq_len
107+
self.tok_embeddings = nn.Embedding(args.vocab_size, args.dim)
108+
self.pos_embeddings = nn.Embedding(args.max_seq_len, args.dim)
109+
self.dropout = nn.Dropout(args.dropout_p)
110+
self.layers = nn.ModuleList()
111+
for _ in range(args.n_layers):
112+
self.layers.append(TransformerBlock(args))
113+
self.norm = nn.LayerNorm(args.dim)
114+
self.output = nn.Linear(args.dim, args.vocab_size, bias=False)
115+
116+
def forward(self, tokens):
117+
_bsz, seq_len = tokens.size()
118+
assert seq_len <= self.max_seq_len
119+
h = self.tok_embeddings(tokens)
120+
pos = torch.arange(0, seq_len, device=tokens.device)
121+
p = self.pos_embeddings(pos) # positional embeddings of shape (seq_len, dim)
122+
h = h + p
123+
h = self.dropout(h)
124+
for layer in self.layers:
125+
h = layer(h)
126+
h = self.norm(h)
127+
output = self.output(h).float()
128+
return output
129+
130+
def reset_parameters(self):
131+
self.tok_embeddings.reset_parameters()
132+
self.pos_embeddings.reset_parameters()
133+
self.norm.reset_parameters()
134+
self.output.reset_parameters()

0 commit comments

Comments
 (0)