@@ -533,13 +533,20 @@ def get_timesteps(self, num_inference_steps, begin_inference_step, strength, dev
533
533
# get the original timestep using init_timestep
534
534
if begin_inference_step is None :
535
535
init_timestep = min (int (num_inference_steps * strength ), num_inference_steps )
536
+ t_start = max (num_inference_steps - init_timestep , 0 )
537
+ new_num_inference_steps = num_inference_steps - t_start
538
+ target_count = t_start * self .scheduler .order
536
539
else :
537
- init_timestep = begin_inference_step - 1
540
+ init_timestep = begin_inference_step
541
+ t_start = max (num_inference_steps - init_timestep , 0 )
542
+ new_num_inference_steps = num_inference_steps - (num_inference_steps - t_start )
543
+ target_count = num_inference_steps
544
+ logger .info (f'Calculating t_start via max({ num_inference_steps } - { init_timestep } , 0)' )
545
+ logger .info (f't_start is { t_start } ' )
546
+ logger .info (f'We will have { target_count } timesteps from { len (self .scheduler .timesteps )} total' )
547
+ timesteps = self .scheduler .timesteps [target_count :]
538
548
539
- t_start = max (num_inference_steps - init_timestep , 0 )
540
- timesteps = self .scheduler .timesteps [t_start * self .scheduler .order :]
541
-
542
- return timesteps , num_inference_steps - t_start
549
+ return timesteps , new_num_inference_steps
543
550
544
551
def prepare_latents (self , image , timestep , batch_size , num_images_per_prompt , dtype , device , generator = None , add_noise = True ):
545
552
if not isinstance (image , (torch .Tensor , PIL .Image .Image , list )):
@@ -924,6 +931,8 @@ def __call__(
924
931
num_warmup_steps = len (timesteps ) - num_inference_steps * self .scheduler .order
925
932
with self .progress_bar (total = num_inference_steps ) as progress_bar :
926
933
for i , t in enumerate (timesteps ):
934
+ if begin_inference_step is not None and i < begin_inference_step :
935
+ logger .info (f'Skipping step { i } because we are waiting for { begin_inference_step } ' )
927
936
# expand the latents if we are doing classifier free guidance
928
937
latent_model_input = torch .cat ([latents ] * 2 ) if do_classifier_free_guidance else latents
929
938
0 commit comments