|
52 | 52 | from diffusers.optimization import get_scheduler
|
53 | 53 | from diffusers.utils import check_min_version, is_wandb_available, make_image_grid
|
54 | 54 | from diffusers.utils.import_utils import is_xformers_available
|
| 55 | +from diffusers.utils.torch_utils import is_compiled_module |
55 | 56 |
|
56 | 57 |
|
57 | 58 | if is_wandb_available():
|
@@ -847,6 +848,11 @@ def main(args):
|
847 | 848 | logger.info("Initializing controlnet weights from unet")
|
848 | 849 | controlnet = ControlNetModel.from_unet(unet)
|
849 | 850 |
|
| 851 | + def unwrap_model(model): |
| 852 | + model = accelerator.unwrap_model(model) |
| 853 | + model = model._orig_mod if is_compiled_module(model) else model |
| 854 | + return model |
| 855 | + |
850 | 856 | # `accelerate` 0.16.0 will have better support for customized saving
|
851 | 857 | if version.parse(accelerate.__version__) >= version.parse("0.16.0"):
|
852 | 858 | # create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format
|
@@ -908,9 +914,9 @@ def load_model_hook(models, input_dir):
|
908 | 914 | " doing mixed precision training, copy of the weights should still be float32."
|
909 | 915 | )
|
910 | 916 |
|
911 |
| - if accelerator.unwrap_model(controlnet).dtype != torch.float32: |
| 917 | + if unwrap_model(controlnet).dtype != torch.float32: |
912 | 918 | raise ValueError(
|
913 |
| - f"Controlnet loaded as datatype {accelerator.unwrap_model(controlnet).dtype}. {low_precision_error_string}" |
| 919 | + f"Controlnet loaded as datatype {unwrap_model(controlnet).dtype}. {low_precision_error_string}" |
914 | 920 | )
|
915 | 921 |
|
916 | 922 | # Enable TF32 for faster training on Ampere GPUs,
|
@@ -1158,7 +1164,8 @@ def compute_embeddings(batch, proportion_empty_prompts, text_encoders, tokenizer
|
1158 | 1164 | sample.to(dtype=weight_dtype) for sample in down_block_res_samples
|
1159 | 1165 | ],
|
1160 | 1166 | mid_block_additional_residual=mid_block_res_sample.to(dtype=weight_dtype),
|
1161 |
| - ).sample |
| 1167 | + return_dict=False, |
| 1168 | + )[0] |
1162 | 1169 |
|
1163 | 1170 | # Get the target for loss depending on the prediction type
|
1164 | 1171 | if noise_scheduler.config.prediction_type == "epsilon":
|
@@ -1223,7 +1230,7 @@ def compute_embeddings(batch, proportion_empty_prompts, text_encoders, tokenizer
|
1223 | 1230 | # Create the pipeline using using the trained modules and save it.
|
1224 | 1231 | accelerator.wait_for_everyone()
|
1225 | 1232 | if accelerator.is_main_process:
|
1226 |
| - controlnet = accelerator.unwrap_model(controlnet) |
| 1233 | + controlnet = unwrap_model(controlnet) |
1227 | 1234 | controlnet.save_pretrained(args.output_dir)
|
1228 | 1235 |
|
1229 | 1236 | if args.push_to_hub:
|
|
0 commit comments