Skip to content

Commit 18b7d6d

Browse files
authored
Merge branch 'main' into pyramid-attention-broadcast
2 parents a5f51bb + a3cc641 commit 18b7d6d

File tree

15 files changed

+1077
-535
lines changed

15 files changed

+1077
-535
lines changed

examples/advanced_diffusion_training/train_dreambooth_lora_sd15_advanced.py

Lines changed: 38 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -67,6 +67,7 @@
6767
convert_state_dict_to_kohya,
6868
is_wandb_available,
6969
)
70+
from diffusers.utils.hub_utils import load_or_create_model_card, populate_model_card
7071
from diffusers.utils.import_utils import is_xformers_available
7172

7273

@@ -79,30 +80,27 @@
7980
def save_model_card(
8081
repo_id: str,
8182
use_dora: bool,
82-
images=None,
83-
base_model=str,
83+
images: list = None,
84+
base_model: str = None,
8485
train_text_encoder=False,
8586
train_text_encoder_ti=False,
8687
token_abstraction_dict=None,
87-
instance_prompt=str,
88-
validation_prompt=str,
88+
instance_prompt=None,
89+
validation_prompt=None,
8990
repo_folder=None,
9091
vae_path=None,
9192
):
92-
img_str = "widget:\n"
9393
lora = "lora" if not use_dora else "dora"
94-
for i, image in enumerate(images):
95-
image.save(os.path.join(repo_folder, f"image_{i}.png"))
96-
img_str += f"""
97-
- text: '{validation_prompt if validation_prompt else ' ' }'
98-
output:
99-
url:
100-
"image_{i}.png"
101-
"""
102-
if not images:
103-
img_str += f"""
104-
- text: '{instance_prompt}'
105-
"""
94+
95+
widget_dict = []
96+
if images is not None:
97+
for i, image in enumerate(images):
98+
image.save(os.path.join(repo_folder, f"image_{i}.png"))
99+
widget_dict.append(
100+
{"text": validation_prompt if validation_prompt else " ", "output": {"url": f"image_{i}.png"}}
101+
)
102+
else:
103+
widget_dict.append({"text": instance_prompt})
106104
embeddings_filename = f"{repo_folder}_emb"
107105
instance_prompt_webui = re.sub(r"<s\d+>", "", re.sub(r"<s\d+>", embeddings_filename, instance_prompt, count=1))
108106
ti_keys = ", ".join(f'"{match}"' for match in re.findall(r"<s\d+>", instance_prompt))
@@ -137,24 +135,7 @@ def save_model_card(
137135
trigger_str += f"""
138136
to trigger concept `{key}` → use `{tokens}` in your prompt \n
139137
"""
140-
141-
yaml = f"""---
142-
tags:
143-
- stable-diffusion
144-
- stable-diffusion-diffusers
145-
- diffusers-training
146-
- text-to-image
147-
- diffusers
148-
- {lora}
149-
- template:sd-lora
150-
{img_str}
151-
base_model: {base_model}
152-
instance_prompt: {instance_prompt}
153-
license: openrail++
154-
---
155-
"""
156-
157-
model_card = f"""
138+
model_description = f"""
158139
# SD1.5 LoRA DreamBooth - {repo_id}
159140
160141
<Gallery />
@@ -202,8 +183,28 @@ def save_model_card(
202183
Special VAE used for training: {vae_path}.
203184
204185
"""
205-
with open(os.path.join(repo_folder, "README.md"), "w") as f:
206-
f.write(yaml + model_card)
186+
model_card = load_or_create_model_card(
187+
repo_id_or_path=repo_id,
188+
from_training=True,
189+
license="openrail++",
190+
base_model=base_model,
191+
prompt=instance_prompt,
192+
model_description=model_description,
193+
inference=True,
194+
widget=widget_dict,
195+
)
196+
197+
tags = [
198+
"text-to-image",
199+
"diffusers",
200+
"diffusers-training",
201+
lora,
202+
"template:sd-lora" "stable-diffusion",
203+
"stable-diffusion-diffusers",
204+
]
205+
model_card = populate_model_card(model_card, tags=tags)
206+
207+
model_card.save(os.path.join(repo_folder, "README.md"))
207208

208209

209210
def import_model_class_from_model_name_or_path(

examples/controlnet/train_controlnet_flux.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -152,6 +152,7 @@ def log_validation(
152152
guidance_scale=3.5,
153153
generator=generator,
154154
).images[0]
155+
image = image.resize((args.resolution, args.resolution))
155156
images.append(image)
156157
image_logs.append(
157158
{"validation_image": validation_image, "images": images, "validation_prompt": validation_prompt}

examples/dreambooth/train_dreambooth_flux.py

Lines changed: 26 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,7 @@
5757
is_wandb_available,
5858
)
5959
from diffusers.utils.hub_utils import load_or_create_model_card, populate_model_card
60+
from diffusers.utils.import_utils import is_torch_npu_available
6061
from diffusers.utils.torch_utils import is_compiled_module
6162

6263

@@ -68,6 +69,12 @@
6869

6970
logger = get_logger(__name__)
7071

72+
if is_torch_npu_available():
73+
import torch_npu
74+
75+
torch.npu.config.allow_internal_format = False
76+
torch.npu.set_compile_mode(jit_compile=False)
77+
7178

7279
def save_model_card(
7380
repo_id: str,
@@ -189,6 +196,8 @@ def log_validation(
189196
del pipeline
190197
if torch.cuda.is_available():
191198
torch.cuda.empty_cache()
199+
elif is_torch_npu_available():
200+
torch_npu.npu.empty_cache()
192201

193202
return images
194203

@@ -1035,7 +1044,9 @@ def main(args):
10351044
cur_class_images = len(list(class_images_dir.iterdir()))
10361045

10371046
if cur_class_images < args.num_class_images:
1038-
has_supported_fp16_accelerator = torch.cuda.is_available() or torch.backends.mps.is_available()
1047+
has_supported_fp16_accelerator = (
1048+
torch.cuda.is_available() or torch.backends.mps.is_available() or is_torch_npu_available()
1049+
)
10391050
torch_dtype = torch.float16 if has_supported_fp16_accelerator else torch.float32
10401051
if args.prior_generation_precision == "fp32":
10411052
torch_dtype = torch.float32
@@ -1073,6 +1084,8 @@ def main(args):
10731084
del pipeline
10741085
if torch.cuda.is_available():
10751086
torch.cuda.empty_cache()
1087+
elif is_torch_npu_available():
1088+
torch_npu.npu.empty_cache()
10761089

10771090
# Handle the repository creation
10781091
if accelerator.is_main_process:
@@ -1354,6 +1367,8 @@ def compute_text_embeddings(prompt, text_encoders, tokenizers):
13541367
gc.collect()
13551368
if torch.cuda.is_available():
13561369
torch.cuda.empty_cache()
1370+
elif is_torch_npu_available():
1371+
torch_npu.npu.empty_cache()
13571372

13581373
# If custom instance prompts are NOT provided (i.e. the instance prompt is used for all images),
13591374
# pack the statically computed variables appropriately here. This is so that we don't
@@ -1719,9 +1734,15 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32):
17191734
)
17201735
if not args.train_text_encoder:
17211736
del text_encoder_one, text_encoder_two
1722-
torch.cuda.empty_cache()
1737+
if torch.cuda.is_available():
1738+
torch.cuda.empty_cache()
1739+
elif is_torch_npu_available():
1740+
torch_npu.npu.empty_cache()
17231741
gc.collect()
17241742

1743+
images = None
1744+
del pipeline
1745+
17251746
# Save the lora layers
17261747
accelerator.wait_for_everyone()
17271748
if accelerator.is_main_process:
@@ -1780,6 +1801,9 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32):
17801801
ignore_patterns=["step_*", "epoch_*"],
17811802
)
17821803

1804+
images = None
1805+
del pipeline
1806+
17831807
accelerator.end_training()
17841808

17851809

examples/dreambooth/train_dreambooth_lora_flux.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -177,7 +177,7 @@ def log_validation(
177177
f"Running validation... \n Generating {args.num_validation_images} images with prompt:"
178178
f" {args.validation_prompt}."
179179
)
180-
pipeline = pipeline.to(accelerator.device, dtype=torch_dtype)
180+
pipeline = pipeline.to(accelerator.device)
181181
pipeline.set_progress_bar_config(disable=True)
182182

183183
# run inference
@@ -1706,7 +1706,7 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32):
17061706
)
17071707

17081708
# handle guidance
1709-
if transformer.config.guidance_embeds:
1709+
if accelerator.unwrap_model(transformer).config.guidance_embeds:
17101710
guidance = torch.tensor([args.guidance_scale], device=accelerator.device)
17111711
guidance = guidance.expand(model_input.shape[0])
17121712
else:
@@ -1819,6 +1819,8 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32):
18191819
# create pipeline
18201820
if not args.train_text_encoder:
18211821
text_encoder_one, text_encoder_two = load_text_encoders(text_encoder_cls_one, text_encoder_cls_two)
1822+
text_encoder_one.to(weight_dtype)
1823+
text_encoder_two.to(weight_dtype)
18221824
pipeline = FluxPipeline.from_pretrained(
18231825
args.pretrained_model_name_or_path,
18241826
vae=vae,
@@ -1842,6 +1844,9 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32):
18421844
del text_encoder_one, text_encoder_two
18431845
free_memory()
18441846

1847+
images = None
1848+
del pipeline
1849+
18451850
# Save the lora layers
18461851
accelerator.wait_for_everyone()
18471852
if accelerator.is_main_process:
@@ -1906,6 +1911,9 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32):
19061911
ignore_patterns=["step_*", "epoch_*"],
19071912
)
19081913

1914+
images = None
1915+
del pipeline
1916+
19091917
accelerator.end_training()
19101918

19111919

examples/reinforcement_learning/README.md

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,13 @@
1-
# Overview
1+
2+
## Diffusion-based Policy Learning for RL
3+
4+
`diffusion_policy` implements [Diffusion Policy](https://diffusion-policy.cs.columbia.edu/), a diffusion model that predicts robot action sequences in reinforcement learning tasks.
5+
6+
This example implements a robot control model for pushing a T-shaped block into a target area. The model takes in current state observations as input, and outputs a trajectory of subsequent steps to follow.
7+
8+
To execute the script, run `diffusion_policy.py`
9+
10+
## Diffuser Locomotion
211

312
These examples show how to run [Diffuser](https://arxiv.org/abs/2205.09991) in Diffusers.
413
There are two ways to use the script, `run_diffuser_locomotion.py`.

0 commit comments

Comments
 (0)