Skip to content

Commit 805aa93

Browse files
authored
[LoRA] enable LoRA for Mochi-1 (#9943)
* feat: add lora support to Mochi-1.
1 parent f6f7afa commit 805aa93

File tree

6 files changed

+522
-5
lines changed

6 files changed

+522
-5
lines changed

src/diffusers/loaders/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -68,6 +68,7 @@ def text_encoder_attn_modules(text_encoder):
6868
"LoraLoaderMixin",
6969
"FluxLoraLoaderMixin",
7070
"CogVideoXLoraLoaderMixin",
71+
"Mochi1LoraLoaderMixin",
7172
]
7273
_import_structure["textual_inversion"] = ["TextualInversionLoaderMixin"]
7374
_import_structure["ip_adapter"] = ["IPAdapterMixin"]
@@ -88,6 +89,7 @@ def text_encoder_attn_modules(text_encoder):
8889
CogVideoXLoraLoaderMixin,
8990
FluxLoraLoaderMixin,
9091
LoraLoaderMixin,
92+
Mochi1LoraLoaderMixin,
9193
SD3LoraLoaderMixin,
9294
StableDiffusionLoraLoaderMixin,
9395
StableDiffusionXLLoraLoaderMixin,

src/diffusers/loaders/lora_pipeline.py

