Skip to content

Commit 99b540b

Browse files
bghirabghirapatrickvonplatenpcuencasayakpaul
authored
[SDXL] Partial diffusion support for Text2Img and Img2Img Pipelines (#4015)
* diffusers#4003 - initial implementation of max_inference_steps * diffusers#4003 - initial implementation of max_inference_steps and first_inference_step for img2img * diffusers#4003 - use first_inference_step as an input arg for get_timestamps in img2img * diffusers#4003 Do not add noise during img2img when we have a defined first timestep * diffusers#4003 Mild updates after revert * diffusers#4003 Missing change * Show implementation with denoising_start and end * Apply suggestions from code review * Update src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py Co-authored-by: Pedro Cuenca <pedro@huggingface.co> * move to 0.19.0dev * Apply suggestions from code review * add exhaustive tests * add docs * finish * Apply suggestions from code review Co-authored-by: Sayak Paul <spsayakpaul@gmail.com> * make style --------- Co-authored-by: bghira <bghira@users.github.com> Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com> Co-authored-by: Pedro Cuenca <pedro@huggingface.co> Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>
1 parent b9feed8 commit 99b540b

File tree

4 files changed

+341
-23
lines changed

4 files changed

+341
-23
lines changed

docs/source/en/api/pipelines/stable_diffusion/stable_diffusion_xl.mdx

Lines changed: 103 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -59,21 +59,117 @@ image = pipe(prompt=prompt).images[0]
5959

6060
### Refining the image output
6161

62-
The image can be refined by making use of [stabilityai/stable-diffusion-xl-refiner-0.9](https://huggingface.co/stabilityai/stable-diffusion-xl-refiner-0.9).
63-
In this case, you only have to output the `latents` from the base model.
62+
In addition to the [base model checkpoint](https://huggingface.co/stabilityai/stable-diffusion-xl-base-0.9),
63+
StableDiffusion-XL also includes a [refiner checkpoint](huggingface.co/stabilityai/stable-diffusion-xl-refiner-0.9)
64+
that is specialized in denoising low-noise stage images to generate images of improved high-frequency quality.
65+
This refiner checkpoint can be used as a "second-step" pipeline after having run the base checkpoint to improve
66+
image quality.
67+
68+
When using the refiner, one can easily
69+
- 1.) employ the base model and refiner as an *Ensemble of Expert Denoisers* as first proposed in [eDiff-I](https://research.nvidia.com/labs/dir/eDiff-I/) or
70+
- 2.) simply run the refiner in [SDEdit](https://arxiv.org/abs/2108.01073) fashion after the base model.
71+
72+
**Note**: The idea of using SD-XL base & refiner as an ensemble of experts was first brought forward by
73+
a couple community contributors which also helped shape the following `diffusers` implementation, namely:
74+
- [SytanSD](https://github.com/SytanSD)
75+
- [bghira](https://github.com/bghira)
76+
- [Birch-san](https://github.com/Birch-san)
77+
78+
#### 1.) Ensemble of Expert Denoisers
79+
80+
When using the base and refiner model as an ensemble of expert of denoisers, the base model should serve as the
81+
expert for the high-noise diffusion stage and the refiner serves as the expert for the low-noise diffusion stage.
82+
83+
The advantage of 1.) over 2.) is that it requires less overall denoising steps and therefore should be significantly
84+
faster. The drawback is that one cannot really inspect the output of the base model; it will still be heavily denoised.
85+
86+
To use the base model and refiner as an ensemble of expert denoisers, make sure to define the fraction
87+
of timesteps which should be run through the high-noise denoising stage (*i.e.* the base model) and the low-noise
88+
denoising stage (*i.e.* the refiner model) respectively. This fraction should be set as the [`~StableDiffusionXLPipeline.__call__.denoising_end`] of the base model
89+
and as the [`~StableDiffusionXLImg2ImgPipeline.__call__.denoising_start`] of the refiner model.
90+
91+
Let's look at an example.
92+
First, we import the two pipelines. Since the text encoders and variational autoencoder are the same
93+
you don't have to load those again for the refiner.
6494

6595
```py
66-
from diffusers import StableDiffusionXLPipeline, StableDiffusionXLImg2ImgPipeline
96+
from diffusers import DiffusionPipeline
6797
import torch
6898

69-
pipe = StableDiffusionXLPipeline.from_pretrained(
99+
base = DiffusionPipeline.from_pretrained(
70100
"stabilityai/stable-diffusion-xl-base-0.9", torch_dtype=torch.float16, variant="fp16", use_safetensors=True
71101
)
72102
pipe.to("cuda")
73103

74-
use_refiner = True
75-
refiner = StableDiffusionXLImg2ImgPipeline.from_pretrained(
76-
"stabilityai/stable-diffusion-xl-refiner-0.9", torch_dtype=torch.float16, use_safetensors=True, variant="fp16"
104+
refiner = DiffusionPipeline.from_pretrained(
105+
"stabilityai/stable-diffusion-xl-refiner-0.9",
106+
text_encoder_2=base.text_encoder_2,
107+
vae=base.vae,
108+
torch_dtype=torch.float16,
109+
use_safetensors=True,
110+
variant="fp16",
111+
)
112+
refiner.to("cuda")
113+
```
114+
115+
Now we define the number of inference steps and the fraction at which the model shall be run through the
116+
high-noise denoising stage (*i.e.* the base model).
117+
118+
```py
119+
n_steps = 40
120+
high_noise_frac = 0.7
121+
```
122+
123+
A fraction of 0.7 means that 70% of the 40 inference steps (28 steps) are run through the base model
124+
and the remaining 12 steps are run through the refiner. Let's run the two pipelines now.
125+
Make sure to set `denoising_end` and `denoising_start` to the same values and keep `num_inference_steps`
126+
constant. Also remember that the output of the base model should be in latent space:
127+
128+
```py
129+
prompt = "A majestic lion jumping from a big stone at night"
130+
131+
image = base(prompt=prompt, num_inference_steps=n_steps, denoising_end=high_noise_frac, output_type="latent").images
132+
image = refiner(prompt=prompt, num_inference_steps=n_steps, denoising_start=high_noise_frac, image=image).images[0]
133+
```
134+
135+
Let's have a look at the image
136+
137+
![lion_ref](https://huggingface.co/datasets/huggingface/documentation-images/blob/main/diffusers/lion_refined.png)
138+
139+
If we would have just run the base model on the same 40 steps, the image would have been arguably less detailed (e.g. the lion eyes and nose):
140+
141+
![lion_base](https://huggingface.co/datasets/huggingface/documentation-images/blob/main/diffusers/lion_base.png)
142+
143+
<Tip>
144+
145+
The ensemble-of-experts method works well on all available schedulers!
146+
147+
</Tip>
148+
149+
#### Refining the image output from fully denoised base image
150+
151+
In standard [`StableDiffusionImg2ImgPipeline`]-fashion, the fully-denoised image generated of the base model
152+
can be further improved using the [refiner checkpoint](huggingface.co/stabilityai/stable-diffusion-xl-refiner-0.9).
153+
154+
For this, you simply run the refiner as a normal image-to-image pipeline after the "base" text-to-image
155+
pipeline. You can leave the outputs of the base model in latent space.
156+
157+
```py
158+
from diffusers import DiffusionPipeline
159+
import torch
160+
161+
pipe = DiffusionPipeline.from_pretrained(
162+
"stabilityai/stable-diffusion-xl-base-0.9", torch_dtype=torch.float16, variant="fp16", use_safetensors=True
163+
)
164+
pipe.to("cuda")
165+
166+
refiner = DiffusionPipeline.from_pretrained(
167+
"stabilityai/stable-diffusion-xl-refiner-0.9",
168+
text_encoder_2=pipe.text_encoder_2,
169+
vae=pipe.vae,
170+
torch_dtype=torch.float16,
171+
use_safetensors=True,
172+
variant="fp16",
77173
)
78174
refiner.to("cuda")
79175

src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -545,6 +545,7 @@ def __call__(
545545
height: Optional[int] = None,
546546
width: Optional[int] = None,
547547
num_inference_steps: int = 50,
548+
denoising_end: Optional[float] = None,
548549
guidance_scale: float = 5.0,
549550
negative_prompt: Optional[Union[str, List[str]]] = None,
550551
num_images_per_prompt: Optional[int] = 1,
@@ -579,6 +580,14 @@ def __call__(
579580
num_inference_steps (`int`, *optional*, defaults to 50):
580581
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
581582
expense of slower inference.
583+
denoising_end (`float`, *optional*):
584+
When specified, determines the fraction (between 0.0 and 1.0) of the total denoising process to be
585+
completed before it is intentionally prematurely terminated. For instance, if denoising_end is set to
586+
0.7 and `num_inference_steps` is fixed at 50, the process will execute only 35 (i.e., 0.7 * 50)
587+
denoising steps. As a result, the returned sample will still retain a substantial amount of noise. The
588+
denoising_end parameter should ideally be utilized when this pipeline forms a part of a "Mixture of
589+
Denoisers" multi-pipeline setup, as elaborated in [**Refining the Image
590+
Output**](https://huggingface.co/docs/diffusers/api/pipelines/stable_diffusion/stable_diffusion_xl#refining-the-image-output)
582591
guidance_scale (`float`, *optional*, defaults to 7.5):
583592
Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
584593
`guidance_scale` is defined as `w` of equation 2. of [Imagen
@@ -746,7 +755,13 @@ def __call__(
746755
add_time_ids = add_time_ids.to(device).repeat(batch_size * num_images_per_prompt, 1)
747756

748757
# 8. Denoising loop
749-
num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
758+
num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
759+
760+
# 7.1 Apply denoising_end
761+
if denoising_end is not None:
762+
num_inference_steps = int(round(denoising_end * num_inference_steps))
763+
timesteps = timesteps[: num_warmup_steps + self.scheduler.order * num_inference_steps]
764+
750765
with self.progress_bar(total=num_inference_steps) as progress_bar:
751766
for i, t in enumerate(timesteps):
752767
# expand the latents if we are doing classifier free guidance

src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py

Lines changed: 89 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -456,11 +456,24 @@ def prepare_extra_step_kwargs(self, generator, eta):
456456
return extra_step_kwargs
457457

458458
def check_inputs(
459-
self, prompt, strength, callback_steps, negative_prompt=None, prompt_embeds=None, negative_prompt_embeds=None
459+
self,
460+
prompt,
461+
strength,
462+
num_inference_steps,
463+
callback_steps,
464+
negative_prompt=None,
465+
prompt_embeds=None,
466+
negative_prompt_embeds=None,
460467
):
461468
if strength < 0 or strength > 1:
462469
raise ValueError(f"The value of strength should in [0.0, 1.0] but is {strength}")
463-
470+
if num_inference_steps is None:
471+
raise ValueError("`num_inference_steps` cannot be None.")
472+
elif not isinstance(num_inference_steps, int) or num_inference_steps <= 0:
473+
raise ValueError(
474+
f"`num_inference_steps` has to be a positive integer but is {num_inference_steps} of type"
475+
f" {type(num_inference_steps)}."
476+
)
464477
if (callback_steps is None) or (
465478
callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0)
466479
):
@@ -495,16 +508,21 @@ def check_inputs(
495508
f" {negative_prompt_embeds.shape}."
496509
)
497510

498-
def get_timesteps(self, num_inference_steps, strength, device):
511+
def get_timesteps(self, num_inference_steps, strength, device, denoising_start=None):
499512
# get the original timestep using init_timestep
500-
init_timestep = min(int(num_inference_steps * strength), num_inference_steps)
513+
if denoising_start is None:
514+
init_timestep = min(int(num_inference_steps * strength), num_inference_steps)
515+
t_start = max(num_inference_steps - init_timestep, 0)
516+
else:
517+
t_start = int(round(denoising_start * num_inference_steps))
501518

502-
t_start = max(num_inference_steps - init_timestep, 0)
503519
timesteps = self.scheduler.timesteps[t_start * self.scheduler.order :]
504520

505521
return timesteps, num_inference_steps - t_start
506522

507-
def prepare_latents(self, image, timestep, batch_size, num_images_per_prompt, dtype, device, generator=None):
523+
def prepare_latents(
524+
self, image, timestep, batch_size, num_images_per_prompt, dtype, device, generator=None, add_noise=True
525+
):
508526
if not isinstance(image, (torch.Tensor, PIL.Image.Image, list)):
509527
raise ValueError(
510528
f"`image` has to be of type `torch.Tensor`, `PIL.Image.Image` or list but is {type(image)}"
@@ -557,11 +575,12 @@ def prepare_latents(self, image, timestep, batch_size, num_images_per_prompt, dt
557575
else:
558576
init_latents = torch.cat([init_latents], dim=0)
559577

560-
shape = init_latents.shape
561-
noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
578+
if add_noise:
579+
shape = init_latents.shape
580+
noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
581+
# get latents
582+
init_latents = self.scheduler.add_noise(init_latents, noise, timestep)
562583

563-
# get latents
564-
init_latents = self.scheduler.add_noise(init_latents, noise, timestep)
565584
latents = init_latents
566585

567586
return latents
@@ -620,6 +639,8 @@ def __call__(
620639
] = None,
621640
strength: float = 0.3,
622641
num_inference_steps: int = 50,
642+
denoising_start: Optional[float] = None,
643+
denoising_end: Optional[float] = None,
623644
guidance_scale: float = 5.0,
624645
negative_prompt: Optional[Union[str, List[str]]] = None,
625646
num_images_per_prompt: Optional[int] = 1,
@@ -651,7 +672,7 @@ def __call__(
651672
instead.
652673
image (`torch.FloatTensor` or `PIL.Image.Image` or `np.ndarray` or `List[torch.FloatTensor]` or `List[PIL.Image.Image]` or `List[np.ndarray]`):
653674
The image(s) to modify with the pipeline.
654-
strength (`float`, *optional*, defaults to 0.8):
675+
strength (`float`, *optional*, defaults to 0.3):
655676
Conceptually, indicates how much to transform the reference `image`. Must be between 0 and 1. `image`
656677
will be used as a starting point, adding more noise to it the larger the `strength`. The number of
657678
denoising steps depends on the amount of noise initially added. When `strength` is 1, added noise will
@@ -660,6 +681,24 @@ def __call__(
660681
num_inference_steps (`int`, *optional*, defaults to 50):
661682
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
662683
expense of slower inference.
684+
denoising_start (`float`, *optional*):
685+
When specified, indicates the fraction (between 0.0 and 1.0) of the total denoising process to be
686+
bypassed before it is initiated. For example, if `denoising_start` is set to 0.7 and
687+
num_inference_steps is fixed at 50, the process will begin only from the 35th (i.e., 0.7 * 50)
688+
denoising step. Consequently, the initial part of the denoising process is skipped and it is assumed
689+
that the passed `image` is a partly denoised image. The `denoising_start` parameter is particularly
690+
beneficial when this pipeline is integrated into a "Mixture of Denoisers" multi-pipeline setup, as
691+
detailed in [**Refining the Image
692+
Output**](https://huggingface.co/docs/diffusers/api/pipelines/stable_diffusion/stable_diffusion_xl#refining-the-image-output).
693+
denoising_end (`float`, *optional*):
694+
When specified, determines the fraction (between 0.0 and 1.0) of the total denoising process to be
695+
completed before it is intentionally prematurely terminated. For instance, if denoising_end is set to
696+
0.7 and `num_inference_steps` is fixed at 50, the process will execute only 35 (i.e., 0.7 * 50)
697+
denoising steps. As a result, the returned sample will still retain a substantial amount of noise (ca.
698+
30%) and should be denoised by a successor pipeline that has `denoising_start` set to 0.7 so that it
699+
only denoised the final 30%. The denoising_end parameter should ideally be utilized when this pipeline
700+
forms a part of a "Mixture of Denoisers" multi-pipeline setup, as elaborated in [**Refining the Image
701+
Output**](https://huggingface.co/docs/diffusers/api/pipelines/stable_diffusion/stable_diffusion_xl#refining-the-image-output).
663702
guidance_scale (`float`, *optional*, defaults to 7.5):
664703
Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
665704
`guidance_scale` is defined as `w` of equation 2. of [Imagen
@@ -738,7 +777,15 @@ def __call__(
738777
"not-safe-for-work" (nsfw) content, according to the `safety_checker`.
739778
"""
740779
# 1. Check inputs. Raise error if not correct
741-
self.check_inputs(prompt, strength, callback_steps, negative_prompt, prompt_embeds, negative_prompt_embeds)
780+
self.check_inputs(
781+
prompt,
782+
strength,
783+
num_inference_steps,
784+
callback_steps,
785+
negative_prompt,
786+
prompt_embeds,
787+
negative_prompt_embeds,
788+
)
742789

743790
# 2. Define call parameters
744791
if prompt is not None and isinstance(prompt, str):
@@ -781,13 +828,25 @@ def __call__(
781828
image = self.image_processor.preprocess(image)
782829

783830
# 5. Prepare timesteps
831+
original_num_steps = num_inference_steps # save for denoising_start/end later
832+
784833
self.scheduler.set_timesteps(num_inference_steps, device=device)
785-
timesteps, num_inference_steps = self.get_timesteps(num_inference_steps, strength, device)
834+
timesteps, num_inference_steps = self.get_timesteps(
835+
num_inference_steps, strength, device, denoising_start=denoising_start
836+
)
786837
latent_timestep = timesteps[:1].repeat(batch_size * num_images_per_prompt)
787838

839+
add_noise = True if denoising_start is None else False
788840
# 6. Prepare latent variables
789841
latents = self.prepare_latents(
790-
image, latent_timestep, batch_size, num_images_per_prompt, prompt_embeds.dtype, device, generator
842+
image,
843+
latent_timestep,
844+
batch_size,
845+
num_images_per_prompt,
846+
prompt_embeds.dtype,
847+
device,
848+
generator,
849+
add_noise,
791850
)
792851
# 7. Prepare extra step kwargs.
793852
extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
@@ -820,7 +879,22 @@ def __call__(
820879
add_time_ids = add_time_ids.to(device).repeat(batch_size * num_images_per_prompt, 1)
821880

822881
# 9. Denoising loop
823-
num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
882+
num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
883+
884+
# 9.1 Apply denoising_end
885+
if denoising_end is not None and denoising_start is not None:
886+
if denoising_start >= denoising_end:
887+
raise ValueError(
888+
f"`denoising_end`: {denoising_end} cannot be larger than `denoising_start`: {denoising_start}."
889+
)
890+
891+
skipped_final_steps = int(round((1 - denoising_end) * original_num_steps))
892+
num_inference_steps = num_inference_steps - skipped_final_steps
893+
timesteps = timesteps[: num_warmup_steps + self.scheduler.order * num_inference_steps]
894+
elif denoising_end is not None:
895+
num_inference_steps = int(round(denoising_end * num_inference_steps))
896+
timesteps = timesteps[: num_warmup_steps + self.scheduler.order * num_inference_steps]
897+
824898
with self.progress_bar(total=num_inference_steps) as progress_bar:
825899
for i, t in enumerate(timesteps):
826900
# expand the latents if we are doing classifier free guidance

0 commit comments

Comments
 (0)