From eb9dfc65bdc3f5a747a6dde77210e4136838ccbc Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Mon, 8 Jan 2024 10:50:44 +0530 Subject: [PATCH 1/3] make it torch.compile comaptible --- examples/dreambooth/train_dreambooth_lora_sdxl.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/examples/dreambooth/train_dreambooth_lora_sdxl.py b/examples/dreambooth/train_dreambooth_lora_sdxl.py index 0b7bd64e9091..dd2d10f7bd1c 100644 --- a/examples/dreambooth/train_dreambooth_lora_sdxl.py +++ b/examples/dreambooth/train_dreambooth_lora_sdxl.py @@ -1429,7 +1429,8 @@ def compute_text_embeddings(prompt, text_encoders, tokenizers): timesteps, prompt_embeds_input, added_cond_kwargs=unet_added_conditions, - ).sample + return_dict=False + )[0] else: unet_added_conditions = {"time_ids": add_time_ids.repeat(elems_to_repeat_time_ids, 1)} prompt_embeds, pooled_prompt_embeds = encode_prompt( @@ -1443,8 +1444,8 @@ def compute_text_embeddings(prompt, text_encoders, tokenizers): ) prompt_embeds_input = prompt_embeds.repeat(elems_to_repeat_text_embeds, 1, 1) model_pred = unet( - noisy_model_input, timesteps, prompt_embeds_input, added_cond_kwargs=unet_added_conditions - ).sample + noisy_model_input, timesteps, prompt_embeds_input, added_cond_kwargs=unet_added_conditions, return_dict=False + )[0] # Get the target for loss depending on the prediction type if noise_scheduler.config.prediction_type == "epsilon": From e035a0dd71d8f9a99d98537f53766606ae06a8d4 Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Mon, 8 Jan 2024 11:04:28 +0530 Subject: [PATCH 2/3] make the text encoder compatible too. --- examples/dreambooth/train_dreambooth_lora_sdxl.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/examples/dreambooth/train_dreambooth_lora_sdxl.py b/examples/dreambooth/train_dreambooth_lora_sdxl.py index dd2d10f7bd1c..b2be02edc813 100644 --- a/examples/dreambooth/train_dreambooth_lora_sdxl.py +++ b/examples/dreambooth/train_dreambooth_lora_sdxl.py @@ -781,12 +781,12 @@ def encode_prompt(text_encoders, tokenizers, prompt, text_input_ids_list=None): prompt_embeds = text_encoder( text_input_ids.to(text_encoder.device), - output_hidden_states=True, + output_hidden_states=True, return_dict=False ) # We are only ALWAYS interested in the pooled output of the final text encoder pooled_prompt_embeds = prompt_embeds[0] - prompt_embeds = prompt_embeds.hidden_states[-2] + prompt_embeds = prompt_embeds[-1][-2] bs_embed, seq_len, _ = prompt_embeds.shape prompt_embeds = prompt_embeds.view(bs_embed, seq_len, -1) prompt_embeds_list.append(prompt_embeds) From ddcd43da9aa54740762ca27c68163ececf3068c2 Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Mon, 8 Jan 2024 11:27:26 +0530 Subject: [PATCH 3/3] style --- examples/dreambooth/train_dreambooth_lora_sdxl.py | 11 +++++++---- 1 file changed, 7 insertions(+), 4 deletions(-) diff --git a/examples/dreambooth/train_dreambooth_lora_sdxl.py b/examples/dreambooth/train_dreambooth_lora_sdxl.py index b2be02edc813..122af23865b8 100644 --- a/examples/dreambooth/train_dreambooth_lora_sdxl.py +++ b/examples/dreambooth/train_dreambooth_lora_sdxl.py @@ -780,8 +780,7 @@ def encode_prompt(text_encoders, tokenizers, prompt, text_input_ids_list=None): text_input_ids = text_input_ids_list[i] prompt_embeds = text_encoder( - text_input_ids.to(text_encoder.device), - output_hidden_states=True, return_dict=False + text_input_ids.to(text_encoder.device), output_hidden_states=True, return_dict=False ) # We are only ALWAYS interested in the pooled output of the final text encoder @@ -1429,7 +1428,7 @@ def compute_text_embeddings(prompt, text_encoders, tokenizers): timesteps, prompt_embeds_input, added_cond_kwargs=unet_added_conditions, - return_dict=False + return_dict=False, )[0] else: unet_added_conditions = {"time_ids": add_time_ids.repeat(elems_to_repeat_time_ids, 1)} @@ -1444,7 +1443,11 @@ def compute_text_embeddings(prompt, text_encoders, tokenizers): ) prompt_embeds_input = prompt_embeds.repeat(elems_to_repeat_text_embeds, 1, 1) model_pred = unet( - noisy_model_input, timesteps, prompt_embeds_input, added_cond_kwargs=unet_added_conditions, return_dict=False + noisy_model_input, + timesteps, + prompt_embeds_input, + added_cond_kwargs=unet_added_conditions, + return_dict=False, )[0] # Get the target for loss depending on the prediction type