Skip to content

Commit bc74fe8

Browse files
committed
Merge branch 'main' into lora-load-adapter
2 parents 37acd79 + 1d1e1a2 commit bc74fe8

File tree

9 files changed

+99
-19
lines changed

9 files changed

+99
-19
lines changed

.github/workflows/push_tests.yml

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -81,7 +81,7 @@ jobs:
8181
- name: Environment
8282
run: |
8383
python utils/print_env.py
84-
- name: Slow PyTorch CUDA checkpoint tests on Ubuntu
84+
- name: PyTorch CUDA checkpoint tests on Ubuntu
8585
env:
8686
HF_TOKEN: ${{ secrets.HF_TOKEN }}
8787
# https://pytorch.org/docs/stable/notes/randomness.html#avoiding-nondeterministic-algorithms
@@ -184,7 +184,7 @@ jobs:
184184
run: |
185185
python utils/print_env.py
186186
187-
- name: Run slow Flax TPU tests
187+
- name: Run Flax TPU tests
188188
env:
189189
HF_TOKEN: ${{ secrets.HF_TOKEN }}
190190
run: |
@@ -232,7 +232,7 @@ jobs:
232232
run: |
233233
python utils/print_env.py
234234
235-
- name: Run slow ONNXRuntime CUDA tests
235+
- name: Run ONNXRuntime CUDA tests
236236
env:
237237
HF_TOKEN: ${{ secrets.HF_TOKEN }}
238238
run: |

examples/dreambooth/train_dreambooth_lora_sd3.py

Lines changed: 14 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -86,6 +86,15 @@ def save_model_card(
8686
validation_prompt=None,
8787
repo_folder=None,
8888
):
89+
if "large" in base_model:
90+
model_variant = "SD3.5-Large"
91+
license_url = "https://huggingface.co/stabilityai/stable-diffusion-3.5-large/blob/main/LICENSE.md"
92+
variant_tags = ["sd3.5-large", "sd3.5", "sd3.5-diffusers"]
93+
else:
94+
model_variant = "SD3"
95+
license_url = "https://huggingface.co/stabilityai/stable-diffusion-3-medium/blob/main/LICENSE.md"
96+
variant_tags = ["sd3", "sd3-diffusers"]
97+
8998
widget_dict = []
9099
if images is not None:
91100
for i, image in enumerate(images):
@@ -95,7 +104,7 @@ def save_model_card(
95104
)
96105

