@@ -2364,7 +2364,7 @@ def save_lora_weights(
2364
2364
2365
2365
class CogVideoXLoraLoaderMixin (LoraBaseMixin ):
2366
2366
r"""
2367
- Load LoRA layers into [`CogVideoXTransformer3DModel`]. Specific to [`CogVideoX `].
2367
+ Load LoRA layers into [`CogVideoXTransformer3DModel`]. Specific to [`CogVideoXPipeline `].
2368
2368
"""
2369
2369
2370
2370
_lora_loadable_modules = ["transformer" ]
@@ -2669,6 +2669,314 @@ def unfuse_lora(self, components: List[str] = ["transformer", "text_encoder"], *
2669
2669
super ().unfuse_lora (components = components )
2670
2670
2671
2671
2672
+ class Mochi1LoraLoaderMixin (LoraBaseMixin ):
2673
+ r"""
2674
+ Load LoRA layers into [`MochiTransformer3DModel`]. Specific to [`MochiPipeline`].
2675
+ """
2676
+
2677
+ _lora_loadable_modules = ["transformer" ]
2678
+ transformer_name = TRANSFORMER_NAME
2679
+
2680
+ @classmethod
2681
+ @validate_hf_hub_args
2682
+ # Copied from diffusers.loaders.lora_pipeline.SD3LoraLoaderMixin.lora_state_dict
2683
+ def lora_state_dict (
2684
+ cls ,
2685
+ pretrained_model_name_or_path_or_dict : Union [str , Dict [str , torch .Tensor ]],
2686
+ ** kwargs ,
2687
+ ):
2688
+ r"""
2689
+ Return state dict for lora weights and the network alphas.
2690
+
2691
+ <Tip warning={true}>
2692
+
2693
+ We support loading A1111 formatted LoRA checkpoints in a limited capacity.
2694
+
2695
+ This function is experimental and might change in the future.
2696
+
2697
+ </Tip>
2698
+
2699
+ Parameters:
2700
+ pretrained_model_name_or_path_or_dict (`str` or `os.PathLike` or `dict`):
2701
+ Can be either:
2702
+
2703
+ - A string, the *model id* (for example `google/ddpm-celebahq-256`) of a pretrained model hosted on
2704
+ the Hub.
2705
+ - A path to a *directory* (for example `./my_model_directory`) containing the model weights saved
2706
+ with [`ModelMixin.save_pretrained`].
2707
+ - A [torch state
2708
+ dict](https://pytorch.org/tutorials/beginner/saving_loading_models.html#what-is-a-state-dict).
2709
+
2710
+ cache_dir (`Union[str, os.PathLike]`, *optional*):
2711
+ Path to a directory where a downloaded pretrained model configuration is cached if the standard cache
2712
+ is not used.
2713
+ force_download (`bool`, *optional*, defaults to `False`):
2714
+ Whether or not to force the (re-)download of the model weights and configuration files, overriding the
2715
+ cached versions if they exist.
2716
+
2717
+ proxies (`Dict[str, str]`, *optional*):
2718
+ A dictionary of proxy servers to use by protocol or endpoint, for example, `{'http': 'foo.bar:3128',
2719
+ 'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request.
2720
+ local_files_only (`bool`, *optional*, defaults to `False`):
2721
+ Whether to only load local model weights and configuration files or not. If set to `True`, the model
2722
+ won't be downloaded from the Hub.
2723
+ token (`str` or *bool*, *optional*):
2724
+ The token to use as HTTP bearer authorization for remote files. If `True`, the token generated from
2725
+ `diffusers-cli login` (stored in `~/.huggingface`) is used.
2726
+ revision (`str`, *optional*, defaults to `"main"`):
2727
+ The specific model version to use. It can be a branch name, a tag name, a commit id, or any identifier
2728
+ allowed by Git.
2729
+ subfolder (`str`, *optional*, defaults to `""`):
2730
+ The subfolder location of a model file within a larger model repository on the Hub or locally.
2731
+
2732
+ """
2733
+ # Load the main state dict first which has the LoRA layers for either of
2734
+ # transformer and text encoder or both.
2735
+ cache_dir = kwargs .pop ("cache_dir" , None )
2736
+ force_download = kwargs .pop ("force_download" , False )
2737
+ proxies = kwargs .pop ("proxies" , None )
2738
+ local_files_only = kwargs .pop ("local_files_only" , None )
2739
+ token = kwargs .pop ("token" , None )
2740
+ revision = kwargs .pop ("revision" , None )
2741
+ subfolder = kwargs .pop ("subfolder" , None )
2742
+ weight_name = kwargs .pop ("weight_name" , None )
2743
+ use_safetensors = kwargs .pop ("use_safetensors" , None )
2744
+
2745
+ allow_pickle = False
2746
+ if use_safetensors is None :
2747
+ use_safetensors = True
2748
+ allow_pickle = True
2749
+
2750
+ user_agent = {
2751
+ "file_type" : "attn_procs_weights" ,
2752
+ "framework" : "pytorch" ,
2753
+ }
2754
+
2755
+ state_dict = _fetch_state_dict (
2756
+ pretrained_model_name_or_path_or_dict = pretrained_model_name_or_path_or_dict ,
2757
+ weight_name = weight_name ,
2758
+ use_safetensors = use_safetensors ,
2759
+ local_files_only = local_files_only ,
2760
+ cache_dir = cache_dir ,
2761
+ force_download = force_download ,
2762
+ proxies = proxies ,
2763
+ token = token ,
2764
+ revision = revision ,
2765
+ subfolder = subfolder ,
2766
+ user_agent = user_agent ,
2767
+ allow_pickle = allow_pickle ,
2768
+ )
2769
+
2770
+ is_dora_scale_present = any ("dora_scale" in k for k in state_dict )
2771
+ if is_dora_scale_present :
2772
+ warn_msg = "It seems like you are using a DoRA checkpoint that is not compatible in Diffusers at the moment. So, we are going to filter out the keys associated to 'dora_scale` from the state dict. If you think this is a mistake please open an issue https://github.com/huggingface/diffusers/issues/new."
2773
+ logger .warning (warn_msg )
2774
+ state_dict = {k : v for k , v in state_dict .items () if "dora_scale" not in k }
2775
+
2776
+ return state_dict
2777
+
2778
+ # Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.load_lora_weights
2779
+ def load_lora_weights (
2780
+ self , pretrained_model_name_or_path_or_dict : Union [str , Dict [str , torch .Tensor ]], adapter_name = None , ** kwargs
2781
+ ):
2782
+ """
2783
+ Load LoRA weights specified in `pretrained_model_name_or_path_or_dict` into `self.transformer` and
2784
+ `self.text_encoder`. All kwargs are forwarded to `self.lora_state_dict`. See
2785
+ [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`] for more details on how the state dict is loaded.
2786
+ See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_into_transformer`] for more details on how the state
2787
+ dict is loaded into `self.transformer`.
2788
+
2789
+ Parameters:
2790
+ pretrained_model_name_or_path_or_dict (`str` or `os.PathLike` or `dict`):
2791
+ See [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`].
2792
+ adapter_name (`str`, *optional*):
2793
+ Adapter name to be used for referencing the loaded adapter model. If not specified, it will use
2794
+ `default_{i}` where i is the total number of adapters being loaded.
2795
+ low_cpu_mem_usage (`bool`, *optional*):
2796
+ Speed up model loading by only loading the pretrained LoRA weights and not initializing the random
2797
+ weights.
2798
+ kwargs (`dict`, *optional*):
2799
+ See [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`].
2800
+ """
2801
+ if not USE_PEFT_BACKEND :
2802
+ raise ValueError ("PEFT backend is required for this method." )
2803
+
2804
+ low_cpu_mem_usage = kwargs .pop ("low_cpu_mem_usage" , _LOW_CPU_MEM_USAGE_DEFAULT_LORA )
2805
+ if low_cpu_mem_usage and is_peft_version ("<" , "0.13.0" ):
2806
+ raise ValueError (
2807
+ "`low_cpu_mem_usage=True` is not compatible with this `peft` version. Please update it with `pip install -U peft`."
2808
+ )
2809
+
2810
+ # if a dict is passed, copy it instead of modifying it inplace
2811
+ if isinstance (pretrained_model_name_or_path_or_dict , dict ):
2812
+ pretrained_model_name_or_path_or_dict = pretrained_model_name_or_path_or_dict .copy ()
2813
+
2814
+ # First, ensure that the checkpoint is a compatible one and can be successfully loaded.
2815
+ state_dict = self .lora_state_dict (pretrained_model_name_or_path_or_dict , ** kwargs )
2816
+
2817
+ is_correct_format = all ("lora" in key for key in state_dict .keys ())
2818
+ if not is_correct_format :
2819
+ raise ValueError ("Invalid LoRA checkpoint." )
2820
+
2821
+ self .load_lora_into_transformer (
2822
+ state_dict ,
2823
+ transformer = getattr (self , self .transformer_name ) if not hasattr (self , "transformer" ) else self .transformer ,
2824
+ adapter_name = adapter_name ,
2825
+ _pipeline = self ,
2826
+ low_cpu_mem_usage = low_cpu_mem_usage ,
2827
+ )
2828
+
2829
+ @classmethod
2830
+ # Copied from diffusers.loaders.lora_pipeline.SD3LoraLoaderMixin.load_lora_into_transformer with SD3Transformer2DModel->CogVideoXTransformer3DModel
2831
+ def load_lora_into_transformer (
2832
+ cls , state_dict , transformer , adapter_name = None , _pipeline = None , low_cpu_mem_usage = False
2833
+ ):
2834
+ """
2835
+ This will load the LoRA layers specified in `state_dict` into `transformer`.
2836
+
2837
+ Parameters:
2838
+ state_dict (`dict`):
2839
+ A standard state dict containing the lora layer parameters. The keys can either be indexed directly
2840
+ into the unet or prefixed with an additional `unet` which can be used to distinguish between text
2841
+ encoder lora layers.
2842
+ transformer (`CogVideoXTransformer3DModel`):
2843
+ The Transformer model to load the LoRA layers into.
2844
+ adapter_name (`str`, *optional*):
2845
+ Adapter name to be used for referencing the loaded adapter model. If not specified, it will use
2846
+ `default_{i}` where i is the total number of adapters being loaded.
2847
+ low_cpu_mem_usage (`bool`, *optional*):
2848
+ Speed up model loading by only loading the pretrained LoRA weights and not initializing the random
2849
+ weights.
2850
+ """
2851
+ if low_cpu_mem_usage and is_peft_version ("<" , "0.13.0" ):
2852
+ raise ValueError (
2853
+ "`low_cpu_mem_usage=True` is not compatible with this `peft` version. Please update it with `pip install -U peft`."
2854
+ )
2855
+
2856
+ # Load the layers corresponding to transformer.
2857
+ logger .info (f"Loading { cls .transformer_name } ." )
2858
+ transformer .load_lora_adapter (
2859
+ state_dict ,
2860
+ network_alphas = None ,
2861
+ adapter_name = adapter_name ,
2862
+ _pipeline = _pipeline ,
2863
+ low_cpu_mem_usage = low_cpu_mem_usage ,
2864
+ )
2865
+
2866
+ @classmethod
2867
+ # Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.save_lora_weights
2868
+ def save_lora_weights (
2869
+ cls ,
2870
+ save_directory : Union [str , os .PathLike ],
2871
+ transformer_lora_layers : Dict [str , Union [torch .nn .Module , torch .Tensor ]] = None ,
2872
+ is_main_process : bool = True ,
2873
+ weight_name : str = None ,
2874
+ save_function : Callable = None ,
2875
+ safe_serialization : bool = True ,
2876
+ ):
2877
+ r"""
2878
+ Save the LoRA parameters corresponding to the UNet and text encoder.
2879
+
2880
+ Arguments:
2881
+ save_directory (`str` or `os.PathLike`):
2882
+ Directory to save LoRA parameters to. Will be created if it doesn't exist.
2883
+ transformer_lora_layers (`Dict[str, torch.nn.Module]` or `Dict[str, torch.Tensor]`):
2884
+ State dict of the LoRA layers corresponding to the `transformer`.
2885
+ is_main_process (`bool`, *optional*, defaults to `True`):
2886
+ Whether the process calling this is the main process or not. Useful during distributed training and you
2887
+ need to call this function on all processes. In this case, set `is_main_process=True` only on the main
2888
+ process to avoid race conditions.
2889
+ save_function (`Callable`):
2890
+ The function to use to save the state dictionary. Useful during distributed training when you need to
2891
+ replace `torch.save` with another method. Can be configured with the environment variable
2892
+ `DIFFUSERS_SAVE_MODE`.
2893
+ safe_serialization (`bool`, *optional*, defaults to `True`):
2894
+ Whether to save the model using `safetensors` or the traditional PyTorch way with `pickle`.
2895
+ """
2896
+ state_dict = {}
2897
+
2898
+ if not transformer_lora_layers :
2899
+ raise ValueError ("You must pass `transformer_lora_layers`." )
2900
+
2901
+ if transformer_lora_layers :
2902
+ state_dict .update (cls .pack_weights (transformer_lora_layers , cls .transformer_name ))
2903
+
2904
+ # Save the model
2905
+ cls .write_lora_layers (
2906
+ state_dict = state_dict ,
2907
+ save_directory = save_directory ,
2908
+ is_main_process = is_main_process ,
2909
+ weight_name = weight_name ,
2910
+ save_function = save_function ,
2911
+ safe_serialization = safe_serialization ,
2912
+ )
2913
+
2914
+ # Copied from diffusers.loaders.lora_pipeline.StableDiffusionLoraLoaderMixin.fuse_lora with unet->transformer
2915
+ def fuse_lora (
2916
+ self ,
2917
+ components : List [str ] = ["transformer" , "text_encoder" ],
2918
+ lora_scale : float = 1.0 ,
2919
+ safe_fusing : bool = False ,
2920
+ adapter_names : Optional [List [str ]] = None ,
2921
+ ** kwargs ,
2922
+ ):
2923
+ r"""
2924
+ Fuses the LoRA parameters into the original parameters of the corresponding blocks.
2925
+
2926
+ <Tip warning={true}>
2927
+
2928
+ This is an experimental API.
2929
+
2930
+ </Tip>
2931
+
2932
+ Args:
2933
+ components: (`List[str]`): List of LoRA-injectable components to fuse the LoRAs into.
2934
+ lora_scale (`float`, defaults to 1.0):
2935
+ Controls how much to influence the outputs with the LoRA parameters.
2936
+ safe_fusing (`bool`, defaults to `False`):
2937
+ Whether to check fused weights for NaN values before fusing and if values are NaN not fusing them.
2938
+ adapter_names (`List[str]`, *optional*):
2939
+ Adapter names to be used for fusing. If nothing is passed, all active adapters will be fused.
2940
+
2941
+ Example:
2942
+
2943
+ ```py
2944
+ from diffusers import DiffusionPipeline
2945
+ import torch
2946
+
2947
+ pipeline = DiffusionPipeline.from_pretrained(
2948
+ "stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.float16
2949
+ ).to("cuda")
2950
+ pipeline.load_lora_weights("nerijs/pixel-art-xl", weight_name="pixel-art-xl.safetensors", adapter_name="pixel")
2951
+ pipeline.fuse_lora(lora_scale=0.7)
2952
+ ```
2953
+ """
2954
+ super ().fuse_lora (
2955
+ components = components , lora_scale = lora_scale , safe_fusing = safe_fusing , adapter_names = adapter_names
2956
+ )
2957
+
2958
+ # Copied from diffusers.loaders.lora_pipeline.StableDiffusionLoraLoaderMixin.unfuse_lora with unet->transformer
2959
+ def unfuse_lora (self , components : List [str ] = ["transformer" , "text_encoder" ], ** kwargs ):
2960
+ r"""
2961
+ Reverses the effect of
2962
+ [`pipe.fuse_lora()`](https://huggingface.co/docs/diffusers/main/en/api/loaders#diffusers.loaders.LoraBaseMixin.fuse_lora).
2963
+
2964
+ <Tip warning={true}>
2965
+
2966
+ This is an experimental API.
2967
+
2968
+ </Tip>
2969
+
2970
+ Args:
2971
+ components (`List[str]`): List of LoRA-injectable components to unfuse LoRA from.
2972
+ unfuse_transformer (`bool`, defaults to `True`): Whether to unfuse the UNet LoRA parameters.
2973
+ unfuse_text_encoder (`bool`, defaults to `True`):
2974
+ Whether to unfuse the text encoder LoRA parameters. If the text encoder wasn't monkey-patched with the
2975
+ LoRA parameters then it won't have any effect.
2976
+ """
2977
+ super ().unfuse_lora (components = components )
2978
+
2979
+
2672
2980
class LoraLoaderMixin (StableDiffusionLoraLoaderMixin ):
2673
2981
def __init__ (self , * args , ** kwargs ):
2674
2982
deprecation_message = "LoraLoaderMixin is deprecated and this will be removed in a future version. Please use `StableDiffusionLoraLoaderMixin`, instead."
0 commit comments