From d490aef00b8ec36153b0d2145ae97517257a34c8 Mon Sep 17 00:00:00 2001 From: Lucas Meyer Date: Mon, 30 Sep 2024 16:50:52 +0200 Subject: [PATCH 1/3] Fix undefined model and optimizer variables (#3067) --- recipes_source/distributed_checkpoint_recipe.rst | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/recipes_source/distributed_checkpoint_recipe.rst b/recipes_source/distributed_checkpoint_recipe.rst index 8f93c2222d6..4956c873dbe 100644 --- a/recipes_source/distributed_checkpoint_recipe.rst +++ b/recipes_source/distributed_checkpoint_recipe.rst @@ -82,7 +82,7 @@ Now, let's create a toy module, wrap it with FSDP, feed it with some dummy input def state_dict(self): # this line automatically manages FSDP FQN's, as well as sets the default state dict type to FSDP.SHARDED_STATE_DICT - model_state_dict, optimizer_state_dict = get_state_dict(model, optimizer) + model_state_dict, optimizer_state_dict = get_state_dict(self.model, self.optimizer) return { "model": model_state_dict, "optim": optimizer_state_dict @@ -202,7 +202,7 @@ The reason that we need the ``state_dict`` prior to loading is: def state_dict(self): # this line automatically manages FSDP FQN's, as well as sets the default state dict type to FSDP.SHARDED_STATE_DICT - model_state_dict, optimizer_state_dict = get_state_dict(model, optimizer) + model_state_dict, optimizer_state_dict = get_state_dict(self.model, self.optimizer) return { "model": model_state_dict, "optim": optimizer_state_dict From 0e52e075d53494c325e0d8cf7175537d7ad02e6d Mon Sep 17 00:00:00 2001 From: Lucas Meyer Date: Mon, 30 Sep 2024 16:51:45 +0200 Subject: [PATCH 2/3] Fix missing import of Stateful (#3067) --- recipes_source/distributed_checkpoint_recipe.rst | 1 + 1 file changed, 1 insertion(+) diff --git a/recipes_source/distributed_checkpoint_recipe.rst b/recipes_source/distributed_checkpoint_recipe.rst index 4956c873dbe..eca9e68f1b3 100644 --- a/recipes_source/distributed_checkpoint_recipe.rst +++ b/recipes_source/distributed_checkpoint_recipe.rst @@ -178,6 +178,7 @@ The reason that we need the ``state_dict`` prior to loading is: import torch import torch.distributed as dist import torch.distributed.checkpoint as dcp + from torch.distributed.checkpoint.stateful import Stateful from torch.distributed.checkpoint.state_dict import get_state_dict, set_state_dict import torch.multiprocessing as mp import torch.nn as nn From 1f1aad64760fa751a056621b005aabe9a9c37d94 Mon Sep 17 00:00:00 2001 From: Lucas Meyer Date: Mon, 30 Sep 2024 16:52:12 +0200 Subject: [PATCH 3/3] Fix state_dict passed to dcp.load (#3067) --- recipes_source/distributed_checkpoint_recipe.rst | 7 ------- 1 file changed, 7 deletions(-) diff --git a/recipes_source/distributed_checkpoint_recipe.rst b/recipes_source/distributed_checkpoint_recipe.rst index eca9e68f1b3..374b5af2b7b 100644 --- a/recipes_source/distributed_checkpoint_recipe.rst +++ b/recipes_source/distributed_checkpoint_recipe.rst @@ -253,13 +253,6 @@ The reason that we need the ``state_dict`` prior to loading is: optimizer = torch.optim.Adam(model.parameters(), lr=0.1) state_dict = { "app": AppState(model, optimizer)} - optimizer = torch.optim.Adam(model.parameters(), lr=0.1) - # generates the state dict we will load into - model_state_dict, optimizer_state_dict = get_state_dict(model, optimizer) - state_dict = { - "model": model_state_dict, - "optimizer": optimizer_state_dict - } dcp.load( state_dict=state_dict, checkpoint_id=CHECKPOINT_DIR,