Skip to content

Commit ca1f7df

Browse files
committed
fix nit, use cast_training_params
1 parent 938a653 commit ca1f7df

File tree

1 file changed

+3
-12
lines changed

1 file changed

+3
-12
lines changed

examples/advanced_diffusion_training/train_dreambooth_lora_sdxl_advanced.py

Lines changed: 3 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,7 @@
5858
)
5959
from diffusers.loaders import LoraLoaderMixin
6060
from diffusers.optimization import get_scheduler
61-
from diffusers.training_utils import _set_state_dict_into_text_encoder, compute_snr
61+
from diffusers.training_utils import _set_state_dict_into_text_encoder, cast_training_params, compute_snr
6262
from diffusers.utils import (
6363
check_min_version,
6464
convert_all_state_dict_to_peft,
@@ -1368,7 +1368,6 @@ def load_model_hook(models, input_dir):
13681368
)
13691369

13701370
if args.train_text_encoder:
1371-
# Do we need to call `scale_lora_layers()` here?
13721371
_set_state_dict_into_text_encoder(lora_state_dict, prefix="text_encoder.", text_encoder=text_encoder_one_)
13731372

13741373
_set_state_dict_into_text_encoder(
@@ -1382,11 +1381,7 @@ def load_model_hook(models, input_dir):
13821381
models = [unet_]
13831382
if args.train_text_encoder:
13841383
models.extend([text_encoder_one_, text_encoder_two_])
1385-
for model in models:
1386-
for param in model.parameters():
1387-
# only upcast trainable parameters (LoRA) into fp32
1388-
if param.requires_grad:
1389-
param.data = param.to(torch.float32)
1384+
cast_training_params(models)
13901385

13911386
accelerator.register_save_state_pre_hook(save_model_hook)
13921387
accelerator.register_load_state_pre_hook(load_model_hook)
@@ -1406,11 +1401,7 @@ def load_model_hook(models, input_dir):
14061401
models = [unet]
14071402
if args.train_text_encoder:
14081403
models.extend([text_encoder_one, text_encoder_two])
1409-
for model in models:
1410-
for param in model.parameters():
1411-
# only upcast trainable parameters (LoRA) into fp32
1412-
if param.requires_grad:
1413-
param.data = param.to(torch.float32)
1404+
cast_training_params(models, dtype=torch.float32)
14141405

14151406
unet_lora_parameters = list(filter(lambda p: p.requires_grad, unet.parameters()))
14161407

0 commit comments

Comments
 (0)