@@ -82,7 +82,7 @@ Now, let's create a toy module, wrap it with FSDP, feed it with some dummy input
82
82
83
83
def state_dict (self ):
84
84
# this line automatically manages FSDP FQN's, as well as sets the default state dict type to FSDP.SHARDED_STATE_DICT
85
- model_state_dict, optimizer_state_dict = get_state_dict(model, optimizer)
85
+ model_state_dict, optimizer_state_dict = get_state_dict(self . model, self . optimizer)
86
86
return {
87
87
" model" : model_state_dict,
88
88
" optim" : optimizer_state_dict
@@ -178,6 +178,7 @@ The reason that we need the ``state_dict`` prior to loading is:
178
178
import torch
179
179
import torch.distributed as dist
180
180
import torch.distributed.checkpoint as dcp
181
+ from torch.distributed.checkpoint.stateful import Stateful
181
182
from torch.distributed.checkpoint.state_dict import get_state_dict, set_state_dict
182
183
import torch.multiprocessing as mp
183
184
import torch.nn as nn
@@ -202,7 +203,7 @@ The reason that we need the ``state_dict`` prior to loading is:
202
203
203
204
def state_dict (self ):
204
205
# this line automatically manages FSDP FQN's, as well as sets the default state dict type to FSDP.SHARDED_STATE_DICT
205
- model_state_dict, optimizer_state_dict = get_state_dict(model, optimizer)
206
+ model_state_dict, optimizer_state_dict = get_state_dict(self . model, self . optimizer)
206
207
return {
207
208
" model" : model_state_dict,
208
209
" optim" : optimizer_state_dict
@@ -252,13 +253,6 @@ The reason that we need the ``state_dict`` prior to loading is:
252
253
optimizer = torch.optim.Adam(model.parameters(), lr = 0.1 )
253
254
254
255
state_dict = { " app" : AppState(model, optimizer)}
255
- optimizer = torch.optim.Adam(model.parameters(), lr = 0.1 )
256
- # generates the state dict we will load into
257
- model_state_dict, optimizer_state_dict = get_state_dict(model, optimizer)
258
- state_dict = {
259
- " model" : model_state_dict,
260
- " optimizer" : optimizer_state_dict
261
- }
262
256
dcp.load(
263
257
state_dict = state_dict,
264
258
checkpoint_id = CHECKPOINT_DIR ,
0 commit comments