Skip to content

Commit 58bcf46

Browse files
authored
Add guidance start/end parameters to StableDiffusionControlNetImg2ImgPipeline (#2731)
* Add guidance start/end parameters to community controlnet img2img pipeline * Fix formats
1 parent 0042efd commit 58bcf46

File tree

1 file changed

+53
-14
lines changed

1 file changed

+53
-14
lines changed

examples/community/stable_diffusion_controlnet_img2img.py

Lines changed: 53 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -437,6 +437,8 @@ def check_inputs(
437437
prompt_embeds=None,
438438
negative_prompt_embeds=None,
439439
strength=None,
440+
controlnet_guidance_start=None,
441+
controlnet_guidance_end=None,
440442
):
441443
if height % 8 != 0 or width % 8 != 0:
442444
raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
@@ -542,7 +544,23 @@ def check_inputs(
542544
)
543545

544546
if strength < 0 or strength > 1:
545-
raise ValueError(f"The value of strength should in [0.0, 1.0] but is {strength}")
547+
raise ValueError(f"The value of `strength` should in [0.0, 1.0] but is {strength}")
548+
549+
if controlnet_guidance_start < 0 or controlnet_guidance_start > 1:
550+
raise ValueError(
551+
f"The value of `controlnet_guidance_start` should in [0.0, 1.0] but is {controlnet_guidance_start}"
552+
)
553+
554+
if controlnet_guidance_end < 0 or controlnet_guidance_end > 1:
555+
raise ValueError(
556+
f"The value of `controlnet_guidance_end` should in [0.0, 1.0] but is {controlnet_guidance_end}"
557+
)
558+
559+
if controlnet_guidance_start > controlnet_guidance_end:
560+
raise ValueError(
561+
"The value of `controlnet_guidance_start` should be less than `controlnet_guidance_end`, but got"
562+
f" `controlnet_guidance_start` {controlnet_guidance_start} >= `controlnet_guidance_end` {controlnet_guidance_end}"
563+
)
546564

547565
def get_timesteps(self, num_inference_steps, strength, device):
548566
# get the original timestep using init_timestep
@@ -643,6 +661,8 @@ def __call__(
643661
callback_steps: int = 1,
644662
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
645663
controlnet_conditioning_scale: float = 1.0,
664+
controlnet_guidance_start: float = 0.0,
665+
controlnet_guidance_end: float = 1.0,
646666
):
647667
r"""
648668
Function invoked when calling the pipeline for generation.
@@ -719,6 +739,11 @@ def __call__(
719739
controlnet_conditioning_scale (`float`, *optional*, defaults to 1.0):
720740
The outputs of the controlnet are multiplied by `controlnet_conditioning_scale` before they are added
721741
to the residual in the original unet.
742+
controlnet_guidance_start ('float', *optional*, defaults to 0.0):
743+
The percentage of total steps the controlnet starts applying. Must be between 0 and 1.
744+
controlnet_guidance_end ('float', *optional*, defaults to 1.0):
745+
The percentage of total steps the controlnet ends applying. Must be between 0 and 1. Must be greater
746+
than `controlnet_guidance_start`.
722747
723748
Examples:
724749
@@ -745,6 +770,8 @@ def __call__(
745770
prompt_embeds,
746771
negative_prompt_embeds,
747772
strength,
773+
controlnet_guidance_start,
774+
controlnet_guidance_end,
748775
)
749776

750777
# 2. Define call parameters
@@ -820,19 +847,31 @@ def __call__(
820847

821848
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
822849

823-
down_block_res_samples, mid_block_res_sample = self.controlnet(
824-
latent_model_input,
825-
t,
826-
encoder_hidden_states=prompt_embeds,
827-
controlnet_cond=controlnet_conditioning_image,
828-
return_dict=False,
829-
)
830-
831-
down_block_res_samples = [
832-
down_block_res_sample * controlnet_conditioning_scale
833-
for down_block_res_sample in down_block_res_samples
834-
]
835-
mid_block_res_sample *= controlnet_conditioning_scale
850+
# compute the percentage of total steps we are at
851+
current_sampling_percent = i / len(timesteps)
852+
853+
if (
854+
current_sampling_percent < controlnet_guidance_start
855+
or current_sampling_percent > controlnet_guidance_end
856+
):
857+
# do not apply the controlnet
858+
down_block_res_samples = None
859+
mid_block_res_sample = None
860+
else:
861+
# apply the controlnet
862+
down_block_res_samples, mid_block_res_sample = self.controlnet(
863+
latent_model_input,
864+
t,
865+
encoder_hidden_states=prompt_embeds,
866+
controlnet_cond=controlnet_conditioning_image,
867+
return_dict=False,
868+
)
869+
870+
down_block_res_samples = [
871+
down_block_res_sample * controlnet_conditioning_scale
872+
for down_block_res_sample in down_block_res_samples
873+
]
874+
mid_block_res_sample *= controlnet_conditioning_scale
836875

837876
# predict the noise residual
838877
noise_pred = self.unet(

0 commit comments

Comments
 (0)