Skip to content

Commit 813da2b

Browse files
authored
Merge branch 'main' into triton_kernel
2 parents db65799 + 79d1723 commit 813da2b

File tree

3 files changed

+6
-14
lines changed

3 files changed

+6
-14
lines changed

beginner_source/basics/saveloadrun_tutorial.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -57,8 +57,8 @@
5757
########################
5858
# We can then load the model as demonstrated below.
5959
#
60-
# As described in `Saving and loading torch.nn.Modules <pytorch.org/docs/main/notes/serialization.html#saving-and-loading-torch-nn-modules>`__,
61-
# saving ``state_dict``s is considered the best practice. However,
60+
# As described in `Saving and loading torch.nn.Modules <https://pytorch.org/docs/main/notes/serialization.html#saving-and-loading-torch-nn-modules>`_,
61+
# saving ``state_dict`` is considered the best practice. However,
6262
# below we use ``weights_only=False`` because this involves loading the
6363
# model, which is a legacy use case for ``torch.save``.
6464

beginner_source/introyt/modelsyt_tutorial.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -311,9 +311,7 @@ def forward(self, sentence):
311311
# ``TransformerDecoder``) and subcomponents (``TransformerEncoderLayer``,
312312
# ``TransformerDecoderLayer``). For details, check out the
313313
# `documentation <https://pytorch.org/docs/stable/nn.html#transformer-layers>`__
314-
# on transformer classes, and the relevant
315-
# `tutorial <https://pytorch.org/tutorials/beginner/transformer_tutorial.html>`__
316-
# on pytorch.org.
314+
# on transformer classes.
317315
#
318316
# Other Layers and Functions
319317
# --------------------------

recipes_source/distributed_checkpoint_recipe.rst

Lines changed: 3 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -82,7 +82,7 @@ Now, let's create a toy module, wrap it with FSDP, feed it with some dummy input
8282
8383
def state_dict(self):
8484
# this line automatically manages FSDP FQN's, as well as sets the default state dict type to FSDP.SHARDED_STATE_DICT
85-
model_state_dict, optimizer_state_dict = get_state_dict(model, optimizer)
85+
model_state_dict, optimizer_state_dict = get_state_dict(self.model, self.optimizer)
8686
return {
8787
"model": model_state_dict,
8888
"optim": optimizer_state_dict
@@ -178,6 +178,7 @@ The reason that we need the ``state_dict`` prior to loading is:
178178
import torch
179179
import torch.distributed as dist
180180
import torch.distributed.checkpoint as dcp
181+
from torch.distributed.checkpoint.stateful import Stateful
181182
from torch.distributed.checkpoint.state_dict import get_state_dict, set_state_dict
182183
import torch.multiprocessing as mp
183184
import torch.nn as nn
@@ -202,7 +203,7 @@ The reason that we need the ``state_dict`` prior to loading is:
202203
203204
def state_dict(self):
204205
# this line automatically manages FSDP FQN's, as well as sets the default state dict type to FSDP.SHARDED_STATE_DICT
205-
model_state_dict, optimizer_state_dict = get_state_dict(model, optimizer)
206+
model_state_dict, optimizer_state_dict = get_state_dict(self.model, self.optimizer)
206207
return {
207208
"model": model_state_dict,
208209
"optim": optimizer_state_dict
@@ -252,13 +253,6 @@ The reason that we need the ``state_dict`` prior to loading is:
252253
optimizer = torch.optim.Adam(model.parameters(), lr=0.1)
253254
254255
state_dict = { "app": AppState(model, optimizer)}
255-
optimizer = torch.optim.Adam(model.parameters(), lr=0.1)
256-
# generates the state dict we will load into
257-
model_state_dict, optimizer_state_dict = get_state_dict(model, optimizer)
258-
state_dict = {
259-
"model": model_state_dict,
260-
"optimizer": optimizer_state_dict
261-
}
262256
dcp.load(
263257
state_dict=state_dict,
264258
checkpoint_id=CHECKPOINT_DIR,

0 commit comments

Comments
 (0)