You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
# In this tutorial, we train a ``nn.TransformerEncoder`` model on a
32
-
# language modeling task. Please note that this tutorial does not cover
32
+
# causal language modeling task. Please note that this tutorial does not cover
33
33
# the training of `nn.TransformerDecoder <https://pytorch.org/docs/stable/generated/torch.nn.TransformerDecoder.html#torch.nn.TransformerDecoder>`__, as depicted in
34
34
# the right half of the diagram above. The language modeling task is to assign a
35
35
# probability for the likelihood of a given word (or a sequence of words)
@@ -41,8 +41,10 @@
41
41
# Along with the input sequence, a square attention mask is required because the
42
42
# self-attention layers in ``nn.TransformerDecoder`` are only allowed to attend
43
43
# the earlier positions in the sequence. For the language modeling task, any
44
-
# tokens on the future positions should be masked. To produce a probability
45
-
# distribution over output words, the output of the ``nn.TransformerEncoder``
44
+
# tokens on the future positions should be masked. This masking, combined with fact that
45
+
# the output embeddings are offset with later positions ensures that the
46
+
# predictions for position i can depend only on the known outputs at positions less than i.
47
+
# To produce a probability distribution over output words, the output of the ``nn.TransformerEncoder``
46
48
# model is passed through a linear layer to output unnormalized logits.
47
49
# The log-softmax function isn't applied here due to the later use of
Copy file name to clipboardExpand all lines: intermediate_source/FSDP_tutorial.rst
+9Lines changed: 9 additions & 0 deletions
Original file line number
Diff line number
Diff line change
@@ -46,6 +46,15 @@ At a high level FSDP works as follow:
46
46
* Run reduce_scatter to sync gradients
47
47
* Discard parameters.
48
48
49
+
One way to view FSDP's sharding is to decompose the DDP gradient all-reduce into reduce-scatter and all-gather. Specifically, during the backward pass, FSDP reduces and scatters gradients, ensuring that each rank possesses a shard of the gradients. Then it updates the corresponding shard of the parameters in the optimizer step. Finally, in the subsequent forward pass, it performs an all-gather operation to collect and combine the updated parameter shards.
Here we use a toy model to run training on the MNIST dataset for demonstration purposes. The APIs and logic can be applied to training larger models as well.
0 commit comments