|
17 | 17 |
|
18 | 18 | import torch
|
19 | 19 |
|
20 |
| -from ..utils import is_peft_version, logging |
| 20 | +from ..utils import is_peft_version, logging, state_dict_all_zero |
21 | 21 |
|
22 | 22 |
|
23 | 23 | logger = logging.get_logger(__name__)
|
@@ -755,29 +755,67 @@ def _convert(original_key, diffusers_key, state_dict, new_state_dict):
|
755 | 755 | state_dict = {k.replace("diffusion_model.", "lora_unet_"): v for k, v in state_dict.items()}
|
756 | 756 | state_dict = {k.replace("text_encoders.clip_l.transformer.", "lora_te_"): v for k, v in state_dict.items()}
|
757 | 757 | has_t5xxl = any(k.startswith("text_encoders.t5xxl.transformer.") for k in state_dict)
|
| 758 | + |
| 759 | + if any("position_embedding" in k for k in state_dict): |
| 760 | + zero_status_pe = state_dict_all_zero(state_dict, "position_embedding") |
| 761 | + if zero_status_pe: |
| 762 | + logger.info( |
| 763 | + "The `position_embedding` LoRA params are all zeros which make them ineffective. " |
| 764 | + "So, we will purge them out of the curret state dict to make loading possible." |
| 765 | + ) |
| 766 | + current_pe_lora_keys = [k for k in state_dict if "position_embedding" in k] |
| 767 | + for k in current_pe_lora_keys: |
| 768 | + state_dict.pop(k) |
| 769 | + else: |
| 770 | + raise NotImplementedError( |
| 771 | + "The state_dict has position_embedding LoRA params and we currently do not support them. " |
| 772 | + "Open an issue if you need this supported - https://github.com/huggingface/diffusers/issues/new." |
| 773 | + ) |
| 774 | + |
758 | 775 | if has_t5xxl:
|
759 |
| - logger.info( |
760 |
| - "T5-xxl keys found in the state dict, which are currently unsupported. We will filter them out." |
761 |
| - "Open an issue if this is a problem - https://github.com/huggingface/diffusers/issues/new." |
762 |
| - ) |
| 776 | + zero_status_t5 = state_dict_all_zero(state_dict, "text_encoders.t5xxl") |
| 777 | + if zero_status_t5: |
| 778 | + logger.info( |
| 779 | + "The `t5xxl` LoRA params are all zeros which make them ineffective. " |
| 780 | + "So, we will purge them out of the curret state dict to make loading possible." |
| 781 | + ) |
| 782 | + else: |
| 783 | + logger.info( |
| 784 | + "T5-xxl keys found in the state dict, which are currently unsupported. We will filter them out." |
| 785 | + "Open an issue if this is a problem - https://github.com/huggingface/diffusers/issues/new." |
| 786 | + ) |
763 | 787 | state_dict = {k: v for k, v in state_dict.items() if not k.startswith("text_encoders.t5xxl.transformer.")}
|
764 | 788 |
|
765 | 789 | any_diffb_keys = any("diff_b" in k and k.startswith(("lora_unet_", "lora_te_")) for k in state_dict)
|
766 | 790 | if any_diffb_keys:
|
767 |
| - logger.info( |
768 |
| - "`diff_b` keys found in the state dict which are currently unsupported. " |
769 |
| - "So, we will filter out those keys. Open an issue if this is a problem - " |
770 |
| - "https://github.com/huggingface/diffusers/issues/new." |
771 |
| - ) |
| 791 | + zero_status_diff_b = state_dict_all_zero(state_dict, "diff_b") |
| 792 | + if zero_status_diff_b: |
| 793 | + logger.info( |
| 794 | + "The `diff_b` LoRA params are all zeros which make them ineffective. " |
| 795 | + "So, we will purge them out of the curret state dict to make loading possible." |
| 796 | + ) |
| 797 | + else: |
| 798 | + logger.info( |
| 799 | + "`diff_b` keys found in the state dict which are currently unsupported. " |
| 800 | + "So, we will filter out those keys. Open an issue if this is a problem - " |
| 801 | + "https://github.com/huggingface/diffusers/issues/new." |
| 802 | + ) |
772 | 803 | state_dict = {k: v for k, v in state_dict.items() if "diff_b" not in k}
|
773 | 804 |
|
774 | 805 | any_norm_diff_keys = any("norm" in k and "diff" in k for k in state_dict)
|
775 | 806 | if any_norm_diff_keys:
|
776 |
| - logger.info( |
777 |
| - "Normalization diff keys found in the state dict which are currently unsupported. " |
778 |
| - "So, we will filter out those keys. Open an issue if this is a problem - " |
779 |
| - "https://github.com/huggingface/diffusers/issues/new." |
780 |
| - ) |
| 807 | + zero_status_diff = state_dict_all_zero(state_dict, "diff") |
| 808 | + if zero_status_diff: |
| 809 | + logger.info( |
| 810 | + "The `diff` LoRA params are all zeros which make them ineffective. " |
| 811 | + "So, we will purge them out of the curret state dict to make loading possible." |
| 812 | + ) |
| 813 | + else: |
| 814 | + logger.info( |
| 815 | + "Normalization diff keys found in the state dict which are currently unsupported. " |
| 816 | + "So, we will filter out those keys. Open an issue if this is a problem - " |
| 817 | + "https://github.com/huggingface/diffusers/issues/new." |
| 818 | + ) |
781 | 819 | state_dict = {k: v for k, v in state_dict.items() if "norm" not in k and "diff" not in k}
|
782 | 820 |
|
783 | 821 | limit_substrings = ["lora_down", "lora_up"]
|
|
0 commit comments