diff --git a/examples/dreambooth/train_dreambooth_lora_sdxl.py b/examples/dreambooth/train_dreambooth_lora_sdxl.py index 0b7bd64e9091..122af23865b8 100644 --- a/examples/dreambooth/train_dreambooth_lora_sdxl.py +++ b/examples/dreambooth/train_dreambooth_lora_sdxl.py @@ -780,13 +780,12 @@ def encode_prompt(text_encoders, tokenizers, prompt, text_input_ids_list=None): text_input_ids = text_input_ids_list[i] prompt_embeds = text_encoder( - text_input_ids.to(text_encoder.device), - output_hidden_states=True, + text_input_ids.to(text_encoder.device), output_hidden_states=True, return_dict=False ) # We are only ALWAYS interested in the pooled output of the final text encoder pooled_prompt_embeds = prompt_embeds[0] - prompt_embeds = prompt_embeds.hidden_states[-2] + prompt_embeds = prompt_embeds[-1][-2] bs_embed, seq_len, _ = prompt_embeds.shape prompt_embeds = prompt_embeds.view(bs_embed, seq_len, -1) prompt_embeds_list.append(prompt_embeds) @@ -1429,7 +1428,8 @@ def compute_text_embeddings(prompt, text_encoders, tokenizers): timesteps, prompt_embeds_input, added_cond_kwargs=unet_added_conditions, - ).sample + return_dict=False, + )[0] else: unet_added_conditions = {"time_ids": add_time_ids.repeat(elems_to_repeat_time_ids, 1)} prompt_embeds, pooled_prompt_embeds = encode_prompt( @@ -1443,8 +1443,12 @@ def compute_text_embeddings(prompt, text_encoders, tokenizers): ) prompt_embeds_input = prompt_embeds.repeat(elems_to_repeat_text_embeds, 1, 1) model_pred = unet( - noisy_model_input, timesteps, prompt_embeds_input, added_cond_kwargs=unet_added_conditions - ).sample + noisy_model_input, + timesteps, + prompt_embeds_input, + added_cond_kwargs=unet_added_conditions, + return_dict=False, + )[0] # Get the target for loss depending on the prediction type if noise_scheduler.config.prediction_type == "epsilon":