Skip to content

Fix Distributed Checkpoint Tutorial #3068

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 5 commits into from
Oct 15, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 3 additions & 9 deletions recipes_source/distributed_checkpoint_recipe.rst
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@ Now, let's create a toy module, wrap it with FSDP, feed it with some dummy input

def state_dict(self):
# this line automatically manages FSDP FQN's, as well as sets the default state dict type to FSDP.SHARDED_STATE_DICT
model_state_dict, optimizer_state_dict = get_state_dict(model, optimizer)
model_state_dict, optimizer_state_dict = get_state_dict(self.model, self.optimizer)
return {
"model": model_state_dict,
"optim": optimizer_state_dict
Expand Down Expand Up @@ -178,6 +178,7 @@ The reason that we need the ``state_dict`` prior to loading is:
import torch
import torch.distributed as dist
import torch.distributed.checkpoint as dcp
from torch.distributed.checkpoint.stateful import Stateful
from torch.distributed.checkpoint.state_dict import get_state_dict, set_state_dict
import torch.multiprocessing as mp
import torch.nn as nn
Expand All @@ -202,7 +203,7 @@ The reason that we need the ``state_dict`` prior to loading is:

def state_dict(self):
# this line automatically manages FSDP FQN's, as well as sets the default state dict type to FSDP.SHARDED_STATE_DICT
model_state_dict, optimizer_state_dict = get_state_dict(model, optimizer)
model_state_dict, optimizer_state_dict = get_state_dict(self.model, self.optimizer)
return {
"model": model_state_dict,
"optim": optimizer_state_dict
Expand Down Expand Up @@ -252,13 +253,6 @@ The reason that we need the ``state_dict`` prior to loading is:
optimizer = torch.optim.Adam(model.parameters(), lr=0.1)

state_dict = { "app": AppState(model, optimizer)}
optimizer = torch.optim.Adam(model.parameters(), lr=0.1)
# generates the state dict we will load into
model_state_dict, optimizer_state_dict = get_state_dict(model, optimizer)
state_dict = {
"model": model_state_dict,
"optimizer": optimizer_state_dict
}
dcp.load(
state_dict=state_dict,
checkpoint_id=CHECKPOINT_DIR,
Expand Down
Loading