@@ -510,9 +510,12 @@ def check_inputs(
510
510
f" { negative_prompt_embeds .shape } ."
511
511
)
512
512
513
- def get_timesteps (self , num_inference_steps , strength , device ):
513
+ def get_timesteps (self , num_inference_steps , first_inference_step , strength , device ):
514
514
# get the original timestep using init_timestep
515
- init_timestep = min (int (num_inference_steps * strength ), num_inference_steps )
515
+ if first_inference_step is None :
516
+ init_timestep = min (int (num_inference_steps * strength ), num_inference_steps )
517
+ else :
518
+ init_timestep = first_inference_step
516
519
517
520
t_start = max (num_inference_steps - init_timestep , 0 )
518
521
timesteps = self .scheduler .timesteps [t_start * self .scheduler .order :]
@@ -805,7 +808,7 @@ def __call__(
805
808
806
809
# 5. Prepare timesteps
807
810
self .scheduler .set_timesteps (num_inference_steps , device = device )
808
- timesteps , num_inference_steps = self .get_timesteps (num_inference_steps , strength , device )
811
+ timesteps , num_inference_steps = self .get_timesteps (num_inference_steps , first_inference_step , strength , device )
809
812
latent_timestep = timesteps [:1 ].repeat (batch_size * num_images_per_prompt )
810
813
811
814
# 6. Prepare latent variables
@@ -846,10 +849,6 @@ def __call__(
846
849
num_warmup_steps = len (timesteps ) - num_inference_steps * self .scheduler .order
847
850
with self .progress_bar (total = num_inference_steps ) as progress_bar :
848
851
for i , t in enumerate (timesteps ):
849
- # skip a number of timesteps, if first_inference_step is set
850
- if first_inference_step is not None and i < first_inference_step :
851
- print (f'Skipping timestep { i } of { num_inference_steps } because of first_inference_step={ first_inference_step } ' )
852
- continue
853
852
# expand the latents if we are doing classifier free guidance
854
853
latent_model_input = torch .cat ([latents ] * 2 ) if do_classifier_free_guidance else latents
855
854
0 commit comments