Skip to content

Commit dce0668

Browse files
authored
Fixes torch.compile() compatible training (#6589)
resolve conflicts
1 parent dd63168 commit dce0668

File tree

1 file changed

+12
-6
lines changed

1 file changed

+12
-6
lines changed

examples/advanced_diffusion_training/train_dreambooth_lora_sdxl_advanced.py

Lines changed: 12 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -68,6 +68,7 @@
6868
is_wandb_available,
6969
)
7070
from diffusers.utils.import_utils import is_xformers_available
71+
from diffusers.utils.torch_utils import is_compiled_module
7172

7273

7374
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
@@ -1293,6 +1294,11 @@ def main(args):
12931294
else:
12941295
param.requires_grad = False
12951296

1297+
def unwrap_model(model):
1298+
model = accelerator.unwrap_model(model)
1299+
model = model._orig_mod if is_compiled_module(model) else model
1300+
return model
1301+
12961302
# create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format
12971303
def save_model_hook(models, weights, output_dir):
12981304
if accelerator.is_main_process:
@@ -1303,14 +1309,14 @@ def save_model_hook(models, weights, output_dir):
13031309
text_encoder_two_lora_layers_to_save = None
13041310

13051311
for model in models:
1306-
if isinstance(model, type(accelerator.unwrap_model(unet))):
1312+
if isinstance(model, type(unwrap_model(unet))):
13071313
unet_lora_layers_to_save = convert_state_dict_to_diffusers(get_peft_model_state_dict(model))
1308-
elif isinstance(model, type(accelerator.unwrap_model(text_encoder_one))):
1314+
elif isinstance(model, type(unwrap_model(text_encoder_one))):
13091315
if args.train_text_encoder:
13101316
text_encoder_one_lora_layers_to_save = convert_state_dict_to_diffusers(
13111317
get_peft_model_state_dict(model)
13121318
)
1313-
elif isinstance(model, type(accelerator.unwrap_model(text_encoder_two))):
1319+
elif isinstance(model, type(unwrap_model(text_encoder_two))):
13141320
if args.train_text_encoder:
13151321
text_encoder_two_lora_layers_to_save = convert_state_dict_to_diffusers(
13161322
get_peft_model_state_dict(model)
@@ -1338,11 +1344,11 @@ def load_model_hook(models, input_dir):
13381344
while len(models) > 0:
13391345
model = models.pop()
13401346

1341-
if isinstance(model, type(accelerator.unwrap_model(unet))):
1347+
if isinstance(model, type(unwrap_model(unet))):
13421348
unet_ = model
1343-
elif isinstance(model, type(accelerator.unwrap_model(text_encoder_one))):
1349+
elif isinstance(model, type(unwrap_model(text_encoder_one))):
13441350
text_encoder_one_ = model
1345-
elif isinstance(model, type(accelerator.unwrap_model(text_encoder_two))):
1351+
elif isinstance(model, type(unwrap_model(text_encoder_two))):
13461352
text_encoder_two_ = model
13471353
else:
13481354
raise ValueError(f"unexpected save model: {model.__class__}")

0 commit comments

Comments
 (0)