diff --git a/examples/controlnet/train_controlnet_sdxl.py b/examples/controlnet/train_controlnet_sdxl.py index e8d7607fdb4c..6a31524cd2f3 100644 --- a/examples/controlnet/train_controlnet_sdxl.py +++ b/examples/controlnet/train_controlnet_sdxl.py @@ -52,6 +52,7 @@ from diffusers.optimization import get_scheduler from diffusers.utils import check_min_version, is_wandb_available, make_image_grid from diffusers.utils.import_utils import is_xformers_available +from diffusers.utils.torch_utils import is_compiled_module if is_wandb_available(): @@ -847,6 +848,11 @@ def main(args): logger.info("Initializing controlnet weights from unet") controlnet = ControlNetModel.from_unet(unet) + def unwrap_model(model): + model = accelerator.unwrap_model(model) + model = model._orig_mod if is_compiled_module(model) else model + return model + # `accelerate` 0.16.0 will have better support for customized saving if version.parse(accelerate.__version__) >= version.parse("0.16.0"): # create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format @@ -908,9 +914,9 @@ def load_model_hook(models, input_dir): " doing mixed precision training, copy of the weights should still be float32." ) - if accelerator.unwrap_model(controlnet).dtype != torch.float32: + if unwrap_model(controlnet).dtype != torch.float32: raise ValueError( - f"Controlnet loaded as datatype {accelerator.unwrap_model(controlnet).dtype}. {low_precision_error_string}" + f"Controlnet loaded as datatype {unwrap_model(controlnet).dtype}. {low_precision_error_string}" ) # Enable TF32 for faster training on Ampere GPUs, @@ -1158,7 +1164,8 @@ def compute_embeddings(batch, proportion_empty_prompts, text_encoders, tokenizer sample.to(dtype=weight_dtype) for sample in down_block_res_samples ], mid_block_additional_residual=mid_block_res_sample.to(dtype=weight_dtype), - ).sample + return_dict=False, + )[0] # Get the target for loss depending on the prediction type if noise_scheduler.config.prediction_type == "epsilon": @@ -1223,7 +1230,7 @@ def compute_embeddings(batch, proportion_empty_prompts, text_encoders, tokenizer # Create the pipeline using using the trained modules and save it. accelerator.wait_for_everyone() if accelerator.is_main_process: - controlnet = accelerator.unwrap_model(controlnet) + controlnet = unwrap_model(controlnet) controlnet.save_pretrained(args.output_dir) if args.push_to_hub: