Skip to content

updated FSDP tutorial #2975

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Jul 29, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion intermediate_source/FSDP_adavnced_tutorial.rst
Original file line number Diff line number Diff line change
Expand Up @@ -502,7 +502,7 @@ layer class (holding MHSA and FFN).


model = FSDP(model,
fsdp_auto_wrap_policy=t5_auto_wrap_policy)
auto_wrap_policy=t5_auto_wrap_policy)

To see the wrapped model, you can easily print the model and visually inspect
the sharding and FSDP units as well.
Expand Down
14 changes: 7 additions & 7 deletions intermediate_source/FSDP_tutorial.rst
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ We add the following code snippets to a python script “FSDP_mnist.py”.
1.2 Import necessary packages

.. note::
This tutorial is intended for PyTorch versions 1.12 and later. If you are using an earlier version, replace all instances of `size_based_auto_wrap_policy` with `default_auto_wrap_policy`.
This tutorial is intended for PyTorch versions 1.12 and later. If you are using an earlier version, replace all instances of `size_based_auto_wrap_policy` with `default_auto_wrap_policy` and `fsdp_auto_wrap_policy` with `auto_wrap_policy`.

.. code-block:: python

Expand Down Expand Up @@ -308,7 +308,7 @@ We have recorded cuda events to measure the time of FSDP model specifics. The CU
CUDA event elapsed time on training loop 40.67462890625sec

Wrapping the model with FSDP, the model will look as follows, we can see the model has been wrapped in one FSDP unit.
Alternatively, we will look at adding the fsdp_auto_wrap_policy next and will discuss the differences.
Alternatively, we will look at adding the auto_wrap_policy next and will discuss the differences.

.. code-block:: bash

Expand All @@ -335,12 +335,12 @@ The following is the peak memory usage from FSDP MNIST training on g4dn.12.xlarg

FSDP Peak Memory Usage

Applying *fsdp_auto_wrap_policy* in FSDP otherwise, FSDP will put the entire model in one FSDP unit, which will reduce computation efficiency and memory efficiency.
Applying *auto_wrap_policy* in FSDP otherwise, FSDP will put the entire model in one FSDP unit, which will reduce computation efficiency and memory efficiency.
The way it works is that, suppose your model contains 100 Linear layers. If you do FSDP(model), there will only be one FSDP unit which wraps the entire model.
In that case, the allgather would collect the full parameters for all 100 linear layers, and hence won't save CUDA memory for parameter sharding.
Also, there is only one blocking allgather call for the all 100 linear layers, there will not be communication and computation overlapping between layers.

To avoid that, you can pass in an fsdp_auto_wrap_policy, which will seal the current FSDP unit and start a new one automatically when the specified condition is met (e.g., size limit).
To avoid that, you can pass in an auto_wrap_policy, which will seal the current FSDP unit and start a new one automatically when the specified condition is met (e.g., size limit).
In that way you will have multiple FSDP units, and only one FSDP unit needs to collect full parameters at a time. E.g., suppose you have 5 FSDP units, and each wraps 20 linear layers.
Then, in the forward, the 1st FSDP unit will allgather parameters for the first 20 linear layers, do computation, discard the parameters and then move on to the next 20 linear layers. So, at any point in time, each rank only materializes parameters/grads for 20 linear layers instead of 100.

Expand All @@ -358,9 +358,9 @@ Finding an optimal auto wrap policy is challenging, PyTorch will add auto tuning
model = Net().to(rank)

model = FSDP(model,
fsdp_auto_wrap_policy=my_auto_wrap_policy)
auto_wrap_policy=my_auto_wrap_policy)

Applying the fsdp_auto_wrap_policy, the model would be as follows:
Applying the auto_wrap_policy, the model would be as follows:

.. code-block:: bash

Expand Down Expand Up @@ -411,7 +411,7 @@ In 2.4 we just add it to the FSDP wrapper
.. code-block:: python

model = FSDP(model,
fsdp_auto_wrap_policy=my_auto_wrap_policy,
auto_wrap_policy=my_auto_wrap_policy,
cpu_offload=CPUOffload(offload_params=True))


Expand Down
Loading