Skip to content

Commit 062afbb

Browse files
author
bghira
committed
diffusers#4003 Mild updates after revert
1 parent c63a53a commit 062afbb

File tree

2 files changed

+57
-19
lines changed

2 files changed

+57
-19
lines changed

src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -444,6 +444,25 @@ def prepare_extra_step_kwargs(self, generator, eta):
444444
extra_step_kwargs["generator"] = generator
445445
return extra_step_kwargs
446446

447+
def timesteps_from_strength(
448+
self, strength: float, num_inference_steps: int
449+
):
450+
"""Retrieve values for `final_inference_step` and `begin_inference_step` from `strength`, `num_inference_steps`
451+
452+
Args:
453+
strength (float): A traditional img2img strength between 0.0 and 1.0, with higher values resulting in greater
454+
influence from the img2img model and lower values, more influence from the base model.
455+
num_inference_steps (int): The total number of inference steps to be taken.
456+
Returns:
457+
final_inference_step (int): The final inference step to be taken.
458+
begin_inference_step (int): The inference step to begin img2img inference.
459+
"""
460+
# We need to invert the percentage. A strength of 0.0 should result in 100% of the inference steps.
461+
inverse_strength = 1.0 - strength
462+
final_inference_step = int(num_inference_steps * inverse_strength)
463+
begin_inference_step = final_inference_step
464+
return final_inference_step, begin_inference_step
465+
447466
def check_inputs(
448467
self,
449468
prompt,

src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py

Lines changed: 38 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -454,8 +454,27 @@ def prepare_extra_step_kwargs(self, generator, eta):
454454
extra_step_kwargs["generator"] = generator
455455
return extra_step_kwargs
456456

457+
def timesteps_from_strength(
458+
self, strength: float, num_inference_steps: int
459+
):
460+
"""Retrieve values for `final_inference_step` and `begin_inference_step` from `strength`, `num_inference_steps`
461+
462+
Args:
463+
strength (float): A traditional img2img strength between 0.0 and 1.0, with higher values resulting in greater
464+
influence from the img2img model and lower values, more influence from the base model.
465+
num_inference_steps (int): The total number of inference steps to be taken.
466+
Returns:
467+
final_inference_step (int): The final inference step to be taken.
468+
begin_inference_step (int): The inference step to begin img2img inference.
469+
"""
470+
# We need to invert the percentage. A strength of 0.0 should result in 100% of the inference steps.
471+
inverse_strength = 1.0 - strength
472+
final_inference_step = int(num_inference_steps * inverse_strength)
473+
begin_inference_step = final_inference_step
474+
return final_inference_step, begin_inference_step
475+
457476
def check_inputs(
458-
self, prompt, strength, num_inference_steps, first_inference_step, callback_steps, negative_prompt=None, prompt_embeds=None, negative_prompt_embeds=None
477+
self, prompt, strength, num_inference_steps, begin_inference_step, callback_steps, negative_prompt=None, prompt_embeds=None, negative_prompt_embeds=None
459478
):
460479
if strength < 0 or strength > 1:
461480
raise ValueError(f"The value of strength should in [0.0, 1.0] but is {strength}")
@@ -466,14 +485,14 @@ def check_inputs(
466485
f"`num_inference_steps` has to be a positive integer but is {num_inference_steps} of type"
467486
f" {type(num_inference_steps)}."
468487
)
469-
if first_inference_step is not None and (not isinstance(first_inference_step, int) or first_inference_step <= 0):
488+
if begin_inference_step is not None and (not isinstance(begin_inference_step, int) or begin_inference_step <= 0):
470489
raise ValueError(
471-
f"`first_inference_step` has to be a positive integer but is {first_inference_step} of type"
472-
f" {type(first_inference_step)}."
490+
f"`begin_inference_step` has to be a positive integer but is {begin_inference_step} of type"
491+
f" {type(begin_inference_step)}."
473492
)
474-
if first_inference_step is not None and first_inference_step > num_inference_steps:
493+
if begin_inference_step is not None and begin_inference_step > num_inference_steps:
475494
raise ValueError(
476-
f"`first_inference_step` has to be smaller than `num_inference_steps` but is {first_inference_step} and"
495+
f"`begin_inference_step` has to be smaller than `num_inference_steps` but is {begin_inference_step} and"
477496
f" `num_inference_steps` is {num_inference_steps}."
478497
)
479498
if (callback_steps is None) or (
@@ -510,12 +529,12 @@ def check_inputs(
510529
f" {negative_prompt_embeds.shape}."
511530
)
512531

513-
def get_timesteps(self, num_inference_steps, first_inference_step, strength, device):
532+
def get_timesteps(self, num_inference_steps, begin_inference_step, strength, device):
514533
# get the original timestep using init_timestep
515-
if first_inference_step is None:
534+
if begin_inference_step is None:
516535
init_timestep = min(int(num_inference_steps * strength), num_inference_steps)
517536
else:
518-
init_timestep = first_inference_step - 1
537+
init_timestep = begin_inference_step - 1
519538

520539
t_start = max(num_inference_steps - init_timestep, 0)
521540
timesteps = self.scheduler.timesteps[t_start * self.scheduler.order :]
@@ -638,8 +657,8 @@ def __call__(
638657
] = None,
639658
strength: float = 0.3,
640659
num_inference_steps: int = 50,
641-
max_inference_steps: Optional[int] = None,
642-
first_inference_step: Optional[int] = None,
660+
final_inference_step: Optional[int] = None,
661+
begin_inference_step: Optional[int] = None,
643662
guidance_scale: float = 5.0,
644663
negative_prompt: Optional[Union[str, List[str]]] = None,
645664
num_images_per_prompt: Optional[int] = 1,
@@ -680,12 +699,12 @@ def __call__(
680699
num_inference_steps (`int`, *optional*, defaults to 50):
681700
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
682701
expense of slower inference.
683-
max_inference_steps (`int`, *optional*):
702+
final_inference_step (`int`, *optional*):
684703
Instead of completing the backwards pass entirely, stop and return the output after this many steps.
685704
Can be useful with `output_type="latent"` and an img2img pipeline, possibly with better fine detail.
686-
first_inference_step (`int`, *optional*):
705+
begin_inference_step (`int`, *optional*):
687706
Ignore the first steps of the denoising process, and start from here.
688-
Useful if the input is a latent tensor that still has residual noise, eg. using `max_inference_steps`.
707+
Useful if the input is a latent tensor that still has residual noise, eg. using `final_inference_step`.
689708
guidance_scale (`float`, *optional*, defaults to 7.5):
690709
Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
691710
`guidance_scale` is defined as `w` of equation 2. of [Imagen
@@ -764,7 +783,7 @@ def __call__(
764783
"not-safe-for-work" (nsfw) content, according to the `safety_checker`.
765784
"""
766785
# 1. Check inputs. Raise error if not correct
767-
self.check_inputs(prompt, strength, num_inference_steps, first_inference_step, callback_steps, negative_prompt, prompt_embeds, negative_prompt_embeds)
786+
self.check_inputs(prompt, strength, num_inference_steps, begin_inference_step, callback_steps, negative_prompt, prompt_embeds, negative_prompt_embeds)
768787

769788
# 2. Define call parameters
770789
if prompt is not None and isinstance(prompt, str):
@@ -808,10 +827,10 @@ def __call__(
808827

809828
# 5. Prepare timesteps
810829
self.scheduler.set_timesteps(num_inference_steps, device=device)
811-
timesteps, num_inference_steps = self.get_timesteps(num_inference_steps, first_inference_step, strength, device)
830+
timesteps, num_inference_steps = self.get_timesteps(num_inference_steps, begin_inference_step, strength, device)
812831
latent_timestep = timesteps[:1].repeat(batch_size * num_images_per_prompt)
813832
add_noise = True
814-
if first_inference_step is not None:
833+
if begin_inference_step is not None:
815834
add_noise = False
816835
# 6. Prepare latent variables
817836
latents = self.prepare_latents(
@@ -884,8 +903,8 @@ def __call__(
884903
progress_bar.update()
885904
if callback is not None and i % callback_steps == 0:
886905
callback(i, t, latents)
887-
if max_inference_steps is not None and i >= max_inference_steps:
888-
logger.debug(f'Breaking inference loop at step {i} as we have reached max_inference_steps={max_inference_steps}')
906+
if final_inference_step is not None and i >= final_inference_step:
907+
logger.debug(f'Breaking inference loop at step {i} as we have reached final_inference_step={final_inference_step}')
889908
break
890909

891910
# make sure the VAE is in float32 mode, as it overflows in float16

0 commit comments

Comments
 (0)