Skip to content

Commit be0b425

Browse files
authored
[Training] make checkpointing compatible when using torch.compile (part II) (#6511)
make checkpointing compatible when using torch.compile.
1 parent da843b3 commit be0b425

File tree

1 file changed

+15
-9
lines changed

1 file changed

+15
-9
lines changed

examples/dreambooth/train_dreambooth_lora_sdxl.py

Lines changed: 15 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,7 @@
5656
from diffusers.training_utils import compute_snr
5757
from diffusers.utils import check_min_version, convert_state_dict_to_diffusers, is_wandb_available
5858
from diffusers.utils.import_utils import is_xformers_available
59+
from diffusers.utils.torch_utils import is_compiled_module
5960

6061

6162
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
@@ -1007,6 +1008,11 @@ def main(args):
10071008
if param.requires_grad:
10081009
param.data = param.to(torch.float32)
10091010

1011+
def unwrap_model(model):
1012+
model = accelerator.unwrap_model(model)
1013+
model = model._orig_mod if is_compiled_module(model) else model
1014+
return model
1015+
10101016
# create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format
10111017
def save_model_hook(models, weights, output_dir):
10121018
if accelerator.is_main_process:
@@ -1017,13 +1023,13 @@ def save_model_hook(models, weights, output_dir):
10171023
text_encoder_two_lora_layers_to_save = None
10181024

10191025
for model in models:
1020-
if isinstance(model, type(accelerator.unwrap_model(unet))):
1026+
if isinstance(model, type(unwrap_model(unet))):
10211027
unet_lora_layers_to_save = convert_state_dict_to_diffusers(get_peft_model_state_dict(model))
1022-
elif isinstance(model, type(accelerator.unwrap_model(text_encoder_one))):
1028+
elif isinstance(model, type(unwrap_model(text_encoder_one))):
10231029
text_encoder_one_lora_layers_to_save = convert_state_dict_to_diffusers(
10241030
get_peft_model_state_dict(model)
10251031
)
1026-
elif isinstance(model, type(accelerator.unwrap_model(text_encoder_two))):
1032+
elif isinstance(model, type(unwrap_model(text_encoder_two))):
10271033
text_encoder_two_lora_layers_to_save = convert_state_dict_to_diffusers(
10281034
get_peft_model_state_dict(model)
10291035
)
@@ -1048,11 +1054,11 @@ def load_model_hook(models, input_dir):
10481054
while len(models) > 0:
10491055
model = models.pop()
10501056

1051-
if isinstance(model, type(accelerator.unwrap_model(unet))):
1057+
if isinstance(model, type(unwrap_model(unet))):
10521058
unet_ = model
1053-
elif isinstance(model, type(accelerator.unwrap_model(text_encoder_one))):
1059+
elif isinstance(model, type(unwrap_model(text_encoder_one))):
10541060
text_encoder_one_ = model
1055-
elif isinstance(model, type(accelerator.unwrap_model(text_encoder_two))):
1061+
elif isinstance(model, type(unwrap_model(text_encoder_two))):
10561062
text_encoder_two_ = model
10571063
else:
10581064
raise ValueError(f"unexpected save model: {model.__class__}")
@@ -1621,16 +1627,16 @@ def compute_text_embeddings(prompt, text_encoders, tokenizers):
16211627
# Save the lora layers
16221628
accelerator.wait_for_everyone()
16231629
if accelerator.is_main_process:
1624-
unet = accelerator.unwrap_model(unet)
1630+
unet = unwrap_model(unet)
16251631
unet = unet.to(torch.float32)
16261632
unet_lora_layers = convert_state_dict_to_diffusers(get_peft_model_state_dict(unet))
16271633

16281634
if args.train_text_encoder:
1629-
text_encoder_one = accelerator.unwrap_model(text_encoder_one)
1635+
text_encoder_one = unwrap_model(text_encoder_one)
16301636
text_encoder_lora_layers = convert_state_dict_to_diffusers(
16311637
get_peft_model_state_dict(text_encoder_one.to(torch.float32))
16321638
)
1633-
text_encoder_two = accelerator.unwrap_model(text_encoder_two)
1639+
text_encoder_two = unwrap_model(text_encoder_two)
16341640
text_encoder_2_lora_layers = convert_state_dict_to_diffusers(
16351641
get_peft_model_state_dict(text_encoder_two.to(torch.float32))
16361642
)

0 commit comments

Comments
 (0)