97106
model_description = f"""
98-
# SD3 DreamBooth LoRA - {repo_id}
107+
# {model_variant} DreamBooth LoRA - {repo_id}
99108
100109
<Gallery />
101110
@@ -120,7 +129,7 @@ def save_model_card(
120129
```py
121130
from diffusers import AutoPipelineForText2Image
122131
import torch
123-
pipeline = AutoPipelineForText2Image.from_pretrained('stabilityai/stable-diffusion-3-medium-diffusers', torch_dtype=torch.float16).to('cuda')
132+
pipeline = AutoPipelineForText2Image.from_pretrained({base_model}, torch_dtype=torch.float16).to('cuda')
124133
pipeline.load_lora_weights('{repo_id}', weight_name='pytorch_lora_weights.safetensors')
125134
image = pipeline('{validation_prompt if validation_prompt else instance_prompt}').images[0]
126135
```
@@ -135,7 +144,7 @@ def save_model_card(
135144
136145
## License
137146
138-
Please adhere to the licensing terms as described [here](https://huggingface.co/stabilityai/stable-diffusion-3-medium/blob/main/LICENSE).
147+
Please adhere to the licensing terms as described [here]({license_url}).
139148
"""
140149
model_card = load_or_create_model_card(
141150
repo_id_or_path=repo_id,
@@ -151,11 +160,11 @@ def save_model_card(
151160
"diffusers-training",
152161
"diffusers",
153162
"lora",
154-
"sd3",
155-
"sd3-diffusers",
156163
"template:sd-lora",
157164
]
158165

166+
tags += variant_tags
167+
159168
model_card = populate_model_card(model_card, tags=tags)
160169
model_card.save(os.path.join(repo_folder, "README.md"))
161170

examples/dreambooth/train_dreambooth_sd3.py

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -77,6 +77,15 @@ def save_model_card(
7777
validation_prompt=None,
7878
repo_folder=None,
7979
):
80+
if "large" in base_model:
81+
model_variant = "SD3.5-Large"
82+
license_url = "https://huggingface.co/stabilityai/stable-diffusion-3.5-large/blob/main/LICENSE.md"
83+
variant_tags = ["sd3.5-large", "sd3.5", "sd3.5-diffusers"]
84+
else:
85+
model_variant = "SD3"
86+
license_url = "https://huggingface.co/stabilityai/stable-diffusion-3-medium/blob/main/LICENSE.md"
87+
variant_tags = ["sd3", "sd3-diffusers"]
88+
8089
widget_dict = []
8190
if images is not None:
8291
for i, image in enumerate(images):
@@ -86,7 +95,7 @@ def save_model_card(
8695
)
8796

8897
model_description = f"""
89-
# SD3 DreamBooth - {repo_id}
98+
# {model_variant} DreamBooth - {repo_id}
9099
91100
<Gallery />
92101
@@ -113,7 +122,7 @@ def save_model_card(
113122
114123
## License
115124
116-
Please adhere to the licensing terms as described `[here](https://huggingface.co/stabilityai/stable-diffusion-3-medium/blob/main/LICENSE)`.
125+
Please adhere to the licensing terms as described `[here]({license_url})`.
117126
"""
118127
model_card = load_or_create_model_card(
119128
repo_id_or_path=repo_id,
@@ -128,10 +137,9 @@ def save_model_card(
128137
"text-to-image",
129138
"diffusers-training",
130139
"diffusers",
131-
"sd3",
132-
"sd3-diffusers",
133140
"template:sd-lora",
134141
]
142+
tags += variant_tags
135143

136144
model_card = populate_model_card(model_card, tags=tags)
137145
model_card.save(os.path.join(repo_folder, "README.md"))

src/diffusers/callbacks.py

Lines changed: 56 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -97,13 +97,17 @@ def callback_fn(self, pipeline, step_index, timestep, callback_kwargs) -> Dict[s
9797

9898
class SDXLCFGCutoffCallback(PipelineCallback):
9999
"""
100-
Callback function for Stable Diffusion XL Pipelines. After certain number of steps (set by `cutoff_step_ratio` or
101-
`cutoff_step_index`), this callback will disable the CFG.
100+
Callback function for the base Stable Diffusion XL Pipelines. After certain number of steps (set by
101+
`cutoff_step_ratio` or `cutoff_step_index`), this callback will disable the CFG.
102102
103103
Note: This callback mutates the pipeline by changing the `_guidance_scale` attribute to 0.0 after the cutoff step.
104104
"""
105105

106-
tensor_inputs = ["prompt_embeds", "add_text_embeds", "add_time_ids"]
106+
tensor_inputs = [
107+
"prompt_embeds",
108+
"add_text_embeds",
109+
"add_time_ids",
110+
]
107111

108112
def callback_fn(self, pipeline, step_index, timestep, callback_kwargs) -> Dict[str, Any]:
109113
cutoff_step_ratio = self.config.cutoff_step_ratio
@@ -129,6 +133,55 @@ def callback_fn(self, pipeline, step_index, timestep, callback_kwargs) -> Dict[s
129133
callback_kwargs[self.tensor_inputs[0]] = prompt_embeds
130134
callback_kwargs[self.tensor_inputs[1]] = add_text_embeds
131135
callback_kwargs[self.tensor_inputs[2]] = add_time_ids
136+
137+
return callback_kwargs
138+
139+
140+
class SDXLControlnetCFGCutoffCallback(PipelineCallback):
141+
"""
142+
Callback function for the Controlnet Stable Diffusion XL Pipelines. After certain number of steps (set by
143+
`cutoff_step_ratio` or `cutoff_step_index`), this callback will disable the CFG.
144+
145+
Note: This callback mutates the pipeline by changing the `_guidance_scale` attribute to 0.0 after the cutoff step.
146+
"""
147+
148+
tensor_inputs = [
149+
"prompt_embeds",
150+
"add_text_embeds",
151+
"add_time_ids",
152+
"image",
153+
]
154+
155+
def callback_fn(self, pipeline, step_index, timestep, callback_kwargs) -> Dict[str, Any]:
156+
cutoff_step_ratio = self.config.cutoff_step_ratio
157+
cutoff_step_index = self.config.cutoff_step_index
158+
159+
# Use cutoff_step_index if it's not None, otherwise use cutoff_step_ratio
160+
cutoff_step = (
161+
cutoff_step_index if cutoff_step_index is not None else int(pipeline.num_timesteps * cutoff_step_ratio)
162+
)
163+
164+
if step_index == cutoff_step:
165+
prompt_embeds = callback_kwargs[self.tensor_inputs[0]]
166+
prompt_embeds = prompt_embeds[-1:] # "-1" denotes the embeddings for conditional text tokens.
167+
168+
add_text_embeds = callback_kwargs[self.tensor_inputs[1]]
169+
add_text_embeds = add_text_embeds[-1:] # "-1" denotes the embeddings for conditional pooled text tokens
170+
171+
add_time_ids = callback_kwargs[self.tensor_inputs[2]]
172+
add_time_ids = add_time_ids[-1:] # "-1" denotes the embeddings for conditional added time vector
173+
174+
# For Controlnet
175+
image = callback_kwargs[self.tensor_inputs[3]]
176+
image = image[-1:]
177+
178+
pipeline._guidance_scale = 0.0
179+
180+
callback_kwargs[self.tensor_inputs[0]] = prompt_embeds
181+
callback_kwargs[self.tensor_inputs[1]] = add_text_embeds
182+
callback_kwargs[self.tensor_inputs[2]] = add_time_ids
183+
callback_kwargs[self.tensor_inputs[3]] = image
184+
132185
return callback_kwargs
133186

134187

src/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -242,6 +242,7 @@ class StableDiffusionXLControlNetPipeline(
242242
"add_time_ids",
243243
"negative_pooled_prompt_embeds",
244244
"negative_add_time_ids",
245+
"image",
245246
]
246247

247248
def __init__(
@@ -1540,6 +1541,7 @@ def __call__(
15401541
)
15411542
add_time_ids = callback_outputs.pop("add_time_ids", add_time_ids)
15421543
negative_add_time_ids = callback_outputs.pop("negative_add_time_ids", negative_add_time_ids)
1544+
image = callback_outputs.pop("image", image)
15431545

15441546
# call the callback, if provided
15451547
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):

src/diffusers/pipelines/flux/pipeline_flux_controlnet_image_to_image.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -903,9 +903,12 @@ def __call__(
903903

904904
timestep = t.expand(latents.shape[0]).to(latents.dtype)
905905

906-
guidance = (
907-
torch.tensor([guidance_scale], device=device) if self.controlnet.config.guidance_embeds else None
908-
)
906+
if isinstance(self.controlnet, FluxMultiControlNetModel):
907+
use_guidance = self.controlnet.nets[0].config.guidance_embeds
908+
else:
909+
use_guidance = self.controlnet.config.guidance_embeds
910+
911+
guidance = torch.tensor([guidance_scale], device=device) if use_guidance else None
909912
guidance = guidance.expand(latents.shape[0]) if guidance is not None else None
910913

911914
if isinstance(controlnet_keep[i], list):

tests/lora/test_lora_layers_flux.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727
from diffusers.utils.testing_utils import (
2828
floats_tensor,
2929
is_peft_available,
30+
nightly,
3031
numpy_cosine_similarity_distance,
3132
require_peft_backend,
3233
require_torch_gpu,
@@ -165,9 +166,10 @@ def test_modify_padding_mode(self):
165166

166167

167168
@slow
169+
@nightly
168170
@require_torch_gpu
169171
@require_peft_backend
170-
# @unittest.skip("We cannot run inference on this model with the current CI hardware")
172+
@unittest.skip("We cannot run inference on this model with the current CI hardware")
171173
# TODO (DN6, sayakpaul): move these tests to a beefier GPU
172174
class FluxLoRAIntegrationTests(unittest.TestCase):
173175
"""internal note: The integration slices were obtained on audace.

tests/lora/test_lora_layers_sd.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@
3434
from diffusers.utils.import_utils import is_accelerate_available
3535
from diffusers.utils.testing_utils import (
3636
load_image,
37+
nightly,
3738
numpy_cosine_similarity_distance,
3839
require_peft_backend,
3940
require_torch_gpu,
@@ -207,6 +208,7 @@ def test_integration_move_lora_dora_cpu(self):
207208

208209

209210
@slow
211+
@nightly
210212
@require_torch_gpu
211213
@require_peft_backend
212214
class LoraIntegrationTests(unittest.TestCase):

tests/lora/test_lora_layers_sdxl.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -113,6 +113,7 @@ def tearDown(self):
113113

114114

115115
@slow
116+
@nightly
116117
@require_torch_gpu
117118
@require_peft_backend
118119
class LoraSDXLIntegrationTests(unittest.TestCase):

0 commit comments

Comments
 (0)