Skip to content

Commit a4c6217

Browse files
author
bghira
committed
diffusers#4003 - use first_inference_step as an input arg for get_timestamps in img2img
1 parent 2f142f6 commit a4c6217

File tree

1 file changed

+6
-7
lines changed

1 file changed

+6
-7
lines changed

src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py

Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -510,9 +510,12 @@ def check_inputs(
510510
f" {negative_prompt_embeds.shape}."
511511
)
512512

513-
def get_timesteps(self, num_inference_steps, strength, device):
513+
def get_timesteps(self, num_inference_steps, first_inference_step, strength, device):
514514
# get the original timestep using init_timestep
515-
init_timestep = min(int(num_inference_steps * strength), num_inference_steps)
515+
if first_inference_step is None:
516+
init_timestep = min(int(num_inference_steps * strength), num_inference_steps)
517+
else:
518+
init_timestep = first_inference_step
516519

517520
t_start = max(num_inference_steps - init_timestep, 0)
518521
timesteps = self.scheduler.timesteps[t_start * self.scheduler.order :]
@@ -805,7 +808,7 @@ def __call__(
805808

806809
# 5. Prepare timesteps
807810
self.scheduler.set_timesteps(num_inference_steps, device=device)
808-
timesteps, num_inference_steps = self.get_timesteps(num_inference_steps, strength, device)
811+
timesteps, num_inference_steps = self.get_timesteps(num_inference_steps, first_inference_step, strength, device)
809812
latent_timestep = timesteps[:1].repeat(batch_size * num_images_per_prompt)
810813

811814
# 6. Prepare latent variables
@@ -846,10 +849,6 @@ def __call__(
846849
num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
847850
with self.progress_bar(total=num_inference_steps) as progress_bar:
848851
for i, t in enumerate(timesteps):
849-
# skip a number of timesteps, if first_inference_step is set
850-
if first_inference_step is not None and i < first_inference_step:
851-
print(f'Skipping timestep {i} of {num_inference_steps} because of first_inference_step={first_inference_step}')
852-
continue
853852
# expand the latents if we are doing classifier free guidance
854853
latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
855854

0 commit comments

Comments
 (0)