Skip to content

Commit 1622265

Browse files
apolinariomultimodalart
and
multimodalart
authored
Add WebUI format support to Advanced Training Script (#6403)
* Add WebUI format support to Advanced Training Script * style --------- Co-authored-by: multimodalart <joaopaulo.passos+multimodal@gmail.com>
1 parent 0b63ad5 commit 1622265

File tree

1 file changed

+43
-16
lines changed

1 file changed

+43
-16
lines changed

examples/advanced_diffusion_training/train_dreambooth_lora_sdxl_advanced.py

Lines changed: 43 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
import logging
2121
import math
2222
import os
23+
import re
2324
import shutil
2425
import warnings
2526
from pathlib import Path
@@ -41,7 +42,7 @@
4142
from peft.utils import get_peft_model_state_dict
4243
from PIL import Image
4344
from PIL.ImageOps import exif_transpose
44-
from safetensors.torch import save_file
45+
from safetensors.torch import load_file, save_file
4546
from torch.utils.data import Dataset
4647
from torchvision import transforms
4748
from tqdm.auto import tqdm
@@ -58,7 +59,13 @@
5859
from diffusers.loaders import LoraLoaderMixin
5960
from diffusers.optimization import get_scheduler
6061
from diffusers.training_utils import compute_snr
61-
from diffusers.utils import check_min_version, convert_state_dict_to_diffusers, is_wandb_available
62+
from diffusers.utils import (
63+
check_min_version,
64+
convert_all_state_dict_to_peft,
65+
convert_state_dict_to_diffusers,
66+
convert_state_dict_to_kohya,
67+
is_wandb_available,
68+
)
6269
from diffusers.utils.import_utils import is_xformers_available
6370

6471

@@ -93,10 +100,17 @@ def save_model_card(
93100
img_str += f"""
94101
- text: '{instance_prompt}'
95102
"""
96-
103+
embeddings_filename = f"{repo_folder}_emb"
104+
instance_prompt_webui = re.sub(r"<s\d+>", "", re.sub(r"<s\d+>", embeddings_filename, instance_prompt, count=1))
105+
ti_keys = ", ".join(f'"{match}"' for match in re.findall(r"<s\d+>", instance_prompt))
106+
if instance_prompt_webui != embeddings_filename:
107+
instance_prompt_sentence = f"For example, `{instance_prompt_webui}`"
108+
else:
109+
instance_prompt_sentence = ""
97110
trigger_str = f"You should use {instance_prompt} to trigger the image generation."
98111
diffusers_imports_pivotal = ""
99112
diffusers_example_pivotal = ""
113+
webui_example_pivotal = ""
100114
if train_text_encoder_ti:
101115
trigger_str = (
102116
"To trigger image generation of trained concept(or concepts) replace each concept identifier "
@@ -105,11 +119,16 @@ def save_model_card(
105119
diffusers_imports_pivotal = """from huggingface_hub import hf_hub_download
106120
from safetensors.torch import load_file
107121
"""
108-
diffusers_example_pivotal = f"""embedding_path = hf_hub_download(repo_id='{repo_id}', filename="embeddings.safetensors", repo_type="model")
122+
diffusers_example_pivotal = f"""embedding_path = hf_hub_download(repo_id='{repo_id}', filename='{embeddings_filename}.safetensors' repo_type="model")
109123
state_dict = load_file(embedding_path)
110-
pipeline.load_textual_inversion(state_dict["clip_l"], token=["<s0>", "<s1>"], text_encoder=pipeline.text_encoder, tokenizer=pipeline.tokenizer)
111-
pipeline.load_textual_inversion(state_dict["clip_g"], token=["<s0>", "<s1>"], text_encoder=pipeline.text_encoder_2, tokenizer=pipeline.tokenizer_2)
124+
pipeline.load_textual_inversion(state_dict["clip_l"], token=[{ti_keys}], text_encoder=pipeline.text_encoder, tokenizer=pipeline.tokenizer)
125+
pipeline.load_textual_inversion(state_dict["clip_g"], token=[{ti_keys}], text_encoder=pipeline.text_encoder_2, tokenizer=pipeline.tokenizer_2)
112126
"""
127+
webui_example_pivotal = f"""- *Embeddings*: download **[`{embeddings_filename}.safetensors` here 💾](/{repo_id}/blob/main/{embeddings_filename}.safetensors)**.
128+
- Place it on it on your `embeddings` folder
129+
- Use it by adding `{embeddings_filename}` to your prompt. {instance_prompt_sentence}
130+
(you need both the LoRA and the embeddings as they were trained together for this LoRA)
131+
"""
113132
if token_abstraction_dict:
114133
for key, value in token_abstraction_dict.items():
115134
tokens = "".join(value)
@@ -141,9 +160,14 @@ def save_model_card(
141160
142161
### These are {repo_id} LoRA adaption weights for {base_model}.
143162
144-
## Trigger words
163+
## Download model
145164
146-
{trigger_str}
165+
### Use it with UIs such as AUTOMATIC1111, Comfy UI, SD.Next, Invoke
166+
167+
- **LoRA**: download **[`{repo_folder}.safetensors` here 💾](/{repo_id}/blob/main/{repo_folder}.safetensors)**.
168+
- Place it on your `models/Lora` folder.
169+
- On AUTOMATIC1111, load the LoRA by adding `<lora:{repo_folder}:1>` to your prompt. On ComfyUI just [load it as a regular LoRA](https://comfyanonymous.github.io/ComfyUI_examples/lora/).
170+
{webui_example_pivotal}
147171
148172
## Use it with the [🧨 diffusers library](https://github.com/huggingface/diffusers)
149173
@@ -159,16 +183,12 @@ def save_model_card(
159183
160184
For more details, including weighting, merging and fusing LoRAs, check the [documentation on loading LoRAs in diffusers](https://huggingface.co/docs/diffusers/main/en/using-diffusers/loading_adapters)
161185
162-
## Download model
163-
164-
### Use it with UIs such as AUTOMATIC1111, Comfy UI, SD.Next, Invoke
165-
166-
- Download the LoRA *.safetensors [here](/{repo_id}/blob/main/pytorch_lora_weights.safetensors). Rename it and place it on your Lora folder.
167-
- Download the text embeddings *.safetensors [here](/{repo_id}/blob/main/embeddings.safetensors). Rename it and place it on it on your embeddings folder.
186+
## Trigger words
168187
169-
All [Files & versions](/{repo_id}/tree/main).
188+
{trigger_str}
170189
171190
## Details
191+
All [Files & versions](/{repo_id}/tree/main).
172192
173193
The weights were trained using [🧨 diffusers Advanced Dreambooth Training Script](https://github.com/huggingface/diffusers/blob/main/examples/advanced_diffusion_training/train_dreambooth_lora_sdxl_advanced.py).
174194
@@ -2035,8 +2055,15 @@ def compute_text_embeddings(prompt, text_encoders, tokenizers):
20352055

20362056
if args.train_text_encoder_ti:
20372057
embedding_handler.save_embeddings(
2038-
f"{args.output_dir}/embeddings.safetensors",
2058+
f"{args.output_dir}/{args.output_dir}_emb.safetensors",
20392059
)
2060+
2061+
# Conver to WebUI format
2062+
lora_state_dict = load_file(f"{args.output_dir}/pytorch_lora_weights.safetensors")
2063+
peft_state_dict = convert_all_state_dict_to_peft(lora_state_dict)
2064+
kohya_state_dict = convert_state_dict_to_kohya(peft_state_dict)
2065+
save_file(kohya_state_dict, f"{args.output_dir}/{args.output_dir}.safetensors")
2066+
20402067
save_model_card(
20412068
model_id if not args.push_to_hub else repo_id,
20422069
images=images,

0 commit comments

Comments
 (0)