diff --git a/docs/source/en/api/pipelines/wan.md b/docs/source/en/api/pipelines/wan.md index b16bf92a6370..a35b73cb8a2e 100644 --- a/docs/source/en/api/pipelines/wan.md +++ b/docs/source/en/api/pipelines/wan.md @@ -14,6 +14,10 @@ # Wan +
+ LoRA +
+ [Wan 2.1](https://github.com/Wan-Video/Wan2.1) by the Alibaba Wan Team. diff --git a/src/diffusers/loaders/lora_conversion_utils.py b/src/diffusers/loaders/lora_conversion_utils.py index 4be6971755d2..2f022098b368 100644 --- a/src/diffusers/loaders/lora_conversion_utils.py +++ b/src/diffusers/loaders/lora_conversion_utils.py @@ -1348,3 +1348,53 @@ def process_block(prefix, index, convert_norm): converted_state_dict[f"transformer.{key}"] = converted_state_dict.pop(key) return converted_state_dict + + +def _convert_non_diffusers_wan_lora_to_diffusers(state_dict): + converted_state_dict = {} + original_state_dict = {k[len("diffusion_model.") :]: v for k, v in state_dict.items()} + + num_blocks = len({k.split("blocks.")[1].split(".")[0] for k in original_state_dict}) + + for i in range(num_blocks): + # Self-attention + for o, c in zip(["q", "k", "v", "o"], ["to_q", "to_k", "to_v", "to_out.0"]): + converted_state_dict[f"blocks.{i}.attn1.{c}.lora_A.weight"] = original_state_dict.pop( + f"blocks.{i}.self_attn.{o}.lora_A.weight" + ) + converted_state_dict[f"blocks.{i}.attn1.{c}.lora_B.weight"] = original_state_dict.pop( + f"blocks.{i}.self_attn.{o}.lora_B.weight" + ) + + # Cross-attention + for o, c in zip(["q", "k", "v", "o"], ["to_q", "to_k", "to_v", "to_out.0"]): + converted_state_dict[f"blocks.{i}.attn2.{c}.lora_A.weight"] = original_state_dict.pop( + f"blocks.{i}.cross_attn.{o}.lora_A.weight" + ) + converted_state_dict[f"blocks.{i}.attn2.{c}.lora_B.weight"] = original_state_dict.pop( + f"blocks.{i}.cross_attn.{o}.lora_B.weight" + ) + for o, c in zip(["k_img", "v_img"], ["add_k_proj", "add_v_proj"]): + converted_state_dict[f"blocks.{i}.attn2.{c}.lora_A.weight"] = original_state_dict.pop( + f"blocks.{i}.cross_attn.{o}.lora_A.weight" + ) + converted_state_dict[f"blocks.{i}.attn2.{c}.lora_B.weight"] = original_state_dict.pop( + f"blocks.{i}.cross_attn.{o}.lora_B.weight" + ) + + # FFN + for o, c in zip(["ffn.0", "ffn.2"], ["net.0.proj", "net.2"]): + converted_state_dict[f"blocks.{i}.ffn.{c}.lora_A.weight"] = original_state_dict.pop( + f"blocks.{i}.{o}.lora_A.weight" + ) + converted_state_dict[f"blocks.{i}.ffn.{c}.lora_B.weight"] = original_state_dict.pop( + f"blocks.{i}.{o}.lora_B.weight" + ) + + if len(original_state_dict) > 0: + raise ValueError(f"`state_dict` should be empty at this point but has {original_state_dict.keys()=}") + + for key in list(converted_state_dict.keys()): + converted_state_dict[f"transformer.{key}"] = converted_state_dict.pop(key) + + return converted_state_dict diff --git a/src/diffusers/loaders/lora_pipeline.py b/src/diffusers/loaders/lora_pipeline.py index b0743d5a6ed5..1dce86e2fd71 100644 --- a/src/diffusers/loaders/lora_pipeline.py +++ b/src/diffusers/loaders/lora_pipeline.py @@ -42,6 +42,7 @@ _convert_kohya_flux_lora_to_diffusers, _convert_non_diffusers_lora_to_diffusers, _convert_non_diffusers_lumina2_lora_to_diffusers, + _convert_non_diffusers_wan_lora_to_diffusers, _convert_xlabs_flux_lora_to_diffusers, _maybe_map_sgm_blocks_to_diffusers, ) @@ -4111,7 +4112,6 @@ class WanLoraLoaderMixin(LoraBaseMixin): @classmethod @validate_hf_hub_args - # Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.lora_state_dict def lora_state_dict( cls, pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]], @@ -4198,6 +4198,8 @@ def lora_state_dict( user_agent=user_agent, allow_pickle=allow_pickle, ) + if any(k.startswith("diffusion_model.") for k in state_dict): + state_dict = _convert_non_diffusers_wan_lora_to_diffusers(state_dict) is_dora_scale_present = any("dora_scale" in k for k in state_dict) if is_dora_scale_present: