From a144ae870544337d0bf2aa94e9e0717e5bd1ae53 Mon Sep 17 00:00:00 2001 From: Vincent Date: Fri, 12 Jan 2024 23:25:56 +0700 Subject: [PATCH] make compile compatible --- .../train_text_to_image_lora_sdxl.py | 43 +++++++++++-------- 1 file changed, 26 insertions(+), 17 deletions(-) diff --git a/examples/text_to_image/train_text_to_image_lora_sdxl.py b/examples/text_to_image/train_text_to_image_lora_sdxl.py index 606a88f55b32..90df453f04e7 100644 --- a/examples/text_to_image/train_text_to_image_lora_sdxl.py +++ b/examples/text_to_image/train_text_to_image_lora_sdxl.py @@ -54,6 +54,7 @@ from diffusers.training_utils import compute_snr from diffusers.utils import check_min_version, convert_state_dict_to_diffusers, is_wandb_available from diffusers.utils.import_utils import is_xformers_available +from diffusers.utils.torch_utils import is_compiled_module # Will error if the minimal version of diffusers is not installed. Remove at your own risks. @@ -460,13 +461,12 @@ def encode_prompt(text_encoders, tokenizers, prompt, text_input_ids_list=None): text_input_ids = text_input_ids_list[i] prompt_embeds = text_encoder( - text_input_ids.to(text_encoder.device), - output_hidden_states=True, + text_input_ids.to(text_encoder.device), output_hidden_states=True, return_dict=False ) # We are only ALWAYS interested in the pooled output of the final text encoder pooled_prompt_embeds = prompt_embeds[0] - prompt_embeds = prompt_embeds.hidden_states[-2] + prompt_embeds = prompt_embeds[-1][-2] bs_embed, seq_len, _ = prompt_embeds.shape prompt_embeds = prompt_embeds.view(bs_embed, seq_len, -1) prompt_embeds_list.append(prompt_embeds) @@ -640,6 +640,11 @@ def main(args): if param.requires_grad: param.data = param.to(torch.float32) + def unwrap_model(model): + model = accelerator.unwrap_model(model) + model = model._orig_mod if is_compiled_module(model) else model + return model + # create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format def save_model_hook(models, weights, output_dir): if accelerator.is_main_process: @@ -650,13 +655,13 @@ def save_model_hook(models, weights, output_dir): text_encoder_two_lora_layers_to_save = None for model in models: - if isinstance(model, type(accelerator.unwrap_model(unet))): + if isinstance(model, type(unwrap_model(unet))): unet_lora_layers_to_save = convert_state_dict_to_diffusers(get_peft_model_state_dict(model)) - elif isinstance(model, type(accelerator.unwrap_model(text_encoder_one))): + elif isinstance(model, type(unwrap_model(text_encoder_one))): text_encoder_one_lora_layers_to_save = convert_state_dict_to_diffusers( get_peft_model_state_dict(model) ) - elif isinstance(model, type(accelerator.unwrap_model(text_encoder_two))): + elif isinstance(model, type(unwrap_model(text_encoder_two))): text_encoder_two_lora_layers_to_save = convert_state_dict_to_diffusers( get_peft_model_state_dict(model) ) @@ -681,11 +686,11 @@ def load_model_hook(models, input_dir): while len(models) > 0: model = models.pop() - if isinstance(model, type(accelerator.unwrap_model(unet))): + if isinstance(model, type(unwrap_model(unet))): unet_ = model - elif isinstance(model, type(accelerator.unwrap_model(text_encoder_one))): + elif isinstance(model, type(unwrap_model(text_encoder_one))): text_encoder_one_ = model - elif isinstance(model, type(accelerator.unwrap_model(text_encoder_two))): + elif isinstance(model, type(unwrap_model(text_encoder_two))): text_encoder_two_ = model else: raise ValueError(f"unexpected save model: {model.__class__}") @@ -1034,8 +1039,12 @@ def compute_time_ids(original_size, crops_coords_top_left): ) unet_added_conditions.update({"text_embeds": pooled_prompt_embeds}) model_pred = unet( - noisy_model_input, timesteps, prompt_embeds, added_cond_kwargs=unet_added_conditions - ).sample + noisy_model_input, + timesteps, + prompt_embeds, + added_cond_kwargs=unet_added_conditions, + return_dict=False, + )[0] # Get the target for loss depending on the prediction type if args.prediction_type is not None: @@ -1128,9 +1137,9 @@ def compute_time_ids(original_size, crops_coords_top_left): pipeline = StableDiffusionXLPipeline.from_pretrained( args.pretrained_model_name_or_path, vae=vae, - text_encoder=accelerator.unwrap_model(text_encoder_one), - text_encoder_2=accelerator.unwrap_model(text_encoder_two), - unet=accelerator.unwrap_model(unet), + text_encoder=unwrap_model(text_encoder_one), + text_encoder_2=unwrap_model(text_encoder_two), + unet=unwrap_model(unet), revision=args.revision, variant=args.variant, torch_dtype=weight_dtype, @@ -1169,12 +1178,12 @@ def compute_time_ids(original_size, crops_coords_top_left): # Save the lora layers accelerator.wait_for_everyone() if accelerator.is_main_process: - unet = accelerator.unwrap_model(unet) + unet = unwrap_model(unet) unet_lora_state_dict = convert_state_dict_to_diffusers(get_peft_model_state_dict(unet)) if args.train_text_encoder: - text_encoder_one = accelerator.unwrap_model(text_encoder_one) - text_encoder_two = accelerator.unwrap_model(text_encoder_two) + text_encoder_one = unwrap_model(text_encoder_one) + text_encoder_two = unwrap_model(text_encoder_two) text_encoder_lora_layers = convert_state_dict_to_diffusers(get_peft_model_state_dict(text_encoder_one)) text_encoder_2_lora_layers = convert_state_dict_to_diffusers(get_peft_model_state_dict(text_encoder_two))