Skip to content

[SDXL] Partial diffusion support for Text2Img and Img2Img Pipelines #4015

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 20 commits into from
Jul 12, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
20 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
110 changes: 103 additions & 7 deletions docs/source/en/api/pipelines/stable_diffusion/stable_diffusion_xl.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -59,21 +59,117 @@ image = pipe(prompt=prompt).images[0]

### Refining the image output

The image can be refined by making use of [stabilityai/stable-diffusion-xl-refiner-0.9](https://huggingface.co/stabilityai/stable-diffusion-xl-refiner-0.9).
In this case, you only have to output the `latents` from the base model.
In addition to the [base model checkpoint](https://huggingface.co/stabilityai/stable-diffusion-xl-base-0.9),
StableDiffusion-XL also includes a [refiner checkpoint](huggingface.co/stabilityai/stable-diffusion-xl-refiner-0.9)
that is specialized in denoising low-noise stage images to generate images of improved high-frequency quality.
This refiner checkpoint can be used as a "second-step" pipeline after having run the base checkpoint to improve
image quality.

When using the refiner, one can easily
- 1.) employ the base model and refiner as an *Ensemble of Expert Denoisers* as first proposed in [eDiff-I](https://research.nvidia.com/labs/dir/eDiff-I/) or
- 2.) simply run the refiner in [SDEdit](https://arxiv.org/abs/2108.01073) fashion after the base model.

**Note**: The idea of using SD-XL base & refiner as an ensemble of experts was first brought forward by
a couple community contributors which also helped shape the following `diffusers` implementation, namely:
- [SytanSD](https://github.com/SytanSD)
- [bghira](https://github.com/bghira)
- [Birch-san](https://github.com/Birch-san)

#### 1.) Ensemble of Expert Denoisers

When using the base and refiner model as an ensemble of expert of denoisers, the base model should serve as the
expert for the high-noise diffusion stage and the refiner serves as the expert for the low-noise diffusion stage.

The advantage of 1.) over 2.) is that it requires less overall denoising steps and therefore should be significantly
faster. The drawback is that one cannot really inspect the output of the base model; it will still be heavily denoised.

To use the base model and refiner as an ensemble of expert denoisers, make sure to define the fraction
of timesteps which should be run through the high-noise denoising stage (*i.e.* the base model) and the low-noise
denoising stage (*i.e.* the refiner model) respectively. This fraction should be set as the [`~StableDiffusionXLPipeline.__call__.denoising_end`] of the base model
and as the [`~StableDiffusionXLImg2ImgPipeline.__call__.denoising_start`] of the refiner model.

Let's look at an example.
First, we import the two pipelines. Since the text encoders and variational autoencoder are the same
you don't have to load those again for the refiner.

```py
from diffusers import StableDiffusionXLPipeline, StableDiffusionXLImg2ImgPipeline
from diffusers import DiffusionPipeline
import torch

pipe = StableDiffusionXLPipeline.from_pretrained(
base = DiffusionPipeline.from_pretrained(
"stabilityai/stable-diffusion-xl-base-0.9", torch_dtype=torch.float16, variant="fp16", use_safetensors=True
)
pipe.to("cuda")

use_refiner = True
refiner = StableDiffusionXLImg2ImgPipeline.from_pretrained(
"stabilityai/stable-diffusion-xl-refiner-0.9", torch_dtype=torch.float16, use_safetensors=True, variant="fp16"
refiner = DiffusionPipeline.from_pretrained(
"stabilityai/stable-diffusion-xl-refiner-0.9",
text_encoder_2=base.text_encoder_2,
vae=base.vae,
torch_dtype=torch.float16,
use_safetensors=True,
variant="fp16",
)
refiner.to("cuda")
```

Now we define the number of inference steps and the fraction at which the model shall be run through the
high-noise denoising stage (*i.e.* the base model).

```py
n_steps = 40
high_noise_frac = 0.7
```

A fraction of 0.7 means that 70% of the 40 inference steps (28 steps) are run through the base model
and the remaining 12 steps are run through the refiner. Let's run the two pipelines now.
Make sure to set `denoising_end` and `denoising_start` to the same values and keep `num_inference_steps`
constant. Also remember that the output of the base model should be in latent space:

```py
prompt = "A majestic lion jumping from a big stone at night"

image = base(prompt=prompt, num_inference_steps=n_steps, denoising_end=high_noise_frac, output_type="latent").images
image = refiner(prompt=prompt, num_inference_steps=n_steps, denoising_start=high_noise_frac, image=image).images[0]
```

Let's have a look at the image

![lion_ref](https://huggingface.co/datasets/huggingface/documentation-images/blob/main/diffusers/lion_refined.png)

If we would have just run the base model on the same 40 steps, the image would have been arguably less detailed (e.g. the lion eyes and nose):

![lion_base](https://huggingface.co/datasets/huggingface/documentation-images/blob/main/diffusers/lion_base.png)

<Tip>

The ensemble-of-experts method works well on all available schedulers!

</Tip>

#### Refining the image output from fully denoised base image

In standard [`StableDiffusionImg2ImgPipeline`]-fashion, the fully-denoised image generated of the base model
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

StableDiffusionXLImg2ImgPipeline? and not StableDiffusionImg2ImgPipeline?

can be further improved using the [refiner checkpoint](huggingface.co/stabilityai/stable-diffusion-xl-refiner-0.9).

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

suggestion: list advantages and disadvantages of this method

  • disadvantages: wasteful with compute cycles, destroys detail in the incoming image
  • advantages: can handle externally sourced or previously-generated images to provide 'variations', when the latent space is no longer available

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'd be in favor of including these!

If we can let's maybe also highlight the point on reducing inference latency?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Want to wait here a bit to see what SAI will use as the "official" way.

For this, you simply run the refiner as a normal image-to-image pipeline after the "base" text-to-image
pipeline. You can leave the outputs of the base model in latent space.

```py
from diffusers import DiffusionPipeline
import torch

pipe = DiffusionPipeline.from_pretrained(
"stabilityai/stable-diffusion-xl-base-0.9", torch_dtype=torch.float16, variant="fp16", use_safetensors=True
)
pipe.to("cuda")

refiner = DiffusionPipeline.from_pretrained(
"stabilityai/stable-diffusion-xl-refiner-0.9",
text_encoder_2=pipe.text_encoder_2,
vae=pipe.vae,
torch_dtype=torch.float16,
use_safetensors=True,
variant="fp16",
)
refiner.to("cuda")

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -545,6 +545,7 @@ def __call__(
height: Optional[int] = None,
width: Optional[int] = None,
num_inference_steps: int = 50,
denoising_end: Optional[float] = None,
guidance_scale: float = 5.0,
negative_prompt: Optional[Union[str, List[str]]] = None,
num_images_per_prompt: Optional[int] = 1,
Expand Down Expand Up @@ -579,6 +580,14 @@ def __call__(
num_inference_steps (`int`, *optional*, defaults to 50):
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
expense of slower inference.
denoising_end (`float`, *optional*):
When specified, determines the fraction (between 0.0 and 1.0) of the total denoising process to be
completed before it is intentionally prematurely terminated. For instance, if denoising_end is set to
0.7 and `num_inference_steps` is fixed at 50, the process will execute only 35 (i.e., 0.7 * 50)
denoising steps. As a result, the returned sample will still retain a substantial amount of noise. The
denoising_end parameter should ideally be utilized when this pipeline forms a part of a "Mixture of
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

as elaborated in []

Denoisers" multi-pipeline setup, as elaborated in [**Refining the Image
Output**](https://huggingface.co/docs/diffusers/api/pipelines/stable_diffusion/stable_diffusion_xl#refining-the-image-output)
guidance_scale (`float`, *optional*, defaults to 7.5):
Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
`guidance_scale` is defined as `w` of equation 2. of [Imagen
Expand Down Expand Up @@ -746,7 +755,13 @@ def __call__(
add_time_ids = add_time_ids.to(device).repeat(batch_size * num_images_per_prompt, 1)

# 8. Denoising loop
num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i noticed that going negative, as well.


# 7.1 Apply denoising_end
if denoising_end is not None:
num_inference_steps = int(round(denoising_end * num_inference_steps))
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this resolved the timestamp bar not completing. however, i really liked that emergent property, as it meant any progress bar capturing the tqdm output would "just magically" have the progress bar continue on logically from one to the next. 😥 more work for that result, but i've been needing to refactor how i handle that, anyway...

timesteps = timesteps[: num_warmup_steps + self.scheduler.order * num_inference_steps]

with self.progress_bar(total=num_inference_steps) as progress_bar:
for i, t in enumerate(timesteps):
# expand the latents if we are doing classifier free guidance
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -456,11 +456,24 @@ def prepare_extra_step_kwargs(self, generator, eta):
return extra_step_kwargs

def check_inputs(
self, prompt, strength, callback_steps, negative_prompt=None, prompt_embeds=None, negative_prompt_embeds=None
self,
prompt,
strength,
num_inference_steps,
callback_steps,
negative_prompt=None,
prompt_embeds=None,
negative_prompt_embeds=None,
):
if strength < 0 or strength > 1:
raise ValueError(f"The value of strength should in [0.0, 1.0] but is {strength}")

if num_inference_steps is None:
raise ValueError("`num_inference_steps` cannot be None.")
elif not isinstance(num_inference_steps, int) or num_inference_steps <= 0:
raise ValueError(
f"`num_inference_steps` has to be a positive integer but is {num_inference_steps} of type"
f" {type(num_inference_steps)}."
)
if (callback_steps is None) or (
callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0)
):
Expand Down Expand Up @@ -495,16 +508,21 @@ def check_inputs(
f" {negative_prompt_embeds.shape}."
)

def get_timesteps(self, num_inference_steps, strength, device):
def get_timesteps(self, num_inference_steps, strength, device, denoising_start=None):
# get the original timestep using init_timestep
init_timestep = min(int(num_inference_steps * strength), num_inference_steps)
if denoising_start is None:
init_timestep = min(int(num_inference_steps * strength), num_inference_steps)
t_start = max(num_inference_steps - init_timestep, 0)
else:
t_start = int(round(denoising_start * num_inference_steps))

t_start = max(num_inference_steps - init_timestep, 0)
timesteps = self.scheduler.timesteps[t_start * self.scheduler.order :]

return timesteps, num_inference_steps - t_start

def prepare_latents(self, image, timestep, batch_size, num_images_per_prompt, dtype, device, generator=None):
def prepare_latents(
self, image, timestep, batch_size, num_images_per_prompt, dtype, device, generator=None, add_noise=True
):
if not isinstance(image, (torch.Tensor, PIL.Image.Image, list)):
raise ValueError(
f"`image` has to be of type `torch.Tensor`, `PIL.Image.Image` or list but is {type(image)}"
Expand Down Expand Up @@ -557,11 +575,12 @@ def prepare_latents(self, image, timestep, batch_size, num_images_per_prompt, dt
else:
init_latents = torch.cat([init_latents], dim=0)

shape = init_latents.shape
noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
if add_noise:
shape = init_latents.shape
noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
# get latents
init_latents = self.scheduler.add_noise(init_latents, noise, timestep)

# get latents
init_latents = self.scheduler.add_noise(init_latents, noise, timestep)
latents = init_latents

return latents
Expand Down Expand Up @@ -620,6 +639,8 @@ def __call__(
] = None,
strength: float = 0.3,
num_inference_steps: int = 50,
denoising_start: Optional[float] = None,
denoising_end: Optional[float] = None,
guidance_scale: float = 5.0,
negative_prompt: Optional[Union[str, List[str]]] = None,
num_images_per_prompt: Optional[int] = 1,
Expand Down Expand Up @@ -651,7 +672,7 @@ def __call__(
instead.
image (`torch.FloatTensor` or `PIL.Image.Image` or `np.ndarray` or `List[torch.FloatTensor]` or `List[PIL.Image.Image]` or `List[np.ndarray]`):
The image(s) to modify with the pipeline.
strength (`float`, *optional*, defaults to 0.8):
strength (`float`, *optional*, defaults to 0.3):
Conceptually, indicates how much to transform the reference `image`. Must be between 0 and 1. `image`
will be used as a starting point, adding more noise to it the larger the `strength`. The number of
denoising steps depends on the amount of noise initially added. When `strength` is 1, added noise will
Expand All @@ -660,6 +681,24 @@ def __call__(
num_inference_steps (`int`, *optional*, defaults to 50):
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
expense of slower inference.
denoising_start (`float`, *optional*):
When specified, indicates the fraction (between 0.0 and 1.0) of the total denoising process to be
bypassed before it is initiated. For example, if `denoising_start` is set to 0.7 and
num_inference_steps is fixed at 50, the process will begin only from the 35th (i.e., 0.7 * 50)
denoising step. Consequently, the initial part of the denoising process is skipped and it is assumed
that the passed `image` is a partly denoised image. The `denoising_start` parameter is particularly
beneficial when this pipeline is integrated into a "Mixture of Denoisers" multi-pipeline setup, as
detailed in [**Refining the Image
Output**](https://huggingface.co/docs/diffusers/api/pipelines/stable_diffusion/stable_diffusion_xl#refining-the-image-output).
denoising_end (`float`, *optional*):
When specified, determines the fraction (between 0.0 and 1.0) of the total denoising process to be
completed before it is intentionally prematurely terminated. For instance, if denoising_end is set to
0.7 and `num_inference_steps` is fixed at 50, the process will execute only 35 (i.e., 0.7 * 50)
denoising steps. As a result, the returned sample will still retain a substantial amount of noise (ca.
30%) and should be denoised by a successor pipeline that has `denoising_start` set to 0.7 so that it
only denoised the final 30%. The denoising_end parameter should ideally be utilized when this pipeline
forms a part of a "Mixture of Denoisers" multi-pipeline setup, as elaborated in [**Refining the Image
Output**](https://huggingface.co/docs/diffusers/api/pipelines/stable_diffusion/stable_diffusion_xl#refining-the-image-output).
guidance_scale (`float`, *optional*, defaults to 7.5):
Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
`guidance_scale` is defined as `w` of equation 2. of [Imagen
Expand Down Expand Up @@ -738,7 +777,15 @@ def __call__(
"not-safe-for-work" (nsfw) content, according to the `safety_checker`.
"""
# 1. Check inputs. Raise error if not correct
self.check_inputs(prompt, strength, callback_steps, negative_prompt, prompt_embeds, negative_prompt_embeds)
self.check_inputs(
prompt,
strength,
num_inference_steps,
callback_steps,
negative_prompt,
prompt_embeds,
negative_prompt_embeds,
)

# 2. Define call parameters
if prompt is not None and isinstance(prompt, str):
Expand Down Expand Up @@ -781,13 +828,25 @@ def __call__(
image = self.image_processor.preprocess(image)

# 5. Prepare timesteps
original_num_steps = num_inference_steps # save for denoising_start/end later

self.scheduler.set_timesteps(num_inference_steps, device=device)
timesteps, num_inference_steps = self.get_timesteps(num_inference_steps, strength, device)
timesteps, num_inference_steps = self.get_timesteps(
num_inference_steps, strength, device, denoising_start=denoising_start
)
latent_timestep = timesteps[:1].repeat(batch_size * num_images_per_prompt)

add_noise = True if denoising_start is None else False
# 6. Prepare latent variables
latents = self.prepare_latents(
image, latent_timestep, batch_size, num_images_per_prompt, prompt_embeds.dtype, device, generator
image,
latent_timestep,
batch_size,
num_images_per_prompt,
prompt_embeds.dtype,
device,
generator,
add_noise,
)
# 7. Prepare extra step kwargs.
extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
Expand Down Expand Up @@ -820,7 +879,22 @@ def __call__(
add_time_ids = add_time_ids.to(device).repeat(batch_size * num_images_per_prompt, 1)

# 9. Denoising loop
num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)

# 9.1 Apply denoising_end
if denoising_end is not None and denoising_start is not None:
if denoising_start >= denoising_end:
raise ValueError(
f"`denoising_end`: {denoising_end} cannot be larger than `denoising_start`: {denoising_start}."
)

skipped_final_steps = int(round((1 - denoising_end) * original_num_steps))
num_inference_steps = num_inference_steps - skipped_final_steps
timesteps = timesteps[: num_warmup_steps + self.scheduler.order * num_inference_steps]
elif denoising_end is not None:
num_inference_steps = int(round(denoising_end * num_inference_steps))
timesteps = timesteps[: num_warmup_steps + self.scheduler.order * num_inference_steps]

with self.progress_bar(total=num_inference_steps) as progress_bar:
for i, t in enumerate(timesteps):
# expand the latents if we are doing classifier free guidance
Expand Down
Loading