Skip to content

Commit e44b205

Browse files
authored
Make ControlNet SDXL Training Script torch.compile compatible (#6526)
* make torch.compile compatible * fix quality
1 parent 60cb443 commit e44b205

File tree

1 file changed

+11
-4
lines changed

1 file changed

+11
-4
lines changed

examples/controlnet/train_controlnet_sdxl.py

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,7 @@
5252
from diffusers.optimization import get_scheduler
5353
from diffusers.utils import check_min_version, is_wandb_available, make_image_grid
5454
from diffusers.utils.import_utils import is_xformers_available
55+
from diffusers.utils.torch_utils import is_compiled_module
5556

5657

5758
if is_wandb_available():
@@ -847,6 +848,11 @@ def main(args):
847848
logger.info("Initializing controlnet weights from unet")
848849
controlnet = ControlNetModel.from_unet(unet)
849850

851+
def unwrap_model(model):
852+
model = accelerator.unwrap_model(model)
853+
model = model._orig_mod if is_compiled_module(model) else model
854+
return model
855+
850856
# `accelerate` 0.16.0 will have better support for customized saving
851857
if version.parse(accelerate.__version__) >= version.parse("0.16.0"):
852858
# create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format
@@ -908,9 +914,9 @@ def load_model_hook(models, input_dir):
908914
" doing mixed precision training, copy of the weights should still be float32."
909915
)
910916

911-
if accelerator.unwrap_model(controlnet).dtype != torch.float32:
917+
if unwrap_model(controlnet).dtype != torch.float32:
912918
raise ValueError(
913-
f"Controlnet loaded as datatype {accelerator.unwrap_model(controlnet).dtype}. {low_precision_error_string}"
919+
f"Controlnet loaded as datatype {unwrap_model(controlnet).dtype}. {low_precision_error_string}"
914920
)
915921

916922
# Enable TF32 for faster training on Ampere GPUs,
@@ -1158,7 +1164,8 @@ def compute_embeddings(batch, proportion_empty_prompts, text_encoders, tokenizer
11581164
sample.to(dtype=weight_dtype) for sample in down_block_res_samples
11591165
],
11601166
mid_block_additional_residual=mid_block_res_sample.to(dtype=weight_dtype),
1161-
).sample
1167+
return_dict=False,
1168+
)[0]
11621169

11631170
# Get the target for loss depending on the prediction type
11641171
if noise_scheduler.config.prediction_type == "epsilon":
@@ -1223,7 +1230,7 @@ def compute_embeddings(batch, proportion_empty_prompts, text_encoders, tokenizer
12231230
# Create the pipeline using using the trained modules and save it.
12241231
accelerator.wait_for_everyone()
12251232
if accelerator.is_main_process:
1226-
controlnet = accelerator.unwrap_model(controlnet)
1233+
controlnet = unwrap_model(controlnet)
12271234
controlnet.save_pretrained(args.output_dir)
12281235

12291236
if args.push_to_hub:

0 commit comments

Comments
 (0)