Lines changed: 309 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2364,7 +2364,7 @@ def save_lora_weights(
23642364

23652365
class CogVideoXLoraLoaderMixin(LoraBaseMixin):
23662366
r"""
2367-
Load LoRA layers into [`CogVideoXTransformer3DModel`]. Specific to [`CogVideoX`].
2367+
Load LoRA layers into [`CogVideoXTransformer3DModel`]. Specific to [`CogVideoXPipeline`].
23682368
"""
23692369

23702370
_lora_loadable_modules = ["transformer"]
@@ -2669,6 +2669,314 @@ def unfuse_lora(self, components: List[str] = ["transformer", "text_encoder"], *
26692669
super().unfuse_lora(components=components)
26702670

26712671

2672+
class Mochi1LoraLoaderMixin(LoraBaseMixin):
2673+
r"""
2674+
Load LoRA layers into [`MochiTransformer3DModel`]. Specific to [`MochiPipeline`].
2675+
"""
2676+
2677+
_lora_loadable_modules = ["transformer"]
2678+
transformer_name = TRANSFORMER_NAME
2679+
2680+
@classmethod
2681+
@validate_hf_hub_args
2682+
# Copied from diffusers.loaders.lora_pipeline.SD3LoraLoaderMixin.lora_state_dict
2683+
def lora_state_dict(
2684+
cls,
2685+
pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]],
2686+
**kwargs,
2687+
):
2688+
r"""
2689+
Return state dict for lora weights and the network alphas.
2690+
2691+
<Tip warning={true}>
2692+
2693+
We support loading A1111 formatted LoRA checkpoints in a limited capacity.
2694+
2695+
This function is experimental and might change in the future.
2696+
2697+
</Tip>
2698+
2699+
Parameters:
2700+
pretrained_model_name_or_path_or_dict (`str` or `os.PathLike` or `dict`):
2701+
Can be either:
2702+
2703+
- A string, the *model id* (for example `google/ddpm-celebahq-256`) of a pretrained model hosted on
2704+
the Hub.
2705+
- A path to a *directory* (for example `./my_model_directory`) containing the model weights saved
2706+
with [`ModelMixin.save_pretrained`].
2707+
- A [torch state
2708+
dict](https://pytorch.org/tutorials/beginner/saving_loading_models.html#what-is-a-state-dict).
2709+
2710+
cache_dir (`Union[str, os.PathLike]`, *optional*):
2711+
Path to a directory where a downloaded pretrained model configuration is cached if the standard cache
2712+
is not used.
2713+
force_download (`bool`, *optional*, defaults to `False`):
2714+
Whether or not to force the (re-)download of the model weights and configuration files, overriding the
2715+
cached versions if they exist.
2716+
2717+
proxies (`Dict[str, str]`, *optional*):
2718+
A dictionary of proxy servers to use by protocol or endpoint, for example, `{'http': 'foo.bar:3128',
2719+
'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request.
2720+
local_files_only (`bool`, *optional*, defaults to `False`):
2721+
Whether to only load local model weights and configuration files or not. If set to `True`, the model
2722+
won't be downloaded from the Hub.
2723+
token (`str` or *bool*, *optional*):
2724+
The token to use as HTTP bearer authorization for remote files. If `True`, the token generated from
2725+
`diffusers-cli login` (stored in `~/.huggingface`) is used.
2726+
revision (`str`, *optional*, defaults to `"main"`):
2727+
The specific model version to use. It can be a branch name, a tag name, a commit id, or any identifier
2728+
allowed by Git.
2729+
subfolder (`str`, *optional*, defaults to `""`):
2730+
The subfolder location of a model file within a larger model repository on the Hub or locally.
2731+
2732+
"""
2733+
# Load the main state dict first which has the LoRA layers for either of
2734+
# transformer and text encoder or both.
2735+
cache_dir = kwargs.pop("cache_dir", None)
2736+
force_download = kwargs.pop("force_download", False)
2737+
proxies = kwargs.pop("proxies", None)
2738+
local_files_only = kwargs.pop("local_files_only", None)
2739+
token = kwargs.pop("token", None)
2740+
revision = kwargs.pop("revision", None)
2741+
subfolder = kwargs.pop("subfolder", None)
2742+
weight_name = kwargs.pop("weight_name", None)
2743+
use_safetensors = kwargs.pop("use_safetensors", None)
2744+
2745+
allow_pickle = False
2746+
if use_safetensors is None:
2747+
use_safetensors = True
2748+
allow_pickle = True
2749+
2750+
user_agent = {
2751+
"file_type": "attn_procs_weights",
2752+
"framework": "pytorch",
2753+
}
2754+
2755+
state_dict = _fetch_state_dict(
2756+
pretrained_model_name_or_path_or_dict=pretrained_model_name_or_path_or_dict,
2757+
weight_name=weight_name,
2758+
use_safetensors=use_safetensors,
2759+
local_files_only=local_files_only,
2760+
cache_dir=cache_dir,
2761+
force_download=force_download,
2762+
proxies=proxies,
2763+
token=token,
2764+
revision=revision,
2765+
subfolder=subfolder,
2766+
user_agent=user_agent,
2767+
allow_pickle=allow_pickle,
2768+
)
2769+
2770+
is_dora_scale_present = any("dora_scale" in k for k in state_dict)
2771+
if is_dora_scale_present:
2772+
warn_msg = "It seems like you are using a DoRA checkpoint that is not compatible in Diffusers at the moment. So, we are going to filter out the keys associated to 'dora_scale` from the state dict. If you think this is a mistake please open an issue https://github.com/huggingface/diffusers/issues/new."
2773+
logger.warning(warn_msg)
2774+
state_dict = {k: v for k, v in state_dict.items() if "dora_scale" not in k}
2775+
2776+
return state_dict
2777+
2778+
# Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.load_lora_weights
2779+
def load_lora_weights(
2780+
self, pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]], adapter_name=None, **kwargs
2781+
):
2782+
"""
2783+
Load LoRA weights specified in `pretrained_model_name_or_path_or_dict` into `self.transformer` and
2784+
`self.text_encoder`. All kwargs are forwarded to `self.lora_state_dict`. See
2785+
[`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`] for more details on how the state dict is loaded.
2786+
See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_into_transformer`] for more details on how the state
2787+
dict is loaded into `self.transformer`.
2788+
2789+
Parameters:
2790+
pretrained_model_name_or_path_or_dict (`str` or `os.PathLike` or `dict`):
2791+
See [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`].
2792+
adapter_name (`str`, *optional*):
2793+
Adapter name to be used for referencing the loaded adapter model. If not specified, it will use
2794+
`default_{i}` where i is the total number of adapters being loaded.
2795+
low_cpu_mem_usage (`bool`, *optional*):
2796+
Speed up model loading by only loading the pretrained LoRA weights and not initializing the random
2797+
weights.
2798+
kwargs (`dict`, *optional*):
2799+
See [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`].
2800+
"""
2801+
if not USE_PEFT_BACKEND:
2802+
raise ValueError("PEFT backend is required for this method.")
2803+
2804+
low_cpu_mem_usage = kwargs.pop("low_cpu_mem_usage", _LOW_CPU_MEM_USAGE_DEFAULT_LORA)
2805+
if low_cpu_mem_usage and is_peft_version("<", "0.13.0"):
2806+
raise ValueError(
2807+
"`low_cpu_mem_usage=True` is not compatible with this `peft` version. Please update it with `pip install -U peft`."
2808+
)
2809+
2810+
# if a dict is passed, copy it instead of modifying it inplace
2811+
if isinstance(pretrained_model_name_or_path_or_dict, dict):
2812+
pretrained_model_name_or_path_or_dict = pretrained_model_name_or_path_or_dict.copy()
2813+
2814+
# First, ensure that the checkpoint is a compatible one and can be successfully loaded.
2815+
state_dict = self.lora_state_dict(pretrained_model_name_or_path_or_dict, **kwargs)
2816+
2817+
is_correct_format = all("lora" in key for key in state_dict.keys())
2818+
if not is_correct_format:
2819+
raise ValueError("Invalid LoRA checkpoint.")
2820+
2821+
self.load_lora_into_transformer(
2822+
state_dict,
2823+
transformer=getattr(self, self.transformer_name) if not hasattr(self, "transformer") else self.transformer,
2824+
adapter_name=adapter_name,
2825+
_pipeline=self,
2826+
low_cpu_mem_usage=low_cpu_mem_usage,
2827+
)
2828+
2829+
@classmethod
2830+
# Copied from diffusers.loaders.lora_pipeline.SD3LoraLoaderMixin.load_lora_into_transformer with SD3Transformer2DModel->CogVideoXTransformer3DModel
2831+
def load_lora_into_transformer(
2832+
cls, state_dict, transformer, adapter_name=None, _pipeline=None, low_cpu_mem_usage=False
2833+
):
2834+
"""
2835+
This will load the LoRA layers specified in `state_dict` into `transformer`.
2836+
2837+
Parameters:
2838+
state_dict (`dict`):
2839+
A standard state dict containing the lora layer parameters. The keys can either be indexed directly
2840+
into the unet or prefixed with an additional `unet` which can be used to distinguish between text
2841+
encoder lora layers.
2842+
transformer (`CogVideoXTransformer3DModel`):
2843+
The Transformer model to load the LoRA layers into.
2844+
adapter_name (`str`, *optional*):
2845+
Adapter name to be used for referencing the loaded adapter model. If not specified, it will use
2846+
`default_{i}` where i is the total number of adapters being loaded.
2847+
low_cpu_mem_usage (`bool`, *optional*):
2848+
Speed up model loading by only loading the pretrained LoRA weights and not initializing the random
2849+
weights.
2850+
"""
2851+
if low_cpu_mem_usage and is_peft_version("<", "0.13.0"):
2852+
raise ValueError(
2853+
"`low_cpu_mem_usage=True` is not compatible with this `peft` version. Please update it with `pip install -U peft`."
2854+
)
2855+
2856+
# Load the layers corresponding to transformer.
2857+
logger.info(f"Loading {cls.transformer_name}.")
2858+
transformer.load_lora_adapter(
2859+
state_dict,
2860+
network_alphas=None,
2861+
adapter_name=adapter_name,
2862+
_pipeline=_pipeline,
2863+
low_cpu_mem_usage=low_cpu_mem_usage,
2864+
)
2865+
2866+
@classmethod
2867+
# Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.save_lora_weights
2868+
def save_lora_weights(
2869+
cls,
2870+
save_directory: Union[str, os.PathLike],
2871+
transformer_lora_layers: Dict[str, Union[torch.nn.Module, torch.Tensor]] = None,
2872+
is_main_process: bool = True,
2873+
weight_name: str = None,
2874+
save_function: Callable = None,
2875+
safe_serialization: bool = True,
2876+
):
2877+
r"""
2878+
Save the LoRA parameters corresponding to the UNet and text encoder.
2879+
2880+
Arguments:
2881+
save_directory (`str` or `os.PathLike`):
2882+
Directory to save LoRA parameters to. Will be created if it doesn't exist.
2883+
transformer_lora_layers (`Dict[str, torch.nn.Module]` or `Dict[str, torch.Tensor]`):
2884+
State dict of the LoRA layers corresponding to the `transformer`.
2885+
is_main_process (`bool`, *optional*, defaults to `True`):
2886+
Whether the process calling this is the main process or not. Useful during distributed training and you
2887+
need to call this function on all processes. In this case, set `is_main_process=True` only on the main
2888+
process to avoid race conditions.
2889+
save_function (`Callable`):
2890+
The function to use to save the state dictionary. Useful during distributed training when you need to
2891+
replace `torch.save` with another method. Can be configured with the environment variable
2892+
`DIFFUSERS_SAVE_MODE`.
2893+
safe_serialization (`bool`, *optional*, defaults to `True`):
2894+
Whether to save the model using `safetensors` or the traditional PyTorch way with `pickle`.
2895+
"""
2896+
state_dict = {}
2897+
2898+
if not transformer_lora_layers:
2899+
raise ValueError("You must pass `transformer_lora_layers`.")
2900+
2901+
if transformer_lora_layers:
2902+
state_dict.update(cls.pack_weights(transformer_lora_layers, cls.transformer_name))
2903+
2904+
# Save the model
2905+
cls.write_lora_layers(
2906+
state_dict=state_dict,
2907+
save_directory=save_directory,
2908+
is_main_process=is_main_process,
2909+
weight_name=weight_name,
2910+
save_function=save_function,
2911+
safe_serialization=safe_serialization,
2912+
)
2913+
2914+
# Copied from diffusers.loaders.lora_pipeline.StableDiffusionLoraLoaderMixin.fuse_lora with unet->transformer
2915+
def fuse_lora(
2916+
self,
2917+
components: List[str] = ["transformer", "text_encoder"],
2918+
lora_scale: float = 1.0,
2919+
safe_fusing: bool = False,
2920+
adapter_names: Optional[List[str]] = None,
2921+
**kwargs,
2922+
):
2923+
r"""
2924+
Fuses the LoRA parameters into the original parameters of the corresponding blocks.
2925+
2926+
<Tip warning={true}>
2927+
2928+
This is an experimental API.
2929+
2930+
</Tip>
2931+
2932+
Args:
2933+
components: (`List[str]`): List of LoRA-injectable components to fuse the LoRAs into.
2934+
lora_scale (`float`, defaults to 1.0):
2935+
Controls how much to influence the outputs with the LoRA parameters.
2936+
safe_fusing (`bool`, defaults to `False`):
2937+
Whether to check fused weights for NaN values before fusing and if values are NaN not fusing them.
2938+
adapter_names (`List[str]`, *optional*):
2939+
Adapter names to be used for fusing. If nothing is passed, all active adapters will be fused.
2940+
2941+
Example:
2942+
2943+
```py
2944+
from diffusers import DiffusionPipeline
2945+
import torch
2946+
2947+
pipeline = DiffusionPipeline.from_pretrained(
2948+
"stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.float16
2949+
).to("cuda")
2950+
pipeline.load_lora_weights("nerijs/pixel-art-xl", weight_name="pixel-art-xl.safetensors", adapter_name="pixel")
2951+
pipeline.fuse_lora(lora_scale=0.7)
2952+
```
2953+
"""
2954+
super().fuse_lora(
2955+
components=components, lora_scale=lora_scale, safe_fusing=safe_fusing, adapter_names=adapter_names
2956+
)
2957+
2958+
# Copied from diffusers.loaders.lora_pipeline.StableDiffusionLoraLoaderMixin.unfuse_lora with unet->transformer
2959+
def unfuse_lora(self, components: List[str] = ["transformer", "text_encoder"], **kwargs):
2960+
r"""
2961+
Reverses the effect of
2962+
[`pipe.fuse_lora()`](https://huggingface.co/docs/diffusers/main/en/api/loaders#diffusers.loaders.LoraBaseMixin.fuse_lora).
2963+
2964+
<Tip warning={true}>
2965+
2966+
This is an experimental API.
2967+
2968+
</Tip>
2969+
2970+
Args:
2971+
components (`List[str]`): List of LoRA-injectable components to unfuse LoRA from.
2972+
unfuse_transformer (`bool`, defaults to `True`): Whether to unfuse the UNet LoRA parameters.
2973+
unfuse_text_encoder (`bool`, defaults to `True`):
2974+
Whether to unfuse the text encoder LoRA parameters. If the text encoder wasn't monkey-patched with the
2975+
LoRA parameters then it won't have any effect.
2976+
"""
2977+
super().unfuse_lora(components=components)
2978+
2979+
26722980
class LoraLoaderMixin(StableDiffusionLoraLoaderMixin):
26732981
def __init__(self, *args, **kwargs):
26742982
deprecation_message = "LoraLoaderMixin is deprecated and this will be removed in a future version. Please use `StableDiffusionLoraLoaderMixin`, instead."

src/diffusers/loaders/peft.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,7 @@
5252
"SD3Transformer2DModel": lambda model_cls, weights: weights,
5353
"FluxTransformer2DModel": lambda model_cls, weights: weights,
5454
"CogVideoXTransformer3DModel": lambda model_cls, weights: weights,
55+
"MochiTransformer3DModel": lambda model_cls, weights: weights,
5556
}
5657

5758

0 commit comments

Comments
 (0)