Skip to content

Commit aa1797e

Browse files
enable stable-xl textual inversion (#6421)
* enable stable-xl textual inversion * check if optimizer_2 exists * check text_encoder_2 before using * add textual inversion for sdxl in a single file * fix style * fix example style * reset for error changes * add readme for sdxl * fix style * disable autocast as it will cause cast error when weight_dtype=bf16 * fix spelling error * fix style and readme and 8bit optimizer * add README_sdxl.md link * add tracker key on log_validation * run style * rm the second center crop --------- Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>
1 parent 5bacc2f commit aa1797e

File tree

4 files changed

+1238
-0
lines changed

4 files changed

+1238
-0
lines changed

examples/textual_inversion/README.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,8 @@ Now we can launch the training using:
6060

6161
**___Note: Change the `resolution` to 768 if you are using the [stable-diffusion-2](https://huggingface.co/stabilityai/stable-diffusion-2) 768x768 model.___**
6262

63+
**___Note: Please follow the [README_sdxl.md](./README_sdxl.md) if you are using the [stable-diffusion-xl](https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0).___**
64+
6365
```bash
6466
export MODEL_NAME="runwayml/stable-diffusion-v1-5"
6567
export DATA_DIR="./cat"
Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,26 @@
1+
## Textual Inversion fine-tuning example for SDXL
2+
3+
```
4+
export MODEL_NAME="stabilityai/stable-diffusion-xl-base-1.0"
5+
export DATA_DIR="./cat"
6+
7+
accelerate launch textual_inversion_sdxl.py \
8+
--pretrained_model_name_or_path=$MODEL_NAME \
9+
--train_data_dir=$DATA_DIR \
10+
--learnable_property="object" \
11+
--placeholder_token="<cat-toy>" \
12+
--initializer_token="toy" \
13+
--mixed_precision="bf16" \
14+
--resolution=768 \
15+
--train_batch_size=1 \
16+
--gradient_accumulation_steps=4 \
17+
--max_train_steps=500 \
18+
--learning_rate=5.0e-04 \
19+
--scale_lr \
20+
--lr_scheduler="constant" \
21+
--lr_warmup_steps=0 \
22+
--save_as_full_pipeline \
23+
--output_dir="./textual_inversion_cat_sdxl"
24+
```
25+
26+
For now, only training of the first text encoder is supported.
Lines changed: 152 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,152 @@
1+
# coding=utf-8
2+
# Copyright 2023 HuggingFace Inc.
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
16+
import logging
17+
import os
18+
import sys
19+
import tempfile
20+
21+
22+
sys.path.append("..")
23+
from test_examples_utils import ExamplesTestsAccelerate, run_command # noqa: E402
24+
25+
26+
logging.basicConfig(level=logging.DEBUG)
27+
28+
logger = logging.getLogger()
29+
stream_handler = logging.StreamHandler(sys.stdout)
30+
logger.addHandler(stream_handler)
31+
32+
33+
class TextualInversionSdxl(ExamplesTestsAccelerate):
34+
def test_textual_inversion_sdxl(self):
35+
with tempfile.TemporaryDirectory() as tmpdir:
36+
test_args = f"""
37+
examples/textual_inversion/textual_inversion_sdxl.py
38+
--pretrained_model_name_or_path hf-internal-testing/tiny-sdxl-pipe
39+
--train_data_dir docs/source/en/imgs
40+
--learnable_property object
41+
--placeholder_token <cat-toy>
42+
--initializer_token a
43+
--save_steps 1
44+
--num_vectors 2
45+
--resolution 64
46+
--train_batch_size 1
47+
--gradient_accumulation_steps 1
48+
--max_train_steps 2
49+
--learning_rate 5.0e-04
50+
--scale_lr
51+
--lr_scheduler constant
52+
--lr_warmup_steps 0
53+
--output_dir {tmpdir}
54+
""".split()
55+
56+
run_command(self._launch_args + test_args)
57+
# save_pretrained smoke test
58+
self.assertTrue(os.path.isfile(os.path.join(tmpdir, "learned_embeds.safetensors")))
59+
60+
def test_textual_inversion_sdxl_checkpointing(self):
61+
with tempfile.TemporaryDirectory() as tmpdir:
62+
test_args = f"""
63+
examples/textual_inversion/textual_inversion_sdxl.py
64+
--pretrained_model_name_or_path hf-internal-testing/tiny-sdxl-pipe
65+
--train_data_dir docs/source/en/imgs
66+
--learnable_property object
67+
--placeholder_token <cat-toy>
68+
--initializer_token a
69+
--save_steps 1
70+
--num_vectors 2
71+
--resolution 64
72+
--train_batch_size 1
73+
--gradient_accumulation_steps 1
74+
--max_train_steps 3
75+
--learning_rate 5.0e-04
76+
--scale_lr
77+
--lr_scheduler constant
78+
--lr_warmup_steps 0
79+
--output_dir {tmpdir}
80+
--checkpointing_steps=1
81+
--checkpoints_total_limit=2
82+
""".split()
83+
84+
run_command(self._launch_args + test_args)
85+
86+
# check checkpoint directories exist
87+
self.assertEqual(
88+
{x for x in os.listdir(tmpdir) if "checkpoint" in x},
89+
{"checkpoint-2", "checkpoint-3"},
90+
)
91+
92+
def test_textual_inversion_sdxl_checkpointing_checkpoints_total_limit_removes_multiple_checkpoints(self):
93+
with tempfile.TemporaryDirectory() as tmpdir:
94+
test_args = f"""
95+
examples/textual_inversion/textual_inversion_sdxl.py
96+
--pretrained_model_name_or_path hf-internal-testing/tiny-sdxl-pipe
97+
--train_data_dir docs/source/en/imgs
98+
--learnable_property object
99+
--placeholder_token <cat-toy>
100+
--initializer_token a
101+
--save_steps 1
102+
--num_vectors 2
103+
--resolution 64
104+
--train_batch_size 1
105+
--gradient_accumulation_steps 1
106+
--max_train_steps 2
107+
--learning_rate 5.0e-04
108+
--scale_lr
109+
--lr_scheduler constant
110+
--lr_warmup_steps 0
111+
--output_dir {tmpdir}
112+
--checkpointing_steps=1
113+
""".split()
114+
115+
run_command(self._launch_args + test_args)
116+
117+
# check checkpoint directories exist
118+
self.assertEqual(
119+
{x for x in os.listdir(tmpdir) if "checkpoint" in x},
120+
{"checkpoint-1", "checkpoint-2"},
121+
)
122+
123+
resume_run_args = f"""
124+
examples/textual_inversion/textual_inversion_sdxl.py
125+
--pretrained_model_name_or_path hf-internal-testing/tiny-sdxl-pipe
126+
--train_data_dir docs/source/en/imgs
127+
--learnable_property object
128+
--placeholder_token <cat-toy>
129+
--initializer_token a
130+
--save_steps 1
131+
--num_vectors 2
132+
--resolution 64
133+
--train_batch_size 1
134+
--gradient_accumulation_steps 1
135+
--max_train_steps 2
136+
--learning_rate 5.0e-04
137+
--scale_lr
138+
--lr_scheduler constant
139+
--lr_warmup_steps 0
140+
--output_dir {tmpdir}
141+
--checkpointing_steps=1
142+
--resume_from_checkpoint=checkpoint-2
143+
--checkpoints_total_limit=2
144+
""".split()
145+
146+
run_command(self._launch_args + resume_run_args)
147+
148+
# check checkpoint directories exist
149+
self.assertEqual(
150+
{x for x in os.listdir(tmpdir) if "checkpoint" in x},
151+
{"checkpoint-2", "checkpoint-3"},
152+
)

0 commit comments

Comments
 (0)