@@ -4255,22 +4255,23 @@ def _maybe_expand_t2v_lora_for_i2v(
4255
4255
transformer : torch .nn .Module ,
4256
4256
state_dict ,
4257
4257
):
4258
- num_blocks = len ({k .split ("blocks." )[1 ].split ("." )[0 ] for k in state_dict })
4259
- is_i2v_lora = any ("k_img" in k for k in state_dict ) and any ("v_img" in k for k in state_dict )
4260
- if not is_i2v_lora :
4261
- return state_dict
4262
-
4263
- if transformer .config .image_dim is None :
4264
- return state_dict
4265
-
4266
- for i in range (num_blocks ):
4267
- for o , c in zip (["k_img" , "v_img" ], ["add_k_proj" , "add_v_proj" ]):
4268
- state_dict [f"blocks.{ i } .attn2.{ c } .lora_A.weight" ] = torch .zeros_like (
4269
- state_dict [f"blocks.{ i } .attn2.{ o .replace ('_img' , '' )} .lora_A.weight" ]
4270
- )
4271
- state_dict [f"blocks.{ i } .attn2.{ c } .lora_B.weight" ] = torch .zeros_like (
4272
- state_dict [f"blocks.{ i } .attn2.{ o .replace ('_img' , '' )} .lora_B.weight" ]
4273
- )
4258
+ if any (k .startswith ("blocks." ) for k in state_dict ):
4259
+ num_blocks = len ({k .split ("blocks." )[1 ].split ("." )[0 ] for k in state_dict })
4260
+ is_i2v_lora = any ("k_img" in k for k in state_dict ) and any ("v_img" in k for k in state_dict )
4261
+ if not is_i2v_lora :
4262
+ return state_dict
4263
+
4264
+ if transformer .config .image_dim is None :
4265
+ return state_dict
4266
+
4267
+ for i in range (num_blocks ):
4268
+ for o , c in zip (["k_img" , "v_img" ], ["add_k_proj" , "add_v_proj" ]):
4269
+ state_dict [f"blocks.{ i } .attn2.{ c } .lora_A.weight" ] = torch .zeros_like (
4270
+ state_dict [f"blocks.{ i } .attn2.{ o .replace ('_img' , '' )} .lora_A.weight" ]
4271
+ )
4272
+ state_dict [f"blocks.{ i } .attn2.{ c } .lora_B.weight" ] = torch .zeros_like (
4273
+ state_dict [f"blocks.{ i } .attn2.{ o .replace ('_img' , '' )} .lora_B.weight" ]
4274
+ )
4274
4275
4275
4276
return state_dict
4276
4277
0 commit comments