Skip to content

Commit 0b63ad5

Browse files
apolinariomultimodalart
and
multimodalart
authored
Create convert_diffusers_sdxl_lora_to_webui.py (#6395)
* Create convert_diffusers_sdxl_lora_to_webui.py * Move some conversion logic to utils * fix logging import * Add usage example --------- Co-authored-by: multimodalart <joaopaulo.passos+multimodal@gmail.com>
1 parent 6a376ce commit 0b63ad5

File tree

2 files changed

+150
-1
lines changed

2 files changed

+150
-1
lines changed
Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,55 @@
1+
# Script for converting a Hugging Face Diffusers trained SDXL LoRAs to Kohya format
2+
# This means that you can input your diffusers-trained LoRAs and
3+
# Get the output to work with WebUIs such as AUTOMATIC1111, ComfyUI, SD.Next and others.
4+
5+
# To get started you can find some cool `diffusers` trained LoRAs such as this cute Corgy
6+
# https://huggingface.co/ignasbud/corgy_dog_LoRA/, download its `pytorch_lora_weights.safetensors` file
7+
# and run the script:
8+
# python convert_diffusers_sdxl_lora_to_webui.py --input_lora pytorch_lora_weights.safetensors --output_lora corgy.safetensors
9+
# now you can use corgy.safetensors in your WebUI of choice!
10+
11+
# To train your own, here are some diffusers training scripts and utils that you can use and then convert:
12+
# LoRA Ease - no code SDXL Dreambooth LoRA trainer: https://huggingface.co/spaces/multimodalart/lora-ease
13+
# Dreambooth Advanced Training Script - state of the art techniques such as pivotal tuning and prodigy optimizer:
14+
# - Script: https://github.com/huggingface/diffusers/blob/main/examples/advanced_diffusion_training/train_dreambooth_lora_sdxl_advanced.py
15+
# - Colab (only on Pro): https://colab.research.google.com/github/huggingface/notebooks/blob/main/diffusers/SDXL_Dreambooth_LoRA_advanced_example.ipynb
16+
# Canonical diffusers training scripts:
17+
# - Script: https://github.com/huggingface/diffusers/blob/main/examples/dreambooth/train_dreambooth_lora_sdxl.py
18+
# - Colab (runs on free tier): https://colab.research.google.com/github/huggingface/notebooks/blob/main/diffusers/SDXL_DreamBooth_LoRA_.ipynb
19+
20+
import argparse
21+
import os
22+
23+
from safetensors.torch import load_file, save_file
24+
25+
from diffusers.utils import convert_all_state_dict_to_peft, convert_state_dict_to_kohya
26+
27+
28+
def convert_and_save(input_lora, output_lora=None):
29+
if output_lora is None:
30+
base_name = os.path.splitext(input_lora)[0]
31+
output_lora = f"{base_name}_webui.safetensors"
32+
33+
diffusers_state_dict = load_file(input_lora)
34+
peft_state_dict = convert_all_state_dict_to_peft(diffusers_state_dict)
35+
kohya_state_dict = convert_state_dict_to_kohya(peft_state_dict)
36+
save_file(kohya_state_dict, output_lora)
37+
38+
39+
if __name__ == "__main__":
40+
parser = argparse.ArgumentParser(description="Convert LoRA model to PEFT and then to Kohya format.")
41+
parser.add_argument(
42+
"input_lora",
43+
type=str,
44+
help="Path to the input LoRA model file in the diffusers format.",
45+
)
46+
parser.add_argument(
47+
"output_lora",
48+
type=str,
49+
nargs="?",
50+
help="Path for the converted LoRA (safetensors format for AUTOMATIC1111, ComfyUI, etc.). Optional, defaults to input name with a _webui suffix.",
51+
)
52+
53+
args = parser.parse_args()
54+
55+
convert_and_save(args.input_lora, args.output_lora)

src/diffusers/utils/state_dict_utils.py

Lines changed: 95 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,14 +16,19 @@
1616
"""
1717
import enum
1818

19+
from .logging import get_logger
20+
21+
22+
logger = get_logger(__name__)
23+
1924

2025
class StateDictType(enum.Enum):
2126
"""
2227
The mode to use when converting state dicts.
2328
"""
2429

2530
DIFFUSERS_OLD = "diffusers_old"
26-
# KOHYA_SS = "kohya_ss" # TODO: implement this
31+
KOHYA_SS = "kohya_ss"
2732
PEFT = "peft"
2833
DIFFUSERS = "diffusers"
2934

@@ -100,6 +105,14 @@ class StateDictType(enum.Enum):
100105
".to_out_lora.down": ".out_proj.lora_linear_layer.down",
101106
}
102107

108+
PEFT_TO_KOHYA_SS = {
109+
"lora_A": "lora_down",
110+
"lora_B": "lora_up",
111+
# This is not a comprehensive dict as kohya format requires replacing `.` with `_` in keys,
112+
# adding prefixes and adding alpha values
113+
# Check `convert_state_dict_to_kohya` for more
114+
}
115+
103116
PEFT_STATE_DICT_MAPPINGS = {
104117
StateDictType.DIFFUSERS_OLD: DIFFUSERS_OLD_TO_PEFT,
105118
StateDictType.DIFFUSERS: DIFFUSERS_TO_PEFT,
@@ -110,6 +123,8 @@ class StateDictType(enum.Enum):
110123
StateDictType.PEFT: PEFT_TO_DIFFUSERS,
111124
}
112125

126+
KOHYA_STATE_DICT_MAPPINGS = {StateDictType.PEFT: PEFT_TO_KOHYA_SS}
127+
113128
KEYS_TO_ALWAYS_REPLACE = {
114129
".processor.": ".",
115130
}
@@ -228,3 +243,82 @@ def convert_unet_state_dict_to_peft(state_dict):
228243
"""
229244
mapping = UNET_TO_DIFFUSERS
230245
return convert_state_dict(state_dict, mapping)
246+
247+
248+
def convert_all_state_dict_to_peft(state_dict):
249+
r"""
250+
Attempts to first `convert_state_dict_to_peft`, and if it doesn't detect `lora_linear_layer`
251+
for a valid `DIFFUSERS` LoRA for example, attempts to exclusively convert the Unet `convert_unet_state_dict_to_peft`
252+
"""
253+
try:
254+
peft_dict = convert_state_dict_to_peft(state_dict)
255+
except Exception as e:
256+
if str(e) == "Could not automatically infer state dict type":
257+
peft_dict = convert_unet_state_dict_to_peft(state_dict)
258+
else:
259+
raise
260+
261+
if not any("lora_A" in key or "lora_B" in key for key in peft_dict.keys()):
262+
raise ValueError("Your LoRA was not converted to PEFT")
263+
264+
return peft_dict
265+
266+
267+
def convert_state_dict_to_kohya(state_dict, original_type=None, **kwargs):
268+
r"""
269+
Converts a `PEFT` state dict to `Kohya` format that can be used in AUTOMATIC1111, ComfyUI, SD.Next, InvokeAI, etc.
270+
The method only supports the conversion from PEFT to Kohya for now.
271+
272+
Args:
273+
state_dict (`dict[str, torch.Tensor]`):
274+
The state dict to convert.
275+
original_type (`StateDictType`, *optional*):
276+
The original type of the state dict, if not provided, the method will try to infer it automatically.
277+
kwargs (`dict`, *args*):
278+
Additional arguments to pass to the method.
279+
280+
- **adapter_name**: For example, in case of PEFT, some keys will be pre-pended
281+
with the adapter name, therefore needs a special handling. By default PEFT also takes care of that in
282+
`get_peft_model_state_dict` method:
283+
https://github.com/huggingface/peft/blob/ba0477f2985b1ba311b83459d29895c809404e99/src/peft/utils/save_and_load.py#L92
284+
but we add it here in case we don't want to rely on that method.
285+
"""
286+
try:
287+
import torch
288+
except ImportError:
289+
logger.error("Converting PEFT state dicts to Kohya requires torch to be installed.")
290+
raise
291+
292+
peft_adapter_name = kwargs.pop("adapter_name", None)
293+
if peft_adapter_name is not None:
294+
peft_adapter_name = "." + peft_adapter_name
295+
else:
296+
peft_adapter_name = ""
297+
298+
if original_type is None:
299+
if any(f".lora_A{peft_adapter_name}.weight" in k for k in state_dict.keys()):
300+
original_type = StateDictType.PEFT
301+
302+
if original_type not in KOHYA_STATE_DICT_MAPPINGS.keys():
303+
raise ValueError(f"Original type {original_type} is not supported")
304+
305+
# Use the convert_state_dict function with the appropriate mapping
306+
kohya_ss_partial_state_dict = convert_state_dict(state_dict, KOHYA_STATE_DICT_MAPPINGS[StateDictType.PEFT])
307+
kohya_ss_state_dict = {}
308+
309+
# Additional logic for replacing header, alpha parameters `.` with `_` in all keys
310+
for kohya_key, weight in kohya_ss_partial_state_dict.items():
311+
if "text_encoder_2." in kohya_key:
312+
kohya_key = kohya_key.replace("text_encoder_2.", "lora_te2.")
313+
elif "text_encoder." in kohya_key:
314+
kohya_key = kohya_key.replace("text_encoder.", "lora_te1.")
315+
elif "unet" in kohya_key:
316+
kohya_key = kohya_key.replace("unet", "lora_unet")
317+
kohya_key = kohya_key.replace(".", "_", kohya_key.count(".") - 2)
318+
kohya_key = kohya_key.replace(peft_adapter_name, "") # Kohya doesn't take names
319+
kohya_ss_state_dict[kohya_key] = weight
320+
if "lora_down" in kohya_key:
321+
alpha_key = f'{kohya_key.split(".")[0]}.alpha'
322+
kohya_ss_state_dict[alpha_key] = torch.tensor(len(weight))
323+
324+
return kohya_ss_state_dict

0 commit comments

Comments
 (0)