Skip to content

Commit fe6287a

Browse files
Merge branch 'huggingface:main' into cogview4_control
2 parents 19d7d27 + 9a8e8db commit fe6287a

File tree

4 files changed

+48
-5
lines changed

4 files changed

+48
-5
lines changed

src/diffusers/pipelines/controlnet/pipeline_controlnet.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -207,7 +207,7 @@ class StableDiffusionControlNetPipeline(
207207
model_cpu_offload_seq = "text_encoder->image_encoder->unet->vae"
208208
_optional_components = ["safety_checker", "feature_extractor", "image_encoder"]
209209
_exclude_from_cpu_offload = ["safety_checker"]
210-
_callback_tensor_inputs = ["latents", "prompt_embeds", "negative_prompt_embeds"]
210+
_callback_tensor_inputs = ["latents", "prompt_embeds", "negative_prompt_embeds", "image"]
211211

212212
def __init__(
213213
self,
@@ -1323,6 +1323,7 @@ def __call__(
13231323
latents = callback_outputs.pop("latents", latents)
13241324
prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
13251325
negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds)
1326+
image = callback_outputs.pop("image", image)
13261327

13271328
# call the callback, if provided
13281329
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):

src/diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -185,7 +185,7 @@ class StableDiffusionControlNetImg2ImgPipeline(
185185
model_cpu_offload_seq = "text_encoder->unet->vae"
186186
_optional_components = ["safety_checker", "feature_extractor", "image_encoder"]
187187
_exclude_from_cpu_offload = ["safety_checker"]
188-
_callback_tensor_inputs = ["latents", "prompt_embeds", "negative_prompt_embeds"]
188+
_callback_tensor_inputs = ["latents", "prompt_embeds", "negative_prompt_embeds", "control_image"]
189189

190190
def __init__(
191191
self,
@@ -1294,6 +1294,7 @@ def __call__(
12941294
latents = callback_outputs.pop("latents", latents)
12951295
prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
12961296
negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds)
1297+
control_image = callback_outputs.pop("control_image", control_image)
12971298

12981299
# call the callback, if provided
12991300
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):

src/diffusers/pipelines/controlnet/pipeline_controlnet_inpaint.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -184,7 +184,7 @@ class StableDiffusionControlNetInpaintPipeline(
184184
model_cpu_offload_seq = "text_encoder->image_encoder->unet->vae"
185185
_optional_components = ["safety_checker", "feature_extractor", "image_encoder"]
186186
_exclude_from_cpu_offload = ["safety_checker"]
187-
_callback_tensor_inputs = ["latents", "prompt_embeds", "negative_prompt_embeds"]
187+
_callback_tensor_inputs = ["latents", "prompt_embeds", "negative_prompt_embeds", "control_image"]
188188

189189
def __init__(
190190
self,
@@ -1476,6 +1476,7 @@ def __call__(
14761476
latents = callback_outputs.pop("latents", latents)
14771477
prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
14781478
negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds)
1479+
control_image = callback_outputs.pop("control_image", control_image)
14791480

14801481
# call the callback, if provided
14811482
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):

tests/lora/test_lora_layers_lumina2.py

Lines changed: 42 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,8 @@
1515
import sys
1616
import unittest
1717

18+
import numpy as np
19+
import pytest
1820
import torch
1921
from transformers import AutoTokenizer, GemmaForCausalLM
2022

@@ -24,12 +26,12 @@
2426
Lumina2Text2ImgPipeline,
2527
Lumina2Transformer2DModel,
2628
)
27-
from diffusers.utils.testing_utils import floats_tensor, require_peft_backend
29+
from diffusers.utils.testing_utils import floats_tensor, is_torch_version, require_peft_backend, skip_mps, torch_device
2830

2931

3032
sys.path.append(".")
3133

32-
from utils import PeftLoraLoaderMixinTests # noqa: E402
34+
from utils import PeftLoraLoaderMixinTests, check_if_lora_correctly_set # noqa: E402
3335

3436

3537
@require_peft_backend
@@ -130,3 +132,41 @@ def test_simple_inference_with_text_lora_fused(self):
130132
@unittest.skip("Text encoder LoRA is not supported in Lumina2.")
131133
def test_simple_inference_with_text_lora_save_load(self):
132134
pass
135+
136+
@skip_mps
137+
@pytest.mark.xfail(
138+
condition=torch.device(torch_device).type == "cpu" and is_torch_version(">=", "2.5"),
139+
reason="Test currently fails on CPU and PyTorch 2.5.1 but not on PyTorch 2.4.1.",
140+
strict=False,
141+
)
142+
def test_lora_fuse_nan(self):
143+
for scheduler_cls in self.scheduler_classes:
144+
components, text_lora_config, denoiser_lora_config = self.get_dummy_components(scheduler_cls)
145+
pipe = self.pipeline_class(**components)
146+
pipe = pipe.to(torch_device)
147+
pipe.set_progress_bar_config(disable=None)
148+
_, _, inputs = self.get_dummy_inputs(with_generator=False)
149+
150+
if "text_encoder" in self.pipeline_class._lora_loadable_modules:
151+
pipe.text_encoder.add_adapter(text_lora_config, "adapter-1")
152+
self.assertTrue(
153+
check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder"
154+
)
155+
156+
denoiser = pipe.transformer if self.unet_kwargs is None else pipe.unet
157+
denoiser.add_adapter(denoiser_lora_config, "adapter-1")
158+
self.assertTrue(check_if_lora_correctly_set(denoiser), "Lora not correctly set in denoiser.")
159+
160+
# corrupt one LoRA weight with `inf` values
161+
with torch.no_grad():
162+
pipe.transformer.layers[0].attn.to_q.lora_A["adapter-1"].weight += float("inf")
163+
164+
# with `safe_fusing=True` we should see an Error
165+
with self.assertRaises(ValueError):
166+
pipe.fuse_lora(components=self.pipeline_class._lora_loadable_modules, safe_fusing=True)
167+
168+
# without we should not see an error, but every image will be black
169+
pipe.fuse_lora(components=self.pipeline_class._lora_loadable_modules, safe_fusing=False)
170+
out = pipe(**inputs)[0]
171+
172+
self.assertTrue(np.isnan(out).all())

0 commit comments

Comments
 (0)