Skip to content

Commit 96cdbae

Browse files
LTMeyersvekars
andauthored
Fix Distributed Checkpoint Tutorial (#3068)
* Fix undefined model and optimizer variables (#3067) * Fix missing import of Stateful (#3067) * Fix state_dict passed to dcp.load (#3067) --------- Co-authored-by: Svetlana Karslioglu <svekars@meta.com>
1 parent 3eee4f2 commit 96cdbae

File tree

1 file changed

+3
-9
lines changed

1 file changed

+3
-9
lines changed

recipes_source/distributed_checkpoint_recipe.rst

Lines changed: 3 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -82,7 +82,7 @@ Now, let's create a toy module, wrap it with FSDP, feed it with some dummy input
8282
8383
def state_dict(self):
8484
# this line automatically manages FSDP FQN's, as well as sets the default state dict type to FSDP.SHARDED_STATE_DICT
85-
model_state_dict, optimizer_state_dict = get_state_dict(model, optimizer)
85+
model_state_dict, optimizer_state_dict = get_state_dict(self.model, self.optimizer)
8686
return {
8787
"model": model_state_dict,
8888
"optim": optimizer_state_dict
@@ -178,6 +178,7 @@ The reason that we need the ``state_dict`` prior to loading is:
178178
import torch
179179
import torch.distributed as dist
180180
import torch.distributed.checkpoint as dcp
181+
from torch.distributed.checkpoint.stateful import Stateful
181182
from torch.distributed.checkpoint.state_dict import get_state_dict, set_state_dict
182183
import torch.multiprocessing as mp
183184
import torch.nn as nn
@@ -202,7 +203,7 @@ The reason that we need the ``state_dict`` prior to loading is:
202203
203204
def state_dict(self):
204205
# this line automatically manages FSDP FQN's, as well as sets the default state dict type to FSDP.SHARDED_STATE_DICT
205-
model_state_dict, optimizer_state_dict = get_state_dict(model, optimizer)
206+
model_state_dict, optimizer_state_dict = get_state_dict(self.model, self.optimizer)
206207
return {
207208
"model": model_state_dict,
208209
"optim": optimizer_state_dict
@@ -252,13 +253,6 @@ The reason that we need the ``state_dict`` prior to loading is:
252253
optimizer = torch.optim.Adam(model.parameters(), lr=0.1)
253254
254255
state_dict = { "app": AppState(model, optimizer)}
255-
optimizer = torch.optim.Adam(model.parameters(), lr=0.1)
256-
# generates the state dict we will load into
257-
model_state_dict, optimizer_state_dict = get_state_dict(model, optimizer)
258-
state_dict = {
259-
"model": model_state_dict,
260-
"optimizer": optimizer_state_dict
261-
}
262256
dcp.load(
263257
state_dict=state_dict,
264258
checkpoint_id=CHECKPOINT_DIR,

0 commit comments

Comments
 (0)