58
58
)
59
59
from diffusers .loaders import LoraLoaderMixin
60
60
from diffusers .optimization import get_scheduler
61
- from diffusers .training_utils import _set_state_dict_into_text_encoder , compute_snr
61
+ from diffusers .training_utils import _set_state_dict_into_text_encoder , cast_training_params , compute_snr
62
62
from diffusers .utils import (
63
63
check_min_version ,
64
64
convert_all_state_dict_to_peft ,
@@ -1368,7 +1368,6 @@ def load_model_hook(models, input_dir):
1368
1368
)
1369
1369
1370
1370
if args .train_text_encoder :
1371
- # Do we need to call `scale_lora_layers()` here?
1372
1371
_set_state_dict_into_text_encoder (lora_state_dict , prefix = "text_encoder." , text_encoder = text_encoder_one_ )
1373
1372
1374
1373
_set_state_dict_into_text_encoder (
@@ -1382,11 +1381,7 @@ def load_model_hook(models, input_dir):
1382
1381
models = [unet_ ]
1383
1382
if args .train_text_encoder :
1384
1383
models .extend ([text_encoder_one_ , text_encoder_two_ ])
1385
- for model in models :
1386
- for param in model .parameters ():
1387
- # only upcast trainable parameters (LoRA) into fp32
1388
- if param .requires_grad :
1389
- param .data = param .to (torch .float32 )
1384
+ cast_training_params (models )
1390
1385
1391
1386
accelerator .register_save_state_pre_hook (save_model_hook )
1392
1387
accelerator .register_load_state_pre_hook (load_model_hook )
@@ -1406,11 +1401,7 @@ def load_model_hook(models, input_dir):
1406
1401
models = [unet ]
1407
1402
if args .train_text_encoder :
1408
1403
models .extend ([text_encoder_one , text_encoder_two ])
1409
- for model in models :
1410
- for param in model .parameters ():
1411
- # only upcast trainable parameters (LoRA) into fp32
1412
- if param .requires_grad :
1413
- param .data = param .to (torch .float32 )
1404
+ cast_training_params (models , dtype = torch .float32 )
1414
1405
1415
1406
unet_lora_parameters = list (filter (lambda p : p .requires_grad , unet .parameters ()))
1416
1407
0 commit comments