38
38
from accelerate .utils import DistributedDataParallelKwargs , ProjectConfiguration , set_seed
39
39
from huggingface_hub import create_repo , upload_folder
40
40
from packaging import version
41
- from peft import LoraConfig
41
+ from peft import LoraConfig , set_peft_model_state_dict
42
42
from peft .utils import get_peft_model_state_dict
43
43
from PIL import Image
44
44
from PIL .ImageOps import exif_transpose
58
58
)
59
59
from diffusers .loaders import LoraLoaderMixin
60
60
from diffusers .optimization import get_scheduler
61
- from diffusers .training_utils import compute_snr
61
+ from diffusers .training_utils import _set_state_dict_into_text_encoder , cast_training_params , compute_snr
62
62
from diffusers .utils import (
63
63
check_min_version ,
64
64
convert_all_state_dict_to_peft ,
65
65
convert_state_dict_to_diffusers ,
66
66
convert_state_dict_to_kohya ,
67
+ convert_unet_state_dict_to_peft ,
67
68
is_wandb_available ,
68
69
)
69
70
from diffusers .utils .import_utils import is_xformers_available
@@ -1292,17 +1293,6 @@ def main(args):
1292
1293
else :
1293
1294
param .requires_grad = False
1294
1295
1295
- # Make sure the trainable params are in float32.
1296
- if args .mixed_precision == "fp16" :
1297
- models = [unet ]
1298
- if args .train_text_encoder :
1299
- models .extend ([text_encoder_one , text_encoder_two ])
1300
- for model in models :
1301
- for param in model .parameters ():
1302
- # only upcast trainable parameters (LoRA) into fp32
1303
- if param .requires_grad :
1304
- param .data = param .to (torch .float32 )
1305
-
1306
1296
# create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format
1307
1297
def save_model_hook (models , weights , output_dir ):
1308
1298
if accelerator .is_main_process :
@@ -1358,17 +1348,34 @@ def load_model_hook(models, input_dir):
1358
1348
raise ValueError (f"unexpected save model: { model .__class__ } " )
1359
1349
1360
1350
lora_state_dict , network_alphas = LoraLoaderMixin .lora_state_dict (input_dir )
1361
- LoraLoaderMixin .load_lora_into_unet (lora_state_dict , network_alphas = network_alphas , unet = unet_ )
1362
1351
1363
- text_encoder_state_dict = {k : v for k , v in lora_state_dict .items () if "text_encoder." in k }
1364
- LoraLoaderMixin .load_lora_into_text_encoder (
1365
- text_encoder_state_dict , network_alphas = network_alphas , text_encoder = text_encoder_one_
1366
- )
1352
+ unet_state_dict = {f'{ k .replace ("unet." , "" )} ' : v for k , v in lora_state_dict .items () if k .startswith ("unet." )}
1353
+ unet_state_dict = convert_unet_state_dict_to_peft (unet_state_dict )
1354
+ incompatible_keys = set_peft_model_state_dict (unet_ , unet_state_dict , adapter_name = "default" )
1355
+ if incompatible_keys is not None :
1356
+ # check only for unexpected keys
1357
+ unexpected_keys = getattr (incompatible_keys , "unexpected_keys" , None )
1358
+ if unexpected_keys :
1359
+ logger .warning (
1360
+ f"Loading adapter weights from state_dict led to unexpected keys not found in the model: "
1361
+ f" { unexpected_keys } . "
1362
+ )
1367
1363
1368
- text_encoder_2_state_dict = {k : v for k , v in lora_state_dict .items () if "text_encoder_2." in k }
1369
- LoraLoaderMixin .load_lora_into_text_encoder (
1370
- text_encoder_2_state_dict , network_alphas = network_alphas , text_encoder = text_encoder_two_
1371
- )
1364
+ if args .train_text_encoder :
1365
+ _set_state_dict_into_text_encoder (lora_state_dict , prefix = "text_encoder." , text_encoder = text_encoder_one_ )
1366
+
1367
+ _set_state_dict_into_text_encoder (
1368
+ lora_state_dict , prefix = "text_encoder_2." , text_encoder = text_encoder_two_
1369
+ )
1370
+
1371
+ # Make sure the trainable params are in float32. This is again needed since the base models
1372
+ # are in `weight_dtype`. More details:
1373
+ # https://github.com/huggingface/diffusers/pull/6514#discussion_r1449796804
1374
+ if args .mixed_precision == "fp16" :
1375
+ models = [unet_ ]
1376
+ if args .train_text_encoder :
1377
+ models .extend ([text_encoder_one_ , text_encoder_two_ ])
1378
+ cast_training_params (models )
1372
1379
1373
1380
accelerator .register_save_state_pre_hook (save_model_hook )
1374
1381
accelerator .register_load_state_pre_hook (load_model_hook )
@@ -1383,6 +1390,13 @@ def load_model_hook(models, input_dir):
1383
1390
args .learning_rate * args .gradient_accumulation_steps * args .train_batch_size * accelerator .num_processes
1384
1391
)
1385
1392
1393
+ # Make sure the trainable params are in float32.
1394
+ if args .mixed_precision == "fp16" :
1395
+ models = [unet ]
1396
+ if args .train_text_encoder :
1397
+ models .extend ([text_encoder_one , text_encoder_two ])
1398
+ cast_training_params (models , dtype = torch .float32 )
1399
+
1386
1400
unet_lora_parameters = list (filter (lambda p : p .requires_grad , unet .parameters ()))
1387
1401
1388
1402
if args .train_text_encoder :
0 commit comments