diff --git a/recipes_source/distributed_checkpoint_recipe.rst b/recipes_source/distributed_checkpoint_recipe.rst index 118dc7e7794..6a70bb02b0b 100644 --- a/recipes_source/distributed_checkpoint_recipe.rst +++ b/recipes_source/distributed_checkpoint_recipe.rst @@ -193,6 +193,7 @@ The reason that we need the ``state_dict`` prior to loading is: model = ToyModel().to(rank) model = FSDP(model) + optimizer = torch.optim.Adam(model.parameters(), lr=0.1) # generates the state dict we will load into model_state_dict, optimizer_state_dict = get_state_dict(model, optimizer) state_dict = {