Skip to content

Commit babfbf3

Browse files
authored
Merge branch 'main' into fix-neural_style_tutorial_weight_init
2 parents c2afd4a + dc448c2 commit babfbf3

File tree

2 files changed

+9
-0
lines changed

2 files changed

+9
-0
lines changed
91 KB
Loading

intermediate_source/FSDP_tutorial.rst

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,15 @@ At a high level FSDP works as follow:
4646
* Run reduce_scatter to sync gradients
4747
* Discard parameters.
4848

49+
One way to view FSDP's sharding is to decompose the DDP gradient all-reduce into reduce-scatter and all-gather. Specifically, during the backward pass, FSDP reduces and scatters gradients, ensuring that each rank possesses a shard of the gradients. Then it updates the corresponding shard of the parameters in the optimizer step. Finally, in the subsequent forward pass, it performs an all-gather operation to collect and combine the updated parameter shards.
50+
51+
.. figure:: /_static/img/distributed/fsdp_sharding.png
52+
:width: 100%
53+
:align: center
54+
:alt: FSDP allreduce
55+
56+
FSDP Allreduce
57+
4958
How to use FSDP
5059
--------------
5160
Here we use a toy model to run training on the MNIST dataset for demonstration purposes. The APIs and logic can be applied to training larger models as well.

0 commit comments

Comments
 (0)