Skip to content

Commit b5689c4

Browse files
committed
Merge remote-tracking branch 'origin/video-loras' into video-loras
2 parents ccdc4fd + 5e6a15b commit b5689c4

File tree

1 file changed

+6
-6
lines changed

1 file changed

+6
-6
lines changed

src/diffusers/loaders/lora_pipeline.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -4251,9 +4251,9 @@ def lora_state_dict(
42514251

42524252
@classmethod
42534253
def maybe_expand_t2v_lora_for_i2v(
4254-
cls,
4255-
transformer: torch.nn.Module,
4256-
state_dict,
4254+
cls,
4255+
transformer: torch.nn.Module,
4256+
state_dict,
42574257
):
42584258
num_blocks = len({k.split("blocks.")[1].split(".")[0] for k in state_dict})
42594259
is_i2v_lora = any("k_img" in k for k in state_dict) and any("v_img" in k for k in state_dict)
@@ -4313,9 +4313,9 @@ def load_lora_weights(
43134313
# First, ensure that the checkpoint is a compatible one and can be successfully loaded.
43144314
state_dict = self.lora_state_dict(pretrained_model_name_or_path_or_dict, **kwargs)
43154315
state_dict = self._maybe_expand_t2v_lora_for_i2v(
4316-
transformer = getattr(self, self.transformer_name) if not hasattr(self,
4317-
"transformer") else self.transformer,
4318-
state_dict = state_dict)
4316+
transformer=getattr(self, self.transformer_name) if not hasattr(self, "transformer") else self.transformer,
4317+
state_dict=state_dict,
4318+
)
43194319
is_correct_format = all("lora" in key for key in state_dict.keys())
43204320
if not is_correct_format:
43214321
raise ValueError("Invalid LoRA checkpoint.")

0 commit comments

Comments
 (0)