13
13
# limitations under the License.
14
14
15
15
import re
16
+ from typing import List
16
17
17
18
import torch
18
19
22
23
logger = logging .get_logger (__name__ )
23
24
24
25
26
+ def swap_scale_shift (weight ):
27
+ shift , scale = weight .chunk (2 , dim = 0 )
28
+ new_weight = torch .cat ([scale , shift ], dim = 0 )
29
+ return new_weight
30
+
31
+
25
32
def _maybe_map_sgm_blocks_to_diffusers (state_dict , unet_config , delimiter = "_" , block_slice_pos = 5 ):
26
33
# 1. get all state_dict_keys
27
34
all_keys = list (state_dict .keys ())
@@ -299,7 +306,9 @@ def _convert_text_encoder_lora_key(key, lora_name):
299
306
key_to_replace = "lora_te2_"
300
307
301
308
diffusers_name = key .replace (key_to_replace , "" ).replace ("_" , "." )
309
+
302
310
diffusers_name = diffusers_name .replace ("text.model" , "text_model" )
311
+ diffusers_name = diffusers_name .replace ("position.embedding" , "position_embedding" )
303
312
diffusers_name = diffusers_name .replace ("self.attn" , "self_attn" )
304
313
diffusers_name = diffusers_name .replace ("q.proj.lora" , "to_q_lora" )
305
314
diffusers_name = diffusers_name .replace ("k.proj.lora" , "to_k_lora" )
@@ -313,6 +322,7 @@ def _convert_text_encoder_lora_key(key, lora_name):
313
322
# Be aware that this is the new diffusers convention and the rest of the code might
314
323
# not utilize it yet.
315
324
diffusers_name = diffusers_name .replace (".lora." , ".lora_linear_layer." )
325
+
316
326
return diffusers_name
317
327
318
328
@@ -341,7 +351,8 @@ def _convert_to_ai_toolkit(sds_sd, ait_sd, sds_key, ait_key):
341
351
342
352
# scale weight by alpha and dim
343
353
rank = down_weight .shape [0 ]
344
- alpha = sds_sd .pop (sds_key + ".alpha" ).item () # alpha is scalar
354
+ default_alpha = torch .tensor (rank , dtype = down_weight .dtype , device = down_weight .device , requires_grad = False )
355
+ alpha = sds_sd .pop (sds_key + ".alpha" , default_alpha ).item () # alpha is scalar
345
356
scale = alpha / rank # LoRA is scaled by 'alpha / rank' in forward pass, so we need to scale it back here
346
357
347
358
# calculate scale_down and scale_up to keep the same value. if scale is 4, scale_down is 2 and scale_up is 2
@@ -362,7 +373,10 @@ def _convert_to_ai_toolkit_cat(sds_sd, ait_sd, sds_key, ait_keys, dims=None):
362
373
sd_lora_rank = down_weight .shape [0 ]
363
374
364
375
# scale weight by alpha and dim
365
- alpha = sds_sd .pop (sds_key + ".alpha" )
376
+ default_alpha = torch .tensor (
377
+ sd_lora_rank , dtype = down_weight .dtype , device = down_weight .device , requires_grad = False
378
+ )
379
+ alpha = sds_sd .pop (sds_key + ".alpha" , default_alpha )
366
380
scale = alpha / sd_lora_rank
367
381
368
382
# calculate scale_down and scale_up
@@ -516,10 +530,62 @@ def _convert_sd_scripts_to_ai_toolkit(sds_sd):
516
530
f"transformer.single_transformer_blocks.{ i } .norm.linear" ,
517
531
)
518
532
533
+ # TODO: alphas.
534
+ if any ("final_layer" in k for k in sds_sd ):
535
+ for lora_key in ["lora_A" , "lora_B" ]:
536
+ orig_lora_key = "lora_down" if lora_key == "lora_A" else "lora_up"
537
+ # Notice the swap.
538
+ ait_sd [f"norm_out.linear.{ lora_key } .weight" ] = swap_scale_shift (
539
+ sds_sd .pop (f"lora_unet_final_layer_adaLN_modulation_1.{ orig_lora_key } .weight" )
540
+ )
541
+ ait_sd [f"proj_out.{ lora_key } .weight" ] = sds_sd .pop (
542
+ f"lora_unet_final_layer_linear.{ orig_lora_key } .weight"
543
+ )
544
+
545
+ if any ("guidance_in" in k for k in sds_sd ):
546
+ for lora_key in ["lora_A" , "lora_B" ]:
547
+ orig_lora_key = "lora_down" if lora_key == "lora_A" else "lora_up"
548
+ ait_sd [f"time_text_embed.guidance_embedder.linear_1.{ lora_key } .weight" ] = sds_sd .pop (
549
+ f"lora_unet_guidance_in_in_layer.{ orig_lora_key } .weight"
550
+ )
551
+ ait_sd [f"time_text_embed.guidance_embedder.linear_2.{ lora_key } .weight" ] = sds_sd .pop (
552
+ f"lora_unet_guidance_in_out_layer.{ orig_lora_key } .weight"
553
+ )
554
+
555
+ if any ("img_in" in k for k in sds_sd ):
556
+ for lora_key in ["lora_A" , "lora_B" ]:
557
+ orig_lora_key = "lora_down" if lora_key == "lora_A" else "lora_up"
558
+ ait_sd [f"x_embedder.{ lora_key } .weight" ] = sds_sd .pop (f"lora_unet_img_in.{ orig_lora_key } .weight" )
559
+
560
+ if any ("txt_in" in k for k in sds_sd ):
561
+ for lora_key in ["lora_A" , "lora_B" ]:
562
+ orig_lora_key = "lora_down" if lora_key == "lora_A" else "lora_up"
563
+ ait_sd [f"context_embedder.{ lora_key } .weight" ] = sds_sd .pop (f"lora_unet_txt_in.{ orig_lora_key } .weight" )
564
+
565
+ if any ("time_in" in k for k in state_dict ):
566
+ for lora_key in ["lora_A" , "lora_B" ]:
567
+ orig_lora_key = "lora_down" if lora_key == "lora_A" else "lora_up"
568
+ ait_sd [f"time_text_embed.timestep_embedder.linear_1.{ lora_key } .weight" ] = sds_sd .pop (
569
+ f"lora_unet_time_in_in_layer.{ orig_lora_key } .weight"
570
+ )
571
+ ait_sd [f"time_text_embed.timestep_embedder.linear_2.{ lora_key } .weight" ] = sds_sd .pop (
572
+ f"lora_unet_time_in_out_layer.{ orig_lora_key } .weight"
573
+ )
574
+
575
+ if any ("vector_in" in k for k in sds_sd ):
576
+ for lora_key in ["lora_A" , "lora_B" ]:
577
+ orig_lora_key = "lora_down" if lora_key == "lora_A" else "lora_up"
578
+ ait_sd [f"time_text_embed.text_embedder.linear_1.{ lora_key } .weight" ] = sds_sd .pop (
579
+ f"lora_unet_vector_in_in_layer.{ orig_lora_key } .weight"
580
+ )
581
+ ait_sd [f"time_text_embed.text_embedder.linear_2.{ lora_key } .weight" ] = sds_sd .pop (
582
+ f"lora_unet_vector_in_out_layer.{ orig_lora_key } .weight"
583
+ )
584
+
519
585
remaining_keys = list (sds_sd .keys ())
520
586
te_state_dict = {}
521
587
if remaining_keys :
522
- if not all (k .startswith ("lora_te" ) for k in remaining_keys ):
588
+ if not all (k .startswith (( "lora_te" , "lora_te1" ) ) for k in remaining_keys ):
523
589
raise ValueError (f"Incompatible keys detected: \n \n { ', ' .join (remaining_keys )} " )
524
590
for key in remaining_keys :
525
591
if not key .endswith ("lora_down.weight" ):
@@ -680,10 +746,59 @@ def _convert(original_key, diffusers_key, state_dict, new_state_dict):
680
746
if has_peft_state_dict :
681
747
state_dict = {k : v for k , v in state_dict .items () if k .startswith ("transformer." )}
682
748
return state_dict
749
+
683
750
# Another weird one.
684
751
has_mixture = any (
685
752
k .startswith ("lora_transformer_" ) and ("lora_down" in k or "lora_up" in k or "alpha" in k ) for k in state_dict
686
753
)
754
+
755
+ # ComfyUI.
756
+ state_dict = {k .replace ("diffusion_model." , "lora_unet." ): v for k , v in state_dict .items ()}
757
+ state_dict = {k .replace ("text_encoders.clip_l.transformer." , "lora_te." ): v for k , v in state_dict .items ()}
758
+ has_t5xxl = any (k .startswith ("text_encoders.t5xxl.transformer." ) for k in state_dict )
759
+ if has_t5xxl :
760
+ logger .info (
761
+ "T5-xxl keys found in the state dict, which are currently unsupported. We will filter them out."
762
+ "Open an issue if this is a problem - https://github.com/huggingface/diffusers/issues/new."
763
+ )
764
+ state_dict = {k : v for k , v in state_dict .items () if not k .startswith ("text_encoders.t5xxl.transformer." )}
765
+
766
+ any_diffb_keys = any ("diff_b" in k and k .startswith (("lora_unet." , "lora_te." )) for k in state_dict )
767
+ if any_diffb_keys :
768
+ logger .info (
769
+ "`diff_b` keys found in the state dict which are currently unsupported. "
770
+ "So, we will filter out those keys. Open an issue if this is a problem - "
771
+ "https://github.com/huggingface/diffusers/issues/new."
772
+ )
773
+ state_dict = {k : v for k , v in state_dict .items () if "diff_b" not in k }
774
+
775
+ any_norm_diff_keys = any ("norm" in k and "diff" in k for k in state_dict )
776
+ if any_norm_diff_keys :
777
+ logger .info (
778
+ "Normalization diff keys found in the state dict which are currently unsupported. "
779
+ "So, we will filter out those keys. Open an issue if this is a problem - "
780
+ "https://github.com/huggingface/diffusers/issues/new."
781
+ )
782
+ state_dict = {k : v for k , v in state_dict .items () if "norm" not in k and "diff" not in k }
783
+
784
+ limit_substrings = ["lora_down" , "lora_up" ]
785
+ if any ("alpha" in k for k in state_dict ):
786
+ limit_substrings .append ("alpha" )
787
+
788
+ state_dict = {
789
+ _custom_replace (k , limit_substrings ): v
790
+ for k , v in state_dict .items ()
791
+ if k .startswith (("lora_unet." , "lora_te." ))
792
+ }
793
+
794
+ if any ("text_projection" in k for k in state_dict ):
795
+ logger .info (
796
+ "`text_projection` keys found in the state_dict which are unexpected. "
797
+ "So, we will filter out those keys. Open an issue if this is a problem - "
798
+ "https://github.com/huggingface/diffusers/issues/new."
799
+ )
800
+ state_dict = {k : v for k , v in state_dict .items () if "text_projection" not in k }
801
+
687
802
if has_mixture :
688
803
return _convert_mixture_state_dict_to_diffusers (state_dict )
689
804
@@ -798,6 +913,23 @@ def handle_qkv(sds_sd, ait_sd, sds_key, ait_keys, dims=None):
798
913
return new_state_dict
799
914
800
915
916
+ def _custom_replace (key : str , substrings : List [str ]) -> str :
917
+ pattern = "(" + "|" .join (re .escape (sub ) for sub in substrings ) + ")"
918
+
919
+ match = re .search (pattern , key )
920
+ if match :
921
+ start_sub = match .start ()
922
+ if start_sub > 0 and key [start_sub - 1 ] == "." :
923
+ boundary = start_sub - 1
924
+ else :
925
+ boundary = start_sub
926
+ left = key [:boundary ].replace ("." , "_" )
927
+ right = key [boundary :]
928
+ return left + right
929
+ else :
930
+ return key .replace ("." , "_" )
931
+
932
+
801
933
def _convert_bfl_flux_control_lora_to_diffusers (original_state_dict ):
802
934
converted_state_dict = {}
803
935
original_state_dict_keys = list (original_state_dict .keys ())
@@ -806,11 +938,6 @@ def _convert_bfl_flux_control_lora_to_diffusers(original_state_dict):
806
938
inner_dim = 3072
807
939
mlp_ratio = 4.0
808
940
809
- def swap_scale_shift (weight ):
810
- shift , scale = weight .chunk (2 , dim = 0 )
811
- new_weight = torch .cat ([scale , shift ], dim = 0 )
812
- return new_weight
813
-
814
941
for lora_key in ["lora_A" , "lora_B" ]:
815
942
## time_text_embed.timestep_embedder <- time_in
816
943
converted_state_dict [
0 commit comments