Skip to content

Commit 942c6a5

Browse files
author
bghira
committed
diffusers#4003 Correct the method being used for skipping steps, and the count of steps
1 parent 77469f5 commit 942c6a5

File tree

1 file changed

+14
-5
lines changed

1 file changed

+14
-5
lines changed

src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py

Lines changed: 14 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -533,13 +533,20 @@ def get_timesteps(self, num_inference_steps, begin_inference_step, strength, dev
533533
# get the original timestep using init_timestep
534534
if begin_inference_step is None:
535535
init_timestep = min(int(num_inference_steps * strength), num_inference_steps)
536+
t_start = max(num_inference_steps - init_timestep, 0)
537+
new_num_inference_steps = num_inference_steps - t_start
538+
target_count = t_start * self.scheduler.order
536539
else:
537-
init_timestep = begin_inference_step - 1
540+
init_timestep = begin_inference_step
541+
t_start = max(num_inference_steps - init_timestep, 0)
542+
new_num_inference_steps = num_inference_steps - (num_inference_steps - t_start)
543+
target_count = num_inference_steps
544+
logger.info(f'Calculating t_start via max({num_inference_steps} - {init_timestep}, 0)')
545+
logger.info(f't_start is {t_start}')
546+
logger.info(f'We will have {target_count} timesteps from {len(self.scheduler.timesteps)} total')
547+
timesteps = self.scheduler.timesteps[target_count :]
538548

539-
t_start = max(num_inference_steps - init_timestep, 0)
540-
timesteps = self.scheduler.timesteps[t_start * self.scheduler.order :]
541-
542-
return timesteps, num_inference_steps - t_start
549+
return timesteps, new_num_inference_steps
543550

544551
def prepare_latents(self, image, timestep, batch_size, num_images_per_prompt, dtype, device, generator=None, add_noise=True):
545552
if not isinstance(image, (torch.Tensor, PIL.Image.Image, list)):
@@ -924,6 +931,8 @@ def __call__(
924931
num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
925932
with self.progress_bar(total=num_inference_steps) as progress_bar:
926933
for i, t in enumerate(timesteps):
934+
if begin_inference_step is not None and i < begin_inference_step:
935+
logger.info(f'Skipping step {i} because we are waiting for {begin_inference_step}')
927936
# expand the latents if we are doing classifier free guidance
928937
latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
929938

0 commit comments

Comments
 (0)