Skip to content

Commit 8f2850f

Browse files
authored
Merge branch 'main' into patch-1
2 parents 9552006 + 96cdbae commit 8f2850f

File tree

2 files changed

+4
-12
lines changed

2 files changed

+4
-12
lines changed

beginner_source/introyt/modelsyt_tutorial.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -311,9 +311,7 @@ def forward(self, sentence):
311311
# ``TransformerDecoder``) and subcomponents (``TransformerEncoderLayer``,
312312
# ``TransformerDecoderLayer``). For details, check out the
313313
# `documentation <https://pytorch.org/docs/stable/nn.html#transformer-layers>`__
314-
# on transformer classes, and the relevant
315-
# `tutorial <https://pytorch.org/tutorials/beginner/transformer_tutorial.html>`__
316-
# on pytorch.org.
314+
# on transformer classes.
317315
#
318316
# Other Layers and Functions
319317
# --------------------------

recipes_source/distributed_checkpoint_recipe.rst

Lines changed: 3 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -82,7 +82,7 @@ Now, let's create a toy module, wrap it with FSDP, feed it with some dummy input
8282
8383
def state_dict(self):
8484
# this line automatically manages FSDP FQN's, as well as sets the default state dict type to FSDP.SHARDED_STATE_DICT
85-
model_state_dict, optimizer_state_dict = get_state_dict(model, optimizer)
85+
model_state_dict, optimizer_state_dict = get_state_dict(self.model, self.optimizer)
8686
return {
8787
"model": model_state_dict,
8888
"optim": optimizer_state_dict
@@ -178,6 +178,7 @@ The reason that we need the ``state_dict`` prior to loading is:
178178
import torch
179179
import torch.distributed as dist
180180
import torch.distributed.checkpoint as dcp
181+
from torch.distributed.checkpoint.stateful import Stateful
181182
from torch.distributed.checkpoint.state_dict import get_state_dict, set_state_dict
182183
import torch.multiprocessing as mp
183184
import torch.nn as nn
@@ -202,7 +203,7 @@ The reason that we need the ``state_dict`` prior to loading is:
202203
203204
def state_dict(self):
204205
# this line automatically manages FSDP FQN's, as well as sets the default state dict type to FSDP.SHARDED_STATE_DICT
205-
model_state_dict, optimizer_state_dict = get_state_dict(model, optimizer)
206+
model_state_dict, optimizer_state_dict = get_state_dict(self.model, self.optimizer)
206207
return {
207208
"model": model_state_dict,
208209
"optim": optimizer_state_dict
@@ -252,13 +253,6 @@ The reason that we need the ``state_dict`` prior to loading is:
252253
optimizer = torch.optim.Adam(model.parameters(), lr=0.1)
253254
254255
state_dict = { "app": AppState(model, optimizer)}
255-
optimizer = torch.optim.Adam(model.parameters(), lr=0.1)
256-
# generates the state dict we will load into
257-
model_state_dict, optimizer_state_dict = get_state_dict(model, optimizer)
258-
state_dict = {
259-
"model": model_state_dict,
260-
"optimizer": optimizer_state_dict
261-
}
262256
dcp.load(
263257
state_dict=state_dict,
264258
checkpoint_id=CHECKPOINT_DIR,

0 commit comments

Comments
 (0)