@@ -455,11 +455,27 @@ def prepare_extra_step_kwargs(self, generator, eta):
455
455
return extra_step_kwargs
456
456
457
457
def check_inputs (
458
- self , prompt , strength , callback_steps , negative_prompt = None , prompt_embeds = None , negative_prompt_embeds = None
458
+ self , prompt , strength , num_inference_steps , first_inference_step , callback_steps , negative_prompt = None , prompt_embeds = None , negative_prompt_embeds = None
459
459
):
460
460
if strength < 0 or strength > 1 :
461
461
raise ValueError (f"The value of strength should in [0.0, 1.0] but is { strength } " )
462
-
462
+ if num_inference_steps is None :
463
+ raise ValueError ("`num_inference_steps` cannot be None." )
464
+ elif not isinstance (num_inference_steps , int ) or num_inference_steps <= 0 :
465
+ raise ValueError (
466
+ f"`num_inference_steps` has to be a positive integer but is { num_inference_steps } of type"
467
+ f" { type (num_inference_steps )} ."
468
+ )
469
+ if first_inference_step is not None and (not isinstance (first_inference_step , int ) or first_inference_step <= 0 ):
470
+ raise ValueError (
471
+ f"`first_inference_step` has to be a positive integer but is { first_inference_step } of type"
472
+ f" { type (first_inference_step )} ."
473
+ )
474
+ if first_inference_step is not None and first_inference_step > num_inference_steps :
475
+ raise ValueError (
476
+ f"`first_inference_step` has to be smaller than `num_inference_steps` but is { first_inference_step } and"
477
+ f" `num_inference_steps` is { num_inference_steps } ."
478
+ )
463
479
if (callback_steps is None ) or (
464
480
callback_steps is not None and (not isinstance (callback_steps , int ) or callback_steps <= 0 )
465
481
):
@@ -619,6 +635,8 @@ def __call__(
619
635
] = None ,
620
636
strength : float = 0.3 ,
621
637
num_inference_steps : int = 50 ,
638
+ max_inference_steps : Optional [int ] = None ,
639
+ first_inference_step : Optional [int ] = None ,
622
640
guidance_scale : float = 5.0 ,
623
641
negative_prompt : Optional [Union [str , List [str ]]] = None ,
624
642
num_images_per_prompt : Optional [int ] = 1 ,
@@ -659,6 +677,12 @@ def __call__(
659
677
num_inference_steps (`int`, *optional*, defaults to 50):
660
678
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
661
679
expense of slower inference.
680
+ max_inference_steps (`int`, *optional*):
681
+ Instead of completing the backwards pass entirely, stop and return the output after this many steps.
682
+ Can be useful with `output_type="latent"` and an img2img pipeline, possibly with better fine detail.
683
+ first_inference_step (`int`, *optional*):
684
+ Ignore the first steps of the denoising process, and start from here.
685
+ Useful if the input is a latent tensor that still has residual noise, eg. using `max_inference_steps`.
662
686
guidance_scale (`float`, *optional*, defaults to 7.5):
663
687
Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
664
688
`guidance_scale` is defined as `w` of equation 2. of [Imagen
@@ -737,7 +761,7 @@ def __call__(
737
761
"not-safe-for-work" (nsfw) content, according to the `safety_checker`.
738
762
"""
739
763
# 1. Check inputs. Raise error if not correct
740
- self .check_inputs (prompt , strength , callback_steps , negative_prompt , prompt_embeds , negative_prompt_embeds )
764
+ self .check_inputs (prompt , strength , num_inference_steps , first_inference_step , callback_steps , negative_prompt , prompt_embeds , negative_prompt_embeds )
741
765
742
766
# 2. Define call parameters
743
767
if prompt is not None and isinstance (prompt , str ):
@@ -822,6 +846,10 @@ def __call__(
822
846
num_warmup_steps = len (timesteps ) - num_inference_steps * self .scheduler .order
823
847
with self .progress_bar (total = num_inference_steps ) as progress_bar :
824
848
for i , t in enumerate (timesteps ):
849
+ # skip a number of timesteps, if first_inference_step is set
850
+ if first_inference_step is not None and i < first_inference_step :
851
+ print (f'Skipping timestep { i } of { num_inference_steps } because of first_inference_step={ first_inference_step } ' )
852
+ continue
825
853
# expand the latents if we are doing classifier free guidance
826
854
latent_model_input = torch .cat ([latents ] * 2 ) if do_classifier_free_guidance else latents
827
855
@@ -855,6 +883,9 @@ def __call__(
855
883
progress_bar .update ()
856
884
if callback is not None and i % callback_steps == 0 :
857
885
callback (i , t , latents )
886
+ if max_inference_steps is not None and i >= max_inference_steps :
887
+ logger .debug (f'Breaking inference loop at step { i } as we have reached max_inference_steps={ max_inference_steps } ' )
888
+ break
858
889
859
890
# make sure the VAE is in float32 mode, as it overflows in float16
860
891
self .vae .to (dtype = torch .float32 )
0 commit comments