diff --git a/docs/source/en/api/pipelines/stable_diffusion/stable_diffusion_xl.mdx b/docs/source/en/api/pipelines/stable_diffusion/stable_diffusion_xl.mdx index 6df873edab00..1c5ea390a49f 100644 --- a/docs/source/en/api/pipelines/stable_diffusion/stable_diffusion_xl.mdx +++ b/docs/source/en/api/pipelines/stable_diffusion/stable_diffusion_xl.mdx @@ -59,21 +59,117 @@ image = pipe(prompt=prompt).images[0] ### Refining the image output -The image can be refined by making use of [stabilityai/stable-diffusion-xl-refiner-0.9](https://huggingface.co/stabilityai/stable-diffusion-xl-refiner-0.9). -In this case, you only have to output the `latents` from the base model. +In addition to the [base model checkpoint](https://huggingface.co/stabilityai/stable-diffusion-xl-base-0.9), +StableDiffusion-XL also includes a [refiner checkpoint](huggingface.co/stabilityai/stable-diffusion-xl-refiner-0.9) +that is specialized in denoising low-noise stage images to generate images of improved high-frequency quality. +This refiner checkpoint can be used as a "second-step" pipeline after having run the base checkpoint to improve +image quality. + +When using the refiner, one can easily +- 1.) employ the base model and refiner as an *Ensemble of Expert Denoisers* as first proposed in [eDiff-I](https://research.nvidia.com/labs/dir/eDiff-I/) or +- 2.) simply run the refiner in [SDEdit](https://arxiv.org/abs/2108.01073) fashion after the base model. + +**Note**: The idea of using SD-XL base & refiner as an ensemble of experts was first brought forward by +a couple community contributors which also helped shape the following `diffusers` implementation, namely: +- [SytanSD](https://github.com/SytanSD) +- [bghira](https://github.com/bghira) +- [Birch-san](https://github.com/Birch-san) + +#### 1.) Ensemble of Expert Denoisers + +When using the base and refiner model as an ensemble of expert of denoisers, the base model should serve as the +expert for the high-noise diffusion stage and the refiner serves as the expert for the low-noise diffusion stage. + +The advantage of 1.) over 2.) is that it requires less overall denoising steps and therefore should be significantly +faster. The drawback is that one cannot really inspect the output of the base model; it will still be heavily denoised. + +To use the base model and refiner as an ensemble of expert denoisers, make sure to define the fraction +of timesteps which should be run through the high-noise denoising stage (*i.e.* the base model) and the low-noise +denoising stage (*i.e.* the refiner model) respectively. This fraction should be set as the [`~StableDiffusionXLPipeline.__call__.denoising_end`] of the base model +and as the [`~StableDiffusionXLImg2ImgPipeline.__call__.denoising_start`] of the refiner model. + +Let's look at an example. +First, we import the two pipelines. Since the text encoders and variational autoencoder are the same +you don't have to load those again for the refiner. ```py -from diffusers import StableDiffusionXLPipeline, StableDiffusionXLImg2ImgPipeline +from diffusers import DiffusionPipeline import torch -pipe = StableDiffusionXLPipeline.from_pretrained( +base = DiffusionPipeline.from_pretrained( "stabilityai/stable-diffusion-xl-base-0.9", torch_dtype=torch.float16, variant="fp16", use_safetensors=True ) pipe.to("cuda") -use_refiner = True -refiner = StableDiffusionXLImg2ImgPipeline.from_pretrained( - "stabilityai/stable-diffusion-xl-refiner-0.9", torch_dtype=torch.float16, use_safetensors=True, variant="fp16" +refiner = DiffusionPipeline.from_pretrained( + "stabilityai/stable-diffusion-xl-refiner-0.9", + text_encoder_2=base.text_encoder_2, + vae=base.vae, + torch_dtype=torch.float16, + use_safetensors=True, + variant="fp16", +) +refiner.to("cuda") +``` + +Now we define the number of inference steps and the fraction at which the model shall be run through the +high-noise denoising stage (*i.e.* the base model). + +```py +n_steps = 40 +high_noise_frac = 0.7 +``` + +A fraction of 0.7 means that 70% of the 40 inference steps (28 steps) are run through the base model +and the remaining 12 steps are run through the refiner. Let's run the two pipelines now. +Make sure to set `denoising_end` and `denoising_start` to the same values and keep `num_inference_steps` +constant. Also remember that the output of the base model should be in latent space: + +```py +prompt = "A majestic lion jumping from a big stone at night" + +image = base(prompt=prompt, num_inference_steps=n_steps, denoising_end=high_noise_frac, output_type="latent").images +image = refiner(prompt=prompt, num_inference_steps=n_steps, denoising_start=high_noise_frac, image=image).images[0] +``` + +Let's have a look at the image + +![lion_ref](https://huggingface.co/datasets/huggingface/documentation-images/blob/main/diffusers/lion_refined.png) + +If we would have just run the base model on the same 40 steps, the image would have been arguably less detailed (e.g. the lion eyes and nose): + +![lion_base](https://huggingface.co/datasets/huggingface/documentation-images/blob/main/diffusers/lion_base.png) + + + +The ensemble-of-experts method works well on all available schedulers! + + + +#### Refining the image output from fully denoised base image + +In standard [`StableDiffusionImg2ImgPipeline`]-fashion, the fully-denoised image generated of the base model +can be further improved using the [refiner checkpoint](huggingface.co/stabilityai/stable-diffusion-xl-refiner-0.9). + +For this, you simply run the refiner as a normal image-to-image pipeline after the "base" text-to-image +pipeline. You can leave the outputs of the base model in latent space. + +```py +from diffusers import DiffusionPipeline +import torch + +pipe = DiffusionPipeline.from_pretrained( + "stabilityai/stable-diffusion-xl-base-0.9", torch_dtype=torch.float16, variant="fp16", use_safetensors=True +) +pipe.to("cuda") + +refiner = DiffusionPipeline.from_pretrained( + "stabilityai/stable-diffusion-xl-refiner-0.9", + text_encoder_2=pipe.text_encoder_2, + vae=pipe.vae, + torch_dtype=torch.float16, + use_safetensors=True, + variant="fp16", ) refiner.to("cuda") diff --git a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py index 21f08b40f20b..b3dcf1b67cda 100644 --- a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py +++ b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py @@ -545,6 +545,7 @@ def __call__( height: Optional[int] = None, width: Optional[int] = None, num_inference_steps: int = 50, + denoising_end: Optional[float] = None, guidance_scale: float = 5.0, negative_prompt: Optional[Union[str, List[str]]] = None, num_images_per_prompt: Optional[int] = 1, @@ -579,6 +580,14 @@ def __call__( num_inference_steps (`int`, *optional*, defaults to 50): The number of denoising steps. More denoising steps usually lead to a higher quality image at the expense of slower inference. + denoising_end (`float`, *optional*): + When specified, determines the fraction (between 0.0 and 1.0) of the total denoising process to be + completed before it is intentionally prematurely terminated. For instance, if denoising_end is set to + 0.7 and `num_inference_steps` is fixed at 50, the process will execute only 35 (i.e., 0.7 * 50) + denoising steps. As a result, the returned sample will still retain a substantial amount of noise. The + denoising_end parameter should ideally be utilized when this pipeline forms a part of a "Mixture of + Denoisers" multi-pipeline setup, as elaborated in [**Refining the Image + Output**](https://huggingface.co/docs/diffusers/api/pipelines/stable_diffusion/stable_diffusion_xl#refining-the-image-output) guidance_scale (`float`, *optional*, defaults to 7.5): Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). `guidance_scale` is defined as `w` of equation 2. of [Imagen @@ -746,7 +755,13 @@ def __call__( add_time_ids = add_time_ids.to(device).repeat(batch_size * num_images_per_prompt, 1) # 8. Denoising loop - num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order + num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0) + + # 7.1 Apply denoising_end + if denoising_end is not None: + num_inference_steps = int(round(denoising_end * num_inference_steps)) + timesteps = timesteps[: num_warmup_steps + self.scheduler.order * num_inference_steps] + with self.progress_bar(total=num_inference_steps) as progress_bar: for i, t in enumerate(timesteps): # expand the latents if we are doing classifier free guidance diff --git a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py index e53b1f68e288..7b0cdfad8c0a 100644 --- a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py +++ b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py @@ -456,11 +456,24 @@ def prepare_extra_step_kwargs(self, generator, eta): return extra_step_kwargs def check_inputs( - self, prompt, strength, callback_steps, negative_prompt=None, prompt_embeds=None, negative_prompt_embeds=None + self, + prompt, + strength, + num_inference_steps, + callback_steps, + negative_prompt=None, + prompt_embeds=None, + negative_prompt_embeds=None, ): if strength < 0 or strength > 1: raise ValueError(f"The value of strength should in [0.0, 1.0] but is {strength}") - + if num_inference_steps is None: + raise ValueError("`num_inference_steps` cannot be None.") + elif not isinstance(num_inference_steps, int) or num_inference_steps <= 0: + raise ValueError( + f"`num_inference_steps` has to be a positive integer but is {num_inference_steps} of type" + f" {type(num_inference_steps)}." + ) if (callback_steps is None) or ( callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0) ): @@ -495,16 +508,21 @@ def check_inputs( f" {negative_prompt_embeds.shape}." ) - def get_timesteps(self, num_inference_steps, strength, device): + def get_timesteps(self, num_inference_steps, strength, device, denoising_start=None): # get the original timestep using init_timestep - init_timestep = min(int(num_inference_steps * strength), num_inference_steps) + if denoising_start is None: + init_timestep = min(int(num_inference_steps * strength), num_inference_steps) + t_start = max(num_inference_steps - init_timestep, 0) + else: + t_start = int(round(denoising_start * num_inference_steps)) - t_start = max(num_inference_steps - init_timestep, 0) timesteps = self.scheduler.timesteps[t_start * self.scheduler.order :] return timesteps, num_inference_steps - t_start - def prepare_latents(self, image, timestep, batch_size, num_images_per_prompt, dtype, device, generator=None): + def prepare_latents( + self, image, timestep, batch_size, num_images_per_prompt, dtype, device, generator=None, add_noise=True + ): if not isinstance(image, (torch.Tensor, PIL.Image.Image, list)): raise ValueError( f"`image` has to be of type `torch.Tensor`, `PIL.Image.Image` or list but is {type(image)}" @@ -557,11 +575,12 @@ def prepare_latents(self, image, timestep, batch_size, num_images_per_prompt, dt else: init_latents = torch.cat([init_latents], dim=0) - shape = init_latents.shape - noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + if add_noise: + shape = init_latents.shape + noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + # get latents + init_latents = self.scheduler.add_noise(init_latents, noise, timestep) - # get latents - init_latents = self.scheduler.add_noise(init_latents, noise, timestep) latents = init_latents return latents @@ -620,6 +639,8 @@ def __call__( ] = None, strength: float = 0.3, num_inference_steps: int = 50, + denoising_start: Optional[float] = None, + denoising_end: Optional[float] = None, guidance_scale: float = 5.0, negative_prompt: Optional[Union[str, List[str]]] = None, num_images_per_prompt: Optional[int] = 1, @@ -651,7 +672,7 @@ def __call__( instead. image (`torch.FloatTensor` or `PIL.Image.Image` or `np.ndarray` or `List[torch.FloatTensor]` or `List[PIL.Image.Image]` or `List[np.ndarray]`): The image(s) to modify with the pipeline. - strength (`float`, *optional*, defaults to 0.8): + strength (`float`, *optional*, defaults to 0.3): Conceptually, indicates how much to transform the reference `image`. Must be between 0 and 1. `image` will be used as a starting point, adding more noise to it the larger the `strength`. The number of denoising steps depends on the amount of noise initially added. When `strength` is 1, added noise will @@ -660,6 +681,24 @@ def __call__( num_inference_steps (`int`, *optional*, defaults to 50): The number of denoising steps. More denoising steps usually lead to a higher quality image at the expense of slower inference. + denoising_start (`float`, *optional*): + When specified, indicates the fraction (between 0.0 and 1.0) of the total denoising process to be + bypassed before it is initiated. For example, if `denoising_start` is set to 0.7 and + num_inference_steps is fixed at 50, the process will begin only from the 35th (i.e., 0.7 * 50) + denoising step. Consequently, the initial part of the denoising process is skipped and it is assumed + that the passed `image` is a partly denoised image. The `denoising_start` parameter is particularly + beneficial when this pipeline is integrated into a "Mixture of Denoisers" multi-pipeline setup, as + detailed in [**Refining the Image + Output**](https://huggingface.co/docs/diffusers/api/pipelines/stable_diffusion/stable_diffusion_xl#refining-the-image-output). + denoising_end (`float`, *optional*): + When specified, determines the fraction (between 0.0 and 1.0) of the total denoising process to be + completed before it is intentionally prematurely terminated. For instance, if denoising_end is set to + 0.7 and `num_inference_steps` is fixed at 50, the process will execute only 35 (i.e., 0.7 * 50) + denoising steps. As a result, the returned sample will still retain a substantial amount of noise (ca. + 30%) and should be denoised by a successor pipeline that has `denoising_start` set to 0.7 so that it + only denoised the final 30%. The denoising_end parameter should ideally be utilized when this pipeline + forms a part of a "Mixture of Denoisers" multi-pipeline setup, as elaborated in [**Refining the Image + Output**](https://huggingface.co/docs/diffusers/api/pipelines/stable_diffusion/stable_diffusion_xl#refining-the-image-output). guidance_scale (`float`, *optional*, defaults to 7.5): Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). `guidance_scale` is defined as `w` of equation 2. of [Imagen @@ -738,7 +777,15 @@ def __call__( "not-safe-for-work" (nsfw) content, according to the `safety_checker`. """ # 1. Check inputs. Raise error if not correct - self.check_inputs(prompt, strength, callback_steps, negative_prompt, prompt_embeds, negative_prompt_embeds) + self.check_inputs( + prompt, + strength, + num_inference_steps, + callback_steps, + negative_prompt, + prompt_embeds, + negative_prompt_embeds, + ) # 2. Define call parameters if prompt is not None and isinstance(prompt, str): @@ -781,13 +828,25 @@ def __call__( image = self.image_processor.preprocess(image) # 5. Prepare timesteps + original_num_steps = num_inference_steps # save for denoising_start/end later + self.scheduler.set_timesteps(num_inference_steps, device=device) - timesteps, num_inference_steps = self.get_timesteps(num_inference_steps, strength, device) + timesteps, num_inference_steps = self.get_timesteps( + num_inference_steps, strength, device, denoising_start=denoising_start + ) latent_timestep = timesteps[:1].repeat(batch_size * num_images_per_prompt) + add_noise = True if denoising_start is None else False # 6. Prepare latent variables latents = self.prepare_latents( - image, latent_timestep, batch_size, num_images_per_prompt, prompt_embeds.dtype, device, generator + image, + latent_timestep, + batch_size, + num_images_per_prompt, + prompt_embeds.dtype, + device, + generator, + add_noise, ) # 7. Prepare extra step kwargs. extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) @@ -820,7 +879,22 @@ def __call__( add_time_ids = add_time_ids.to(device).repeat(batch_size * num_images_per_prompt, 1) # 9. Denoising loop - num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order + num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0) + + # 9.1 Apply denoising_end + if denoising_end is not None and denoising_start is not None: + if denoising_start >= denoising_end: + raise ValueError( + f"`denoising_end`: {denoising_end} cannot be larger than `denoising_start`: {denoising_start}." + ) + + skipped_final_steps = int(round((1 - denoising_end) * original_num_steps)) + num_inference_steps = num_inference_steps - skipped_final_steps + timesteps = timesteps[: num_warmup_steps + self.scheduler.order * num_inference_steps] + elif denoising_end is not None: + num_inference_steps = int(round(denoising_end * num_inference_steps)) + timesteps = timesteps[: num_warmup_steps + self.scheduler.order * num_inference_steps] + with self.progress_bar(total=num_inference_steps) as progress_bar: for i, t in enumerate(timesteps): # expand the latents if we are doing classifier free guidance diff --git a/tests/pipelines/stable_diffusion_xl/test_stable_diffusion_xl.py b/tests/pipelines/stable_diffusion_xl/test_stable_diffusion_xl.py index e481e85916d2..3f4dd19c9bd9 100644 --- a/tests/pipelines/stable_diffusion_xl/test_stable_diffusion_xl.py +++ b/tests/pipelines/stable_diffusion_xl/test_stable_diffusion_xl.py @@ -13,6 +13,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +import copy import unittest import numpy as np @@ -21,9 +22,14 @@ from diffusers import ( AutoencoderKL, + DDIMScheduler, + DPMSolverMultistepScheduler, EulerDiscreteScheduler, + HeunDiscreteScheduler, + StableDiffusionXLImg2ImgPipeline, StableDiffusionXLPipeline, UNet2DConditionModel, + UniPCMultistepScheduler, ) from diffusers.utils import torch_device from diffusers.utils.testing_utils import enable_full_determinism, require_torch_gpu @@ -216,3 +222,130 @@ def test_stable_diffusion_xl_offloads(self): assert np.abs(image_slices[0] - image_slices[1]).max() < 1e-3 assert np.abs(image_slices[0] - image_slices[2]).max() < 1e-3 + + def test_stable_diffusion_two_xl_mixture_of_denoiser(self): + components = self.get_dummy_components() + pipe_1 = StableDiffusionXLPipeline(**components).to(torch_device) + pipe_1.unet.set_default_attn_processor() + pipe_2 = StableDiffusionXLImg2ImgPipeline(**components).to(torch_device) + pipe_2.unet.set_default_attn_processor() + + def assert_run_mixture(num_steps, split, scheduler_cls): + inputs = self.get_dummy_inputs(torch_device) + inputs["num_inference_steps"] = num_steps + + pipe_1.scheduler = scheduler_cls.from_config(pipe_1.scheduler.config) + pipe_2.scheduler = scheduler_cls.from_config(pipe_2.scheduler.config) + + # Let's retrieve the number of timesteps we want to use + pipe_1.scheduler.set_timesteps(num_steps) + expected_steps = pipe_1.scheduler.timesteps.tolist() + + split_id = int(round(split * num_steps)) * pipe_1.scheduler.order + expected_steps_1 = expected_steps[:split_id] + expected_steps_2 = expected_steps[split_id:] + + # now we monkey patch step `done_steps` + # list into the step function for testing + done_steps = [] + old_step = copy.copy(scheduler_cls.step) + + def new_step(self, *args, **kwargs): + done_steps.append(args[1].cpu().item()) # args[1] is always the passed `t` + return old_step(self, *args, **kwargs) + + scheduler_cls.step = new_step + + inputs_1 = {**inputs, **{"denoising_end": split, "output_type": "latent"}} + latents = pipe_1(**inputs_1).images[0] + + assert expected_steps_1 == done_steps, f"Failure with {scheduler_cls.__name__} and {num_steps} and {split}" + + inputs_2 = {**inputs, **{"denoising_start": split, "image": latents}} + pipe_2(**inputs_2).images[0] + + assert expected_steps_2 == done_steps[len(expected_steps_1) :] + assert expected_steps == done_steps, f"Failure with {scheduler_cls.__name__} and {num_steps} and {split}" + + for steps in [5, 8]: + for split in [0.33, 0.49, 0.71]: + for scheduler_cls in [ + DDIMScheduler, + EulerDiscreteScheduler, + DPMSolverMultistepScheduler, + UniPCMultistepScheduler, + HeunDiscreteScheduler, + ]: + assert_run_mixture(steps, split, scheduler_cls) + + def test_stable_diffusion_three_xl_mixture_of_denoiser(self): + components = self.get_dummy_components() + pipe_1 = StableDiffusionXLPipeline(**components).to(torch_device) + pipe_1.unet.set_default_attn_processor() + pipe_2 = StableDiffusionXLImg2ImgPipeline(**components).to(torch_device) + pipe_2.unet.set_default_attn_processor() + pipe_3 = StableDiffusionXLImg2ImgPipeline(**components).to(torch_device) + pipe_3.unet.set_default_attn_processor() + + def assert_run_mixture(num_steps, split_1, split_2, scheduler_cls): + inputs = self.get_dummy_inputs(torch_device) + inputs["num_inference_steps"] = num_steps + + pipe_1.scheduler = scheduler_cls.from_config(pipe_1.scheduler.config) + pipe_2.scheduler = scheduler_cls.from_config(pipe_2.scheduler.config) + pipe_3.scheduler = scheduler_cls.from_config(pipe_3.scheduler.config) + + # Let's retrieve the number of timesteps we want to use + pipe_1.scheduler.set_timesteps(num_steps) + expected_steps = pipe_1.scheduler.timesteps.tolist() + + split_id_1 = int(round(split_1 * num_steps)) * pipe_1.scheduler.order + split_id_2 = int(round(split_2 * num_steps)) * pipe_1.scheduler.order + expected_steps_1 = expected_steps[:split_id_1] + expected_steps_2 = expected_steps[split_id_1:split_id_2] + expected_steps_3 = expected_steps[split_id_2:] + + # now we monkey patch step `done_steps` + # list into the step function for testing + done_steps = [] + old_step = copy.copy(scheduler_cls.step) + + def new_step(self, *args, **kwargs): + done_steps.append(args[1].cpu().item()) # args[1] is always the passed `t` + return old_step(self, *args, **kwargs) + + scheduler_cls.step = new_step + + inputs_1 = {**inputs, **{"denoising_end": split_1, "output_type": "latent"}} + latents = pipe_1(**inputs_1).images[0] + + assert ( + expected_steps_1 == done_steps + ), f"Failure with {scheduler_cls.__name__} and {num_steps} and {split_1} and {split_2}" + + inputs_2 = { + **inputs, + **{"denoising_start": split_1, "denoising_end": split_2, "image": latents, "output_type": "latent"}, + } + pipe_2(**inputs_2).images[0] + + assert expected_steps_2 == done_steps[len(expected_steps_1) :] + + inputs_3 = {**inputs, **{"denoising_start": split_2, "image": latents}} + pipe_3(**inputs_3).images[0] + + assert expected_steps_3 == done_steps[len(expected_steps_1) + len(expected_steps_2) :] + assert ( + expected_steps == done_steps + ), f"Failure with {scheduler_cls.__name__} and {num_steps} and {split_1} and {split_2}" + + for steps in [7, 11]: + for split_1, split_2 in zip([0.19, 0.32], [0.81, 0.68]): + for scheduler_cls in [ + DDIMScheduler, + EulerDiscreteScheduler, + DPMSolverMultistepScheduler, + UniPCMultistepScheduler, + HeunDiscreteScheduler, + ]: + assert_run_mixture(steps, split_1, split_2, scheduler_cls)