Skip to content

Add guidance start/end parameters to StableDiffusionControlNetImg2ImgPipeline #2731

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Mar 21, 2023
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
67 changes: 53 additions & 14 deletions examples/community/stable_diffusion_controlnet_img2img.py
Original file line number Diff line number Diff line change
Expand Up @@ -437,6 +437,8 @@ def check_inputs(
prompt_embeds=None,
negative_prompt_embeds=None,
strength=None,
controlnet_guidance_start=None,
controlnet_guidance_end=None,
):
if height % 8 != 0 or width % 8 != 0:
raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
Expand Down Expand Up @@ -542,7 +544,23 @@ def check_inputs(
)

if strength < 0 or strength > 1:
raise ValueError(f"The value of strength should in [0.0, 1.0] but is {strength}")
raise ValueError(f"The value of `strength` should in [0.0, 1.0] but is {strength}")

if controlnet_guidance_start < 0 or controlnet_guidance_start > 1:
raise ValueError(
f"The value of `controlnet_guidance_start` should in [0.0, 1.0] but is {controlnet_guidance_start}"
)

if controlnet_guidance_end < 0 or controlnet_guidance_end > 1:
raise ValueError(
f"The value of `controlnet_guidance_end` should in [0.0, 1.0] but is {controlnet_guidance_end}"
)

if controlnet_guidance_start > controlnet_guidance_end:
raise ValueError(
"The value of `controlnet_guidance_start` should be less than `controlnet_guidance_end`, but got"
f" `controlnet_guidance_start` {controlnet_guidance_start} >= `controlnet_guidance_end` {controlnet_guidance_end}"
)

def get_timesteps(self, num_inference_steps, strength, device):
# get the original timestep using init_timestep
Expand Down Expand Up @@ -643,6 +661,8 @@ def __call__(
callback_steps: int = 1,
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
controlnet_conditioning_scale: float = 1.0,
controlnet_guidance_start: float = 0.0,
controlnet_guidance_end: float = 1.0,
):
r"""
Function invoked when calling the pipeline for generation.
Expand Down Expand Up @@ -719,6 +739,11 @@ def __call__(
controlnet_conditioning_scale (`float`, *optional*, defaults to 1.0):
The outputs of the controlnet are multiplied by `controlnet_conditioning_scale` before they are added
to the residual in the original unet.
controlnet_guidance_start ('float', *optional*, defaults to 0.0):
The percentage of total steps the controlnet starts applying. Must be between 0 and 1.
controlnet_guidance_end ('float', *optional*, defaults to 1.0):
The percentage of total steps the controlnet ends applying. Must be between 0 and 1. Must be greater
than `controlnet_guidance_start`.

Examples:

Expand All @@ -745,6 +770,8 @@ def __call__(
prompt_embeds,
negative_prompt_embeds,
strength,
controlnet_guidance_start,
controlnet_guidance_end,
)

# 2. Define call parameters
Expand Down Expand Up @@ -820,19 +847,31 @@ def __call__(

latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)

down_block_res_samples, mid_block_res_sample = self.controlnet(
latent_model_input,
t,
encoder_hidden_states=prompt_embeds,
controlnet_cond=controlnet_conditioning_image,
return_dict=False,
)

down_block_res_samples = [
down_block_res_sample * controlnet_conditioning_scale
for down_block_res_sample in down_block_res_samples
]
mid_block_res_sample *= controlnet_conditioning_scale
# compute the percentage of total steps we are at
current_sampling_percent = i / len(timesteps)

if (
current_sampling_percent < controlnet_guidance_start
or current_sampling_percent > controlnet_guidance_end
):
# do not apply the controlnet
down_block_res_samples = None
mid_block_res_sample = None
else:
# apply the controlnet
down_block_res_samples, mid_block_res_sample = self.controlnet(
latent_model_input,
t,
encoder_hidden_states=prompt_embeds,
controlnet_cond=controlnet_conditioning_image,
return_dict=False,
)

down_block_res_samples = [
down_block_res_sample * controlnet_conditioning_scale
for down_block_res_sample in down_block_res_samples
]
mid_block_res_sample *= controlnet_conditioning_scale

# predict the noise residual
noise_pred = self.unet(
Expand Down