@@ -1605,9 +1605,18 @@ def _convert_non_diffusers_wan_lora_to_diffusers(state_dict):
1605
1605
if diff_keys :
1606
1606
for diff_k in diff_keys :
1607
1607
param = original_state_dict [diff_k ]
1608
+ # The magnitudes of the .diff-ending weights are very low (most are below 1e-4, some are upto 1e-3,
1609
+ # and 2 of them are about 1.6e-2 [the case with AccVideo lora]). The low magnitudes mostly correspond
1610
+ # to norm layers. Ignoring them is the best option at the moment until a better solution is found. It
1611
+ # is okay to ignore because they do not affect the model output in a significant manner.
1612
+ threshold = 1.6e-2
1613
+ absdiff = param .abs ().max () - param .abs ().min ()
1608
1614
all_zero = torch .all (param == 0 ).item ()
1609
- if all_zero :
1610
- logger .debug (f"Removed { diff_k } key from the state dict as it's all zeros." )
1615
+ all_absdiff_lower_than_threshold = absdiff < threshold
1616
+ if all_zero or all_absdiff_lower_than_threshold :
1617
+ logger .debug (
1618
+ f"Removed { diff_k } key from the state dict as it's all zeros, or values lower than hardcoded threshold."
1619
+ )
1611
1620
original_state_dict .pop (diff_k )
1612
1621
1613
1622
# For the `diff_b` keys, we treat them as lora_bias.
@@ -1655,12 +1664,16 @@ def _convert_non_diffusers_wan_lora_to_diffusers(state_dict):
1655
1664
1656
1665
# FFN
1657
1666
for o , c in zip (["ffn.0" , "ffn.2" ], ["net.0.proj" , "net.2" ]):
1658
- converted_state_dict [f"blocks.{ i } .ffn.{ c } .lora_A.weight" ] = original_state_dict .pop (
1659
- f"blocks.{ i } .{ o } .{ lora_down_key } .weight"
1660
- )
1661
- converted_state_dict [f"blocks.{ i } .ffn.{ c } .lora_B.weight" ] = original_state_dict .pop (
1662
- f"blocks.{ i } .{ o } .{ lora_up_key } .weight"
1663
- )
1667
+ original_key = f"blocks.{ i } .{ o } .{ lora_down_key } .weight"
1668
+ converted_key = f"blocks.{ i } .ffn.{ c } .lora_A.weight"
1669
+ if original_key in original_state_dict :
1670
+ converted_state_dict [converted_key ] = original_state_dict .pop (original_key )
1671
+
1672
+ original_key = f"blocks.{ i } .{ o } .{ lora_up_key } .weight"
1673
+ converted_key = f"blocks.{ i } .ffn.{ c } .lora_B.weight"
1674
+ if original_key in original_state_dict :
1675
+ converted_state_dict [converted_key ] = original_state_dict .pop (original_key )
1676
+
1664
1677
if f"blocks.{ i } .{ o } .diff_b" in original_state_dict :
1665
1678
converted_state_dict [f"blocks.{ i } .ffn.{ c } .lora_B.bias" ] = original_state_dict .pop (
1666
1679
f"blocks.{ i } .{ o } .diff_b"
@@ -1669,12 +1682,16 @@ def _convert_non_diffusers_wan_lora_to_diffusers(state_dict):
1669
1682
# Remaining.
1670
1683
if original_state_dict :
1671
1684
if any ("time_projection" in k for k in original_state_dict ):
1672
- converted_state_dict ["condition_embedder.time_proj.lora_A.weight" ] = original_state_dict .pop (
1673
- f"time_projection.1.{ lora_down_key } .weight"
1674
- )
1675
- converted_state_dict ["condition_embedder.time_proj.lora_B.weight" ] = original_state_dict .pop (
1676
- f"time_projection.1.{ lora_up_key } .weight"
1677
- )
1685
+ original_key = f"time_projection.1.{ lora_down_key } .weight"
1686
+ converted_key = "condition_embedder.time_proj.lora_A.weight"
1687
+ if original_key in original_state_dict :
1688
+ converted_state_dict [converted_key ] = original_state_dict .pop (original_key )
1689
+
1690
+ original_key = f"time_projection.1.{ lora_up_key } .weight"
1691
+ converted_key = "condition_embedder.time_proj.lora_B.weight"
1692
+ if original_key in original_state_dict :
1693
+ converted_state_dict [converted_key ] = original_state_dict .pop (original_key )
1694
+
1678
1695
if "time_projection.1.diff_b" in original_state_dict :
1679
1696
converted_state_dict ["condition_embedder.time_proj.lora_B.bias" ] = original_state_dict .pop (
1680
1697
"time_projection.1.diff_b"
@@ -1709,6 +1726,20 @@ def _convert_non_diffusers_wan_lora_to_diffusers(state_dict):
1709
1726
original_state_dict .pop (f"{ text_time } .{ b_n } .diff_b" )
1710
1727
)
1711
1728
1729
+ for img_ours , img_theirs in [
1730
+ ("ff.net.0.proj" , "img_emb.proj.1" ),
1731
+ ("ff.net.2" , "img_emb.proj.3" ),
1732
+ ]:
1733
+ original_key = f"{ img_theirs } .{ lora_down_key } .weight"
1734
+ converted_key = f"condition_embedder.image_embedder.{ img_ours } .lora_A.weight"
1735
+ if original_key in original_state_dict :
1736
+ converted_state_dict [converted_key ] = original_state_dict .pop (original_key )
1737
+
1738
+ original_key = f"{ img_theirs } .{ lora_up_key } .weight"
1739
+ converted_key = f"condition_embedder.image_embedder.{ img_ours } .lora_B.weight"
1740
+ if original_key in original_state_dict :
1741
+ converted_state_dict [converted_key ] = original_state_dict .pop (original_key )
1742
+
1712
1743
if len (original_state_dict ) > 0 :
1713
1744
diff = all (".diff" in k for k in original_state_dict )
1714
1745
if diff :
0 commit comments