@@ -437,6 +437,8 @@ def check_inputs(
437
437
prompt_embeds = None ,
438
438
negative_prompt_embeds = None ,
439
439
strength = None ,
440
+ controlnet_guidance_start = None ,
441
+ controlnet_guidance_end = None ,
440
442
):
441
443
if height % 8 != 0 or width % 8 != 0 :
442
444
raise ValueError (f"`height` and `width` have to be divisible by 8 but are { height } and { width } ." )
@@ -542,7 +544,23 @@ def check_inputs(
542
544
)
543
545
544
546
if strength < 0 or strength > 1 :
545
- raise ValueError (f"The value of strength should in [0.0, 1.0] but is { strength } " )
547
+ raise ValueError (f"The value of `strength` should in [0.0, 1.0] but is { strength } " )
548
+
549
+ if controlnet_guidance_start < 0 or controlnet_guidance_start > 1 :
550
+ raise ValueError (
551
+ f"The value of `controlnet_guidance_start` should in [0.0, 1.0] but is { controlnet_guidance_start } "
552
+ )
553
+
554
+ if controlnet_guidance_end < 0 or controlnet_guidance_end > 1 :
555
+ raise ValueError (
556
+ f"The value of `controlnet_guidance_end` should in [0.0, 1.0] but is { controlnet_guidance_end } "
557
+ )
558
+
559
+ if controlnet_guidance_start > controlnet_guidance_end :
560
+ raise ValueError (
561
+ "The value of `controlnet_guidance_start` should be less than `controlnet_guidance_end`, but got"
562
+ f" `controlnet_guidance_start` { controlnet_guidance_start } >= `controlnet_guidance_end` { controlnet_guidance_end } "
563
+ )
546
564
547
565
def get_timesteps (self , num_inference_steps , strength , device ):
548
566
# get the original timestep using init_timestep
@@ -643,6 +661,8 @@ def __call__(
643
661
callback_steps : int = 1 ,
644
662
cross_attention_kwargs : Optional [Dict [str , Any ]] = None ,
645
663
controlnet_conditioning_scale : float = 1.0 ,
664
+ controlnet_guidance_start : float = 0.0 ,
665
+ controlnet_guidance_end : float = 1.0 ,
646
666
):
647
667
r"""
648
668
Function invoked when calling the pipeline for generation.
@@ -719,6 +739,11 @@ def __call__(
719
739
controlnet_conditioning_scale (`float`, *optional*, defaults to 1.0):
720
740
The outputs of the controlnet are multiplied by `controlnet_conditioning_scale` before they are added
721
741
to the residual in the original unet.
742
+ controlnet_guidance_start ('float', *optional*, defaults to 0.0):
743
+ The percentage of total steps the controlnet starts applying. Must be between 0 and 1.
744
+ controlnet_guidance_end ('float', *optional*, defaults to 1.0):
745
+ The percentage of total steps the controlnet ends applying. Must be between 0 and 1. Must be greater
746
+ than `controlnet_guidance_start`.
722
747
723
748
Examples:
724
749
@@ -745,6 +770,8 @@ def __call__(
745
770
prompt_embeds ,
746
771
negative_prompt_embeds ,
747
772
strength ,
773
+ controlnet_guidance_start ,
774
+ controlnet_guidance_end ,
748
775
)
749
776
750
777
# 2. Define call parameters
@@ -820,19 +847,31 @@ def __call__(
820
847
821
848
latent_model_input = self .scheduler .scale_model_input (latent_model_input , t )
822
849
823
- down_block_res_samples , mid_block_res_sample = self .controlnet (
824
- latent_model_input ,
825
- t ,
826
- encoder_hidden_states = prompt_embeds ,
827
- controlnet_cond = controlnet_conditioning_image ,
828
- return_dict = False ,
829
- )
830
-
831
- down_block_res_samples = [
832
- down_block_res_sample * controlnet_conditioning_scale
833
- for down_block_res_sample in down_block_res_samples
834
- ]
835
- mid_block_res_sample *= controlnet_conditioning_scale
850
+ # compute the percentage of total steps we are at
851
+ current_sampling_percent = i / len (timesteps )
852
+
853
+ if (
854
+ current_sampling_percent < controlnet_guidance_start
855
+ or current_sampling_percent > controlnet_guidance_end
856
+ ):
857
+ # do not apply the controlnet
858
+ down_block_res_samples = None
859
+ mid_block_res_sample = None
860
+ else :
861
+ # apply the controlnet
862
+ down_block_res_samples , mid_block_res_sample = self .controlnet (
863
+ latent_model_input ,
864
+ t ,
865
+ encoder_hidden_states = prompt_embeds ,
866
+ controlnet_cond = controlnet_conditioning_image ,
867
+ return_dict = False ,
868
+ )
869
+
870
+ down_block_res_samples = [
871
+ down_block_res_sample * controlnet_conditioning_scale
872
+ for down_block_res_sample in down_block_res_samples
873
+ ]
874
+ mid_block_res_sample *= controlnet_conditioning_scale
836
875
837
876
# predict the noise residual
838
877
noise_pred = self .unet (
0 commit comments