@@ -339,93 +339,93 @@ def _load_lora_into_text_encoder(
339
339
# If the serialization format is new (introduced in https://github.com/huggingface/diffusers/pull/2918),
340
340
# then the `state_dict` keys should have `unet_name` and/or `text_encoder_name` as
341
341
# their prefixes.
342
- keys = list (state_dict .keys ())
343
342
prefix = text_encoder_name if prefix is None else prefix
344
343
345
- # Safe prefix to check with.
346
- if any (text_encoder_name in key for key in keys ):
347
- # Load the layers corresponding to text encoder and make necessary adjustments.
348
- text_encoder_keys = [k for k in keys if k .startswith (prefix ) and k .split ("." )[0 ] == prefix ]
349
- text_encoder_lora_state_dict = {
350
- k .replace (f"{ prefix } ." , "" ): v for k , v in state_dict .items () if k in text_encoder_keys
351
- }
344
+ # Load the layers corresponding to text encoder and make necessary adjustments.
345
+ if prefix is not None :
346
+ state_dict = {k [len (f"{ prefix } ." ) :]: v for k , v in state_dict .items () if k .startswith (f"{ prefix } ." )}
347
+
348
+ if len (state_dict ) > 0 :
349
+ logger .info (f"Loading { prefix } ." )
350
+ rank = {}
351
+ state_dict = convert_state_dict_to_diffusers (state_dict )
352
+
353
+ # convert state dict
354
+ state_dict = convert_state_dict_to_peft (state_dict )
355
+
356
+ for name , _ in text_encoder_attn_modules (text_encoder ):
357
+ for module in ("out_proj" , "q_proj" , "k_proj" , "v_proj" ):
358
+ rank_key = f"{ name } .{ module } .lora_B.weight"
359
+ if rank_key not in state_dict :
360
+ continue
361
+ rank [rank_key ] = state_dict [rank_key ].shape [1 ]
362
+
363
+ for name , _ in text_encoder_mlp_modules (text_encoder ):
364
+ for module in ("fc1" , "fc2" ):
365
+ rank_key = f"{ name } .{ module } .lora_B.weight"
366
+ if rank_key not in state_dict :
367
+ continue
368
+ rank [rank_key ] = state_dict [rank_key ].shape [1 ]
369
+
370
+ if network_alphas is not None :
371
+ alpha_keys = [k for k in network_alphas .keys () if k .startswith (prefix ) and k .split ("." )[0 ] == prefix ]
372
+ network_alphas = {k .replace (f"{ prefix } ." , "" ): v for k , v in network_alphas .items () if k in alpha_keys }
373
+
374
+ lora_config_kwargs = get_peft_kwargs (rank , network_alphas , state_dict , is_unet = False )
375
+
376
+ if "use_dora" in lora_config_kwargs :
377
+ if lora_config_kwargs ["use_dora" ]:
378
+ if is_peft_version ("<" , "0.9.0" ):
379
+ raise ValueError (
380
+ "You need `peft` 0.9.0 at least to use DoRA-enabled LoRAs. Please upgrade your installation of `peft`."
381
+ )
382
+ else :
383
+ if is_peft_version ("<" , "0.9.0" ):
384
+ lora_config_kwargs .pop ("use_dora" )
385
+
386
+ if "lora_bias" in lora_config_kwargs :
387
+ if lora_config_kwargs ["lora_bias" ]:
388
+ if is_peft_version ("<=" , "0.13.2" ):
389
+ raise ValueError (
390
+ "You need `peft` 0.14.0 at least to use `bias` in LoRAs. Please upgrade your installation of `peft`."
391
+ )
392
+ else :
393
+ if is_peft_version ("<=" , "0.13.2" ):
394
+ lora_config_kwargs .pop ("lora_bias" )
352
395
353
- if len (text_encoder_lora_state_dict ) > 0 :
354
- logger .info (f"Loading { prefix } ." )
355
- rank = {}
356
- text_encoder_lora_state_dict = convert_state_dict_to_diffusers (text_encoder_lora_state_dict )
357
-
358
- # convert state dict
359
- text_encoder_lora_state_dict = convert_state_dict_to_peft (text_encoder_lora_state_dict )
360
-
361
- for name , _ in text_encoder_attn_modules (text_encoder ):
362
- for module in ("out_proj" , "q_proj" , "k_proj" , "v_proj" ):
363
- rank_key = f"{ name } .{ module } .lora_B.weight"
364
- if rank_key not in text_encoder_lora_state_dict :
365
- continue
366
- rank [rank_key ] = text_encoder_lora_state_dict [rank_key ].shape [1 ]
367
-
368
- for name , _ in text_encoder_mlp_modules (text_encoder ):
369
- for module in ("fc1" , "fc2" ):
370
- rank_key = f"{ name } .{ module } .lora_B.weight"
371
- if rank_key not in text_encoder_lora_state_dict :
372
- continue
373
- rank [rank_key ] = text_encoder_lora_state_dict [rank_key ].shape [1 ]
374
-
375
- if network_alphas is not None :
376
- alpha_keys = [k for k in network_alphas .keys () if k .startswith (prefix ) and k .split ("." )[0 ] == prefix ]
377
- network_alphas = {k .replace (f"{ prefix } ." , "" ): v for k , v in network_alphas .items () if k in alpha_keys }
378
-
379
- lora_config_kwargs = get_peft_kwargs (rank , network_alphas , text_encoder_lora_state_dict , is_unet = False )
380
-
381
- if "use_dora" in lora_config_kwargs :
382
- if lora_config_kwargs ["use_dora" ]:
383
- if is_peft_version ("<" , "0.9.0" ):
384
- raise ValueError (
385
- "You need `peft` 0.9.0 at least to use DoRA-enabled LoRAs. Please upgrade your installation of `peft`."
386
- )
387
- else :
388
- if is_peft_version ("<" , "0.9.0" ):
389
- lora_config_kwargs .pop ("use_dora" )
390
-
391
- if "lora_bias" in lora_config_kwargs :
392
- if lora_config_kwargs ["lora_bias" ]:
393
- if is_peft_version ("<=" , "0.13.2" ):
394
- raise ValueError (
395
- "You need `peft` 0.14.0 at least to use `bias` in LoRAs. Please upgrade your installation of `peft`."
396
- )
397
- else :
398
- if is_peft_version ("<=" , "0.13.2" ):
399
- lora_config_kwargs .pop ("lora_bias" )
396
+ lora_config = LoraConfig (** lora_config_kwargs )
400
397
401
- lora_config = LoraConfig (** lora_config_kwargs )
398
+ # adapter_name
399
+ if adapter_name is None :
400
+ adapter_name = get_adapter_name (text_encoder )
402
401
403
- # adapter_name
404
- if adapter_name is None :
405
- adapter_name = get_adapter_name (text_encoder )
402
+ is_model_cpu_offload , is_sequential_cpu_offload = _func_optionally_disable_offloading (_pipeline )
406
403
407
- is_model_cpu_offload , is_sequential_cpu_offload = _func_optionally_disable_offloading (_pipeline )
404
+ # inject LoRA layers and load the state dict
405
+ # in transformers we automatically check whether the adapter name is already in use or not
406
+ text_encoder .load_adapter (
407
+ adapter_name = adapter_name ,
408
+ adapter_state_dict = state_dict ,
409
+ peft_config = lora_config ,
410
+ ** peft_kwargs ,
411
+ )
408
412
409
- # inject LoRA layers and load the state dict
410
- # in transformers we automatically check whether the adapter name is already in use or not
411
- text_encoder .load_adapter (
412
- adapter_name = adapter_name ,
413
- adapter_state_dict = text_encoder_lora_state_dict ,
414
- peft_config = lora_config ,
415
- ** peft_kwargs ,
416
- )
413
+ # scale LoRA layers with `lora_scale`
414
+ scale_lora_layers (text_encoder , weight = lora_scale )
417
415
418
- # scale LoRA layers with `lora_scale`
419
- scale_lora_layers (text_encoder , weight = lora_scale )
416
+ text_encoder .to (device = text_encoder .device , dtype = text_encoder .dtype )
420
417
421
- text_encoder .to (device = text_encoder .device , dtype = text_encoder .dtype )
418
+ # Offload back.
419
+ if is_model_cpu_offload :
420
+ _pipeline .enable_model_cpu_offload ()
421
+ elif is_sequential_cpu_offload :
422
+ _pipeline .enable_sequential_cpu_offload ()
423
+ # Unsafe code />
422
424
423
- # Offload back.
424
- if is_model_cpu_offload :
425
- _pipeline .enable_model_cpu_offload ()
426
- elif is_sequential_cpu_offload :
427
- _pipeline .enable_sequential_cpu_offload ()
428
- # Unsafe code />
425
+ if prefix is not None and not state_dict :
426
+ logger .info (
427
+ f"No LoRA keys associated to { text_encoder .__class__ .__name__ } found with the { prefix = } . This is safe to ignore if LoRA state dict didn't originally have any { text_encoder .__class__ .__name__ } related params. Open an issue if you think it's unexpected: https://github.com/huggingface/diffusers/issues/new"
428
+ )
429
429
430
430
431
431
def _func_optionally_disable_offloading (_pipeline ):
0 commit comments