Skip to content

Make Dreambooth SD LoRA Training Script torch.compile compatible #6534

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 22 additions & 11 deletions examples/dreambooth/train_dreambooth_lora.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@
from diffusers.optimization import get_scheduler
from diffusers.utils import check_min_version, convert_state_dict_to_diffusers, is_wandb_available
from diffusers.utils.import_utils import is_xformers_available
from diffusers.utils.torch_utils import is_compiled_module


# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
Expand Down Expand Up @@ -647,6 +648,7 @@ def encode_prompt(text_encoder, input_ids, attention_mask, text_encoder_use_atte
prompt_embeds = text_encoder(
text_input_ids,
attention_mask=attention_mask,
return_dict=False,
)
prompt_embeds = prompt_embeds[0]

Expand Down Expand Up @@ -843,6 +845,11 @@ def main(args):
)
text_encoder.add_adapter(text_lora_config)

def unwrap_model(model):
model = accelerator.unwrap_model(model)
model = model._orig_mod if is_compiled_module(model) else model
return model

# create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format
def save_model_hook(models, weights, output_dir):
if accelerator.is_main_process:
Expand All @@ -852,9 +859,9 @@ def save_model_hook(models, weights, output_dir):
text_encoder_lora_layers_to_save = None

for model in models:
if isinstance(model, type(accelerator.unwrap_model(unet))):
if isinstance(model, type(unwrap_model(unet))):
unet_lora_layers_to_save = convert_state_dict_to_diffusers(get_peft_model_state_dict(model))
elif isinstance(model, type(accelerator.unwrap_model(text_encoder))):
elif isinstance(model, type(unwrap_model(text_encoder))):
text_encoder_lora_layers_to_save = convert_state_dict_to_diffusers(
get_peft_model_state_dict(model)
)
Expand All @@ -877,9 +884,9 @@ def load_model_hook(models, input_dir):
while len(models) > 0:
model = models.pop()

if isinstance(model, type(accelerator.unwrap_model(unet))):
if isinstance(model, type(unwrap_model(unet))):
unet_ = model
elif isinstance(model, type(accelerator.unwrap_model(text_encoder))):
elif isinstance(model, type(unwrap_model(text_encoder))):
text_encoder_ = model
else:
raise ValueError(f"unexpected save model: {model.__class__}")
Expand Down Expand Up @@ -1118,7 +1125,7 @@ def compute_text_embeddings(prompt):
text_encoder_use_attention_mask=args.text_encoder_use_attention_mask,
)

if accelerator.unwrap_model(unet).config.in_channels == channels * 2:
if unwrap_model(unet).config.in_channels == channels * 2:
noisy_model_input = torch.cat([noisy_model_input, noisy_model_input], dim=1)

if args.class_labels_conditioning == "timesteps":
Expand All @@ -1128,8 +1135,12 @@ def compute_text_embeddings(prompt):

# Predict the noise residual
model_pred = unet(
noisy_model_input, timesteps, encoder_hidden_states, class_labels=class_labels
).sample
noisy_model_input,
timesteps,
encoder_hidden_states,
class_labels=class_labels,
return_dict=False,
)[0]

# if model predicts variance, throw away the prediction. we will only train on the
# simplified training objective. This means that all schedulers using the fine tuned
Expand Down Expand Up @@ -1215,8 +1226,8 @@ def compute_text_embeddings(prompt):
# create pipeline
pipeline = DiffusionPipeline.from_pretrained(
args.pretrained_model_name_or_path,
unet=accelerator.unwrap_model(unet),
text_encoder=None if args.pre_compute_text_embeddings else accelerator.unwrap_model(text_encoder),
unet=unwrap_model(unet),
text_encoder=None if args.pre_compute_text_embeddings else unwrap_model(text_encoder),
revision=args.revision,
variant=args.variant,
torch_dtype=weight_dtype,
Expand Down Expand Up @@ -1284,13 +1295,13 @@ def compute_text_embeddings(prompt):
# Save the lora layers
accelerator.wait_for_everyone()
if accelerator.is_main_process:
unet = accelerator.unwrap_model(unet)
unet = unwrap_model(unet)
unet = unet.to(torch.float32)

unet_lora_state_dict = convert_state_dict_to_diffusers(get_peft_model_state_dict(unet))

if args.train_text_encoder:
text_encoder = accelerator.unwrap_model(text_encoder)
text_encoder = unwrap_model(text_encoder)
text_encoder_state_dict = convert_state_dict_to_diffusers(get_peft_model_state_dict(text_encoder))
else:
text_encoder_state_dict = None
Expand Down