From 58c2dc91b3f6ff53b9e9c8824edaf9e2ca38970b Mon Sep 17 00:00:00 2001 From: Hyowon Ha Date: Sat, 18 Mar 2023 17:38:38 +0900 Subject: [PATCH 1/2] Add guidance start/end parameters to community controlnet img2img pipeline --- .../stable_diffusion_controlnet_img2img.py | 60 ++++++++++++++----- 1 file changed, 46 insertions(+), 14 deletions(-) diff --git a/examples/community/stable_diffusion_controlnet_img2img.py b/examples/community/stable_diffusion_controlnet_img2img.py index 5aa5e47c6578..6298a373a1c5 100644 --- a/examples/community/stable_diffusion_controlnet_img2img.py +++ b/examples/community/stable_diffusion_controlnet_img2img.py @@ -437,6 +437,8 @@ def check_inputs( prompt_embeds=None, negative_prompt_embeds=None, strength=None, + controlnet_guidance_start=None, + controlnet_guidance_end=None, ): if height % 8 != 0 or width % 8 != 0: raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") @@ -542,7 +544,19 @@ def check_inputs( ) if strength < 0 or strength > 1: - raise ValueError(f"The value of strength should in [0.0, 1.0] but is {strength}") + raise ValueError(f"The value of `strength` should in [0.0, 1.0] but is {strength}") + + if controlnet_guidance_start < 0 or controlnet_guidance_start > 1: + raise ValueError(f"The value of `controlnet_guidance_start` should in [0.0, 1.0] but is {controlnet_guidance_start}") + + if controlnet_guidance_end < 0 or controlnet_guidance_end > 1: + raise ValueError(f"The value of `controlnet_guidance_end` should in [0.0, 1.0] but is {controlnet_guidance_end}") + + if controlnet_guidance_start > controlnet_guidance_end: + raise ValueError( + "The value of `controlnet_guidance_start` should be less than `controlnet_guidance_end`, but got" + f" `controlnet_guidance_start` {controlnet_guidance_start} >= `controlnet_guidance_end` {controlnet_guidance_end}" + ) def get_timesteps(self, num_inference_steps, strength, device): # get the original timestep using init_timestep @@ -643,6 +657,8 @@ def __call__( callback_steps: int = 1, cross_attention_kwargs: Optional[Dict[str, Any]] = None, controlnet_conditioning_scale: float = 1.0, + controlnet_guidance_start: float = 0.0, + controlnet_guidance_end: float = 1.0, ): r""" Function invoked when calling the pipeline for generation. @@ -719,6 +735,11 @@ def __call__( controlnet_conditioning_scale (`float`, *optional*, defaults to 1.0): The outputs of the controlnet are multiplied by `controlnet_conditioning_scale` before they are added to the residual in the original unet. + controlnet_guidance_start ('float', *optional*, defaults to 0.0): + The percentage of total steps the controlnet starts applying. Must be between 0 and 1. + controlnet_guidance_end ('float', *optional*, defaults to 1.0): + The percentage of total steps the controlnet ends applying. Must be between 0 and 1. Must be greater + than `controlnet_guidance_start`. Examples: @@ -745,6 +766,8 @@ def __call__( prompt_embeds, negative_prompt_embeds, strength, + controlnet_guidance_start, + controlnet_guidance_end, ) # 2. Define call parameters @@ -820,19 +843,28 @@ def __call__( latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) - down_block_res_samples, mid_block_res_sample = self.controlnet( - latent_model_input, - t, - encoder_hidden_states=prompt_embeds, - controlnet_cond=controlnet_conditioning_image, - return_dict=False, - ) - - down_block_res_samples = [ - down_block_res_sample * controlnet_conditioning_scale - for down_block_res_sample in down_block_res_samples - ] - mid_block_res_sample *= controlnet_conditioning_scale + # compute the percentage of total steps we are at + current_sampling_percent = i / len(timesteps) + + if current_sampling_percent < controlnet_guidance_start or current_sampling_percent > controlnet_guidance_end: + # do not apply the controlnet + down_block_res_samples = None + mid_block_res_sample = None + else: + # apply the controlnet + down_block_res_samples, mid_block_res_sample = self.controlnet( + latent_model_input, + t, + encoder_hidden_states=prompt_embeds, + controlnet_cond=controlnet_conditioning_image, + return_dict=False, + ) + + down_block_res_samples = [ + down_block_res_sample * controlnet_conditioning_scale + for down_block_res_sample in down_block_res_samples + ] + mid_block_res_sample *= controlnet_conditioning_scale # predict the noise residual noise_pred = self.unet( From 13b75cde26c66f8db85cfbbb12a41c091693e129 Mon Sep 17 00:00:00 2001 From: Hyowon Ha Date: Sat, 18 Mar 2023 18:42:55 +0900 Subject: [PATCH 2/2] Fix formats --- .../stable_diffusion_controlnet_img2img.py | 13 ++++++++++--- 1 file changed, 10 insertions(+), 3 deletions(-) diff --git a/examples/community/stable_diffusion_controlnet_img2img.py b/examples/community/stable_diffusion_controlnet_img2img.py index 6298a373a1c5..51533a92d84a 100644 --- a/examples/community/stable_diffusion_controlnet_img2img.py +++ b/examples/community/stable_diffusion_controlnet_img2img.py @@ -547,10 +547,14 @@ def check_inputs( raise ValueError(f"The value of `strength` should in [0.0, 1.0] but is {strength}") if controlnet_guidance_start < 0 or controlnet_guidance_start > 1: - raise ValueError(f"The value of `controlnet_guidance_start` should in [0.0, 1.0] but is {controlnet_guidance_start}") + raise ValueError( + f"The value of `controlnet_guidance_start` should in [0.0, 1.0] but is {controlnet_guidance_start}" + ) if controlnet_guidance_end < 0 or controlnet_guidance_end > 1: - raise ValueError(f"The value of `controlnet_guidance_end` should in [0.0, 1.0] but is {controlnet_guidance_end}") + raise ValueError( + f"The value of `controlnet_guidance_end` should in [0.0, 1.0] but is {controlnet_guidance_end}" + ) if controlnet_guidance_start > controlnet_guidance_end: raise ValueError( @@ -846,7 +850,10 @@ def __call__( # compute the percentage of total steps we are at current_sampling_percent = i / len(timesteps) - if current_sampling_percent < controlnet_guidance_start or current_sampling_percent > controlnet_guidance_end: + if ( + current_sampling_percent < controlnet_guidance_start + or current_sampling_percent > controlnet_guidance_end + ): # do not apply the controlnet down_block_res_samples = None mid_block_res_sample = None