From 8265cbad00bd4f9557169ebae871089cf733ab5e Mon Sep 17 00:00:00 2001 From: charchit7 Date: Thu, 11 Jan 2024 13:36:23 +0530 Subject: [PATCH 1/2] make torch.compile compatible --- examples/controlnet/train_controlnet_sdxl.py | 15 ++++++++++----- 1 file changed, 10 insertions(+), 5 deletions(-) diff --git a/examples/controlnet/train_controlnet_sdxl.py b/examples/controlnet/train_controlnet_sdxl.py index e8d7607fdb4c..7d3ea3eb0d87 100644 --- a/examples/controlnet/train_controlnet_sdxl.py +++ b/examples/controlnet/train_controlnet_sdxl.py @@ -52,7 +52,7 @@ from diffusers.optimization import get_scheduler from diffusers.utils import check_min_version, is_wandb_available, make_image_grid from diffusers.utils.import_utils import is_xformers_available - +from diffusers.utils.torch_utils import is_compiled_module if is_wandb_available(): import wandb @@ -847,6 +847,11 @@ def main(args): logger.info("Initializing controlnet weights from unet") controlnet = ControlNetModel.from_unet(unet) + def unwrap_model(model): + model = accelerator.unwrap_model(model) + model = model._orig_mod if is_compiled_module(model) else model + return model + # `accelerate` 0.16.0 will have better support for customized saving if version.parse(accelerate.__version__) >= version.parse("0.16.0"): # create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format @@ -908,9 +913,9 @@ def load_model_hook(models, input_dir): " doing mixed precision training, copy of the weights should still be float32." ) - if accelerator.unwrap_model(controlnet).dtype != torch.float32: + if unwrap_model(controlnet).dtype != torch.float32: raise ValueError( - f"Controlnet loaded as datatype {accelerator.unwrap_model(controlnet).dtype}. {low_precision_error_string}" + f"Controlnet loaded as datatype {unwrap_model(controlnet).dtype}. {low_precision_error_string}" ) # Enable TF32 for faster training on Ampere GPUs, @@ -1158,7 +1163,7 @@ def compute_embeddings(batch, proportion_empty_prompts, text_encoders, tokenizer sample.to(dtype=weight_dtype) for sample in down_block_res_samples ], mid_block_additional_residual=mid_block_res_sample.to(dtype=weight_dtype), - ).sample + return_dict=False)[0] # Get the target for loss depending on the prediction type if noise_scheduler.config.prediction_type == "epsilon": @@ -1223,7 +1228,7 @@ def compute_embeddings(batch, proportion_empty_prompts, text_encoders, tokenizer # Create the pipeline using using the trained modules and save it. accelerator.wait_for_everyone() if accelerator.is_main_process: - controlnet = accelerator.unwrap_model(controlnet) + controlnet = unwrap_model(controlnet) controlnet.save_pretrained(args.output_dir) if args.push_to_hub: From 1be4ac992b1d5812f6494a2370d4db60b2f30afb Mon Sep 17 00:00:00 2001 From: charchit7 Date: Thu, 11 Jan 2024 16:52:45 +0530 Subject: [PATCH 2/2] fix quality --- examples/controlnet/train_controlnet_sdxl.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/examples/controlnet/train_controlnet_sdxl.py b/examples/controlnet/train_controlnet_sdxl.py index 7d3ea3eb0d87..6a31524cd2f3 100644 --- a/examples/controlnet/train_controlnet_sdxl.py +++ b/examples/controlnet/train_controlnet_sdxl.py @@ -54,6 +54,7 @@ from diffusers.utils.import_utils import is_xformers_available from diffusers.utils.torch_utils import is_compiled_module + if is_wandb_available(): import wandb @@ -1163,7 +1164,8 @@ def compute_embeddings(batch, proportion_empty_prompts, text_encoders, tokenizer sample.to(dtype=weight_dtype) for sample in down_block_res_samples ], mid_block_additional_residual=mid_block_res_sample.to(dtype=weight_dtype), - return_dict=False)[0] + return_dict=False, + )[0] # Get the target for loss depending on the prediction type if noise_scheduler.config.prediction_type == "epsilon":