Skip to content

Commit 63e581c

Browse files
committed
add check to state_dict keys
1 parent 6637a12 commit 63e581c

File tree

1 file changed

+17
-16
lines changed

1 file changed

+17
-16
lines changed

src/diffusers/loaders/lora_pipeline.py

Lines changed: 17 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -4255,22 +4255,23 @@ def _maybe_expand_t2v_lora_for_i2v(
42554255
transformer: torch.nn.Module,
42564256
state_dict,
42574257
):
4258-
num_blocks = len({k.split("blocks.")[1].split(".")[0] for k in state_dict})
4259-
is_i2v_lora = any("k_img" in k for k in state_dict) and any("v_img" in k for k in state_dict)
4260-
if not is_i2v_lora:
4261-
return state_dict
4262-
4263-
if transformer.config.image_dim is None:
4264-
return state_dict
4265-
4266-
for i in range(num_blocks):
4267-
for o, c in zip(["k_img", "v_img"], ["add_k_proj", "add_v_proj"]):
4268-
state_dict[f"blocks.{i}.attn2.{c}.lora_A.weight"] = torch.zeros_like(
4269-
state_dict[f"blocks.{i}.attn2.{o.replace('_img', '')}.lora_A.weight"]
4270-
)
4271-
state_dict[f"blocks.{i}.attn2.{c}.lora_B.weight"] = torch.zeros_like(
4272-
state_dict[f"blocks.{i}.attn2.{o.replace('_img', '')}.lora_B.weight"]
4273-
)
4258+
if any(k.startswith("blocks.") for k in state_dict):
4259+
num_blocks = len({k.split("blocks.")[1].split(".")[0] for k in state_dict})
4260+
is_i2v_lora = any("k_img" in k for k in state_dict) and any("v_img" in k for k in state_dict)
4261+
if not is_i2v_lora:
4262+
return state_dict
4263+
4264+
if transformer.config.image_dim is None:
4265+
return state_dict
4266+
4267+
for i in range(num_blocks):
4268+
for o, c in zip(["k_img", "v_img"], ["add_k_proj", "add_v_proj"]):
4269+
state_dict[f"blocks.{i}.attn2.{c}.lora_A.weight"] = torch.zeros_like(
4270+
state_dict[f"blocks.{i}.attn2.{o.replace('_img', '')}.lora_A.weight"]
4271+
)
4272+
state_dict[f"blocks.{i}.attn2.{c}.lora_B.weight"] = torch.zeros_like(
4273+
state_dict[f"blocks.{i}.attn2.{o.replace('_img', '')}.lora_B.weight"]
4274+
)
42744275

42754276
return state_dict
42764277

0 commit comments

Comments
 (0)