38
38
from accelerate .utils import DistributedDataParallelKwargs , ProjectConfiguration , set_seed
39
39
from huggingface_hub import create_repo , upload_folder
40
40
from packaging import version
41
- from peft import LoraConfig
41
+ from peft import LoraConfig , set_peft_model_state_dict
42
42
from peft .utils import get_peft_model_state_dict
43
43
from PIL import Image
44
44
from PIL .ImageOps import exif_transpose
58
58
)
59
59
from diffusers .loaders import LoraLoaderMixin
60
60
from diffusers .optimization import get_scheduler
61
- from diffusers .training_utils import compute_snr
61
+ from diffusers .training_utils import _set_state_dict_into_text_encoder , compute_snr
62
62
from diffusers .utils import (
63
63
check_min_version ,
64
64
convert_all_state_dict_to_peft ,
65
65
convert_state_dict_to_diffusers ,
66
+ convert_unet_state_dict_to_peft ,
66
67
convert_state_dict_to_kohya ,
67
68
is_wandb_available ,
68
69
)
69
70
from diffusers .utils .import_utils import is_xformers_available
71
+ from diffusers .utils .torch_utils import is_compiled_module
70
72
71
73
72
74
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
@@ -1292,16 +1294,10 @@ def main(args):
1292
1294
else :
1293
1295
param .requires_grad = False
1294
1296
1295
- # Make sure the trainable params are in float32.
1296
- if args .mixed_precision == "fp16" :
1297
- models = [unet ]
1298
- if args .train_text_encoder :
1299
- models .extend ([text_encoder_one , text_encoder_two ])
1300
- for model in models :
1301
- for param in model .parameters ():
1302
- # only upcast trainable parameters (LoRA) into fp32
1303
- if param .requires_grad :
1304
- param .data = param .to (torch .float32 )
1297
+ def unwrap_model (model ):
1298
+ model = accelerator .unwrap_model (model )
1299
+ model = model ._orig_mod if is_compiled_module (model ) else model
1300
+ return model
1305
1301
1306
1302
# create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format
1307
1303
def save_model_hook (models , weights , output_dir ):
@@ -1313,14 +1309,14 @@ def save_model_hook(models, weights, output_dir):
1313
1309
text_encoder_two_lora_layers_to_save = None
1314
1310
1315
1311
for model in models :
1316
- if isinstance (model , type (accelerator . unwrap_model (unet ))):
1312
+ if isinstance (model , type (unwrap_model (unet ))):
1317
1313
unet_lora_layers_to_save = convert_state_dict_to_diffusers (get_peft_model_state_dict (model ))
1318
- elif isinstance (model , type (accelerator . unwrap_model (text_encoder_one ))):
1314
+ elif isinstance (model , type (unwrap_model (text_encoder_one ))):
1319
1315
if args .train_text_encoder :
1320
1316
text_encoder_one_lora_layers_to_save = convert_state_dict_to_diffusers (
1321
1317
get_peft_model_state_dict (model )
1322
1318
)
1323
- elif isinstance (model , type (accelerator . unwrap_model (text_encoder_two ))):
1319
+ elif isinstance (model , type (unwrap_model (text_encoder_two ))):
1324
1320
if args .train_text_encoder :
1325
1321
text_encoder_two_lora_layers_to_save = convert_state_dict_to_diffusers (
1326
1322
get_peft_model_state_dict (model )
@@ -1348,27 +1344,49 @@ def load_model_hook(models, input_dir):
1348
1344
while len (models ) > 0 :
1349
1345
model = models .pop ()
1350
1346
1351
- if isinstance (model , type (accelerator . unwrap_model (unet ))):
1347
+ if isinstance (model , type (unwrap_model (unet ))):
1352
1348
unet_ = model
1353
- elif isinstance (model , type (accelerator . unwrap_model (text_encoder_one ))):
1349
+ elif isinstance (model , type (unwrap_model (text_encoder_one ))):
1354
1350
text_encoder_one_ = model
1355
- elif isinstance (model , type (accelerator . unwrap_model (text_encoder_two ))):
1351
+ elif isinstance (model , type (unwrap_model (text_encoder_two ))):
1356
1352
text_encoder_two_ = model
1357
1353
else :
1358
1354
raise ValueError (f"unexpected save model: { model .__class__ } " )
1359
1355
1360
1356
lora_state_dict , network_alphas = LoraLoaderMixin .lora_state_dict (input_dir )
1361
- LoraLoaderMixin .load_lora_into_unet (lora_state_dict , network_alphas = network_alphas , unet = unet_ )
1362
1357
1363
- text_encoder_state_dict = {k : v for k , v in lora_state_dict .items () if "text_encoder." in k }
1364
- LoraLoaderMixin .load_lora_into_text_encoder (
1365
- text_encoder_state_dict , network_alphas = network_alphas , text_encoder = text_encoder_one_
1366
- )
1358
+ unet_state_dict = {f'{ k .replace ("unet." , "" )} ' : v for k , v in lora_state_dict .items () if k .startswith ("unet." )}
1359
+ unet_state_dict = convert_unet_state_dict_to_peft (unet_state_dict )
1360
+ incompatible_keys = set_peft_model_state_dict (unet_ , unet_state_dict , adapter_name = "default" )
1361
+ if incompatible_keys is not None :
1362
+ # check only for unexpected keys
1363
+ unexpected_keys = getattr (incompatible_keys , "unexpected_keys" , None )
1364
+ if unexpected_keys :
1365
+ logger .warning (
1366
+ f"Loading adapter weights from state_dict led to unexpected keys not found in the model: "
1367
+ f" { unexpected_keys } . "
1368
+ )
1367
1369
1368
- text_encoder_2_state_dict = {k : v for k , v in lora_state_dict .items () if "text_encoder_2." in k }
1369
- LoraLoaderMixin .load_lora_into_text_encoder (
1370
- text_encoder_2_state_dict , network_alphas = network_alphas , text_encoder = text_encoder_two_
1371
- )
1370
+ if args .train_text_encoder :
1371
+ # Do we need to call `scale_lora_layers()` here?
1372
+ _set_state_dict_into_text_encoder (lora_state_dict , prefix = "text_encoder." , text_encoder = text_encoder_one_ )
1373
+
1374
+ _set_state_dict_into_text_encoder (
1375
+ lora_state_dict , prefix = "text_encoder_2." , text_encoder = text_encoder_two_
1376
+ )
1377
+
1378
+ # Make sure the trainable params are in float32. This is again needed since the base models
1379
+ # are in `weight_dtype`. More details:
1380
+ # https://github.com/huggingface/diffusers/pull/6514#discussion_r1449796804
1381
+ if args .mixed_precision == "fp16" :
1382
+ models = [unet_ ]
1383
+ if args .train_text_encoder :
1384
+ models .extend ([text_encoder_one_ , text_encoder_two_ ])
1385
+ for model in models :
1386
+ for param in model .parameters ():
1387
+ # only upcast trainable parameters (LoRA) into fp32
1388
+ if param .requires_grad :
1389
+ param .data = param .to (torch .float32 )
1372
1390
1373
1391
accelerator .register_save_state_pre_hook (save_model_hook )
1374
1392
accelerator .register_load_state_pre_hook (load_model_hook )
@@ -1383,6 +1401,17 @@ def load_model_hook(models, input_dir):
1383
1401
args .learning_rate * args .gradient_accumulation_steps * args .train_batch_size * accelerator .num_processes
1384
1402
)
1385
1403
1404
+ # Make sure the trainable params are in float32.
1405
+ if args .mixed_precision == "fp16" :
1406
+ models = [unet ]
1407
+ if args .train_text_encoder :
1408
+ models .extend ([text_encoder_one , text_encoder_two ])
1409
+ for model in models :
1410
+ for param in model .parameters ():
1411
+ # only upcast trainable parameters (LoRA) into fp32
1412
+ if param .requires_grad :
1413
+ param .data = param .to (torch .float32 )
1414
+
1386
1415
unet_lora_parameters = list (filter (lambda p : p .requires_grad , unet .parameters ()))
1387
1416
1388
1417
if args .train_text_encoder :
0 commit comments