File tree Expand file tree Collapse file tree 2 files changed +21
-2
lines changed Expand file tree Collapse file tree 2 files changed +21
-2
lines changed Original file line number Diff line number Diff line change 1
1
## FSDP2
2
2
To run FSDP2 on transformer model:
3
3
```
4
+ cd distributed/FSDP2
4
5
torchrun --nproc_per_node 2 train.py
5
6
```
7
+ * For 1st time, it creates a "checkpoints" folder and save state dicts there
8
+ * For 2nd time, it loads from previous checkpoints
9
+
10
+ To enable explicit prefetching
11
+ ```
12
+ torchrun --nproc_per_node 2 train.py --explicit-prefetch
13
+ ```
14
+
15
+ To enable mixed precision
16
+ ```
17
+ torchrun --nproc_per_node 2 train.py --mixed-precision
18
+ ```
19
+
20
+ To showcse DCP API
21
+ ```
22
+ torchrun --nproc_per_node 2 train.py --dcp-api
23
+ ```
6
24
7
25
## Ensure you are running a recent version of PyTorch:
8
26
see https://pytorch.org/get-started/locally/ to install at least 2.5 and ideally a current nightly build.
Original file line number Diff line number Diff line change @@ -57,8 +57,6 @@ def main(args):
57
57
fully_shard (model , ** fsdp_kwargs )
58
58
59
59
inspect_model (model )
60
- if args .mixed_precision :
61
- inspect_mixed_precision (model )
62
60
63
61
if args .explicit_prefetching :
64
62
set_modules_to_forward_prefetch (model , num_to_forward_prefetch = 2 )
@@ -70,6 +68,9 @@ def main(args):
70
68
model .reset_parameters ()
71
69
else :
72
70
checkpointer .load_model (model )
71
+
72
+ if args .mixed_precision :
73
+ inspect_mixed_precision (model )
73
74
74
75
optim = torch .optim .Adam (model .parameters (), lr = 1e-2 )
75
76
if checkpointer .last_training_time is not None :
You can’t perform that action at this time.
0 commit comments