Skip to content

Commit c63a53a

Browse files
author
bghira
committed
diffusers#4003 Do not add noise during img2img when we have a defined first timestep
1 parent a4c6217 commit c63a53a

File tree

1 file changed

+12
-10
lines changed

1 file changed

+12
-10
lines changed

src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py

Lines changed: 12 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -515,14 +515,14 @@ def get_timesteps(self, num_inference_steps, first_inference_step, strength, dev
515515
if first_inference_step is None:
516516
init_timestep = min(int(num_inference_steps * strength), num_inference_steps)
517517
else:
518-
init_timestep = first_inference_step
518+
init_timestep = first_inference_step - 1
519519

520520
t_start = max(num_inference_steps - init_timestep, 0)
521521
timesteps = self.scheduler.timesteps[t_start * self.scheduler.order :]
522522

523523
return timesteps, num_inference_steps - t_start
524524

525-
def prepare_latents(self, image, timestep, batch_size, num_images_per_prompt, dtype, device, generator=None):
525+
def prepare_latents(self, image, timestep, batch_size, num_images_per_prompt, dtype, device, generator=None, add_noise=True):
526526
if not isinstance(image, (torch.Tensor, PIL.Image.Image, list)):
527527
raise ValueError(
528528
f"`image` has to be of type `torch.Tensor`, `PIL.Image.Image` or list but is {type(image)}"
@@ -574,12 +574,12 @@ def prepare_latents(self, image, timestep, batch_size, num_images_per_prompt, dt
574574
)
575575
else:
576576
init_latents = torch.cat([init_latents], dim=0)
577+
if add_noise:
578+
shape = init_latents.shape
579+
noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
577580

578-
shape = init_latents.shape
579-
noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
580-
581-
# get latents
582-
init_latents = self.scheduler.add_noise(init_latents, noise, timestep)
581+
# get latents
582+
init_latents = self.scheduler.add_noise(init_latents, noise, timestep)
583583
latents = init_latents
584584

585585
return latents
@@ -671,7 +671,7 @@ def __call__(
671671
instead.
672672
image (`torch.FloatTensor` or `PIL.Image.Image` or `np.ndarray` or `List[torch.FloatTensor]` or `List[PIL.Image.Image]` or `List[np.ndarray]`):
673673
The image(s) to modify with the pipeline.
674-
strength (`float`, *optional*, defaults to 0.8):
674+
strength (`float`, *optional*, defaults to 0.3):
675675
Conceptually, indicates how much to transform the reference `image`. Must be between 0 and 1. `image`
676676
will be used as a starting point, adding more noise to it the larger the `strength`. The number of
677677
denoising steps depends on the amount of noise initially added. When `strength` is 1, added noise will
@@ -810,10 +810,12 @@ def __call__(
810810
self.scheduler.set_timesteps(num_inference_steps, device=device)
811811
timesteps, num_inference_steps = self.get_timesteps(num_inference_steps, first_inference_step, strength, device)
812812
latent_timestep = timesteps[:1].repeat(batch_size * num_images_per_prompt)
813-
813+
add_noise = True
814+
if first_inference_step is not None:
815+
add_noise = False
814816
# 6. Prepare latent variables
815817
latents = self.prepare_latents(
816-
image, latent_timestep, batch_size, num_images_per_prompt, prompt_embeds.dtype, device, generator
818+
image, latent_timestep, batch_size, num_images_per_prompt, prompt_embeds.dtype, device, generator, add_noise
817819
)
818820
# 7. Prepare extra step kwargs.
819821
extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)

0 commit comments

Comments
 (0)