Skip to content

Commit 98dd2c2

Browse files
committed
update README
Summary: Test Plan: Reviewers: Subscribers: Tasks: Tags:
1 parent 7815968 commit 98dd2c2

File tree

2 files changed

+21
-2
lines changed

2 files changed

+21
-2
lines changed

distributed/FSDP2/README.md

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,26 @@
11
## FSDP2
22
To run FSDP2 on transformer model:
33
```
4+
cd distributed/FSDP2
45
torchrun --nproc_per_node 2 train.py
56
```
7+
* For 1st time, it creates a "checkpoints" folder and save state dicts there
8+
* For 2nd time, it loads from previous checkpoints
9+
10+
To enable explicit prefetching
11+
```
12+
torchrun --nproc_per_node 2 train.py --explicit-prefetch
13+
```
14+
15+
To enable mixed precision
16+
```
17+
torchrun --nproc_per_node 2 train.py --mixed-precision
18+
```
19+
20+
To showcse DCP API
21+
```
22+
torchrun --nproc_per_node 2 train.py --dcp-api
23+
```
624

725
## Ensure you are running a recent version of PyTorch:
826
see https://pytorch.org/get-started/locally/ to install at least 2.5 and ideally a current nightly build.

distributed/FSDP2/train.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -57,8 +57,6 @@ def main(args):
5757
fully_shard(model, **fsdp_kwargs)
5858

5959
inspect_model(model)
60-
if args.mixed_precision:
61-
inspect_mixed_precision(model)
6260

6361
if args.explicit_prefetching:
6462
set_modules_to_forward_prefetch(model, num_to_forward_prefetch=2)
@@ -70,6 +68,9 @@ def main(args):
7068
model.reset_parameters()
7169
else:
7270
checkpointer.load_model(model)
71+
72+
if args.mixed_precision:
73+
inspect_mixed_precision(model)
7374

7475
optim = torch.optim.Adam(model.parameters(), lr=1e-2)
7576
if checkpointer.last_training_time is not None:

0 commit comments

Comments
 (0)