Skip to content

Commit 97fda1b

Browse files
sayakpaulDN6yiyixuxu
authored
[LoRA] feat: support non-diffusers lumina2 LoRAs. (#10909)
* feat: support non-diffusers lumina2 LoRAs. * revert ipynb changes (but I don't know why this is required ☹️) * empty --------- Co-authored-by: Dhruv Nair <dhruv.nair@gmail.com> Co-authored-by: YiYi Xu <yixu310@gmail.com>
1 parent cc22058 commit 97fda1b

File tree

2 files changed

+77
-1
lines changed

2 files changed

+77
-1
lines changed

src/diffusers/loaders/lora_conversion_utils.py

Lines changed: 71 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1276,3 +1276,74 @@ def remap_single_transformer_blocks_(key, state_dict):
12761276
converted_state_dict[f"transformer.{key}"] = converted_state_dict.pop(key)
12771277

12781278
return converted_state_dict
1279+
1280+
1281+
def _convert_non_diffusers_lumina2_lora_to_diffusers(state_dict):
1282+
# Remove "diffusion_model." prefix from keys.
1283+
state_dict = {k[len("diffusion_model.") :]: v for k, v in state_dict.items()}
1284+
converted_state_dict = {}
1285+
1286+
def get_num_layers(keys, pattern):
1287+
layers = set()
1288+
for key in keys:
1289+
match = re.search(pattern, key)
1290+
if match:
1291+
layers.add(int(match.group(1)))
1292+
return len(layers)
1293+
1294+
def process_block(prefix, index, convert_norm):
1295+
# Process attention qkv: pop lora_A and lora_B weights.
1296+
lora_down = state_dict.pop(f"{prefix}.{index}.attention.qkv.lora_A.weight")
1297+
lora_up = state_dict.pop(f"{prefix}.{index}.attention.qkv.lora_B.weight")
1298+
for attn_key in ["to_q", "to_k", "to_v"]:
1299+
converted_state_dict[f"{prefix}.{index}.attn.{attn_key}.lora_A.weight"] = lora_down
1300+
for attn_key, weight in zip(["to_q", "to_k", "to_v"], torch.split(lora_up, [2304, 768, 768], dim=0)):
1301+
converted_state_dict[f"{prefix}.{index}.attn.{attn_key}.lora_B.weight"] = weight
1302+
1303+
# Process attention out weights.
1304+
converted_state_dict[f"{prefix}.{index}.attn.to_out.0.lora_A.weight"] = state_dict.pop(
1305+
f"{prefix}.{index}.attention.out.lora_A.weight"
1306+
)
1307+
converted_state_dict[f"{prefix}.{index}.attn.to_out.0.lora_B.weight"] = state_dict.pop(
1308+
f"{prefix}.{index}.attention.out.lora_B.weight"
1309+
)
1310+
1311+
# Process feed-forward weights for layers 1, 2, and 3.
1312+
for layer in range(1, 4):
1313+
converted_state_dict[f"{prefix}.{index}.feed_forward.linear_{layer}.lora_A.weight"] = state_dict.pop(
1314+
f"{prefix}.{index}.feed_forward.w{layer}.lora_A.weight"
1315+
)
1316+
converted_state_dict[f"{prefix}.{index}.feed_forward.linear_{layer}.lora_B.weight"] = state_dict.pop(
1317+
f"{prefix}.{index}.feed_forward.w{layer}.lora_B.weight"
1318+
)
1319+
1320+
if convert_norm:
1321+
converted_state_dict[f"{prefix}.{index}.norm1.linear.lora_A.weight"] = state_dict.pop(
1322+
f"{prefix}.{index}.adaLN_modulation.1.lora_A.weight"
1323+
)
1324+
converted_state_dict[f"{prefix}.{index}.norm1.linear.lora_B.weight"] = state_dict.pop(
1325+
f"{prefix}.{index}.adaLN_modulation.1.lora_B.weight"
1326+
)
1327+
1328+
noise_refiner_pattern = r"noise_refiner\.(\d+)\."
1329+
num_noise_refiner_layers = get_num_layers(state_dict.keys(), noise_refiner_pattern)
1330+
for i in range(num_noise_refiner_layers):
1331+
process_block("noise_refiner", i, convert_norm=True)
1332+
1333+
context_refiner_pattern = r"context_refiner\.(\d+)\."
1334+
num_context_refiner_layers = get_num_layers(state_dict.keys(), context_refiner_pattern)
1335+
for i in range(num_context_refiner_layers):
1336+
process_block("context_refiner", i, convert_norm=False)
1337+
1338+
core_transformer_pattern = r"layers\.(\d+)\."
1339+
num_core_transformer_layers = get_num_layers(state_dict.keys(), core_transformer_pattern)
1340+
for i in range(num_core_transformer_layers):
1341+
process_block("layers", i, convert_norm=True)
1342+
1343+
if len(state_dict) > 0:
1344+
raise ValueError(f"`state_dict` should be empty at this point but has {state_dict.keys()=}")
1345+
1346+
for key in list(converted_state_dict.keys()):
1347+
converted_state_dict[f"transformer.{key}"] = converted_state_dict.pop(key)
1348+
1349+
return converted_state_dict

src/diffusers/loaders/lora_pipeline.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,7 @@
4141
_convert_hunyuan_video_lora_to_diffusers,
4242
_convert_kohya_flux_lora_to_diffusers,
4343
_convert_non_diffusers_lora_to_diffusers,
44+
_convert_non_diffusers_lumina2_lora_to_diffusers,
4445
_convert_xlabs_flux_lora_to_diffusers,
4546
_maybe_map_sgm_blocks_to_diffusers,
4647
)
@@ -3815,7 +3816,6 @@ class Lumina2LoraLoaderMixin(LoraBaseMixin):
38153816

38163817
@classmethod
38173818
@validate_hf_hub_args
3818-
# Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.lora_state_dict
38193819
def lora_state_dict(
38203820
cls,
38213821
pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]],
@@ -3909,6 +3909,11 @@ def lora_state_dict(
39093909
logger.warning(warn_msg)
39103910
state_dict = {k: v for k, v in state_dict.items() if "dora_scale" not in k}
39113911

3912+
# conversion.
3913+
non_diffusers = any(k.startswith("diffusion_model.") for k in state_dict)
3914+
if non_diffusers:
3915+
state_dict = _convert_non_diffusers_lumina2_lora_to_diffusers(state_dict)
3916+
39123917
return state_dict
39133918

39143919
# Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.load_lora_weights

0 commit comments

Comments
 (0)