Skip to content

Commit 0990377

Browse files
authored
Make T2I Adapter SDXL Training Script torch.compile compatible (#6577)
update for t2i_adapter
1 parent d6a70d8 commit 0990377

File tree

1 file changed

+11
-4
lines changed

1 file changed

+11
-4
lines changed

examples/t2i_adapter/train_t2i_adapter_sdxl.py

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,7 @@
5050
from diffusers.optimization import get_scheduler
5151
from diffusers.utils import check_min_version, is_wandb_available
5252
from diffusers.utils.import_utils import is_xformers_available
53+
from diffusers.utils.torch_utils import is_compiled_module
5354

5455

5556
MAX_SEQ_LENGTH = 77
@@ -926,6 +927,11 @@ def load_model_hook(models, input_dir):
926927
else:
927928
raise ValueError("xformers is not available. Make sure it is installed correctly")
928929

930+
def unwrap_model(model):
931+
model = accelerator.unwrap_model(model)
932+
model = model._orig_mod if is_compiled_module(model) else model
933+
return model
934+
929935
if args.gradient_checkpointing:
930936
unet.enable_gradient_checkpointing()
931937

@@ -935,9 +941,9 @@ def load_model_hook(models, input_dir):
935941
" doing mixed precision training, copy of the weights should still be float32."
936942
)
937943

938-
if accelerator.unwrap_model(t2iadapter).dtype != torch.float32:
944+
if unwrap_model(t2iadapter).dtype != torch.float32:
939945
raise ValueError(
940-
f"Controlnet loaded as datatype {accelerator.unwrap_model(t2iadapter).dtype}. {low_precision_error_string}"
946+
f"Controlnet loaded as datatype {unwrap_model(t2iadapter).dtype}. {low_precision_error_string}"
941947
)
942948

943949
# Enable TF32 for faster training on Ampere GPUs,
@@ -1198,7 +1204,8 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32):
11981204
encoder_hidden_states=batch["prompt_ids"],
11991205
added_cond_kwargs=batch["unet_added_conditions"],
12001206
down_block_additional_residuals=down_block_additional_residuals,
1201-
).sample
1207+
return_dict=False,
1208+
)[0]
12021209

12031210
# Denoise the latents
12041211
denoised_latents = model_pred * (-sigmas) + noisy_latents
@@ -1279,7 +1286,7 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32):
12791286
# Create the pipeline using using the trained modules and save it.
12801287
accelerator.wait_for_everyone()
12811288
if accelerator.is_main_process:
1282-
t2iadapter = accelerator.unwrap_model(t2iadapter)
1289+
t2iadapter = unwrap_model(t2iadapter)
12831290
t2iadapter.save_pretrained(args.output_dir)
12841291

12851292
if args.push_to_hub:

0 commit comments

Comments
 (0)