Skip to content

Commit 2f142f6

Browse files
author
bghira
committed
diffusers#4003 - initial implementation of max_inference_steps and first_inference_step for img2img
1 parent 770feb7 commit 2f142f6

File tree

1 file changed

+34
-3
lines changed

1 file changed

+34
-3
lines changed

src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py

Lines changed: 34 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -455,11 +455,27 @@ def prepare_extra_step_kwargs(self, generator, eta):
455455
return extra_step_kwargs
456456

457457
def check_inputs(
458-
self, prompt, strength, callback_steps, negative_prompt=None, prompt_embeds=None, negative_prompt_embeds=None
458+
self, prompt, strength, num_inference_steps, first_inference_step, callback_steps, negative_prompt=None, prompt_embeds=None, negative_prompt_embeds=None
459459
):
460460
if strength < 0 or strength > 1:
461461
raise ValueError(f"The value of strength should in [0.0, 1.0] but is {strength}")
462-
462+
if num_inference_steps is None:
463+
raise ValueError("`num_inference_steps` cannot be None.")
464+
elif not isinstance(num_inference_steps, int) or num_inference_steps <= 0:
465+
raise ValueError(
466+
f"`num_inference_steps` has to be a positive integer but is {num_inference_steps} of type"
467+
f" {type(num_inference_steps)}."
468+
)
469+
if first_inference_step is not None and (not isinstance(first_inference_step, int) or first_inference_step <= 0):
470+
raise ValueError(
471+
f"`first_inference_step` has to be a positive integer but is {first_inference_step} of type"
472+
f" {type(first_inference_step)}."
473+
)
474+
if first_inference_step is not None and first_inference_step > num_inference_steps:
475+
raise ValueError(
476+
f"`first_inference_step` has to be smaller than `num_inference_steps` but is {first_inference_step} and"
477+
f" `num_inference_steps` is {num_inference_steps}."
478+
)
463479
if (callback_steps is None) or (
464480
callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0)
465481
):
@@ -619,6 +635,8 @@ def __call__(
619635
] = None,
620636
strength: float = 0.3,
621637
num_inference_steps: int = 50,
638+
max_inference_steps: Optional[int] = None,
639+
first_inference_step: Optional[int] = None,
622640
guidance_scale: float = 5.0,
623641
negative_prompt: Optional[Union[str, List[str]]] = None,
624642
num_images_per_prompt: Optional[int] = 1,
@@ -659,6 +677,12 @@ def __call__(
659677
num_inference_steps (`int`, *optional*, defaults to 50):
660678
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
661679
expense of slower inference.
680+
max_inference_steps (`int`, *optional*):
681+
Instead of completing the backwards pass entirely, stop and return the output after this many steps.
682+
Can be useful with `output_type="latent"` and an img2img pipeline, possibly with better fine detail.
683+
first_inference_step (`int`, *optional*):
684+
Ignore the first steps of the denoising process, and start from here.
685+
Useful if the input is a latent tensor that still has residual noise, eg. using `max_inference_steps`.
662686
guidance_scale (`float`, *optional*, defaults to 7.5):
663687
Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
664688
`guidance_scale` is defined as `w` of equation 2. of [Imagen
@@ -737,7 +761,7 @@ def __call__(
737761
"not-safe-for-work" (nsfw) content, according to the `safety_checker`.
738762
"""
739763
# 1. Check inputs. Raise error if not correct
740-
self.check_inputs(prompt, strength, callback_steps, negative_prompt, prompt_embeds, negative_prompt_embeds)
764+
self.check_inputs(prompt, strength, num_inference_steps, first_inference_step, callback_steps, negative_prompt, prompt_embeds, negative_prompt_embeds)
741765

742766
# 2. Define call parameters
743767
if prompt is not None and isinstance(prompt, str):
@@ -822,6 +846,10 @@ def __call__(
822846
num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
823847
with self.progress_bar(total=num_inference_steps) as progress_bar:
824848
for i, t in enumerate(timesteps):
849+
# skip a number of timesteps, if first_inference_step is set
850+
if first_inference_step is not None and i < first_inference_step:
851+
print(f'Skipping timestep {i} of {num_inference_steps} because of first_inference_step={first_inference_step}')
852+
continue
825853
# expand the latents if we are doing classifier free guidance
826854
latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
827855

@@ -855,6 +883,9 @@ def __call__(
855883
progress_bar.update()
856884
if callback is not None and i % callback_steps == 0:
857885
callback(i, t, latents)
886+
if max_inference_steps is not None and i >= max_inference_steps:
887+
logger.debug(f'Breaking inference loop at step {i} as we have reached max_inference_steps={max_inference_steps}')
888+
break
858889

859890
# make sure the VAE is in float32 mode, as it overflows in float16
860891
self.vae.to(dtype=torch.float32)

0 commit comments

Comments
 (0)