Skip to content

Commit 1040dfd

Browse files
[Fix] Multiple image conditionings in a single batch for StableDiffusionControlNetPipeline (#6334)
* [Fix] Multiple image conditionings in a single batch for `StableDiffusionControlNetPipeline`. * Refactor `check_inputs` in `StableDiffusionControlNetPipeline` to avoid redundant codes. * Make the behavior of MultiControlNetModel to be the same to the original ControlNetModel * Keep the code change minimum for nested list support * Add fast test `test_inference_nested_image_input` * Remove redundant check for nested image condition in `check_inputs` Remove `len(image) == len(prompt)` check out of `check_image()` Co-authored-by: YiYi Xu <yixu310@gmail.com> * Better `ValueError` message for incompatible nested image list size Co-authored-by: YiYi Xu <yixu310@gmail.com> * Fix syntax error in `check_inputs` * Remove warning message for multi-ControlNets with multiple prompts * Fix a typo in test_controlnet.py Co-authored-by: YiYi Xu <yixu310@gmail.com> * Add test case for multiple prompts, single image conditioning in `StableDiffusionMultiControlNetPipelineFastTests` * Improved `ValueError` message for nested `controlnet_conditioning_scale` * Documenting the behavior of image list as `StableDiffusionControlNetPipeline` input --------- Co-authored-by: YiYi Xu <yixu310@gmail.com>
1 parent 49a4b37 commit 1040dfd

File tree

2 files changed

+46
-12
lines changed

2 files changed

+46
-12
lines changed

src/diffusers/pipelines/controlnet/pipeline_controlnet.py

Lines changed: 19 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -603,15 +603,6 @@ def check_inputs(
603603
f" {negative_prompt_embeds.shape}."
604604
)
605605

606-
# `prompt` needs more sophisticated handling when there are multiple
607-
# conditionings.
608-
if isinstance(self.controlnet, MultiControlNetModel):
609-
if isinstance(prompt, list):
610-
logger.warning(
611-
f"You have {len(self.controlnet.nets)} ControlNets and you have passed {len(prompt)}"
612-
" prompts. The conditionings will be fixed across the prompts."
613-
)
614-
615606
# Check `image`
616607
is_compiled = hasattr(F, "scaled_dot_product_attention") and isinstance(
617608
self.controlnet, torch._dynamo.eval_frame.OptimizedModule
@@ -633,7 +624,13 @@ def check_inputs(
633624
# When `image` is a nested list:
634625
# (e.g. [[canny_image_1, pose_image_1], [canny_image_2, pose_image_2]])
635626
elif any(isinstance(i, list) for i in image):
636-
raise ValueError("A single batch of multiple conditionings is not supported at the moment.")
627+
transposed_image = [list(t) for t in zip(*image)]
628+
if len(transposed_image) != len(self.controlnet.nets):
629+
raise ValueError(
630+
f"For multiple controlnets: if you pass`image` as a list of list, each sublist must have the same length as the number of controlnets, but the sublists in `image` got {len(transposed_image)} images and {len(self.controlnet.nets)} ControlNets."
631+
)
632+
for image_ in transposed_image:
633+
self.check_image(image_, prompt, prompt_embeds)
637634
elif len(image) != len(self.controlnet.nets):
638635
raise ValueError(
639636
f"For multiple controlnets: `image` must have the same length as the number of controlnets, but got {len(image)} images and {len(self.controlnet.nets)} ControlNets."
@@ -659,7 +656,10 @@ def check_inputs(
659656
):
660657
if isinstance(controlnet_conditioning_scale, list):
661658
if any(isinstance(i, list) for i in controlnet_conditioning_scale):
662-
raise ValueError("A single batch of multiple conditionings is not supported at the moment.")
659+
raise ValueError(
660+
"A single batch of varying conditioning scale settings (e.g. [[1.0, 0.5], [0.2, 0.8]]) is not supported at the moment. "
661+
"The conditioning scale must be fixed across the batch."
662+
)
663663
elif isinstance(controlnet_conditioning_scale, list) and len(controlnet_conditioning_scale) != len(
664664
self.controlnet.nets
665665
):
@@ -906,7 +906,9 @@ def __call__(
906906
accepted as an image. The dimensions of the output image defaults to `image`'s dimensions. If height
907907
and/or width are passed, `image` is resized accordingly. If multiple ControlNets are specified in
908908
`init`, images must be passed as a list such that each element of the list can be correctly batched for
909-
input to a single ControlNet.
909+
input to a single ControlNet. When `prompt` is a list, and if a list of images is passed for a single ControlNet,
910+
each will be paired with each prompt in the `prompt` list. This also applies to multiple ControlNets,
911+
where a list of image lists can be passed to batch for each prompt and each ControlNet.
910912
height (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`):
911913
The height in pixels of the generated image.
912914
width (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`):
@@ -1105,6 +1107,11 @@ def __call__(
11051107
elif isinstance(controlnet, MultiControlNetModel):
11061108
images = []
11071109

1110+
# Nested lists as ControlNet condition
1111+
if isinstance(image[0], list):
1112+
# Transpose the nested image list
1113+
image = [list(t) for t in zip(*image)]
1114+
11081115
for image_ in image:
11091116
image_ = self.prepare_image(
11101117
image=image_,

tests/pipelines/controlnet/test_controlnet.py

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -460,6 +460,33 @@ def test_save_pretrained_raise_not_implemented_exception(self):
460460
except NotImplementedError:
461461
pass
462462

463+
def test_inference_multiple_prompt_input(self):
464+
device = "cpu"
465+
466+
components = self.get_dummy_components()
467+
sd_pipe = StableDiffusionControlNetPipeline(**components)
468+
sd_pipe = sd_pipe.to(torch_device)
469+
sd_pipe.set_progress_bar_config(disable=None)
470+
471+
inputs = self.get_dummy_inputs(device)
472+
inputs["prompt"] = [inputs["prompt"], inputs["prompt"]]
473+
inputs["image"] = [inputs["image"], inputs["image"]]
474+
output = sd_pipe(**inputs)
475+
image = output.images
476+
477+
assert image.shape == (2, 64, 64, 3)
478+
479+
image_1, image_2 = image
480+
# make sure that the outputs are different
481+
assert np.sum(np.abs(image_1 - image_2)) > 1e-3
482+
483+
# multiple prompts, single image conditioning
484+
inputs = self.get_dummy_inputs(device)
485+
inputs["prompt"] = [inputs["prompt"], inputs["prompt"]]
486+
output_1 = sd_pipe(**inputs)
487+
488+
assert np.abs(image - output_1.images).max() < 1e-3
489+
463490

464491
class StableDiffusionMultiControlNetOneModelPipelineFastTests(
465492
PipelineTesterMixin, PipelineKarrasSchedulerTesterMixin, unittest.TestCase

0 commit comments

Comments
 (0)