@@ -454,8 +454,27 @@ def prepare_extra_step_kwargs(self, generator, eta):
454
454
extra_step_kwargs ["generator" ] = generator
455
455
return extra_step_kwargs
456
456
457
+ def timesteps_from_strength (
458
+ self , strength : float , num_inference_steps : int
459
+ ):
460
+ """Retrieve values for `final_inference_step` and `begin_inference_step` from `strength`, `num_inference_steps`
461
+
462
+ Args:
463
+ strength (float): A traditional img2img strength between 0.0 and 1.0, with higher values resulting in greater
464
+ influence from the img2img model and lower values, more influence from the base model.
465
+ num_inference_steps (int): The total number of inference steps to be taken.
466
+ Returns:
467
+ final_inference_step (int): The final inference step to be taken.
468
+ begin_inference_step (int): The inference step to begin img2img inference.
469
+ """
470
+ # We need to invert the percentage. A strength of 0.0 should result in 100% of the inference steps.
471
+ inverse_strength = 1.0 - strength
472
+ final_inference_step = int (num_inference_steps * inverse_strength )
473
+ begin_inference_step = final_inference_step
474
+ return final_inference_step , begin_inference_step
475
+
457
476
def check_inputs (
458
- self , prompt , strength , num_inference_steps , first_inference_step , callback_steps , negative_prompt = None , prompt_embeds = None , negative_prompt_embeds = None
477
+ self , prompt , strength , num_inference_steps , begin_inference_step , callback_steps , negative_prompt = None , prompt_embeds = None , negative_prompt_embeds = None
459
478
):
460
479
if strength < 0 or strength > 1 :
461
480
raise ValueError (f"The value of strength should in [0.0, 1.0] but is { strength } " )
@@ -466,14 +485,14 @@ def check_inputs(
466
485
f"`num_inference_steps` has to be a positive integer but is { num_inference_steps } of type"
467
486
f" { type (num_inference_steps )} ."
468
487
)
469
- if first_inference_step is not None and (not isinstance (first_inference_step , int ) or first_inference_step <= 0 ):
488
+ if begin_inference_step is not None and (not isinstance (begin_inference_step , int ) or begin_inference_step <= 0 ):
470
489
raise ValueError (
471
- f"`first_inference_step ` has to be a positive integer but is { first_inference_step } of type"
472
- f" { type (first_inference_step )} ."
490
+ f"`begin_inference_step ` has to be a positive integer but is { begin_inference_step } of type"
491
+ f" { type (begin_inference_step )} ."
473
492
)
474
- if first_inference_step is not None and first_inference_step > num_inference_steps :
493
+ if begin_inference_step is not None and begin_inference_step > num_inference_steps :
475
494
raise ValueError (
476
- f"`first_inference_step ` has to be smaller than `num_inference_steps` but is { first_inference_step } and"
495
+ f"`begin_inference_step ` has to be smaller than `num_inference_steps` but is { begin_inference_step } and"
477
496
f" `num_inference_steps` is { num_inference_steps } ."
478
497
)
479
498
if (callback_steps is None ) or (
@@ -510,12 +529,12 @@ def check_inputs(
510
529
f" { negative_prompt_embeds .shape } ."
511
530
)
512
531
513
- def get_timesteps (self , num_inference_steps , first_inference_step , strength , device ):
532
+ def get_timesteps (self , num_inference_steps , begin_inference_step , strength , device ):
514
533
# get the original timestep using init_timestep
515
- if first_inference_step is None :
534
+ if begin_inference_step is None :
516
535
init_timestep = min (int (num_inference_steps * strength ), num_inference_steps )
517
536
else :
518
- init_timestep = first_inference_step - 1
537
+ init_timestep = begin_inference_step - 1
519
538
520
539
t_start = max (num_inference_steps - init_timestep , 0 )
521
540
timesteps = self .scheduler .timesteps [t_start * self .scheduler .order :]
@@ -638,8 +657,8 @@ def __call__(
638
657
] = None ,
639
658
strength : float = 0.3 ,
640
659
num_inference_steps : int = 50 ,
641
- max_inference_steps : Optional [int ] = None ,
642
- first_inference_step : Optional [int ] = None ,
660
+ final_inference_step : Optional [int ] = None ,
661
+ begin_inference_step : Optional [int ] = None ,
643
662
guidance_scale : float = 5.0 ,
644
663
negative_prompt : Optional [Union [str , List [str ]]] = None ,
645
664
num_images_per_prompt : Optional [int ] = 1 ,
@@ -680,12 +699,12 @@ def __call__(
680
699
num_inference_steps (`int`, *optional*, defaults to 50):
681
700
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
682
701
expense of slower inference.
683
- max_inference_steps (`int`, *optional*):
702
+ final_inference_step (`int`, *optional*):
684
703
Instead of completing the backwards pass entirely, stop and return the output after this many steps.
685
704
Can be useful with `output_type="latent"` and an img2img pipeline, possibly with better fine detail.
686
- first_inference_step (`int`, *optional*):
705
+ begin_inference_step (`int`, *optional*):
687
706
Ignore the first steps of the denoising process, and start from here.
688
- Useful if the input is a latent tensor that still has residual noise, eg. using `max_inference_steps `.
707
+ Useful if the input is a latent tensor that still has residual noise, eg. using `final_inference_step `.
689
708
guidance_scale (`float`, *optional*, defaults to 7.5):
690
709
Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
691
710
`guidance_scale` is defined as `w` of equation 2. of [Imagen
@@ -764,7 +783,7 @@ def __call__(
764
783
"not-safe-for-work" (nsfw) content, according to the `safety_checker`.
765
784
"""
766
785
# 1. Check inputs. Raise error if not correct
767
- self .check_inputs (prompt , strength , num_inference_steps , first_inference_step , callback_steps , negative_prompt , prompt_embeds , negative_prompt_embeds )
786
+ self .check_inputs (prompt , strength , num_inference_steps , begin_inference_step , callback_steps , negative_prompt , prompt_embeds , negative_prompt_embeds )
768
787
769
788
# 2. Define call parameters
770
789
if prompt is not None and isinstance (prompt , str ):
@@ -808,10 +827,10 @@ def __call__(
808
827
809
828
# 5. Prepare timesteps
810
829
self .scheduler .set_timesteps (num_inference_steps , device = device )
811
- timesteps , num_inference_steps = self .get_timesteps (num_inference_steps , first_inference_step , strength , device )
830
+ timesteps , num_inference_steps = self .get_timesteps (num_inference_steps , begin_inference_step , strength , device )
812
831
latent_timestep = timesteps [:1 ].repeat (batch_size * num_images_per_prompt )
813
832
add_noise = True
814
- if first_inference_step is not None :
833
+ if begin_inference_step is not None :
815
834
add_noise = False
816
835
# 6. Prepare latent variables
817
836
latents = self .prepare_latents (
@@ -884,8 +903,8 @@ def __call__(
884
903
progress_bar .update ()
885
904
if callback is not None and i % callback_steps == 0 :
886
905
callback (i , t , latents )
887
- if max_inference_steps is not None and i >= max_inference_steps :
888
- logger .debug (f'Breaking inference loop at step { i } as we have reached max_inference_steps= { max_inference_steps } ' )
906
+ if final_inference_step is not None and i >= final_inference_step :
907
+ logger .debug (f'Breaking inference loop at step { i } as we have reached final_inference_step= { final_inference_step } ' )
889
908
break
890
909
891
910
# make sure the VAE is in float32 mode, as it overflows in float16
0 commit comments