From 5febd94b296d4953cb556d0c24281c25101c20c4 Mon Sep 17 00:00:00 2001 From: hlky Date: Mon, 2 Dec 2024 11:34:54 +0000 Subject: [PATCH] Add `sigmas` to Flux pipelines --- src/diffusers/pipelines/flux/pipeline_flux.py | 15 +++++++-------- .../pipelines/flux/pipeline_flux_control.py | 15 +++++++-------- .../flux/pipeline_flux_control_img2img.py | 15 +++++++-------- .../pipelines/flux/pipeline_flux_controlnet.py | 15 +++++++-------- .../pipeline_flux_controlnet_image_to_image.py | 13 +++++++------ .../flux/pipeline_flux_controlnet_inpainting.py | 13 +++++++------ .../pipelines/flux/pipeline_flux_fill.py | 15 +++++++-------- .../pipelines/flux/pipeline_flux_img2img.py | 15 +++++++-------- .../pipelines/flux/pipeline_flux_inpaint.py | 15 +++++++-------- 9 files changed, 63 insertions(+), 68 deletions(-) diff --git a/src/diffusers/pipelines/flux/pipeline_flux.py b/src/diffusers/pipelines/flux/pipeline_flux.py index e0add1e60ce2..ec2801625552 100644 --- a/src/diffusers/pipelines/flux/pipeline_flux.py +++ b/src/diffusers/pipelines/flux/pipeline_flux.py @@ -554,7 +554,7 @@ def __call__( height: Optional[int] = None, width: Optional[int] = None, num_inference_steps: int = 28, - timesteps: List[int] = None, + sigmas: Optional[List[float]] = None, guidance_scale: float = 3.5, num_images_per_prompt: Optional[int] = 1, generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, @@ -585,10 +585,10 @@ def __call__( num_inference_steps (`int`, *optional*, defaults to 50): The number of denoising steps. More denoising steps usually lead to a higher quality image at the expense of slower inference. - timesteps (`List[int]`, *optional*): - Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument - in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is - passed will be used. Must be in descending order. + sigmas (`List[float]`, *optional*): + Custom sigmas to use for the denoising process with schedulers which support a `sigmas` argument in + their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed + will be used. guidance_scale (`float`, *optional*, defaults to 7.0): Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). `guidance_scale` is defined as `w` of equation 2. of [Imagen @@ -699,7 +699,7 @@ def __call__( ) # 5. Prepare timesteps - sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps) + sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps) if sigmas is None else sigmas image_seq_len = latents.shape[1] mu = calculate_shift( image_seq_len, @@ -712,8 +712,7 @@ def __call__( self.scheduler, num_inference_steps, device, - timesteps, - sigmas, + sigmas=sigmas, mu=mu, ) num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0) diff --git a/src/diffusers/pipelines/flux/pipeline_flux_control.py b/src/diffusers/pipelines/flux/pipeline_flux_control.py index 04a93ba6351c..dc3ca8cf7b09 100644 --- a/src/diffusers/pipelines/flux/pipeline_flux_control.py +++ b/src/diffusers/pipelines/flux/pipeline_flux_control.py @@ -621,7 +621,7 @@ def __call__( height: Optional[int] = None, width: Optional[int] = None, num_inference_steps: int = 28, - timesteps: List[int] = None, + sigmas: Optional[List[float]] = None, guidance_scale: float = 3.5, num_images_per_prompt: Optional[int] = 1, generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, @@ -660,10 +660,10 @@ def __call__( num_inference_steps (`int`, *optional*, defaults to 50): The number of denoising steps. More denoising steps usually lead to a higher quality image at the expense of slower inference. - timesteps (`List[int]`, *optional*): - Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument - in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is - passed will be used. Must be in descending order. + sigmas (`List[float]`, *optional*): + Custom sigmas to use for the denoising process with schedulers which support a `sigmas` argument in + their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed + will be used. guidance_scale (`float`, *optional*, defaults to 7.0): Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). `guidance_scale` is defined as `w` of equation 2. of [Imagen @@ -799,7 +799,7 @@ def __call__( ) # 5. Prepare timesteps - sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps) + sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps) if sigmas is None else sigmas image_seq_len = latents.shape[1] mu = calculate_shift( image_seq_len, @@ -812,8 +812,7 @@ def __call__( self.scheduler, num_inference_steps, device, - timesteps, - sigmas, + sigmas=sigmas, mu=mu, ) num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0) diff --git a/src/diffusers/pipelines/flux/pipeline_flux_control_img2img.py b/src/diffusers/pipelines/flux/pipeline_flux_control_img2img.py index ef20ab98ee2e..7001b19569f2 100644 --- a/src/diffusers/pipelines/flux/pipeline_flux_control_img2img.py +++ b/src/diffusers/pipelines/flux/pipeline_flux_control_img2img.py @@ -647,7 +647,7 @@ def __call__( width: Optional[int] = None, strength: float = 0.6, num_inference_steps: int = 28, - timesteps: List[int] = None, + sigmas: Optional[List[float]] = None, guidance_scale: float = 7.0, num_images_per_prompt: Optional[int] = 1, generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, @@ -698,10 +698,10 @@ def __call__( num_inference_steps (`int`, *optional*, defaults to 50): The number of denoising steps. More denoising steps usually lead to a higher quality image at the expense of slower inference. - timesteps (`List[int]`, *optional*): - Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument - in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is - passed will be used. Must be in descending order. + sigmas (`List[float]`, *optional*): + Custom sigmas to use for the denoising process with schedulers which support a `sigmas` argument in + their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed + will be used. guidance_scale (`float`, *optional*, defaults to 7.0): Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). `guidance_scale` is defined as `w` of equation 2. of [Imagen @@ -805,7 +805,7 @@ def __call__( ) # 4.Prepare timesteps - sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps) + sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps) if sigmas is None else sigmas image_seq_len = (int(height) // self.vae_scale_factor // 2) * (int(width) // self.vae_scale_factor // 2) mu = calculate_shift( image_seq_len, @@ -818,8 +818,7 @@ def __call__( self.scheduler, num_inference_steps, device, - timesteps, - sigmas, + sigmas=sigmas, mu=mu, ) timesteps, num_inference_steps = self.get_timesteps(num_inference_steps, strength, device) diff --git a/src/diffusers/pipelines/flux/pipeline_flux_controlnet.py b/src/diffusers/pipelines/flux/pipeline_flux_controlnet.py index ce7ea35c6cea..4c2d2a0a3db9 100644 --- a/src/diffusers/pipelines/flux/pipeline_flux_controlnet.py +++ b/src/diffusers/pipelines/flux/pipeline_flux_controlnet.py @@ -602,7 +602,7 @@ def __call__( height: Optional[int] = None, width: Optional[int] = None, num_inference_steps: int = 28, - timesteps: List[int] = None, + sigmas: Optional[List[float]] = None, guidance_scale: float = 7.0, control_guidance_start: Union[float, List[float]] = 0.0, control_guidance_end: Union[float, List[float]] = 1.0, @@ -638,10 +638,10 @@ def __call__( num_inference_steps (`int`, *optional*, defaults to 50): The number of denoising steps. More denoising steps usually lead to a higher quality image at the expense of slower inference. - timesteps (`List[int]`, *optional*): - Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument - in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is - passed will be used. Must be in descending order. + sigmas (`List[float]`, *optional*): + Custom sigmas to use for the denoising process with schedulers which support a `sigmas` argument in + their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed + will be used. guidance_scale (`float`, *optional*, defaults to 7.0): Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). `guidance_scale` is defined as `w` of equation 2. of [Imagen @@ -872,7 +872,7 @@ def __call__( ) # 5. Prepare timesteps - sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps) + sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps) if sigmas is None else sigmas image_seq_len = latents.shape[1] mu = calculate_shift( image_seq_len, @@ -885,8 +885,7 @@ def __call__( self.scheduler, num_inference_steps, device, - timesteps, - sigmas, + sigmas=sigmas, mu=mu, ) diff --git a/src/diffusers/pipelines/flux/pipeline_flux_controlnet_image_to_image.py b/src/diffusers/pipelines/flux/pipeline_flux_controlnet_image_to_image.py index 6ab34d8a9c08..4c82d73f0379 100644 --- a/src/diffusers/pipelines/flux/pipeline_flux_controlnet_image_to_image.py +++ b/src/diffusers/pipelines/flux/pipeline_flux_controlnet_image_to_image.py @@ -646,7 +646,7 @@ def __call__( width: Optional[int] = None, strength: float = 0.6, num_inference_steps: int = 28, - timesteps: List[int] = None, + sigmas: Optional[List[float]] = None, guidance_scale: float = 7.0, control_guidance_start: Union[float, List[float]] = 0.0, control_guidance_end: Union[float, List[float]] = 1.0, @@ -685,8 +685,10 @@ def __call__( num_inference_steps (`int`, *optional*, defaults to 28): The number of denoising steps. More denoising steps usually lead to a higher quality image at the expense of slower inference. - timesteps (`List[int]`, *optional*): - Custom timesteps to use for the denoising process. + sigmas (`List[float]`, *optional*): + Custom sigmas to use for the denoising process with schedulers which support a `sigmas` argument in + their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed + will be used. guidance_scale (`float`, *optional*, defaults to 7.0): Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). control_mode (`int` or `List[int]`, *optional*): @@ -858,7 +860,7 @@ def __call__( control_mode = torch.tensor(control_mode_).to(device, dtype=torch.long) control_mode = control_mode.reshape([-1, 1]) - sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps) + sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps) if sigmas is None else sigmas image_seq_len = (int(height) // self.vae_scale_factor // 2) * (int(width) // self.vae_scale_factor // 2) mu = calculate_shift( image_seq_len, @@ -871,8 +873,7 @@ def __call__( self.scheduler, num_inference_steps, device, - timesteps, - sigmas, + sigmas=sigmas, mu=mu, ) timesteps, num_inference_steps = self.get_timesteps(num_inference_steps, strength, device) diff --git a/src/diffusers/pipelines/flux/pipeline_flux_controlnet_inpainting.py b/src/diffusers/pipelines/flux/pipeline_flux_controlnet_inpainting.py index d81cffaca35b..c557cf134b05 100644 --- a/src/diffusers/pipelines/flux/pipeline_flux_controlnet_inpainting.py +++ b/src/diffusers/pipelines/flux/pipeline_flux_controlnet_inpainting.py @@ -752,7 +752,7 @@ def __call__( width: Optional[int] = None, strength: float = 0.6, padding_mask_crop: Optional[int] = None, - timesteps: List[int] = None, + sigmas: Optional[List[float]] = None, num_inference_steps: int = 28, guidance_scale: float = 7.0, control_guidance_start: Union[float, List[float]] = 0.0, @@ -799,8 +799,10 @@ def __call__( num_inference_steps (`int`, *optional*, defaults to 28): The number of denoising steps. More denoising steps usually lead to a higher quality image at the expense of slower inference. - timesteps (`List[int]`, *optional*): - Custom timesteps to use for the denoising process. + sigmas (`List[float]`, *optional*): + Custom sigmas to use for the denoising process with schedulers which support a `sigmas` argument in + their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed + will be used. guidance_scale (`float`, *optional*, defaults to 7.0): Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). control_guidance_start (`float` or `List[float]`, *optional*, defaults to 0.0): @@ -1009,7 +1011,7 @@ def __call__( # 6. Prepare timesteps - sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps) + sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps) if sigmas is None else sigmas image_seq_len = (int(global_height) // self.vae_scale_factor // 2) * ( int(global_width) // self.vae_scale_factor // 2 ) @@ -1024,8 +1026,7 @@ def __call__( self.scheduler, num_inference_steps, device, - timesteps, - sigmas, + sigmas=sigmas, mu=mu, ) timesteps, num_inference_steps = self.get_timesteps(num_inference_steps, strength, device) diff --git a/src/diffusers/pipelines/flux/pipeline_flux_fill.py b/src/diffusers/pipelines/flux/pipeline_flux_fill.py index 32b2bbefa709..723478ce724d 100644 --- a/src/diffusers/pipelines/flux/pipeline_flux_fill.py +++ b/src/diffusers/pipelines/flux/pipeline_flux_fill.py @@ -689,7 +689,7 @@ def __call__( height: Optional[int] = None, width: Optional[int] = None, num_inference_steps: int = 50, - timesteps: List[int] = None, + sigmas: Optional[List[float]] = None, guidance_scale: float = 30.0, num_images_per_prompt: Optional[int] = 1, generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, @@ -735,10 +735,10 @@ def __call__( num_inference_steps (`int`, *optional*, defaults to 50): The number of denoising steps. More denoising steps usually lead to a higher quality image at the expense of slower inference. - timesteps (`List[int]`, *optional*): - Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument - in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is - passed will be used. Must be in descending order. + sigmas (`List[float]`, *optional*): + Custom sigmas to use for the denoising process with schedulers which support a `sigmas` argument in + their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed + will be used. guidance_scale (`float`, *optional*, defaults to 7.0): Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). `guidance_scale` is defined as `w` of equation 2. of [Imagen @@ -878,7 +878,7 @@ def __call__( masked_image_latents = torch.cat((masked_image_latents, mask), dim=-1) # 6. Prepare timesteps - sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps) + sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps) if sigmas is None else sigmas image_seq_len = latents.shape[1] mu = calculate_shift( image_seq_len, @@ -891,8 +891,7 @@ def __call__( self.scheduler, num_inference_steps, device, - timesteps, - sigmas, + sigmas=sigmas, mu=mu, ) num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0) diff --git a/src/diffusers/pipelines/flux/pipeline_flux_img2img.py b/src/diffusers/pipelines/flux/pipeline_flux_img2img.py index d34d9b53aa6b..2b336fbdd472 100644 --- a/src/diffusers/pipelines/flux/pipeline_flux_img2img.py +++ b/src/diffusers/pipelines/flux/pipeline_flux_img2img.py @@ -593,7 +593,7 @@ def __call__( width: Optional[int] = None, strength: float = 0.6, num_inference_steps: int = 28, - timesteps: List[int] = None, + sigmas: Optional[List[float]] = None, guidance_scale: float = 7.0, num_images_per_prompt: Optional[int] = 1, generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, @@ -636,10 +636,10 @@ def __call__( num_inference_steps (`int`, *optional*, defaults to 50): The number of denoising steps. More denoising steps usually lead to a higher quality image at the expense of slower inference. - timesteps (`List[int]`, *optional*): - Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument - in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is - passed will be used. Must be in descending order. + sigmas (`List[float]`, *optional*): + Custom sigmas to use for the denoising process with schedulers which support a `sigmas` argument in + their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed + will be used. guidance_scale (`float`, *optional*, defaults to 7.0): Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). `guidance_scale` is defined as `w` of equation 2. of [Imagen @@ -742,7 +742,7 @@ def __call__( ) # 4.Prepare timesteps - sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps) + sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps) if sigmas is None else sigmas image_seq_len = (int(height) // self.vae_scale_factor // 2) * (int(width) // self.vae_scale_factor // 2) mu = calculate_shift( image_seq_len, @@ -755,8 +755,7 @@ def __call__( self.scheduler, num_inference_steps, device, - timesteps, - sigmas, + sigmas=sigmas, mu=mu, ) timesteps, num_inference_steps = self.get_timesteps(num_inference_steps, strength, device) diff --git a/src/diffusers/pipelines/flux/pipeline_flux_inpaint.py b/src/diffusers/pipelines/flux/pipeline_flux_inpaint.py index 3fcf6ace8a79..15abdb90ebd0 100644 --- a/src/diffusers/pipelines/flux/pipeline_flux_inpaint.py +++ b/src/diffusers/pipelines/flux/pipeline_flux_inpaint.py @@ -693,7 +693,7 @@ def __call__( padding_mask_crop: Optional[int] = None, strength: float = 0.6, num_inference_steps: int = 28, - timesteps: List[int] = None, + sigmas: Optional[List[float]] = None, guidance_scale: float = 7.0, num_images_per_prompt: Optional[int] = 1, generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, @@ -753,10 +753,10 @@ def __call__( num_inference_steps (`int`, *optional*, defaults to 50): The number of denoising steps. More denoising steps usually lead to a higher quality image at the expense of slower inference. - timesteps (`List[int]`, *optional*): - Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument - in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is - passed will be used. Must be in descending order. + sigmas (`List[float]`, *optional*): + Custom sigmas to use for the denoising process with schedulers which support a `sigmas` argument in + their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed + will be used. guidance_scale (`float`, *optional*, defaults to 7.0): Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). `guidance_scale` is defined as `w` of equation 2. of [Imagen @@ -873,7 +873,7 @@ def __call__( ) # 4.Prepare timesteps - sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps) + sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps) if sigmas is None else sigmas image_seq_len = (int(height) // self.vae_scale_factor // 2) * (int(width) // self.vae_scale_factor // 2) mu = calculate_shift( image_seq_len, @@ -886,8 +886,7 @@ def __call__( self.scheduler, num_inference_steps, device, - timesteps, - sigmas, + sigmas=sigmas, mu=mu, ) timesteps, num_inference_steps = self.get_timesteps(num_inference_steps, strength, device)