From f351017fae1f3a5f2a5803316747673b235e5df8 Mon Sep 17 00:00:00 2001 From: linoytsaban Date: Tue, 18 Mar 2025 15:17:02 +0200 Subject: [PATCH 01/16] @hlky t2v->i2v --- src/diffusers/loaders/lora_pipeline.py | 30 +++++++++++++++++++++++++- 1 file changed, 29 insertions(+), 1 deletion(-) diff --git a/src/diffusers/loaders/lora_pipeline.py b/src/diffusers/loaders/lora_pipeline.py index 160793ba1b58..4f540e0fcfa4 100644 --- a/src/diffusers/loaders/lora_pipeline.py +++ b/src/diffusers/loaders/lora_pipeline.py @@ -4249,6 +4249,31 @@ def lora_state_dict( return state_dict + @classmethod + def maybe_expand_t2v_lora_for_i2v( + cls, + transformer: torch.nn.Module, + state_dict, + ): + num_blocks = len({k.split("blocks.")[1].split(".")[0] for k in state_dict}) + is_i2v_lora = any("k_img" in k for k in state_dict) and any("v_img" in k for k in state_dict) + if not is_i2v_lora: + return state_dict + + if transformer.config.image_dim is None: + return state_dict + + for i in range(num_blocks): + for o, c in zip(["k_img", "v_img"], ["add_k_proj", "add_v_proj"]): + state_dict[f"blocks.{i}.attn2.{c}.lora_A.weight"] = torch.zeros_like( + state_dict[f"blocks.{i}.attn2.{o.replace('_img', '')}.lora_A.weight"] + ) + state_dict[f"blocks.{i}.attn2.{c}.lora_B.weight"] = torch.zeros_like( + state_dict[f"blocks.{i}.attn2.{o.replace('_img', '')}.lora_B.weight"] + ) + + return state_dict + # Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.load_lora_weights def load_lora_weights( self, pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]], adapter_name=None, **kwargs @@ -4287,7 +4312,10 @@ def load_lora_weights( # First, ensure that the checkpoint is a compatible one and can be successfully loaded. state_dict = self.lora_state_dict(pretrained_model_name_or_path_or_dict, **kwargs) - + state_dict = self._maybe_expand_t2v_lora_for_i2v( + transformer = getattr(self, self.transformer_name) if not hasattr(self, + "transformer") else self.transformer, + state_dict = state_dict) is_correct_format = all("lora" in key for key in state_dict.keys()) if not is_correct_format: raise ValueError("Invalid LoRA checkpoint.") From 5e6a15b0c3909f49ab3b50ca61f5a9efe120f7f6 Mon Sep 17 00:00:00 2001 From: "github-actions[bot]" Date: Tue, 18 Mar 2025 13:25:12 +0000 Subject: [PATCH 02/16] Apply style fixes --- src/diffusers/loaders/lora_pipeline.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/src/diffusers/loaders/lora_pipeline.py b/src/diffusers/loaders/lora_pipeline.py index 4f540e0fcfa4..70df0fe01073 100644 --- a/src/diffusers/loaders/lora_pipeline.py +++ b/src/diffusers/loaders/lora_pipeline.py @@ -4251,9 +4251,9 @@ def lora_state_dict( @classmethod def maybe_expand_t2v_lora_for_i2v( - cls, - transformer: torch.nn.Module, - state_dict, + cls, + transformer: torch.nn.Module, + state_dict, ): num_blocks = len({k.split("blocks.")[1].split(".")[0] for k in state_dict}) is_i2v_lora = any("k_img" in k for k in state_dict) and any("v_img" in k for k in state_dict) @@ -4313,9 +4313,9 @@ def load_lora_weights( # First, ensure that the checkpoint is a compatible one and can be successfully loaded. state_dict = self.lora_state_dict(pretrained_model_name_or_path_or_dict, **kwargs) state_dict = self._maybe_expand_t2v_lora_for_i2v( - transformer = getattr(self, self.transformer_name) if not hasattr(self, - "transformer") else self.transformer, - state_dict = state_dict) + transformer=getattr(self, self.transformer_name) if not hasattr(self, "transformer") else self.transformer, + state_dict=state_dict, + ) is_correct_format = all("lora" in key for key in state_dict.keys()) if not is_correct_format: raise ValueError("Invalid LoRA checkpoint.") From ccdc4fd6a9406e94e5acf41ae80706c1b936995e Mon Sep 17 00:00:00 2001 From: linoytsaban Date: Tue, 18 Mar 2025 15:37:44 +0200 Subject: [PATCH 03/16] try with ones to not nullify layers --- src/diffusers/loaders/lora_pipeline.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/diffusers/loaders/lora_pipeline.py b/src/diffusers/loaders/lora_pipeline.py index 4f540e0fcfa4..01bc773609cc 100644 --- a/src/diffusers/loaders/lora_pipeline.py +++ b/src/diffusers/loaders/lora_pipeline.py @@ -4265,10 +4265,10 @@ def maybe_expand_t2v_lora_for_i2v( for i in range(num_blocks): for o, c in zip(["k_img", "v_img"], ["add_k_proj", "add_v_proj"]): - state_dict[f"blocks.{i}.attn2.{c}.lora_A.weight"] = torch.zeros_like( + state_dict[f"blocks.{i}.attn2.{c}.lora_A.weight"] = torch.ones_like( state_dict[f"blocks.{i}.attn2.{o.replace('_img', '')}.lora_A.weight"] ) - state_dict[f"blocks.{i}.attn2.{c}.lora_B.weight"] = torch.zeros_like( + state_dict[f"blocks.{i}.attn2.{c}.lora_B.weight"] = torch.ones_like( state_dict[f"blocks.{i}.attn2.{o.replace('_img', '')}.lora_B.weight"] ) From fe2d3b4c7d0a30a72825a064202eb420ea367876 Mon Sep 17 00:00:00 2001 From: linoytsaban Date: Tue, 18 Mar 2025 16:01:49 +0200 Subject: [PATCH 04/16] fix method name --- src/diffusers/loaders/lora_pipeline.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/diffusers/loaders/lora_pipeline.py b/src/diffusers/loaders/lora_pipeline.py index d896ea862c1f..292fbeeece8a 100644 --- a/src/diffusers/loaders/lora_pipeline.py +++ b/src/diffusers/loaders/lora_pipeline.py @@ -4250,7 +4250,7 @@ def lora_state_dict( return state_dict @classmethod - def maybe_expand_t2v_lora_for_i2v( + def _maybe_expand_t2v_lora_for_i2v( cls, transformer: torch.nn.Module, state_dict, From 6637a121a0516b605842fb6b921eec641fcd37ac Mon Sep 17 00:00:00 2001 From: linoytsaban Date: Tue, 18 Mar 2025 16:34:51 +0200 Subject: [PATCH 05/16] revert to zeros --- src/diffusers/loaders/lora_pipeline.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/diffusers/loaders/lora_pipeline.py b/src/diffusers/loaders/lora_pipeline.py index 292fbeeece8a..03d898fda1dc 100644 --- a/src/diffusers/loaders/lora_pipeline.py +++ b/src/diffusers/loaders/lora_pipeline.py @@ -4265,10 +4265,10 @@ def _maybe_expand_t2v_lora_for_i2v( for i in range(num_blocks): for o, c in zip(["k_img", "v_img"], ["add_k_proj", "add_v_proj"]): - state_dict[f"blocks.{i}.attn2.{c}.lora_A.weight"] = torch.ones_like( + state_dict[f"blocks.{i}.attn2.{c}.lora_A.weight"] = torch.zeros_like( state_dict[f"blocks.{i}.attn2.{o.replace('_img', '')}.lora_A.weight"] ) - state_dict[f"blocks.{i}.attn2.{c}.lora_B.weight"] = torch.ones_like( + state_dict[f"blocks.{i}.attn2.{c}.lora_B.weight"] = torch.zeros_like( state_dict[f"blocks.{i}.attn2.{o.replace('_img', '')}.lora_B.weight"] ) From 63e581cad8b4faa5e7108a81026d5b39225e0b39 Mon Sep 17 00:00:00 2001 From: linoytsaban Date: Tue, 18 Mar 2025 17:33:14 +0200 Subject: [PATCH 06/16] add check to state_dict keys --- src/diffusers/loaders/lora_pipeline.py | 33 +++++++++++++------------- 1 file changed, 17 insertions(+), 16 deletions(-) diff --git a/src/diffusers/loaders/lora_pipeline.py b/src/diffusers/loaders/lora_pipeline.py index 03d898fda1dc..f4ce851ca859 100644 --- a/src/diffusers/loaders/lora_pipeline.py +++ b/src/diffusers/loaders/lora_pipeline.py @@ -4255,22 +4255,23 @@ def _maybe_expand_t2v_lora_for_i2v( transformer: torch.nn.Module, state_dict, ): - num_blocks = len({k.split("blocks.")[1].split(".")[0] for k in state_dict}) - is_i2v_lora = any("k_img" in k for k in state_dict) and any("v_img" in k for k in state_dict) - if not is_i2v_lora: - return state_dict - - if transformer.config.image_dim is None: - return state_dict - - for i in range(num_blocks): - for o, c in zip(["k_img", "v_img"], ["add_k_proj", "add_v_proj"]): - state_dict[f"blocks.{i}.attn2.{c}.lora_A.weight"] = torch.zeros_like( - state_dict[f"blocks.{i}.attn2.{o.replace('_img', '')}.lora_A.weight"] - ) - state_dict[f"blocks.{i}.attn2.{c}.lora_B.weight"] = torch.zeros_like( - state_dict[f"blocks.{i}.attn2.{o.replace('_img', '')}.lora_B.weight"] - ) + if any(k.startswith("blocks.") for k in state_dict): + num_blocks = len({k.split("blocks.")[1].split(".")[0] for k in state_dict}) + is_i2v_lora = any("k_img" in k for k in state_dict) and any("v_img" in k for k in state_dict) + if not is_i2v_lora: + return state_dict + + if transformer.config.image_dim is None: + return state_dict + + for i in range(num_blocks): + for o, c in zip(["k_img", "v_img"], ["add_k_proj", "add_v_proj"]): + state_dict[f"blocks.{i}.attn2.{c}.lora_A.weight"] = torch.zeros_like( + state_dict[f"blocks.{i}.attn2.{o.replace('_img', '')}.lora_A.weight"] + ) + state_dict[f"blocks.{i}.attn2.{c}.lora_B.weight"] = torch.zeros_like( + state_dict[f"blocks.{i}.attn2.{o.replace('_img', '')}.lora_B.weight"] + ) return state_dict From 9fa3d933b3ff2c0180b5ee384db1472e9c8659b3 Mon Sep 17 00:00:00 2001 From: linoytsaban Date: Tue, 18 Mar 2025 17:37:45 +0200 Subject: [PATCH 07/16] add comment --- src/diffusers/loaders/lora_pipeline.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/diffusers/loaders/lora_pipeline.py b/src/diffusers/loaders/lora_pipeline.py index f4ce851ca859..2868fb7d1370 100644 --- a/src/diffusers/loaders/lora_pipeline.py +++ b/src/diffusers/loaders/lora_pipeline.py @@ -4275,7 +4275,7 @@ def _maybe_expand_t2v_lora_for_i2v( return state_dict - # Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.load_lora_weights + # Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.load_lora_weights with T2V LoRA->I2V LoRA option def load_lora_weights( self, pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]], adapter_name=None, **kwargs ): @@ -4313,6 +4313,7 @@ def load_lora_weights( # First, ensure that the checkpoint is a compatible one and can be successfully loaded. state_dict = self.lora_state_dict(pretrained_model_name_or_path_or_dict, **kwargs) + # convert T2V LoRA to I2V LoRA (when loaded to Wan I2V) by adding zeros for the additional (missing) _img layers state_dict = self._maybe_expand_t2v_lora_for_i2v( transformer=getattr(self, self.transformer_name) if not hasattr(self, "transformer") else self.transformer, state_dict=state_dict, From 051f534d185c0ea065bf36a9926c4b48f496d429 Mon Sep 17 00:00:00 2001 From: Linoy Date: Tue, 18 Mar 2025 15:50:27 +0000 Subject: [PATCH 08/16] copies fix --- src/diffusers/loaders/lora_pipeline.py | 6 +----- 1 file changed, 1 insertion(+), 5 deletions(-) diff --git a/src/diffusers/loaders/lora_pipeline.py b/src/diffusers/loaders/lora_pipeline.py index 2868fb7d1370..e9fcca6ab7fa 100644 --- a/src/diffusers/loaders/lora_pipeline.py +++ b/src/diffusers/loaders/lora_pipeline.py @@ -4313,11 +4313,7 @@ def load_lora_weights( # First, ensure that the checkpoint is a compatible one and can be successfully loaded. state_dict = self.lora_state_dict(pretrained_model_name_or_path_or_dict, **kwargs) - # convert T2V LoRA to I2V LoRA (when loaded to Wan I2V) by adding zeros for the additional (missing) _img layers - state_dict = self._maybe_expand_t2v_lora_for_i2v( - transformer=getattr(self, self.transformer_name) if not hasattr(self, "transformer") else self.transformer, - state_dict=state_dict, - ) + is_correct_format = all("lora" in key for key in state_dict.keys()) if not is_correct_format: raise ValueError("Invalid LoRA checkpoint.") From 3834c16c63f03c7d49eec472e530e10ca5aa8dee Mon Sep 17 00:00:00 2001 From: Linoy Date: Tue, 18 Mar 2025 15:52:28 +0000 Subject: [PATCH 09/16] Revert "copies fix" This reverts commit 051f534d185c0ea065bf36a9926c4b48f496d429. --- src/diffusers/loaders/lora_pipeline.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/src/diffusers/loaders/lora_pipeline.py b/src/diffusers/loaders/lora_pipeline.py index e9fcca6ab7fa..2868fb7d1370 100644 --- a/src/diffusers/loaders/lora_pipeline.py +++ b/src/diffusers/loaders/lora_pipeline.py @@ -4313,7 +4313,11 @@ def load_lora_weights( # First, ensure that the checkpoint is a compatible one and can be successfully loaded. state_dict = self.lora_state_dict(pretrained_model_name_or_path_or_dict, **kwargs) - + # convert T2V LoRA to I2V LoRA (when loaded to Wan I2V) by adding zeros for the additional (missing) _img layers + state_dict = self._maybe_expand_t2v_lora_for_i2v( + transformer=getattr(self, self.transformer_name) if not hasattr(self, "transformer") else self.transformer, + state_dict=state_dict, + ) is_correct_format = all("lora" in key for key in state_dict.keys()) if not is_correct_format: raise ValueError("Invalid LoRA checkpoint.") From 292d6188c58956c01ac2dfda06f5cc5502a5b1fb Mon Sep 17 00:00:00 2001 From: linoytsaban Date: Tue, 18 Mar 2025 17:54:54 +0200 Subject: [PATCH 10/16] remove copied from --- src/diffusers/loaders/lora_pipeline.py | 1 - 1 file changed, 1 deletion(-) diff --git a/src/diffusers/loaders/lora_pipeline.py b/src/diffusers/loaders/lora_pipeline.py index 2868fb7d1370..18a38a071e3a 100644 --- a/src/diffusers/loaders/lora_pipeline.py +++ b/src/diffusers/loaders/lora_pipeline.py @@ -4275,7 +4275,6 @@ def _maybe_expand_t2v_lora_for_i2v( return state_dict - # Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.load_lora_weights with T2V LoRA->I2V LoRA option def load_lora_weights( self, pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]], adapter_name=None, **kwargs ): From 92aabcb9f0bf972e97530bc857ed8c95ab38f402 Mon Sep 17 00:00:00 2001 From: Linoy Tsaban <57615435+linoytsaban@users.noreply.github.com> Date: Wed, 19 Mar 2025 15:27:21 +0200 Subject: [PATCH 11/16] Update src/diffusers/loaders/lora_pipeline.py Co-authored-by: hlky --- src/diffusers/loaders/lora_pipeline.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/diffusers/loaders/lora_pipeline.py b/src/diffusers/loaders/lora_pipeline.py index 18a38a071e3a..42530e7e59aa 100644 --- a/src/diffusers/loaders/lora_pipeline.py +++ b/src/diffusers/loaders/lora_pipeline.py @@ -4258,7 +4258,7 @@ def _maybe_expand_t2v_lora_for_i2v( if any(k.startswith("blocks.") for k in state_dict): num_blocks = len({k.split("blocks.")[1].split(".")[0] for k in state_dict}) is_i2v_lora = any("k_img" in k for k in state_dict) and any("v_img" in k for k in state_dict) - if not is_i2v_lora: + if is_i2v_lora: return state_dict if transformer.config.image_dim is None: From 51c570dcf3c19d6f63f6c3d704e4027dd554a988 Mon Sep 17 00:00:00 2001 From: Linoy Tsaban <57615435+linoytsaban@users.noreply.github.com> Date: Wed, 19 Mar 2025 15:35:30 +0200 Subject: [PATCH 12/16] Update src/diffusers/loaders/lora_pipeline.py Co-authored-by: hlky --- src/diffusers/loaders/lora_pipeline.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/diffusers/loaders/lora_pipeline.py b/src/diffusers/loaders/lora_pipeline.py index 42530e7e59aa..ce9eaa64806b 100644 --- a/src/diffusers/loaders/lora_pipeline.py +++ b/src/diffusers/loaders/lora_pipeline.py @@ -4257,7 +4257,7 @@ def _maybe_expand_t2v_lora_for_i2v( ): if any(k.startswith("blocks.") for k in state_dict): num_blocks = len({k.split("blocks.")[1].split(".")[0] for k in state_dict}) - is_i2v_lora = any("k_img" in k for k in state_dict) and any("v_img" in k for k in state_dict) + is_i2v_lora = any("add_k_proj" in k for k in state_dict) and any("add_v_proj" in k for k in state_dict) if is_i2v_lora: return state_dict From f5b5986e22564a0c97c435d5bfcb484ed8d77748 Mon Sep 17 00:00:00 2001 From: linoytsaban Date: Wed, 19 Mar 2025 15:50:41 +0200 Subject: [PATCH 13/16] update --- src/diffusers/loaders/lora_pipeline.py | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/src/diffusers/loaders/lora_pipeline.py b/src/diffusers/loaders/lora_pipeline.py index ce9eaa64806b..cd9303429de7 100644 --- a/src/diffusers/loaders/lora_pipeline.py +++ b/src/diffusers/loaders/lora_pipeline.py @@ -4255,13 +4255,15 @@ def _maybe_expand_t2v_lora_for_i2v( transformer: torch.nn.Module, state_dict, ): - if any(k.startswith("blocks.") for k in state_dict): + + if transformer.config.image_dim is not None: + return state_dict + + if any(k.startswith("transformer.blocks.") for k in state_dict): num_blocks = len({k.split("blocks.")[1].split(".")[0] for k in state_dict}) is_i2v_lora = any("add_k_proj" in k for k in state_dict) and any("add_v_proj" in k for k in state_dict) - if is_i2v_lora: - return state_dict - if transformer.config.image_dim is None: + if is_i2v_lora: return state_dict for i in range(num_blocks): From d2dd6ae161986d6d3a2e124b2747a80cd773485d Mon Sep 17 00:00:00 2001 From: linoytsaban Date: Wed, 19 Mar 2025 15:57:15 +0200 Subject: [PATCH 14/16] update --- src/diffusers/loaders/lora_pipeline.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/diffusers/loaders/lora_pipeline.py b/src/diffusers/loaders/lora_pipeline.py index cd9303429de7..166a6c5c18fb 100644 --- a/src/diffusers/loaders/lora_pipeline.py +++ b/src/diffusers/loaders/lora_pipeline.py @@ -4256,7 +4256,7 @@ def _maybe_expand_t2v_lora_for_i2v( state_dict, ): - if transformer.config.image_dim is not None: + if transformer.config.image_dim is None: return state_dict if any(k.startswith("transformer.blocks.") for k in state_dict): From c46445590cfc32016e4d5c2ce42e435cb759de06 Mon Sep 17 00:00:00 2001 From: Linoy Tsaban <57615435+linoytsaban@users.noreply.github.com> Date: Wed, 19 Mar 2025 15:57:32 +0200 Subject: [PATCH 15/16] Update src/diffusers/loaders/lora_pipeline.py Co-authored-by: hlky --- src/diffusers/loaders/lora_pipeline.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/src/diffusers/loaders/lora_pipeline.py b/src/diffusers/loaders/lora_pipeline.py index cd9303429de7..08ae77dabb0e 100644 --- a/src/diffusers/loaders/lora_pipeline.py +++ b/src/diffusers/loaders/lora_pipeline.py @@ -4268,11 +4268,11 @@ def _maybe_expand_t2v_lora_for_i2v( for i in range(num_blocks): for o, c in zip(["k_img", "v_img"], ["add_k_proj", "add_v_proj"]): - state_dict[f"blocks.{i}.attn2.{c}.lora_A.weight"] = torch.zeros_like( - state_dict[f"blocks.{i}.attn2.{o.replace('_img', '')}.lora_A.weight"] + state_dict[f"transformer.blocks.{i}.attn2.{c}.lora_A.weight"] = torch.zeros_like( + state_dict[f"transformer.blocks.{i}.attn2.to_k.lora_A.weight"] ) - state_dict[f"blocks.{i}.attn2.{c}.lora_B.weight"] = torch.zeros_like( - state_dict[f"blocks.{i}.attn2.{o.replace('_img', '')}.lora_B.weight"] + state_dict[f"transformer.blocks.{i}.attn2.{c}.lora_B.weight"] = torch.zeros_like( + state_dict[f"transformer.blocks.{i}.attn2.to_k.lora_B.weight"] ) return state_dict From 6c394656dd3b27b7df09da263242ad14be693097 Mon Sep 17 00:00:00 2001 From: "github-actions[bot]" Date: Wed, 19 Mar 2025 14:22:34 +0000 Subject: [PATCH 16/16] Apply style fixes --- src/diffusers/loaders/lora_pipeline.py | 1 - 1 file changed, 1 deletion(-) diff --git a/src/diffusers/loaders/lora_pipeline.py b/src/diffusers/loaders/lora_pipeline.py index e81356202438..e522778deeed 100644 --- a/src/diffusers/loaders/lora_pipeline.py +++ b/src/diffusers/loaders/lora_pipeline.py @@ -4255,7 +4255,6 @@ def _maybe_expand_t2v_lora_for_i2v( transformer: torch.nn.Module, state_dict, ): - if transformer.config.image_dim is None: return state_dict