Skip to content

Commit 60cb443

Browse files
authored
Make Dreambooth SD LoRA Training Script torch.compile compatible (#6534)
support compile
1 parent 1dd0ac9 commit 60cb443

File tree

1 file changed

+22
-11
lines changed

1 file changed

+22
-11
lines changed

examples/dreambooth/train_dreambooth_lora.py

Lines changed: 22 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,7 @@
5656
from diffusers.optimization import get_scheduler
5757
from diffusers.utils import check_min_version, convert_state_dict_to_diffusers, is_wandb_available
5858
from diffusers.utils.import_utils import is_xformers_available
59+
from diffusers.utils.torch_utils import is_compiled_module
5960

6061

6162
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
@@ -647,6 +648,7 @@ def encode_prompt(text_encoder, input_ids, attention_mask, text_encoder_use_atte
647648
prompt_embeds = text_encoder(
648649
text_input_ids,
649650
attention_mask=attention_mask,
651+
return_dict=False,
650652
)
651653
prompt_embeds = prompt_embeds[0]
652654

@@ -843,6 +845,11 @@ def main(args):
843845
)
844846
text_encoder.add_adapter(text_lora_config)
845847

848+
def unwrap_model(model):
849+
model = accelerator.unwrap_model(model)
850+
model = model._orig_mod if is_compiled_module(model) else model
851+
return model
852+
846853
# create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format
847854
def save_model_hook(models, weights, output_dir):
848855
if accelerator.is_main_process:
@@ -852,9 +859,9 @@ def save_model_hook(models, weights, output_dir):
852859
text_encoder_lora_layers_to_save = None
853860

854861
for model in models:
855-
if isinstance(model, type(accelerator.unwrap_model(unet))):
862+
if isinstance(model, type(unwrap_model(unet))):
856863
unet_lora_layers_to_save = convert_state_dict_to_diffusers(get_peft_model_state_dict(model))
857-
elif isinstance(model, type(accelerator.unwrap_model(text_encoder))):
864+
elif isinstance(model, type(unwrap_model(text_encoder))):
858865
text_encoder_lora_layers_to_save = convert_state_dict_to_diffusers(
859866
get_peft_model_state_dict(model)
860867
)
@@ -877,9 +884,9 @@ def load_model_hook(models, input_dir):
877884
while len(models) > 0:
878885
model = models.pop()
879886

880-
if isinstance(model, type(accelerator.unwrap_model(unet))):
887+
if isinstance(model, type(unwrap_model(unet))):
881888
unet_ = model
882-
elif isinstance(model, type(accelerator.unwrap_model(text_encoder))):
889+
elif isinstance(model, type(unwrap_model(text_encoder))):
883890
text_encoder_ = model
884891
else:
885892
raise ValueError(f"unexpected save model: {model.__class__}")
@@ -1118,7 +1125,7 @@ def compute_text_embeddings(prompt):
11181125
text_encoder_use_attention_mask=args.text_encoder_use_attention_mask,
11191126
)
11201127

1121-
if accelerator.unwrap_model(unet).config.in_channels == channels * 2:
1128+
if unwrap_model(unet).config.in_channels == channels * 2:
11221129
noisy_model_input = torch.cat([noisy_model_input, noisy_model_input], dim=1)
11231130

11241131
if args.class_labels_conditioning == "timesteps":
@@ -1128,8 +1135,12 @@ def compute_text_embeddings(prompt):
11281135

11291136
# Predict the noise residual
11301137
model_pred = unet(
1131-
noisy_model_input, timesteps, encoder_hidden_states, class_labels=class_labels
1132-
).sample
1138+
noisy_model_input,
1139+
timesteps,
1140+
encoder_hidden_states,
1141+
class_labels=class_labels,
1142+
return_dict=False,
1143+
)[0]
11331144

11341145
# if model predicts variance, throw away the prediction. we will only train on the
11351146
# simplified training objective. This means that all schedulers using the fine tuned
@@ -1215,8 +1226,8 @@ def compute_text_embeddings(prompt):
12151226
# create pipeline
12161227
pipeline = DiffusionPipeline.from_pretrained(
12171228
args.pretrained_model_name_or_path,
1218-
unet=accelerator.unwrap_model(unet),
1219-
text_encoder=None if args.pre_compute_text_embeddings else accelerator.unwrap_model(text_encoder),
1229+
unet=unwrap_model(unet),
1230+
text_encoder=None if args.pre_compute_text_embeddings else unwrap_model(text_encoder),
12201231
revision=args.revision,
12211232
variant=args.variant,
12221233
torch_dtype=weight_dtype,
@@ -1284,13 +1295,13 @@ def compute_text_embeddings(prompt):
12841295
# Save the lora layers
12851296
accelerator.wait_for_everyone()
12861297
if accelerator.is_main_process:
1287-
unet = accelerator.unwrap_model(unet)
1298+
unet = unwrap_model(unet)
12881299
unet = unet.to(torch.float32)
12891300

12901301
unet_lora_state_dict = convert_state_dict_to_diffusers(get_peft_model_state_dict(unet))
12911302

12921303
if args.train_text_encoder:
1293-
text_encoder = accelerator.unwrap_model(text_encoder)
1304+
text_encoder = unwrap_model(text_encoder)
12941305
text_encoder_state_dict = convert_state_dict_to_diffusers(get_peft_model_state_dict(text_encoder))
12951306
else:
12961307
text_encoder_state_dict = None

0 commit comments

Comments
 (0)