From 28296791e00394cc34b007dfdd853cf8b0c9e044 Mon Sep 17 00:00:00 2001 From: Aryan Date: Thu, 21 Nov 2024 06:53:58 +0100 Subject: [PATCH 01/58] update --- scripts/convert_flux_to_diffusers.py | 6 +- .../models/transformers/transformer_flux.py | 7 +- src/diffusers/pipelines/flux/pipeline_flux.py | 88 +++++++++++++++++-- .../flux/pipeline_flux_controlnet.py | 1 + 4 files changed, 94 insertions(+), 8 deletions(-) diff --git a/scripts/convert_flux_to_diffusers.py b/scripts/convert_flux_to_diffusers.py index 05a1da256d33..dd9b464c818c 100644 --- a/scripts/convert_flux_to_diffusers.py +++ b/scripts/convert_flux_to_diffusers.py @@ -37,6 +37,8 @@ parser.add_argument("--original_state_dict_repo_id", default=None, type=str) parser.add_argument("--filename", default="flux.safetensors", type=str) parser.add_argument("--checkpoint_path", default=None, type=str) +parser.add_argument("--in_channels", type=int, default=64) +parser.add_argument("--out_channels", type=int, default=None) parser.add_argument("--vae", action="store_true") parser.add_argument("--transformer", action="store_true") parser.add_argument("--output_path", type=str) @@ -282,7 +284,9 @@ def main(args): converted_transformer_state_dict = convert_flux_transformer_checkpoint_to_diffusers( original_ckpt, num_layers, num_single_layers, inner_dim, mlp_ratio=mlp_ratio ) - transformer = FluxTransformer2DModel(guidance_embeds=has_guidance) + transformer = FluxTransformer2DModel( + in_channels=args.in_channels, out_channels=args.out_channels, guidance_embeds=has_guidance + ) transformer.load_state_dict(converted_transformer_state_dict, strict=True) print( diff --git a/src/diffusers/models/transformers/transformer_flux.py b/src/diffusers/models/transformers/transformer_flux.py index 0ad3be866019..18527e3c46c0 100644 --- a/src/diffusers/models/transformers/transformer_flux.py +++ b/src/diffusers/models/transformers/transformer_flux.py @@ -238,6 +238,7 @@ def __init__( self, patch_size: int = 1, in_channels: int = 64, + out_channels: Optional[int] = None, num_layers: int = 19, num_single_layers: int = 38, attention_head_dim: int = 128, @@ -248,7 +249,7 @@ def __init__( axes_dims_rope: Tuple[int] = (16, 56, 56), ): super().__init__() - self.out_channels = in_channels + self.out_channels = out_channels or in_channels self.inner_dim = self.config.num_attention_heads * self.config.attention_head_dim self.pos_embed = FluxPosEmbed(theta=10000, axes_dim=axes_dims_rope) @@ -261,7 +262,7 @@ def __init__( ) self.context_embedder = nn.Linear(self.config.joint_attention_dim, self.inner_dim) - self.x_embedder = torch.nn.Linear(self.config.in_channels, self.inner_dim) + self.x_embedder = nn.Linear(self.config.in_channels, self.inner_dim) self.transformer_blocks = nn.ModuleList( [ @@ -449,6 +450,7 @@ def forward( logger.warning( "Passing `scale` via `joint_attention_kwargs` when not using the PEFT backend is ineffective." ) + hidden_states = self.x_embedder(hidden_states) timestep = timestep.to(hidden_states.dtype) * 1000 @@ -456,6 +458,7 @@ def forward( guidance = guidance.to(hidden_states.dtype) * 1000 else: guidance = None + temb = ( self.time_text_embed(timestep, pooled_projections) if guidance is None diff --git a/src/diffusers/pipelines/flux/pipeline_flux.py b/src/diffusers/pipelines/flux/pipeline_flux.py index 12996f3f3e92..40975c2cc0db 100644 --- a/src/diffusers/pipelines/flux/pipeline_flux.py +++ b/src/diffusers/pipelines/flux/pipeline_flux.py @@ -19,7 +19,7 @@ import torch from transformers import CLIPTextModel, CLIPTokenizer, T5EncoderModel, T5TokenizerFast -from ...image_processor import VaeImageProcessor +from ...image_processor import PipelineImageInput, VaeImageProcessor from ...loaders import FluxLoraLoaderMixin, FromSingleFileMixin, TextualInversionLoaderMixin from ...models.autoencoders import AutoencoderKL from ...models.transformers import FluxTransformer2DModel @@ -513,7 +513,7 @@ def prepare_latents( shape = (batch_size, num_channels_latents, height, width) if latents is not None: - latent_image_ids = self._prepare_latent_image_ids(batch_size, height, width, device, dtype) + latent_image_ids = self._prepare_latent_image_ids(batch_size, height // 2, width // 2, device, dtype) return latents.to(device=device, dtype=dtype), latent_image_ids if isinstance(generator, list) and len(generator) != batch_size: @@ -529,6 +529,41 @@ def prepare_latents( return latents, latent_image_ids + # Copied from diffusers.pipelines.controlnet_sd3.pipeline_stable_diffusion_3_controlnet.StableDiffusion3ControlNetPipeline.prepare_image + def prepare_image( + self, + image, + width, + height, + batch_size, + num_images_per_prompt, + device, + dtype, + do_classifier_free_guidance=False, + guess_mode=False, + ): + if isinstance(image, torch.Tensor): + pass + else: + image = self.image_processor.preprocess(image, height=height, width=width) + + image_batch_size = image.shape[0] + + if image_batch_size == 1: + repeat_by = batch_size + else: + # image batch size is the same as prompt batch size + repeat_by = num_images_per_prompt + + image = image.repeat_interleave(repeat_by, dim=0) + + image = image.to(device=device, dtype=dtype) + + if do_classifier_free_guidance and not guess_mode: + image = torch.cat([image] * 2) + + return image + @property def guidance_scale(self): return self._guidance_scale @@ -556,9 +591,11 @@ def __call__( num_inference_steps: int = 28, timesteps: List[int] = None, guidance_scale: float = 3.5, + control_image: PipelineImageInput = None, num_images_per_prompt: Optional[int] = 1, generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, latents: Optional[torch.FloatTensor] = None, + control_latents: Optional[torch.FloatTensor] = None, prompt_embeds: Optional[torch.FloatTensor] = None, pooled_prompt_embeds: Optional[torch.FloatTensor] = None, output_type: Optional[str] = "pil", @@ -595,6 +632,14 @@ def __call__( Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, usually at the expense of lower image quality. + control_image (`torch.Tensor`, `PIL.Image.Image`, `np.ndarray`, `List[torch.Tensor]`, `List[PIL.Image.Image]`, `List[np.ndarray]`,: + `List[List[torch.Tensor]]`, `List[List[np.ndarray]]` or `List[List[PIL.Image.Image]]`): + The ControlNet input condition to provide guidance to the `unet` for generation. If the type is + specified as `torch.Tensor`, it is passed to ControlNet as is. `PIL.Image.Image` can also be accepted + as an image. The dimensions of the output image defaults to `image`'s dimensions. If height and/or + width are passed, `image` is resized accordingly. If multiple ControlNets are specified in `init`, + images must be passed as a list such that each element of the list can be correctly batched for input + to a single ControlNet. num_images_per_prompt (`int`, *optional*, defaults to 1): The number of images to generate per prompt. generator (`torch.Generator` or `List[torch.Generator]`, *optional*): @@ -667,6 +712,7 @@ def __call__( device = self._execution_device + # 3. Prepare text embeddings lora_scale = ( self.joint_attention_kwargs.get("scale", None) if self.joint_attention_kwargs is not None else None ) @@ -686,7 +732,35 @@ def __call__( ) # 4. Prepare latent variables - num_channels_latents = self.transformer.config.in_channels // 4 + num_channels_latents = ( + self.transformer.config.in_channels // 4 + if control_image is None + else self.transformer.config.in_channels // 8 + ) + + if control_image is not None and control_latents is None: + control_image = self.prepare_image( + image=control_image, + width=width, + height=height, + batch_size=batch_size * num_images_per_prompt, + num_images_per_prompt=num_images_per_prompt, + device=device, + dtype=self.vae.dtype, + ) + + control_latents = self.vae.encode(control_image).latent_dist.sample(generator=generator) + control_latents = (control_latents - self.vae.config.shift_factor) * self.vae.config.scaling_factor + + height_control_image, width_control_image = control_latents.shape[2:] + control_latents = self._pack_latents( + control_latents, + batch_size * num_images_per_prompt, + num_channels_latents, + height_control_image, + width_control_image, + ) + latents, latent_image_ids = self.prepare_latents( batch_size * num_images_per_prompt, num_channels_latents, @@ -732,11 +806,16 @@ def __call__( if self.interrupt: continue + if control_latents is not None: + latent_model_input = torch.cat([latents, control_latents], dim=2) + else: + latent_model_input = latents + # broadcast to batch dimension in a way that's compatible with ONNX/Core ML timestep = t.expand(latents.shape[0]).to(latents.dtype) noise_pred = self.transformer( - hidden_states=latents, + hidden_states=latent_model_input, timestep=timestep / 1000, guidance=guidance, pooled_projections=pooled_prompt_embeds, @@ -774,7 +853,6 @@ def __call__( if output_type == "latent": image = latents - else: latents = self._unpack_latents(latents, height, width, self.vae_scale_factor) latents = (latents / self.vae.config.scaling_factor) + self.vae.config.shift_factor diff --git a/src/diffusers/pipelines/flux/pipeline_flux_controlnet.py b/src/diffusers/pipelines/flux/pipeline_flux_controlnet.py index 904173852ee4..ac5abdca079a 100644 --- a/src/diffusers/pipelines/flux/pipeline_flux_controlnet.py +++ b/src/diffusers/pipelines/flux/pipeline_flux_controlnet.py @@ -736,6 +736,7 @@ def __call__( device = self._execution_device dtype = self.transformer.dtype + # 3. Prepare text embeddings lora_scale = ( self.joint_attention_kwargs.get("scale", None) if self.joint_attention_kwargs is not None else None ) From f56ffb1d377d36296cf423473c872fe6a26e7994 Mon Sep 17 00:00:00 2001 From: yiyixuxu Date: Thu, 21 Nov 2024 17:40:17 +0100 Subject: [PATCH 02/58] add --- scripts/convert_flux_to_diffusers.py | 9 +- src/diffusers/__init__.py | 2 + .../models/transformers/transformer_flux.py | 6 +- src/diffusers/pipelines/__init__.py | 2 + src/diffusers/pipelines/flux/__init__.py | 2 + .../pipelines/flux/pipeline_flux_fill.py | 936 ++++++++++++++++++ .../dummy_torch_and_transformers_objects.py | 15 + 7 files changed, 970 insertions(+), 2 deletions(-) create mode 100644 src/diffusers/pipelines/flux/pipeline_flux_fill.py diff --git a/scripts/convert_flux_to_diffusers.py b/scripts/convert_flux_to_diffusers.py index 05a1da256d33..2530a438bd4b 100644 --- a/scripts/convert_flux_to_diffusers.py +++ b/scripts/convert_flux_to_diffusers.py @@ -279,10 +279,17 @@ def main(args): num_single_layers = 38 inner_dim = 3072 mlp_ratio = 4.0 + + # dev has 64, dev-fill has 384 + in_channels = original_ckpt["img_in.weight"].shape[1] + out_channels = 64 + converted_transformer_state_dict = convert_flux_transformer_checkpoint_to_diffusers( original_ckpt, num_layers, num_single_layers, inner_dim, mlp_ratio=mlp_ratio ) - transformer = FluxTransformer2DModel(guidance_embeds=has_guidance) + transformer = FluxTransformer2DModel( + guidance_embeds=has_guidance, in_channels=in_channels, out_channels=out_channels + ) transformer.load_state_dict(converted_transformer_state_dict, strict=True) print( diff --git a/src/diffusers/__init__.py b/src/diffusers/__init__.py index d9d7491e5c79..f334a8c50768 100644 --- a/src/diffusers/__init__.py +++ b/src/diffusers/__init__.py @@ -272,6 +272,7 @@ "FluxControlNetImg2ImgPipeline", "FluxControlNetInpaintPipeline", "FluxControlNetPipeline", + "FluxFillPipeline", "FluxImg2ImgPipeline", "FluxInpaintPipeline", "FluxPipeline", @@ -737,6 +738,7 @@ FluxControlNetImg2ImgPipeline, FluxControlNetInpaintPipeline, FluxControlNetPipeline, + FluxFillPipeline, FluxImg2ImgPipeline, FluxInpaintPipeline, FluxPipeline, diff --git a/src/diffusers/models/transformers/transformer_flux.py b/src/diffusers/models/transformers/transformer_flux.py index 0ad3be866019..cb294fa51304 100644 --- a/src/diffusers/models/transformers/transformer_flux.py +++ b/src/diffusers/models/transformers/transformer_flux.py @@ -238,6 +238,7 @@ def __init__( self, patch_size: int = 1, in_channels: int = 64, + out_channels: int = None, num_layers: int = 19, num_single_layers: int = 38, attention_head_dim: int = 128, @@ -248,7 +249,10 @@ def __init__( axes_dims_rope: Tuple[int] = (16, 56, 56), ): super().__init__() - self.out_channels = in_channels + if out_channels is None: + self.out_channels = in_channels + else: + self.out_channels = out_channels self.inner_dim = self.config.num_attention_heads * self.config.attention_head_dim self.pos_embed = FluxPosEmbed(theta=10000, axes_dim=axes_dims_rope) diff --git a/src/diffusers/pipelines/__init__.py b/src/diffusers/pipelines/__init__.py index 98574de1ad5f..4c5bcc824ffa 100644 --- a/src/diffusers/pipelines/__init__.py +++ b/src/diffusers/pipelines/__init__.py @@ -133,6 +133,7 @@ "FluxImg2ImgPipeline", "FluxInpaintPipeline", "FluxPipeline", + "FluxFillPipeline", ] _import_structure["audioldm"] = ["AudioLDMPipeline"] _import_structure["audioldm2"] = [ @@ -524,6 +525,7 @@ FluxControlNetImg2ImgPipeline, FluxControlNetInpaintPipeline, FluxControlNetPipeline, + FluxFillPipeline, FluxImg2ImgPipeline, FluxInpaintPipeline, FluxPipeline, diff --git a/src/diffusers/pipelines/flux/__init__.py b/src/diffusers/pipelines/flux/__init__.py index 0ebf5ea6d78d..a19019aaf986 100644 --- a/src/diffusers/pipelines/flux/__init__.py +++ b/src/diffusers/pipelines/flux/__init__.py @@ -26,6 +26,7 @@ _import_structure["pipeline_flux_controlnet"] = ["FluxControlNetPipeline"] _import_structure["pipeline_flux_controlnet_image_to_image"] = ["FluxControlNetImg2ImgPipeline"] _import_structure["pipeline_flux_controlnet_inpainting"] = ["FluxControlNetInpaintPipeline"] + _import_structure["pipeline_flux_fill"] = ["FluxFillPipeline"] _import_structure["pipeline_flux_img2img"] = ["FluxImg2ImgPipeline"] _import_structure["pipeline_flux_inpaint"] = ["FluxInpaintPipeline"] if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: @@ -39,6 +40,7 @@ from .pipeline_flux_controlnet import FluxControlNetPipeline from .pipeline_flux_controlnet_image_to_image import FluxControlNetImg2ImgPipeline from .pipeline_flux_controlnet_inpainting import FluxControlNetInpaintPipeline + from .pipeline_flux_fill import FluxFillPipeline from .pipeline_flux_img2img import FluxImg2ImgPipeline from .pipeline_flux_inpaint import FluxInpaintPipeline else: diff --git a/src/diffusers/pipelines/flux/pipeline_flux_fill.py b/src/diffusers/pipelines/flux/pipeline_flux_fill.py new file mode 100644 index 000000000000..77024625c9c7 --- /dev/null +++ b/src/diffusers/pipelines/flux/pipeline_flux_fill.py @@ -0,0 +1,936 @@ +# Copyright 2024 Black Forest Labs and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import inspect +from typing import Any, Callable, Dict, List, Optional, Union + +import numpy as np +import torch +from transformers import CLIPTextModel, CLIPTokenizer, T5EncoderModel, T5TokenizerFast + +from ...image_processor import VaeImageProcessor +from ...loaders import FluxLoraLoaderMixin, FromSingleFileMixin, TextualInversionLoaderMixin +from ...models.autoencoders import AutoencoderKL +from ...models.transformers import FluxTransformer2DModel +from ...schedulers import FlowMatchEulerDiscreteScheduler +from ...utils import ( + USE_PEFT_BACKEND, + is_torch_xla_available, + logging, + replace_example_docstring, + scale_lora_layers, + unscale_lora_layers, +) +from ...utils.torch_utils import randn_tensor +from ..pipeline_utils import DiffusionPipeline +from .pipeline_output import FluxPipelineOutput + + +if is_torch_xla_available(): + import torch_xla.core.xla_model as xm + + XLA_AVAILABLE = True +else: + XLA_AVAILABLE = False + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + +EXAMPLE_DOC_STRING = """ + Examples: + ```py + >>> import torch + >>> from diffusers import FluxPipeline + + >>> pipe = FluxPipeline.from_pretrained("black-forest-labs/FLUX.1-schnell", torch_dtype=torch.bfloat16) + >>> pipe.to("cuda") + >>> prompt = "A cat holding a sign that says hello world" + >>> # Depending on the variant being used, the pipeline call will slightly vary. + >>> # Refer to the pipeline documentation for more details. + >>> image = pipe(prompt, num_inference_steps=4, guidance_scale=0.0).images[0] + >>> image.save("flux.png") + ``` +""" + + +# Copied from diffusers.pipelines.flux.pipeline_flux.calculate_shift +def calculate_shift( + image_seq_len, + base_seq_len: int = 256, + max_seq_len: int = 4096, + base_shift: float = 0.5, + max_shift: float = 1.16, +): + m = (max_shift - base_shift) / (max_seq_len - base_seq_len) + b = base_shift - m * base_seq_len + mu = image_seq_len * m + b + return mu + + +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps +def retrieve_timesteps( + scheduler, + num_inference_steps: Optional[int] = None, + device: Optional[Union[str, torch.device]] = None, + timesteps: Optional[List[int]] = None, + sigmas: Optional[List[float]] = None, + **kwargs, +): + r""" + Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles + custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`. + + Args: + scheduler (`SchedulerMixin`): + The scheduler to get timesteps from. + num_inference_steps (`int`): + The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps` + must be `None`. + device (`str` or `torch.device`, *optional*): + The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. + timesteps (`List[int]`, *optional*): + Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed, + `num_inference_steps` and `sigmas` must be `None`. + sigmas (`List[float]`, *optional*): + Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed, + `num_inference_steps` and `timesteps` must be `None`. + + Returns: + `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the + second element is the number of inference steps. + """ + if timesteps is not None and sigmas is not None: + raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values") + if timesteps is not None: + accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accepts_timesteps: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" timestep schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + elif sigmas is not None: + accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accept_sigmas: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" sigmas schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + else: + scheduler.set_timesteps(num_inference_steps, device=device, **kwargs) + timesteps = scheduler.timesteps + return timesteps, num_inference_steps + + +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents +def retrieve_latents( + encoder_output: torch.Tensor, generator: Optional[torch.Generator] = None, sample_mode: str = "sample" +): + if hasattr(encoder_output, "latent_dist") and sample_mode == "sample": + return encoder_output.latent_dist.sample(generator) + elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax": + return encoder_output.latent_dist.mode() + elif hasattr(encoder_output, "latents"): + return encoder_output.latents + else: + raise AttributeError("Could not access latents of provided encoder_output") + + +class FluxFillPipeline( + DiffusionPipeline, + FluxLoraLoaderMixin, + FromSingleFileMixin, + TextualInversionLoaderMixin, +): + r""" + The Flux pipeline for text-to-image generation. + + Reference: https://blackforestlabs.ai/announcing-black-forest-labs/ + + Args: + transformer ([`FluxTransformer2DModel`]): + Conditional Transformer (MMDiT) architecture to denoise the encoded image latents. + scheduler ([`FlowMatchEulerDiscreteScheduler`]): + A scheduler to be used in combination with `transformer` to denoise the encoded image latents. + vae ([`AutoencoderKL`]): + Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations. + text_encoder ([`CLIPTextModel`]): + [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically + the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant. + text_encoder_2 ([`T5EncoderModel`]): + [T5](https://huggingface.co/docs/transformers/en/model_doc/t5#transformers.T5EncoderModel), specifically + the [google/t5-v1_1-xxl](https://huggingface.co/google/t5-v1_1-xxl) variant. + tokenizer (`CLIPTokenizer`): + Tokenizer of class + [CLIPTokenizer](https://huggingface.co/docs/transformers/en/model_doc/clip#transformers.CLIPTokenizer). + tokenizer_2 (`T5TokenizerFast`): + Second Tokenizer of class + [T5TokenizerFast](https://huggingface.co/docs/transformers/en/model_doc/t5#transformers.T5TokenizerFast). + """ + + model_cpu_offload_seq = "text_encoder->text_encoder_2->transformer->vae" + _optional_components = [] + _callback_tensor_inputs = ["latents", "prompt_embeds"] + + def __init__( + self, + scheduler: FlowMatchEulerDiscreteScheduler, + vae: AutoencoderKL, + text_encoder: CLIPTextModel, + tokenizer: CLIPTokenizer, + text_encoder_2: T5EncoderModel, + tokenizer_2: T5TokenizerFast, + transformer: FluxTransformer2DModel, + ): + super().__init__() + + self.register_modules( + vae=vae, + text_encoder=text_encoder, + text_encoder_2=text_encoder_2, + tokenizer=tokenizer, + tokenizer_2=tokenizer_2, + transformer=transformer, + scheduler=scheduler, + ) + self.vae_scale_factor = ( + 2 ** (len(self.vae.config.block_out_channels) - 1) if hasattr(self, "vae") and self.vae is not None else 8 + ) + # Flux latents are turned into 2x2 patches and packed. This means the latent width and height has to be divisible + # by the patch size. So the vae scale factor is multiplied by the patch size to account for this + self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor * 2, do_resize=False) + self.mask_processor = VaeImageProcessor( + vae_scale_factor=self.vae_scale_factor * 2, + vae_latent_channels=self.vae.config.latent_channels, + do_normalize=False, + do_binarize=True, + do_convert_grayscale=True, + do_resize=False, + ) + self.tokenizer_max_length = ( + self.tokenizer.model_max_length if hasattr(self, "tokenizer") and self.tokenizer is not None else 77 + ) + self.default_sample_size = 128 + + # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline._get_t5_prompt_embeds + def _get_t5_prompt_embeds( + self, + prompt: Union[str, List[str]] = None, + num_images_per_prompt: int = 1, + max_sequence_length: int = 512, + device: Optional[torch.device] = None, + dtype: Optional[torch.dtype] = None, + ): + device = device or self._execution_device + dtype = dtype or self.text_encoder.dtype + + prompt = [prompt] if isinstance(prompt, str) else prompt + batch_size = len(prompt) + + if isinstance(self, TextualInversionLoaderMixin): + prompt = self.maybe_convert_prompt(prompt, self.tokenizer_2) + + text_inputs = self.tokenizer_2( + prompt, + padding="max_length", + max_length=max_sequence_length, + truncation=True, + return_length=False, + return_overflowing_tokens=False, + return_tensors="pt", + ) + text_input_ids = text_inputs.input_ids + untruncated_ids = self.tokenizer_2(prompt, padding="longest", return_tensors="pt").input_ids + + if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids): + removed_text = self.tokenizer_2.batch_decode(untruncated_ids[:, self.tokenizer_max_length - 1 : -1]) + logger.warning( + "The following part of your input was truncated because `max_sequence_length` is set to " + f" {max_sequence_length} tokens: {removed_text}" + ) + + prompt_embeds = self.text_encoder_2(text_input_ids.to(device), output_hidden_states=False)[0] + + dtype = self.text_encoder_2.dtype + prompt_embeds = prompt_embeds.to(dtype=dtype, device=device) + + _, seq_len, _ = prompt_embeds.shape + + # duplicate text embeddings and attention mask for each generation per prompt, using mps friendly method + prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) + prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) + + return prompt_embeds + + # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline._get_clip_prompt_embeds + def _get_clip_prompt_embeds( + self, + prompt: Union[str, List[str]], + num_images_per_prompt: int = 1, + device: Optional[torch.device] = None, + ): + device = device or self._execution_device + + prompt = [prompt] if isinstance(prompt, str) else prompt + batch_size = len(prompt) + + if isinstance(self, TextualInversionLoaderMixin): + prompt = self.maybe_convert_prompt(prompt, self.tokenizer) + + text_inputs = self.tokenizer( + prompt, + padding="max_length", + max_length=self.tokenizer_max_length, + truncation=True, + return_overflowing_tokens=False, + return_length=False, + return_tensors="pt", + ) + + text_input_ids = text_inputs.input_ids + untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids + if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids): + removed_text = self.tokenizer.batch_decode(untruncated_ids[:, self.tokenizer_max_length - 1 : -1]) + logger.warning( + "The following part of your input was truncated because CLIP can only handle sequences up to" + f" {self.tokenizer_max_length} tokens: {removed_text}" + ) + prompt_embeds = self.text_encoder(text_input_ids.to(device), output_hidden_states=False) + + # Use pooled output of CLIPTextModel + prompt_embeds = prompt_embeds.pooler_output + prompt_embeds = prompt_embeds.to(dtype=self.text_encoder.dtype, device=device) + + # duplicate text embeddings for each generation per prompt, using mps friendly method + prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt) + prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, -1) + + return prompt_embeds + + def prepare_mask_latents( + self, + mask, + masked_image, + batch_size, + num_channels_latents, + num_images_per_prompt, + height, + width, + dtype, + device, + generator, + ): + # VAE applies 8x compression on images but we must also account for packing which requires + # latent height and width to be divisible by 2. + height = 2 * (int(height) // (self.vae_scale_factor * 2)) + width = 2 * (int(width) // (self.vae_scale_factor * 2)) + + batch_size = batch_size * num_images_per_prompt + + if masked_image.shape[1] == 16: + masked_image_latents = masked_image + else: + masked_image_latents = retrieve_latents(self.vae.encode(masked_image), generator=generator) + masked_image_latents = (masked_image_latents - self.vae.config.shift_factor) * self.vae.config.scaling_factor + + # duplicate mask and masked_image_latents for each generation per prompt, using mps friendly method + if mask.shape[0] < batch_size: + if not batch_size % mask.shape[0] == 0: + raise ValueError( + "The passed mask and the required batch size don't match. Masks are supposed to be duplicated to" + f" a total batch size of {batch_size}, but {mask.shape[0]} masks were passed. Make sure the number" + " of masks that you pass is divisible by the total requested batch size." + ) + mask = mask.repeat(batch_size // mask.shape[0], 1, 1, 1) + if masked_image_latents.shape[0] < batch_size: + if not batch_size % masked_image_latents.shape[0] == 0: + raise ValueError( + "The passed images and the required batch size don't match. Images are supposed to be duplicated" + f" to a total batch size of {batch_size}, but {masked_image_latents.shape[0]} images were passed." + " Make sure the number of images that you pass is divisible by the total requested batch size." + ) + masked_image_latents = masked_image_latents.repeat(batch_size // masked_image_latents.shape[0], 1, 1, 1) + + # prepare mask for latents + mask = mask[:, 0, :, :] + mask = mask.view(batch_size, height, self.vae_scale_factor, width, self.vae_scale_factor) + mask = mask.permute(0, 2, 4, 1, 3) + mask = mask.reshape(batch_size, self.vae_scale_factor * self.vae_scale_factor, height, width) + + masked_image_latents = masked_image_latents.to(device=device, dtype=dtype) + + masked_image_latents = self._pack_latents( + masked_image_latents, + batch_size, + num_channels_latents, + height, + width, + ) + + mask = self._pack_latents( + mask, + batch_size, + 64, + height, + width, + ) + mask = mask.to(device=device, dtype=dtype) + + return mask, masked_image_latents + + # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline.encode_prompt + def encode_prompt( + self, + prompt: Union[str, List[str]], + prompt_2: Union[str, List[str]], + device: Optional[torch.device] = None, + num_images_per_prompt: int = 1, + prompt_embeds: Optional[torch.FloatTensor] = None, + pooled_prompt_embeds: Optional[torch.FloatTensor] = None, + max_sequence_length: int = 512, + lora_scale: Optional[float] = None, + ): + r""" + + Args: + prompt (`str` or `List[str]`, *optional*): + prompt to be encoded + prompt_2 (`str` or `List[str]`, *optional*): + The prompt or prompts to be sent to the `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is + used in all text-encoders + device: (`torch.device`): + torch device + num_images_per_prompt (`int`): + number of images that should be generated per prompt + prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + pooled_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. + If not provided, pooled text embeddings will be generated from `prompt` input argument. + lora_scale (`float`, *optional*): + A lora scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded. + """ + device = device or self._execution_device + + # set lora scale so that monkey patched LoRA + # function of text encoder can correctly access it + if lora_scale is not None and isinstance(self, FluxLoraLoaderMixin): + self._lora_scale = lora_scale + + # dynamically adjust the LoRA scale + if self.text_encoder is not None and USE_PEFT_BACKEND: + scale_lora_layers(self.text_encoder, lora_scale) + if self.text_encoder_2 is not None and USE_PEFT_BACKEND: + scale_lora_layers(self.text_encoder_2, lora_scale) + + prompt = [prompt] if isinstance(prompt, str) else prompt + + if prompt_embeds is None: + prompt_2 = prompt_2 or prompt + prompt_2 = [prompt_2] if isinstance(prompt_2, str) else prompt_2 + + # We only use the pooled prompt output from the CLIPTextModel + pooled_prompt_embeds = self._get_clip_prompt_embeds( + prompt=prompt, + device=device, + num_images_per_prompt=num_images_per_prompt, + ) + prompt_embeds = self._get_t5_prompt_embeds( + prompt=prompt_2, + num_images_per_prompt=num_images_per_prompt, + max_sequence_length=max_sequence_length, + device=device, + ) + + if self.text_encoder is not None: + if isinstance(self, FluxLoraLoaderMixin) and USE_PEFT_BACKEND: + # Retrieve the original scale by scaling back the LoRA layers + unscale_lora_layers(self.text_encoder, lora_scale) + + if self.text_encoder_2 is not None: + if isinstance(self, FluxLoraLoaderMixin) and USE_PEFT_BACKEND: + # Retrieve the original scale by scaling back the LoRA layers + unscale_lora_layers(self.text_encoder_2, lora_scale) + + dtype = self.text_encoder.dtype if self.text_encoder is not None else self.transformer.dtype + text_ids = torch.zeros(prompt_embeds.shape[1], 3).to(device=device, dtype=dtype) + + return prompt_embeds, pooled_prompt_embeds, text_ids + + # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline.check_inputs + def check_inputs( + self, + prompt, + prompt_2, + height, + width, + prompt_embeds=None, + pooled_prompt_embeds=None, + callback_on_step_end_tensor_inputs=None, + max_sequence_length=None, + ): + if height % (self.vae_scale_factor * 2) != 0 or width % (self.vae_scale_factor * 2) != 0: + logger.warning( + f"`height` and `width` have to be divisible by {self.vae_scale_factor * 2} but are {height} and {width}. Dimensions will be resized accordingly" + ) + + if callback_on_step_end_tensor_inputs is not None and not all( + k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs + ): + raise ValueError( + f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}" + ) + + if prompt is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt_2 is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt_2`: {prompt_2} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt is None and prompt_embeds is None: + raise ValueError( + "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." + ) + elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + elif prompt_2 is not None and (not isinstance(prompt_2, str) and not isinstance(prompt_2, list)): + raise ValueError(f"`prompt_2` has to be of type `str` or `list` but is {type(prompt_2)}") + + if prompt_embeds is not None and pooled_prompt_embeds is None: + raise ValueError( + "If `prompt_embeds` are provided, `pooled_prompt_embeds` also have to be passed. Make sure to generate `pooled_prompt_embeds` from the same text encoder that was used to generate `prompt_embeds`." + ) + + if max_sequence_length is not None and max_sequence_length > 512: + raise ValueError(f"`max_sequence_length` cannot be greater than 512 but is {max_sequence_length}") + + @staticmethod + # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline._prepare_latent_image_ids + def _prepare_latent_image_ids(batch_size, height, width, device, dtype): + latent_image_ids = torch.zeros(height, width, 3) + latent_image_ids[..., 1] = latent_image_ids[..., 1] + torch.arange(height)[:, None] + latent_image_ids[..., 2] = latent_image_ids[..., 2] + torch.arange(width)[None, :] + + latent_image_id_height, latent_image_id_width, latent_image_id_channels = latent_image_ids.shape + + latent_image_ids = latent_image_ids.reshape( + latent_image_id_height * latent_image_id_width, latent_image_id_channels + ) + + return latent_image_ids.to(device=device, dtype=dtype) + + @staticmethod + # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline._pack_latents + def _pack_latents(latents, batch_size, num_channels_latents, height, width): + latents = latents.view(batch_size, num_channels_latents, height // 2, 2, width // 2, 2) + latents = latents.permute(0, 2, 4, 1, 3, 5) + latents = latents.reshape(batch_size, (height // 2) * (width // 2), num_channels_latents * 4) + + return latents + + @staticmethod + # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline._unpack_latents + def _unpack_latents(latents, height, width, vae_scale_factor): + batch_size, num_patches, channels = latents.shape + + # VAE applies 8x compression on images but we must also account for packing which requires + # latent height and width to be divisible by 2. + height = 2 * (int(height) // (vae_scale_factor * 2)) + width = 2 * (int(width) // (vae_scale_factor * 2)) + + latents = latents.view(batch_size, height // 2, width // 2, channels // 4, 2, 2) + latents = latents.permute(0, 3, 1, 4, 2, 5) + + latents = latents.reshape(batch_size, channels // (2 * 2), height, width) + + return latents + + def enable_vae_slicing(self): + r""" + Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to + compute decoding in several steps. This is useful to save some memory and allow larger batch sizes. + """ + self.vae.enable_slicing() + + def disable_vae_slicing(self): + r""" + Disable sliced VAE decoding. If `enable_vae_slicing` was previously enabled, this method will go back to + computing decoding in one step. + """ + self.vae.disable_slicing() + + def enable_vae_tiling(self): + r""" + Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to + compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow + processing larger images. + """ + self.vae.enable_tiling() + + def disable_vae_tiling(self): + r""" + Disable tiled VAE decoding. If `enable_vae_tiling` was previously enabled, this method will go back to + computing decoding in one step. + """ + self.vae.disable_tiling() + + # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline.prepare_latents + def prepare_latents( + self, + batch_size, + num_channels_latents, + height, + width, + dtype, + device, + generator, + latents=None, + ): + # VAE applies 8x compression on images but we must also account for packing which requires + # latent height and width to be divisible by 2. + height = 2 * (int(height) // (self.vae_scale_factor * 2)) + width = 2 * (int(width) // (self.vae_scale_factor * 2)) + + shape = (batch_size, num_channels_latents, height, width) + + if latents is not None: + latent_image_ids = self._prepare_latent_image_ids(batch_size, height // 2, width // 2, device, dtype) + return latents.to(device=device, dtype=dtype), latent_image_ids + + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + + latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + latents = self._pack_latents(latents, batch_size, num_channels_latents, height, width) + + latent_image_ids = self._prepare_latent_image_ids(batch_size, height // 2, width // 2, device, dtype) + + return latents, latent_image_ids + + @property + def guidance_scale(self): + return self._guidance_scale + + @property + def joint_attention_kwargs(self): + return self._joint_attention_kwargs + + @property + def num_timesteps(self): + return self._num_timesteps + + @property + def interrupt(self): + return self._interrupt + + @torch.no_grad() + @replace_example_docstring(EXAMPLE_DOC_STRING) + def __call__( + self, + prompt: Union[str, List[str]] = None, + prompt_2: Optional[Union[str, List[str]]] = None, + image: Optional[torch.FloatTensor] = None, + mask_image: Optional[torch.FloatTensor] = None, + masked_image_latents: Optional[torch.FloatTensor] = None, + height: Optional[int] = None, + width: Optional[int] = None, + num_inference_steps: int = 28, + timesteps: List[int] = None, + guidance_scale: float = 3.5, + num_images_per_prompt: Optional[int] = 1, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + latents: Optional[torch.FloatTensor] = None, + prompt_embeds: Optional[torch.FloatTensor] = None, + pooled_prompt_embeds: Optional[torch.FloatTensor] = None, + output_type: Optional[str] = "pil", + return_dict: bool = True, + joint_attention_kwargs: Optional[Dict[str, Any]] = None, + callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None, + callback_on_step_end_tensor_inputs: List[str] = ["latents"], + max_sequence_length: int = 512, + ): + r""" + Function invoked when calling the pipeline for generation. + + Args: + prompt (`str` or `List[str]`, *optional*): + The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`. + instead. + prompt_2 (`str` or `List[str]`, *optional*): + The prompt or prompts to be sent to `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is + will be used instead + image (`torch.Tensor`, `PIL.Image.Image`, `np.ndarray`, `List[torch.Tensor]`, `List[PIL.Image.Image]`, or `List[np.ndarray]`): + `Image`, numpy array or tensor representing an image batch to be used as the starting point. For both + numpy array and pytorch tensor, the expected value range is between `[0, 1]` If it's a tensor or a list + or tensors, the expected shape should be `(B, C, H, W)` or `(C, H, W)`. If it is a numpy array or a + list of arrays, the expected shape should be `(B, H, W, C)` or `(H, W, C)` It can also accept image + latents as `image`, but if passing latents directly it is not encoded again. + mask_image (`torch.Tensor`, `PIL.Image.Image`, `np.ndarray`, `List[torch.Tensor]`, `List[PIL.Image.Image]`, or `List[np.ndarray]`): + `Image`, numpy array or tensor representing an image batch to mask `image`. White pixels in the mask + are repainted while black pixels are preserved. If `mask_image` is a PIL image, it is converted to a + single channel (luminance) before use. If it's a numpy array or pytorch tensor, it should contain one + color channel (L) instead of 3, so the expected shape for pytorch tensor would be `(B, 1, H, W)`, `(B, + H, W)`, `(1, H, W)`, `(H, W)`. And for numpy array would be for `(B, H, W, 1)`, `(B, H, W)`, `(H, W, + 1)`, or `(H, W)`. + mask_image_latent (`torch.Tensor`, `List[torch.Tensor]`): + `Tensor` representing an image batch to mask `image` generated by VAE. If not provided, the mask + latents tensor will ge generated by `mask_image`. + height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): + The height in pixels of the generated image. This is set to 1024 by default for the best results. + width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): + The width in pixels of the generated image. This is set to 1024 by default for the best results. + num_inference_steps (`int`, *optional*, defaults to 50): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + timesteps (`List[int]`, *optional*): + Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument + in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is + passed will be used. Must be in descending order. + guidance_scale (`float`, *optional*, defaults to 7.0): + Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). + `guidance_scale` is defined as `w` of equation 2. of [Imagen + Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > + 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, + usually at the expense of lower image quality. + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + generator (`torch.Generator` or `List[torch.Generator]`, *optional*): + One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) + to make generation deterministic. + latents (`torch.FloatTensor`, *optional*): + Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor will ge generated by sampling using the supplied random `generator`. + prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + pooled_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. + If not provided, pooled text embeddings will be generated from `prompt` input argument. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generate image. Choose between + [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.flux.FluxPipelineOutput`] instead of a plain tuple. + joint_attention_kwargs (`dict`, *optional*): + A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under + `self.processor` in + [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). + callback_on_step_end (`Callable`, *optional*): + A function that calls at the end of each denoising steps during the inference. The function is called + with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int, + callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by + `callback_on_step_end_tensor_inputs`. + callback_on_step_end_tensor_inputs (`List`, *optional*): + The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list + will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the + `._callback_tensor_inputs` attribute of your pipeline class. + max_sequence_length (`int` defaults to 512): Maximum sequence length to use with the `prompt`. + + Examples: + + Returns: + [`~pipelines.flux.FluxPipelineOutput`] or `tuple`: [`~pipelines.flux.FluxPipelineOutput`] if `return_dict` + is True, otherwise a `tuple`. When returning a tuple, the first element is a list with the generated + images. + """ + + height = height or self.default_sample_size * self.vae_scale_factor + width = width or self.default_sample_size * self.vae_scale_factor + + # 1. Check inputs. Raise error if not correct + self.check_inputs( + prompt, + prompt_2, + height, + width, + prompt_embeds=prompt_embeds, + pooled_prompt_embeds=pooled_prompt_embeds, + callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs, + max_sequence_length=max_sequence_length, + ) + + self._guidance_scale = guidance_scale + self._joint_attention_kwargs = joint_attention_kwargs + self._interrupt = False + + # 2. Define call parameters + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + device = self._execution_device + + lora_scale = ( + self.joint_attention_kwargs.get("scale", None) if self.joint_attention_kwargs is not None else None + ) + ( + prompt_embeds, + pooled_prompt_embeds, + text_ids, + ) = self.encode_prompt( + prompt=prompt, + prompt_2=prompt_2, + prompt_embeds=prompt_embeds, + pooled_prompt_embeds=pooled_prompt_embeds, + device=device, + num_images_per_prompt=num_images_per_prompt, + max_sequence_length=max_sequence_length, + lora_scale=lora_scale, + ) + + # 4. Prepare latent variables + num_channels_latents = self.transformer.config.out_channels // 4 + latents, latent_image_ids = self.prepare_latents( + batch_size * num_images_per_prompt, + num_channels_latents, + height, + width, + prompt_embeds.dtype, + device, + generator, + latents, + ) + + if masked_image_latents is not None: + masked_image_latents = masked_image_latents.to(latents.device) + else: + if image is not None and mask_image is not None: + image = self.image_processor.preprocess(image) + mask_image = self.mask_processor.preprocess(mask_image) + masked_image = image * (1 - mask_image) + masked_image = masked_image.to(device=device, dtype=prompt_embeds.dtype) + + height, width = image.shape[-2:] + + mask, masked_image_latents = self.prepare_mask_latents( + mask_image, + masked_image, + batch_size, + num_channels_latents, + num_images_per_prompt, + height, + width, + prompt_embeds.dtype, + device, + generator, + ) + masked_image_latents = torch.cat((masked_image_latents, mask), dim=-1) + + # 5. Prepare timesteps + sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps) + image_seq_len = latents.shape[1] + mu = calculate_shift( + image_seq_len, + self.scheduler.config.base_image_seq_len, + self.scheduler.config.max_image_seq_len, + self.scheduler.config.base_shift, + self.scheduler.config.max_shift, + ) + timesteps, num_inference_steps = retrieve_timesteps( + self.scheduler, + num_inference_steps, + device, + timesteps, + sigmas, + mu=mu, + ) + num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0) + self._num_timesteps = len(timesteps) + + # handle guidance + if self.transformer.config.guidance_embeds: + guidance = torch.full([1], guidance_scale, device=device, dtype=torch.float32) + guidance = guidance.expand(latents.shape[0]) + else: + guidance = None + + # 6. Denoising loop + with self.progress_bar(total=num_inference_steps) as progress_bar: + for i, t in enumerate(timesteps): + if self.interrupt: + continue + + # broadcast to batch dimension in a way that's compatible with ONNX/Core ML + timestep = t.expand(latents.shape[0]).to(latents.dtype) + + noise_pred = self.transformer( + hidden_states=torch.cat((latents, masked_image_latents), dim=2), + timestep=timestep / 1000, + guidance=guidance, + pooled_projections=pooled_prompt_embeds, + encoder_hidden_states=prompt_embeds, + txt_ids=text_ids, + img_ids=latent_image_ids, + joint_attention_kwargs=self.joint_attention_kwargs, + return_dict=False, + )[0] + + # compute the previous noisy sample x_t -> x_t-1 + latents_dtype = latents.dtype + latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0] + + if latents.dtype != latents_dtype: + if torch.backends.mps.is_available(): + # some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272 + latents = latents.to(latents_dtype) + + if callback_on_step_end is not None: + callback_kwargs = {} + for k in callback_on_step_end_tensor_inputs: + callback_kwargs[k] = locals()[k] + callback_outputs = callback_on_step_end(self, i, t, callback_kwargs) + + latents = callback_outputs.pop("latents", latents) + prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds) + + # call the callback, if provided + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + progress_bar.update() + + if XLA_AVAILABLE: + xm.mark_step() + + if output_type == "latent": + image = latents + + else: + latents = self._unpack_latents(latents, height, width, self.vae_scale_factor) + latents = (latents / self.vae.config.scaling_factor) + self.vae.config.shift_factor + image = self.vae.decode(latents, return_dict=False)[0] + image = self.image_processor.postprocess(image, output_type=output_type) + + # Offload all models + self.maybe_free_model_hooks() + + if not return_dict: + return (image,) + + return FluxPipelineOutput(images=image) diff --git a/src/diffusers/utils/dummy_torch_and_transformers_objects.py b/src/diffusers/utils/dummy_torch_and_transformers_objects.py index 8b4b158efd0a..9f8689457555 100644 --- a/src/diffusers/utils/dummy_torch_and_transformers_objects.py +++ b/src/diffusers/utils/dummy_torch_and_transformers_objects.py @@ -422,6 +422,21 @@ def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch", "transformers"]) +class FluxFillPipeline(metaclass=DummyObject): + _backends = ["torch", "transformers"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch", "transformers"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + class FluxImg2ImgPipeline(metaclass=DummyObject): _backends = ["torch", "transformers"] From 7e4df06b7b594dee78610043b81094bb079d443f Mon Sep 17 00:00:00 2001 From: Aryan Date: Thu, 21 Nov 2024 17:45:15 +0100 Subject: [PATCH 03/58] update --- .../pipelines/flux/pipeline_flux_img2img.py | 83 ++++++++++++++++++- 1 file changed, 81 insertions(+), 2 deletions(-) diff --git a/src/diffusers/pipelines/flux/pipeline_flux_img2img.py b/src/diffusers/pipelines/flux/pipeline_flux_img2img.py index d34d9b53aa6b..dcabcac6ed7c 100644 --- a/src/diffusers/pipelines/flux/pipeline_flux_img2img.py +++ b/src/diffusers/pipelines/flux/pipeline_flux_img2img.py @@ -566,6 +566,41 @@ def prepare_latents( latents = self._pack_latents(latents, batch_size, num_channels_latents, height, width) return latents, latent_image_ids + # Copied from diffusers.pipelines.controlnet_sd3.pipeline_stable_diffusion_3_controlnet.StableDiffusion3ControlNetPipeline.prepare_image + def prepare_image( + self, + image, + width, + height, + batch_size, + num_images_per_prompt, + device, + dtype, + do_classifier_free_guidance=False, + guess_mode=False, + ): + if isinstance(image, torch.Tensor): + pass + else: + image = self.image_processor.preprocess(image, height=height, width=width) + + image_batch_size = image.shape[0] + + if image_batch_size == 1: + repeat_by = batch_size + else: + # image batch size is the same as prompt batch size + repeat_by = num_images_per_prompt + + image = image.repeat_interleave(repeat_by, dim=0) + + image = image.to(device=device, dtype=dtype) + + if do_classifier_free_guidance and not guess_mode: + image = torch.cat([image] * 2) + + return image + @property def guidance_scale(self): return self._guidance_scale @@ -595,8 +630,10 @@ def __call__( num_inference_steps: int = 28, timesteps: List[int] = None, guidance_scale: float = 7.0, + control_image: PipelineImageInput = None, num_images_per_prompt: Optional[int] = 1, generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + control_latents: Optional[torch.FloatTensor] = None, latents: Optional[torch.FloatTensor] = None, prompt_embeds: Optional[torch.FloatTensor] = None, pooled_prompt_embeds: Optional[torch.FloatTensor] = None, @@ -646,6 +683,14 @@ def __call__( Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, usually at the expense of lower image quality. + control_image (`torch.Tensor`, `PIL.Image.Image`, `np.ndarray`, `List[torch.Tensor]`, `List[PIL.Image.Image]`, `List[np.ndarray]`,: + `List[List[torch.Tensor]]`, `List[List[np.ndarray]]` or `List[List[PIL.Image.Image]]`): + The ControlNet input condition to provide guidance to the `unet` for generation. If the type is + specified as `torch.Tensor`, it is passed to ControlNet as is. `PIL.Image.Image` can also be accepted + as an image. The dimensions of the output image defaults to `image`'s dimensions. If height and/or + width are passed, `image` is resized accordingly. If multiple ControlNets are specified in `init`, + images must be passed as a list such that each element of the list can be correctly batched for input + to a single ControlNet. num_images_per_prompt (`int`, *optional*, defaults to 1): The number of images to generate per prompt. generator (`torch.Generator` or `List[torch.Generator]`, *optional*): @@ -723,6 +768,7 @@ def __call__( device = self._execution_device + # 3. Prepare text embeddings lora_scale = ( self.joint_attention_kwargs.get("scale", None) if self.joint_attention_kwargs is not None else None ) @@ -769,7 +815,34 @@ def __call__( latent_timestep = timesteps[:1].repeat(batch_size * num_images_per_prompt) # 5. Prepare latent variables - num_channels_latents = self.transformer.config.in_channels // 4 + num_channels_latents = ( + self.transformer.config.in_channels // 4 + if control_image is None + else self.transformer.config.in_channels // 8 + ) + + if control_image is not None and control_latents is None: + control_image = self.prepare_image( + image=control_image, + width=width, + height=height, + batch_size=batch_size * num_images_per_prompt, + num_images_per_prompt=num_images_per_prompt, + device=device, + dtype=self.vae.dtype, + ) + + control_latents = self.vae.encode(control_image).latent_dist.sample(generator=generator) + control_latents = (control_latents - self.vae.config.shift_factor) * self.vae.config.scaling_factor + + height_control_image, width_control_image = control_latents.shape[2:] + control_latents = self._pack_latents( + control_latents, + batch_size * num_images_per_prompt, + num_channels_latents, + height_control_image, + width_control_image, + ) latents, latent_image_ids = self.prepare_latents( init_image, @@ -800,10 +873,16 @@ def __call__( if self.interrupt: continue + if control_latents is not None: + latent_model_input = torch.cat([latents, control_latents], dim=2) + else: + latent_model_input = latents + # broadcast to batch dimension in a way that's compatible with ONNX/Core ML timestep = t.expand(latents.shape[0]).to(latents.dtype) + noise_pred = self.transformer( - hidden_states=latents, + hidden_states=latent_model_input, timestep=timestep / 1000, guidance=guidance, pooled_projections=pooled_prompt_embeds, From 217e90cb99e1710267d0858e5d9a12351fd90139 Mon Sep 17 00:00:00 2001 From: Aryan Date: Thu, 21 Nov 2024 23:37:17 +0100 Subject: [PATCH 04/58] add control-lora conversion script; make flux loader handle norms; fix rank calculation assumption --- .../convert_flux_control_lora_to_diffusers.py | 393 ++++++++++++++++++ src/diffusers/loaders/lora_pipeline.py | 54 ++- src/diffusers/loaders/peft.py | 4 +- 3 files changed, 445 insertions(+), 6 deletions(-) create mode 100644 scripts/convert_flux_control_lora_to_diffusers.py diff --git a/scripts/convert_flux_control_lora_to_diffusers.py b/scripts/convert_flux_control_lora_to_diffusers.py new file mode 100644 index 000000000000..a110bd0bc0c8 --- /dev/null +++ b/scripts/convert_flux_control_lora_to_diffusers.py @@ -0,0 +1,393 @@ +import argparse +from contextlib import nullcontext + +import safetensors.torch +import torch +from accelerate import init_empty_weights +from huggingface_hub import hf_hub_download + +from diffusers.utils.import_utils import is_accelerate_available + + +CTX = init_empty_weights if is_accelerate_available else nullcontext + +parser = argparse.ArgumentParser() +parser.add_argument("--original_state_dict_repo_id", default=None, type=str) +parser.add_argument("--filename", default="flux-canny-dev-lora.safetensors", type=str) +parser.add_argument("--checkpoint_path", default=None, type=str) +parser.add_argument("--output_path", type=str) +parser.add_argument("--dtype", type=str, default="bf16") + +args = parser.parse_args() +dtype = torch.bfloat16 if args.dtype == "bf16" else torch.float32 + + +# Adapted from from the original BFL codebase. +def optionally_expand_state_dict(name: str, param: torch.Tensor, state_dict: dict) -> dict: + if name in state_dict: + print(f"Expanding '{name}' with shape {state_dict[name].shape} to model parameter with shape {param.shape}.") + # expand with zeros: + expanded_state_dict_weight = torch.zeros_like(param, device=state_dict[name].device) + # popular with pre-trained param for the first half. Remaining half stays with zeros. + slices = tuple(slice(0, dim) for dim in state_dict[name].shape) + expanded_state_dict_weight[slices] = state_dict[name] + state_dict[name] = expanded_state_dict_weight + + return state_dict + + +def load_original_checkpoint(args): + if args.original_state_dict_repo_id is not None: + ckpt_path = hf_hub_download(repo_id=args.original_state_dict_repo_id, filename=args.filename) + elif args.checkpoint_path is not None: + ckpt_path = args.checkpoint_path + else: + raise ValueError(" please provide either `original_state_dict_repo_id` or a local `checkpoint_path`") + + original_state_dict = safetensors.torch.load_file(ckpt_path) + return original_state_dict + + +# in SD3 original implementation of AdaLayerNormContinuous, it split linear projection output into shift, scale; +# while in diffusers it split into scale, shift. Here we swap the linear projection weights in order to be able to use diffusers implementation +def swap_scale_shift(weight): + shift, scale = weight.chunk(2, dim=0) + new_weight = torch.cat([scale, shift], dim=0) + return new_weight + + +def convert_flux_control_lora_checkpoint_to_diffusers( + original_state_dict, num_layers, num_single_layers, inner_dim, mlp_ratio=4.0 +): + converted_state_dict = {} + + ## time_text_embed.timestep_embedder <- time_in + for lora_key, diffusers_lora_key in zip(["lora_A", "lora_B"], ["lora_A", "lora_B"]): + converted_state_dict[ + f"time_text_embed.timestep_embedder.linear_1.{diffusers_lora_key}.weight" + ] = original_state_dict.pop(f"time_in.in_layer.{lora_key}.weight") + if f"time_in.in_layer.{lora_key}.bias" in original_state_dict.keys(): + converted_state_dict[ + f"time_text_embed.timestep_embedder.linear_1.{diffusers_lora_key}.bias" + ] = original_state_dict.pop(f"time_in.in_layer.{lora_key}.bias") + + converted_state_dict[ + f"time_text_embed.timestep_embedder.linear_2.{diffusers_lora_key}.weight" + ] = original_state_dict.pop(f"time_in.out_layer.{lora_key}.weight") + if f"time_in.out_layer.{lora_key}.bias" in original_state_dict.keys(): + converted_state_dict[ + f"time_text_embed.timestep_embedder.linear_2.{diffusers_lora_key}.bias" + ] = original_state_dict.pop(f"time_in.out_layer.{lora_key}.bias") + + ## time_text_embed.text_embedder <- vector_in + converted_state_dict[ + f"time_text_embed.text_embedder.linear_1.{diffusers_lora_key}.weight" + ] = original_state_dict.pop(f"vector_in.in_layer.{lora_key}.weight") + if f"vector_in.in_layer.{lora_key}.bias" in original_state_dict.keys(): + converted_state_dict[ + f"time_text_embed.text_embedder.linear_1.{diffusers_lora_key}.bias" + ] = original_state_dict.pop(f"vector_in.in_layer.{lora_key}.bias") + + converted_state_dict[ + f"time_text_embed.text_embedder.linear_2.{diffusers_lora_key}.weight" + ] = original_state_dict.pop(f"vector_in.out_layer.{lora_key}.weight") + if f"vector_in.out_layer.{lora_key}.bias" in original_state_dict.keys(): + converted_state_dict[ + f"time_text_embed.text_embedder.linear_2.{diffusers_lora_key}.bias" + ] = original_state_dict.pop(f"vector_in.out_layer.{lora_key}.bias") + + # guidance + has_guidance = any("guidance" in k for k in original_state_dict) + if has_guidance: + converted_state_dict[ + f"time_text_embed.guidance_embedder.linear_1.{diffusers_lora_key}.weight" + ] = original_state_dict.pop(f"guidance_in.in_layer.{lora_key}.weight") + if f"guidance_in.in_layer.{lora_key}.bias" in original_state_dict.keys(): + converted_state_dict[ + f"time_text_embed.guidance_embedder.linear_1.{diffusers_lora_key}.bias" + ] = original_state_dict.pop(f"guidance_in.in_layer.{lora_key}.bias") + + converted_state_dict[ + f"time_text_embed.guidance_embedder.linear_2.{diffusers_lora_key}.weight" + ] = original_state_dict.pop(f"guidance_in.out_layer.{lora_key}.weight") + if f"guidance_in.out_layer.{lora_key}.bias" in original_state_dict.keys(): + converted_state_dict[ + f"time_text_embed.guidance_embedder.linear_2.{diffusers_lora_key}.bias" + ] = original_state_dict.pop(f"guidance_in.out_layer.{lora_key}.bias") + + # context_embedder + converted_state_dict[f"context_embedder.{diffusers_lora_key}.weight"] = original_state_dict.pop( + f"txt_in.{lora_key}.weight" + ) + if f"txt_in.{lora_key}.bias" in original_state_dict.keys(): + converted_state_dict[f"context_embedder.{diffusers_lora_key}.bias"] = original_state_dict.pop( + f"txt_in.{lora_key}.bias" + ) + + # x_embedder + converted_state_dict[f"x_embedder.{diffusers_lora_key}.weight"] = original_state_dict.pop( + f"img_in.{lora_key}.weight" + ) + if f"img_in.{lora_key}.bias" in original_state_dict.keys(): + converted_state_dict[f"x_embedder.{diffusers_lora_key}.bias"] = original_state_dict.pop( + f"img_in.{lora_key}.bias" + ) + + # double transformer blocks + for i in range(num_layers): + block_prefix = f"transformer_blocks.{i}." + + for lora_key, diffusers_lora_key in zip(["lora_A", "lora_B"], ["lora_A", "lora_B"]): + # norms + converted_state_dict[f"{block_prefix}norm1.linear.{diffusers_lora_key}.weight"] = original_state_dict.pop( + f"double_blocks.{i}.img_mod.lin.{lora_key}.weight" + ) + if f"double_blocks.{i}.img_mod.lin.{lora_key}.bias" in original_state_dict.keys(): + converted_state_dict[ + f"{block_prefix}norm1.linear.{diffusers_lora_key}.bias" + ] = original_state_dict.pop(f"double_blocks.{i}.img_mod.lin.{lora_key}.bias") + + converted_state_dict[ + f"{block_prefix}norm1_context.linear.{diffusers_lora_key}.weight" + ] = original_state_dict.pop(f"double_blocks.{i}.txt_mod.lin.{lora_key}.weight") + if f"double_blocks.{i}.txt_mod.lin.{lora_key}.bias" in original_state_dict.keys(): + converted_state_dict[ + f"{block_prefix}norm1_context.linear.{diffusers_lora_key}.bias" + ] = original_state_dict.pop(f"double_blocks.{i}.txt_mod.lin.{lora_key}.bias") + + # Q, K, V + if lora_key == "lora_A": + sample_lora_weight = original_state_dict.pop(f"double_blocks.{i}.img_attn.qkv.{lora_key}.weight") + converted_state_dict[f"{block_prefix}attn.to_v.{diffusers_lora_key}.weight"] = torch.cat( + [sample_lora_weight] + ) + converted_state_dict[f"{block_prefix}attn.to_q.{diffusers_lora_key}.weight"] = torch.cat( + [sample_lora_weight] + ) + converted_state_dict[f"{block_prefix}attn.to_k.{diffusers_lora_key}.weight"] = torch.cat( + [sample_lora_weight] + ) + + context_lora_weight = original_state_dict.pop(f"double_blocks.{i}.txt_attn.qkv.{lora_key}.weight") + converted_state_dict[f"{block_prefix}attn.add_q_proj.{diffusers_lora_key}.weight"] = torch.cat( + [context_lora_weight] + ) + converted_state_dict[f"{block_prefix}attn.add_k_proj.{diffusers_lora_key}.weight"] = torch.cat( + [context_lora_weight] + ) + converted_state_dict[f"{block_prefix}attn.add_v_proj.{diffusers_lora_key}.weight"] = torch.cat( + [context_lora_weight] + ) + else: + sample_q, sample_k, sample_v = torch.chunk( + original_state_dict.pop(f"double_blocks.{i}.img_attn.qkv.{lora_key}.weight"), 3, dim=0 + ) + converted_state_dict[f"{block_prefix}attn.to_q.{diffusers_lora_key}.weight"] = torch.cat([sample_q]) + converted_state_dict[f"{block_prefix}attn.to_k.{diffusers_lora_key}.weight"] = torch.cat([sample_k]) + converted_state_dict[f"{block_prefix}attn.to_v.{diffusers_lora_key}.weight"] = torch.cat([sample_v]) + + context_q, context_k, context_v = torch.chunk( + original_state_dict.pop(f"double_blocks.{i}.txt_attn.qkv.{lora_key}.weight"), 3, dim=0 + ) + converted_state_dict[f"{block_prefix}attn.add_q_proj.{diffusers_lora_key}.weight"] = torch.cat( + [context_q] + ) + converted_state_dict[f"{block_prefix}attn.add_k_proj.{diffusers_lora_key}.weight"] = torch.cat( + [context_k] + ) + converted_state_dict[f"{block_prefix}attn.add_v_proj.{diffusers_lora_key}.weight"] = torch.cat( + [context_v] + ) + + if f"double_blocks.{i}.img_attn.qkv.{lora_key}.bias" in original_state_dict.keys(): + sample_q_bias, sample_k_bias, sample_v_bias = torch.chunk( + original_state_dict.pop(f"double_blocks.{i}.img_attn.qkv.{lora_key}.bias"), 3, dim=0 + ) + converted_state_dict[f"{block_prefix}attn.to_q.{diffusers_lora_key}.bias"] = torch.cat([sample_q_bias]) + converted_state_dict[f"{block_prefix}attn.to_k.{diffusers_lora_key}.bias"] = torch.cat([sample_k_bias]) + converted_state_dict[f"{block_prefix}attn.to_v.{diffusers_lora_key}.bias"] = torch.cat([sample_v_bias]) + + if f"double_blocks.{i}.txt_attn.qkv.{lora_key}.bias" in original_state_dict.keys(): + context_q_bias, context_k_bias, context_v_bias = torch.chunk( + original_state_dict.pop(f"double_blocks.{i}.txt_attn.qkv.{lora_key}.bias"), 3, dim=0 + ) + converted_state_dict[f"{block_prefix}attn.add_q_proj.{diffusers_lora_key}.bias"] = torch.cat( + [context_q_bias] + ) + converted_state_dict[f"{block_prefix}attn.add_k_proj.{diffusers_lora_key}.bias"] = torch.cat( + [context_k_bias] + ) + converted_state_dict[f"{block_prefix}attn.add_v_proj.{diffusers_lora_key}.bias"] = torch.cat( + [context_v_bias] + ) + + # ff img_mlp + converted_state_dict[f"{block_prefix}ff.net.0.proj.{diffusers_lora_key}.weight"] = original_state_dict.pop( + f"double_blocks.{i}.img_mlp.0.{lora_key}.weight" + ) + if f"double_blocks.{i}.img_mlp.0.{lora_key}.bias" in original_state_dict.keys(): + converted_state_dict[ + f"{block_prefix}ff.net.0.proj{diffusers_lora_key}..bias" + ] = original_state_dict.pop(f"double_blocks.{i}.img_mlp.0.{lora_key}.bias") + + converted_state_dict[f"{block_prefix}ff.net.2.{diffusers_lora_key}.weight"] = original_state_dict.pop( + f"double_blocks.{i}.img_mlp.2.{lora_key}.weight" + ) + if f"double_blocks.{i}.img_mlp.2.{lora_key}.bias" in original_state_dict.keys(): + converted_state_dict[f"{block_prefix}ff.net.2.{diffusers_lora_key}.bias"] = original_state_dict.pop( + f"double_blocks.{i}.img_mlp.2.{lora_key}.bias" + ) + + converted_state_dict[ + f"{block_prefix}ff_context.net.0.proj.{diffusers_lora_key}.weight" + ] = original_state_dict.pop(f"double_blocks.{i}.txt_mlp.0.{lora_key}.weight") + if f"double_blocks.{i}.txt_mlp.0.{lora_key}.bias" in original_state_dict.keys(): + converted_state_dict[ + f"{block_prefix}ff_context.net.0.proj.{diffusers_lora_key}.bias" + ] = original_state_dict.pop(f"double_blocks.{i}.txt_mlp.0.{lora_key}.bias") + + converted_state_dict[ + f"{block_prefix}ff_context.net.2.{diffusers_lora_key}.weight" + ] = original_state_dict.pop(f"double_blocks.{i}.txt_mlp.2.{lora_key}.weight") + if f"double_blocks.{i}.txt_mlp.2.{lora_key}.bias" in original_state_dict.keys(): + converted_state_dict[ + f"{block_prefix}ff_context.net.2.{diffusers_lora_key}.bias" + ] = original_state_dict.pop(f"double_blocks.{i}.txt_mlp.2.{lora_key}.bias") + + # output projections. + converted_state_dict[f"{block_prefix}attn.to_out.0.{diffusers_lora_key}.weight"] = original_state_dict.pop( + f"double_blocks.{i}.img_attn.proj.{lora_key}.weight" + ) + if f"double_blocks.{i}.img_attn.proj.{lora_key}.bias" in original_state_dict.keys(): + converted_state_dict[ + f"{block_prefix}attn.to_out.0.{diffusers_lora_key}.bias" + ] = original_state_dict.pop(f"double_blocks.{i}.img_attn.proj.{lora_key}.bias") + converted_state_dict[ + f"{block_prefix}attn.to_add_out.{diffusers_lora_key}.weight" + ] = original_state_dict.pop(f"double_blocks.{i}.txt_attn.proj.{lora_key}.weight") + if f"double_blocks.{i}.txt_attn.proj.{lora_key}.bias" in original_state_dict.keys(): + converted_state_dict[ + f"{block_prefix}attn.to_add_out.{diffusers_lora_key}.bias" + ] = original_state_dict.pop(f"double_blocks.{i}.txt_attn.proj.{lora_key}.bias") + + # qk_norm + converted_state_dict[f"{block_prefix}attn.norm_q.weight"] = original_state_dict.pop( + f"double_blocks.{i}.img_attn.norm.query_norm.scale" + ) + converted_state_dict[f"{block_prefix}attn.norm_k.weight"] = original_state_dict.pop( + f"double_blocks.{i}.img_attn.norm.key_norm.scale" + ) + converted_state_dict[f"{block_prefix}attn.norm_added_q.weight"] = original_state_dict.pop( + f"double_blocks.{i}.txt_attn.norm.query_norm.scale" + ) + converted_state_dict[f"{block_prefix}attn.norm_added_k.weight"] = original_state_dict.pop( + f"double_blocks.{i}.txt_attn.norm.key_norm.scale" + ) + + # single transfomer blocks + for i in range(num_single_layers): + block_prefix = f"single_transformer_blocks.{i}." + + for lora_key, diffusers_lora_key in zip(["lora_A", "lora_B"], ["lora_A", "lora_B"]): + # norm.linear <- single_blocks.0.modulation.lin + converted_state_dict[f"{block_prefix}norm.linear.{diffusers_lora_key}.weight"] = original_state_dict.pop( + f"single_blocks.{i}.modulation.lin.{lora_key}.weight" + ) + if f"single_blocks.{i}.modulation.lin.{lora_key}.bias" in original_state_dict.keys(): + converted_state_dict[f"{block_prefix}norm.linear.{diffusers_lora_key}.bias"] = original_state_dict.pop( + f"single_blocks.{i}.modulation.lin.{lora_key}.bias" + ) + + # Q, K, V, mlp + mlp_hidden_dim = int(inner_dim * mlp_ratio) + split_size = (inner_dim, inner_dim, inner_dim, mlp_hidden_dim) + + if lora_key == "lora_A": + lora_weight = original_state_dict.pop(f"single_blocks.{i}.linear1.{lora_key}.weight") + converted_state_dict[f"{block_prefix}attn.to_q.{diffusers_lora_key}.weight"] = torch.cat([lora_weight]) + converted_state_dict[f"{block_prefix}attn.to_k.{diffusers_lora_key}.weight"] = torch.cat([lora_weight]) + converted_state_dict[f"{block_prefix}attn.to_v.{diffusers_lora_key}.weight"] = torch.cat([lora_weight]) + converted_state_dict[f"{block_prefix}proj_mlp.{diffusers_lora_key}.weight"] = torch.cat([lora_weight]) + + if f"single_blocks.{i}.linear1.{lora_key}.bias" in original_state_dict.keys(): + lora_bias = original_state_dict.pop(f"single_blocks.{i}.linear1.{lora_key}.bias") + converted_state_dict[f"{block_prefix}attn.to_q.{diffusers_lora_key}.bias"] = torch.cat([lora_bias]) + converted_state_dict[f"{block_prefix}attn.to_k.{diffusers_lora_key}.bias"] = torch.cat([lora_bias]) + converted_state_dict[f"{block_prefix}attn.to_v.{diffusers_lora_key}.bias"] = torch.cat([lora_bias]) + converted_state_dict[f"{block_prefix}proj_mlp.{diffusers_lora_key}.bias"] = torch.cat([lora_bias]) + else: + q, k, v, mlp = torch.split( + original_state_dict.pop(f"single_blocks.{i}.linear1.{lora_key}.weight"), split_size, dim=0 + ) + converted_state_dict[f"{block_prefix}attn.to_q.{diffusers_lora_key}.weight"] = torch.cat([q]) + converted_state_dict[f"{block_prefix}attn.to_k.{diffusers_lora_key}.weight"] = torch.cat([k]) + converted_state_dict[f"{block_prefix}attn.to_v.{diffusers_lora_key}.weight"] = torch.cat([v]) + converted_state_dict[f"{block_prefix}proj_mlp.{diffusers_lora_key}.weight"] = torch.cat([mlp]) + + if f"single_blocks.{i}.linear1.{lora_key}.bias" in original_state_dict.keys(): + q_bias, k_bias, v_bias, mlp_bias = torch.split( + original_state_dict.pop(f"single_blocks.{i}.linear1.{lora_key}.bias"), split_size, dim=0 + ) + converted_state_dict[f"{block_prefix}attn.to_q.{diffusers_lora_key}.bias"] = torch.cat([q_bias]) + converted_state_dict[f"{block_prefix}attn.to_k.{diffusers_lora_key}.bias"] = torch.cat([k_bias]) + converted_state_dict[f"{block_prefix}attn.to_v.{diffusers_lora_key}.bias"] = torch.cat([v_bias]) + converted_state_dict[f"{block_prefix}proj_mlp.{diffusers_lora_key}.bias"] = torch.cat([mlp_bias]) + + # output projections. + converted_state_dict[f"{block_prefix}proj_out.{diffusers_lora_key}.weight"] = original_state_dict.pop( + f"single_blocks.{i}.linear2.{lora_key}.weight" + ) + if f"single_blocks.{i}.linear2.{lora_key}.bias" in original_state_dict.keys(): + converted_state_dict[f"{block_prefix}proj_out.{diffusers_lora_key}.bias"] = original_state_dict.pop( + f"single_blocks.{i}.linear2.{lora_key}.bias" + ) + + # qk norm + converted_state_dict[f"{block_prefix}attn.norm_q.weight"] = original_state_dict.pop( + f"single_blocks.{i}.norm.query_norm.scale" + ) + converted_state_dict[f"{block_prefix}attn.norm_k.weight"] = original_state_dict.pop( + f"single_blocks.{i}.norm.key_norm.scale" + ) + + for lora_key, diffusers_lora_key in zip(["lora_A", "lora_B"], ["lora_A", "lora_B"]): + converted_state_dict[f"proj_out.{diffusers_lora_key}.weight"] = original_state_dict.pop( + f"final_layer.linear.{lora_key}.weight" + ) + if f"final_layer.linear.{lora_key}.bias" in original_state_dict.keys(): + converted_state_dict[f"proj_out.{diffusers_lora_key}.bias"] = original_state_dict.pop( + f"final_layer.linear.{lora_key}.bias" + ) + + converted_state_dict[f"norm_out.linear.{diffusers_lora_key}.weight"] = swap_scale_shift( + original_state_dict.pop(f"final_layer.adaLN_modulation.1.{lora_key}.weight") + ) + if f"final_layer.adaLN_modulation.1.{lora_key}.bias" in original_state_dict.keys(): + converted_state_dict[f"norm_out.linear.{diffusers_lora_key}.bias"] = swap_scale_shift( + original_state_dict.pop(f"final_layer.adaLN_modulation.1.{lora_key}.bias") + ) + + print("Remaining:", original_state_dict.keys()) + + for key in list(converted_state_dict.keys()): + converted_state_dict[f"transformer.{key}"] = converted_state_dict.pop(key) + + return converted_state_dict + + +def main(args): + original_ckpt = load_original_checkpoint(args) + + num_layers = 19 + num_single_layers = 38 + inner_dim = 3072 + mlp_ratio = 4.0 + + converted_control_lora_state_dict = convert_flux_control_lora_checkpoint_to_diffusers( + original_ckpt, num_layers, num_single_layers, inner_dim, mlp_ratio + ) + safetensors.torch.save_file(converted_control_lora_state_dict, args.output_path) + + +if __name__ == "__main__": + main(args) diff --git a/src/diffusers/loaders/lora_pipeline.py b/src/diffusers/loaders/lora_pipeline.py index 109592c69c3e..a0f4b582fbe8 100644 --- a/src/diffusers/loaders/lora_pipeline.py +++ b/src/diffusers/loaders/lora_pipeline.py @@ -1787,14 +1787,41 @@ def load_lora_weights( pretrained_model_name_or_path_or_dict, return_alphas=True, **kwargs ) - is_correct_format = all("lora" in key for key in state_dict.keys()) - if not is_correct_format: + has_lora_keys = any("lora" in key for key in state_dict.keys()) + + # Flux Control LoRAs also have norm keys + supported_norm_keys = ["norm_q", "norm_k", "norm_added_q", "norm_added_k"] + has_norm_keys = any(norm_key in key for key in state_dict.keys() for norm_key in supported_norm_keys) + + if not (has_lora_keys or has_norm_keys): raise ValueError("Invalid LoRA checkpoint.") - transformer_state_dict = {k: v for k, v in state_dict.items() if "transformer." in k} - if len(transformer_state_dict) > 0: + def prune_state_dict_(state_dict): + pruned_keys = [] + for key in list(state_dict.keys()): + is_lora_key_present = "lora" in key + is_norm_key_present = any(norm_key in key for norm_key in supported_norm_keys) + if not is_lora_key_present and not is_norm_key_present: + state_dict.pop(key) + pruned_keys.append(key) + return pruned_keys + + pruned_keys = prune_state_dict_(state_dict) + if len(pruned_keys) > 0: + logger.warning( + f"The provided LoRA state dict contains additional weights that are not compatible with Flux. The following are the incompatible weights:\n{pruned_keys}" + ) + + transformer_lora_state_dict = {k: v for k, v in state_dict.items() if "transformer." in k and "lora" in k} + transformer_norm_state_dict = { + k: v + for k, v in state_dict.items() + if "transformer." in k and any(norm_key in k for norm_key in supported_norm_keys) + } + + if len(transformer_lora_state_dict) > 0: self.load_lora_into_transformer( - state_dict, + transformer_lora_state_dict, network_alphas=network_alphas, transformer=getattr(self, self.transformer_name) if not hasattr(self, "transformer") @@ -1804,6 +1831,14 @@ def load_lora_weights( low_cpu_mem_usage=low_cpu_mem_usage, ) + if len(transformer_norm_state_dict) > 0: + self.load_norm_into_transformer( + transformer_norm_state_dict, + transformer=getattr(self, self.transformer_name) + if not hasattr(self, "transformer") + else self.transformer, + ) + text_encoder_state_dict = {k: v for k, v in state_dict.items() if "text_encoder." in k} if len(text_encoder_state_dict) > 0: self.load_lora_into_text_encoder( @@ -1860,6 +1895,15 @@ def load_lora_into_transformer( low_cpu_mem_usage=low_cpu_mem_usage, ) + @classmethod + def load_norm_into_transformer( + cls, + state_dict, + transformer: torch.nn.Module, + ): + print(state_dict.keys()) + transformer.load_state_dict(state_dict, strict=True) + @classmethod # Copied from diffusers.loaders.lora_pipeline.StableDiffusionLoraLoaderMixin.load_lora_into_text_encoder def load_lora_into_text_encoder( diff --git a/src/diffusers/loaders/peft.py b/src/diffusers/loaders/peft.py index bf118c88b2de..fa593c2ecd63 100644 --- a/src/diffusers/loaders/peft.py +++ b/src/diffusers/loaders/peft.py @@ -216,7 +216,9 @@ def load_lora_adapter(self, pretrained_model_name_or_path_or_dict, prefix="trans rank = {} for key, val in state_dict.items(): - if "lora_B" in key: + # Cannot figure out rank from lora layers that don't have atleast 2 dimensions. + # Bias layers in LoRA only have a single dimension + if "lora_B" in key and val.ndim > 1: rank[key] = val.shape[1] if network_alphas is not None and len(network_alphas) >= 1: From b4f1cbf28faef9d4f4a1839083eb6032db3888d9 Mon Sep 17 00:00:00 2001 From: Aryan Date: Fri, 22 Nov 2024 02:40:05 +0100 Subject: [PATCH 05/58] control lora updates --- src/diffusers/loaders/lora_pipeline.py | 136 +++++++++++++++++++++++-- 1 file changed, 125 insertions(+), 11 deletions(-) diff --git a/src/diffusers/loaders/lora_pipeline.py b/src/diffusers/loaders/lora_pipeline.py index a0f4b582fbe8..1b94dc47b002 100644 --- a/src/diffusers/loaders/lora_pipeline.py +++ b/src/diffusers/loaders/lora_pipeline.py @@ -1819,24 +1819,27 @@ def prune_state_dict_(state_dict): if "transformer." in k and any(norm_key in k for norm_key in supported_norm_keys) } + transformer = getattr(self, self.transformer_name) if not hasattr(self, "transformer") else self.transformer + self._maybe_expand_transformer_param_shape_( + transformer, transformer_lora_state_dict, transformer_norm_state_dict + ) + print(transformer) + if len(transformer_lora_state_dict) > 0: self.load_lora_into_transformer( transformer_lora_state_dict, network_alphas=network_alphas, - transformer=getattr(self, self.transformer_name) - if not hasattr(self, "transformer") - else self.transformer, + transformer=transformer, adapter_name=adapter_name, _pipeline=self, low_cpu_mem_usage=low_cpu_mem_usage, ) if len(transformer_norm_state_dict) > 0: - self.load_norm_into_transformer( + self._transformer_norm_layers = self.load_norm_into_transformer( transformer_norm_state_dict, - transformer=getattr(self, self.transformer_name) - if not hasattr(self, "transformer") - else self.transformer, + transformer=transformer, + discard_original_layers=False, ) text_encoder_state_dict = {k: v for k, v in state_dict.items() if "text_encoder." in k} @@ -1899,10 +1902,41 @@ def load_lora_into_transformer( def load_norm_into_transformer( cls, state_dict, - transformer: torch.nn.Module, - ): - print(state_dict.keys()) - transformer.load_state_dict(state_dict, strict=True) + transformer, + prefix=None, + discard_original_layers=False, + ) -> Dict[str, torch.Tensor]: + # Remove prefix if present + prefix = prefix or cls.transformer_name + for key in list(state_dict.keys()): + if key.split(".")[0] == prefix: + state_dict[key.replace(f"{prefix}.", "")] = state_dict.pop(key) + + # Find invalid keys + transformer_state_dict = transformer.state_dict() + transformer_keys = set(transformer_state_dict.keys()) + state_dict_keys = set(state_dict.keys()) + extra_keys = list(state_dict_keys - transformer_keys) + logger.warning( + f"Unsupported keys found in state dict when trying to load normalization layers into the transformer. The following keys will be ignored:\n{extra_keys}." + ) + + for key in extra_keys: + state_dict.pop(key) + + # Save the layers that are going to be overwritten so that unload_lora_weights can work as expected + overwritten_layers = {} + if not discard_original_layers: + for key in state_dict.keys(): + overwritten_layers[key] = transformer_state_dict[key] + + # We can't load with strict=True because the current state_dict does not contain all the transformer keys + logger.info( + "Normalization layers in LoRA state dict can only be loaded if fused directly in the transformer. Calls to `.fuse_lora()` will only affect the LoRA layers and not the normalization layers." + ) + transformer.load_state_dict(state_dict, strict=False) + + return overwritten_layers @classmethod # Copied from diffusers.loaders.lora_pipeline.StableDiffusionLoraLoaderMixin.load_lora_into_text_encoder @@ -2139,6 +2173,11 @@ def fuse_lora( pipeline.fuse_lora(lora_scale=0.7) ``` """ + if len(self._transformer_norm_layers.keys()) > 0: + logger.info( + "Normalization layers cannot be loaded without fusing. Calls to `.fuse_lora()` will only affect the actual LoRA layers." + ) + super().fuse_lora( components=components, lora_scale=lora_scale, safe_fusing=safe_fusing, adapter_names=adapter_names ) @@ -2157,8 +2196,83 @@ def unfuse_lora(self, components: List[str] = ["transformer", "text_encoder"], * Args: components (`List[str]`): List of LoRA-injectable components to unfuse LoRA from. """ + transformer = getattr(self, self.transformer_name) if not hasattr(self, "transformer") else self.transformer + transformer.load_state_dict(self._transformer_norm_layers) + super().unfuse_lora(components=components) + @classmethod + def _maybe_expand_transformer_param_shape_( + cls, + transformer: torch.nn.Module, + lora_state_dict=None, + norm_state_dict=None, + prefix=None, + ): + state_dict = {} + if lora_state_dict is not None: + state_dict.update(lora_state_dict) + if norm_state_dict is not None: + state_dict.update(norm_state_dict) + + # Remove prefix if present + prefix = prefix or cls.transformer_name + for key in list(state_dict.keys()): + if key.split(".")[0] == prefix: + state_dict[key.replace(f"{prefix}.", "")] = state_dict.pop(key) + + def get_submodule(module, name): + for part in name.split("."): + if len(name) == 0: + break + if not hasattr(module, part): + raise AttributeError(f"Submodule '{part}' not found in '{module}'.") + module = getattr(module, part) + return module + + # Expand transformer parameter shapes if they don't match lora + for name, module in transformer.named_modules(): + if isinstance(module, torch.nn.Linear): + module_weight = module.weight.data + module_bias = module.bias.data if hasattr(module, "bias") else None + bias = module_bias is not None + name_split = name.split(".") + + lora_A_name = f"{name}.lora_A" + lora_B_name = f"{name}.lora_B" + lora_A_weight_name = f"{lora_A_name}.weight" + lora_B_weight_name = f"{lora_B_name}.weight" + + if lora_A_weight_name not in state_dict.keys(): + continue + + in_features = state_dict[lora_A_weight_name].shape[1] + out_features = state_dict[lora_B_weight_name].shape[0] + + if tuple(module_weight.shape) == (out_features, in_features): + continue + + parent_module_name = ".".join(name_split[:-1]) + current_module_name = name_split[-1] + parent_module = get_submodule(transformer, parent_module_name) + + expanded_module = torch.nn.Linear( + in_features, out_features, bias=bias, device=module_weight.device, dtype=module_weight.dtype + ) + + new_weight = module_weight.new_zeros(expanded_module.weight.data.shape) + slices = tuple(slice(0, dim) for dim in module_weight.shape) + new_weight[slices] = module_weight + expanded_module.weight.data.copy_(new_weight) + + if bias: + new_bias = module_bias.new_zeros(expanded_module.bias.data.shape) + slices = tuple(slice(0, dim) for dim in module_bias.shape) + new_bias[slices] = module_bias + expanded_module.bias.data.copy_(new_bias) + + setattr(parent_module, current_module_name, expanded_module) + # The reason why we subclass from `StableDiffusionLoraLoaderMixin` here is because Amused initially # relied on `StableDiffusionLoraLoaderMixin` for its LoRA support. From 414b30b37df93ebb2db482fff70dfa0e6dcf7529 Mon Sep 17 00:00:00 2001 From: Aryan Date: Fri, 22 Nov 2024 02:42:50 +0100 Subject: [PATCH 06/58] remove copied-from --- src/diffusers/loaders/lora_pipeline.py | 1 - 1 file changed, 1 deletion(-) diff --git a/src/diffusers/loaders/lora_pipeline.py b/src/diffusers/loaders/lora_pipeline.py index 1b94dc47b002..b34370e620c3 100644 --- a/src/diffusers/loaders/lora_pipeline.py +++ b/src/diffusers/loaders/lora_pipeline.py @@ -2133,7 +2133,6 @@ def save_lora_weights( safe_serialization=safe_serialization, ) - # Copied from diffusers.loaders.lora_pipeline.StableDiffusionLoraLoaderMixin.fuse_lora with unet->transformer def fuse_lora( self, components: List[str] = ["transformer", "text_encoder"], From 6b02ac201a85d15e0ce6fa16e01cd31c0bf50d28 Mon Sep 17 00:00:00 2001 From: Aryan Date: Fri, 22 Nov 2024 06:25:47 +0100 Subject: [PATCH 07/58] create separate pipelines for flux control --- src/diffusers/__init__.py | 4 + src/diffusers/loaders/peft.py | 1 + src/diffusers/pipelines/__init__.py | 4 + src/diffusers/pipelines/flux/__init__.py | 4 + src/diffusers/pipelines/flux/pipeline_flux.py | 86 +- .../pipelines/flux/pipeline_flux_control.py | 869 ++++++++++++++++ .../flux/pipeline_flux_control_img2img.py | 933 ++++++++++++++++++ .../pipelines/flux/pipeline_flux_img2img.py | 83 +- 8 files changed, 1821 insertions(+), 163 deletions(-) create mode 100644 src/diffusers/pipelines/flux/pipeline_flux_control.py create mode 100644 src/diffusers/pipelines/flux/pipeline_flux_control_img2img.py diff --git a/src/diffusers/__init__.py b/src/diffusers/__init__.py index f334a8c50768..e6acfc17109f 100644 --- a/src/diffusers/__init__.py +++ b/src/diffusers/__init__.py @@ -269,9 +269,11 @@ "CogVideoXVideoToVideoPipeline", "CogView3PlusPipeline", "CycleDiffusionPipeline", + "FluxControlImg2ImgPipeline", "FluxControlNetImg2ImgPipeline", "FluxControlNetInpaintPipeline", "FluxControlNetPipeline", + "FluxControlPipeline", "FluxFillPipeline", "FluxImg2ImgPipeline", "FluxInpaintPipeline", @@ -735,9 +737,11 @@ CogVideoXVideoToVideoPipeline, CogView3PlusPipeline, CycleDiffusionPipeline, + FluxControlImg2ImgPipeline, FluxControlNetImg2ImgPipeline, FluxControlNetInpaintPipeline, FluxControlNetPipeline, + FluxControlPipeline, FluxFillPipeline, FluxImg2ImgPipeline, FluxInpaintPipeline, diff --git a/src/diffusers/loaders/peft.py b/src/diffusers/loaders/peft.py index fa593c2ecd63..7b842e752699 100644 --- a/src/diffusers/loaders/peft.py +++ b/src/diffusers/loaders/peft.py @@ -226,6 +226,7 @@ def load_lora_adapter(self, pretrained_model_name_or_path_or_dict, prefix="trans network_alphas = {k.replace(f"{prefix}.", ""): v for k, v in network_alphas.items() if k in alpha_keys} lora_config_kwargs = get_peft_kwargs(rank, network_alpha_dict=network_alphas, peft_state_dict=state_dict) + print(lora_config_kwargs) if "use_dora" in lora_config_kwargs: if lora_config_kwargs["use_dora"]: if is_peft_version("<", "0.9.0"): diff --git a/src/diffusers/pipelines/__init__.py b/src/diffusers/pipelines/__init__.py index 4c5bcc824ffa..4373bda25e21 100644 --- a/src/diffusers/pipelines/__init__.py +++ b/src/diffusers/pipelines/__init__.py @@ -127,6 +127,8 @@ "AnimateDiffVideoToVideoControlNetPipeline", ] _import_structure["flux"] = [ + "FluxControlPipeline", + "FluxControlImg2ImgPipeline", "FluxControlNetPipeline", "FluxControlNetImg2ImgPipeline", "FluxControlNetInpaintPipeline", @@ -522,9 +524,11 @@ VQDiffusionPipeline, ) from .flux import ( + FluxControlImg2ImgPipeline, FluxControlNetImg2ImgPipeline, FluxControlNetInpaintPipeline, FluxControlNetPipeline, + FluxControlPipeline, FluxFillPipeline, FluxImg2ImgPipeline, FluxInpaintPipeline, diff --git a/src/diffusers/pipelines/flux/__init__.py b/src/diffusers/pipelines/flux/__init__.py index a19019aaf986..083e80c7360e 100644 --- a/src/diffusers/pipelines/flux/__init__.py +++ b/src/diffusers/pipelines/flux/__init__.py @@ -23,6 +23,8 @@ _dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects)) else: _import_structure["pipeline_flux"] = ["FluxPipeline"] + _import_structure["pipeline_flux_control"] = ["FluxControlPipeline"] + _import_structure["pipeline_flux_control_img2img"] = ["FluxControlImg2ImgPipeline"] _import_structure["pipeline_flux_controlnet"] = ["FluxControlNetPipeline"] _import_structure["pipeline_flux_controlnet_image_to_image"] = ["FluxControlNetImg2ImgPipeline"] _import_structure["pipeline_flux_controlnet_inpainting"] = ["FluxControlNetInpaintPipeline"] @@ -37,6 +39,8 @@ from ...utils.dummy_torch_and_transformers_objects import * # noqa F403 else: from .pipeline_flux import FluxPipeline + from .pipeline_flux_control import FluxControlPipeline + from .pipeline_flux_control_img2img import FluxControlImg2ImgPipeline from .pipeline_flux_controlnet import FluxControlNetPipeline from .pipeline_flux_controlnet_image_to_image import FluxControlNetImg2ImgPipeline from .pipeline_flux_controlnet_inpainting import FluxControlNetInpaintPipeline diff --git a/src/diffusers/pipelines/flux/pipeline_flux.py b/src/diffusers/pipelines/flux/pipeline_flux.py index 40975c2cc0db..e0add1e60ce2 100644 --- a/src/diffusers/pipelines/flux/pipeline_flux.py +++ b/src/diffusers/pipelines/flux/pipeline_flux.py @@ -19,7 +19,7 @@ import torch from transformers import CLIPTextModel, CLIPTokenizer, T5EncoderModel, T5TokenizerFast -from ...image_processor import PipelineImageInput, VaeImageProcessor +from ...image_processor import VaeImageProcessor from ...loaders import FluxLoraLoaderMixin, FromSingleFileMixin, TextualInversionLoaderMixin from ...models.autoencoders import AutoencoderKL from ...models.transformers import FluxTransformer2DModel @@ -529,41 +529,6 @@ def prepare_latents( return latents, latent_image_ids - # Copied from diffusers.pipelines.controlnet_sd3.pipeline_stable_diffusion_3_controlnet.StableDiffusion3ControlNetPipeline.prepare_image - def prepare_image( - self, - image, - width, - height, - batch_size, - num_images_per_prompt, - device, - dtype, - do_classifier_free_guidance=False, - guess_mode=False, - ): - if isinstance(image, torch.Tensor): - pass - else: - image = self.image_processor.preprocess(image, height=height, width=width) - - image_batch_size = image.shape[0] - - if image_batch_size == 1: - repeat_by = batch_size - else: - # image batch size is the same as prompt batch size - repeat_by = num_images_per_prompt - - image = image.repeat_interleave(repeat_by, dim=0) - - image = image.to(device=device, dtype=dtype) - - if do_classifier_free_guidance and not guess_mode: - image = torch.cat([image] * 2) - - return image - @property def guidance_scale(self): return self._guidance_scale @@ -591,11 +556,9 @@ def __call__( num_inference_steps: int = 28, timesteps: List[int] = None, guidance_scale: float = 3.5, - control_image: PipelineImageInput = None, num_images_per_prompt: Optional[int] = 1, generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, latents: Optional[torch.FloatTensor] = None, - control_latents: Optional[torch.FloatTensor] = None, prompt_embeds: Optional[torch.FloatTensor] = None, pooled_prompt_embeds: Optional[torch.FloatTensor] = None, output_type: Optional[str] = "pil", @@ -632,14 +595,6 @@ def __call__( Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, usually at the expense of lower image quality. - control_image (`torch.Tensor`, `PIL.Image.Image`, `np.ndarray`, `List[torch.Tensor]`, `List[PIL.Image.Image]`, `List[np.ndarray]`,: - `List[List[torch.Tensor]]`, `List[List[np.ndarray]]` or `List[List[PIL.Image.Image]]`): - The ControlNet input condition to provide guidance to the `unet` for generation. If the type is - specified as `torch.Tensor`, it is passed to ControlNet as is. `PIL.Image.Image` can also be accepted - as an image. The dimensions of the output image defaults to `image`'s dimensions. If height and/or - width are passed, `image` is resized accordingly. If multiple ControlNets are specified in `init`, - images must be passed as a list such that each element of the list can be correctly batched for input - to a single ControlNet. num_images_per_prompt (`int`, *optional*, defaults to 1): The number of images to generate per prompt. generator (`torch.Generator` or `List[torch.Generator]`, *optional*): @@ -712,7 +667,6 @@ def __call__( device = self._execution_device - # 3. Prepare text embeddings lora_scale = ( self.joint_attention_kwargs.get("scale", None) if self.joint_attention_kwargs is not None else None ) @@ -732,35 +686,7 @@ def __call__( ) # 4. Prepare latent variables - num_channels_latents = ( - self.transformer.config.in_channels // 4 - if control_image is None - else self.transformer.config.in_channels // 8 - ) - - if control_image is not None and control_latents is None: - control_image = self.prepare_image( - image=control_image, - width=width, - height=height, - batch_size=batch_size * num_images_per_prompt, - num_images_per_prompt=num_images_per_prompt, - device=device, - dtype=self.vae.dtype, - ) - - control_latents = self.vae.encode(control_image).latent_dist.sample(generator=generator) - control_latents = (control_latents - self.vae.config.shift_factor) * self.vae.config.scaling_factor - - height_control_image, width_control_image = control_latents.shape[2:] - control_latents = self._pack_latents( - control_latents, - batch_size * num_images_per_prompt, - num_channels_latents, - height_control_image, - width_control_image, - ) - + num_channels_latents = self.transformer.config.in_channels // 4 latents, latent_image_ids = self.prepare_latents( batch_size * num_images_per_prompt, num_channels_latents, @@ -806,16 +732,11 @@ def __call__( if self.interrupt: continue - if control_latents is not None: - latent_model_input = torch.cat([latents, control_latents], dim=2) - else: - latent_model_input = latents - # broadcast to batch dimension in a way that's compatible with ONNX/Core ML timestep = t.expand(latents.shape[0]).to(latents.dtype) noise_pred = self.transformer( - hidden_states=latent_model_input, + hidden_states=latents, timestep=timestep / 1000, guidance=guidance, pooled_projections=pooled_prompt_embeds, @@ -853,6 +774,7 @@ def __call__( if output_type == "latent": image = latents + else: latents = self._unpack_latents(latents, height, width, self.vae_scale_factor) latents = (latents / self.vae.config.scaling_factor) + self.vae.config.shift_factor diff --git a/src/diffusers/pipelines/flux/pipeline_flux_control.py b/src/diffusers/pipelines/flux/pipeline_flux_control.py new file mode 100644 index 000000000000..05af2e090991 --- /dev/null +++ b/src/diffusers/pipelines/flux/pipeline_flux_control.py @@ -0,0 +1,869 @@ +# Copyright 2024 Black Forest Labs and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import inspect +from typing import Any, Callable, Dict, List, Optional, Union + +import numpy as np +import torch +from transformers import CLIPTextModel, CLIPTokenizer, T5EncoderModel, T5TokenizerFast + +from ...image_processor import PipelineImageInput, VaeImageProcessor +from ...loaders import FluxLoraLoaderMixin, FromSingleFileMixin, TextualInversionLoaderMixin +from ...models.autoencoders import AutoencoderKL +from ...models.transformers import FluxTransformer2DModel +from ...schedulers import FlowMatchEulerDiscreteScheduler +from ...utils import ( + USE_PEFT_BACKEND, + is_torch_xla_available, + logging, + replace_example_docstring, + scale_lora_layers, + unscale_lora_layers, +) +from ...utils.torch_utils import randn_tensor +from ..pipeline_utils import DiffusionPipeline +from .pipeline_output import FluxPipelineOutput + + +if is_torch_xla_available(): + import torch_xla.core.xla_model as xm + + XLA_AVAILABLE = True +else: + XLA_AVAILABLE = False + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + +EXAMPLE_DOC_STRING = """ + Examples: + ```py + >>> import torch + >>> from diffusers import FluxControlPipeline + + >>> pipe = FluxControlPipeline.from_pretrained("black-forest-labs/FLUX.1-schnell", torch_dtype=torch.bfloat16) + >>> pipe.to("cuda") + >>> prompt = "A cat holding a sign that says hello world" + >>> # Depending on the variant being used, the pipeline call will slightly vary. + >>> # Refer to the pipeline documentation for more details. + >>> image = pipe(prompt, num_inference_steps=4, guidance_scale=0.0).images[0] + >>> image.save("flux.png") + ``` +""" + + +def calculate_shift( + image_seq_len, + base_seq_len: int = 256, + max_seq_len: int = 4096, + base_shift: float = 0.5, + max_shift: float = 1.16, +): + m = (max_shift - base_shift) / (max_seq_len - base_seq_len) + b = base_shift - m * base_seq_len + mu = image_seq_len * m + b + return mu + + +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps +def retrieve_timesteps( + scheduler, + num_inference_steps: Optional[int] = None, + device: Optional[Union[str, torch.device]] = None, + timesteps: Optional[List[int]] = None, + sigmas: Optional[List[float]] = None, + **kwargs, +): + r""" + Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles + custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`. + + Args: + scheduler (`SchedulerMixin`): + The scheduler to get timesteps from. + num_inference_steps (`int`): + The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps` + must be `None`. + device (`str` or `torch.device`, *optional*): + The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. + timesteps (`List[int]`, *optional*): + Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed, + `num_inference_steps` and `sigmas` must be `None`. + sigmas (`List[float]`, *optional*): + Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed, + `num_inference_steps` and `timesteps` must be `None`. + + Returns: + `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the + second element is the number of inference steps. + """ + if timesteps is not None and sigmas is not None: + raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values") + if timesteps is not None: + accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accepts_timesteps: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" timestep schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + elif sigmas is not None: + accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accept_sigmas: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" sigmas schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + else: + scheduler.set_timesteps(num_inference_steps, device=device, **kwargs) + timesteps = scheduler.timesteps + return timesteps, num_inference_steps + + +class FluxControlPipeline( + DiffusionPipeline, + FluxLoraLoaderMixin, + FromSingleFileMixin, + TextualInversionLoaderMixin, +): + r""" + The Flux pipeline for controllable text-to-image generation. + + Reference: https://blackforestlabs.ai/announcing-black-forest-labs/ + + Args: + transformer ([`FluxTransformer2DModel`]): + Conditional Transformer (MMDiT) architecture to denoise the encoded image latents. + scheduler ([`FlowMatchEulerDiscreteScheduler`]): + A scheduler to be used in combination with `transformer` to denoise the encoded image latents. + vae ([`AutoencoderKL`]): + Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations. + text_encoder ([`CLIPTextModel`]): + [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically + the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant. + text_encoder_2 ([`T5EncoderModel`]): + [T5](https://huggingface.co/docs/transformers/en/model_doc/t5#transformers.T5EncoderModel), specifically + the [google/t5-v1_1-xxl](https://huggingface.co/google/t5-v1_1-xxl) variant. + tokenizer (`CLIPTokenizer`): + Tokenizer of class + [CLIPTokenizer](https://huggingface.co/docs/transformers/en/model_doc/clip#transformers.CLIPTokenizer). + tokenizer_2 (`T5TokenizerFast`): + Second Tokenizer of class + [T5TokenizerFast](https://huggingface.co/docs/transformers/en/model_doc/t5#transformers.T5TokenizerFast). + """ + + model_cpu_offload_seq = "text_encoder->text_encoder_2->transformer->vae" + _optional_components = [] + _callback_tensor_inputs = ["latents", "prompt_embeds"] + + def __init__( + self, + scheduler: FlowMatchEulerDiscreteScheduler, + vae: AutoencoderKL, + text_encoder: CLIPTextModel, + tokenizer: CLIPTokenizer, + text_encoder_2: T5EncoderModel, + tokenizer_2: T5TokenizerFast, + transformer: FluxTransformer2DModel, + ): + super().__init__() + + self.register_modules( + vae=vae, + text_encoder=text_encoder, + text_encoder_2=text_encoder_2, + tokenizer=tokenizer, + tokenizer_2=tokenizer_2, + transformer=transformer, + scheduler=scheduler, + ) + self.vae_scale_factor = ( + 2 ** (len(self.vae.config.block_out_channels) - 1) if hasattr(self, "vae") and self.vae is not None else 8 + ) + # Flux latents are turned into 2x2 patches and packed. This means the latent width and height has to be divisible + # by the patch size. So the vae scale factor is multiplied by the patch size to account for this + self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor * 2) + self.tokenizer_max_length = ( + self.tokenizer.model_max_length if hasattr(self, "tokenizer") and self.tokenizer is not None else 77 + ) + self.default_sample_size = 128 + + # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline._get_t5_prompt_embeds + def _get_t5_prompt_embeds( + self, + prompt: Union[str, List[str]] = None, + num_images_per_prompt: int = 1, + max_sequence_length: int = 512, + device: Optional[torch.device] = None, + dtype: Optional[torch.dtype] = None, + ): + device = device or self._execution_device + dtype = dtype or self.text_encoder.dtype + + prompt = [prompt] if isinstance(prompt, str) else prompt + batch_size = len(prompt) + + if isinstance(self, TextualInversionLoaderMixin): + prompt = self.maybe_convert_prompt(prompt, self.tokenizer_2) + + text_inputs = self.tokenizer_2( + prompt, + padding="max_length", + max_length=max_sequence_length, + truncation=True, + return_length=False, + return_overflowing_tokens=False, + return_tensors="pt", + ) + text_input_ids = text_inputs.input_ids + untruncated_ids = self.tokenizer_2(prompt, padding="longest", return_tensors="pt").input_ids + + if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids): + removed_text = self.tokenizer_2.batch_decode(untruncated_ids[:, self.tokenizer_max_length - 1 : -1]) + logger.warning( + "The following part of your input was truncated because `max_sequence_length` is set to " + f" {max_sequence_length} tokens: {removed_text}" + ) + + prompt_embeds = self.text_encoder_2(text_input_ids.to(device), output_hidden_states=False)[0] + + dtype = self.text_encoder_2.dtype + prompt_embeds = prompt_embeds.to(dtype=dtype, device=device) + + _, seq_len, _ = prompt_embeds.shape + + # duplicate text embeddings and attention mask for each generation per prompt, using mps friendly method + prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) + prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) + + return prompt_embeds + + # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline._get_clip_prompt_embeds + def _get_clip_prompt_embeds( + self, + prompt: Union[str, List[str]], + num_images_per_prompt: int = 1, + device: Optional[torch.device] = None, + ): + device = device or self._execution_device + + prompt = [prompt] if isinstance(prompt, str) else prompt + batch_size = len(prompt) + + if isinstance(self, TextualInversionLoaderMixin): + prompt = self.maybe_convert_prompt(prompt, self.tokenizer) + + text_inputs = self.tokenizer( + prompt, + padding="max_length", + max_length=self.tokenizer_max_length, + truncation=True, + return_overflowing_tokens=False, + return_length=False, + return_tensors="pt", + ) + + text_input_ids = text_inputs.input_ids + untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids + if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids): + removed_text = self.tokenizer.batch_decode(untruncated_ids[:, self.tokenizer_max_length - 1 : -1]) + logger.warning( + "The following part of your input was truncated because CLIP can only handle sequences up to" + f" {self.tokenizer_max_length} tokens: {removed_text}" + ) + prompt_embeds = self.text_encoder(text_input_ids.to(device), output_hidden_states=False) + + # Use pooled output of CLIPTextModel + prompt_embeds = prompt_embeds.pooler_output + prompt_embeds = prompt_embeds.to(dtype=self.text_encoder.dtype, device=device) + + # duplicate text embeddings for each generation per prompt, using mps friendly method + prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt) + prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, -1) + + return prompt_embeds + + # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline.encode_prompt + def encode_prompt( + self, + prompt: Union[str, List[str]], + prompt_2: Union[str, List[str]], + device: Optional[torch.device] = None, + num_images_per_prompt: int = 1, + prompt_embeds: Optional[torch.FloatTensor] = None, + pooled_prompt_embeds: Optional[torch.FloatTensor] = None, + max_sequence_length: int = 512, + lora_scale: Optional[float] = None, + ): + r""" + + Args: + prompt (`str` or `List[str]`, *optional*): + prompt to be encoded + prompt_2 (`str` or `List[str]`, *optional*): + The prompt or prompts to be sent to the `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is + used in all text-encoders + device: (`torch.device`): + torch device + num_images_per_prompt (`int`): + number of images that should be generated per prompt + prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + pooled_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. + If not provided, pooled text embeddings will be generated from `prompt` input argument. + lora_scale (`float`, *optional*): + A lora scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded. + """ + device = device or self._execution_device + + # set lora scale so that monkey patched LoRA + # function of text encoder can correctly access it + if lora_scale is not None and isinstance(self, FluxLoraLoaderMixin): + self._lora_scale = lora_scale + + # dynamically adjust the LoRA scale + if self.text_encoder is not None and USE_PEFT_BACKEND: + scale_lora_layers(self.text_encoder, lora_scale) + if self.text_encoder_2 is not None and USE_PEFT_BACKEND: + scale_lora_layers(self.text_encoder_2, lora_scale) + + prompt = [prompt] if isinstance(prompt, str) else prompt + + if prompt_embeds is None: + prompt_2 = prompt_2 or prompt + prompt_2 = [prompt_2] if isinstance(prompt_2, str) else prompt_2 + + # We only use the pooled prompt output from the CLIPTextModel + pooled_prompt_embeds = self._get_clip_prompt_embeds( + prompt=prompt, + device=device, + num_images_per_prompt=num_images_per_prompt, + ) + prompt_embeds = self._get_t5_prompt_embeds( + prompt=prompt_2, + num_images_per_prompt=num_images_per_prompt, + max_sequence_length=max_sequence_length, + device=device, + ) + + if self.text_encoder is not None: + if isinstance(self, FluxLoraLoaderMixin) and USE_PEFT_BACKEND: + # Retrieve the original scale by scaling back the LoRA layers + unscale_lora_layers(self.text_encoder, lora_scale) + + if self.text_encoder_2 is not None: + if isinstance(self, FluxLoraLoaderMixin) and USE_PEFT_BACKEND: + # Retrieve the original scale by scaling back the LoRA layers + unscale_lora_layers(self.text_encoder_2, lora_scale) + + dtype = self.text_encoder.dtype if self.text_encoder is not None else self.transformer.dtype + text_ids = torch.zeros(prompt_embeds.shape[1], 3).to(device=device, dtype=dtype) + + return prompt_embeds, pooled_prompt_embeds, text_ids + + # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline.check_inputs + def check_inputs( + self, + prompt, + prompt_2, + height, + width, + prompt_embeds=None, + pooled_prompt_embeds=None, + callback_on_step_end_tensor_inputs=None, + max_sequence_length=None, + ): + if height % (self.vae_scale_factor * 2) != 0 or width % (self.vae_scale_factor * 2) != 0: + logger.warning( + f"`height` and `width` have to be divisible by {self.vae_scale_factor * 2} but are {height} and {width}. Dimensions will be resized accordingly" + ) + + if callback_on_step_end_tensor_inputs is not None and not all( + k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs + ): + raise ValueError( + f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}" + ) + + if prompt is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt_2 is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt_2`: {prompt_2} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt is None and prompt_embeds is None: + raise ValueError( + "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." + ) + elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + elif prompt_2 is not None and (not isinstance(prompt_2, str) and not isinstance(prompt_2, list)): + raise ValueError(f"`prompt_2` has to be of type `str` or `list` but is {type(prompt_2)}") + + if prompt_embeds is not None and pooled_prompt_embeds is None: + raise ValueError( + "If `prompt_embeds` are provided, `pooled_prompt_embeds` also have to be passed. Make sure to generate `pooled_prompt_embeds` from the same text encoder that was used to generate `prompt_embeds`." + ) + + if max_sequence_length is not None and max_sequence_length > 512: + raise ValueError(f"`max_sequence_length` cannot be greater than 512 but is {max_sequence_length}") + + @staticmethod + # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline._prepare_latent_image_ids + def _prepare_latent_image_ids(batch_size, height, width, device, dtype): + latent_image_ids = torch.zeros(height, width, 3) + latent_image_ids[..., 1] = latent_image_ids[..., 1] + torch.arange(height)[:, None] + latent_image_ids[..., 2] = latent_image_ids[..., 2] + torch.arange(width)[None, :] + + latent_image_id_height, latent_image_id_width, latent_image_id_channels = latent_image_ids.shape + + latent_image_ids = latent_image_ids.reshape( + latent_image_id_height * latent_image_id_width, latent_image_id_channels + ) + + return latent_image_ids.to(device=device, dtype=dtype) + + @staticmethod + # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline._pack_latents + def _pack_latents(latents, batch_size, num_channels_latents, height, width): + latents = latents.view(batch_size, num_channels_latents, height // 2, 2, width // 2, 2) + latents = latents.permute(0, 2, 4, 1, 3, 5) + latents = latents.reshape(batch_size, (height // 2) * (width // 2), num_channels_latents * 4) + + return latents + + @staticmethod + # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline._unpack_latents + def _unpack_latents(latents, height, width, vae_scale_factor): + batch_size, num_patches, channels = latents.shape + + # VAE applies 8x compression on images but we must also account for packing which requires + # latent height and width to be divisible by 2. + height = 2 * (int(height) // (vae_scale_factor * 2)) + width = 2 * (int(width) // (vae_scale_factor * 2)) + + latents = latents.view(batch_size, height // 2, width // 2, channels // 4, 2, 2) + latents = latents.permute(0, 3, 1, 4, 2, 5) + + latents = latents.reshape(batch_size, channels // (2 * 2), height, width) + + return latents + + def enable_vae_slicing(self): + r""" + Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to + compute decoding in several steps. This is useful to save some memory and allow larger batch sizes. + """ + self.vae.enable_slicing() + + def disable_vae_slicing(self): + r""" + Disable sliced VAE decoding. If `enable_vae_slicing` was previously enabled, this method will go back to + computing decoding in one step. + """ + self.vae.disable_slicing() + + def enable_vae_tiling(self): + r""" + Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to + compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow + processing larger images. + """ + self.vae.enable_tiling() + + def disable_vae_tiling(self): + r""" + Disable tiled VAE decoding. If `enable_vae_tiling` was previously enabled, this method will go back to + computing decoding in one step. + """ + self.vae.disable_tiling() + + # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline.prepare_latents + def prepare_latents( + self, + batch_size, + num_channels_latents, + height, + width, + dtype, + device, + generator, + latents=None, + ): + # VAE applies 8x compression on images but we must also account for packing which requires + # latent height and width to be divisible by 2. + height = 2 * (int(height) // (self.vae_scale_factor * 2)) + width = 2 * (int(width) // (self.vae_scale_factor * 2)) + + shape = (batch_size, num_channels_latents, height, width) + + if latents is not None: + latent_image_ids = self._prepare_latent_image_ids(batch_size, height // 2, width // 2, device, dtype) + return latents.to(device=device, dtype=dtype), latent_image_ids + + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + + latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + latents = self._pack_latents(latents, batch_size, num_channels_latents, height, width) + + latent_image_ids = self._prepare_latent_image_ids(batch_size, height // 2, width // 2, device, dtype) + + return latents, latent_image_ids + + # Copied from diffusers.pipelines.controlnet_sd3.pipeline_stable_diffusion_3_controlnet.StableDiffusion3ControlNetPipeline.prepare_image + def prepare_image( + self, + image, + width, + height, + batch_size, + num_images_per_prompt, + device, + dtype, + do_classifier_free_guidance=False, + guess_mode=False, + ): + if isinstance(image, torch.Tensor): + pass + else: + image = self.image_processor.preprocess(image, height=height, width=width) + + image_batch_size = image.shape[0] + + if image_batch_size == 1: + repeat_by = batch_size + else: + # image batch size is the same as prompt batch size + repeat_by = num_images_per_prompt + + image = image.repeat_interleave(repeat_by, dim=0) + + image = image.to(device=device, dtype=dtype) + + if do_classifier_free_guidance and not guess_mode: + image = torch.cat([image] * 2) + + return image + + @property + def guidance_scale(self): + return self._guidance_scale + + @property + def joint_attention_kwargs(self): + return self._joint_attention_kwargs + + @property + def num_timesteps(self): + return self._num_timesteps + + @property + def interrupt(self): + return self._interrupt + + @torch.no_grad() + @replace_example_docstring(EXAMPLE_DOC_STRING) + def __call__( + self, + prompt: Union[str, List[str]] = None, + prompt_2: Optional[Union[str, List[str]]] = None, + control_image: PipelineImageInput = None, + height: Optional[int] = None, + width: Optional[int] = None, + num_inference_steps: int = 28, + timesteps: List[int] = None, + guidance_scale: float = 3.5, + num_images_per_prompt: Optional[int] = 1, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + latents: Optional[torch.FloatTensor] = None, + control_latents: Optional[torch.FloatTensor] = None, + prompt_embeds: Optional[torch.FloatTensor] = None, + pooled_prompt_embeds: Optional[torch.FloatTensor] = None, + output_type: Optional[str] = "pil", + return_dict: bool = True, + joint_attention_kwargs: Optional[Dict[str, Any]] = None, + callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None, + callback_on_step_end_tensor_inputs: List[str] = ["latents"], + max_sequence_length: int = 512, + ): + r""" + Function invoked when calling the pipeline for generation. + + Args: + prompt (`str` or `List[str]`, *optional*): + The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`. + instead. + prompt_2 (`str` or `List[str]`, *optional*): + The prompt or prompts to be sent to `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is + will be used instead + control_image (`torch.Tensor`, `PIL.Image.Image`, `np.ndarray`, `List[torch.Tensor]`, `List[PIL.Image.Image]`, `List[np.ndarray]`,: + `List[List[torch.Tensor]]`, `List[List[np.ndarray]]` or `List[List[PIL.Image.Image]]`): + The ControlNet input condition to provide guidance to the `unet` for generation. If the type is + specified as `torch.Tensor`, it is passed to ControlNet as is. `PIL.Image.Image` can also be accepted + as an image. The dimensions of the output image defaults to `image`'s dimensions. If height and/or + width are passed, `image` is resized accordingly. If multiple ControlNets are specified in `init`, + images must be passed as a list such that each element of the list can be correctly batched for input + to a single ControlNet. + height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): + The height in pixels of the generated image. This is set to 1024 by default for the best results. + width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): + The width in pixels of the generated image. This is set to 1024 by default for the best results. + num_inference_steps (`int`, *optional*, defaults to 50): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + timesteps (`List[int]`, *optional*): + Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument + in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is + passed will be used. Must be in descending order. + guidance_scale (`float`, *optional*, defaults to 7.0): + Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). + `guidance_scale` is defined as `w` of equation 2. of [Imagen + Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > + 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, + usually at the expense of lower image quality. + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + generator (`torch.Generator` or `List[torch.Generator]`, *optional*): + One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) + to make generation deterministic. + latents (`torch.FloatTensor`, *optional*): + Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor will ge generated by sampling using the supplied random `generator`. + prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + pooled_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. + If not provided, pooled text embeddings will be generated from `prompt` input argument. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generate image. Choose between + [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.flux.FluxPipelineOutput`] instead of a plain tuple. + joint_attention_kwargs (`dict`, *optional*): + A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under + `self.processor` in + [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). + callback_on_step_end (`Callable`, *optional*): + A function that calls at the end of each denoising steps during the inference. The function is called + with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int, + callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by + `callback_on_step_end_tensor_inputs`. + callback_on_step_end_tensor_inputs (`List`, *optional*): + The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list + will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the + `._callback_tensor_inputs` attribute of your pipeline class. + max_sequence_length (`int` defaults to 512): Maximum sequence length to use with the `prompt`. + + Examples: + + Returns: + [`~pipelines.flux.FluxPipelineOutput`] or `tuple`: [`~pipelines.flux.FluxPipelineOutput`] if `return_dict` + is True, otherwise a `tuple`. When returning a tuple, the first element is a list with the generated + images. + """ + + height = height or self.default_sample_size * self.vae_scale_factor + width = width or self.default_sample_size * self.vae_scale_factor + + # 1. Check inputs. Raise error if not correct + self.check_inputs( + prompt, + prompt_2, + height, + width, + prompt_embeds=prompt_embeds, + pooled_prompt_embeds=pooled_prompt_embeds, + callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs, + max_sequence_length=max_sequence_length, + ) + + self._guidance_scale = guidance_scale + self._joint_attention_kwargs = joint_attention_kwargs + self._interrupt = False + + # 2. Define call parameters + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + device = self._execution_device + + # 3. Prepare text embeddings + lora_scale = ( + self.joint_attention_kwargs.get("scale", None) if self.joint_attention_kwargs is not None else None + ) + ( + prompt_embeds, + pooled_prompt_embeds, + text_ids, + ) = self.encode_prompt( + prompt=prompt, + prompt_2=prompt_2, + prompt_embeds=prompt_embeds, + pooled_prompt_embeds=pooled_prompt_embeds, + device=device, + num_images_per_prompt=num_images_per_prompt, + max_sequence_length=max_sequence_length, + lora_scale=lora_scale, + ) + + # 4. Prepare latent variables + num_channels_latents = self.transformer.config.in_channels // 8 + + if control_latents is None: + control_image = self.prepare_image( + image=control_image, + width=width, + height=height, + batch_size=batch_size * num_images_per_prompt, + num_images_per_prompt=num_images_per_prompt, + device=device, + dtype=self.vae.dtype, + ) + + control_latents = self.vae.encode(control_image).latent_dist.sample(generator=generator) + control_latents = (control_latents - self.vae.config.shift_factor) * self.vae.config.scaling_factor + + height_control_image, width_control_image = control_latents.shape[2:] + control_latents = self._pack_latents( + control_latents, + batch_size * num_images_per_prompt, + num_channels_latents, + height_control_image, + width_control_image, + ) + + latents, latent_image_ids = self.prepare_latents( + batch_size * num_images_per_prompt, + num_channels_latents, + height, + width, + prompt_embeds.dtype, + device, + generator, + latents, + ) + + # 5. Prepare timesteps + sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps) + image_seq_len = latents.shape[1] + mu = calculate_shift( + image_seq_len, + self.scheduler.config.base_image_seq_len, + self.scheduler.config.max_image_seq_len, + self.scheduler.config.base_shift, + self.scheduler.config.max_shift, + ) + timesteps, num_inference_steps = retrieve_timesteps( + self.scheduler, + num_inference_steps, + device, + timesteps, + sigmas, + mu=mu, + ) + num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0) + self._num_timesteps = len(timesteps) + + # handle guidance + if self.transformer.config.guidance_embeds: + guidance = torch.full([1], guidance_scale, device=device, dtype=torch.float32) + guidance = guidance.expand(latents.shape[0]) + else: + guidance = None + + # 6. Denoising loop + with self.progress_bar(total=num_inference_steps) as progress_bar: + for i, t in enumerate(timesteps): + if self.interrupt: + continue + + latent_model_input = torch.cat([latents, control_latents], dim=2) + + # broadcast to batch dimension in a way that's compatible with ONNX/Core ML + timestep = t.expand(latents.shape[0]).to(latents.dtype) + + noise_pred = self.transformer( + hidden_states=latent_model_input, + timestep=timestep / 1000, + guidance=guidance, + pooled_projections=pooled_prompt_embeds, + encoder_hidden_states=prompt_embeds, + txt_ids=text_ids, + img_ids=latent_image_ids, + joint_attention_kwargs=self.joint_attention_kwargs, + return_dict=False, + )[0] + + # compute the previous noisy sample x_t -> x_t-1 + latents_dtype = latents.dtype + latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0] + + if latents.dtype != latents_dtype: + if torch.backends.mps.is_available(): + # some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272 + latents = latents.to(latents_dtype) + + if callback_on_step_end is not None: + callback_kwargs = {} + for k in callback_on_step_end_tensor_inputs: + callback_kwargs[k] = locals()[k] + callback_outputs = callback_on_step_end(self, i, t, callback_kwargs) + + latents = callback_outputs.pop("latents", latents) + prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds) + + # call the callback, if provided + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + progress_bar.update() + + if XLA_AVAILABLE: + xm.mark_step() + + if output_type == "latent": + image = latents + else: + latents = self._unpack_latents(latents, height, width, self.vae_scale_factor) + latents = (latents / self.vae.config.scaling_factor) + self.vae.config.shift_factor + image = self.vae.decode(latents, return_dict=False)[0] + image = self.image_processor.postprocess(image, output_type=output_type) + + # Offload all models + self.maybe_free_model_hooks() + + if not return_dict: + return (image,) + + return FluxPipelineOutput(images=image) diff --git a/src/diffusers/pipelines/flux/pipeline_flux_control_img2img.py b/src/diffusers/pipelines/flux/pipeline_flux_control_img2img.py new file mode 100644 index 000000000000..afda0bd9355f --- /dev/null +++ b/src/diffusers/pipelines/flux/pipeline_flux_control_img2img.py @@ -0,0 +1,933 @@ +# Copyright 2024 Black Forest Labs and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import inspect +from typing import Any, Callable, Dict, List, Optional, Union + +import numpy as np +import torch +from transformers import CLIPTextModel, CLIPTokenizer, T5EncoderModel, T5TokenizerFast + +from ...image_processor import PipelineImageInput, VaeImageProcessor +from ...loaders import FluxLoraLoaderMixin, FromSingleFileMixin, TextualInversionLoaderMixin +from ...models.autoencoders import AutoencoderKL +from ...models.transformers import FluxTransformer2DModel +from ...schedulers import FlowMatchEulerDiscreteScheduler +from ...utils import ( + USE_PEFT_BACKEND, + is_torch_xla_available, + logging, + replace_example_docstring, + scale_lora_layers, + unscale_lora_layers, +) +from ...utils.torch_utils import randn_tensor +from ..pipeline_utils import DiffusionPipeline +from .pipeline_output import FluxPipelineOutput + + +if is_torch_xla_available(): + import torch_xla.core.xla_model as xm + + XLA_AVAILABLE = True +else: + XLA_AVAILABLE = False + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + +EXAMPLE_DOC_STRING = """ + Examples: + ```py + >>> import torch + + >>> from diffusers import FluxControlImg2ImgPipeline + >>> from diffusers.utils import load_image + + >>> device = "cuda" + >>> pipe = FluxControlImg2ImgPipeline.from_pretrained( + ... "black-forest-labs/FLUX.1-schnell", torch_dtype=torch.bfloat16 + ... ) + >>> pipe = pipe.to(device) + + >>> url = "https://raw.githubusercontent.com/CompVis/stable-diffusion/main/assets/stable-samples/img2img/sketch-mountains-input.jpg" + >>> init_image = load_image(url).resize((1024, 1024)) + + >>> prompt = "cat wizard, gandalf, lord of the rings, detailed, fantasy, cute, adorable, Pixar, Disney, 8k" + + >>> images = pipe( + ... prompt=prompt, image=init_image, num_inference_steps=4, strength=0.95, guidance_scale=0.0 + ... ).images[0] + ``` +""" + + +# Copied from diffusers.pipelines.flux.pipeline_flux.calculate_shift +def calculate_shift( + image_seq_len, + base_seq_len: int = 256, + max_seq_len: int = 4096, + base_shift: float = 0.5, + max_shift: float = 1.16, +): + m = (max_shift - base_shift) / (max_seq_len - base_seq_len) + b = base_shift - m * base_seq_len + mu = image_seq_len * m + b + return mu + + +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents +def retrieve_latents( + encoder_output: torch.Tensor, generator: Optional[torch.Generator] = None, sample_mode: str = "sample" +): + if hasattr(encoder_output, "latent_dist") and sample_mode == "sample": + return encoder_output.latent_dist.sample(generator) + elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax": + return encoder_output.latent_dist.mode() + elif hasattr(encoder_output, "latents"): + return encoder_output.latents + else: + raise AttributeError("Could not access latents of provided encoder_output") + + +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps +def retrieve_timesteps( + scheduler, + num_inference_steps: Optional[int] = None, + device: Optional[Union[str, torch.device]] = None, + timesteps: Optional[List[int]] = None, + sigmas: Optional[List[float]] = None, + **kwargs, +): + r""" + Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles + custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`. + + Args: + scheduler (`SchedulerMixin`): + The scheduler to get timesteps from. + num_inference_steps (`int`): + The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps` + must be `None`. + device (`str` or `torch.device`, *optional*): + The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. + timesteps (`List[int]`, *optional*): + Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed, + `num_inference_steps` and `sigmas` must be `None`. + sigmas (`List[float]`, *optional*): + Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed, + `num_inference_steps` and `timesteps` must be `None`. + + Returns: + `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the + second element is the number of inference steps. + """ + if timesteps is not None and sigmas is not None: + raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values") + if timesteps is not None: + accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accepts_timesteps: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" timestep schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + elif sigmas is not None: + accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accept_sigmas: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" sigmas schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + else: + scheduler.set_timesteps(num_inference_steps, device=device, **kwargs) + timesteps = scheduler.timesteps + return timesteps, num_inference_steps + + +class FluxControlImg2ImgPipeline(DiffusionPipeline, FluxLoraLoaderMixin, FromSingleFileMixin): + r""" + The Flux pipeline for image inpainting. + + Reference: https://blackforestlabs.ai/announcing-black-forest-labs/ + + Args: + transformer ([`FluxTransformer2DModel`]): + Conditional Transformer (MMDiT) architecture to denoise the encoded image latents. + scheduler ([`FlowMatchEulerDiscreteScheduler`]): + A scheduler to be used in combination with `transformer` to denoise the encoded image latents. + vae ([`AutoencoderKL`]): + Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations. + text_encoder ([`CLIPTextModel`]): + [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically + the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant. + text_encoder_2 ([`T5EncoderModel`]): + [T5](https://huggingface.co/docs/transformers/en/model_doc/t5#transformers.T5EncoderModel), specifically + the [google/t5-v1_1-xxl](https://huggingface.co/google/t5-v1_1-xxl) variant. + tokenizer (`CLIPTokenizer`): + Tokenizer of class + [CLIPTokenizer](https://huggingface.co/docs/transformers/en/model_doc/clip#transformers.CLIPTokenizer). + tokenizer_2 (`T5TokenizerFast`): + Second Tokenizer of class + [T5TokenizerFast](https://huggingface.co/docs/transformers/en/model_doc/t5#transformers.T5TokenizerFast). + """ + + model_cpu_offload_seq = "text_encoder->text_encoder_2->transformer->vae" + _optional_components = [] + _callback_tensor_inputs = ["latents", "prompt_embeds"] + + def __init__( + self, + scheduler: FlowMatchEulerDiscreteScheduler, + vae: AutoencoderKL, + text_encoder: CLIPTextModel, + tokenizer: CLIPTokenizer, + text_encoder_2: T5EncoderModel, + tokenizer_2: T5TokenizerFast, + transformer: FluxTransformer2DModel, + ): + super().__init__() + + self.register_modules( + vae=vae, + text_encoder=text_encoder, + text_encoder_2=text_encoder_2, + tokenizer=tokenizer, + tokenizer_2=tokenizer_2, + transformer=transformer, + scheduler=scheduler, + ) + self.vae_scale_factor = ( + 2 ** (len(self.vae.config.block_out_channels) - 1) if hasattr(self, "vae") and self.vae is not None else 8 + ) + # Flux latents are turned into 2x2 patches and packed. This means the latent width and height has to be divisible + # by the patch size. So the vae scale factor is multiplied by the patch size to account for this + self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor * 2) + self.tokenizer_max_length = ( + self.tokenizer.model_max_length if hasattr(self, "tokenizer") and self.tokenizer is not None else 77 + ) + self.default_sample_size = 128 + + # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline._get_t5_prompt_embeds + def _get_t5_prompt_embeds( + self, + prompt: Union[str, List[str]] = None, + num_images_per_prompt: int = 1, + max_sequence_length: int = 512, + device: Optional[torch.device] = None, + dtype: Optional[torch.dtype] = None, + ): + device = device or self._execution_device + dtype = dtype or self.text_encoder.dtype + + prompt = [prompt] if isinstance(prompt, str) else prompt + batch_size = len(prompt) + + if isinstance(self, TextualInversionLoaderMixin): + prompt = self.maybe_convert_prompt(prompt, self.tokenizer_2) + + text_inputs = self.tokenizer_2( + prompt, + padding="max_length", + max_length=max_sequence_length, + truncation=True, + return_length=False, + return_overflowing_tokens=False, + return_tensors="pt", + ) + text_input_ids = text_inputs.input_ids + untruncated_ids = self.tokenizer_2(prompt, padding="longest", return_tensors="pt").input_ids + + if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids): + removed_text = self.tokenizer_2.batch_decode(untruncated_ids[:, self.tokenizer_max_length - 1 : -1]) + logger.warning( + "The following part of your input was truncated because `max_sequence_length` is set to " + f" {max_sequence_length} tokens: {removed_text}" + ) + + prompt_embeds = self.text_encoder_2(text_input_ids.to(device), output_hidden_states=False)[0] + + dtype = self.text_encoder_2.dtype + prompt_embeds = prompt_embeds.to(dtype=dtype, device=device) + + _, seq_len, _ = prompt_embeds.shape + + # duplicate text embeddings and attention mask for each generation per prompt, using mps friendly method + prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) + prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) + + return prompt_embeds + + # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline._get_clip_prompt_embeds + def _get_clip_prompt_embeds( + self, + prompt: Union[str, List[str]], + num_images_per_prompt: int = 1, + device: Optional[torch.device] = None, + ): + device = device or self._execution_device + + prompt = [prompt] if isinstance(prompt, str) else prompt + batch_size = len(prompt) + + if isinstance(self, TextualInversionLoaderMixin): + prompt = self.maybe_convert_prompt(prompt, self.tokenizer) + + text_inputs = self.tokenizer( + prompt, + padding="max_length", + max_length=self.tokenizer_max_length, + truncation=True, + return_overflowing_tokens=False, + return_length=False, + return_tensors="pt", + ) + + text_input_ids = text_inputs.input_ids + untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids + if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids): + removed_text = self.tokenizer.batch_decode(untruncated_ids[:, self.tokenizer_max_length - 1 : -1]) + logger.warning( + "The following part of your input was truncated because CLIP can only handle sequences up to" + f" {self.tokenizer_max_length} tokens: {removed_text}" + ) + prompt_embeds = self.text_encoder(text_input_ids.to(device), output_hidden_states=False) + + # Use pooled output of CLIPTextModel + prompt_embeds = prompt_embeds.pooler_output + prompt_embeds = prompt_embeds.to(dtype=self.text_encoder.dtype, device=device) + + # duplicate text embeddings for each generation per prompt, using mps friendly method + prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt) + prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, -1) + + return prompt_embeds + + # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline.encode_prompt + def encode_prompt( + self, + prompt: Union[str, List[str]], + prompt_2: Union[str, List[str]], + device: Optional[torch.device] = None, + num_images_per_prompt: int = 1, + prompt_embeds: Optional[torch.FloatTensor] = None, + pooled_prompt_embeds: Optional[torch.FloatTensor] = None, + max_sequence_length: int = 512, + lora_scale: Optional[float] = None, + ): + r""" + + Args: + prompt (`str` or `List[str]`, *optional*): + prompt to be encoded + prompt_2 (`str` or `List[str]`, *optional*): + The prompt or prompts to be sent to the `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is + used in all text-encoders + device: (`torch.device`): + torch device + num_images_per_prompt (`int`): + number of images that should be generated per prompt + prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + pooled_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. + If not provided, pooled text embeddings will be generated from `prompt` input argument. + lora_scale (`float`, *optional*): + A lora scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded. + """ + device = device or self._execution_device + + # set lora scale so that monkey patched LoRA + # function of text encoder can correctly access it + if lora_scale is not None and isinstance(self, FluxLoraLoaderMixin): + self._lora_scale = lora_scale + + # dynamically adjust the LoRA scale + if self.text_encoder is not None and USE_PEFT_BACKEND: + scale_lora_layers(self.text_encoder, lora_scale) + if self.text_encoder_2 is not None and USE_PEFT_BACKEND: + scale_lora_layers(self.text_encoder_2, lora_scale) + + prompt = [prompt] if isinstance(prompt, str) else prompt + + if prompt_embeds is None: + prompt_2 = prompt_2 or prompt + prompt_2 = [prompt_2] if isinstance(prompt_2, str) else prompt_2 + + # We only use the pooled prompt output from the CLIPTextModel + pooled_prompt_embeds = self._get_clip_prompt_embeds( + prompt=prompt, + device=device, + num_images_per_prompt=num_images_per_prompt, + ) + prompt_embeds = self._get_t5_prompt_embeds( + prompt=prompt_2, + num_images_per_prompt=num_images_per_prompt, + max_sequence_length=max_sequence_length, + device=device, + ) + + if self.text_encoder is not None: + if isinstance(self, FluxLoraLoaderMixin) and USE_PEFT_BACKEND: + # Retrieve the original scale by scaling back the LoRA layers + unscale_lora_layers(self.text_encoder, lora_scale) + + if self.text_encoder_2 is not None: + if isinstance(self, FluxLoraLoaderMixin) and USE_PEFT_BACKEND: + # Retrieve the original scale by scaling back the LoRA layers + unscale_lora_layers(self.text_encoder_2, lora_scale) + + dtype = self.text_encoder.dtype if self.text_encoder is not None else self.transformer.dtype + text_ids = torch.zeros(prompt_embeds.shape[1], 3).to(device=device, dtype=dtype) + + return prompt_embeds, pooled_prompt_embeds, text_ids + + # Copied from diffusers.pipelines.stable_diffusion_3.pipeline_stable_diffusion_3_inpaint.StableDiffusion3InpaintPipeline._encode_vae_image + def _encode_vae_image(self, image: torch.Tensor, generator: torch.Generator): + if isinstance(generator, list): + image_latents = [ + retrieve_latents(self.vae.encode(image[i : i + 1]), generator=generator[i]) + for i in range(image.shape[0]) + ] + image_latents = torch.cat(image_latents, dim=0) + else: + image_latents = retrieve_latents(self.vae.encode(image), generator=generator) + + image_latents = (image_latents - self.vae.config.shift_factor) * self.vae.config.scaling_factor + + return image_latents + + # Copied from diffusers.pipelines.stable_diffusion_3.pipeline_stable_diffusion_3_img2img.StableDiffusion3Img2ImgPipeline.get_timesteps + def get_timesteps(self, num_inference_steps, strength, device): + # get the original timestep using init_timestep + init_timestep = min(num_inference_steps * strength, num_inference_steps) + + t_start = int(max(num_inference_steps - init_timestep, 0)) + timesteps = self.scheduler.timesteps[t_start * self.scheduler.order :] + if hasattr(self.scheduler, "set_begin_index"): + self.scheduler.set_begin_index(t_start * self.scheduler.order) + + return timesteps, num_inference_steps - t_start + + # Copied from diffusers.pipelines.flux.pipeline_flux_img2img.FluxImg2ImgPipeline.check_inputs + def check_inputs( + self, + prompt, + prompt_2, + strength, + height, + width, + prompt_embeds=None, + pooled_prompt_embeds=None, + callback_on_step_end_tensor_inputs=None, + max_sequence_length=None, + ): + if strength < 0 or strength > 1: + raise ValueError(f"The value of strength should in [0.0, 1.0] but is {strength}") + + if height % (self.vae_scale_factor * 2) != 0 or width % (self.vae_scale_factor * 2) != 0: + logger.warning( + f"`height` and `width` have to be divisible by {self.vae_scale_factor * 2} but are {height} and {width}. Dimensions will be resized accordingly" + ) + + if callback_on_step_end_tensor_inputs is not None and not all( + k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs + ): + raise ValueError( + f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}" + ) + + if prompt is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt_2 is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt_2`: {prompt_2} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt is None and prompt_embeds is None: + raise ValueError( + "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." + ) + elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + elif prompt_2 is not None and (not isinstance(prompt_2, str) and not isinstance(prompt_2, list)): + raise ValueError(f"`prompt_2` has to be of type `str` or `list` but is {type(prompt_2)}") + + if prompt_embeds is not None and pooled_prompt_embeds is None: + raise ValueError( + "If `prompt_embeds` are provided, `pooled_prompt_embeds` also have to be passed. Make sure to generate `pooled_prompt_embeds` from the same text encoder that was used to generate `prompt_embeds`." + ) + + if max_sequence_length is not None and max_sequence_length > 512: + raise ValueError(f"`max_sequence_length` cannot be greater than 512 but is {max_sequence_length}") + + @staticmethod + # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline._prepare_latent_image_ids + def _prepare_latent_image_ids(batch_size, height, width, device, dtype): + latent_image_ids = torch.zeros(height, width, 3) + latent_image_ids[..., 1] = latent_image_ids[..., 1] + torch.arange(height)[:, None] + latent_image_ids[..., 2] = latent_image_ids[..., 2] + torch.arange(width)[None, :] + + latent_image_id_height, latent_image_id_width, latent_image_id_channels = latent_image_ids.shape + + latent_image_ids = latent_image_ids.reshape( + latent_image_id_height * latent_image_id_width, latent_image_id_channels + ) + + return latent_image_ids.to(device=device, dtype=dtype) + + @staticmethod + # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline._pack_latents + def _pack_latents(latents, batch_size, num_channels_latents, height, width): + latents = latents.view(batch_size, num_channels_latents, height // 2, 2, width // 2, 2) + latents = latents.permute(0, 2, 4, 1, 3, 5) + latents = latents.reshape(batch_size, (height // 2) * (width // 2), num_channels_latents * 4) + + return latents + + @staticmethod + # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline._unpack_latents + def _unpack_latents(latents, height, width, vae_scale_factor): + batch_size, num_patches, channels = latents.shape + + # VAE applies 8x compression on images but we must also account for packing which requires + # latent height and width to be divisible by 2. + height = 2 * (int(height) // (vae_scale_factor * 2)) + width = 2 * (int(width) // (vae_scale_factor * 2)) + + latents = latents.view(batch_size, height // 2, width // 2, channels // 4, 2, 2) + latents = latents.permute(0, 3, 1, 4, 2, 5) + + latents = latents.reshape(batch_size, channels // (2 * 2), height, width) + + return latents + + # Copied from diffusers.pipelines.flux.pipeline_flux_img2img.FluxImg2ImgPipeline.prepare_latents + def prepare_latents( + self, + image, + timestep, + batch_size, + num_channels_latents, + height, + width, + dtype, + device, + generator, + latents=None, + ): + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + + # VAE applies 8x compression on images but we must also account for packing which requires + # latent height and width to be divisible by 2. + height = 2 * (int(height) // (self.vae_scale_factor * 2)) + width = 2 * (int(width) // (self.vae_scale_factor * 2)) + shape = (batch_size, num_channels_latents, height, width) + latent_image_ids = self._prepare_latent_image_ids(batch_size, height // 2, width // 2, device, dtype) + + if latents is not None: + return latents.to(device=device, dtype=dtype), latent_image_ids + + image = image.to(device=device, dtype=dtype) + image_latents = self._encode_vae_image(image=image, generator=generator) + if batch_size > image_latents.shape[0] and batch_size % image_latents.shape[0] == 0: + # expand init_latents for batch_size + additional_image_per_prompt = batch_size // image_latents.shape[0] + image_latents = torch.cat([image_latents] * additional_image_per_prompt, dim=0) + elif batch_size > image_latents.shape[0] and batch_size % image_latents.shape[0] != 0: + raise ValueError( + f"Cannot duplicate `image` of batch size {image_latents.shape[0]} to {batch_size} text prompts." + ) + else: + image_latents = torch.cat([image_latents], dim=0) + + noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + latents = self.scheduler.scale_noise(image_latents, timestep, noise) + latents = self._pack_latents(latents, batch_size, num_channels_latents, height, width) + return latents, latent_image_ids + + # Copied from diffusers.pipelines.controlnet_sd3.pipeline_stable_diffusion_3_controlnet.StableDiffusion3ControlNetPipeline.prepare_image + def prepare_image( + self, + image, + width, + height, + batch_size, + num_images_per_prompt, + device, + dtype, + do_classifier_free_guidance=False, + guess_mode=False, + ): + if isinstance(image, torch.Tensor): + pass + else: + image = self.image_processor.preprocess(image, height=height, width=width) + + image_batch_size = image.shape[0] + + if image_batch_size == 1: + repeat_by = batch_size + else: + # image batch size is the same as prompt batch size + repeat_by = num_images_per_prompt + + image = image.repeat_interleave(repeat_by, dim=0) + + image = image.to(device=device, dtype=dtype) + + if do_classifier_free_guidance and not guess_mode: + image = torch.cat([image] * 2) + + return image + + @property + def guidance_scale(self): + return self._guidance_scale + + @property + def joint_attention_kwargs(self): + return self._joint_attention_kwargs + + @property + def num_timesteps(self): + return self._num_timesteps + + @property + def interrupt(self): + return self._interrupt + + @torch.no_grad() + @replace_example_docstring(EXAMPLE_DOC_STRING) + def __call__( + self, + prompt: Union[str, List[str]] = None, + prompt_2: Optional[Union[str, List[str]]] = None, + image: PipelineImageInput = None, + control_image: PipelineImageInput = None, + height: Optional[int] = None, + width: Optional[int] = None, + strength: float = 0.6, + num_inference_steps: int = 28, + timesteps: List[int] = None, + guidance_scale: float = 7.0, + num_images_per_prompt: Optional[int] = 1, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + control_latents: Optional[torch.FloatTensor] = None, + latents: Optional[torch.FloatTensor] = None, + prompt_embeds: Optional[torch.FloatTensor] = None, + pooled_prompt_embeds: Optional[torch.FloatTensor] = None, + output_type: Optional[str] = "pil", + return_dict: bool = True, + joint_attention_kwargs: Optional[Dict[str, Any]] = None, + callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None, + callback_on_step_end_tensor_inputs: List[str] = ["latents"], + max_sequence_length: int = 512, + ): + r""" + Function invoked when calling the pipeline for generation. + + Args: + prompt (`str` or `List[str]`, *optional*): + The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`. + instead. + prompt_2 (`str` or `List[str]`, *optional*): + The prompt or prompts to be sent to `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is + will be used instead + image (`torch.Tensor`, `PIL.Image.Image`, `np.ndarray`, `List[torch.Tensor]`, `List[PIL.Image.Image]`, or `List[np.ndarray]`): + `Image`, numpy array or tensor representing an image batch to be used as the starting point. For both + numpy array and pytorch tensor, the expected value range is between `[0, 1]` If it's a tensor or a list + or tensors, the expected shape should be `(B, C, H, W)` or `(C, H, W)`. If it is a numpy array or a + list of arrays, the expected shape should be `(B, H, W, C)` or `(H, W, C)` It can also accept image + latents as `image`, but if passing latents directly it is not encoded again. + control_image (`torch.Tensor`, `PIL.Image.Image`, `np.ndarray`, `List[torch.Tensor]`, `List[PIL.Image.Image]`, `List[np.ndarray]`,: + `List[List[torch.Tensor]]`, `List[List[np.ndarray]]` or `List[List[PIL.Image.Image]]`): + The ControlNet input condition to provide guidance to the `unet` for generation. If the type is + specified as `torch.Tensor`, it is passed to ControlNet as is. `PIL.Image.Image` can also be accepted + as an image. The dimensions of the output image defaults to `image`'s dimensions. If height and/or + width are passed, `image` is resized accordingly. If multiple ControlNets are specified in `init`, + images must be passed as a list such that each element of the list can be correctly batched for input + to a single ControlNet. + height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): + The height in pixels of the generated image. This is set to 1024 by default for the best results. + width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): + The width in pixels of the generated image. This is set to 1024 by default for the best results. + strength (`float`, *optional*, defaults to 1.0): + Indicates extent to transform the reference `image`. Must be between 0 and 1. `image` is used as a + starting point and more noise is added the higher the `strength`. The number of denoising steps depends + on the amount of noise initially added. When `strength` is 1, added noise is maximum and the denoising + process runs for the full number of iterations specified in `num_inference_steps`. A value of 1 + essentially ignores `image`. + num_inference_steps (`int`, *optional*, defaults to 50): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + timesteps (`List[int]`, *optional*): + Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument + in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is + passed will be used. Must be in descending order. + guidance_scale (`float`, *optional*, defaults to 7.0): + Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). + `guidance_scale` is defined as `w` of equation 2. of [Imagen + Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > + 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, + usually at the expense of lower image quality. + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + generator (`torch.Generator` or `List[torch.Generator]`, *optional*): + One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) + to make generation deterministic. + latents (`torch.FloatTensor`, *optional*): + Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor will ge generated by sampling using the supplied random `generator`. + prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + pooled_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. + If not provided, pooled text embeddings will be generated from `prompt` input argument. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generate image. Choose between + [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.flux.FluxPipelineOutput`] instead of a plain tuple. + joint_attention_kwargs (`dict`, *optional*): + A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under + `self.processor` in + [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). + callback_on_step_end (`Callable`, *optional*): + A function that calls at the end of each denoising steps during the inference. The function is called + with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int, + callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by + `callback_on_step_end_tensor_inputs`. + callback_on_step_end_tensor_inputs (`List`, *optional*): + The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list + will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the + `._callback_tensor_inputs` attribute of your pipeline class. + max_sequence_length (`int` defaults to 512): Maximum sequence length to use with the `prompt`. + + Examples: + + Returns: + [`~pipelines.flux.FluxPipelineOutput`] or `tuple`: [`~pipelines.flux.FluxPipelineOutput`] if `return_dict` + is True, otherwise a `tuple`. When returning a tuple, the first element is a list with the generated + images. + """ + + height = height or self.default_sample_size * self.vae_scale_factor + width = width or self.default_sample_size * self.vae_scale_factor + + # 1. Check inputs. Raise error if not correct + self.check_inputs( + prompt, + prompt_2, + strength, + height, + width, + prompt_embeds=prompt_embeds, + pooled_prompt_embeds=pooled_prompt_embeds, + callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs, + max_sequence_length=max_sequence_length, + ) + + self._guidance_scale = guidance_scale + self._joint_attention_kwargs = joint_attention_kwargs + self._interrupt = False + + # 2. Preprocess image + init_image = self.image_processor.preprocess(image, height=height, width=width) + init_image = init_image.to(dtype=torch.float32) + + # 3. Define call parameters + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + device = self._execution_device + + # 3. Prepare text embeddings + lora_scale = ( + self.joint_attention_kwargs.get("scale", None) if self.joint_attention_kwargs is not None else None + ) + ( + prompt_embeds, + pooled_prompt_embeds, + text_ids, + ) = self.encode_prompt( + prompt=prompt, + prompt_2=prompt_2, + prompt_embeds=prompt_embeds, + pooled_prompt_embeds=pooled_prompt_embeds, + device=device, + num_images_per_prompt=num_images_per_prompt, + max_sequence_length=max_sequence_length, + lora_scale=lora_scale, + ) + + # 4.Prepare timesteps + sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps) + image_seq_len = (int(height) // self.vae_scale_factor // 2) * (int(width) // self.vae_scale_factor // 2) + mu = calculate_shift( + image_seq_len, + self.scheduler.config.base_image_seq_len, + self.scheduler.config.max_image_seq_len, + self.scheduler.config.base_shift, + self.scheduler.config.max_shift, + ) + timesteps, num_inference_steps = retrieve_timesteps( + self.scheduler, + num_inference_steps, + device, + timesteps, + sigmas, + mu=mu, + ) + timesteps, num_inference_steps = self.get_timesteps(num_inference_steps, strength, device) + + if num_inference_steps < 1: + raise ValueError( + f"After adjusting the num_inference_steps by strength parameter: {strength}, the number of pipeline" + f"steps is {num_inference_steps} which is < 1 and not appropriate for this pipeline." + ) + latent_timestep = timesteps[:1].repeat(batch_size * num_images_per_prompt) + + # 5. Prepare latent variables + num_channels_latents = self.transformer.config.in_channels // 8 + + if control_latents is None: + control_image = self.prepare_image( + image=control_image, + width=width, + height=height, + batch_size=batch_size * num_images_per_prompt, + num_images_per_prompt=num_images_per_prompt, + device=device, + dtype=self.vae.dtype, + ) + + control_latents = self.vae.encode(control_image).latent_dist.sample(generator=generator) + control_latents = (control_latents - self.vae.config.shift_factor) * self.vae.config.scaling_factor + + height_control_image, width_control_image = control_latents.shape[2:] + control_latents = self._pack_latents( + control_latents, + batch_size * num_images_per_prompt, + num_channels_latents, + height_control_image, + width_control_image, + ) + + latents, latent_image_ids = self.prepare_latents( + init_image, + latent_timestep, + batch_size * num_images_per_prompt, + num_channels_latents, + height, + width, + prompt_embeds.dtype, + device, + generator, + latents, + ) + + num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0) + self._num_timesteps = len(timesteps) + + # handle guidance + if self.transformer.config.guidance_embeds: + guidance = torch.full([1], guidance_scale, device=device, dtype=torch.float32) + guidance = guidance.expand(latents.shape[0]) + else: + guidance = None + + # 6. Denoising loop + with self.progress_bar(total=num_inference_steps) as progress_bar: + for i, t in enumerate(timesteps): + if self.interrupt: + continue + + latent_model_input = torch.cat([latents, control_latents], dim=2) + + # broadcast to batch dimension in a way that's compatible with ONNX/Core ML + timestep = t.expand(latents.shape[0]).to(latents.dtype) + + noise_pred = self.transformer( + hidden_states=latent_model_input, + timestep=timestep / 1000, + guidance=guidance, + pooled_projections=pooled_prompt_embeds, + encoder_hidden_states=prompt_embeds, + txt_ids=text_ids, + img_ids=latent_image_ids, + joint_attention_kwargs=self.joint_attention_kwargs, + return_dict=False, + )[0] + + # compute the previous noisy sample x_t -> x_t-1 + latents_dtype = latents.dtype + latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0] + + if latents.dtype != latents_dtype: + if torch.backends.mps.is_available(): + # some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272 + latents = latents.to(latents_dtype) + + if callback_on_step_end is not None: + callback_kwargs = {} + for k in callback_on_step_end_tensor_inputs: + callback_kwargs[k] = locals()[k] + callback_outputs = callback_on_step_end(self, i, t, callback_kwargs) + + latents = callback_outputs.pop("latents", latents) + prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds) + + # call the callback, if provided + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + progress_bar.update() + + if XLA_AVAILABLE: + xm.mark_step() + + if output_type == "latent": + image = latents + + else: + latents = self._unpack_latents(latents, height, width, self.vae_scale_factor) + latents = (latents / self.vae.config.scaling_factor) + self.vae.config.shift_factor + image = self.vae.decode(latents, return_dict=False)[0] + image = self.image_processor.postprocess(image, output_type=output_type) + + # Offload all models + self.maybe_free_model_hooks() + + if not return_dict: + return (image,) + + return FluxPipelineOutput(images=image) diff --git a/src/diffusers/pipelines/flux/pipeline_flux_img2img.py b/src/diffusers/pipelines/flux/pipeline_flux_img2img.py index dcabcac6ed7c..d34d9b53aa6b 100644 --- a/src/diffusers/pipelines/flux/pipeline_flux_img2img.py +++ b/src/diffusers/pipelines/flux/pipeline_flux_img2img.py @@ -566,41 +566,6 @@ def prepare_latents( latents = self._pack_latents(latents, batch_size, num_channels_latents, height, width) return latents, latent_image_ids - # Copied from diffusers.pipelines.controlnet_sd3.pipeline_stable_diffusion_3_controlnet.StableDiffusion3ControlNetPipeline.prepare_image - def prepare_image( - self, - image, - width, - height, - batch_size, - num_images_per_prompt, - device, - dtype, - do_classifier_free_guidance=False, - guess_mode=False, - ): - if isinstance(image, torch.Tensor): - pass - else: - image = self.image_processor.preprocess(image, height=height, width=width) - - image_batch_size = image.shape[0] - - if image_batch_size == 1: - repeat_by = batch_size - else: - # image batch size is the same as prompt batch size - repeat_by = num_images_per_prompt - - image = image.repeat_interleave(repeat_by, dim=0) - - image = image.to(device=device, dtype=dtype) - - if do_classifier_free_guidance and not guess_mode: - image = torch.cat([image] * 2) - - return image - @property def guidance_scale(self): return self._guidance_scale @@ -630,10 +595,8 @@ def __call__( num_inference_steps: int = 28, timesteps: List[int] = None, guidance_scale: float = 7.0, - control_image: PipelineImageInput = None, num_images_per_prompt: Optional[int] = 1, generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, - control_latents: Optional[torch.FloatTensor] = None, latents: Optional[torch.FloatTensor] = None, prompt_embeds: Optional[torch.FloatTensor] = None, pooled_prompt_embeds: Optional[torch.FloatTensor] = None, @@ -683,14 +646,6 @@ def __call__( Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, usually at the expense of lower image quality. - control_image (`torch.Tensor`, `PIL.Image.Image`, `np.ndarray`, `List[torch.Tensor]`, `List[PIL.Image.Image]`, `List[np.ndarray]`,: - `List[List[torch.Tensor]]`, `List[List[np.ndarray]]` or `List[List[PIL.Image.Image]]`): - The ControlNet input condition to provide guidance to the `unet` for generation. If the type is - specified as `torch.Tensor`, it is passed to ControlNet as is. `PIL.Image.Image` can also be accepted - as an image. The dimensions of the output image defaults to `image`'s dimensions. If height and/or - width are passed, `image` is resized accordingly. If multiple ControlNets are specified in `init`, - images must be passed as a list such that each element of the list can be correctly batched for input - to a single ControlNet. num_images_per_prompt (`int`, *optional*, defaults to 1): The number of images to generate per prompt. generator (`torch.Generator` or `List[torch.Generator]`, *optional*): @@ -768,7 +723,6 @@ def __call__( device = self._execution_device - # 3. Prepare text embeddings lora_scale = ( self.joint_attention_kwargs.get("scale", None) if self.joint_attention_kwargs is not None else None ) @@ -815,34 +769,7 @@ def __call__( latent_timestep = timesteps[:1].repeat(batch_size * num_images_per_prompt) # 5. Prepare latent variables - num_channels_latents = ( - self.transformer.config.in_channels // 4 - if control_image is None - else self.transformer.config.in_channels // 8 - ) - - if control_image is not None and control_latents is None: - control_image = self.prepare_image( - image=control_image, - width=width, - height=height, - batch_size=batch_size * num_images_per_prompt, - num_images_per_prompt=num_images_per_prompt, - device=device, - dtype=self.vae.dtype, - ) - - control_latents = self.vae.encode(control_image).latent_dist.sample(generator=generator) - control_latents = (control_latents - self.vae.config.shift_factor) * self.vae.config.scaling_factor - - height_control_image, width_control_image = control_latents.shape[2:] - control_latents = self._pack_latents( - control_latents, - batch_size * num_images_per_prompt, - num_channels_latents, - height_control_image, - width_control_image, - ) + num_channels_latents = self.transformer.config.in_channels // 4 latents, latent_image_ids = self.prepare_latents( init_image, @@ -873,16 +800,10 @@ def __call__( if self.interrupt: continue - if control_latents is not None: - latent_model_input = torch.cat([latents, control_latents], dim=2) - else: - latent_model_input = latents - # broadcast to batch dimension in a way that's compatible with ONNX/Core ML timestep = t.expand(latents.shape[0]).to(latents.dtype) - noise_pred = self.transformer( - hidden_states=latent_model_input, + hidden_states=latents, timestep=timestep / 1000, guidance=guidance, pooled_projections=pooled_prompt_embeds, From 3169bf5ee11ffdb3ae8fb7d53ba7d55b7afdc0b8 Mon Sep 17 00:00:00 2001 From: Aryan Date: Fri, 22 Nov 2024 06:26:28 +0100 Subject: [PATCH 08/58] make fix-copies --- .../dummy_torch_and_transformers_objects.py | 30 +++++++++++++++++++ 1 file changed, 30 insertions(+) diff --git a/src/diffusers/utils/dummy_torch_and_transformers_objects.py b/src/diffusers/utils/dummy_torch_and_transformers_objects.py index 9f8689457555..16cbe9bcb354 100644 --- a/src/diffusers/utils/dummy_torch_and_transformers_objects.py +++ b/src/diffusers/utils/dummy_torch_and_transformers_objects.py @@ -377,6 +377,21 @@ def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch", "transformers"]) +class FluxControlImg2ImgPipeline(metaclass=DummyObject): + _backends = ["torch", "transformers"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch", "transformers"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + class FluxControlNetImg2ImgPipeline(metaclass=DummyObject): _backends = ["torch", "transformers"] @@ -422,6 +437,21 @@ def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch", "transformers"]) +class FluxControlPipeline(metaclass=DummyObject): + _backends = ["torch", "transformers"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch", "transformers"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + class FluxFillPipeline(metaclass=DummyObject): _backends = ["torch", "transformers"] From f7f006dbd3a712cd7a6df3a0db42b7e5ea421c81 Mon Sep 17 00:00:00 2001 From: Aryan Date: Fri, 22 Nov 2024 07:44:18 +0100 Subject: [PATCH 09/58] update docs --- docs/source/en/api/pipelines/flux.md | 87 ++++++++++++++++++- .../pipelines/flux/pipeline_flux_control.py | 34 ++++++-- .../pipelines/flux/pipeline_flux_img2img.py | 42 ++++++--- 3 files changed, 138 insertions(+), 25 deletions(-) diff --git a/docs/source/en/api/pipelines/flux.md b/docs/source/en/api/pipelines/flux.md index 255c69c854bc..182ce335b57a 100644 --- a/docs/source/en/api/pipelines/flux.md +++ b/docs/source/en/api/pipelines/flux.md @@ -22,12 +22,19 @@ Flux can be quite expensive to run on consumer hardware devices. However, you ca -Flux comes in two variants: +Flux comes in the following variants: -* Timestep-distilled (`black-forest-labs/FLUX.1-schnell`) -* Guidance-distilled (`black-forest-labs/FLUX.1-dev`) +| model type | model id | +|:----------:|:--------:| +| Timestep-distilled | [`black-forest-labs/FLUX.1-schnell`](https://huggingface.co/black-forest-labs/FLUX.1-schnell) | +| Guidance-distilled | [`black-forest-labs/FLUX.1-dev`](https://huggingface.co/black-forest-labs/FLUX.1-dev) | +| Fill Inpainting (Guidance-distilled) | [`black-forest-labs/FLUX.1-Fill-dev`](https://huggingface.co/black-forest-labs/FLUX.1-Fill-dev) | +| Canny Control (Guidance-distilled) | [`black-forest-labs/FLUX.1-Canny-dev`](https://huggingface.co/black-forest-labs/FLUX.1-Canny-dev) | +| Depth Control (Guidance-distilled) | [`black-forest-labs/FLUX.1-Depth-dev`](https://huggingface.co/black-forest-labs/FLUX.1-Depth-dev) | +| Canny Control (LoRA) | [`black-forest-labs/FLUX.1-Canny-dev-lora`](https://huggingface.co/black-forest-labs/FLUX.1-Canny-dev-lora) | +| Depth Control (LoRA) | [`black-forest-labs/FLUX.1-Depth-dev-lora`](https://huggingface.co/black-forest-labs/FLUX.1-Depth-dev-lora) | -Both checkpoints have slightly difference usage which we detail below. +All checkpoints have slightly difference usage which we detail below. ### Timestep-distilled @@ -77,7 +84,67 @@ out = pipe( out.save("image.png") ``` +### Canny Control + +**Note:** `black-forest-labs/Flux.1-Canny-dev` is not a ControlNet model. ControlNet models are a separate component from the UNet/Transformer whose residuals are added to the actual underlying model. Canny Control is an alternate architecture that achieves effectively the same results as a ControlNet model would, by using channel-wise concatenation with input control condition and ensuring the transformer learns structure control by following the condition as closely as possible. + +```python +import torch +from controlnet_aux import CannyDetector +from diffusers import FluxControlPipeline +from diffusers.utils import load_image + +pipe = FluxControlPipeline.from_pretrained("black-forest-labs/FLUX.1-Canny-dev", torch_dtype=torch.bfloat16).to("cuda") + +prompt = "A robot made of exotic candies and chocolates of different kinds. The background is filled with confetti and celebratory gifts." +control_image = load_image("https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/robot.png") + +processor = CannyDetector() +control_image = processor(control_image, low_threshold=50, high_threshold=200, detect_resolution=1024, image_resolution=1024) + +image = pipe( + prompt=prompt, + control_image=control_image, + height=1024, + width=1024, + num_inference_steps=50, + guidance_scale=30.0, +).images[0] +image.save("output.png") +``` + +### Depth Control + +**Note:** `black-forest-labs/Flux.1-Depth-dev` is not a ControlNet model. ControlNet models are a separate component from the UNet/Transformer whose residuals are added to the actual underlying model. Depth Control is an alternate architecture that achieves effectively the same results as a ControlNet model would, by using channel-wise concatenation with input control condition and ensuring the transformer learns structure control by following the condition as closely as possible. + +```python +import torch +from diffusers import FluxControlPipeline, FluxTransformer2DModel +from diffusers.utils import load_image +from image_gen_aux import DepthPreprocessor + +pipe = FluxControlPipeline.from_pretrained("black-forest-labs/FLUX.1-Depth-dev", torch_dtype=torch.bfloat16).to("cuda") + +prompt = "A robot made of exotic candies and chocolates of different kinds. The background is filled with confetti and celebratory gifts." +control_image = load_image("https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/robot.png") + +processor = DepthPreprocessor.from_pretrained("LiheYoung/depth-anything-large-hf") +control_image = processor(control_image)[0].convert("RGB") + +image = pipe( + prompt=prompt, + control_image=control_image, + height=1024, + width=1024, + num_inference_steps=30, + guidance_scale=10.0, + generator=torch.Generator().manual_seed(42), +).images[0] +image.save("output.png") +``` + ## Running FP16 inference + Flux can generate high-quality images with FP16 (i.e. to accelerate inference on Turing/Volta GPUs) but produces different outputs compared to FP32/BF16. The issue is that some activations in the text encoders have to be clipped when running in FP16, which affects the overall image. Forcing text encoders to run with FP32 inference thus removes this output difference. See [here](https://github.com/huggingface/diffusers/pull/9097#issuecomment-2272292516) for details. FP16 inference code: @@ -188,3 +255,15 @@ image.save("flux-fp8-dev.png") [[autodoc]] FluxControlNetImg2ImgPipeline - all - __call__ + +## FluxControlPipeline + +[[autodoc]] FluxControlPipeline + - all + - __call__ + +## FluxControlImg2ImgPipeline + +[[autodoc]] FluxControlImg2ImgPipeline + - all + - __call__ diff --git a/src/diffusers/pipelines/flux/pipeline_flux_control.py b/src/diffusers/pipelines/flux/pipeline_flux_control.py index 05af2e090991..44051158e3e2 100644 --- a/src/diffusers/pipelines/flux/pipeline_flux_control.py +++ b/src/diffusers/pipelines/flux/pipeline_flux_control.py @@ -51,15 +51,33 @@ Examples: ```py >>> import torch + >>> from controlnet_aux import CannyDetector >>> from diffusers import FluxControlPipeline - - >>> pipe = FluxControlPipeline.from_pretrained("black-forest-labs/FLUX.1-schnell", torch_dtype=torch.bfloat16) - >>> pipe.to("cuda") - >>> prompt = "A cat holding a sign that says hello world" - >>> # Depending on the variant being used, the pipeline call will slightly vary. - >>> # Refer to the pipeline documentation for more details. - >>> image = pipe(prompt, num_inference_steps=4, guidance_scale=0.0).images[0] - >>> image.save("flux.png") + >>> from diffusers.utils import load_image + + >>> pipe = FluxControlPipeline.from_pretrained( + ... "black-forest-labs/FLUX.1-Canny-dev", torch_dtype=torch.bfloat16 + ... ).to("cuda") + + >>> prompt = "A robot made of exotic candies and chocolates of different kinds. The background is filled with confetti and celebratory gifts." + >>> control_image = load_image( + ... "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/robot.png" + ... ) + + >>> processor = CannyDetector() + >>> control_image = processor( + ... control_image, low_threshold=50, high_threshold=200, detect_resolution=1024, image_resolution=1024 + ... ) + + >>> image = pipe( + ... prompt=prompt, + ... control_image=control_image, + ... height=1024, + ... width=1024, + ... num_inference_steps=50, + ... guidance_scale=30.0, + ... ).images[0] + >>> image.save("output.png") ``` """ diff --git a/src/diffusers/pipelines/flux/pipeline_flux_img2img.py b/src/diffusers/pipelines/flux/pipeline_flux_img2img.py index d34d9b53aa6b..d894091aa11d 100644 --- a/src/diffusers/pipelines/flux/pipeline_flux_img2img.py +++ b/src/diffusers/pipelines/flux/pipeline_flux_img2img.py @@ -51,22 +51,38 @@ Examples: ```py >>> import torch - - >>> from diffusers import FluxImg2ImgPipeline + >>> from controlnet_aux import CannyDetector + >>> from diffusers import FluxControlImg2ImgPipeline >>> from diffusers.utils import load_image - >>> device = "cuda" - >>> pipe = FluxImg2ImgPipeline.from_pretrained("black-forest-labs/FLUX.1-schnell", torch_dtype=torch.bfloat16) - >>> pipe = pipe.to(device) - - >>> url = "https://raw.githubusercontent.com/CompVis/stable-diffusion/main/assets/stable-samples/img2img/sketch-mountains-input.jpg" - >>> init_image = load_image(url).resize((1024, 1024)) - - >>> prompt = "cat wizard, gandalf, lord of the rings, detailed, fantasy, cute, adorable, Pixar, Disney, 8k" - - >>> images = pipe( - ... prompt=prompt, image=init_image, num_inference_steps=4, strength=0.95, guidance_scale=0.0 + >>> pipe = FluxControlImg2ImgPipeline.from_pretrained( + ... "black-forest-labs/FLUX.1-Canny-dev", torch_dtype=torch.bfloat16 + ... ).to("cuda") + + >>> prompt = "A robot made of exotic candies and chocolates of different kinds. Abstract background" + >>> image = load_image( + ... "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/watercolor-painting.jpg" + ... ) + >>> control_image = load_image( + ... "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/robot.png" + ... ) + + >>> processor = CannyDetector() + >>> control_image = processor( + ... control_image, low_threshold=50, high_threshold=200, detect_resolution=1024, image_resolution=1024 + ... ) + + >>> image = pipe( + ... prompt=prompt, + ... image=image, + ... control_image=control_image, + ... strength=0.8, + ... height=1024, + ... width=1024, + ... num_inference_steps=50, + ... guidance_scale=30.0, ... ).images[0] + >>> image.save("output.png") ``` """ From 8bb940ed31c08ff17bfbf466aea611b9ae148365 Mon Sep 17 00:00:00 2001 From: Aryan Date: Fri, 22 Nov 2024 07:57:03 +0100 Subject: [PATCH 10/58] add tests --- .../flux/test_pipeline_flux_control.py | 203 ++++++++++++++++++ .../test_pipeline_flux_control_img2img.py | 168 +++++++++++++++ 2 files changed, 371 insertions(+) create mode 100644 tests/pipelines/flux/test_pipeline_flux_control.py create mode 100644 tests/pipelines/flux/test_pipeline_flux_control_img2img.py diff --git a/tests/pipelines/flux/test_pipeline_flux_control.py b/tests/pipelines/flux/test_pipeline_flux_control.py new file mode 100644 index 000000000000..2bd511db3d65 --- /dev/null +++ b/tests/pipelines/flux/test_pipeline_flux_control.py @@ -0,0 +1,203 @@ +import unittest + +import numpy as np +import torch +from PIL import Image +from transformers import AutoTokenizer, CLIPTextConfig, CLIPTextModel, CLIPTokenizer, T5EncoderModel + +from diffusers import AutoencoderKL, FlowMatchEulerDiscreteScheduler, FluxControlPipeline, FluxTransformer2DModel +from diffusers.utils.testing_utils import torch_device + +from ..test_pipelines_common import ( + PipelineTesterMixin, + check_qkv_fusion_matches_attn_procs_length, + check_qkv_fusion_processors_exist, +) + + +class FluxControlPipelineFastTests(unittest.TestCase, PipelineTesterMixin): + pipeline_class = FluxControlPipeline + params = frozenset(["prompt", "height", "width", "guidance_scale", "prompt_embeds", "pooled_prompt_embeds"]) + batch_params = frozenset(["prompt"]) + + # there is no xformers processor for Flux + test_xformers_attention = False + + def get_dummy_components(self): + torch.manual_seed(0) + transformer = FluxTransformer2DModel( + patch_size=1, + in_channels=8, + out_channels=4, + num_layers=1, + num_single_layers=1, + attention_head_dim=16, + num_attention_heads=2, + joint_attention_dim=32, + pooled_projection_dim=32, + axes_dims_rope=[4, 4, 8], + ) + clip_text_encoder_config = CLIPTextConfig( + bos_token_id=0, + eos_token_id=2, + hidden_size=32, + intermediate_size=37, + layer_norm_eps=1e-05, + num_attention_heads=4, + num_hidden_layers=5, + pad_token_id=1, + vocab_size=1000, + hidden_act="gelu", + projection_dim=32, + ) + + torch.manual_seed(0) + text_encoder = CLIPTextModel(clip_text_encoder_config) + + torch.manual_seed(0) + text_encoder_2 = T5EncoderModel.from_pretrained("hf-internal-testing/tiny-random-t5") + + tokenizer = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip") + tokenizer_2 = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-t5") + + torch.manual_seed(0) + vae = AutoencoderKL( + sample_size=32, + in_channels=3, + out_channels=3, + block_out_channels=(4,), + layers_per_block=1, + latent_channels=1, + norm_num_groups=1, + use_quant_conv=False, + use_post_quant_conv=False, + shift_factor=0.0609, + scaling_factor=1.5035, + ) + + scheduler = FlowMatchEulerDiscreteScheduler() + + return { + "scheduler": scheduler, + "text_encoder": text_encoder, + "text_encoder_2": text_encoder_2, + "tokenizer": tokenizer, + "tokenizer_2": tokenizer_2, + "transformer": transformer, + "vae": vae, + } + + def get_dummy_inputs(self, device, seed=0): + if str(device).startswith("mps"): + generator = torch.manual_seed(seed) + else: + generator = torch.Generator(device="cpu").manual_seed(seed) + + control_image = Image.new("RGB", (16, 16), 0) + + inputs = { + "prompt": "A painting of a squirrel eating a burger", + "control_image": control_image, + "generator": generator, + "num_inference_steps": 2, + "guidance_scale": 5.0, + "height": 8, + "width": 8, + "max_sequence_length": 48, + "output_type": "np", + } + return inputs + + def test_flux_different_prompts(self): + pipe = self.pipeline_class(**self.get_dummy_components()).to(torch_device) + + inputs = self.get_dummy_inputs(torch_device) + output_same_prompt = pipe(**inputs).images[0] + + inputs = self.get_dummy_inputs(torch_device) + inputs["prompt_2"] = "a different prompt" + output_different_prompts = pipe(**inputs).images[0] + + max_diff = np.abs(output_same_prompt - output_different_prompts).max() + + # Outputs should be different here + # For some reasons, they don't show large differences + assert max_diff > 1e-6 + + def test_flux_prompt_embeds(self): + pipe = self.pipeline_class(**self.get_dummy_components()).to(torch_device) + inputs = self.get_dummy_inputs(torch_device) + + output_with_prompt = pipe(**inputs).images[0] + + inputs = self.get_dummy_inputs(torch_device) + prompt = inputs.pop("prompt") + + (prompt_embeds, pooled_prompt_embeds, text_ids) = pipe.encode_prompt( + prompt, + prompt_2=None, + device=torch_device, + max_sequence_length=inputs["max_sequence_length"], + ) + output_with_embeds = pipe( + prompt_embeds=prompt_embeds, + pooled_prompt_embeds=pooled_prompt_embeds, + **inputs, + ).images[0] + + max_diff = np.abs(output_with_prompt - output_with_embeds).max() + assert max_diff < 1e-4 + + def test_fused_qkv_projections(self): + device = "cpu" # ensure determinism for the device-dependent torch.Generator + components = self.get_dummy_components() + pipe = self.pipeline_class(**components) + pipe = pipe.to(device) + pipe.set_progress_bar_config(disable=None) + + inputs = self.get_dummy_inputs(device) + image = pipe(**inputs).images + original_image_slice = image[0, -3:, -3:, -1] + + # TODO (sayakpaul): will refactor this once `fuse_qkv_projections()` has been added + # to the pipeline level. + pipe.transformer.fuse_qkv_projections() + assert check_qkv_fusion_processors_exist( + pipe.transformer + ), "Something wrong with the fused attention processors. Expected all the attention processors to be fused." + assert check_qkv_fusion_matches_attn_procs_length( + pipe.transformer, pipe.transformer.original_attn_processors + ), "Something wrong with the attention processors concerning the fused QKV projections." + + inputs = self.get_dummy_inputs(device) + image = pipe(**inputs).images + image_slice_fused = image[0, -3:, -3:, -1] + + pipe.transformer.unfuse_qkv_projections() + inputs = self.get_dummy_inputs(device) + image = pipe(**inputs).images + image_slice_disabled = image[0, -3:, -3:, -1] + + assert np.allclose( + original_image_slice, image_slice_fused, atol=1e-3, rtol=1e-3 + ), "Fusion of QKV projections shouldn't affect the outputs." + assert np.allclose( + image_slice_fused, image_slice_disabled, atol=1e-3, rtol=1e-3 + ), "Outputs, with QKV projection fusion enabled, shouldn't change when fused QKV projections are disabled." + assert np.allclose( + original_image_slice, image_slice_disabled, atol=1e-2, rtol=1e-2 + ), "Original outputs should match when fused QKV projections are disabled." + + def test_flux_image_output_shape(self): + pipe = self.pipeline_class(**self.get_dummy_components()).to(torch_device) + inputs = self.get_dummy_inputs(torch_device) + + height_width_pairs = [(32, 32), (72, 57)] + for height, width in height_width_pairs: + expected_height = height - height % (pipe.vae_scale_factor * 2) + expected_width = width - width % (pipe.vae_scale_factor * 2) + + inputs.update({"height": height, "width": width}) + image = pipe(**inputs).images[0] + output_height, output_width, _ = image.shape + assert (output_height, output_width) == (expected_height, expected_width) diff --git a/tests/pipelines/flux/test_pipeline_flux_control_img2img.py b/tests/pipelines/flux/test_pipeline_flux_control_img2img.py new file mode 100644 index 000000000000..807013270eda --- /dev/null +++ b/tests/pipelines/flux/test_pipeline_flux_control_img2img.py @@ -0,0 +1,168 @@ +import unittest + +import numpy as np +import torch +from PIL import Image +from transformers import AutoTokenizer, CLIPTextConfig, CLIPTextModel, CLIPTokenizer, T5EncoderModel + +from diffusers import ( + AutoencoderKL, + FlowMatchEulerDiscreteScheduler, + FluxControlImg2ImgPipeline, + FluxTransformer2DModel, +) +from diffusers.utils.testing_utils import enable_full_determinism, torch_device + +from ..test_pipelines_common import PipelineTesterMixin + + +enable_full_determinism() + + +class FluxControlImg2ImgPipelineFastTests(unittest.TestCase, PipelineTesterMixin): + pipeline_class = FluxControlImg2ImgPipeline + params = frozenset(["prompt", "height", "width", "guidance_scale", "prompt_embeds", "pooled_prompt_embeds"]) + batch_params = frozenset(["prompt"]) + test_xformers_attention = False + + def get_dummy_components(self): + torch.manual_seed(0) + transformer = FluxTransformer2DModel( + patch_size=1, + in_channels=8, + out_channels=4, + num_layers=1, + num_single_layers=1, + attention_head_dim=16, + num_attention_heads=2, + joint_attention_dim=32, + pooled_projection_dim=32, + axes_dims_rope=[4, 4, 8], + ) + clip_text_encoder_config = CLIPTextConfig( + bos_token_id=0, + eos_token_id=2, + hidden_size=32, + intermediate_size=37, + layer_norm_eps=1e-05, + num_attention_heads=4, + num_hidden_layers=5, + pad_token_id=1, + vocab_size=1000, + hidden_act="gelu", + projection_dim=32, + ) + + torch.manual_seed(0) + text_encoder = CLIPTextModel(clip_text_encoder_config) + + torch.manual_seed(0) + text_encoder_2 = T5EncoderModel.from_pretrained("hf-internal-testing/tiny-random-t5") + + tokenizer = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip") + tokenizer_2 = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-t5") + + torch.manual_seed(0) + vae = AutoencoderKL( + sample_size=32, + in_channels=3, + out_channels=3, + block_out_channels=(4,), + layers_per_block=1, + latent_channels=1, + norm_num_groups=1, + use_quant_conv=False, + use_post_quant_conv=False, + shift_factor=0.0609, + scaling_factor=1.5035, + ) + + scheduler = FlowMatchEulerDiscreteScheduler() + + return { + "scheduler": scheduler, + "text_encoder": text_encoder, + "text_encoder_2": text_encoder_2, + "tokenizer": tokenizer, + "tokenizer_2": tokenizer_2, + "transformer": transformer, + "vae": vae, + } + + def get_dummy_inputs(self, device, seed=0): + if str(device).startswith("mps"): + generator = torch.manual_seed(seed) + else: + generator = torch.Generator(device="cpu").manual_seed(seed) + + image = Image.new("RGB", (16, 16), 0) + control_image = Image.new("RGB", (16, 16), 0) + + inputs = { + "prompt": "A painting of a squirrel eating a burger", + "image": image, + "control_image": control_image, + "generator": generator, + "num_inference_steps": 2, + "guidance_scale": 5.0, + "height": 8, + "width": 8, + "max_sequence_length": 48, + "strength": 0.8, + "output_type": "np", + } + return inputs + + def test_flux_different_prompts(self): + pipe = self.pipeline_class(**self.get_dummy_components()).to(torch_device) + + inputs = self.get_dummy_inputs(torch_device) + output_same_prompt = pipe(**inputs).images[0] + + inputs = self.get_dummy_inputs(torch_device) + inputs["prompt_2"] = "a different prompt" + output_different_prompts = pipe(**inputs).images[0] + + max_diff = np.abs(output_same_prompt - output_different_prompts).max() + + # Outputs should be different here + # For some reasons, they don't show large differences + assert max_diff > 1e-6 + + def test_flux_prompt_embeds(self): + pipe = self.pipeline_class(**self.get_dummy_components()).to(torch_device) + inputs = self.get_dummy_inputs(torch_device) + + output_with_prompt = pipe(**inputs).images[0] + + inputs = self.get_dummy_inputs(torch_device) + prompt = inputs.pop("prompt") + + (prompt_embeds, pooled_prompt_embeds, text_ids) = pipe.encode_prompt( + prompt, + prompt_2=None, + device=torch_device, + max_sequence_length=inputs["max_sequence_length"], + ) + output_with_embeds = pipe( + prompt_embeds=prompt_embeds, + pooled_prompt_embeds=pooled_prompt_embeds, + **inputs, + ).images[0] + + max_diff = np.abs(output_with_prompt - output_with_embeds).max() + assert max_diff < 1e-4 + + def test_flux_image_output_shape(self): + pipe = self.pipeline_class(**self.get_dummy_components()).to(torch_device) + inputs = self.get_dummy_inputs(torch_device) + + height_width_pairs = [(32, 32), (72, 57)] + for height, width in height_width_pairs: + expected_height = height - height % (pipe.vae_scale_factor * 2) + expected_width = width - width % (pipe.vae_scale_factor * 2) + + inputs.update({"height": height, "width": width}) + image = pipe(**inputs).images[0] + output_height, output_width, _ = image.shape + assert (output_height, output_width) == (expected_height, expected_width) From 9e615fdf426a3a4cd6171d5fc76d9810243a43eb Mon Sep 17 00:00:00 2001 From: Aryan Date: Fri, 22 Nov 2024 07:59:00 +0100 Subject: [PATCH 11/58] fix --- .../flux/pipeline_flux_control_img2img.py | 34 ++++++++++----- .../pipelines/flux/pipeline_flux_img2img.py | 42 ++++++------------- 2 files changed, 37 insertions(+), 39 deletions(-) diff --git a/src/diffusers/pipelines/flux/pipeline_flux_control_img2img.py b/src/diffusers/pipelines/flux/pipeline_flux_control_img2img.py index afda0bd9355f..684908e93deb 100644 --- a/src/diffusers/pipelines/flux/pipeline_flux_control_img2img.py +++ b/src/diffusers/pipelines/flux/pipeline_flux_control_img2img.py @@ -51,24 +51,38 @@ Examples: ```py >>> import torch - + >>> from controlnet_aux import CannyDetector >>> from diffusers import FluxControlImg2ImgPipeline >>> from diffusers.utils import load_image - >>> device = "cuda" >>> pipe = FluxControlImg2ImgPipeline.from_pretrained( - ... "black-forest-labs/FLUX.1-schnell", torch_dtype=torch.bfloat16 - ... ) - >>> pipe = pipe.to(device) + ... "black-forest-labs/FLUX.1-Canny-dev", torch_dtype=torch.bfloat16 + ... ).to("cuda") - >>> url = "https://raw.githubusercontent.com/CompVis/stable-diffusion/main/assets/stable-samples/img2img/sketch-mountains-input.jpg" - >>> init_image = load_image(url).resize((1024, 1024)) + >>> prompt = "A robot made of exotic candies and chocolates of different kinds. Abstract background" + >>> image = load_image( + ... "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/watercolor-painting.jpg" + ... ) + >>> control_image = load_image( + ... "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/robot.png" + ... ) - >>> prompt = "cat wizard, gandalf, lord of the rings, detailed, fantasy, cute, adorable, Pixar, Disney, 8k" + >>> processor = CannyDetector() + >>> control_image = processor( + ... control_image, low_threshold=50, high_threshold=200, detect_resolution=1024, image_resolution=1024 + ... ) - >>> images = pipe( - ... prompt=prompt, image=init_image, num_inference_steps=4, strength=0.95, guidance_scale=0.0 + >>> image = pipe( + ... prompt=prompt, + ... image=image, + ... control_image=control_image, + ... strength=0.8, + ... height=1024, + ... width=1024, + ... num_inference_steps=50, + ... guidance_scale=30.0, ... ).images[0] + >>> image.save("output.png") ``` """ diff --git a/src/diffusers/pipelines/flux/pipeline_flux_img2img.py b/src/diffusers/pipelines/flux/pipeline_flux_img2img.py index d894091aa11d..d34d9b53aa6b 100644 --- a/src/diffusers/pipelines/flux/pipeline_flux_img2img.py +++ b/src/diffusers/pipelines/flux/pipeline_flux_img2img.py @@ -51,38 +51,22 @@ Examples: ```py >>> import torch - >>> from controlnet_aux import CannyDetector - >>> from diffusers import FluxControlImg2ImgPipeline + + >>> from diffusers import FluxImg2ImgPipeline >>> from diffusers.utils import load_image - >>> pipe = FluxControlImg2ImgPipeline.from_pretrained( - ... "black-forest-labs/FLUX.1-Canny-dev", torch_dtype=torch.bfloat16 - ... ).to("cuda") - - >>> prompt = "A robot made of exotic candies and chocolates of different kinds. Abstract background" - >>> image = load_image( - ... "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/watercolor-painting.jpg" - ... ) - >>> control_image = load_image( - ... "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/robot.png" - ... ) - - >>> processor = CannyDetector() - >>> control_image = processor( - ... control_image, low_threshold=50, high_threshold=200, detect_resolution=1024, image_resolution=1024 - ... ) - - >>> image = pipe( - ... prompt=prompt, - ... image=image, - ... control_image=control_image, - ... strength=0.8, - ... height=1024, - ... width=1024, - ... num_inference_steps=50, - ... guidance_scale=30.0, + >>> device = "cuda" + >>> pipe = FluxImg2ImgPipeline.from_pretrained("black-forest-labs/FLUX.1-schnell", torch_dtype=torch.bfloat16) + >>> pipe = pipe.to(device) + + >>> url = "https://raw.githubusercontent.com/CompVis/stable-diffusion/main/assets/stable-samples/img2img/sketch-mountains-input.jpg" + >>> init_image = load_image(url).resize((1024, 1024)) + + >>> prompt = "cat wizard, gandalf, lord of the rings, detailed, fantasy, cute, adorable, Pixar, Disney, 8k" + + >>> images = pipe( + ... prompt=prompt, image=init_image, num_inference_steps=4, strength=0.95, guidance_scale=0.0 ... ).images[0] - >>> image.save("output.png") ``` """ From 89fd9707742bf17ebfec88b336a23f0c27a30b2f Mon Sep 17 00:00:00 2001 From: Aryan Date: Sat, 23 Nov 2024 02:18:47 +0530 Subject: [PATCH 12/58] Apply suggestions from code review Co-authored-by: Sayak Paul --- docs/source/en/api/pipelines/flux.md | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/docs/source/en/api/pipelines/flux.md b/docs/source/en/api/pipelines/flux.md index 182ce335b57a..76aea520a99b 100644 --- a/docs/source/en/api/pipelines/flux.md +++ b/docs/source/en/api/pipelines/flux.md @@ -33,8 +33,9 @@ Flux comes in the following variants: | Depth Control (Guidance-distilled) | [`black-forest-labs/FLUX.1-Depth-dev`](https://huggingface.co/black-forest-labs/FLUX.1-Depth-dev) | | Canny Control (LoRA) | [`black-forest-labs/FLUX.1-Canny-dev-lora`](https://huggingface.co/black-forest-labs/FLUX.1-Canny-dev-lora) | | Depth Control (LoRA) | [`black-forest-labs/FLUX.1-Depth-dev-lora`](https://huggingface.co/black-forest-labs/FLUX.1-Depth-dev-lora) | +| Redux (Adapter) | [`black-forest-labs/FLUX.1-Redux-dev`](https://huggingface.co/black-forest-labs/FLUX.1-Redux-dev) | -All checkpoints have slightly difference usage which we detail below. +All checkpoints have different usage which we detail below. ### Timestep-distilled @@ -86,7 +87,7 @@ out.save("image.png") ### Canny Control -**Note:** `black-forest-labs/Flux.1-Canny-dev` is not a ControlNet model. ControlNet models are a separate component from the UNet/Transformer whose residuals are added to the actual underlying model. Canny Control is an alternate architecture that achieves effectively the same results as a ControlNet model would, by using channel-wise concatenation with input control condition and ensuring the transformer learns structure control by following the condition as closely as possible. +**Note:** `black-forest-labs/Flux.1-Canny-dev` is _not_ a [`ControlNetModel`] model. ControlNet models are a separate component from the UNet/Transformer whose residuals are added to the actual underlying model. Canny Control is an alternate architecture that achieves effectively the same results as a ControlNet model would, by using channel-wise concatenation with input control condition and ensuring the transformer learns structure control by following the condition as closely as possible. ```python import torch @@ -115,7 +116,7 @@ image.save("output.png") ### Depth Control -**Note:** `black-forest-labs/Flux.1-Depth-dev` is not a ControlNet model. ControlNet models are a separate component from the UNet/Transformer whose residuals are added to the actual underlying model. Depth Control is an alternate architecture that achieves effectively the same results as a ControlNet model would, by using channel-wise concatenation with input control condition and ensuring the transformer learns structure control by following the condition as closely as possible. +**Note:** `black-forest-labs/Flux.1-Depth-dev` is _not_ a ControlNet model. [`ControlNetModel`] models are a separate component from the UNet/Transformer whose residuals are added to the actual underlying model. Depth Control is an alternate architecture that achieves effectively the same results as a ControlNet model would, by using channel-wise concatenation with input control condition and ensuring the transformer learns structure control by following the condition as closely as possible. ```python import torch From 73cfc519c9b99b7dc3251cc6a90a5db3056c4819 Mon Sep 17 00:00:00 2001 From: Aryan Date: Fri, 22 Nov 2024 21:50:30 +0100 Subject: [PATCH 13/58] remove control lora changes --- .../convert_flux_control_lora_to_diffusers.py | 393 ------------------ src/diffusers/loaders/lora_pipeline.py | 175 +------- src/diffusers/loaders/peft.py | 5 +- 3 files changed, 10 insertions(+), 563 deletions(-) delete mode 100644 scripts/convert_flux_control_lora_to_diffusers.py diff --git a/scripts/convert_flux_control_lora_to_diffusers.py b/scripts/convert_flux_control_lora_to_diffusers.py deleted file mode 100644 index a110bd0bc0c8..000000000000 --- a/scripts/convert_flux_control_lora_to_diffusers.py +++ /dev/null @@ -1,393 +0,0 @@ -import argparse -from contextlib import nullcontext - -import safetensors.torch -import torch -from accelerate import init_empty_weights -from huggingface_hub import hf_hub_download - -from diffusers.utils.import_utils import is_accelerate_available - - -CTX = init_empty_weights if is_accelerate_available else nullcontext - -parser = argparse.ArgumentParser() -parser.add_argument("--original_state_dict_repo_id", default=None, type=str) -parser.add_argument("--filename", default="flux-canny-dev-lora.safetensors", type=str) -parser.add_argument("--checkpoint_path", default=None, type=str) -parser.add_argument("--output_path", type=str) -parser.add_argument("--dtype", type=str, default="bf16") - -args = parser.parse_args() -dtype = torch.bfloat16 if args.dtype == "bf16" else torch.float32 - - -# Adapted from from the original BFL codebase. -def optionally_expand_state_dict(name: str, param: torch.Tensor, state_dict: dict) -> dict: - if name in state_dict: - print(f"Expanding '{name}' with shape {state_dict[name].shape} to model parameter with shape {param.shape}.") - # expand with zeros: - expanded_state_dict_weight = torch.zeros_like(param, device=state_dict[name].device) - # popular with pre-trained param for the first half. Remaining half stays with zeros. - slices = tuple(slice(0, dim) for dim in state_dict[name].shape) - expanded_state_dict_weight[slices] = state_dict[name] - state_dict[name] = expanded_state_dict_weight - - return state_dict - - -def load_original_checkpoint(args): - if args.original_state_dict_repo_id is not None: - ckpt_path = hf_hub_download(repo_id=args.original_state_dict_repo_id, filename=args.filename) - elif args.checkpoint_path is not None: - ckpt_path = args.checkpoint_path - else: - raise ValueError(" please provide either `original_state_dict_repo_id` or a local `checkpoint_path`") - - original_state_dict = safetensors.torch.load_file(ckpt_path) - return original_state_dict - - -# in SD3 original implementation of AdaLayerNormContinuous, it split linear projection output into shift, scale; -# while in diffusers it split into scale, shift. Here we swap the linear projection weights in order to be able to use diffusers implementation -def swap_scale_shift(weight): - shift, scale = weight.chunk(2, dim=0) - new_weight = torch.cat([scale, shift], dim=0) - return new_weight - - -def convert_flux_control_lora_checkpoint_to_diffusers( - original_state_dict, num_layers, num_single_layers, inner_dim, mlp_ratio=4.0 -): - converted_state_dict = {} - - ## time_text_embed.timestep_embedder <- time_in - for lora_key, diffusers_lora_key in zip(["lora_A", "lora_B"], ["lora_A", "lora_B"]): - converted_state_dict[ - f"time_text_embed.timestep_embedder.linear_1.{diffusers_lora_key}.weight" - ] = original_state_dict.pop(f"time_in.in_layer.{lora_key}.weight") - if f"time_in.in_layer.{lora_key}.bias" in original_state_dict.keys(): - converted_state_dict[ - f"time_text_embed.timestep_embedder.linear_1.{diffusers_lora_key}.bias" - ] = original_state_dict.pop(f"time_in.in_layer.{lora_key}.bias") - - converted_state_dict[ - f"time_text_embed.timestep_embedder.linear_2.{diffusers_lora_key}.weight" - ] = original_state_dict.pop(f"time_in.out_layer.{lora_key}.weight") - if f"time_in.out_layer.{lora_key}.bias" in original_state_dict.keys(): - converted_state_dict[ - f"time_text_embed.timestep_embedder.linear_2.{diffusers_lora_key}.bias" - ] = original_state_dict.pop(f"time_in.out_layer.{lora_key}.bias") - - ## time_text_embed.text_embedder <- vector_in - converted_state_dict[ - f"time_text_embed.text_embedder.linear_1.{diffusers_lora_key}.weight" - ] = original_state_dict.pop(f"vector_in.in_layer.{lora_key}.weight") - if f"vector_in.in_layer.{lora_key}.bias" in original_state_dict.keys(): - converted_state_dict[ - f"time_text_embed.text_embedder.linear_1.{diffusers_lora_key}.bias" - ] = original_state_dict.pop(f"vector_in.in_layer.{lora_key}.bias") - - converted_state_dict[ - f"time_text_embed.text_embedder.linear_2.{diffusers_lora_key}.weight" - ] = original_state_dict.pop(f"vector_in.out_layer.{lora_key}.weight") - if f"vector_in.out_layer.{lora_key}.bias" in original_state_dict.keys(): - converted_state_dict[ - f"time_text_embed.text_embedder.linear_2.{diffusers_lora_key}.bias" - ] = original_state_dict.pop(f"vector_in.out_layer.{lora_key}.bias") - - # guidance - has_guidance = any("guidance" in k for k in original_state_dict) - if has_guidance: - converted_state_dict[ - f"time_text_embed.guidance_embedder.linear_1.{diffusers_lora_key}.weight" - ] = original_state_dict.pop(f"guidance_in.in_layer.{lora_key}.weight") - if f"guidance_in.in_layer.{lora_key}.bias" in original_state_dict.keys(): - converted_state_dict[ - f"time_text_embed.guidance_embedder.linear_1.{diffusers_lora_key}.bias" - ] = original_state_dict.pop(f"guidance_in.in_layer.{lora_key}.bias") - - converted_state_dict[ - f"time_text_embed.guidance_embedder.linear_2.{diffusers_lora_key}.weight" - ] = original_state_dict.pop(f"guidance_in.out_layer.{lora_key}.weight") - if f"guidance_in.out_layer.{lora_key}.bias" in original_state_dict.keys(): - converted_state_dict[ - f"time_text_embed.guidance_embedder.linear_2.{diffusers_lora_key}.bias" - ] = original_state_dict.pop(f"guidance_in.out_layer.{lora_key}.bias") - - # context_embedder - converted_state_dict[f"context_embedder.{diffusers_lora_key}.weight"] = original_state_dict.pop( - f"txt_in.{lora_key}.weight" - ) - if f"txt_in.{lora_key}.bias" in original_state_dict.keys(): - converted_state_dict[f"context_embedder.{diffusers_lora_key}.bias"] = original_state_dict.pop( - f"txt_in.{lora_key}.bias" - ) - - # x_embedder - converted_state_dict[f"x_embedder.{diffusers_lora_key}.weight"] = original_state_dict.pop( - f"img_in.{lora_key}.weight" - ) - if f"img_in.{lora_key}.bias" in original_state_dict.keys(): - converted_state_dict[f"x_embedder.{diffusers_lora_key}.bias"] = original_state_dict.pop( - f"img_in.{lora_key}.bias" - ) - - # double transformer blocks - for i in range(num_layers): - block_prefix = f"transformer_blocks.{i}." - - for lora_key, diffusers_lora_key in zip(["lora_A", "lora_B"], ["lora_A", "lora_B"]): - # norms - converted_state_dict[f"{block_prefix}norm1.linear.{diffusers_lora_key}.weight"] = original_state_dict.pop( - f"double_blocks.{i}.img_mod.lin.{lora_key}.weight" - ) - if f"double_blocks.{i}.img_mod.lin.{lora_key}.bias" in original_state_dict.keys(): - converted_state_dict[ - f"{block_prefix}norm1.linear.{diffusers_lora_key}.bias" - ] = original_state_dict.pop(f"double_blocks.{i}.img_mod.lin.{lora_key}.bias") - - converted_state_dict[ - f"{block_prefix}norm1_context.linear.{diffusers_lora_key}.weight" - ] = original_state_dict.pop(f"double_blocks.{i}.txt_mod.lin.{lora_key}.weight") - if f"double_blocks.{i}.txt_mod.lin.{lora_key}.bias" in original_state_dict.keys(): - converted_state_dict[ - f"{block_prefix}norm1_context.linear.{diffusers_lora_key}.bias" - ] = original_state_dict.pop(f"double_blocks.{i}.txt_mod.lin.{lora_key}.bias") - - # Q, K, V - if lora_key == "lora_A": - sample_lora_weight = original_state_dict.pop(f"double_blocks.{i}.img_attn.qkv.{lora_key}.weight") - converted_state_dict[f"{block_prefix}attn.to_v.{diffusers_lora_key}.weight"] = torch.cat( - [sample_lora_weight] - ) - converted_state_dict[f"{block_prefix}attn.to_q.{diffusers_lora_key}.weight"] = torch.cat( - [sample_lora_weight] - ) - converted_state_dict[f"{block_prefix}attn.to_k.{diffusers_lora_key}.weight"] = torch.cat( - [sample_lora_weight] - ) - - context_lora_weight = original_state_dict.pop(f"double_blocks.{i}.txt_attn.qkv.{lora_key}.weight") - converted_state_dict[f"{block_prefix}attn.add_q_proj.{diffusers_lora_key}.weight"] = torch.cat( - [context_lora_weight] - ) - converted_state_dict[f"{block_prefix}attn.add_k_proj.{diffusers_lora_key}.weight"] = torch.cat( - [context_lora_weight] - ) - converted_state_dict[f"{block_prefix}attn.add_v_proj.{diffusers_lora_key}.weight"] = torch.cat( - [context_lora_weight] - ) - else: - sample_q, sample_k, sample_v = torch.chunk( - original_state_dict.pop(f"double_blocks.{i}.img_attn.qkv.{lora_key}.weight"), 3, dim=0 - ) - converted_state_dict[f"{block_prefix}attn.to_q.{diffusers_lora_key}.weight"] = torch.cat([sample_q]) - converted_state_dict[f"{block_prefix}attn.to_k.{diffusers_lora_key}.weight"] = torch.cat([sample_k]) - converted_state_dict[f"{block_prefix}attn.to_v.{diffusers_lora_key}.weight"] = torch.cat([sample_v]) - - context_q, context_k, context_v = torch.chunk( - original_state_dict.pop(f"double_blocks.{i}.txt_attn.qkv.{lora_key}.weight"), 3, dim=0 - ) - converted_state_dict[f"{block_prefix}attn.add_q_proj.{diffusers_lora_key}.weight"] = torch.cat( - [context_q] - ) - converted_state_dict[f"{block_prefix}attn.add_k_proj.{diffusers_lora_key}.weight"] = torch.cat( - [context_k] - ) - converted_state_dict[f"{block_prefix}attn.add_v_proj.{diffusers_lora_key}.weight"] = torch.cat( - [context_v] - ) - - if f"double_blocks.{i}.img_attn.qkv.{lora_key}.bias" in original_state_dict.keys(): - sample_q_bias, sample_k_bias, sample_v_bias = torch.chunk( - original_state_dict.pop(f"double_blocks.{i}.img_attn.qkv.{lora_key}.bias"), 3, dim=0 - ) - converted_state_dict[f"{block_prefix}attn.to_q.{diffusers_lora_key}.bias"] = torch.cat([sample_q_bias]) - converted_state_dict[f"{block_prefix}attn.to_k.{diffusers_lora_key}.bias"] = torch.cat([sample_k_bias]) - converted_state_dict[f"{block_prefix}attn.to_v.{diffusers_lora_key}.bias"] = torch.cat([sample_v_bias]) - - if f"double_blocks.{i}.txt_attn.qkv.{lora_key}.bias" in original_state_dict.keys(): - context_q_bias, context_k_bias, context_v_bias = torch.chunk( - original_state_dict.pop(f"double_blocks.{i}.txt_attn.qkv.{lora_key}.bias"), 3, dim=0 - ) - converted_state_dict[f"{block_prefix}attn.add_q_proj.{diffusers_lora_key}.bias"] = torch.cat( - [context_q_bias] - ) - converted_state_dict[f"{block_prefix}attn.add_k_proj.{diffusers_lora_key}.bias"] = torch.cat( - [context_k_bias] - ) - converted_state_dict[f"{block_prefix}attn.add_v_proj.{diffusers_lora_key}.bias"] = torch.cat( - [context_v_bias] - ) - - # ff img_mlp - converted_state_dict[f"{block_prefix}ff.net.0.proj.{diffusers_lora_key}.weight"] = original_state_dict.pop( - f"double_blocks.{i}.img_mlp.0.{lora_key}.weight" - ) - if f"double_blocks.{i}.img_mlp.0.{lora_key}.bias" in original_state_dict.keys(): - converted_state_dict[ - f"{block_prefix}ff.net.0.proj{diffusers_lora_key}..bias" - ] = original_state_dict.pop(f"double_blocks.{i}.img_mlp.0.{lora_key}.bias") - - converted_state_dict[f"{block_prefix}ff.net.2.{diffusers_lora_key}.weight"] = original_state_dict.pop( - f"double_blocks.{i}.img_mlp.2.{lora_key}.weight" - ) - if f"double_blocks.{i}.img_mlp.2.{lora_key}.bias" in original_state_dict.keys(): - converted_state_dict[f"{block_prefix}ff.net.2.{diffusers_lora_key}.bias"] = original_state_dict.pop( - f"double_blocks.{i}.img_mlp.2.{lora_key}.bias" - ) - - converted_state_dict[ - f"{block_prefix}ff_context.net.0.proj.{diffusers_lora_key}.weight" - ] = original_state_dict.pop(f"double_blocks.{i}.txt_mlp.0.{lora_key}.weight") - if f"double_blocks.{i}.txt_mlp.0.{lora_key}.bias" in original_state_dict.keys(): - converted_state_dict[ - f"{block_prefix}ff_context.net.0.proj.{diffusers_lora_key}.bias" - ] = original_state_dict.pop(f"double_blocks.{i}.txt_mlp.0.{lora_key}.bias") - - converted_state_dict[ - f"{block_prefix}ff_context.net.2.{diffusers_lora_key}.weight" - ] = original_state_dict.pop(f"double_blocks.{i}.txt_mlp.2.{lora_key}.weight") - if f"double_blocks.{i}.txt_mlp.2.{lora_key}.bias" in original_state_dict.keys(): - converted_state_dict[ - f"{block_prefix}ff_context.net.2.{diffusers_lora_key}.bias" - ] = original_state_dict.pop(f"double_blocks.{i}.txt_mlp.2.{lora_key}.bias") - - # output projections. - converted_state_dict[f"{block_prefix}attn.to_out.0.{diffusers_lora_key}.weight"] = original_state_dict.pop( - f"double_blocks.{i}.img_attn.proj.{lora_key}.weight" - ) - if f"double_blocks.{i}.img_attn.proj.{lora_key}.bias" in original_state_dict.keys(): - converted_state_dict[ - f"{block_prefix}attn.to_out.0.{diffusers_lora_key}.bias" - ] = original_state_dict.pop(f"double_blocks.{i}.img_attn.proj.{lora_key}.bias") - converted_state_dict[ - f"{block_prefix}attn.to_add_out.{diffusers_lora_key}.weight" - ] = original_state_dict.pop(f"double_blocks.{i}.txt_attn.proj.{lora_key}.weight") - if f"double_blocks.{i}.txt_attn.proj.{lora_key}.bias" in original_state_dict.keys(): - converted_state_dict[ - f"{block_prefix}attn.to_add_out.{diffusers_lora_key}.bias" - ] = original_state_dict.pop(f"double_blocks.{i}.txt_attn.proj.{lora_key}.bias") - - # qk_norm - converted_state_dict[f"{block_prefix}attn.norm_q.weight"] = original_state_dict.pop( - f"double_blocks.{i}.img_attn.norm.query_norm.scale" - ) - converted_state_dict[f"{block_prefix}attn.norm_k.weight"] = original_state_dict.pop( - f"double_blocks.{i}.img_attn.norm.key_norm.scale" - ) - converted_state_dict[f"{block_prefix}attn.norm_added_q.weight"] = original_state_dict.pop( - f"double_blocks.{i}.txt_attn.norm.query_norm.scale" - ) - converted_state_dict[f"{block_prefix}attn.norm_added_k.weight"] = original_state_dict.pop( - f"double_blocks.{i}.txt_attn.norm.key_norm.scale" - ) - - # single transfomer blocks - for i in range(num_single_layers): - block_prefix = f"single_transformer_blocks.{i}." - - for lora_key, diffusers_lora_key in zip(["lora_A", "lora_B"], ["lora_A", "lora_B"]): - # norm.linear <- single_blocks.0.modulation.lin - converted_state_dict[f"{block_prefix}norm.linear.{diffusers_lora_key}.weight"] = original_state_dict.pop( - f"single_blocks.{i}.modulation.lin.{lora_key}.weight" - ) - if f"single_blocks.{i}.modulation.lin.{lora_key}.bias" in original_state_dict.keys(): - converted_state_dict[f"{block_prefix}norm.linear.{diffusers_lora_key}.bias"] = original_state_dict.pop( - f"single_blocks.{i}.modulation.lin.{lora_key}.bias" - ) - - # Q, K, V, mlp - mlp_hidden_dim = int(inner_dim * mlp_ratio) - split_size = (inner_dim, inner_dim, inner_dim, mlp_hidden_dim) - - if lora_key == "lora_A": - lora_weight = original_state_dict.pop(f"single_blocks.{i}.linear1.{lora_key}.weight") - converted_state_dict[f"{block_prefix}attn.to_q.{diffusers_lora_key}.weight"] = torch.cat([lora_weight]) - converted_state_dict[f"{block_prefix}attn.to_k.{diffusers_lora_key}.weight"] = torch.cat([lora_weight]) - converted_state_dict[f"{block_prefix}attn.to_v.{diffusers_lora_key}.weight"] = torch.cat([lora_weight]) - converted_state_dict[f"{block_prefix}proj_mlp.{diffusers_lora_key}.weight"] = torch.cat([lora_weight]) - - if f"single_blocks.{i}.linear1.{lora_key}.bias" in original_state_dict.keys(): - lora_bias = original_state_dict.pop(f"single_blocks.{i}.linear1.{lora_key}.bias") - converted_state_dict[f"{block_prefix}attn.to_q.{diffusers_lora_key}.bias"] = torch.cat([lora_bias]) - converted_state_dict[f"{block_prefix}attn.to_k.{diffusers_lora_key}.bias"] = torch.cat([lora_bias]) - converted_state_dict[f"{block_prefix}attn.to_v.{diffusers_lora_key}.bias"] = torch.cat([lora_bias]) - converted_state_dict[f"{block_prefix}proj_mlp.{diffusers_lora_key}.bias"] = torch.cat([lora_bias]) - else: - q, k, v, mlp = torch.split( - original_state_dict.pop(f"single_blocks.{i}.linear1.{lora_key}.weight"), split_size, dim=0 - ) - converted_state_dict[f"{block_prefix}attn.to_q.{diffusers_lora_key}.weight"] = torch.cat([q]) - converted_state_dict[f"{block_prefix}attn.to_k.{diffusers_lora_key}.weight"] = torch.cat([k]) - converted_state_dict[f"{block_prefix}attn.to_v.{diffusers_lora_key}.weight"] = torch.cat([v]) - converted_state_dict[f"{block_prefix}proj_mlp.{diffusers_lora_key}.weight"] = torch.cat([mlp]) - - if f"single_blocks.{i}.linear1.{lora_key}.bias" in original_state_dict.keys(): - q_bias, k_bias, v_bias, mlp_bias = torch.split( - original_state_dict.pop(f"single_blocks.{i}.linear1.{lora_key}.bias"), split_size, dim=0 - ) - converted_state_dict[f"{block_prefix}attn.to_q.{diffusers_lora_key}.bias"] = torch.cat([q_bias]) - converted_state_dict[f"{block_prefix}attn.to_k.{diffusers_lora_key}.bias"] = torch.cat([k_bias]) - converted_state_dict[f"{block_prefix}attn.to_v.{diffusers_lora_key}.bias"] = torch.cat([v_bias]) - converted_state_dict[f"{block_prefix}proj_mlp.{diffusers_lora_key}.bias"] = torch.cat([mlp_bias]) - - # output projections. - converted_state_dict[f"{block_prefix}proj_out.{diffusers_lora_key}.weight"] = original_state_dict.pop( - f"single_blocks.{i}.linear2.{lora_key}.weight" - ) - if f"single_blocks.{i}.linear2.{lora_key}.bias" in original_state_dict.keys(): - converted_state_dict[f"{block_prefix}proj_out.{diffusers_lora_key}.bias"] = original_state_dict.pop( - f"single_blocks.{i}.linear2.{lora_key}.bias" - ) - - # qk norm - converted_state_dict[f"{block_prefix}attn.norm_q.weight"] = original_state_dict.pop( - f"single_blocks.{i}.norm.query_norm.scale" - ) - converted_state_dict[f"{block_prefix}attn.norm_k.weight"] = original_state_dict.pop( - f"single_blocks.{i}.norm.key_norm.scale" - ) - - for lora_key, diffusers_lora_key in zip(["lora_A", "lora_B"], ["lora_A", "lora_B"]): - converted_state_dict[f"proj_out.{diffusers_lora_key}.weight"] = original_state_dict.pop( - f"final_layer.linear.{lora_key}.weight" - ) - if f"final_layer.linear.{lora_key}.bias" in original_state_dict.keys(): - converted_state_dict[f"proj_out.{diffusers_lora_key}.bias"] = original_state_dict.pop( - f"final_layer.linear.{lora_key}.bias" - ) - - converted_state_dict[f"norm_out.linear.{diffusers_lora_key}.weight"] = swap_scale_shift( - original_state_dict.pop(f"final_layer.adaLN_modulation.1.{lora_key}.weight") - ) - if f"final_layer.adaLN_modulation.1.{lora_key}.bias" in original_state_dict.keys(): - converted_state_dict[f"norm_out.linear.{diffusers_lora_key}.bias"] = swap_scale_shift( - original_state_dict.pop(f"final_layer.adaLN_modulation.1.{lora_key}.bias") - ) - - print("Remaining:", original_state_dict.keys()) - - for key in list(converted_state_dict.keys()): - converted_state_dict[f"transformer.{key}"] = converted_state_dict.pop(key) - - return converted_state_dict - - -def main(args): - original_ckpt = load_original_checkpoint(args) - - num_layers = 19 - num_single_layers = 38 - inner_dim = 3072 - mlp_ratio = 4.0 - - converted_control_lora_state_dict = convert_flux_control_lora_checkpoint_to_diffusers( - original_ckpt, num_layers, num_single_layers, inner_dim, mlp_ratio - ) - safetensors.torch.save_file(converted_control_lora_state_dict, args.output_path) - - -if __name__ == "__main__": - main(args) diff --git a/src/diffusers/loaders/lora_pipeline.py b/src/diffusers/loaders/lora_pipeline.py index b34370e620c3..109592c69c3e 100644 --- a/src/diffusers/loaders/lora_pipeline.py +++ b/src/diffusers/loaders/lora_pipeline.py @@ -1787,61 +1787,23 @@ def load_lora_weights( pretrained_model_name_or_path_or_dict, return_alphas=True, **kwargs ) - has_lora_keys = any("lora" in key for key in state_dict.keys()) - - # Flux Control LoRAs also have norm keys - supported_norm_keys = ["norm_q", "norm_k", "norm_added_q", "norm_added_k"] - has_norm_keys = any(norm_key in key for key in state_dict.keys() for norm_key in supported_norm_keys) - - if not (has_lora_keys or has_norm_keys): + is_correct_format = all("lora" in key for key in state_dict.keys()) + if not is_correct_format: raise ValueError("Invalid LoRA checkpoint.") - def prune_state_dict_(state_dict): - pruned_keys = [] - for key in list(state_dict.keys()): - is_lora_key_present = "lora" in key - is_norm_key_present = any(norm_key in key for norm_key in supported_norm_keys) - if not is_lora_key_present and not is_norm_key_present: - state_dict.pop(key) - pruned_keys.append(key) - return pruned_keys - - pruned_keys = prune_state_dict_(state_dict) - if len(pruned_keys) > 0: - logger.warning( - f"The provided LoRA state dict contains additional weights that are not compatible with Flux. The following are the incompatible weights:\n{pruned_keys}" - ) - - transformer_lora_state_dict = {k: v for k, v in state_dict.items() if "transformer." in k and "lora" in k} - transformer_norm_state_dict = { - k: v - for k, v in state_dict.items() - if "transformer." in k and any(norm_key in k for norm_key in supported_norm_keys) - } - - transformer = getattr(self, self.transformer_name) if not hasattr(self, "transformer") else self.transformer - self._maybe_expand_transformer_param_shape_( - transformer, transformer_lora_state_dict, transformer_norm_state_dict - ) - print(transformer) - - if len(transformer_lora_state_dict) > 0: + transformer_state_dict = {k: v for k, v in state_dict.items() if "transformer." in k} + if len(transformer_state_dict) > 0: self.load_lora_into_transformer( - transformer_lora_state_dict, + state_dict, network_alphas=network_alphas, - transformer=transformer, + transformer=getattr(self, self.transformer_name) + if not hasattr(self, "transformer") + else self.transformer, adapter_name=adapter_name, _pipeline=self, low_cpu_mem_usage=low_cpu_mem_usage, ) - if len(transformer_norm_state_dict) > 0: - self._transformer_norm_layers = self.load_norm_into_transformer( - transformer_norm_state_dict, - transformer=transformer, - discard_original_layers=False, - ) - text_encoder_state_dict = {k: v for k, v in state_dict.items() if "text_encoder." in k} if len(text_encoder_state_dict) > 0: self.load_lora_into_text_encoder( @@ -1898,46 +1860,6 @@ def load_lora_into_transformer( low_cpu_mem_usage=low_cpu_mem_usage, ) - @classmethod - def load_norm_into_transformer( - cls, - state_dict, - transformer, - prefix=None, - discard_original_layers=False, - ) -> Dict[str, torch.Tensor]: - # Remove prefix if present - prefix = prefix or cls.transformer_name - for key in list(state_dict.keys()): - if key.split(".")[0] == prefix: - state_dict[key.replace(f"{prefix}.", "")] = state_dict.pop(key) - - # Find invalid keys - transformer_state_dict = transformer.state_dict() - transformer_keys = set(transformer_state_dict.keys()) - state_dict_keys = set(state_dict.keys()) - extra_keys = list(state_dict_keys - transformer_keys) - logger.warning( - f"Unsupported keys found in state dict when trying to load normalization layers into the transformer. The following keys will be ignored:\n{extra_keys}." - ) - - for key in extra_keys: - state_dict.pop(key) - - # Save the layers that are going to be overwritten so that unload_lora_weights can work as expected - overwritten_layers = {} - if not discard_original_layers: - for key in state_dict.keys(): - overwritten_layers[key] = transformer_state_dict[key] - - # We can't load with strict=True because the current state_dict does not contain all the transformer keys - logger.info( - "Normalization layers in LoRA state dict can only be loaded if fused directly in the transformer. Calls to `.fuse_lora()` will only affect the LoRA layers and not the normalization layers." - ) - transformer.load_state_dict(state_dict, strict=False) - - return overwritten_layers - @classmethod # Copied from diffusers.loaders.lora_pipeline.StableDiffusionLoraLoaderMixin.load_lora_into_text_encoder def load_lora_into_text_encoder( @@ -2133,6 +2055,7 @@ def save_lora_weights( safe_serialization=safe_serialization, ) + # Copied from diffusers.loaders.lora_pipeline.StableDiffusionLoraLoaderMixin.fuse_lora with unet->transformer def fuse_lora( self, components: List[str] = ["transformer", "text_encoder"], @@ -2172,11 +2095,6 @@ def fuse_lora( pipeline.fuse_lora(lora_scale=0.7) ``` """ - if len(self._transformer_norm_layers.keys()) > 0: - logger.info( - "Normalization layers cannot be loaded without fusing. Calls to `.fuse_lora()` will only affect the actual LoRA layers." - ) - super().fuse_lora( components=components, lora_scale=lora_scale, safe_fusing=safe_fusing, adapter_names=adapter_names ) @@ -2195,83 +2113,8 @@ def unfuse_lora(self, components: List[str] = ["transformer", "text_encoder"], * Args: components (`List[str]`): List of LoRA-injectable components to unfuse LoRA from. """ - transformer = getattr(self, self.transformer_name) if not hasattr(self, "transformer") else self.transformer - transformer.load_state_dict(self._transformer_norm_layers) - super().unfuse_lora(components=components) - @classmethod - def _maybe_expand_transformer_param_shape_( - cls, - transformer: torch.nn.Module, - lora_state_dict=None, - norm_state_dict=None, - prefix=None, - ): - state_dict = {} - if lora_state_dict is not None: - state_dict.update(lora_state_dict) - if norm_state_dict is not None: - state_dict.update(norm_state_dict) - - # Remove prefix if present - prefix = prefix or cls.transformer_name - for key in list(state_dict.keys()): - if key.split(".")[0] == prefix: - state_dict[key.replace(f"{prefix}.", "")] = state_dict.pop(key) - - def get_submodule(module, name): - for part in name.split("."): - if len(name) == 0: - break - if not hasattr(module, part): - raise AttributeError(f"Submodule '{part}' not found in '{module}'.") - module = getattr(module, part) - return module - - # Expand transformer parameter shapes if they don't match lora - for name, module in transformer.named_modules(): - if isinstance(module, torch.nn.Linear): - module_weight = module.weight.data - module_bias = module.bias.data if hasattr(module, "bias") else None - bias = module_bias is not None - name_split = name.split(".") - - lora_A_name = f"{name}.lora_A" - lora_B_name = f"{name}.lora_B" - lora_A_weight_name = f"{lora_A_name}.weight" - lora_B_weight_name = f"{lora_B_name}.weight" - - if lora_A_weight_name not in state_dict.keys(): - continue - - in_features = state_dict[lora_A_weight_name].shape[1] - out_features = state_dict[lora_B_weight_name].shape[0] - - if tuple(module_weight.shape) == (out_features, in_features): - continue - - parent_module_name = ".".join(name_split[:-1]) - current_module_name = name_split[-1] - parent_module = get_submodule(transformer, parent_module_name) - - expanded_module = torch.nn.Linear( - in_features, out_features, bias=bias, device=module_weight.device, dtype=module_weight.dtype - ) - - new_weight = module_weight.new_zeros(expanded_module.weight.data.shape) - slices = tuple(slice(0, dim) for dim in module_weight.shape) - new_weight[slices] = module_weight - expanded_module.weight.data.copy_(new_weight) - - if bias: - new_bias = module_bias.new_zeros(expanded_module.bias.data.shape) - slices = tuple(slice(0, dim) for dim in module_bias.shape) - new_bias[slices] = module_bias - expanded_module.bias.data.copy_(new_bias) - - setattr(parent_module, current_module_name, expanded_module) - # The reason why we subclass from `StableDiffusionLoraLoaderMixin` here is because Amused initially # relied on `StableDiffusionLoraLoaderMixin` for its LoRA support. diff --git a/src/diffusers/loaders/peft.py b/src/diffusers/loaders/peft.py index 7b842e752699..bf118c88b2de 100644 --- a/src/diffusers/loaders/peft.py +++ b/src/diffusers/loaders/peft.py @@ -216,9 +216,7 @@ def load_lora_adapter(self, pretrained_model_name_or_path_or_dict, prefix="trans rank = {} for key, val in state_dict.items(): - # Cannot figure out rank from lora layers that don't have atleast 2 dimensions. - # Bias layers in LoRA only have a single dimension - if "lora_B" in key and val.ndim > 1: + if "lora_B" in key: rank[key] = val.shape[1] if network_alphas is not None and len(network_alphas) >= 1: @@ -226,7 +224,6 @@ def load_lora_adapter(self, pretrained_model_name_or_path_or_dict, prefix="trans network_alphas = {k.replace(f"{prefix}.", ""): v for k, v in network_alphas.items() if k in alpha_keys} lora_config_kwargs = get_peft_kwargs(rank, network_alpha_dict=network_alphas, peft_state_dict=state_dict) - print(lora_config_kwargs) if "use_dora" in lora_config_kwargs: if lora_config_kwargs["use_dora"]: if is_peft_version("<", "0.9.0"): From c94966f99a5563f1a848d80fd5b16e211b37084d Mon Sep 17 00:00:00 2001 From: Aryan Date: Fri, 22 Nov 2024 22:16:57 +0100 Subject: [PATCH 14/58] apply suggestions from review --- .../pipelines/flux/pipeline_flux_control.py | 40 ++++++++++--------- .../flux/pipeline_flux_control_img2img.py | 33 ++++++++------- 2 files changed, 38 insertions(+), 35 deletions(-) diff --git a/src/diffusers/pipelines/flux/pipeline_flux_control.py b/src/diffusers/pipelines/flux/pipeline_flux_control.py index 44051158e3e2..04a93ba6351c 100644 --- a/src/diffusers/pipelines/flux/pipeline_flux_control.py +++ b/src/diffusers/pipelines/flux/pipeline_flux_control.py @@ -215,9 +215,14 @@ def __init__( self.vae_scale_factor = ( 2 ** (len(self.vae.config.block_out_channels) - 1) if hasattr(self, "vae") and self.vae is not None else 8 ) + self.vae_latent_channels = ( + self.vae.config.latent_channels if hasattr(self, "vae") and self.vae is not None else 16 + ) # Flux latents are turned into 2x2 patches and packed. This means the latent width and height has to be divisible # by the patch size. So the vae scale factor is multiplied by the patch size to account for this - self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor * 2) + self.image_processor = VaeImageProcessor( + vae_scale_factor=self.vae_scale_factor * 2, vae_latent_channels=self.vae_latent_channels + ) self.tokenizer_max_length = ( self.tokenizer.model_max_length if hasattr(self, "tokenizer") and self.tokenizer is not None else 77 ) @@ -621,7 +626,6 @@ def __call__( num_images_per_prompt: Optional[int] = 1, generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, latents: Optional[torch.FloatTensor] = None, - control_latents: Optional[torch.FloatTensor] = None, prompt_embeds: Optional[torch.FloatTensor] = None, pooled_prompt_embeds: Optional[torch.FloatTensor] = None, output_type: Optional[str] = "pil", @@ -760,23 +764,23 @@ def __call__( # 4. Prepare latent variables num_channels_latents = self.transformer.config.in_channels // 8 - if control_latents is None: - control_image = self.prepare_image( - image=control_image, - width=width, - height=height, - batch_size=batch_size * num_images_per_prompt, - num_images_per_prompt=num_images_per_prompt, - device=device, - dtype=self.vae.dtype, - ) + control_image = self.prepare_image( + image=control_image, + width=width, + height=height, + batch_size=batch_size * num_images_per_prompt, + num_images_per_prompt=num_images_per_prompt, + device=device, + dtype=self.vae.dtype, + ) - control_latents = self.vae.encode(control_image).latent_dist.sample(generator=generator) - control_latents = (control_latents - self.vae.config.shift_factor) * self.vae.config.scaling_factor + if control_image.ndim == 4: + control_image = self.vae.encode(control_image).latent_dist.sample(generator=generator) + control_image = (control_image - self.vae.config.shift_factor) * self.vae.config.scaling_factor - height_control_image, width_control_image = control_latents.shape[2:] - control_latents = self._pack_latents( - control_latents, + height_control_image, width_control_image = control_image.shape[2:] + control_image = self._pack_latents( + control_image, batch_size * num_images_per_prompt, num_channels_latents, height_control_image, @@ -828,7 +832,7 @@ def __call__( if self.interrupt: continue - latent_model_input = torch.cat([latents, control_latents], dim=2) + latent_model_input = torch.cat([latents, control_image], dim=2) # broadcast to batch dimension in a way that's compatible with ONNX/Core ML timestep = t.expand(latents.shape[0]).to(latents.dtype) diff --git a/src/diffusers/pipelines/flux/pipeline_flux_control_img2img.py b/src/diffusers/pipelines/flux/pipeline_flux_control_img2img.py index 684908e93deb..ef20ab98ee2e 100644 --- a/src/diffusers/pipelines/flux/pipeline_flux_control_img2img.py +++ b/src/diffusers/pipelines/flux/pipeline_flux_control_img2img.py @@ -651,7 +651,6 @@ def __call__( guidance_scale: float = 7.0, num_images_per_prompt: Optional[int] = 1, generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, - control_latents: Optional[torch.FloatTensor] = None, latents: Optional[torch.FloatTensor] = None, prompt_embeds: Optional[torch.FloatTensor] = None, pooled_prompt_embeds: Optional[torch.FloatTensor] = None, @@ -835,23 +834,23 @@ def __call__( # 5. Prepare latent variables num_channels_latents = self.transformer.config.in_channels // 8 - if control_latents is None: - control_image = self.prepare_image( - image=control_image, - width=width, - height=height, - batch_size=batch_size * num_images_per_prompt, - num_images_per_prompt=num_images_per_prompt, - device=device, - dtype=self.vae.dtype, - ) + control_image = self.prepare_image( + image=control_image, + width=width, + height=height, + batch_size=batch_size * num_images_per_prompt, + num_images_per_prompt=num_images_per_prompt, + device=device, + dtype=self.vae.dtype, + ) - control_latents = self.vae.encode(control_image).latent_dist.sample(generator=generator) - control_latents = (control_latents - self.vae.config.shift_factor) * self.vae.config.scaling_factor + if control_image.ndim == 4: + control_image = self.vae.encode(control_image).latent_dist.sample(generator=generator) + control_image = (control_image - self.vae.config.shift_factor) * self.vae.config.scaling_factor - height_control_image, width_control_image = control_latents.shape[2:] - control_latents = self._pack_latents( - control_latents, + height_control_image, width_control_image = control_image.shape[2:] + control_image = self._pack_latents( + control_image, batch_size * num_images_per_prompt, num_channels_latents, height_control_image, @@ -887,7 +886,7 @@ def __call__( if self.interrupt: continue - latent_model_input = torch.cat([latents, control_latents], dim=2) + latent_model_input = torch.cat([latents, control_image], dim=2) # broadcast to batch dimension in a way that's compatible with ONNX/Core ML timestep = t.expand(latents.shape[0]).to(latents.dtype) From cfe13e793b7328ced9cda31eb71303554b370a43 Mon Sep 17 00:00:00 2001 From: Aryan Date: Fri, 22 Nov 2024 21:50:30 +0100 Subject: [PATCH 15/58] Revert "remove control lora changes" This reverts commit 73cfc519c9b99b7dc3251cc6a90a5db3056c4819. --- .../convert_flux_control_lora_to_diffusers.py | 393 ++++++++++++++++++ src/diffusers/loaders/lora_pipeline.py | 175 +++++++- src/diffusers/loaders/peft.py | 5 +- 3 files changed, 563 insertions(+), 10 deletions(-) create mode 100644 scripts/convert_flux_control_lora_to_diffusers.py diff --git a/scripts/convert_flux_control_lora_to_diffusers.py b/scripts/convert_flux_control_lora_to_diffusers.py new file mode 100644 index 000000000000..a110bd0bc0c8 --- /dev/null +++ b/scripts/convert_flux_control_lora_to_diffusers.py @@ -0,0 +1,393 @@ +import argparse +from contextlib import nullcontext + +import safetensors.torch +import torch +from accelerate import init_empty_weights +from huggingface_hub import hf_hub_download + +from diffusers.utils.import_utils import is_accelerate_available + + +CTX = init_empty_weights if is_accelerate_available else nullcontext + +parser = argparse.ArgumentParser() +parser.add_argument("--original_state_dict_repo_id", default=None, type=str) +parser.add_argument("--filename", default="flux-canny-dev-lora.safetensors", type=str) +parser.add_argument("--checkpoint_path", default=None, type=str) +parser.add_argument("--output_path", type=str) +parser.add_argument("--dtype", type=str, default="bf16") + +args = parser.parse_args() +dtype = torch.bfloat16 if args.dtype == "bf16" else torch.float32 + + +# Adapted from from the original BFL codebase. +def optionally_expand_state_dict(name: str, param: torch.Tensor, state_dict: dict) -> dict: + if name in state_dict: + print(f"Expanding '{name}' with shape {state_dict[name].shape} to model parameter with shape {param.shape}.") + # expand with zeros: + expanded_state_dict_weight = torch.zeros_like(param, device=state_dict[name].device) + # popular with pre-trained param for the first half. Remaining half stays with zeros. + slices = tuple(slice(0, dim) for dim in state_dict[name].shape) + expanded_state_dict_weight[slices] = state_dict[name] + state_dict[name] = expanded_state_dict_weight + + return state_dict + + +def load_original_checkpoint(args): + if args.original_state_dict_repo_id is not None: + ckpt_path = hf_hub_download(repo_id=args.original_state_dict_repo_id, filename=args.filename) + elif args.checkpoint_path is not None: + ckpt_path = args.checkpoint_path + else: + raise ValueError(" please provide either `original_state_dict_repo_id` or a local `checkpoint_path`") + + original_state_dict = safetensors.torch.load_file(ckpt_path) + return original_state_dict + + +# in SD3 original implementation of AdaLayerNormContinuous, it split linear projection output into shift, scale; +# while in diffusers it split into scale, shift. Here we swap the linear projection weights in order to be able to use diffusers implementation +def swap_scale_shift(weight): + shift, scale = weight.chunk(2, dim=0) + new_weight = torch.cat([scale, shift], dim=0) + return new_weight + + +def convert_flux_control_lora_checkpoint_to_diffusers( + original_state_dict, num_layers, num_single_layers, inner_dim, mlp_ratio=4.0 +): + converted_state_dict = {} + + ## time_text_embed.timestep_embedder <- time_in + for lora_key, diffusers_lora_key in zip(["lora_A", "lora_B"], ["lora_A", "lora_B"]): + converted_state_dict[ + f"time_text_embed.timestep_embedder.linear_1.{diffusers_lora_key}.weight" + ] = original_state_dict.pop(f"time_in.in_layer.{lora_key}.weight") + if f"time_in.in_layer.{lora_key}.bias" in original_state_dict.keys(): + converted_state_dict[ + f"time_text_embed.timestep_embedder.linear_1.{diffusers_lora_key}.bias" + ] = original_state_dict.pop(f"time_in.in_layer.{lora_key}.bias") + + converted_state_dict[ + f"time_text_embed.timestep_embedder.linear_2.{diffusers_lora_key}.weight" + ] = original_state_dict.pop(f"time_in.out_layer.{lora_key}.weight") + if f"time_in.out_layer.{lora_key}.bias" in original_state_dict.keys(): + converted_state_dict[ + f"time_text_embed.timestep_embedder.linear_2.{diffusers_lora_key}.bias" + ] = original_state_dict.pop(f"time_in.out_layer.{lora_key}.bias") + + ## time_text_embed.text_embedder <- vector_in + converted_state_dict[ + f"time_text_embed.text_embedder.linear_1.{diffusers_lora_key}.weight" + ] = original_state_dict.pop(f"vector_in.in_layer.{lora_key}.weight") + if f"vector_in.in_layer.{lora_key}.bias" in original_state_dict.keys(): + converted_state_dict[ + f"time_text_embed.text_embedder.linear_1.{diffusers_lora_key}.bias" + ] = original_state_dict.pop(f"vector_in.in_layer.{lora_key}.bias") + + converted_state_dict[ + f"time_text_embed.text_embedder.linear_2.{diffusers_lora_key}.weight" + ] = original_state_dict.pop(f"vector_in.out_layer.{lora_key}.weight") + if f"vector_in.out_layer.{lora_key}.bias" in original_state_dict.keys(): + converted_state_dict[ + f"time_text_embed.text_embedder.linear_2.{diffusers_lora_key}.bias" + ] = original_state_dict.pop(f"vector_in.out_layer.{lora_key}.bias") + + # guidance + has_guidance = any("guidance" in k for k in original_state_dict) + if has_guidance: + converted_state_dict[ + f"time_text_embed.guidance_embedder.linear_1.{diffusers_lora_key}.weight" + ] = original_state_dict.pop(f"guidance_in.in_layer.{lora_key}.weight") + if f"guidance_in.in_layer.{lora_key}.bias" in original_state_dict.keys(): + converted_state_dict[ + f"time_text_embed.guidance_embedder.linear_1.{diffusers_lora_key}.bias" + ] = original_state_dict.pop(f"guidance_in.in_layer.{lora_key}.bias") + + converted_state_dict[ + f"time_text_embed.guidance_embedder.linear_2.{diffusers_lora_key}.weight" + ] = original_state_dict.pop(f"guidance_in.out_layer.{lora_key}.weight") + if f"guidance_in.out_layer.{lora_key}.bias" in original_state_dict.keys(): + converted_state_dict[ + f"time_text_embed.guidance_embedder.linear_2.{diffusers_lora_key}.bias" + ] = original_state_dict.pop(f"guidance_in.out_layer.{lora_key}.bias") + + # context_embedder + converted_state_dict[f"context_embedder.{diffusers_lora_key}.weight"] = original_state_dict.pop( + f"txt_in.{lora_key}.weight" + ) + if f"txt_in.{lora_key}.bias" in original_state_dict.keys(): + converted_state_dict[f"context_embedder.{diffusers_lora_key}.bias"] = original_state_dict.pop( + f"txt_in.{lora_key}.bias" + ) + + # x_embedder + converted_state_dict[f"x_embedder.{diffusers_lora_key}.weight"] = original_state_dict.pop( + f"img_in.{lora_key}.weight" + ) + if f"img_in.{lora_key}.bias" in original_state_dict.keys(): + converted_state_dict[f"x_embedder.{diffusers_lora_key}.bias"] = original_state_dict.pop( + f"img_in.{lora_key}.bias" + ) + + # double transformer blocks + for i in range(num_layers): + block_prefix = f"transformer_blocks.{i}." + + for lora_key, diffusers_lora_key in zip(["lora_A", "lora_B"], ["lora_A", "lora_B"]): + # norms + converted_state_dict[f"{block_prefix}norm1.linear.{diffusers_lora_key}.weight"] = original_state_dict.pop( + f"double_blocks.{i}.img_mod.lin.{lora_key}.weight" + ) + if f"double_blocks.{i}.img_mod.lin.{lora_key}.bias" in original_state_dict.keys(): + converted_state_dict[ + f"{block_prefix}norm1.linear.{diffusers_lora_key}.bias" + ] = original_state_dict.pop(f"double_blocks.{i}.img_mod.lin.{lora_key}.bias") + + converted_state_dict[ + f"{block_prefix}norm1_context.linear.{diffusers_lora_key}.weight" + ] = original_state_dict.pop(f"double_blocks.{i}.txt_mod.lin.{lora_key}.weight") + if f"double_blocks.{i}.txt_mod.lin.{lora_key}.bias" in original_state_dict.keys(): + converted_state_dict[ + f"{block_prefix}norm1_context.linear.{diffusers_lora_key}.bias" + ] = original_state_dict.pop(f"double_blocks.{i}.txt_mod.lin.{lora_key}.bias") + + # Q, K, V + if lora_key == "lora_A": + sample_lora_weight = original_state_dict.pop(f"double_blocks.{i}.img_attn.qkv.{lora_key}.weight") + converted_state_dict[f"{block_prefix}attn.to_v.{diffusers_lora_key}.weight"] = torch.cat( + [sample_lora_weight] + ) + converted_state_dict[f"{block_prefix}attn.to_q.{diffusers_lora_key}.weight"] = torch.cat( + [sample_lora_weight] + ) + converted_state_dict[f"{block_prefix}attn.to_k.{diffusers_lora_key}.weight"] = torch.cat( + [sample_lora_weight] + ) + + context_lora_weight = original_state_dict.pop(f"double_blocks.{i}.txt_attn.qkv.{lora_key}.weight") + converted_state_dict[f"{block_prefix}attn.add_q_proj.{diffusers_lora_key}.weight"] = torch.cat( + [context_lora_weight] + ) + converted_state_dict[f"{block_prefix}attn.add_k_proj.{diffusers_lora_key}.weight"] = torch.cat( + [context_lora_weight] + ) + converted_state_dict[f"{block_prefix}attn.add_v_proj.{diffusers_lora_key}.weight"] = torch.cat( + [context_lora_weight] + ) + else: + sample_q, sample_k, sample_v = torch.chunk( + original_state_dict.pop(f"double_blocks.{i}.img_attn.qkv.{lora_key}.weight"), 3, dim=0 + ) + converted_state_dict[f"{block_prefix}attn.to_q.{diffusers_lora_key}.weight"] = torch.cat([sample_q]) + converted_state_dict[f"{block_prefix}attn.to_k.{diffusers_lora_key}.weight"] = torch.cat([sample_k]) + converted_state_dict[f"{block_prefix}attn.to_v.{diffusers_lora_key}.weight"] = torch.cat([sample_v]) + + context_q, context_k, context_v = torch.chunk( + original_state_dict.pop(f"double_blocks.{i}.txt_attn.qkv.{lora_key}.weight"), 3, dim=0 + ) + converted_state_dict[f"{block_prefix}attn.add_q_proj.{diffusers_lora_key}.weight"] = torch.cat( + [context_q] + ) + converted_state_dict[f"{block_prefix}attn.add_k_proj.{diffusers_lora_key}.weight"] = torch.cat( + [context_k] + ) + converted_state_dict[f"{block_prefix}attn.add_v_proj.{diffusers_lora_key}.weight"] = torch.cat( + [context_v] + ) + + if f"double_blocks.{i}.img_attn.qkv.{lora_key}.bias" in original_state_dict.keys(): + sample_q_bias, sample_k_bias, sample_v_bias = torch.chunk( + original_state_dict.pop(f"double_blocks.{i}.img_attn.qkv.{lora_key}.bias"), 3, dim=0 + ) + converted_state_dict[f"{block_prefix}attn.to_q.{diffusers_lora_key}.bias"] = torch.cat([sample_q_bias]) + converted_state_dict[f"{block_prefix}attn.to_k.{diffusers_lora_key}.bias"] = torch.cat([sample_k_bias]) + converted_state_dict[f"{block_prefix}attn.to_v.{diffusers_lora_key}.bias"] = torch.cat([sample_v_bias]) + + if f"double_blocks.{i}.txt_attn.qkv.{lora_key}.bias" in original_state_dict.keys(): + context_q_bias, context_k_bias, context_v_bias = torch.chunk( + original_state_dict.pop(f"double_blocks.{i}.txt_attn.qkv.{lora_key}.bias"), 3, dim=0 + ) + converted_state_dict[f"{block_prefix}attn.add_q_proj.{diffusers_lora_key}.bias"] = torch.cat( + [context_q_bias] + ) + converted_state_dict[f"{block_prefix}attn.add_k_proj.{diffusers_lora_key}.bias"] = torch.cat( + [context_k_bias] + ) + converted_state_dict[f"{block_prefix}attn.add_v_proj.{diffusers_lora_key}.bias"] = torch.cat( + [context_v_bias] + ) + + # ff img_mlp + converted_state_dict[f"{block_prefix}ff.net.0.proj.{diffusers_lora_key}.weight"] = original_state_dict.pop( + f"double_blocks.{i}.img_mlp.0.{lora_key}.weight" + ) + if f"double_blocks.{i}.img_mlp.0.{lora_key}.bias" in original_state_dict.keys(): + converted_state_dict[ + f"{block_prefix}ff.net.0.proj{diffusers_lora_key}..bias" + ] = original_state_dict.pop(f"double_blocks.{i}.img_mlp.0.{lora_key}.bias") + + converted_state_dict[f"{block_prefix}ff.net.2.{diffusers_lora_key}.weight"] = original_state_dict.pop( + f"double_blocks.{i}.img_mlp.2.{lora_key}.weight" + ) + if f"double_blocks.{i}.img_mlp.2.{lora_key}.bias" in original_state_dict.keys(): + converted_state_dict[f"{block_prefix}ff.net.2.{diffusers_lora_key}.bias"] = original_state_dict.pop( + f"double_blocks.{i}.img_mlp.2.{lora_key}.bias" + ) + + converted_state_dict[ + f"{block_prefix}ff_context.net.0.proj.{diffusers_lora_key}.weight" + ] = original_state_dict.pop(f"double_blocks.{i}.txt_mlp.0.{lora_key}.weight") + if f"double_blocks.{i}.txt_mlp.0.{lora_key}.bias" in original_state_dict.keys(): + converted_state_dict[ + f"{block_prefix}ff_context.net.0.proj.{diffusers_lora_key}.bias" + ] = original_state_dict.pop(f"double_blocks.{i}.txt_mlp.0.{lora_key}.bias") + + converted_state_dict[ + f"{block_prefix}ff_context.net.2.{diffusers_lora_key}.weight" + ] = original_state_dict.pop(f"double_blocks.{i}.txt_mlp.2.{lora_key}.weight") + if f"double_blocks.{i}.txt_mlp.2.{lora_key}.bias" in original_state_dict.keys(): + converted_state_dict[ + f"{block_prefix}ff_context.net.2.{diffusers_lora_key}.bias" + ] = original_state_dict.pop(f"double_blocks.{i}.txt_mlp.2.{lora_key}.bias") + + # output projections. + converted_state_dict[f"{block_prefix}attn.to_out.0.{diffusers_lora_key}.weight"] = original_state_dict.pop( + f"double_blocks.{i}.img_attn.proj.{lora_key}.weight" + ) + if f"double_blocks.{i}.img_attn.proj.{lora_key}.bias" in original_state_dict.keys(): + converted_state_dict[ + f"{block_prefix}attn.to_out.0.{diffusers_lora_key}.bias" + ] = original_state_dict.pop(f"double_blocks.{i}.img_attn.proj.{lora_key}.bias") + converted_state_dict[ + f"{block_prefix}attn.to_add_out.{diffusers_lora_key}.weight" + ] = original_state_dict.pop(f"double_blocks.{i}.txt_attn.proj.{lora_key}.weight") + if f"double_blocks.{i}.txt_attn.proj.{lora_key}.bias" in original_state_dict.keys(): + converted_state_dict[ + f"{block_prefix}attn.to_add_out.{diffusers_lora_key}.bias" + ] = original_state_dict.pop(f"double_blocks.{i}.txt_attn.proj.{lora_key}.bias") + + # qk_norm + converted_state_dict[f"{block_prefix}attn.norm_q.weight"] = original_state_dict.pop( + f"double_blocks.{i}.img_attn.norm.query_norm.scale" + ) + converted_state_dict[f"{block_prefix}attn.norm_k.weight"] = original_state_dict.pop( + f"double_blocks.{i}.img_attn.norm.key_norm.scale" + ) + converted_state_dict[f"{block_prefix}attn.norm_added_q.weight"] = original_state_dict.pop( + f"double_blocks.{i}.txt_attn.norm.query_norm.scale" + ) + converted_state_dict[f"{block_prefix}attn.norm_added_k.weight"] = original_state_dict.pop( + f"double_blocks.{i}.txt_attn.norm.key_norm.scale" + ) + + # single transfomer blocks + for i in range(num_single_layers): + block_prefix = f"single_transformer_blocks.{i}." + + for lora_key, diffusers_lora_key in zip(["lora_A", "lora_B"], ["lora_A", "lora_B"]): + # norm.linear <- single_blocks.0.modulation.lin + converted_state_dict[f"{block_prefix}norm.linear.{diffusers_lora_key}.weight"] = original_state_dict.pop( + f"single_blocks.{i}.modulation.lin.{lora_key}.weight" + ) + if f"single_blocks.{i}.modulation.lin.{lora_key}.bias" in original_state_dict.keys(): + converted_state_dict[f"{block_prefix}norm.linear.{diffusers_lora_key}.bias"] = original_state_dict.pop( + f"single_blocks.{i}.modulation.lin.{lora_key}.bias" + ) + + # Q, K, V, mlp + mlp_hidden_dim = int(inner_dim * mlp_ratio) + split_size = (inner_dim, inner_dim, inner_dim, mlp_hidden_dim) + + if lora_key == "lora_A": + lora_weight = original_state_dict.pop(f"single_blocks.{i}.linear1.{lora_key}.weight") + converted_state_dict[f"{block_prefix}attn.to_q.{diffusers_lora_key}.weight"] = torch.cat([lora_weight]) + converted_state_dict[f"{block_prefix}attn.to_k.{diffusers_lora_key}.weight"] = torch.cat([lora_weight]) + converted_state_dict[f"{block_prefix}attn.to_v.{diffusers_lora_key}.weight"] = torch.cat([lora_weight]) + converted_state_dict[f"{block_prefix}proj_mlp.{diffusers_lora_key}.weight"] = torch.cat([lora_weight]) + + if f"single_blocks.{i}.linear1.{lora_key}.bias" in original_state_dict.keys(): + lora_bias = original_state_dict.pop(f"single_blocks.{i}.linear1.{lora_key}.bias") + converted_state_dict[f"{block_prefix}attn.to_q.{diffusers_lora_key}.bias"] = torch.cat([lora_bias]) + converted_state_dict[f"{block_prefix}attn.to_k.{diffusers_lora_key}.bias"] = torch.cat([lora_bias]) + converted_state_dict[f"{block_prefix}attn.to_v.{diffusers_lora_key}.bias"] = torch.cat([lora_bias]) + converted_state_dict[f"{block_prefix}proj_mlp.{diffusers_lora_key}.bias"] = torch.cat([lora_bias]) + else: + q, k, v, mlp = torch.split( + original_state_dict.pop(f"single_blocks.{i}.linear1.{lora_key}.weight"), split_size, dim=0 + ) + converted_state_dict[f"{block_prefix}attn.to_q.{diffusers_lora_key}.weight"] = torch.cat([q]) + converted_state_dict[f"{block_prefix}attn.to_k.{diffusers_lora_key}.weight"] = torch.cat([k]) + converted_state_dict[f"{block_prefix}attn.to_v.{diffusers_lora_key}.weight"] = torch.cat([v]) + converted_state_dict[f"{block_prefix}proj_mlp.{diffusers_lora_key}.weight"] = torch.cat([mlp]) + + if f"single_blocks.{i}.linear1.{lora_key}.bias" in original_state_dict.keys(): + q_bias, k_bias, v_bias, mlp_bias = torch.split( + original_state_dict.pop(f"single_blocks.{i}.linear1.{lora_key}.bias"), split_size, dim=0 + ) + converted_state_dict[f"{block_prefix}attn.to_q.{diffusers_lora_key}.bias"] = torch.cat([q_bias]) + converted_state_dict[f"{block_prefix}attn.to_k.{diffusers_lora_key}.bias"] = torch.cat([k_bias]) + converted_state_dict[f"{block_prefix}attn.to_v.{diffusers_lora_key}.bias"] = torch.cat([v_bias]) + converted_state_dict[f"{block_prefix}proj_mlp.{diffusers_lora_key}.bias"] = torch.cat([mlp_bias]) + + # output projections. + converted_state_dict[f"{block_prefix}proj_out.{diffusers_lora_key}.weight"] = original_state_dict.pop( + f"single_blocks.{i}.linear2.{lora_key}.weight" + ) + if f"single_blocks.{i}.linear2.{lora_key}.bias" in original_state_dict.keys(): + converted_state_dict[f"{block_prefix}proj_out.{diffusers_lora_key}.bias"] = original_state_dict.pop( + f"single_blocks.{i}.linear2.{lora_key}.bias" + ) + + # qk norm + converted_state_dict[f"{block_prefix}attn.norm_q.weight"] = original_state_dict.pop( + f"single_blocks.{i}.norm.query_norm.scale" + ) + converted_state_dict[f"{block_prefix}attn.norm_k.weight"] = original_state_dict.pop( + f"single_blocks.{i}.norm.key_norm.scale" + ) + + for lora_key, diffusers_lora_key in zip(["lora_A", "lora_B"], ["lora_A", "lora_B"]): + converted_state_dict[f"proj_out.{diffusers_lora_key}.weight"] = original_state_dict.pop( + f"final_layer.linear.{lora_key}.weight" + ) + if f"final_layer.linear.{lora_key}.bias" in original_state_dict.keys(): + converted_state_dict[f"proj_out.{diffusers_lora_key}.bias"] = original_state_dict.pop( + f"final_layer.linear.{lora_key}.bias" + ) + + converted_state_dict[f"norm_out.linear.{diffusers_lora_key}.weight"] = swap_scale_shift( + original_state_dict.pop(f"final_layer.adaLN_modulation.1.{lora_key}.weight") + ) + if f"final_layer.adaLN_modulation.1.{lora_key}.bias" in original_state_dict.keys(): + converted_state_dict[f"norm_out.linear.{diffusers_lora_key}.bias"] = swap_scale_shift( + original_state_dict.pop(f"final_layer.adaLN_modulation.1.{lora_key}.bias") + ) + + print("Remaining:", original_state_dict.keys()) + + for key in list(converted_state_dict.keys()): + converted_state_dict[f"transformer.{key}"] = converted_state_dict.pop(key) + + return converted_state_dict + + +def main(args): + original_ckpt = load_original_checkpoint(args) + + num_layers = 19 + num_single_layers = 38 + inner_dim = 3072 + mlp_ratio = 4.0 + + converted_control_lora_state_dict = convert_flux_control_lora_checkpoint_to_diffusers( + original_ckpt, num_layers, num_single_layers, inner_dim, mlp_ratio + ) + safetensors.torch.save_file(converted_control_lora_state_dict, args.output_path) + + +if __name__ == "__main__": + main(args) diff --git a/src/diffusers/loaders/lora_pipeline.py b/src/diffusers/loaders/lora_pipeline.py index 109592c69c3e..b34370e620c3 100644 --- a/src/diffusers/loaders/lora_pipeline.py +++ b/src/diffusers/loaders/lora_pipeline.py @@ -1787,23 +1787,61 @@ def load_lora_weights( pretrained_model_name_or_path_or_dict, return_alphas=True, **kwargs ) - is_correct_format = all("lora" in key for key in state_dict.keys()) - if not is_correct_format: + has_lora_keys = any("lora" in key for key in state_dict.keys()) + + # Flux Control LoRAs also have norm keys + supported_norm_keys = ["norm_q", "norm_k", "norm_added_q", "norm_added_k"] + has_norm_keys = any(norm_key in key for key in state_dict.keys() for norm_key in supported_norm_keys) + + if not (has_lora_keys or has_norm_keys): raise ValueError("Invalid LoRA checkpoint.") - transformer_state_dict = {k: v for k, v in state_dict.items() if "transformer." in k} - if len(transformer_state_dict) > 0: + def prune_state_dict_(state_dict): + pruned_keys = [] + for key in list(state_dict.keys()): + is_lora_key_present = "lora" in key + is_norm_key_present = any(norm_key in key for norm_key in supported_norm_keys) + if not is_lora_key_present and not is_norm_key_present: + state_dict.pop(key) + pruned_keys.append(key) + return pruned_keys + + pruned_keys = prune_state_dict_(state_dict) + if len(pruned_keys) > 0: + logger.warning( + f"The provided LoRA state dict contains additional weights that are not compatible with Flux. The following are the incompatible weights:\n{pruned_keys}" + ) + + transformer_lora_state_dict = {k: v for k, v in state_dict.items() if "transformer." in k and "lora" in k} + transformer_norm_state_dict = { + k: v + for k, v in state_dict.items() + if "transformer." in k and any(norm_key in k for norm_key in supported_norm_keys) + } + + transformer = getattr(self, self.transformer_name) if not hasattr(self, "transformer") else self.transformer + self._maybe_expand_transformer_param_shape_( + transformer, transformer_lora_state_dict, transformer_norm_state_dict + ) + print(transformer) + + if len(transformer_lora_state_dict) > 0: self.load_lora_into_transformer( - state_dict, + transformer_lora_state_dict, network_alphas=network_alphas, - transformer=getattr(self, self.transformer_name) - if not hasattr(self, "transformer") - else self.transformer, + transformer=transformer, adapter_name=adapter_name, _pipeline=self, low_cpu_mem_usage=low_cpu_mem_usage, ) + if len(transformer_norm_state_dict) > 0: + self._transformer_norm_layers = self.load_norm_into_transformer( + transformer_norm_state_dict, + transformer=transformer, + discard_original_layers=False, + ) + text_encoder_state_dict = {k: v for k, v in state_dict.items() if "text_encoder." in k} if len(text_encoder_state_dict) > 0: self.load_lora_into_text_encoder( @@ -1860,6 +1898,46 @@ def load_lora_into_transformer( low_cpu_mem_usage=low_cpu_mem_usage, ) + @classmethod + def load_norm_into_transformer( + cls, + state_dict, + transformer, + prefix=None, + discard_original_layers=False, + ) -> Dict[str, torch.Tensor]: + # Remove prefix if present + prefix = prefix or cls.transformer_name + for key in list(state_dict.keys()): + if key.split(".")[0] == prefix: + state_dict[key.replace(f"{prefix}.", "")] = state_dict.pop(key) + + # Find invalid keys + transformer_state_dict = transformer.state_dict() + transformer_keys = set(transformer_state_dict.keys()) + state_dict_keys = set(state_dict.keys()) + extra_keys = list(state_dict_keys - transformer_keys) + logger.warning( + f"Unsupported keys found in state dict when trying to load normalization layers into the transformer. The following keys will be ignored:\n{extra_keys}." + ) + + for key in extra_keys: + state_dict.pop(key) + + # Save the layers that are going to be overwritten so that unload_lora_weights can work as expected + overwritten_layers = {} + if not discard_original_layers: + for key in state_dict.keys(): + overwritten_layers[key] = transformer_state_dict[key] + + # We can't load with strict=True because the current state_dict does not contain all the transformer keys + logger.info( + "Normalization layers in LoRA state dict can only be loaded if fused directly in the transformer. Calls to `.fuse_lora()` will only affect the LoRA layers and not the normalization layers." + ) + transformer.load_state_dict(state_dict, strict=False) + + return overwritten_layers + @classmethod # Copied from diffusers.loaders.lora_pipeline.StableDiffusionLoraLoaderMixin.load_lora_into_text_encoder def load_lora_into_text_encoder( @@ -2055,7 +2133,6 @@ def save_lora_weights( safe_serialization=safe_serialization, ) - # Copied from diffusers.loaders.lora_pipeline.StableDiffusionLoraLoaderMixin.fuse_lora with unet->transformer def fuse_lora( self, components: List[str] = ["transformer", "text_encoder"], @@ -2095,6 +2172,11 @@ def fuse_lora( pipeline.fuse_lora(lora_scale=0.7) ``` """ + if len(self._transformer_norm_layers.keys()) > 0: + logger.info( + "Normalization layers cannot be loaded without fusing. Calls to `.fuse_lora()` will only affect the actual LoRA layers." + ) + super().fuse_lora( components=components, lora_scale=lora_scale, safe_fusing=safe_fusing, adapter_names=adapter_names ) @@ -2113,8 +2195,83 @@ def unfuse_lora(self, components: List[str] = ["transformer", "text_encoder"], * Args: components (`List[str]`): List of LoRA-injectable components to unfuse LoRA from. """ + transformer = getattr(self, self.transformer_name) if not hasattr(self, "transformer") else self.transformer + transformer.load_state_dict(self._transformer_norm_layers) + super().unfuse_lora(components=components) + @classmethod + def _maybe_expand_transformer_param_shape_( + cls, + transformer: torch.nn.Module, + lora_state_dict=None, + norm_state_dict=None, + prefix=None, + ): + state_dict = {} + if lora_state_dict is not None: + state_dict.update(lora_state_dict) + if norm_state_dict is not None: + state_dict.update(norm_state_dict) + + # Remove prefix if present + prefix = prefix or cls.transformer_name + for key in list(state_dict.keys()): + if key.split(".")[0] == prefix: + state_dict[key.replace(f"{prefix}.", "")] = state_dict.pop(key) + + def get_submodule(module, name): + for part in name.split("."): + if len(name) == 0: + break + if not hasattr(module, part): + raise AttributeError(f"Submodule '{part}' not found in '{module}'.") + module = getattr(module, part) + return module + + # Expand transformer parameter shapes if they don't match lora + for name, module in transformer.named_modules(): + if isinstance(module, torch.nn.Linear): + module_weight = module.weight.data + module_bias = module.bias.data if hasattr(module, "bias") else None + bias = module_bias is not None + name_split = name.split(".") + + lora_A_name = f"{name}.lora_A" + lora_B_name = f"{name}.lora_B" + lora_A_weight_name = f"{lora_A_name}.weight" + lora_B_weight_name = f"{lora_B_name}.weight" + + if lora_A_weight_name not in state_dict.keys(): + continue + + in_features = state_dict[lora_A_weight_name].shape[1] + out_features = state_dict[lora_B_weight_name].shape[0] + + if tuple(module_weight.shape) == (out_features, in_features): + continue + + parent_module_name = ".".join(name_split[:-1]) + current_module_name = name_split[-1] + parent_module = get_submodule(transformer, parent_module_name) + + expanded_module = torch.nn.Linear( + in_features, out_features, bias=bias, device=module_weight.device, dtype=module_weight.dtype + ) + + new_weight = module_weight.new_zeros(expanded_module.weight.data.shape) + slices = tuple(slice(0, dim) for dim in module_weight.shape) + new_weight[slices] = module_weight + expanded_module.weight.data.copy_(new_weight) + + if bias: + new_bias = module_bias.new_zeros(expanded_module.bias.data.shape) + slices = tuple(slice(0, dim) for dim in module_bias.shape) + new_bias[slices] = module_bias + expanded_module.bias.data.copy_(new_bias) + + setattr(parent_module, current_module_name, expanded_module) + # The reason why we subclass from `StableDiffusionLoraLoaderMixin` here is because Amused initially # relied on `StableDiffusionLoraLoaderMixin` for its LoRA support. diff --git a/src/diffusers/loaders/peft.py b/src/diffusers/loaders/peft.py index bf118c88b2de..7b842e752699 100644 --- a/src/diffusers/loaders/peft.py +++ b/src/diffusers/loaders/peft.py @@ -216,7 +216,9 @@ def load_lora_adapter(self, pretrained_model_name_or_path_or_dict, prefix="trans rank = {} for key, val in state_dict.items(): - if "lora_B" in key: + # Cannot figure out rank from lora layers that don't have atleast 2 dimensions. + # Bias layers in LoRA only have a single dimension + if "lora_B" in key and val.ndim > 1: rank[key] = val.shape[1] if network_alphas is not None and len(network_alphas) >= 1: @@ -224,6 +226,7 @@ def load_lora_adapter(self, pretrained_model_name_or_path_or_dict, prefix="trans network_alphas = {k.replace(f"{prefix}.", ""): v for k, v in network_alphas.items() if k in alpha_keys} lora_config_kwargs = get_peft_kwargs(rank, network_alpha_dict=network_alphas, peft_state_dict=state_dict) + print(lora_config_kwargs) if "use_dora" in lora_config_kwargs: if lora_config_kwargs["use_dora"]: if is_peft_version("<", "0.9.0"): From 0c959a7e6ffd102d40ebbe0bf0f1849b01c132e3 Mon Sep 17 00:00:00 2001 From: Aryan Date: Sat, 23 Nov 2024 05:10:41 +0100 Subject: [PATCH 16/58] update --- .../convert_flux_control_lora_to_diffusers.py | 158 +++++++++--------- 1 file changed, 79 insertions(+), 79 deletions(-) diff --git a/scripts/convert_flux_control_lora_to_diffusers.py b/scripts/convert_flux_control_lora_to_diffusers.py index a110bd0bc0c8..4577fe702e6b 100644 --- a/scripts/convert_flux_control_lora_to_diffusers.py +++ b/scripts/convert_flux_control_lora_to_diffusers.py @@ -61,75 +61,75 @@ def convert_flux_control_lora_checkpoint_to_diffusers( ): converted_state_dict = {} - ## time_text_embed.timestep_embedder <- time_in - for lora_key, diffusers_lora_key in zip(["lora_A", "lora_B"], ["lora_A", "lora_B"]): + for lora_key in ["lora_A", "lora_B"]: + ## time_text_embed.timestep_embedder <- time_in converted_state_dict[ - f"time_text_embed.timestep_embedder.linear_1.{diffusers_lora_key}.weight" + f"time_text_embed.timestep_embedder.linear_1.{lora_key}.weight" ] = original_state_dict.pop(f"time_in.in_layer.{lora_key}.weight") if f"time_in.in_layer.{lora_key}.bias" in original_state_dict.keys(): converted_state_dict[ - f"time_text_embed.timestep_embedder.linear_1.{diffusers_lora_key}.bias" + f"time_text_embed.timestep_embedder.linear_1.{lora_key}.bias" ] = original_state_dict.pop(f"time_in.in_layer.{lora_key}.bias") converted_state_dict[ - f"time_text_embed.timestep_embedder.linear_2.{diffusers_lora_key}.weight" + f"time_text_embed.timestep_embedder.linear_2.{lora_key}.weight" ] = original_state_dict.pop(f"time_in.out_layer.{lora_key}.weight") if f"time_in.out_layer.{lora_key}.bias" in original_state_dict.keys(): converted_state_dict[ - f"time_text_embed.timestep_embedder.linear_2.{diffusers_lora_key}.bias" + f"time_text_embed.timestep_embedder.linear_2.{lora_key}.bias" ] = original_state_dict.pop(f"time_in.out_layer.{lora_key}.bias") ## time_text_embed.text_embedder <- vector_in converted_state_dict[ - f"time_text_embed.text_embedder.linear_1.{diffusers_lora_key}.weight" + f"time_text_embed.text_embedder.linear_1.{lora_key}.weight" ] = original_state_dict.pop(f"vector_in.in_layer.{lora_key}.weight") if f"vector_in.in_layer.{lora_key}.bias" in original_state_dict.keys(): converted_state_dict[ - f"time_text_embed.text_embedder.linear_1.{diffusers_lora_key}.bias" + f"time_text_embed.text_embedder.linear_1.{lora_key}.bias" ] = original_state_dict.pop(f"vector_in.in_layer.{lora_key}.bias") converted_state_dict[ - f"time_text_embed.text_embedder.linear_2.{diffusers_lora_key}.weight" + f"time_text_embed.text_embedder.linear_2.{lora_key}.weight" ] = original_state_dict.pop(f"vector_in.out_layer.{lora_key}.weight") if f"vector_in.out_layer.{lora_key}.bias" in original_state_dict.keys(): converted_state_dict[ - f"time_text_embed.text_embedder.linear_2.{diffusers_lora_key}.bias" + f"time_text_embed.text_embedder.linear_2.{lora_key}.bias" ] = original_state_dict.pop(f"vector_in.out_layer.{lora_key}.bias") # guidance has_guidance = any("guidance" in k for k in original_state_dict) if has_guidance: converted_state_dict[ - f"time_text_embed.guidance_embedder.linear_1.{diffusers_lora_key}.weight" + f"time_text_embed.guidance_embedder.linear_1.{lora_key}.weight" ] = original_state_dict.pop(f"guidance_in.in_layer.{lora_key}.weight") if f"guidance_in.in_layer.{lora_key}.bias" in original_state_dict.keys(): converted_state_dict[ - f"time_text_embed.guidance_embedder.linear_1.{diffusers_lora_key}.bias" + f"time_text_embed.guidance_embedder.linear_1.{lora_key}.bias" ] = original_state_dict.pop(f"guidance_in.in_layer.{lora_key}.bias") converted_state_dict[ - f"time_text_embed.guidance_embedder.linear_2.{diffusers_lora_key}.weight" + f"time_text_embed.guidance_embedder.linear_2.{lora_key}.weight" ] = original_state_dict.pop(f"guidance_in.out_layer.{lora_key}.weight") if f"guidance_in.out_layer.{lora_key}.bias" in original_state_dict.keys(): converted_state_dict[ - f"time_text_embed.guidance_embedder.linear_2.{diffusers_lora_key}.bias" + f"time_text_embed.guidance_embedder.linear_2.{lora_key}.bias" ] = original_state_dict.pop(f"guidance_in.out_layer.{lora_key}.bias") # context_embedder - converted_state_dict[f"context_embedder.{diffusers_lora_key}.weight"] = original_state_dict.pop( + converted_state_dict[f"context_embedder.{lora_key}.weight"] = original_state_dict.pop( f"txt_in.{lora_key}.weight" ) if f"txt_in.{lora_key}.bias" in original_state_dict.keys(): - converted_state_dict[f"context_embedder.{diffusers_lora_key}.bias"] = original_state_dict.pop( + converted_state_dict[f"context_embedder.{lora_key}.bias"] = original_state_dict.pop( f"txt_in.{lora_key}.bias" ) # x_embedder - converted_state_dict[f"x_embedder.{diffusers_lora_key}.weight"] = original_state_dict.pop( + converted_state_dict[f"x_embedder.{lora_key}.weight"] = original_state_dict.pop( f"img_in.{lora_key}.weight" ) if f"img_in.{lora_key}.bias" in original_state_dict.keys(): - converted_state_dict[f"x_embedder.{diffusers_lora_key}.bias"] = original_state_dict.pop( + converted_state_dict[f"x_embedder.{lora_key}.bias"] = original_state_dict.pop( f"img_in.{lora_key}.bias" ) @@ -137,65 +137,65 @@ def convert_flux_control_lora_checkpoint_to_diffusers( for i in range(num_layers): block_prefix = f"transformer_blocks.{i}." - for lora_key, diffusers_lora_key in zip(["lora_A", "lora_B"], ["lora_A", "lora_B"]): + for lora_key, lora_key in zip(["lora_A", "lora_B"], ["lora_A", "lora_B"]): # norms - converted_state_dict[f"{block_prefix}norm1.linear.{diffusers_lora_key}.weight"] = original_state_dict.pop( + converted_state_dict[f"{block_prefix}norm1.linear.{lora_key}.weight"] = original_state_dict.pop( f"double_blocks.{i}.img_mod.lin.{lora_key}.weight" ) if f"double_blocks.{i}.img_mod.lin.{lora_key}.bias" in original_state_dict.keys(): converted_state_dict[ - f"{block_prefix}norm1.linear.{diffusers_lora_key}.bias" + f"{block_prefix}norm1.linear.{lora_key}.bias" ] = original_state_dict.pop(f"double_blocks.{i}.img_mod.lin.{lora_key}.bias") converted_state_dict[ - f"{block_prefix}norm1_context.linear.{diffusers_lora_key}.weight" + f"{block_prefix}norm1_context.linear.{lora_key}.weight" ] = original_state_dict.pop(f"double_blocks.{i}.txt_mod.lin.{lora_key}.weight") if f"double_blocks.{i}.txt_mod.lin.{lora_key}.bias" in original_state_dict.keys(): converted_state_dict[ - f"{block_prefix}norm1_context.linear.{diffusers_lora_key}.bias" + f"{block_prefix}norm1_context.linear.{lora_key}.bias" ] = original_state_dict.pop(f"double_blocks.{i}.txt_mod.lin.{lora_key}.bias") # Q, K, V if lora_key == "lora_A": sample_lora_weight = original_state_dict.pop(f"double_blocks.{i}.img_attn.qkv.{lora_key}.weight") - converted_state_dict[f"{block_prefix}attn.to_v.{diffusers_lora_key}.weight"] = torch.cat( + converted_state_dict[f"{block_prefix}attn.to_v.{lora_key}.weight"] = torch.cat( [sample_lora_weight] ) - converted_state_dict[f"{block_prefix}attn.to_q.{diffusers_lora_key}.weight"] = torch.cat( + converted_state_dict[f"{block_prefix}attn.to_q.{lora_key}.weight"] = torch.cat( [sample_lora_weight] ) - converted_state_dict[f"{block_prefix}attn.to_k.{diffusers_lora_key}.weight"] = torch.cat( + converted_state_dict[f"{block_prefix}attn.to_k.{lora_key}.weight"] = torch.cat( [sample_lora_weight] ) context_lora_weight = original_state_dict.pop(f"double_blocks.{i}.txt_attn.qkv.{lora_key}.weight") - converted_state_dict[f"{block_prefix}attn.add_q_proj.{diffusers_lora_key}.weight"] = torch.cat( + converted_state_dict[f"{block_prefix}attn.add_q_proj.{lora_key}.weight"] = torch.cat( [context_lora_weight] ) - converted_state_dict[f"{block_prefix}attn.add_k_proj.{diffusers_lora_key}.weight"] = torch.cat( + converted_state_dict[f"{block_prefix}attn.add_k_proj.{lora_key}.weight"] = torch.cat( [context_lora_weight] ) - converted_state_dict[f"{block_prefix}attn.add_v_proj.{diffusers_lora_key}.weight"] = torch.cat( + converted_state_dict[f"{block_prefix}attn.add_v_proj.{lora_key}.weight"] = torch.cat( [context_lora_weight] ) else: sample_q, sample_k, sample_v = torch.chunk( original_state_dict.pop(f"double_blocks.{i}.img_attn.qkv.{lora_key}.weight"), 3, dim=0 ) - converted_state_dict[f"{block_prefix}attn.to_q.{diffusers_lora_key}.weight"] = torch.cat([sample_q]) - converted_state_dict[f"{block_prefix}attn.to_k.{diffusers_lora_key}.weight"] = torch.cat([sample_k]) - converted_state_dict[f"{block_prefix}attn.to_v.{diffusers_lora_key}.weight"] = torch.cat([sample_v]) + converted_state_dict[f"{block_prefix}attn.to_q.{lora_key}.weight"] = torch.cat([sample_q]) + converted_state_dict[f"{block_prefix}attn.to_k.{lora_key}.weight"] = torch.cat([sample_k]) + converted_state_dict[f"{block_prefix}attn.to_v.{lora_key}.weight"] = torch.cat([sample_v]) context_q, context_k, context_v = torch.chunk( original_state_dict.pop(f"double_blocks.{i}.txt_attn.qkv.{lora_key}.weight"), 3, dim=0 ) - converted_state_dict[f"{block_prefix}attn.add_q_proj.{diffusers_lora_key}.weight"] = torch.cat( + converted_state_dict[f"{block_prefix}attn.add_q_proj.{lora_key}.weight"] = torch.cat( [context_q] ) - converted_state_dict[f"{block_prefix}attn.add_k_proj.{diffusers_lora_key}.weight"] = torch.cat( + converted_state_dict[f"{block_prefix}attn.add_k_proj.{lora_key}.weight"] = torch.cat( [context_k] ) - converted_state_dict[f"{block_prefix}attn.add_v_proj.{diffusers_lora_key}.weight"] = torch.cat( + converted_state_dict[f"{block_prefix}attn.add_v_proj.{lora_key}.weight"] = torch.cat( [context_v] ) @@ -203,71 +203,71 @@ def convert_flux_control_lora_checkpoint_to_diffusers( sample_q_bias, sample_k_bias, sample_v_bias = torch.chunk( original_state_dict.pop(f"double_blocks.{i}.img_attn.qkv.{lora_key}.bias"), 3, dim=0 ) - converted_state_dict[f"{block_prefix}attn.to_q.{diffusers_lora_key}.bias"] = torch.cat([sample_q_bias]) - converted_state_dict[f"{block_prefix}attn.to_k.{diffusers_lora_key}.bias"] = torch.cat([sample_k_bias]) - converted_state_dict[f"{block_prefix}attn.to_v.{diffusers_lora_key}.bias"] = torch.cat([sample_v_bias]) + converted_state_dict[f"{block_prefix}attn.to_q.{lora_key}.bias"] = torch.cat([sample_q_bias]) + converted_state_dict[f"{block_prefix}attn.to_k.{lora_key}.bias"] = torch.cat([sample_k_bias]) + converted_state_dict[f"{block_prefix}attn.to_v.{lora_key}.bias"] = torch.cat([sample_v_bias]) if f"double_blocks.{i}.txt_attn.qkv.{lora_key}.bias" in original_state_dict.keys(): context_q_bias, context_k_bias, context_v_bias = torch.chunk( original_state_dict.pop(f"double_blocks.{i}.txt_attn.qkv.{lora_key}.bias"), 3, dim=0 ) - converted_state_dict[f"{block_prefix}attn.add_q_proj.{diffusers_lora_key}.bias"] = torch.cat( + converted_state_dict[f"{block_prefix}attn.add_q_proj.{lora_key}.bias"] = torch.cat( [context_q_bias] ) - converted_state_dict[f"{block_prefix}attn.add_k_proj.{diffusers_lora_key}.bias"] = torch.cat( + converted_state_dict[f"{block_prefix}attn.add_k_proj.{lora_key}.bias"] = torch.cat( [context_k_bias] ) - converted_state_dict[f"{block_prefix}attn.add_v_proj.{diffusers_lora_key}.bias"] = torch.cat( + converted_state_dict[f"{block_prefix}attn.add_v_proj.{lora_key}.bias"] = torch.cat( [context_v_bias] ) # ff img_mlp - converted_state_dict[f"{block_prefix}ff.net.0.proj.{diffusers_lora_key}.weight"] = original_state_dict.pop( + converted_state_dict[f"{block_prefix}ff.net.0.proj.{lora_key}.weight"] = original_state_dict.pop( f"double_blocks.{i}.img_mlp.0.{lora_key}.weight" ) if f"double_blocks.{i}.img_mlp.0.{lora_key}.bias" in original_state_dict.keys(): converted_state_dict[ - f"{block_prefix}ff.net.0.proj{diffusers_lora_key}..bias" + f"{block_prefix}ff.net.0.proj{lora_key}..bias" ] = original_state_dict.pop(f"double_blocks.{i}.img_mlp.0.{lora_key}.bias") - converted_state_dict[f"{block_prefix}ff.net.2.{diffusers_lora_key}.weight"] = original_state_dict.pop( + converted_state_dict[f"{block_prefix}ff.net.2.{lora_key}.weight"] = original_state_dict.pop( f"double_blocks.{i}.img_mlp.2.{lora_key}.weight" ) if f"double_blocks.{i}.img_mlp.2.{lora_key}.bias" in original_state_dict.keys(): - converted_state_dict[f"{block_prefix}ff.net.2.{diffusers_lora_key}.bias"] = original_state_dict.pop( + converted_state_dict[f"{block_prefix}ff.net.2.{lora_key}.bias"] = original_state_dict.pop( f"double_blocks.{i}.img_mlp.2.{lora_key}.bias" ) converted_state_dict[ - f"{block_prefix}ff_context.net.0.proj.{diffusers_lora_key}.weight" + f"{block_prefix}ff_context.net.0.proj.{lora_key}.weight" ] = original_state_dict.pop(f"double_blocks.{i}.txt_mlp.0.{lora_key}.weight") if f"double_blocks.{i}.txt_mlp.0.{lora_key}.bias" in original_state_dict.keys(): converted_state_dict[ - f"{block_prefix}ff_context.net.0.proj.{diffusers_lora_key}.bias" + f"{block_prefix}ff_context.net.0.proj.{lora_key}.bias" ] = original_state_dict.pop(f"double_blocks.{i}.txt_mlp.0.{lora_key}.bias") converted_state_dict[ - f"{block_prefix}ff_context.net.2.{diffusers_lora_key}.weight" + f"{block_prefix}ff_context.net.2.{lora_key}.weight" ] = original_state_dict.pop(f"double_blocks.{i}.txt_mlp.2.{lora_key}.weight") if f"double_blocks.{i}.txt_mlp.2.{lora_key}.bias" in original_state_dict.keys(): converted_state_dict[ - f"{block_prefix}ff_context.net.2.{diffusers_lora_key}.bias" + f"{block_prefix}ff_context.net.2.{lora_key}.bias" ] = original_state_dict.pop(f"double_blocks.{i}.txt_mlp.2.{lora_key}.bias") # output projections. - converted_state_dict[f"{block_prefix}attn.to_out.0.{diffusers_lora_key}.weight"] = original_state_dict.pop( + converted_state_dict[f"{block_prefix}attn.to_out.0.{lora_key}.weight"] = original_state_dict.pop( f"double_blocks.{i}.img_attn.proj.{lora_key}.weight" ) if f"double_blocks.{i}.img_attn.proj.{lora_key}.bias" in original_state_dict.keys(): converted_state_dict[ - f"{block_prefix}attn.to_out.0.{diffusers_lora_key}.bias" + f"{block_prefix}attn.to_out.0.{lora_key}.bias" ] = original_state_dict.pop(f"double_blocks.{i}.img_attn.proj.{lora_key}.bias") converted_state_dict[ - f"{block_prefix}attn.to_add_out.{diffusers_lora_key}.weight" + f"{block_prefix}attn.to_add_out.{lora_key}.weight" ] = original_state_dict.pop(f"double_blocks.{i}.txt_attn.proj.{lora_key}.weight") if f"double_blocks.{i}.txt_attn.proj.{lora_key}.bias" in original_state_dict.keys(): converted_state_dict[ - f"{block_prefix}attn.to_add_out.{diffusers_lora_key}.bias" + f"{block_prefix}attn.to_add_out.{lora_key}.bias" ] = original_state_dict.pop(f"double_blocks.{i}.txt_attn.proj.{lora_key}.bias") # qk_norm @@ -288,13 +288,13 @@ def convert_flux_control_lora_checkpoint_to_diffusers( for i in range(num_single_layers): block_prefix = f"single_transformer_blocks.{i}." - for lora_key, diffusers_lora_key in zip(["lora_A", "lora_B"], ["lora_A", "lora_B"]): + for lora_key in ["lora_A", "lora_B"]: # norm.linear <- single_blocks.0.modulation.lin - converted_state_dict[f"{block_prefix}norm.linear.{diffusers_lora_key}.weight"] = original_state_dict.pop( + converted_state_dict[f"{block_prefix}norm.linear.{lora_key}.weight"] = original_state_dict.pop( f"single_blocks.{i}.modulation.lin.{lora_key}.weight" ) if f"single_blocks.{i}.modulation.lin.{lora_key}.bias" in original_state_dict.keys(): - converted_state_dict[f"{block_prefix}norm.linear.{diffusers_lora_key}.bias"] = original_state_dict.pop( + converted_state_dict[f"{block_prefix}norm.linear.{lora_key}.bias"] = original_state_dict.pop( f"single_blocks.{i}.modulation.lin.{lora_key}.bias" ) @@ -304,41 +304,41 @@ def convert_flux_control_lora_checkpoint_to_diffusers( if lora_key == "lora_A": lora_weight = original_state_dict.pop(f"single_blocks.{i}.linear1.{lora_key}.weight") - converted_state_dict[f"{block_prefix}attn.to_q.{diffusers_lora_key}.weight"] = torch.cat([lora_weight]) - converted_state_dict[f"{block_prefix}attn.to_k.{diffusers_lora_key}.weight"] = torch.cat([lora_weight]) - converted_state_dict[f"{block_prefix}attn.to_v.{diffusers_lora_key}.weight"] = torch.cat([lora_weight]) - converted_state_dict[f"{block_prefix}proj_mlp.{diffusers_lora_key}.weight"] = torch.cat([lora_weight]) + converted_state_dict[f"{block_prefix}attn.to_q.{lora_key}.weight"] = torch.cat([lora_weight]) + converted_state_dict[f"{block_prefix}attn.to_k.{lora_key}.weight"] = torch.cat([lora_weight]) + converted_state_dict[f"{block_prefix}attn.to_v.{lora_key}.weight"] = torch.cat([lora_weight]) + converted_state_dict[f"{block_prefix}proj_mlp.{lora_key}.weight"] = torch.cat([lora_weight]) if f"single_blocks.{i}.linear1.{lora_key}.bias" in original_state_dict.keys(): lora_bias = original_state_dict.pop(f"single_blocks.{i}.linear1.{lora_key}.bias") - converted_state_dict[f"{block_prefix}attn.to_q.{diffusers_lora_key}.bias"] = torch.cat([lora_bias]) - converted_state_dict[f"{block_prefix}attn.to_k.{diffusers_lora_key}.bias"] = torch.cat([lora_bias]) - converted_state_dict[f"{block_prefix}attn.to_v.{diffusers_lora_key}.bias"] = torch.cat([lora_bias]) - converted_state_dict[f"{block_prefix}proj_mlp.{diffusers_lora_key}.bias"] = torch.cat([lora_bias]) + converted_state_dict[f"{block_prefix}attn.to_q.{lora_key}.bias"] = torch.cat([lora_bias]) + converted_state_dict[f"{block_prefix}attn.to_k.{lora_key}.bias"] = torch.cat([lora_bias]) + converted_state_dict[f"{block_prefix}attn.to_v.{lora_key}.bias"] = torch.cat([lora_bias]) + converted_state_dict[f"{block_prefix}proj_mlp.{lora_key}.bias"] = torch.cat([lora_bias]) else: q, k, v, mlp = torch.split( original_state_dict.pop(f"single_blocks.{i}.linear1.{lora_key}.weight"), split_size, dim=0 ) - converted_state_dict[f"{block_prefix}attn.to_q.{diffusers_lora_key}.weight"] = torch.cat([q]) - converted_state_dict[f"{block_prefix}attn.to_k.{diffusers_lora_key}.weight"] = torch.cat([k]) - converted_state_dict[f"{block_prefix}attn.to_v.{diffusers_lora_key}.weight"] = torch.cat([v]) - converted_state_dict[f"{block_prefix}proj_mlp.{diffusers_lora_key}.weight"] = torch.cat([mlp]) + converted_state_dict[f"{block_prefix}attn.to_q.{lora_key}.weight"] = torch.cat([q]) + converted_state_dict[f"{block_prefix}attn.to_k.{lora_key}.weight"] = torch.cat([k]) + converted_state_dict[f"{block_prefix}attn.to_v.{lora_key}.weight"] = torch.cat([v]) + converted_state_dict[f"{block_prefix}proj_mlp.{lora_key}.weight"] = torch.cat([mlp]) if f"single_blocks.{i}.linear1.{lora_key}.bias" in original_state_dict.keys(): q_bias, k_bias, v_bias, mlp_bias = torch.split( original_state_dict.pop(f"single_blocks.{i}.linear1.{lora_key}.bias"), split_size, dim=0 ) - converted_state_dict[f"{block_prefix}attn.to_q.{diffusers_lora_key}.bias"] = torch.cat([q_bias]) - converted_state_dict[f"{block_prefix}attn.to_k.{diffusers_lora_key}.bias"] = torch.cat([k_bias]) - converted_state_dict[f"{block_prefix}attn.to_v.{diffusers_lora_key}.bias"] = torch.cat([v_bias]) - converted_state_dict[f"{block_prefix}proj_mlp.{diffusers_lora_key}.bias"] = torch.cat([mlp_bias]) + converted_state_dict[f"{block_prefix}attn.to_q.{lora_key}.bias"] = torch.cat([q_bias]) + converted_state_dict[f"{block_prefix}attn.to_k.{lora_key}.bias"] = torch.cat([k_bias]) + converted_state_dict[f"{block_prefix}attn.to_v.{lora_key}.bias"] = torch.cat([v_bias]) + converted_state_dict[f"{block_prefix}proj_mlp.{lora_key}.bias"] = torch.cat([mlp_bias]) # output projections. - converted_state_dict[f"{block_prefix}proj_out.{diffusers_lora_key}.weight"] = original_state_dict.pop( + converted_state_dict[f"{block_prefix}proj_out.{lora_key}.weight"] = original_state_dict.pop( f"single_blocks.{i}.linear2.{lora_key}.weight" ) if f"single_blocks.{i}.linear2.{lora_key}.bias" in original_state_dict.keys(): - converted_state_dict[f"{block_prefix}proj_out.{diffusers_lora_key}.bias"] = original_state_dict.pop( + converted_state_dict[f"{block_prefix}proj_out.{lora_key}.bias"] = original_state_dict.pop( f"single_blocks.{i}.linear2.{lora_key}.bias" ) @@ -350,20 +350,20 @@ def convert_flux_control_lora_checkpoint_to_diffusers( f"single_blocks.{i}.norm.key_norm.scale" ) - for lora_key, diffusers_lora_key in zip(["lora_A", "lora_B"], ["lora_A", "lora_B"]): - converted_state_dict[f"proj_out.{diffusers_lora_key}.weight"] = original_state_dict.pop( + for lora_key in ["lora_A", "lora_B"]: + converted_state_dict[f"proj_out.{lora_key}.weight"] = original_state_dict.pop( f"final_layer.linear.{lora_key}.weight" ) if f"final_layer.linear.{lora_key}.bias" in original_state_dict.keys(): - converted_state_dict[f"proj_out.{diffusers_lora_key}.bias"] = original_state_dict.pop( + converted_state_dict[f"proj_out.{lora_key}.bias"] = original_state_dict.pop( f"final_layer.linear.{lora_key}.bias" ) - converted_state_dict[f"norm_out.linear.{diffusers_lora_key}.weight"] = swap_scale_shift( + converted_state_dict[f"norm_out.linear.{lora_key}.weight"] = swap_scale_shift( original_state_dict.pop(f"final_layer.adaLN_modulation.1.{lora_key}.weight") ) if f"final_layer.adaLN_modulation.1.{lora_key}.bias" in original_state_dict.keys(): - converted_state_dict[f"norm_out.linear.{diffusers_lora_key}.bias"] = swap_scale_shift( + converted_state_dict[f"norm_out.linear.{lora_key}.bias"] = swap_scale_shift( original_state_dict.pop(f"final_layer.adaLN_modulation.1.{lora_key}.bias") ) From 6ef2c8b86e62e307c976a4aa9c846e9c10f5eef3 Mon Sep 17 00:00:00 2001 From: Aryan Date: Sat, 23 Nov 2024 05:19:24 +0100 Subject: [PATCH 17/58] update --- .../convert_flux_control_lora_to_diffusers.py | 183 ++++++++---------- src/diffusers/loaders/lora_pipeline.py | 29 +-- 2 files changed, 98 insertions(+), 114 deletions(-) diff --git a/scripts/convert_flux_control_lora_to_diffusers.py b/scripts/convert_flux_control_lora_to_diffusers.py index 4577fe702e6b..cd1295d9765d 100644 --- a/scripts/convert_flux_control_lora_to_diffusers.py +++ b/scripts/convert_flux_control_lora_to_diffusers.py @@ -60,13 +60,14 @@ def convert_flux_control_lora_checkpoint_to_diffusers( original_state_dict, num_layers, num_single_layers, inner_dim, mlp_ratio=4.0 ): converted_state_dict = {} + original_state_dict_keys = original_state_dict.keys() for lora_key in ["lora_A", "lora_B"]: ## time_text_embed.timestep_embedder <- time_in converted_state_dict[ f"time_text_embed.timestep_embedder.linear_1.{lora_key}.weight" ] = original_state_dict.pop(f"time_in.in_layer.{lora_key}.weight") - if f"time_in.in_layer.{lora_key}.bias" in original_state_dict.keys(): + if f"time_in.in_layer.{lora_key}.bias" in original_state_dict_keys: converted_state_dict[ f"time_text_embed.timestep_embedder.linear_1.{lora_key}.bias" ] = original_state_dict.pop(f"time_in.in_layer.{lora_key}.bias") @@ -74,27 +75,27 @@ def convert_flux_control_lora_checkpoint_to_diffusers( converted_state_dict[ f"time_text_embed.timestep_embedder.linear_2.{lora_key}.weight" ] = original_state_dict.pop(f"time_in.out_layer.{lora_key}.weight") - if f"time_in.out_layer.{lora_key}.bias" in original_state_dict.keys(): + if f"time_in.out_layer.{lora_key}.bias" in original_state_dict_keys: converted_state_dict[ f"time_text_embed.timestep_embedder.linear_2.{lora_key}.bias" ] = original_state_dict.pop(f"time_in.out_layer.{lora_key}.bias") ## time_text_embed.text_embedder <- vector_in - converted_state_dict[ - f"time_text_embed.text_embedder.linear_1.{lora_key}.weight" - ] = original_state_dict.pop(f"vector_in.in_layer.{lora_key}.weight") - if f"vector_in.in_layer.{lora_key}.bias" in original_state_dict.keys(): - converted_state_dict[ - f"time_text_embed.text_embedder.linear_1.{lora_key}.bias" - ] = original_state_dict.pop(f"vector_in.in_layer.{lora_key}.bias") + converted_state_dict[f"time_text_embed.text_embedder.linear_1.{lora_key}.weight"] = original_state_dict.pop( + f"vector_in.in_layer.{lora_key}.weight" + ) + if f"vector_in.in_layer.{lora_key}.bias" in original_state_dict_keys: + converted_state_dict[f"time_text_embed.text_embedder.linear_1.{lora_key}.bias"] = original_state_dict.pop( + f"vector_in.in_layer.{lora_key}.bias" + ) - converted_state_dict[ - f"time_text_embed.text_embedder.linear_2.{lora_key}.weight" - ] = original_state_dict.pop(f"vector_in.out_layer.{lora_key}.weight") - if f"vector_in.out_layer.{lora_key}.bias" in original_state_dict.keys(): - converted_state_dict[ - f"time_text_embed.text_embedder.linear_2.{lora_key}.bias" - ] = original_state_dict.pop(f"vector_in.out_layer.{lora_key}.bias") + converted_state_dict[f"time_text_embed.text_embedder.linear_2.{lora_key}.weight"] = original_state_dict.pop( + f"vector_in.out_layer.{lora_key}.weight" + ) + if f"vector_in.out_layer.{lora_key}.bias" in original_state_dict_keys: + converted_state_dict[f"time_text_embed.text_embedder.linear_2.{lora_key}.bias"] = original_state_dict.pop( + f"vector_in.out_layer.{lora_key}.bias" + ) # guidance has_guidance = any("guidance" in k for k in original_state_dict) @@ -102,7 +103,7 @@ def convert_flux_control_lora_checkpoint_to_diffusers( converted_state_dict[ f"time_text_embed.guidance_embedder.linear_1.{lora_key}.weight" ] = original_state_dict.pop(f"guidance_in.in_layer.{lora_key}.weight") - if f"guidance_in.in_layer.{lora_key}.bias" in original_state_dict.keys(): + if f"guidance_in.in_layer.{lora_key}.bias" in original_state_dict_keys: converted_state_dict[ f"time_text_embed.guidance_embedder.linear_1.{lora_key}.bias" ] = original_state_dict.pop(f"guidance_in.in_layer.{lora_key}.bias") @@ -110,7 +111,7 @@ def convert_flux_control_lora_checkpoint_to_diffusers( converted_state_dict[ f"time_text_embed.guidance_embedder.linear_2.{lora_key}.weight" ] = original_state_dict.pop(f"guidance_in.out_layer.{lora_key}.weight") - if f"guidance_in.out_layer.{lora_key}.bias" in original_state_dict.keys(): + if f"guidance_in.out_layer.{lora_key}.bias" in original_state_dict_keys: converted_state_dict[ f"time_text_embed.guidance_embedder.linear_2.{lora_key}.bias" ] = original_state_dict.pop(f"guidance_in.out_layer.{lora_key}.bias") @@ -119,19 +120,15 @@ def convert_flux_control_lora_checkpoint_to_diffusers( converted_state_dict[f"context_embedder.{lora_key}.weight"] = original_state_dict.pop( f"txt_in.{lora_key}.weight" ) - if f"txt_in.{lora_key}.bias" in original_state_dict.keys(): + if f"txt_in.{lora_key}.bias" in original_state_dict_keys: converted_state_dict[f"context_embedder.{lora_key}.bias"] = original_state_dict.pop( f"txt_in.{lora_key}.bias" ) # x_embedder - converted_state_dict[f"x_embedder.{lora_key}.weight"] = original_state_dict.pop( - f"img_in.{lora_key}.weight" - ) - if f"img_in.{lora_key}.bias" in original_state_dict.keys(): - converted_state_dict[f"x_embedder.{lora_key}.bias"] = original_state_dict.pop( - f"img_in.{lora_key}.bias" - ) + converted_state_dict[f"x_embedder.{lora_key}.weight"] = original_state_dict.pop(f"img_in.{lora_key}.weight") + if f"img_in.{lora_key}.bias" in original_state_dict_keys: + converted_state_dict[f"x_embedder.{lora_key}.bias"] = original_state_dict.pop(f"img_in.{lora_key}.bias") # double transformer blocks for i in range(num_layers): @@ -142,31 +139,25 @@ def convert_flux_control_lora_checkpoint_to_diffusers( converted_state_dict[f"{block_prefix}norm1.linear.{lora_key}.weight"] = original_state_dict.pop( f"double_blocks.{i}.img_mod.lin.{lora_key}.weight" ) - if f"double_blocks.{i}.img_mod.lin.{lora_key}.bias" in original_state_dict.keys(): - converted_state_dict[ - f"{block_prefix}norm1.linear.{lora_key}.bias" - ] = original_state_dict.pop(f"double_blocks.{i}.img_mod.lin.{lora_key}.bias") + if f"double_blocks.{i}.img_mod.lin.{lora_key}.bias" in original_state_dict_keys: + converted_state_dict[f"{block_prefix}norm1.linear.{lora_key}.bias"] = original_state_dict.pop( + f"double_blocks.{i}.img_mod.lin.{lora_key}.bias" + ) - converted_state_dict[ - f"{block_prefix}norm1_context.linear.{lora_key}.weight" - ] = original_state_dict.pop(f"double_blocks.{i}.txt_mod.lin.{lora_key}.weight") - if f"double_blocks.{i}.txt_mod.lin.{lora_key}.bias" in original_state_dict.keys(): - converted_state_dict[ - f"{block_prefix}norm1_context.linear.{lora_key}.bias" - ] = original_state_dict.pop(f"double_blocks.{i}.txt_mod.lin.{lora_key}.bias") + converted_state_dict[f"{block_prefix}norm1_context.linear.{lora_key}.weight"] = original_state_dict.pop( + f"double_blocks.{i}.txt_mod.lin.{lora_key}.weight" + ) + if f"double_blocks.{i}.txt_mod.lin.{lora_key}.bias" in original_state_dict_keys: + converted_state_dict[f"{block_prefix}norm1_context.linear.{lora_key}.bias"] = original_state_dict.pop( + f"double_blocks.{i}.txt_mod.lin.{lora_key}.bias" + ) # Q, K, V if lora_key == "lora_A": sample_lora_weight = original_state_dict.pop(f"double_blocks.{i}.img_attn.qkv.{lora_key}.weight") - converted_state_dict[f"{block_prefix}attn.to_v.{lora_key}.weight"] = torch.cat( - [sample_lora_weight] - ) - converted_state_dict[f"{block_prefix}attn.to_q.{lora_key}.weight"] = torch.cat( - [sample_lora_weight] - ) - converted_state_dict[f"{block_prefix}attn.to_k.{lora_key}.weight"] = torch.cat( - [sample_lora_weight] - ) + converted_state_dict[f"{block_prefix}attn.to_v.{lora_key}.weight"] = torch.cat([sample_lora_weight]) + converted_state_dict[f"{block_prefix}attn.to_q.{lora_key}.weight"] = torch.cat([sample_lora_weight]) + converted_state_dict[f"{block_prefix}attn.to_k.{lora_key}.weight"] = torch.cat([sample_lora_weight]) context_lora_weight = original_state_dict.pop(f"double_blocks.{i}.txt_attn.qkv.{lora_key}.weight") converted_state_dict[f"{block_prefix}attn.add_q_proj.{lora_key}.weight"] = torch.cat( @@ -189,17 +180,11 @@ def convert_flux_control_lora_checkpoint_to_diffusers( context_q, context_k, context_v = torch.chunk( original_state_dict.pop(f"double_blocks.{i}.txt_attn.qkv.{lora_key}.weight"), 3, dim=0 ) - converted_state_dict[f"{block_prefix}attn.add_q_proj.{lora_key}.weight"] = torch.cat( - [context_q] - ) - converted_state_dict[f"{block_prefix}attn.add_k_proj.{lora_key}.weight"] = torch.cat( - [context_k] - ) - converted_state_dict[f"{block_prefix}attn.add_v_proj.{lora_key}.weight"] = torch.cat( - [context_v] - ) + converted_state_dict[f"{block_prefix}attn.add_q_proj.{lora_key}.weight"] = torch.cat([context_q]) + converted_state_dict[f"{block_prefix}attn.add_k_proj.{lora_key}.weight"] = torch.cat([context_k]) + converted_state_dict[f"{block_prefix}attn.add_v_proj.{lora_key}.weight"] = torch.cat([context_v]) - if f"double_blocks.{i}.img_attn.qkv.{lora_key}.bias" in original_state_dict.keys(): + if f"double_blocks.{i}.img_attn.qkv.{lora_key}.bias" in original_state_dict_keys: sample_q_bias, sample_k_bias, sample_v_bias = torch.chunk( original_state_dict.pop(f"double_blocks.{i}.img_attn.qkv.{lora_key}.bias"), 3, dim=0 ) @@ -207,68 +192,62 @@ def convert_flux_control_lora_checkpoint_to_diffusers( converted_state_dict[f"{block_prefix}attn.to_k.{lora_key}.bias"] = torch.cat([sample_k_bias]) converted_state_dict[f"{block_prefix}attn.to_v.{lora_key}.bias"] = torch.cat([sample_v_bias]) - if f"double_blocks.{i}.txt_attn.qkv.{lora_key}.bias" in original_state_dict.keys(): + if f"double_blocks.{i}.txt_attn.qkv.{lora_key}.bias" in original_state_dict_keys: context_q_bias, context_k_bias, context_v_bias = torch.chunk( original_state_dict.pop(f"double_blocks.{i}.txt_attn.qkv.{lora_key}.bias"), 3, dim=0 ) - converted_state_dict[f"{block_prefix}attn.add_q_proj.{lora_key}.bias"] = torch.cat( - [context_q_bias] - ) - converted_state_dict[f"{block_prefix}attn.add_k_proj.{lora_key}.bias"] = torch.cat( - [context_k_bias] - ) - converted_state_dict[f"{block_prefix}attn.add_v_proj.{lora_key}.bias"] = torch.cat( - [context_v_bias] - ) + converted_state_dict[f"{block_prefix}attn.add_q_proj.{lora_key}.bias"] = torch.cat([context_q_bias]) + converted_state_dict[f"{block_prefix}attn.add_k_proj.{lora_key}.bias"] = torch.cat([context_k_bias]) + converted_state_dict[f"{block_prefix}attn.add_v_proj.{lora_key}.bias"] = torch.cat([context_v_bias]) # ff img_mlp converted_state_dict[f"{block_prefix}ff.net.0.proj.{lora_key}.weight"] = original_state_dict.pop( f"double_blocks.{i}.img_mlp.0.{lora_key}.weight" ) - if f"double_blocks.{i}.img_mlp.0.{lora_key}.bias" in original_state_dict.keys(): - converted_state_dict[ - f"{block_prefix}ff.net.0.proj{lora_key}..bias" - ] = original_state_dict.pop(f"double_blocks.{i}.img_mlp.0.{lora_key}.bias") + if f"double_blocks.{i}.img_mlp.0.{lora_key}.bias" in original_state_dict_keys: + converted_state_dict[f"{block_prefix}ff.net.0.proj{lora_key}..bias"] = original_state_dict.pop( + f"double_blocks.{i}.img_mlp.0.{lora_key}.bias" + ) converted_state_dict[f"{block_prefix}ff.net.2.{lora_key}.weight"] = original_state_dict.pop( f"double_blocks.{i}.img_mlp.2.{lora_key}.weight" ) - if f"double_blocks.{i}.img_mlp.2.{lora_key}.bias" in original_state_dict.keys(): + if f"double_blocks.{i}.img_mlp.2.{lora_key}.bias" in original_state_dict_keys: converted_state_dict[f"{block_prefix}ff.net.2.{lora_key}.bias"] = original_state_dict.pop( f"double_blocks.{i}.img_mlp.2.{lora_key}.bias" ) - converted_state_dict[ - f"{block_prefix}ff_context.net.0.proj.{lora_key}.weight" - ] = original_state_dict.pop(f"double_blocks.{i}.txt_mlp.0.{lora_key}.weight") - if f"double_blocks.{i}.txt_mlp.0.{lora_key}.bias" in original_state_dict.keys(): - converted_state_dict[ - f"{block_prefix}ff_context.net.0.proj.{lora_key}.bias" - ] = original_state_dict.pop(f"double_blocks.{i}.txt_mlp.0.{lora_key}.bias") + converted_state_dict[f"{block_prefix}ff_context.net.0.proj.{lora_key}.weight"] = original_state_dict.pop( + f"double_blocks.{i}.txt_mlp.0.{lora_key}.weight" + ) + if f"double_blocks.{i}.txt_mlp.0.{lora_key}.bias" in original_state_dict_keys: + converted_state_dict[f"{block_prefix}ff_context.net.0.proj.{lora_key}.bias"] = original_state_dict.pop( + f"double_blocks.{i}.txt_mlp.0.{lora_key}.bias" + ) - converted_state_dict[ - f"{block_prefix}ff_context.net.2.{lora_key}.weight" - ] = original_state_dict.pop(f"double_blocks.{i}.txt_mlp.2.{lora_key}.weight") - if f"double_blocks.{i}.txt_mlp.2.{lora_key}.bias" in original_state_dict.keys(): - converted_state_dict[ - f"{block_prefix}ff_context.net.2.{lora_key}.bias" - ] = original_state_dict.pop(f"double_blocks.{i}.txt_mlp.2.{lora_key}.bias") + converted_state_dict[f"{block_prefix}ff_context.net.2.{lora_key}.weight"] = original_state_dict.pop( + f"double_blocks.{i}.txt_mlp.2.{lora_key}.weight" + ) + if f"double_blocks.{i}.txt_mlp.2.{lora_key}.bias" in original_state_dict_keys: + converted_state_dict[f"{block_prefix}ff_context.net.2.{lora_key}.bias"] = original_state_dict.pop( + f"double_blocks.{i}.txt_mlp.2.{lora_key}.bias" + ) # output projections. converted_state_dict[f"{block_prefix}attn.to_out.0.{lora_key}.weight"] = original_state_dict.pop( f"double_blocks.{i}.img_attn.proj.{lora_key}.weight" ) - if f"double_blocks.{i}.img_attn.proj.{lora_key}.bias" in original_state_dict.keys(): - converted_state_dict[ - f"{block_prefix}attn.to_out.0.{lora_key}.bias" - ] = original_state_dict.pop(f"double_blocks.{i}.img_attn.proj.{lora_key}.bias") - converted_state_dict[ - f"{block_prefix}attn.to_add_out.{lora_key}.weight" - ] = original_state_dict.pop(f"double_blocks.{i}.txt_attn.proj.{lora_key}.weight") - if f"double_blocks.{i}.txt_attn.proj.{lora_key}.bias" in original_state_dict.keys(): - converted_state_dict[ - f"{block_prefix}attn.to_add_out.{lora_key}.bias" - ] = original_state_dict.pop(f"double_blocks.{i}.txt_attn.proj.{lora_key}.bias") + if f"double_blocks.{i}.img_attn.proj.{lora_key}.bias" in original_state_dict_keys: + converted_state_dict[f"{block_prefix}attn.to_out.0.{lora_key}.bias"] = original_state_dict.pop( + f"double_blocks.{i}.img_attn.proj.{lora_key}.bias" + ) + converted_state_dict[f"{block_prefix}attn.to_add_out.{lora_key}.weight"] = original_state_dict.pop( + f"double_blocks.{i}.txt_attn.proj.{lora_key}.weight" + ) + if f"double_blocks.{i}.txt_attn.proj.{lora_key}.bias" in original_state_dict_keys: + converted_state_dict[f"{block_prefix}attn.to_add_out.{lora_key}.bias"] = original_state_dict.pop( + f"double_blocks.{i}.txt_attn.proj.{lora_key}.bias" + ) # qk_norm converted_state_dict[f"{block_prefix}attn.norm_q.weight"] = original_state_dict.pop( @@ -293,7 +272,7 @@ def convert_flux_control_lora_checkpoint_to_diffusers( converted_state_dict[f"{block_prefix}norm.linear.{lora_key}.weight"] = original_state_dict.pop( f"single_blocks.{i}.modulation.lin.{lora_key}.weight" ) - if f"single_blocks.{i}.modulation.lin.{lora_key}.bias" in original_state_dict.keys(): + if f"single_blocks.{i}.modulation.lin.{lora_key}.bias" in original_state_dict_keys: converted_state_dict[f"{block_prefix}norm.linear.{lora_key}.bias"] = original_state_dict.pop( f"single_blocks.{i}.modulation.lin.{lora_key}.bias" ) @@ -309,7 +288,7 @@ def convert_flux_control_lora_checkpoint_to_diffusers( converted_state_dict[f"{block_prefix}attn.to_v.{lora_key}.weight"] = torch.cat([lora_weight]) converted_state_dict[f"{block_prefix}proj_mlp.{lora_key}.weight"] = torch.cat([lora_weight]) - if f"single_blocks.{i}.linear1.{lora_key}.bias" in original_state_dict.keys(): + if f"single_blocks.{i}.linear1.{lora_key}.bias" in original_state_dict_keys: lora_bias = original_state_dict.pop(f"single_blocks.{i}.linear1.{lora_key}.bias") converted_state_dict[f"{block_prefix}attn.to_q.{lora_key}.bias"] = torch.cat([lora_bias]) converted_state_dict[f"{block_prefix}attn.to_k.{lora_key}.bias"] = torch.cat([lora_bias]) @@ -324,7 +303,7 @@ def convert_flux_control_lora_checkpoint_to_diffusers( converted_state_dict[f"{block_prefix}attn.to_v.{lora_key}.weight"] = torch.cat([v]) converted_state_dict[f"{block_prefix}proj_mlp.{lora_key}.weight"] = torch.cat([mlp]) - if f"single_blocks.{i}.linear1.{lora_key}.bias" in original_state_dict.keys(): + if f"single_blocks.{i}.linear1.{lora_key}.bias" in original_state_dict_keys: q_bias, k_bias, v_bias, mlp_bias = torch.split( original_state_dict.pop(f"single_blocks.{i}.linear1.{lora_key}.bias"), split_size, dim=0 ) @@ -337,7 +316,7 @@ def convert_flux_control_lora_checkpoint_to_diffusers( converted_state_dict[f"{block_prefix}proj_out.{lora_key}.weight"] = original_state_dict.pop( f"single_blocks.{i}.linear2.{lora_key}.weight" ) - if f"single_blocks.{i}.linear2.{lora_key}.bias" in original_state_dict.keys(): + if f"single_blocks.{i}.linear2.{lora_key}.bias" in original_state_dict_keys: converted_state_dict[f"{block_prefix}proj_out.{lora_key}.bias"] = original_state_dict.pop( f"single_blocks.{i}.linear2.{lora_key}.bias" ) @@ -354,7 +333,7 @@ def convert_flux_control_lora_checkpoint_to_diffusers( converted_state_dict[f"proj_out.{lora_key}.weight"] = original_state_dict.pop( f"final_layer.linear.{lora_key}.weight" ) - if f"final_layer.linear.{lora_key}.bias" in original_state_dict.keys(): + if f"final_layer.linear.{lora_key}.bias" in original_state_dict_keys: converted_state_dict[f"proj_out.{lora_key}.bias"] = original_state_dict.pop( f"final_layer.linear.{lora_key}.bias" ) @@ -362,7 +341,7 @@ def convert_flux_control_lora_checkpoint_to_diffusers( converted_state_dict[f"norm_out.linear.{lora_key}.weight"] = swap_scale_shift( original_state_dict.pop(f"final_layer.adaLN_modulation.1.{lora_key}.weight") ) - if f"final_layer.adaLN_modulation.1.{lora_key}.bias" in original_state_dict.keys(): + if f"final_layer.adaLN_modulation.1.{lora_key}.bias" in original_state_dict_keys: converted_state_dict[f"norm_out.linear.{lora_key}.bias"] = swap_scale_shift( original_state_dict.pop(f"final_layer.adaLN_modulation.1.{lora_key}.bias") ) diff --git a/src/diffusers/loaders/lora_pipeline.py b/src/diffusers/loaders/lora_pipeline.py index b34370e620c3..beffe166b78b 100644 --- a/src/diffusers/loaders/lora_pipeline.py +++ b/src/diffusers/loaders/lora_pipeline.py @@ -1812,10 +1812,12 @@ def prune_state_dict_(state_dict): f"The provided LoRA state dict contains additional weights that are not compatible with Flux. The following are the incompatible weights:\n{pruned_keys}" ) - transformer_lora_state_dict = {k: v for k, v in state_dict.items() if "transformer." in k and "lora" in k} + transformer_lora_state_dict = { + k: state_dict.pop(k) for k in list(state_dict.keys()) if "transformer." in k and "lora" in k + } transformer_norm_state_dict = { - k: v - for k, v in state_dict.items() + k: state_dict.pop(k) + for k in list(state_dict.keys()) if "transformer." in k and any(norm_key in k for norm_key in supported_norm_keys) } @@ -1823,7 +1825,6 @@ def prune_state_dict_(state_dict): self._maybe_expand_transformer_param_shape_( transformer, transformer_lora_state_dict, transformer_norm_state_dict ) - print(transformer) if len(transformer_lora_state_dict) > 0: self.load_lora_into_transformer( @@ -1836,7 +1837,7 @@ def prune_state_dict_(state_dict): ) if len(transformer_norm_state_dict) > 0: - self._transformer_norm_layers = self.load_norm_into_transformer( + self._transformer_norm_layers = self._load_norm_into_transformer( transformer_norm_state_dict, transformer=transformer, discard_original_layers=False, @@ -1899,7 +1900,7 @@ def load_lora_into_transformer( ) @classmethod - def load_norm_into_transformer( + def _load_norm_into_transformer( cls, state_dict, transformer, @@ -1925,10 +1926,10 @@ def load_norm_into_transformer( state_dict.pop(key) # Save the layers that are going to be overwritten so that unload_lora_weights can work as expected - overwritten_layers = {} + overwritten_layers_state_dict = {} if not discard_original_layers: for key in state_dict.keys(): - overwritten_layers[key] = transformer_state_dict[key] + overwritten_layers_state_dict[key] = transformer_state_dict[key] # We can't load with strict=True because the current state_dict does not contain all the transformer keys logger.info( @@ -1936,7 +1937,7 @@ def load_norm_into_transformer( ) transformer.load_state_dict(state_dict, strict=False) - return overwritten_layers + return overwritten_layers_state_dict @classmethod # Copied from diffusers.loaders.lora_pipeline.StableDiffusionLoraLoaderMixin.load_lora_into_text_encoder @@ -2196,7 +2197,7 @@ def unfuse_lora(self, components: List[str] = ["transformer", "text_encoder"], * components (`List[str]`): List of LoRA-injectable components to unfuse LoRA from. """ transformer = getattr(self, self.transformer_name) if not hasattr(self, "transformer") else self.transformer - transformer.load_state_dict(self._transformer_norm_layers) + transformer.load_state_dict(self._transformer_norm_layers, strict=False) super().unfuse_lora(components=components) @@ -2259,13 +2260,17 @@ def get_submodule(module, name): in_features, out_features, bias=bias, device=module_weight.device, dtype=module_weight.dtype ) - new_weight = module_weight.new_zeros(expanded_module.weight.data.shape) + new_weight = torch.zeros_like( + expanded_module.weight.data.shape, device=module_weight.device, dtype=module_weight.dtype + ) slices = tuple(slice(0, dim) for dim in module_weight.shape) new_weight[slices] = module_weight expanded_module.weight.data.copy_(new_weight) if bias: - new_bias = module_bias.new_zeros(expanded_module.bias.data.shape) + new_bias = torch.zeros_like( + expanded_module.bias.data.shape, device=module_bias.device, dtype=module_bias.dtype + ) slices = tuple(slice(0, dim) for dim in module_bias.shape) new_bias[slices] = module_bias expanded_module.bias.data.copy_(new_bias) From 42970ee1529dee0592eae1771eab84e62c661fd7 Mon Sep 17 00:00:00 2001 From: Aryan Date: Sat, 23 Nov 2024 05:46:26 +0100 Subject: [PATCH 18/58] improve log messages --- src/diffusers/loaders/lora_pipeline.py | 45 ++++++++++++++++++++++---- 1 file changed, 39 insertions(+), 6 deletions(-) diff --git a/src/diffusers/loaders/lora_pipeline.py b/src/diffusers/loaders/lora_pipeline.py index beffe166b78b..7ff20e29f1b1 100644 --- a/src/diffusers/loaders/lora_pipeline.py +++ b/src/diffusers/loaders/lora_pipeline.py @@ -1822,10 +1822,17 @@ def prune_state_dict_(state_dict): } transformer = getattr(self, self.transformer_name) if not hasattr(self, "transformer") else self.transformer - self._maybe_expand_transformer_param_shape_( + has_param_with_expanded_shape = self._maybe_expand_transformer_param_shape_or_error_( transformer, transformer_lora_state_dict, transformer_norm_state_dict ) + if has_param_with_expanded_shape: + logger.info( + "The LoRA weights contain parameters that have different shapes that expected by the transformer. " + "As a result, the state_dict of the transformer has been expanded to match the LoRA parameter shapes. " + "To get a comprehensive list of parameter names that were modified, enable debug logging." + ) + if len(transformer_lora_state_dict) > 0: self.load_lora_into_transformer( transformer_lora_state_dict, @@ -1931,10 +1938,13 @@ def _load_norm_into_transformer( for key in state_dict.keys(): overwritten_layers_state_dict[key] = transformer_state_dict[key] - # We can't load with strict=True because the current state_dict does not contain all the transformer keys logger.info( - "Normalization layers in LoRA state dict can only be loaded if fused directly in the transformer. Calls to `.fuse_lora()` will only affect the LoRA layers and not the normalization layers." + "The provided state dict contains normalization layers in addition to LoRA layers. The normalization layers will directly update the state_dict of the transformer " + 'as opposed to the LoRA layers that will co-exist separately until the "fuse_lora()" method is called. That is to say, the normalization layers will always be directly ' + "fused into the transformer and can only be unfused if `discard_original_layers=True` is passed." ) + + # We can't load with strict=True because the current state_dict does not contain all the transformer keys transformer.load_state_dict(state_dict, strict=False) return overwritten_layers_state_dict @@ -2175,7 +2185,9 @@ def fuse_lora( """ if len(self._transformer_norm_layers.keys()) > 0: logger.info( - "Normalization layers cannot be loaded without fusing. Calls to `.fuse_lora()` will only affect the actual LoRA layers." + "The provided state dict contains normalization layers in addition to LoRA layers. The normalization layers will directly update the state_dict of the transformer " + 'as opposed to the LoRA layers that will co-exist separately until the "fuse_lora()" method is called. That is to say, the normalization layers will always be directly ' + "fused into the transformer and can only be unfused if `discard_original_layers=True` is passed." ) super().fuse_lora( @@ -2202,13 +2214,13 @@ def unfuse_lora(self, components: List[str] = ["transformer", "text_encoder"], * super().unfuse_lora(components=components) @classmethod - def _maybe_expand_transformer_param_shape_( + def _maybe_expand_transformer_param_shape_or_error_( cls, transformer: torch.nn.Module, lora_state_dict=None, norm_state_dict=None, prefix=None, - ): + ) -> bool: state_dict = {} if lora_state_dict is not None: state_dict.update(lora_state_dict) @@ -2231,6 +2243,8 @@ def get_submodule(module, name): return module # Expand transformer parameter shapes if they don't match lora + has_param_with_shape_update = False + for name, module in transformer.named_modules(): if isinstance(module, torch.nn.Linear): module_weight = module.weight.data @@ -2252,6 +2266,23 @@ def get_submodule(module, name): if tuple(module_weight.shape) == (out_features, in_features): continue + module_out_features, module_in_features = module_weight.shape + if out_features < module_out_features or in_features < module_in_features: + raise NotImplementedError( + f"Only LoRAs with input/output features higher than the current modules' input/output features " + f"are currently supported. The provided LoRA contains {in_features=} and {out_features=}, which " + f"are lower than {module_in_features=} and {module_out_features=}. If you require support for " + f"this please open an issue at https://github.com/huggingface/diffusers/issues." + ) + + logger.debug( + f'Expanding the nn.Linear input/output features for module="{name}" because the provided LoRA ' + f"checkpoint contains higher number of features than expected. The number of input_features will be " + f"expanded from {module_in_features} to {in_features}, and the number of output features will be " + f"expanded from {module_out_features} to {out_features}." + ) + + has_param_with_shape_update = True parent_module_name = ".".join(name_split[:-1]) current_module_name = name_split[-1] parent_module = get_submodule(transformer, parent_module_name) @@ -2277,6 +2308,8 @@ def get_submodule(module, name): setattr(parent_module, current_module_name, expanded_module) + return has_param_with_shape_update + # The reason why we subclass from `StableDiffusionLoraLoaderMixin` here is because Amused initially # relied on `StableDiffusionLoraLoaderMixin` for its LoRA support. From 6523fa650ab9e7d6db2c4c24df8a64c35117a93f Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Mon, 25 Nov 2024 09:08:45 +0530 Subject: [PATCH 19/58] updates. --- .../convert_flux_control_lora_to_diffusers.py | 27 +++------------ src/diffusers/loaders/lora_pipeline.py | 12 ++++--- src/diffusers/loaders/peft.py | 34 ++++++++++++++++++- 3 files changed, 44 insertions(+), 29 deletions(-) diff --git a/scripts/convert_flux_control_lora_to_diffusers.py b/scripts/convert_flux_control_lora_to_diffusers.py index cd1295d9765d..031f77686190 100644 --- a/scripts/convert_flux_control_lora_to_diffusers.py +++ b/scripts/convert_flux_control_lora_to_diffusers.py @@ -1,15 +1,9 @@ import argparse -from contextlib import nullcontext import safetensors.torch import torch -from accelerate import init_empty_weights from huggingface_hub import hf_hub_download -from diffusers.utils.import_utils import is_accelerate_available - - -CTX = init_empty_weights if is_accelerate_available else nullcontext parser = argparse.ArgumentParser() parser.add_argument("--original_state_dict_repo_id", default=None, type=str) @@ -22,27 +16,13 @@ dtype = torch.bfloat16 if args.dtype == "bf16" else torch.float32 -# Adapted from from the original BFL codebase. -def optionally_expand_state_dict(name: str, param: torch.Tensor, state_dict: dict) -> dict: - if name in state_dict: - print(f"Expanding '{name}' with shape {state_dict[name].shape} to model parameter with shape {param.shape}.") - # expand with zeros: - expanded_state_dict_weight = torch.zeros_like(param, device=state_dict[name].device) - # popular with pre-trained param for the first half. Remaining half stays with zeros. - slices = tuple(slice(0, dim) for dim in state_dict[name].shape) - expanded_state_dict_weight[slices] = state_dict[name] - state_dict[name] = expanded_state_dict_weight - - return state_dict - - def load_original_checkpoint(args): if args.original_state_dict_repo_id is not None: ckpt_path = hf_hub_download(repo_id=args.original_state_dict_repo_id, filename=args.filename) elif args.checkpoint_path is not None: ckpt_path = args.checkpoint_path else: - raise ValueError(" please provide either `original_state_dict_repo_id` or a local `checkpoint_path`") + raise ValueError("Please provide either `original_state_dict_repo_id` or a local `checkpoint_path`") original_state_dict = safetensors.torch.load_file(ckpt_path) return original_state_dict @@ -60,7 +40,7 @@ def convert_flux_control_lora_checkpoint_to_diffusers( original_state_dict, num_layers, num_single_layers, inner_dim, mlp_ratio=4.0 ): converted_state_dict = {} - original_state_dict_keys = original_state_dict.keys() + original_state_dict_keys = list(original_state_dict.keys()) for lora_key in ["lora_A", "lora_B"]: ## time_text_embed.timestep_embedder <- time_in @@ -346,7 +326,8 @@ def convert_flux_control_lora_checkpoint_to_diffusers( original_state_dict.pop(f"final_layer.adaLN_modulation.1.{lora_key}.bias") ) - print("Remaining:", original_state_dict.keys()) + if len(original_state_dict) > 0: + raise ValueError(f"`original_state_dict` should be empty at this point but has {original_state_dict.keys()=}.") for key in list(converted_state_dict.keys()): converted_state_dict[f"transformer.{key}"] = converted_state_dict.pop(key) diff --git a/src/diffusers/loaders/lora_pipeline.py b/src/diffusers/loaders/lora_pipeline.py index 7ff20e29f1b1..9a8e8473d5c1 100644 --- a/src/diffusers/loaders/lora_pipeline.py +++ b/src/diffusers/loaders/lora_pipeline.py @@ -1925,9 +1925,11 @@ def _load_norm_into_transformer( transformer_keys = set(transformer_state_dict.keys()) state_dict_keys = set(state_dict.keys()) extra_keys = list(state_dict_keys - transformer_keys) - logger.warning( - f"Unsupported keys found in state dict when trying to load normalization layers into the transformer. The following keys will be ignored:\n{extra_keys}." - ) + + if extra_keys: + logger.warning( + f"Unsupported keys found in state dict when trying to load normalization layers into the transformer. The following keys will be ignored:\n{extra_keys}." + ) for key in extra_keys: state_dict.pop(key) @@ -2292,7 +2294,7 @@ def get_submodule(module, name): ) new_weight = torch.zeros_like( - expanded_module.weight.data.shape, device=module_weight.device, dtype=module_weight.dtype + expanded_module.weight.data, device=module_weight.device, dtype=module_weight.dtype ) slices = tuple(slice(0, dim) for dim in module_weight.shape) new_weight[slices] = module_weight @@ -2300,7 +2302,7 @@ def get_submodule(module, name): if bias: new_bias = torch.zeros_like( - expanded_module.bias.data.shape, device=module_bias.device, dtype=module_bias.dtype + expanded_module.bias.data, device=module_bias.device, dtype=module_bias.dtype ) slices = tuple(slice(0, dim) for dim in module_bias.shape) new_bias[slices] = module_bias diff --git a/src/diffusers/loaders/peft.py b/src/diffusers/loaders/peft.py index 7b842e752699..d425f8a3a8b9 100644 --- a/src/diffusers/loaders/peft.py +++ b/src/diffusers/loaders/peft.py @@ -56,6 +56,37 @@ } +def _maybe_adjust_config(config): + rank_pattern = config["rank_pattern"].copy() + target_modules = config["target_modules"] + original_r = config["r"] + + for key in list(rank_pattern.keys()): + key_rank = rank_pattern[key] + + # try to detect ambiguity + exact_matches = [mod for mod in target_modules if mod == key] + substring_matches = [mod for mod in target_modules if key in mod and mod != key] + ambiguous_key = key + + if exact_matches and substring_matches: + # if ambiguous we update the rank associated with the ambiguous key (`proj_out`, for example) + config["r"] = key_rank + # remove the ambiguous key from `rank_pattern` and update its rank to `r`, instead + del config["rank_pattern"][key] + for mod in substring_matches: + # avoid overwriting if the module already has a specific rank + if mod not in config["rank_pattern"]: + config["rank_pattern"][mod] = original_r + + # update the rest of the keys with the `original_r` + for mod in target_modules: + if mod != ambiguous_key and mod not in config["rank_pattern"]: + config["rank_pattern"][mod] = original_r + + return config + + class PeftAdapterMixin: """ A class containing all functions for loading and using adapters weights that are supported in PEFT library. For @@ -226,7 +257,8 @@ def load_lora_adapter(self, pretrained_model_name_or_path_or_dict, prefix="trans network_alphas = {k.replace(f"{prefix}.", ""): v for k, v in network_alphas.items() if k in alpha_keys} lora_config_kwargs = get_peft_kwargs(rank, network_alpha_dict=network_alphas, peft_state_dict=state_dict) - print(lora_config_kwargs) + lora_config_kwargs = _maybe_adjust_config(lora_config_kwargs) + if "use_dora" in lora_config_kwargs: if lora_config_kwargs["use_dora"]: if is_peft_version("<", "0.9.0"): From 81ab40b0c81dce2043324c6aadb44e11d6843394 Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Mon, 25 Nov 2024 17:56:35 +0530 Subject: [PATCH 20/58] updates --- src/diffusers/pipelines/flux/pipeline_flux_control.py | 10 +++++++++- 1 file changed, 9 insertions(+), 1 deletion(-) diff --git a/src/diffusers/pipelines/flux/pipeline_flux_control.py b/src/diffusers/pipelines/flux/pipeline_flux_control.py index 04a93ba6351c..1f0c08d98acf 100644 --- a/src/diffusers/pipelines/flux/pipeline_flux_control.py +++ b/src/diffusers/pipelines/flux/pipeline_flux_control.py @@ -762,7 +762,15 @@ def __call__( ) # 4. Prepare latent variables - num_channels_latents = self.transformer.config.in_channels // 8 + if self.transformer.x_embedder.weight.data.shape[1] != self.transformer.config.in_channels: + logger.info( + f"Different number of in_channels found in the transformer. " + f"`transformer.config.in_channels` is {self.transformer.config.in_channels}, whereas the " + f"x_embedder.weight.data.shape[1] is {self.transformer.x_embedder.weight.data.shape[1]}." + ) + num_channels_latents = self.transformer.x_embedder.weight.data.shape[1] // 8 + else: + num_channels_latents = self.transformer.config.in_channels // 8 control_image = self.prepare_image( image=control_image, From 16336193196a2dd0ca11d85c04a77eb3de2b8713 Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Tue, 26 Nov 2024 09:49:07 +0530 Subject: [PATCH 21/58] support register_config. --- src/diffusers/loaders/lora_pipeline.py | 11 ++++++++++- src/diffusers/pipelines/flux/pipeline_flux_control.py | 11 +---------- 2 files changed, 11 insertions(+), 11 deletions(-) diff --git a/src/diffusers/loaders/lora_pipeline.py b/src/diffusers/loaders/lora_pipeline.py index 9a8e8473d5c1..2b03686a71d4 100644 --- a/src/diffusers/loaders/lora_pipeline.py +++ b/src/diffusers/loaders/lora_pipeline.py @@ -61,6 +61,8 @@ UNET_NAME = "unet" TRANSFORMER_NAME = "transformer" +_MODULE_NAME_TO_ATTRIBUTE_MAP_FLUX = {"x_embedder": "in_channels"} + class StableDiffusionLoraLoaderMixin(LoraBaseMixin): r""" @@ -2271,7 +2273,7 @@ def get_submodule(module, name): module_out_features, module_in_features = module_weight.shape if out_features < module_out_features or in_features < module_in_features: raise NotImplementedError( - f"Only LoRAs with input/output features higher than the current modules' input/output features " + f"Only LoRAs with input/output features higher than the current module's input/output features " f"are currently supported. The provided LoRA contains {in_features=} and {out_features=}, which " f"are lower than {module_in_features=} and {module_out_features=}. If you require support for " f"this please open an issue at https://github.com/huggingface/diffusers/issues." @@ -2310,6 +2312,13 @@ def get_submodule(module, name): setattr(parent_module, current_module_name, expanded_module) + if current_module_name in _MODULE_NAME_TO_ATTRIBUTE_MAP_FLUX: + attribute_name = _MODULE_NAME_TO_ATTRIBUTE_MAP_FLUX[current_module_name] + new_value = int(expanded_module.weight.data.shape[1]) + old_value = getattr(transformer.config, attribute_name) + setattr(transformer.config, attribute_name, new_value) + logger.info(f"Set the {attribute_name} attribute of the model to {new_value} from {old_value}.") + return has_param_with_shape_update diff --git a/src/diffusers/pipelines/flux/pipeline_flux_control.py b/src/diffusers/pipelines/flux/pipeline_flux_control.py index 1f0c08d98acf..492bc520753d 100644 --- a/src/diffusers/pipelines/flux/pipeline_flux_control.py +++ b/src/diffusers/pipelines/flux/pipeline_flux_control.py @@ -762,16 +762,7 @@ def __call__( ) # 4. Prepare latent variables - if self.transformer.x_embedder.weight.data.shape[1] != self.transformer.config.in_channels: - logger.info( - f"Different number of in_channels found in the transformer. " - f"`transformer.config.in_channels` is {self.transformer.config.in_channels}, whereas the " - f"x_embedder.weight.data.shape[1] is {self.transformer.x_embedder.weight.data.shape[1]}." - ) - num_channels_latents = self.transformer.x_embedder.weight.data.shape[1] // 8 - else: - num_channels_latents = self.transformer.config.in_channels // 8 - + num_channels_latents = self.transformer.config.in_channels // 8 control_image = self.prepare_image( image=control_image, width=width, From b9039b18d5a61955a26383b35a0e760941aee8a7 Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Tue, 26 Nov 2024 10:06:12 +0530 Subject: [PATCH 22/58] fix --- src/diffusers/loaders/lora_pipeline.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/src/diffusers/loaders/lora_pipeline.py b/src/diffusers/loaders/lora_pipeline.py index 2b03686a71d4..7686d8bbf526 100644 --- a/src/diffusers/loaders/lora_pipeline.py +++ b/src/diffusers/loaders/lora_pipeline.py @@ -2187,9 +2187,13 @@ def fuse_lora( pipeline.fuse_lora(lora_scale=0.7) ``` """ - if len(self._transformer_norm_layers.keys()) > 0: + if ( + hasattr(self, "_transformer_norm_layers") + and isinstance(self._transformer_norm_layers, dict) + and len(self._transformer_norm_layers.keys()) > 0 + ): logger.info( - "The provided state dict contains normalization layers in addition to LoRA layers. The normalization layers will directly update the state_dict of the transformer " + "The provided state dict contains normalization layers in addition to LoRA layers. The normalization layers will be directly updated the state_dict of the transformer " 'as opposed to the LoRA layers that will co-exist separately until the "fuse_lora()" method is called. That is to say, the normalization layers will always be directly ' "fused into the transformer and can only be unfused if `discard_original_layers=True` is passed." ) From 5f94d7469f83a678eb6df5db2988b407f1ef2f78 Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Tue, 26 Nov 2024 10:31:53 +0530 Subject: [PATCH 23/58] fix --- src/diffusers/loaders/lora_pipeline.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/diffusers/loaders/lora_pipeline.py b/src/diffusers/loaders/lora_pipeline.py index 7686d8bbf526..39ccc93aeccf 100644 --- a/src/diffusers/loaders/lora_pipeline.py +++ b/src/diffusers/loaders/lora_pipeline.py @@ -2217,7 +2217,8 @@ def unfuse_lora(self, components: List[str] = ["transformer", "text_encoder"], * components (`List[str]`): List of LoRA-injectable components to unfuse LoRA from. """ transformer = getattr(self, self.transformer_name) if not hasattr(self, "transformer") else self.transformer - transformer.load_state_dict(self._transformer_norm_layers, strict=False) + if hasattr(transformer, "_transformer_norm_layers") and transformer._transformer_norm_layers: + transformer.load_state_dict(transformer._transformer_norm_layers, strict=False) super().unfuse_lora(components=components) From bd31651ae7bf16c18bd30be5b8216e848da934d7 Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Tue, 26 Nov 2024 10:40:00 +0530 Subject: [PATCH 24/58] fix --- src/diffusers/loaders/lora_pipeline.py | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/src/diffusers/loaders/lora_pipeline.py b/src/diffusers/loaders/lora_pipeline.py index 39ccc93aeccf..8ad773f50e53 100644 --- a/src/diffusers/loaders/lora_pipeline.py +++ b/src/diffusers/loaders/lora_pipeline.py @@ -2222,6 +2222,14 @@ def unfuse_lora(self, components: List[str] = ["transformer", "text_encoder"], * super().unfuse_lora(components=components) + # We override this here account for `_transformer_norm_layers`. + def unload_lora_weights(self): + super().unload_lora_weights() + + transformer = getattr(self, self.transformer_name) if not hasattr(self, "transformer") else self.transformer + if hasattr(transformer, "_transformer_norm_layers") and transformer._transformer_norm_layers: + transformer.load_state_dict(transformer._transformer_norm_layers, strict=False) + @classmethod def _maybe_expand_transformer_param_shape_or_error_( cls, From f54ec56f95d30c9982718681a7f0ab281be42423 Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Thu, 28 Nov 2024 12:52:16 +0530 Subject: [PATCH 25/58] updates --- scripts/convert_flux_control_lora_to_diffusers.py | 2 +- src/diffusers/loaders/peft.py | 11 +++++++++++ src/diffusers/utils/peft_utils.py | 3 +++ 3 files changed, 15 insertions(+), 1 deletion(-) diff --git a/scripts/convert_flux_control_lora_to_diffusers.py b/scripts/convert_flux_control_lora_to_diffusers.py index 031f77686190..2fb25184d00d 100644 --- a/scripts/convert_flux_control_lora_to_diffusers.py +++ b/scripts/convert_flux_control_lora_to_diffusers.py @@ -185,7 +185,7 @@ def convert_flux_control_lora_checkpoint_to_diffusers( f"double_blocks.{i}.img_mlp.0.{lora_key}.weight" ) if f"double_blocks.{i}.img_mlp.0.{lora_key}.bias" in original_state_dict_keys: - converted_state_dict[f"{block_prefix}ff.net.0.proj{lora_key}..bias"] = original_state_dict.pop( + converted_state_dict[f"{block_prefix}ff.net.0.proj.{lora_key}.bias"] = original_state_dict.pop( f"double_blocks.{i}.img_mlp.0.{lora_key}.bias" ) diff --git a/src/diffusers/loaders/peft.py b/src/diffusers/loaders/peft.py index d425f8a3a8b9..e15e0b820716 100644 --- a/src/diffusers/loaders/peft.py +++ b/src/diffusers/loaders/peft.py @@ -268,6 +268,17 @@ def load_lora_adapter(self, pretrained_model_name_or_path_or_dict, prefix="trans else: if is_peft_version("<", "0.9.0"): lora_config_kwargs.pop("use_dora") + + if "lora_bias" in lora_config_kwargs: + if lora_config_kwargs["lora_bias"]: + if is_peft_version("<", "0.13.2"): + raise ValueError( + "You need `peft` 0.13.3 at least to use `bias` in LoRAs. Please upgrade your installation of `peft`." + ) + else: + if is_peft_version("<", "0.13.2"): + lora_config_kwargs.pop("lora_bias") + lora_config = LoraConfig(**lora_config_kwargs) # adapter_name diff --git a/src/diffusers/utils/peft_utils.py b/src/diffusers/utils/peft_utils.py index dcc78a547a13..5947d5c558c7 100644 --- a/src/diffusers/utils/peft_utils.py +++ b/src/diffusers/utils/peft_utils.py @@ -180,6 +180,8 @@ def get_peft_kwargs(rank_dict, network_alpha_dict, peft_state_dict, is_unet=True # layer names without the Diffusers specific target_modules = list({name.split(".lora")[0] for name in peft_state_dict.keys()}) use_dora = any("lora_magnitude_vector" in k for k in peft_state_dict) + # for now we know that the "bias" keys are only associated with `lora_B`. + lora_bias = any("lora_B" and "bias" in k for k in peft_state_dict) lora_config_kwargs = { "r": r, @@ -188,6 +190,7 @@ def get_peft_kwargs(rank_dict, network_alpha_dict, peft_state_dict, is_unet=True "alpha_pattern": alpha_pattern, "target_modules": target_modules, "use_dora": use_dora, + "lora_bias": lora_bias, } return lora_config_kwargs From 8032405188c56c250814c283ad13554fd510880b Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Thu, 28 Nov 2024 12:59:39 +0530 Subject: [PATCH 26/58] updates --- src/diffusers/loaders/peft.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/diffusers/loaders/peft.py b/src/diffusers/loaders/peft.py index e15e0b820716..54921a104f6e 100644 --- a/src/diffusers/loaders/peft.py +++ b/src/diffusers/loaders/peft.py @@ -271,12 +271,12 @@ def load_lora_adapter(self, pretrained_model_name_or_path_or_dict, prefix="trans if "lora_bias" in lora_config_kwargs: if lora_config_kwargs["lora_bias"]: - if is_peft_version("<", "0.13.2"): + if is_peft_version("<=", "0.13.2"): raise ValueError( "You need `peft` 0.13.3 at least to use `bias` in LoRAs. Please upgrade your installation of `peft`." ) else: - if is_peft_version("<", "0.13.2"): + if is_peft_version("<=", "0.13.2"): lora_config_kwargs.pop("lora_bias") lora_config = LoraConfig(**lora_config_kwargs) From 6b70bf77f8d147eea9b19291bda3844ab318483f Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Thu, 28 Nov 2024 13:49:16 +0530 Subject: [PATCH 27/58] updates --- src/diffusers/loaders/lora_pipeline.py | 56 ++++++++++++++++++++++++++ 1 file changed, 56 insertions(+) diff --git a/src/diffusers/loaders/lora_pipeline.py b/src/diffusers/loaders/lora_pipeline.py index 8ad773f50e53..1dab1e3e46d7 100644 --- a/src/diffusers/loaders/lora_pipeline.py +++ b/src/diffusers/loaders/lora_pipeline.py @@ -410,6 +410,7 @@ def load_lora_into_text_encoder( } lora_config_kwargs = get_peft_kwargs(rank, network_alphas, text_encoder_lora_state_dict, is_unet=False) + if "use_dora" in lora_config_kwargs: if lora_config_kwargs["use_dora"]: if is_peft_version("<", "0.9.0"): @@ -419,6 +420,17 @@ def load_lora_into_text_encoder( else: if is_peft_version("<", "0.9.0"): lora_config_kwargs.pop("use_dora") + + if "lora_bias" in lora_config_kwargs: + if lora_config_kwargs["lora_bias"]: + if is_peft_version("<=", "0.13.2"): + raise ValueError( + "You need `peft` 0.13.3 at least to use `bias` in LoRAs. Please upgrade your installation of `peft`." + ) + else: + if is_peft_version("<=", "0.13.2"): + lora_config_kwargs.pop("lora_bias") + lora_config = LoraConfig(**lora_config_kwargs) # adapter_name @@ -950,6 +962,17 @@ def load_lora_into_text_encoder( else: if is_peft_version("<", "0.9.0"): lora_config_kwargs.pop("use_dora") + + if "lora_bias" in lora_config_kwargs: + if lora_config_kwargs["lora_bias"]: + if is_peft_version("<=", "0.13.2"): + raise ValueError( + "You need `peft` 0.13.3 at least to use `bias` in LoRAs. Please upgrade your installation of `peft`." + ) + else: + if is_peft_version("<=", "0.13.2"): + lora_config_kwargs.pop("lora_bias") + lora_config = LoraConfig(**lora_config_kwargs) # adapter_name @@ -1447,6 +1470,17 @@ def load_lora_into_text_encoder( else: if is_peft_version("<", "0.9.0"): lora_config_kwargs.pop("use_dora") + + if "lora_bias" in lora_config_kwargs: + if lora_config_kwargs["lora_bias"]: + if is_peft_version("<=", "0.13.2"): + raise ValueError( + "You need `peft` 0.13.3 at least to use `bias` in LoRAs. Please upgrade your installation of `peft`." + ) + else: + if is_peft_version("<=", "0.13.2"): + lora_config_kwargs.pop("lora_bias") + lora_config = LoraConfig(**lora_config_kwargs) # adapter_name @@ -2064,6 +2098,17 @@ def load_lora_into_text_encoder( else: if is_peft_version("<", "0.9.0"): lora_config_kwargs.pop("use_dora") + + if "lora_bias" in lora_config_kwargs: + if lora_config_kwargs["lora_bias"]: + if is_peft_version("<=", "0.13.2"): + raise ValueError( + "You need `peft` 0.13.3 at least to use `bias` in LoRAs. Please upgrade your installation of `peft`." + ) + else: + if is_peft_version("<=", "0.13.2"): + lora_config_kwargs.pop("lora_bias") + lora_config = LoraConfig(**lora_config_kwargs) # adapter_name @@ -2497,6 +2542,17 @@ def load_lora_into_text_encoder( else: if is_peft_version("<", "0.9.0"): lora_config_kwargs.pop("use_dora") + + if "lora_bias" in lora_config_kwargs: + if lora_config_kwargs["lora_bias"]: + if is_peft_version("<=", "0.13.2"): + raise ValueError( + "You need `peft` 0.13.3 at least to use `bias` in LoRAs. Please upgrade your installation of `peft`." + ) + else: + if is_peft_version("<=", "0.13.2"): + lora_config_kwargs.pop("lora_bias") + lora_config = LoraConfig(**lora_config_kwargs) # adapter_name From 3726e2d0d2ec1f7f8e0542d5b9d11acd4ee3e2f0 Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Thu, 28 Nov 2024 13:50:03 +0530 Subject: [PATCH 28/58] fix-copies --- src/diffusers/loaders/lora_pipeline.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/src/diffusers/loaders/lora_pipeline.py b/src/diffusers/loaders/lora_pipeline.py index 1dab1e3e46d7..ec0cf0005966 100644 --- a/src/diffusers/loaders/lora_pipeline.py +++ b/src/diffusers/loaders/lora_pipeline.py @@ -953,6 +953,7 @@ def load_lora_into_text_encoder( } lora_config_kwargs = get_peft_kwargs(rank, network_alphas, text_encoder_lora_state_dict, is_unet=False) + if "use_dora" in lora_config_kwargs: if lora_config_kwargs["use_dora"]: if is_peft_version("<", "0.9.0"): @@ -1461,6 +1462,7 @@ def load_lora_into_text_encoder( } lora_config_kwargs = get_peft_kwargs(rank, network_alphas, text_encoder_lora_state_dict, is_unet=False) + if "use_dora" in lora_config_kwargs: if lora_config_kwargs["use_dora"]: if is_peft_version("<", "0.9.0"): @@ -2089,6 +2091,7 @@ def load_lora_into_text_encoder( } lora_config_kwargs = get_peft_kwargs(rank, network_alphas, text_encoder_lora_state_dict, is_unet=False) + if "use_dora" in lora_config_kwargs: if lora_config_kwargs["use_dora"]: if is_peft_version("<", "0.9.0"): @@ -2533,6 +2536,7 @@ def load_lora_into_text_encoder( } lora_config_kwargs = get_peft_kwargs(rank, network_alphas, text_encoder_lora_state_dict, is_unet=False) + if "use_dora" in lora_config_kwargs: if lora_config_kwargs["use_dora"]: if is_peft_version("<", "0.9.0"): From 908d151d68cb9d4bee89638f03a72a71d4951dd4 Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Fri, 29 Nov 2024 17:17:08 +0530 Subject: [PATCH 29/58] fix --- src/diffusers/loaders/lora_pipeline.py | 7 ++++++- src/diffusers/loaders/peft.py | 10 ++++++++++ src/diffusers/utils/peft_utils.py | 2 +- 3 files changed, 17 insertions(+), 2 deletions(-) diff --git a/src/diffusers/loaders/lora_pipeline.py b/src/diffusers/loaders/lora_pipeline.py index ec0cf0005966..6cb57b5de884 100644 --- a/src/diffusers/loaders/lora_pipeline.py +++ b/src/diffusers/loaders/lora_pipeline.py @@ -1981,7 +1981,8 @@ def _load_norm_into_transformer( logger.info( "The provided state dict contains normalization layers in addition to LoRA layers. The normalization layers will directly update the state_dict of the transformer " 'as opposed to the LoRA layers that will co-exist separately until the "fuse_lora()" method is called. That is to say, the normalization layers will always be directly ' - "fused into the transformer and can only be unfused if `discard_original_layers=True` is passed." + "fused into the transformer and can only be unfused if `discard_original_layers=True` is passed. This might also have implications when dealing with multiple LoRAs. " + "If you notice something unexpected, please open an issue: https://github.com/huggingface/diffusers/issues." ) # We can't load with strict=True because the current state_dict does not contain all the transformer keys @@ -2286,6 +2287,10 @@ def _maybe_expand_transformer_param_shape_or_error_( norm_state_dict=None, prefix=None, ) -> bool: + """ + Control LoRA expands the shape of the input layer from (3072, 64) to (3072, 128). This method handles that and + generalizes things a bit so that any parameter that needs expansion receives appropriate treatement. + """ state_dict = {} if lora_state_dict is not None: state_dict.update(lora_state_dict) diff --git a/src/diffusers/loaders/peft.py b/src/diffusers/loaders/peft.py index 54921a104f6e..b39936df84a1 100644 --- a/src/diffusers/loaders/peft.py +++ b/src/diffusers/loaders/peft.py @@ -57,6 +57,12 @@ def _maybe_adjust_config(config): + """ + We may run into some ambiguous configuration values when a model has module names, sharing a common prefix + (`proj_out.weight` and `blocks.transformer.proj_out.weight`, for example) and they have different LoRA ranks. This + method removes the ambiguity by following what is described here: + https://github.com/huggingface/diffusers/pull/9985#issuecomment-2493840028. + """ rank_pattern = config["rank_pattern"].copy() target_modules = config["target_modules"] original_r = config["r"] @@ -65,6 +71,10 @@ def _maybe_adjust_config(config): key_rank = rank_pattern[key] # try to detect ambiguity + # `target_modules` can also be a str, in which case this loop would loop + # over the chars of the str. The technically correct way to match LoRA keys + # in PEFT is to use LoraModel._check_target_module_exists (lora_config, key). + # But this cuts it for now. exact_matches = [mod for mod in target_modules if mod == key] substring_matches = [mod for mod in target_modules if key in mod and mod != key] ambiguous_key = key diff --git a/src/diffusers/utils/peft_utils.py b/src/diffusers/utils/peft_utils.py index 5947d5c558c7..a518596f4756 100644 --- a/src/diffusers/utils/peft_utils.py +++ b/src/diffusers/utils/peft_utils.py @@ -181,7 +181,7 @@ def get_peft_kwargs(rank_dict, network_alpha_dict, peft_state_dict, is_unet=True target_modules = list({name.split(".lora")[0] for name in peft_state_dict.keys()}) use_dora = any("lora_magnitude_vector" in k for k in peft_state_dict) # for now we know that the "bias" keys are only associated with `lora_B`. - lora_bias = any("lora_B" and "bias" in k for k in peft_state_dict) + lora_bias = any("lora_B" in k and k.endswith(".bias") for k in peft_state_dict) lora_config_kwargs = { "r": r, From 07d44e7ac01f5a4fdd3b5f6bb188209d1dcf7fea Mon Sep 17 00:00:00 2001 From: Aryan Date: Sun, 1 Dec 2024 21:31:19 +0100 Subject: [PATCH 30/58] apply suggestions from review --- src/diffusers/loaders/lora_pipeline.py | 25 +++++++------------------ 1 file changed, 7 insertions(+), 18 deletions(-) diff --git a/src/diffusers/loaders/lora_pipeline.py b/src/diffusers/loaders/lora_pipeline.py index 6cb57b5de884..29681b9fb11d 100644 --- a/src/diffusers/loaders/lora_pipeline.py +++ b/src/diffusers/loaders/lora_pipeline.py @@ -11,6 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. + import os from typing import Callable, Dict, List, Optional, Union @@ -1956,7 +1957,7 @@ def _load_norm_into_transformer( prefix = prefix or cls.transformer_name for key in list(state_dict.keys()): if key.split(".")[0] == prefix: - state_dict[key.replace(f"{prefix}.", "")] = state_dict.pop(key) + state_dict[key[len(f"{prefix}.") :]] = state_dict.pop(key) # Find invalid keys transformer_state_dict = transformer.state_dict() @@ -2278,6 +2279,7 @@ def unload_lora_weights(self): transformer = getattr(self, self.transformer_name) if not hasattr(self, "transformer") else self.transformer if hasattr(transformer, "_transformer_norm_layers") and transformer._transformer_norm_layers: transformer.load_state_dict(transformer._transformer_norm_layers, strict=False) + transformer._transformer_norm_layers = None @classmethod def _maybe_expand_transformer_param_shape_or_error_( @@ -2303,15 +2305,6 @@ def _maybe_expand_transformer_param_shape_or_error_( if key.split(".")[0] == prefix: state_dict[key.replace(f"{prefix}.", "")] = state_dict.pop(key) - def get_submodule(module, name): - for part in name.split("."): - if len(name) == 0: - break - if not hasattr(module, part): - raise AttributeError(f"Submodule '{part}' not found in '{module}'.") - module = getattr(module, part) - return module - # Expand transformer parameter shapes if they don't match lora has_param_with_shape_update = False @@ -2320,12 +2313,9 @@ def get_submodule(module, name): module_weight = module.weight.data module_bias = module.bias.data if hasattr(module, "bias") else None bias = module_bias is not None - name_split = name.split(".") - lora_A_name = f"{name}.lora_A" - lora_B_name = f"{name}.lora_B" - lora_A_weight_name = f"{lora_A_name}.weight" - lora_B_weight_name = f"{lora_B_name}.weight" + lora_A_weight_name = f"{name}.lora_A.weight" + lora_B_weight_name = f"{name}.lora_B.weight" if lora_A_weight_name not in state_dict.keys(): continue @@ -2353,9 +2343,8 @@ def get_submodule(module, name): ) has_param_with_shape_update = True - parent_module_name = ".".join(name_split[:-1]) - current_module_name = name_split[-1] - parent_module = get_submodule(transformer, parent_module_name) + parent_module_name, _, current_module_name = name.rpartition(".") + parent_module = transformer.get_submodule(parent_module_name) expanded_module = torch.nn.Linear( in_features, out_features, bias=bias, device=module_weight.device, dtype=module_weight.dtype From b66e691058652815b7791fea1c4aecb8cba26dd4 Mon Sep 17 00:00:00 2001 From: Aryan Date: Sun, 1 Dec 2024 23:56:08 +0100 Subject: [PATCH 31/58] add tests --- src/diffusers/loaders/lora_pipeline.py | 30 ++--- tests/lora/test_lora_layers_flux.py | 153 +++++++++++++++++++++++++ 2 files changed, 161 insertions(+), 22 deletions(-) diff --git a/src/diffusers/loaders/lora_pipeline.py b/src/diffusers/loaders/lora_pipeline.py index 29681b9fb11d..f91d35de81e4 100644 --- a/src/diffusers/loaders/lora_pipeline.py +++ b/src/diffusers/loaders/lora_pipeline.py @@ -1835,22 +1835,6 @@ def load_lora_weights( if not (has_lora_keys or has_norm_keys): raise ValueError("Invalid LoRA checkpoint.") - def prune_state_dict_(state_dict): - pruned_keys = [] - for key in list(state_dict.keys()): - is_lora_key_present = "lora" in key - is_norm_key_present = any(norm_key in key for norm_key in supported_norm_keys) - if not is_lora_key_present and not is_norm_key_present: - state_dict.pop(key) - pruned_keys.append(key) - return pruned_keys - - pruned_keys = prune_state_dict_(state_dict) - if len(pruned_keys) > 0: - logger.warning( - f"The provided LoRA state dict contains additional weights that are not compatible with Flux. The following are the incompatible weights:\n{pruned_keys}" - ) - transformer_lora_state_dict = { k: state_dict.pop(k) for k in list(state_dict.keys()) if "transformer." in k and "lora" in k } @@ -1883,7 +1867,7 @@ def prune_state_dict_(state_dict): ) if len(transformer_norm_state_dict) > 0: - self._transformer_norm_layers = self._load_norm_into_transformer( + transformer._transformer_norm_layers = self._load_norm_into_transformer( transformer_norm_state_dict, transformer=transformer, discard_original_layers=False, @@ -1977,7 +1961,7 @@ def _load_norm_into_transformer( overwritten_layers_state_dict = {} if not discard_original_layers: for key in state_dict.keys(): - overwritten_layers_state_dict[key] = transformer_state_dict[key] + overwritten_layers_state_dict[key] = transformer_state_dict[key].clone() logger.info( "The provided state dict contains normalization layers in addition to LoRA layers. The normalization layers will directly update the state_dict of the transformer " @@ -2237,10 +2221,12 @@ def fuse_lora( pipeline.fuse_lora(lora_scale=0.7) ``` """ + + transformer = getattr(self, self.transformer_name) if not hasattr(self, "transformer") else self.transformer if ( - hasattr(self, "_transformer_norm_layers") - and isinstance(self._transformer_norm_layers, dict) - and len(self._transformer_norm_layers.keys()) > 0 + hasattr(transformer, "_transformer_norm_layers") + and isinstance(transformer._transformer_norm_layers, dict) + and len(transformer._transformer_norm_layers.keys()) > 0 ): logger.info( "The provided state dict contains normalization layers in addition to LoRA layers. The normalization layers will be directly updated the state_dict of the transformer " @@ -2303,7 +2289,7 @@ def _maybe_expand_transformer_param_shape_or_error_( prefix = prefix or cls.transformer_name for key in list(state_dict.keys()): if key.split(".")[0] == prefix: - state_dict[key.replace(f"{prefix}.", "")] = state_dict.pop(key) + state_dict[key[len(f"{prefix}.") :]] = state_dict.pop(key) # Expand transformer parameter shapes if they don't match lora has_param_with_shape_update = False diff --git a/tests/lora/test_lora_layers_flux.py b/tests/lora/test_lora_layers_flux.py index e6e87c7ba939..0396898ebcf9 100644 --- a/tests/lora/test_lora_layers_flux.py +++ b/tests/lora/test_lora_layers_flux.py @@ -24,12 +24,15 @@ from transformers import AutoTokenizer, CLIPTextModel, CLIPTokenizer, T5EncoderModel from diffusers import FlowMatchEulerDiscreteScheduler, FluxPipeline, FluxTransformer2DModel +from diffusers.utils import logging from diffusers.utils.testing_utils import ( + CaptureLogger, floats_tensor, is_peft_available, nightly, numpy_cosine_similarity_distance, require_peft_backend, + require_peft_version_greater, require_torch_gpu, slow, torch_device, @@ -108,6 +111,30 @@ def get_dummy_inputs(self, with_generator=True): return noise, input_ids, pipeline_inputs + def get_dummy_tensor_inputs(self, device=None): + batch_size = 1 + num_latent_channels = 4 + num_image_channels = 3 + height = width = 4 + sequence_length = 48 + embedding_dim = 32 + + hidden_states = torch.randn((batch_size, height * width, num_latent_channels)).to(torch_device) + encoder_hidden_states = torch.randn((batch_size, sequence_length, embedding_dim)).to(torch_device) + pooled_prompt_embeds = torch.randn((batch_size, embedding_dim)).to(torch_device) + text_ids = torch.randn((sequence_length, num_image_channels)).to(torch_device) + image_ids = torch.randn((height * width, num_image_channels)).to(torch_device) + timestep = torch.tensor([1.0]).to(torch_device).expand(batch_size) + + return { + "hidden_states": hidden_states, + "encoder_hidden_states": encoder_hidden_states, + "pooled_projections": pooled_prompt_embeds, + "txt_ids": text_ids, + "img_ids": image_ids, + "timestep": timestep, + } + def test_with_alpha_in_state_dict(self): components, _, denoiser_lora_config = self.get_dummy_components(FlowMatchEulerDiscreteScheduler) pipe = self.pipeline_class(**components) @@ -156,6 +183,132 @@ def test_with_alpha_in_state_dict(self): ) self.assertFalse(np.allclose(images_lora_with_alpha, images_lora, atol=1e-3, rtol=1e-3)) + def test_with_norm_in_state_dict(self): + components, _, denoiser_lora_config = self.get_dummy_components(FlowMatchEulerDiscreteScheduler) + pipe = self.pipeline_class(**components) + pipe = pipe.to(torch_device) + pipe.set_progress_bar_config(disable=None) + + inputs = self.get_dummy_tensor_inputs(torch_device) + + logger = logging.get_logger("diffusers.loaders.lora_pipeline") + logger.setLevel(logging.INFO) + + with torch.no_grad(): + original_output = pipe.transformer(**inputs)[0] + + for norm_layer in ["norm_q", "norm_k", "norm_added_q", "norm_added_k"]: + norm_state_dict = {} + for name, module in pipe.transformer.named_modules(): + if norm_layer not in name or not hasattr(module, "weight") or module.weight is None: + continue + norm_state_dict[f"transformer.{name}.weight"] = torch.randn( + module.weight.shape, device=module.weight.device, dtype=module.weight.dtype + ) + + with torch.no_grad(): + with CaptureLogger(logger) as cap_logger: + pipe.load_lora_weights(norm_state_dict) + lora_load_output = pipe.transformer(**inputs)[0] + self.assertTrue( + cap_logger.out.startswith( + "The provided state dict contains normalization layers in addition to LoRA layers" + ) + ) + + pipe.unload_lora_weights() + lora_unload_output = pipe.transformer(**inputs)[0] + + self.assertTrue(pipe.transformer._transformer_norm_layers is None) + self.assertFalse(np.allclose(original_output, lora_load_output, atol=1e-5, rtol=1e-5)) + self.assertTrue(np.allclose(original_output, lora_unload_output, atol=1e-5, rtol=1e-5)) + + with CaptureLogger(logger) as cap_logger: + for key in list(norm_state_dict.keys()): + norm_state_dict[key.replace("norm", "norm_k_something_random")] = norm_state_dict.pop(key) + pipe.load_lora_weights(norm_state_dict) + + self.assertTrue( + cap_logger.out.startswith("Unsupported keys found in state dict when trying to load normalization layers") + ) + + def test_lora_parameter_expanded_shapes(self): + components, _, _ = self.get_dummy_components(FlowMatchEulerDiscreteScheduler) + pipe = self.pipeline_class(**components) + pipe = pipe.to(torch_device) + pipe.set_progress_bar_config(disable=None) + + inputs = self.get_dummy_tensor_inputs(torch_device) + + logger = logging.get_logger("diffusers.loaders.lora_pipeline") + logger.setLevel(logging.DEBUG) + + with torch.no_grad(): + original_output = pipe.transformer(**inputs)[0] + + out_features, in_features = pipe.transformer.x_embedder.weight.shape + rank = 4 + + dummy_lora_A = torch.nn.Linear(2 * in_features, rank, bias=False) + dummy_lora_B = torch.nn.Linear(rank, out_features, bias=False) + lora_state_dict = { + "transformer.x_embedder.lora_A.weight": dummy_lora_A.weight, + "transformer.x_embedder.lora_B.weight": dummy_lora_B.weight, + } + with CaptureLogger(logger) as cap_logger: + pipe.load_lora_weights(lora_state_dict, "adapter-1") + inputs["hidden_states"] = torch.cat([inputs["hidden_states"]] * 2, dim=2) + with torch.no_grad(): + expanded_output = pipe.transformer(**inputs)[0] + pipe.delete_adapters("adapter-1") + self.assertTrue(cap_logger.out.startswith("Expanding the nn.Linear input/output features for module")) + self.assertFalse(np.allclose(original_output, expanded_output, atol=1e-3, rtol=1e-3)) + + components, _, _ = self.get_dummy_components(FlowMatchEulerDiscreteScheduler) + pipe = self.pipeline_class(**components) + pipe = pipe.to(torch_device) + pipe.set_progress_bar_config(disable=None) + dummy_lora_A = torch.nn.Linear(1, rank, bias=False) + dummy_lora_B = torch.nn.Linear(rank, out_features, bias=False) + lora_state_dict = { + "transformer.x_embedder.lora_A.weight": dummy_lora_A.weight, + "transformer.x_embedder.lora_B.weight": dummy_lora_B.weight, + } + # We should error out because lora input features is less than original. We only + # support expanding the module, not shrinking it + with self.assertRaises(NotImplementedError): + pipe.load_lora_weights(lora_state_dict, "adapter-1") + + @require_peft_version_greater("0.13.2") + def test_lora_B_bias(self): + components, _, denoiser_lora_config = self.get_dummy_components(FlowMatchEulerDiscreteScheduler) + pipe = self.pipeline_class(**components) + pipe = pipe.to(torch_device) + pipe.set_progress_bar_config(disable=None) + + inputs = self.get_dummy_tensor_inputs(torch_device) + + logger = logging.get_logger("diffusers.loaders.lora_pipeline") + logger.setLevel(logging.INFO) + + with torch.no_grad(): + original_output = pipe.transformer(**inputs)[0] + + denoiser_lora_config.lora_bias = False + pipe.transformer.add_adapter(denoiser_lora_config, "adapter-1") + with torch.no_grad(): + lora_bias_false_output = pipe.transformer(**inputs)[0] + pipe.delete_adapters("adapter-1") + + denoiser_lora_config.lora_bias = True + pipe.transformer.add_adapter(denoiser_lora_config, "adapter-1") + with torch.no_grad(): + lora_bias_true_output = pipe.transformer(**inputs)[0] + + self.assertFalse(np.allclose(original_output, lora_bias_false_output, atol=1e-3, rtol=1e-3)) + self.assertFalse(np.allclose(original_output, lora_bias_true_output, atol=1e-3, rtol=1e-3)) + self.assertFalse(np.allclose(lora_bias_false_output, lora_bias_true_output, atol=1e-3, rtol=1e-3)) + @unittest.skip("Not supported in Flux.") def test_simple_inference_with_text_denoiser_block_scale_for_all_dict_options(self): pass From 66d746606d5fa9e1ae4977fe88fdf3bd6592724b Mon Sep 17 00:00:00 2001 From: Aryan Date: Mon, 2 Dec 2024 09:51:16 +0100 Subject: [PATCH 32/58] remove conversion script; enable on-the-fly conversion --- .../convert_flux_control_lora_to_diffusers.py | 353 ------------------ .../loaders/lora_conversion_utils.py | 207 ++++++++++ src/diffusers/loaders/lora_pipeline.py | 6 + 3 files changed, 213 insertions(+), 353 deletions(-) delete mode 100644 scripts/convert_flux_control_lora_to_diffusers.py diff --git a/scripts/convert_flux_control_lora_to_diffusers.py b/scripts/convert_flux_control_lora_to_diffusers.py deleted file mode 100644 index 2fb25184d00d..000000000000 --- a/scripts/convert_flux_control_lora_to_diffusers.py +++ /dev/null @@ -1,353 +0,0 @@ -import argparse - -import safetensors.torch -import torch -from huggingface_hub import hf_hub_download - - -parser = argparse.ArgumentParser() -parser.add_argument("--original_state_dict_repo_id", default=None, type=str) -parser.add_argument("--filename", default="flux-canny-dev-lora.safetensors", type=str) -parser.add_argument("--checkpoint_path", default=None, type=str) -parser.add_argument("--output_path", type=str) -parser.add_argument("--dtype", type=str, default="bf16") - -args = parser.parse_args() -dtype = torch.bfloat16 if args.dtype == "bf16" else torch.float32 - - -def load_original_checkpoint(args): - if args.original_state_dict_repo_id is not None: - ckpt_path = hf_hub_download(repo_id=args.original_state_dict_repo_id, filename=args.filename) - elif args.checkpoint_path is not None: - ckpt_path = args.checkpoint_path - else: - raise ValueError("Please provide either `original_state_dict_repo_id` or a local `checkpoint_path`") - - original_state_dict = safetensors.torch.load_file(ckpt_path) - return original_state_dict - - -# in SD3 original implementation of AdaLayerNormContinuous, it split linear projection output into shift, scale; -# while in diffusers it split into scale, shift. Here we swap the linear projection weights in order to be able to use diffusers implementation -def swap_scale_shift(weight): - shift, scale = weight.chunk(2, dim=0) - new_weight = torch.cat([scale, shift], dim=0) - return new_weight - - -def convert_flux_control_lora_checkpoint_to_diffusers( - original_state_dict, num_layers, num_single_layers, inner_dim, mlp_ratio=4.0 -): - converted_state_dict = {} - original_state_dict_keys = list(original_state_dict.keys()) - - for lora_key in ["lora_A", "lora_B"]: - ## time_text_embed.timestep_embedder <- time_in - converted_state_dict[ - f"time_text_embed.timestep_embedder.linear_1.{lora_key}.weight" - ] = original_state_dict.pop(f"time_in.in_layer.{lora_key}.weight") - if f"time_in.in_layer.{lora_key}.bias" in original_state_dict_keys: - converted_state_dict[ - f"time_text_embed.timestep_embedder.linear_1.{lora_key}.bias" - ] = original_state_dict.pop(f"time_in.in_layer.{lora_key}.bias") - - converted_state_dict[ - f"time_text_embed.timestep_embedder.linear_2.{lora_key}.weight" - ] = original_state_dict.pop(f"time_in.out_layer.{lora_key}.weight") - if f"time_in.out_layer.{lora_key}.bias" in original_state_dict_keys: - converted_state_dict[ - f"time_text_embed.timestep_embedder.linear_2.{lora_key}.bias" - ] = original_state_dict.pop(f"time_in.out_layer.{lora_key}.bias") - - ## time_text_embed.text_embedder <- vector_in - converted_state_dict[f"time_text_embed.text_embedder.linear_1.{lora_key}.weight"] = original_state_dict.pop( - f"vector_in.in_layer.{lora_key}.weight" - ) - if f"vector_in.in_layer.{lora_key}.bias" in original_state_dict_keys: - converted_state_dict[f"time_text_embed.text_embedder.linear_1.{lora_key}.bias"] = original_state_dict.pop( - f"vector_in.in_layer.{lora_key}.bias" - ) - - converted_state_dict[f"time_text_embed.text_embedder.linear_2.{lora_key}.weight"] = original_state_dict.pop( - f"vector_in.out_layer.{lora_key}.weight" - ) - if f"vector_in.out_layer.{lora_key}.bias" in original_state_dict_keys: - converted_state_dict[f"time_text_embed.text_embedder.linear_2.{lora_key}.bias"] = original_state_dict.pop( - f"vector_in.out_layer.{lora_key}.bias" - ) - - # guidance - has_guidance = any("guidance" in k for k in original_state_dict) - if has_guidance: - converted_state_dict[ - f"time_text_embed.guidance_embedder.linear_1.{lora_key}.weight" - ] = original_state_dict.pop(f"guidance_in.in_layer.{lora_key}.weight") - if f"guidance_in.in_layer.{lora_key}.bias" in original_state_dict_keys: - converted_state_dict[ - f"time_text_embed.guidance_embedder.linear_1.{lora_key}.bias" - ] = original_state_dict.pop(f"guidance_in.in_layer.{lora_key}.bias") - - converted_state_dict[ - f"time_text_embed.guidance_embedder.linear_2.{lora_key}.weight" - ] = original_state_dict.pop(f"guidance_in.out_layer.{lora_key}.weight") - if f"guidance_in.out_layer.{lora_key}.bias" in original_state_dict_keys: - converted_state_dict[ - f"time_text_embed.guidance_embedder.linear_2.{lora_key}.bias" - ] = original_state_dict.pop(f"guidance_in.out_layer.{lora_key}.bias") - - # context_embedder - converted_state_dict[f"context_embedder.{lora_key}.weight"] = original_state_dict.pop( - f"txt_in.{lora_key}.weight" - ) - if f"txt_in.{lora_key}.bias" in original_state_dict_keys: - converted_state_dict[f"context_embedder.{lora_key}.bias"] = original_state_dict.pop( - f"txt_in.{lora_key}.bias" - ) - - # x_embedder - converted_state_dict[f"x_embedder.{lora_key}.weight"] = original_state_dict.pop(f"img_in.{lora_key}.weight") - if f"img_in.{lora_key}.bias" in original_state_dict_keys: - converted_state_dict[f"x_embedder.{lora_key}.bias"] = original_state_dict.pop(f"img_in.{lora_key}.bias") - - # double transformer blocks - for i in range(num_layers): - block_prefix = f"transformer_blocks.{i}." - - for lora_key, lora_key in zip(["lora_A", "lora_B"], ["lora_A", "lora_B"]): - # norms - converted_state_dict[f"{block_prefix}norm1.linear.{lora_key}.weight"] = original_state_dict.pop( - f"double_blocks.{i}.img_mod.lin.{lora_key}.weight" - ) - if f"double_blocks.{i}.img_mod.lin.{lora_key}.bias" in original_state_dict_keys: - converted_state_dict[f"{block_prefix}norm1.linear.{lora_key}.bias"] = original_state_dict.pop( - f"double_blocks.{i}.img_mod.lin.{lora_key}.bias" - ) - - converted_state_dict[f"{block_prefix}norm1_context.linear.{lora_key}.weight"] = original_state_dict.pop( - f"double_blocks.{i}.txt_mod.lin.{lora_key}.weight" - ) - if f"double_blocks.{i}.txt_mod.lin.{lora_key}.bias" in original_state_dict_keys: - converted_state_dict[f"{block_prefix}norm1_context.linear.{lora_key}.bias"] = original_state_dict.pop( - f"double_blocks.{i}.txt_mod.lin.{lora_key}.bias" - ) - - # Q, K, V - if lora_key == "lora_A": - sample_lora_weight = original_state_dict.pop(f"double_blocks.{i}.img_attn.qkv.{lora_key}.weight") - converted_state_dict[f"{block_prefix}attn.to_v.{lora_key}.weight"] = torch.cat([sample_lora_weight]) - converted_state_dict[f"{block_prefix}attn.to_q.{lora_key}.weight"] = torch.cat([sample_lora_weight]) - converted_state_dict[f"{block_prefix}attn.to_k.{lora_key}.weight"] = torch.cat([sample_lora_weight]) - - context_lora_weight = original_state_dict.pop(f"double_blocks.{i}.txt_attn.qkv.{lora_key}.weight") - converted_state_dict[f"{block_prefix}attn.add_q_proj.{lora_key}.weight"] = torch.cat( - [context_lora_weight] - ) - converted_state_dict[f"{block_prefix}attn.add_k_proj.{lora_key}.weight"] = torch.cat( - [context_lora_weight] - ) - converted_state_dict[f"{block_prefix}attn.add_v_proj.{lora_key}.weight"] = torch.cat( - [context_lora_weight] - ) - else: - sample_q, sample_k, sample_v = torch.chunk( - original_state_dict.pop(f"double_blocks.{i}.img_attn.qkv.{lora_key}.weight"), 3, dim=0 - ) - converted_state_dict[f"{block_prefix}attn.to_q.{lora_key}.weight"] = torch.cat([sample_q]) - converted_state_dict[f"{block_prefix}attn.to_k.{lora_key}.weight"] = torch.cat([sample_k]) - converted_state_dict[f"{block_prefix}attn.to_v.{lora_key}.weight"] = torch.cat([sample_v]) - - context_q, context_k, context_v = torch.chunk( - original_state_dict.pop(f"double_blocks.{i}.txt_attn.qkv.{lora_key}.weight"), 3, dim=0 - ) - converted_state_dict[f"{block_prefix}attn.add_q_proj.{lora_key}.weight"] = torch.cat([context_q]) - converted_state_dict[f"{block_prefix}attn.add_k_proj.{lora_key}.weight"] = torch.cat([context_k]) - converted_state_dict[f"{block_prefix}attn.add_v_proj.{lora_key}.weight"] = torch.cat([context_v]) - - if f"double_blocks.{i}.img_attn.qkv.{lora_key}.bias" in original_state_dict_keys: - sample_q_bias, sample_k_bias, sample_v_bias = torch.chunk( - original_state_dict.pop(f"double_blocks.{i}.img_attn.qkv.{lora_key}.bias"), 3, dim=0 - ) - converted_state_dict[f"{block_prefix}attn.to_q.{lora_key}.bias"] = torch.cat([sample_q_bias]) - converted_state_dict[f"{block_prefix}attn.to_k.{lora_key}.bias"] = torch.cat([sample_k_bias]) - converted_state_dict[f"{block_prefix}attn.to_v.{lora_key}.bias"] = torch.cat([sample_v_bias]) - - if f"double_blocks.{i}.txt_attn.qkv.{lora_key}.bias" in original_state_dict_keys: - context_q_bias, context_k_bias, context_v_bias = torch.chunk( - original_state_dict.pop(f"double_blocks.{i}.txt_attn.qkv.{lora_key}.bias"), 3, dim=0 - ) - converted_state_dict[f"{block_prefix}attn.add_q_proj.{lora_key}.bias"] = torch.cat([context_q_bias]) - converted_state_dict[f"{block_prefix}attn.add_k_proj.{lora_key}.bias"] = torch.cat([context_k_bias]) - converted_state_dict[f"{block_prefix}attn.add_v_proj.{lora_key}.bias"] = torch.cat([context_v_bias]) - - # ff img_mlp - converted_state_dict[f"{block_prefix}ff.net.0.proj.{lora_key}.weight"] = original_state_dict.pop( - f"double_blocks.{i}.img_mlp.0.{lora_key}.weight" - ) - if f"double_blocks.{i}.img_mlp.0.{lora_key}.bias" in original_state_dict_keys: - converted_state_dict[f"{block_prefix}ff.net.0.proj.{lora_key}.bias"] = original_state_dict.pop( - f"double_blocks.{i}.img_mlp.0.{lora_key}.bias" - ) - - converted_state_dict[f"{block_prefix}ff.net.2.{lora_key}.weight"] = original_state_dict.pop( - f"double_blocks.{i}.img_mlp.2.{lora_key}.weight" - ) - if f"double_blocks.{i}.img_mlp.2.{lora_key}.bias" in original_state_dict_keys: - converted_state_dict[f"{block_prefix}ff.net.2.{lora_key}.bias"] = original_state_dict.pop( - f"double_blocks.{i}.img_mlp.2.{lora_key}.bias" - ) - - converted_state_dict[f"{block_prefix}ff_context.net.0.proj.{lora_key}.weight"] = original_state_dict.pop( - f"double_blocks.{i}.txt_mlp.0.{lora_key}.weight" - ) - if f"double_blocks.{i}.txt_mlp.0.{lora_key}.bias" in original_state_dict_keys: - converted_state_dict[f"{block_prefix}ff_context.net.0.proj.{lora_key}.bias"] = original_state_dict.pop( - f"double_blocks.{i}.txt_mlp.0.{lora_key}.bias" - ) - - converted_state_dict[f"{block_prefix}ff_context.net.2.{lora_key}.weight"] = original_state_dict.pop( - f"double_blocks.{i}.txt_mlp.2.{lora_key}.weight" - ) - if f"double_blocks.{i}.txt_mlp.2.{lora_key}.bias" in original_state_dict_keys: - converted_state_dict[f"{block_prefix}ff_context.net.2.{lora_key}.bias"] = original_state_dict.pop( - f"double_blocks.{i}.txt_mlp.2.{lora_key}.bias" - ) - - # output projections. - converted_state_dict[f"{block_prefix}attn.to_out.0.{lora_key}.weight"] = original_state_dict.pop( - f"double_blocks.{i}.img_attn.proj.{lora_key}.weight" - ) - if f"double_blocks.{i}.img_attn.proj.{lora_key}.bias" in original_state_dict_keys: - converted_state_dict[f"{block_prefix}attn.to_out.0.{lora_key}.bias"] = original_state_dict.pop( - f"double_blocks.{i}.img_attn.proj.{lora_key}.bias" - ) - converted_state_dict[f"{block_prefix}attn.to_add_out.{lora_key}.weight"] = original_state_dict.pop( - f"double_blocks.{i}.txt_attn.proj.{lora_key}.weight" - ) - if f"double_blocks.{i}.txt_attn.proj.{lora_key}.bias" in original_state_dict_keys: - converted_state_dict[f"{block_prefix}attn.to_add_out.{lora_key}.bias"] = original_state_dict.pop( - f"double_blocks.{i}.txt_attn.proj.{lora_key}.bias" - ) - - # qk_norm - converted_state_dict[f"{block_prefix}attn.norm_q.weight"] = original_state_dict.pop( - f"double_blocks.{i}.img_attn.norm.query_norm.scale" - ) - converted_state_dict[f"{block_prefix}attn.norm_k.weight"] = original_state_dict.pop( - f"double_blocks.{i}.img_attn.norm.key_norm.scale" - ) - converted_state_dict[f"{block_prefix}attn.norm_added_q.weight"] = original_state_dict.pop( - f"double_blocks.{i}.txt_attn.norm.query_norm.scale" - ) - converted_state_dict[f"{block_prefix}attn.norm_added_k.weight"] = original_state_dict.pop( - f"double_blocks.{i}.txt_attn.norm.key_norm.scale" - ) - - # single transfomer blocks - for i in range(num_single_layers): - block_prefix = f"single_transformer_blocks.{i}." - - for lora_key in ["lora_A", "lora_B"]: - # norm.linear <- single_blocks.0.modulation.lin - converted_state_dict[f"{block_prefix}norm.linear.{lora_key}.weight"] = original_state_dict.pop( - f"single_blocks.{i}.modulation.lin.{lora_key}.weight" - ) - if f"single_blocks.{i}.modulation.lin.{lora_key}.bias" in original_state_dict_keys: - converted_state_dict[f"{block_prefix}norm.linear.{lora_key}.bias"] = original_state_dict.pop( - f"single_blocks.{i}.modulation.lin.{lora_key}.bias" - ) - - # Q, K, V, mlp - mlp_hidden_dim = int(inner_dim * mlp_ratio) - split_size = (inner_dim, inner_dim, inner_dim, mlp_hidden_dim) - - if lora_key == "lora_A": - lora_weight = original_state_dict.pop(f"single_blocks.{i}.linear1.{lora_key}.weight") - converted_state_dict[f"{block_prefix}attn.to_q.{lora_key}.weight"] = torch.cat([lora_weight]) - converted_state_dict[f"{block_prefix}attn.to_k.{lora_key}.weight"] = torch.cat([lora_weight]) - converted_state_dict[f"{block_prefix}attn.to_v.{lora_key}.weight"] = torch.cat([lora_weight]) - converted_state_dict[f"{block_prefix}proj_mlp.{lora_key}.weight"] = torch.cat([lora_weight]) - - if f"single_blocks.{i}.linear1.{lora_key}.bias" in original_state_dict_keys: - lora_bias = original_state_dict.pop(f"single_blocks.{i}.linear1.{lora_key}.bias") - converted_state_dict[f"{block_prefix}attn.to_q.{lora_key}.bias"] = torch.cat([lora_bias]) - converted_state_dict[f"{block_prefix}attn.to_k.{lora_key}.bias"] = torch.cat([lora_bias]) - converted_state_dict[f"{block_prefix}attn.to_v.{lora_key}.bias"] = torch.cat([lora_bias]) - converted_state_dict[f"{block_prefix}proj_mlp.{lora_key}.bias"] = torch.cat([lora_bias]) - else: - q, k, v, mlp = torch.split( - original_state_dict.pop(f"single_blocks.{i}.linear1.{lora_key}.weight"), split_size, dim=0 - ) - converted_state_dict[f"{block_prefix}attn.to_q.{lora_key}.weight"] = torch.cat([q]) - converted_state_dict[f"{block_prefix}attn.to_k.{lora_key}.weight"] = torch.cat([k]) - converted_state_dict[f"{block_prefix}attn.to_v.{lora_key}.weight"] = torch.cat([v]) - converted_state_dict[f"{block_prefix}proj_mlp.{lora_key}.weight"] = torch.cat([mlp]) - - if f"single_blocks.{i}.linear1.{lora_key}.bias" in original_state_dict_keys: - q_bias, k_bias, v_bias, mlp_bias = torch.split( - original_state_dict.pop(f"single_blocks.{i}.linear1.{lora_key}.bias"), split_size, dim=0 - ) - converted_state_dict[f"{block_prefix}attn.to_q.{lora_key}.bias"] = torch.cat([q_bias]) - converted_state_dict[f"{block_prefix}attn.to_k.{lora_key}.bias"] = torch.cat([k_bias]) - converted_state_dict[f"{block_prefix}attn.to_v.{lora_key}.bias"] = torch.cat([v_bias]) - converted_state_dict[f"{block_prefix}proj_mlp.{lora_key}.bias"] = torch.cat([mlp_bias]) - - # output projections. - converted_state_dict[f"{block_prefix}proj_out.{lora_key}.weight"] = original_state_dict.pop( - f"single_blocks.{i}.linear2.{lora_key}.weight" - ) - if f"single_blocks.{i}.linear2.{lora_key}.bias" in original_state_dict_keys: - converted_state_dict[f"{block_prefix}proj_out.{lora_key}.bias"] = original_state_dict.pop( - f"single_blocks.{i}.linear2.{lora_key}.bias" - ) - - # qk norm - converted_state_dict[f"{block_prefix}attn.norm_q.weight"] = original_state_dict.pop( - f"single_blocks.{i}.norm.query_norm.scale" - ) - converted_state_dict[f"{block_prefix}attn.norm_k.weight"] = original_state_dict.pop( - f"single_blocks.{i}.norm.key_norm.scale" - ) - - for lora_key in ["lora_A", "lora_B"]: - converted_state_dict[f"proj_out.{lora_key}.weight"] = original_state_dict.pop( - f"final_layer.linear.{lora_key}.weight" - ) - if f"final_layer.linear.{lora_key}.bias" in original_state_dict_keys: - converted_state_dict[f"proj_out.{lora_key}.bias"] = original_state_dict.pop( - f"final_layer.linear.{lora_key}.bias" - ) - - converted_state_dict[f"norm_out.linear.{lora_key}.weight"] = swap_scale_shift( - original_state_dict.pop(f"final_layer.adaLN_modulation.1.{lora_key}.weight") - ) - if f"final_layer.adaLN_modulation.1.{lora_key}.bias" in original_state_dict_keys: - converted_state_dict[f"norm_out.linear.{lora_key}.bias"] = swap_scale_shift( - original_state_dict.pop(f"final_layer.adaLN_modulation.1.{lora_key}.bias") - ) - - if len(original_state_dict) > 0: - raise ValueError(f"`original_state_dict` should be empty at this point but has {original_state_dict.keys()=}.") - - for key in list(converted_state_dict.keys()): - converted_state_dict[f"transformer.{key}"] = converted_state_dict.pop(key) - - return converted_state_dict - - -def main(args): - original_ckpt = load_original_checkpoint(args) - - num_layers = 19 - num_single_layers = 38 - inner_dim = 3072 - mlp_ratio = 4.0 - - converted_control_lora_state_dict = convert_flux_control_lora_checkpoint_to_diffusers( - original_ckpt, num_layers, num_single_layers, inner_dim, mlp_ratio - ) - safetensors.torch.save_file(converted_control_lora_state_dict, args.output_path) - - -if __name__ == "__main__": - main(args) diff --git a/src/diffusers/loaders/lora_conversion_utils.py b/src/diffusers/loaders/lora_conversion_utils.py index 51a406b2f6a3..85d2987f0b70 100644 --- a/src/diffusers/loaders/lora_conversion_utils.py +++ b/src/diffusers/loaders/lora_conversion_utils.py @@ -663,3 +663,210 @@ def handle_qkv(sds_sd, ait_sd, sds_key, ait_keys, dims=None): raise ValueError(f"`old_state_dict` should be at this point but has: {list(old_state_dict.keys())}.") return new_state_dict + + +def _convert_bfl_flux_control_lora_to_diffusers(old_state_dict): + # in SD3 original implementation of AdaLayerNormContinuous, it split linear projection output into shift, scale; + # while in diffusers it split into scale, shift. Here we swap the linear projection weights in order to be able to use diffusers implementation + def swap_scale_shift(weight): + shift, scale = weight.chunk(2, dim=0) + new_weight = torch.cat([scale, shift], dim=0) + return new_weight + + def remap_double_blocks(key, converted_state_dict, original_state_dict): + block_index = int(key.split(".")[1]) + + if "img_attn.qkv.lora_A" in key: + sample_lora_weight = original_state_dict.pop(f"double_blocks.{block_index}.img_attn.qkv.lora_A.weight") + converted_state_dict[f"transformer_blocks.{block_index}.attn.to_v.lora_A.weight"] = torch.cat( + [sample_lora_weight] + ) + converted_state_dict[f"transformer_blocks.{block_index}.attn.to_q.lora_A.weight"] = torch.cat( + [sample_lora_weight] + ) + converted_state_dict[f"transformer_blocks.{block_index}.attn.to_k.lora_A.weight"] = torch.cat( + [sample_lora_weight] + ) + + elif "txt_attn.qkv.lora_A" in key: + context_lora_weight = original_state_dict.pop(f"double_blocks.{block_index}.txt_attn.qkv.lora_A.weight") + converted_state_dict[f"transformer_blocks.{block_index}.attn.add_q_proj.lora_A.weight"] = torch.cat( + [context_lora_weight] + ) + converted_state_dict[f"transformer_blocks.{block_index}.attn.add_k_proj.lora_A.weight"] = torch.cat( + [context_lora_weight] + ) + converted_state_dict[f"transformer_blocks.{block_index}.attn.add_v_proj.lora_A.weight"] = torch.cat( + [context_lora_weight] + ) + + elif "img_attn.qkv.lora_B.weight" in key: + sample_q, sample_k, sample_v = torch.chunk( + original_state_dict.pop(f"double_blocks.{block_index}.img_attn.qkv.lora_B.weight"), 3, dim=0 + ) + converted_state_dict[f"transformer_blocks.{block_index}.attn.to_q.lora_B.weight"] = torch.cat([sample_q]) + converted_state_dict[f"transformer_blocks.{block_index}.attn.to_k.lora_B.weight"] = torch.cat([sample_k]) + converted_state_dict[f"transformer_blocks.{block_index}.attn.to_v.lora_B.weight"] = torch.cat([sample_v]) + + elif "img_attn.qkv.lora_B.bias" in key: + sample_q_bias, sample_k_bias, sample_v_bias = torch.chunk( + original_state_dict.pop(f"double_blocks.{block_index}.img_attn.qkv.lora_B.bias"), 3, dim=0 + ) + converted_state_dict[f"transformer_blocks.{block_index}.attn.to_q.lora_B.bias"] = torch.cat( + [sample_q_bias] + ) + converted_state_dict[f"transformer_blocks.{block_index}.attn.to_k.lora_B.bias"] = torch.cat( + [sample_k_bias] + ) + converted_state_dict[f"transformer_blocks.{block_index}.attn.to_v.lora_B.bias"] = torch.cat( + [sample_v_bias] + ) + + elif "txt_attn.qkv.lora_B.weight" in key: + context_q, context_k, context_v = torch.chunk( + original_state_dict.pop(f"double_blocks.{block_index}.txt_attn.qkv.lora_B.weight"), 3, dim=0 + ) + converted_state_dict[f"transformer_blocks.{block_index}.attn.add_q_proj.lora_B.weight"] = torch.cat( + [context_q] + ) + converted_state_dict[f"transformer_blocks.{block_index}.attn.add_k_proj.lora_B.weight"] = torch.cat( + [context_k] + ) + converted_state_dict[f"transformer_blocks.{block_index}.attn.add_v_proj.lora_B.weight"] = torch.cat( + [context_v] + ) + + elif "txt_attn.qkv.lora_B.bias" in key: + context_q_bias, context_k_bias, context_v_bias = torch.chunk( + original_state_dict.pop(f"double_blocks.{block_index}.txt_attn.qkv.lora_B.bias"), 3, dim=0 + ) + converted_state_dict[f"transformer_blocks.{block_index}.attn.add_q_proj.lora_B.bias"] = torch.cat( + [context_q_bias] + ) + converted_state_dict[f"transformer_blocks.{block_index}.attn.add_k_proj.lora_B.bias"] = torch.cat( + [context_k_bias] + ) + converted_state_dict[f"transformer_blocks.{block_index}.attn.add_v_proj.lora_B.bias"] = torch.cat( + [context_v_bias] + ) + + else: + new_key = key.replace("double_blocks", "transformer_blocks") + new_key = new_key.replace("img_mod.lin", "norm1.linear") + new_key = new_key.replace("txt_mod.lin", "norm1_context.linear") + new_key = new_key.replace("img_mlp.0", "ff.net.0.proj") + new_key = new_key.replace("img_mlp.2", "ff.net.2.proj") + new_key = new_key.replace("txt_mlp.0", "ff_context.net.0.proj") + new_key = new_key.replace("txt_mlp.2", "ff_context.net.2.proj") + new_key = new_key.replace("img_attn.proj", "attn.to_out.0") + new_key = new_key.replace("img_attn.norm.query_norm.scale", "attn.norm_q.weight") + new_key = new_key.replace("img_attn.norm.key_norm.scale", "attn.norm_k.weight") + new_key = new_key.replace("txt_attn.proj", "attn.to_add_out.0") + # new_key = new_key.replace("txt_attn.norm.query_norm.scale", "attn.norm_added_q.weight") + # new_key = new_key.replace("txt_attn.norm.key_norm.scale", "attn.norm_added_k.weight") + converted_state_dict[new_key] = original_state_dict.pop(key) + + def remap_single_blocks(key, converted_state_dict, original_state_dict): + block_index = int(key.split(".")[1]) + + # Hardcoded for now. Can try to infer from state dict in future + inner_dim = 3072 + mlp_ratio = 4.0 + mlp_hidden_dim = int(inner_dim * mlp_ratio) + split_size = (inner_dim, inner_dim, inner_dim, mlp_hidden_dim) + + if "linear1.lora_A.weight" in key: + lora_weight = original_state_dict.pop(f"single_blocks.{block_index}.linear1.lora_A.weight") + converted_state_dict[f"single_transformer_blocks.{block_index}.attn.to_q.lora_A.weight"] = torch.cat( + [lora_weight] + ) + converted_state_dict[f"single_transformer_blocks.{block_index}.attn.to_k.lora_A.weight"] = torch.cat( + [lora_weight] + ) + converted_state_dict[f"single_transformer_blocks.{block_index}.attn.to_v.lora_A.weight"] = torch.cat( + [lora_weight] + ) + converted_state_dict[f"single_transformer_blocks.{block_index}.proj_mlp.lora_A.weight"] = torch.cat( + [lora_weight] + ) + + elif "linear1.lora_B.weight" in key: + q, k, v, mlp = torch.split( + original_state_dict.pop(f"single_blocks.{block_index}.linear1.lora_B.weight"), split_size, dim=0 + ) + converted_state_dict[f"single_transformer_blocks.{block_index}.attn.to_q.lora_B.weight"] = torch.cat([q]) + converted_state_dict[f"single_transformer_blocks.{block_index}.attn.to_k.lora_B.weight"] = torch.cat([k]) + converted_state_dict[f"single_transformer_blocks.{block_index}.attn.to_v.lora_B.weight"] = torch.cat([v]) + converted_state_dict[f"single_transformer_blocks.{block_index}.proj_mlp.lora_B.weight"] = torch.cat([mlp]) + + elif "linear1.lora_B.bias" in key: + q_bias, k_bias, v_bias, mlp_bias = torch.split( + original_state_dict.pop(f"single_blocks.{block_index}.linear1.lora_B.bias"), split_size, dim=0 + ) + converted_state_dict[f"single_transformer_blocks.{block_index}.attn.to_q.lora_B.bias"] = torch.cat( + [q_bias] + ) + converted_state_dict[f"single_transformer_blocks.{block_index}.attn.to_k.lora_B.bias"] = torch.cat( + [k_bias] + ) + converted_state_dict[f"single_transformer_blocks.{block_index}.attn.to_v.lora_B.bias"] = torch.cat( + [v_bias] + ) + converted_state_dict[f"single_transformer_blocks.{block_index}.proj_mlp.lora_B.bias"] = torch.cat( + [mlp_bias] + ) + + else: + new_key = key.replace("modulation.lin", "norm.linear") + new_key = new_key.replace("linear2", "proj_out") + # new_key = new_key.replace("norm.query_norm.scale", "attn.norm_q.weight") + # new_key = new_key.replace("norm.key_norm.scale", "attn.norm_k.weight") + converted_state_dict[new_key] = original_state_dict.pop(key) + + def remap_final_layer(key, converted_state_dict, original_state_dict): + new_key = key.replace("final_layer.linear", "proj_out") + new_key = new_key.replace("final_layer.adaLN_modulation.1", "norm_out.linear") + converted_state_dict[new_key] = swap_scale_shift(original_state_dict.pop(key)) + + TRANSFORMER_KEYS_RENAME_DICT = { + "time_in.in_layer": "time_text_embed.timestep_embedder.linear_1", + "time_in.out_layer": "time_text_embed.timestep_embedder.linear_2", + "vector_in.in_layer": "time_text_embed.text_embedder.linear_1", + "vector_in.out_layer": "time_text_embed.text_embedder.linear_2", + "guidance_in.in_layer": "time_text_embed.guidance_embedder.linear_1", + "guidance_in.out_layer": "time_text_embed.guidance_embedder.linear_2", + "txt_in": "context_embedder", + "img_in": "x_embedder", + "final_layer.linear": "proj_out", + "final_layer.adaLN_modulation.1": "norm_out.linear", + } + + TRANSFORMER_SPECIAL_KEYS_REMAP = { + "double_blocks": remap_double_blocks, + "single_blocks": remap_single_blocks, + "final_layer": remap_final_layer, + } + + new_state_dict = {} + + for key in list(old_state_dict.keys()): + new_key = key + for replace_key, rename_key in TRANSFORMER_KEYS_RENAME_DICT.items(): + new_key = new_key.replace(replace_key, rename_key) + if new_key != key: + new_state_dict[new_key] = old_state_dict.pop(key) + + for key in list(old_state_dict.keys()): + for special_key, handler_fn_inplace in TRANSFORMER_SPECIAL_KEYS_REMAP.items(): + if special_key not in key: + continue + handler_fn_inplace(key, new_state_dict, old_state_dict) + + # prefix everything with transformer + for key in list(new_state_dict.keys()): + new_state_dict[f"transformer.{key}"] = new_state_dict.pop(key) + + if len(old_state_dict) > 0: + raise ValueError(f"`original_state_dict` should be empty at this point but has {old_state_dict.keys()=}.") + + return new_state_dict diff --git a/src/diffusers/loaders/lora_pipeline.py b/src/diffusers/loaders/lora_pipeline.py index f91d35de81e4..4653afcf08cb 100644 --- a/src/diffusers/loaders/lora_pipeline.py +++ b/src/diffusers/loaders/lora_pipeline.py @@ -35,6 +35,7 @@ ) from .lora_base import LORA_WEIGHT_NAME, LORA_WEIGHT_NAME_SAFE, LoraBaseMixin, _fetch_state_dict # noqa from .lora_conversion_utils import ( + _convert_bfl_flux_control_lora_to_diffusers, _convert_kohya_flux_lora_to_diffusers, _convert_non_diffusers_lora_to_diffusers, _convert_xlabs_flux_lora_to_diffusers, @@ -1760,6 +1761,11 @@ def lora_state_dict( # xlabs doesn't use `alpha`. return (state_dict, None) if return_alphas else state_dict + is_bfl_control = any("query_norm.scale" in k for k in state_dict) + if is_bfl_control: + state_dict = _convert_bfl_flux_control_lora_to_diffusers(state_dict) + return (state_dict, None) if return_alphas else state_dict + # For state dicts like # https://huggingface.co/TheLastBen/Jon_Snow_Flux_LoRA keys = list(state_dict.keys()) From 64c821b85428487af43081e2ee6a9ef38750bcb0 Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Mon, 2 Dec 2024 17:00:29 +0530 Subject: [PATCH 33/58] bias -> lora_bias. --- src/diffusers/loaders/lora_pipeline.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/diffusers/loaders/lora_pipeline.py b/src/diffusers/loaders/lora_pipeline.py index 4653afcf08cb..777c9e57e2e0 100644 --- a/src/diffusers/loaders/lora_pipeline.py +++ b/src/diffusers/loaders/lora_pipeline.py @@ -2537,7 +2537,7 @@ def load_lora_into_text_encoder( if lora_config_kwargs["lora_bias"]: if is_peft_version("<=", "0.13.2"): raise ValueError( - "You need `peft` 0.13.3 at least to use `bias` in LoRAs. Please upgrade your installation of `peft`." + "You need `peft` 0.13.3 at least to use `lora_bias` in LoRAs. Please upgrade your installation of `peft`." ) else: if is_peft_version("<=", "0.13.2"): From 30a89a6a3f53cc51a3c2b8a69d2e0fda3c6deaca Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Mon, 2 Dec 2024 17:01:11 +0530 Subject: [PATCH 34/58] fix-copies --- src/diffusers/loaders/lora_pipeline.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/diffusers/loaders/lora_pipeline.py b/src/diffusers/loaders/lora_pipeline.py index 777c9e57e2e0..4653afcf08cb 100644 --- a/src/diffusers/loaders/lora_pipeline.py +++ b/src/diffusers/loaders/lora_pipeline.py @@ -2537,7 +2537,7 @@ def load_lora_into_text_encoder( if lora_config_kwargs["lora_bias"]: if is_peft_version("<=", "0.13.2"): raise ValueError( - "You need `peft` 0.13.3 at least to use `lora_bias` in LoRAs. Please upgrade your installation of `peft`." + "You need `peft` 0.13.3 at least to use `bias` in LoRAs. Please upgrade your installation of `peft`." ) else: if is_peft_version("<=", "0.13.2"): From bca1eaa02bbd86ef7f1ba7ba2c40357c962f79c9 Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Mon, 2 Dec 2024 17:01:42 +0530 Subject: [PATCH 35/58] peft.py --- src/diffusers/loaders/peft.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/diffusers/loaders/peft.py b/src/diffusers/loaders/peft.py index b39936df84a1..fa2139143a82 100644 --- a/src/diffusers/loaders/peft.py +++ b/src/diffusers/loaders/peft.py @@ -283,7 +283,7 @@ def load_lora_adapter(self, pretrained_model_name_or_path_or_dict, prefix="trans if lora_config_kwargs["lora_bias"]: if is_peft_version("<=", "0.13.2"): raise ValueError( - "You need `peft` 0.13.3 at least to use `bias` in LoRAs. Please upgrade your installation of `peft`." + "You need `peft` 0.13.3 at least to use `lora_bias` in LoRAs. Please upgrade your installation of `peft`." ) else: if is_peft_version("<=", "0.13.2"): From e7df1978f0e5ea3a6dda068344876493d13ecc5e Mon Sep 17 00:00:00 2001 From: Aryan Date: Mon, 2 Dec 2024 21:51:08 +0100 Subject: [PATCH 36/58] fix lora conversion --- src/diffusers/loaders/lora_conversion_utils.py | 15 ++++++--------- 1 file changed, 6 insertions(+), 9 deletions(-) diff --git a/src/diffusers/loaders/lora_conversion_utils.py b/src/diffusers/loaders/lora_conversion_utils.py index 85d2987f0b70..dae488641b0e 100644 --- a/src/diffusers/loaders/lora_conversion_utils.py +++ b/src/diffusers/loaders/lora_conversion_utils.py @@ -666,7 +666,7 @@ def handle_qkv(sds_sd, ait_sd, sds_key, ait_keys, dims=None): def _convert_bfl_flux_control_lora_to_diffusers(old_state_dict): - # in SD3 original implementation of AdaLayerNormContinuous, it split linear projection output into shift, scale; + # in Flux original implementation of AdaLayerNormContinuous, it split linear projection output into shift, scale; # while in diffusers it split into scale, shift. Here we swap the linear projection weights in order to be able to use diffusers implementation def swap_scale_shift(weight): shift, scale = weight.chunk(2, dim=0) @@ -755,15 +755,13 @@ def remap_double_blocks(key, converted_state_dict, original_state_dict): new_key = new_key.replace("img_mod.lin", "norm1.linear") new_key = new_key.replace("txt_mod.lin", "norm1_context.linear") new_key = new_key.replace("img_mlp.0", "ff.net.0.proj") - new_key = new_key.replace("img_mlp.2", "ff.net.2.proj") + new_key = new_key.replace("img_mlp.2", "ff.net.2") new_key = new_key.replace("txt_mlp.0", "ff_context.net.0.proj") - new_key = new_key.replace("txt_mlp.2", "ff_context.net.2.proj") + new_key = new_key.replace("txt_mlp.2", "ff_context.net.2") new_key = new_key.replace("img_attn.proj", "attn.to_out.0") new_key = new_key.replace("img_attn.norm.query_norm.scale", "attn.norm_q.weight") new_key = new_key.replace("img_attn.norm.key_norm.scale", "attn.norm_k.weight") - new_key = new_key.replace("txt_attn.proj", "attn.to_add_out.0") - # new_key = new_key.replace("txt_attn.norm.query_norm.scale", "attn.norm_added_q.weight") - # new_key = new_key.replace("txt_attn.norm.key_norm.scale", "attn.norm_added_k.weight") + new_key = new_key.replace("txt_attn.proj", "attn.to_add_out") converted_state_dict[new_key] = original_state_dict.pop(key) def remap_single_blocks(key, converted_state_dict, original_state_dict): @@ -817,10 +815,9 @@ def remap_single_blocks(key, converted_state_dict, original_state_dict): ) else: - new_key = key.replace("modulation.lin", "norm.linear") + new_key = key.replace("single_blocks", "single_transformer_blocks") + new_key = new_key.replace("modulation.lin", "norm.linear") new_key = new_key.replace("linear2", "proj_out") - # new_key = new_key.replace("norm.query_norm.scale", "attn.norm_q.weight") - # new_key = new_key.replace("norm.key_norm.scale", "attn.norm_k.weight") converted_state_dict[new_key] = original_state_dict.pop(key) def remap_final_layer(key, converted_state_dict, original_state_dict): From 5fd9fda95b44b05b67768c58492a388dec865b61 Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Tue, 3 Dec 2024 12:39:27 +0530 Subject: [PATCH 37/58] changes Co-authored-by: a-r-r-o-w --- .../loaders/lora_conversion_utils.py | 448 +++++++++++------- src/diffusers/loaders/lora_pipeline.py | 18 +- .../pipelines/flux/pipeline_flux_control.py | 4 +- 3 files changed, 287 insertions(+), 183 deletions(-) diff --git a/src/diffusers/loaders/lora_conversion_utils.py b/src/diffusers/loaders/lora_conversion_utils.py index dae488641b0e..8e9bb079b43a 100644 --- a/src/diffusers/loaders/lora_conversion_utils.py +++ b/src/diffusers/loaders/lora_conversion_utils.py @@ -665,205 +665,307 @@ def handle_qkv(sds_sd, ait_sd, sds_key, ait_keys, dims=None): return new_state_dict -def _convert_bfl_flux_control_lora_to_diffusers(old_state_dict): - # in Flux original implementation of AdaLayerNormContinuous, it split linear projection output into shift, scale; - # while in diffusers it split into scale, shift. Here we swap the linear projection weights in order to be able to use diffusers implementation +def _convert_bfl_flux_control_lora_to_diffusers(original_state_dict): + converted_state_dict = {} + original_state_dict_keys = list(original_state_dict.keys()) + num_layers = 19 + num_single_layers = 38 + inner_dim = 3072 + mlp_ratio = 4.0 + def swap_scale_shift(weight): shift, scale = weight.chunk(2, dim=0) new_weight = torch.cat([scale, shift], dim=0) return new_weight - def remap_double_blocks(key, converted_state_dict, original_state_dict): - block_index = int(key.split(".")[1]) - - if "img_attn.qkv.lora_A" in key: - sample_lora_weight = original_state_dict.pop(f"double_blocks.{block_index}.img_attn.qkv.lora_A.weight") - converted_state_dict[f"transformer_blocks.{block_index}.attn.to_v.lora_A.weight"] = torch.cat( - [sample_lora_weight] - ) - converted_state_dict[f"transformer_blocks.{block_index}.attn.to_q.lora_A.weight"] = torch.cat( - [sample_lora_weight] - ) - converted_state_dict[f"transformer_blocks.{block_index}.attn.to_k.lora_A.weight"] = torch.cat( - [sample_lora_weight] - ) - - elif "txt_attn.qkv.lora_A" in key: - context_lora_weight = original_state_dict.pop(f"double_blocks.{block_index}.txt_attn.qkv.lora_A.weight") - converted_state_dict[f"transformer_blocks.{block_index}.attn.add_q_proj.lora_A.weight"] = torch.cat( - [context_lora_weight] - ) - converted_state_dict[f"transformer_blocks.{block_index}.attn.add_k_proj.lora_A.weight"] = torch.cat( - [context_lora_weight] - ) - converted_state_dict[f"transformer_blocks.{block_index}.attn.add_v_proj.lora_A.weight"] = torch.cat( - [context_lora_weight] - ) - - elif "img_attn.qkv.lora_B.weight" in key: - sample_q, sample_k, sample_v = torch.chunk( - original_state_dict.pop(f"double_blocks.{block_index}.img_attn.qkv.lora_B.weight"), 3, dim=0 - ) - converted_state_dict[f"transformer_blocks.{block_index}.attn.to_q.lora_B.weight"] = torch.cat([sample_q]) - converted_state_dict[f"transformer_blocks.{block_index}.attn.to_k.lora_B.weight"] = torch.cat([sample_k]) - converted_state_dict[f"transformer_blocks.{block_index}.attn.to_v.lora_B.weight"] = torch.cat([sample_v]) + for lora_key in ["lora_A", "lora_B"]: + ## time_text_embed.timestep_embedder <- time_in + converted_state_dict[ + f"time_text_embed.timestep_embedder.linear_1.{lora_key}.weight" + ] = original_state_dict.pop(f"time_in.in_layer.{lora_key}.weight") + if f"time_in.in_layer.{lora_key}.bias" in original_state_dict_keys: + converted_state_dict[ + f"time_text_embed.timestep_embedder.linear_1.{lora_key}.bias" + ] = original_state_dict.pop(f"time_in.in_layer.{lora_key}.bias") + + converted_state_dict[ + f"time_text_embed.timestep_embedder.linear_2.{lora_key}.weight" + ] = original_state_dict.pop(f"time_in.out_layer.{lora_key}.weight") + if f"time_in.out_layer.{lora_key}.bias" in original_state_dict_keys: + converted_state_dict[ + f"time_text_embed.timestep_embedder.linear_2.{lora_key}.bias" + ] = original_state_dict.pop(f"time_in.out_layer.{lora_key}.bias") + + ## time_text_embed.text_embedder <- vector_in + converted_state_dict[f"time_text_embed.text_embedder.linear_1.{lora_key}.weight"] = original_state_dict.pop( + f"vector_in.in_layer.{lora_key}.weight" + ) + if f"vector_in.in_layer.{lora_key}.bias" in original_state_dict_keys: + converted_state_dict[f"time_text_embed.text_embedder.linear_1.{lora_key}.bias"] = original_state_dict.pop( + f"vector_in.in_layer.{lora_key}.bias" + ) + + converted_state_dict[f"time_text_embed.text_embedder.linear_2.{lora_key}.weight"] = original_state_dict.pop( + f"vector_in.out_layer.{lora_key}.weight" + ) + if f"vector_in.out_layer.{lora_key}.bias" in original_state_dict_keys: + converted_state_dict[f"time_text_embed.text_embedder.linear_2.{lora_key}.bias"] = original_state_dict.pop( + f"vector_in.out_layer.{lora_key}.bias" + ) + + # guidance + has_guidance = any("guidance" in k for k in original_state_dict) + if has_guidance: + converted_state_dict[ + f"time_text_embed.guidance_embedder.linear_1.{lora_key}.weight" + ] = original_state_dict.pop(f"guidance_in.in_layer.{lora_key}.weight") + if f"guidance_in.in_layer.{lora_key}.bias" in original_state_dict_keys: + converted_state_dict[ + f"time_text_embed.guidance_embedder.linear_1.{lora_key}.bias" + ] = original_state_dict.pop(f"guidance_in.in_layer.{lora_key}.bias") + + converted_state_dict[ + f"time_text_embed.guidance_embedder.linear_2.{lora_key}.weight" + ] = original_state_dict.pop(f"guidance_in.out_layer.{lora_key}.weight") + if f"guidance_in.out_layer.{lora_key}.bias" in original_state_dict_keys: + converted_state_dict[ + f"time_text_embed.guidance_embedder.linear_2.{lora_key}.bias" + ] = original_state_dict.pop(f"guidance_in.out_layer.{lora_key}.bias") + + # context_embedder + converted_state_dict[f"context_embedder.{lora_key}.weight"] = original_state_dict.pop( + f"txt_in.{lora_key}.weight" + ) + if f"txt_in.{lora_key}.bias" in original_state_dict_keys: + converted_state_dict[f"context_embedder.{lora_key}.bias"] = original_state_dict.pop( + f"txt_in.{lora_key}.bias" + ) + + # x_embedder + converted_state_dict[f"x_embedder.{lora_key}.weight"] = original_state_dict.pop(f"img_in.{lora_key}.weight") + if f"img_in.{lora_key}.bias" in original_state_dict_keys: + converted_state_dict[f"x_embedder.{lora_key}.bias"] = original_state_dict.pop(f"img_in.{lora_key}.bias") + + # double transformer blocks + for i in range(num_layers): + block_prefix = f"transformer_blocks.{i}." + + for lora_key, lora_key in zip(["lora_A", "lora_B"], ["lora_A", "lora_B"]): + # norms + converted_state_dict[f"{block_prefix}norm1.linear.{lora_key}.weight"] = original_state_dict.pop( + f"double_blocks.{i}.img_mod.lin.{lora_key}.weight" + ) + if f"double_blocks.{i}.img_mod.lin.{lora_key}.bias" in original_state_dict_keys: + converted_state_dict[f"{block_prefix}norm1.linear.{lora_key}.bias"] = original_state_dict.pop( + f"double_blocks.{i}.img_mod.lin.{lora_key}.bias" + ) - elif "img_attn.qkv.lora_B.bias" in key: - sample_q_bias, sample_k_bias, sample_v_bias = torch.chunk( - original_state_dict.pop(f"double_blocks.{block_index}.img_attn.qkv.lora_B.bias"), 3, dim=0 - ) - converted_state_dict[f"transformer_blocks.{block_index}.attn.to_q.lora_B.bias"] = torch.cat( - [sample_q_bias] - ) - converted_state_dict[f"transformer_blocks.{block_index}.attn.to_k.lora_B.bias"] = torch.cat( - [sample_k_bias] - ) - converted_state_dict[f"transformer_blocks.{block_index}.attn.to_v.lora_B.bias"] = torch.cat( - [sample_v_bias] + converted_state_dict[f"{block_prefix}norm1_context.linear.{lora_key}.weight"] = original_state_dict.pop( + f"double_blocks.{i}.txt_mod.lin.{lora_key}.weight" ) + if f"double_blocks.{i}.txt_mod.lin.{lora_key}.bias" in original_state_dict_keys: + converted_state_dict[f"{block_prefix}norm1_context.linear.{lora_key}.bias"] = original_state_dict.pop( + f"double_blocks.{i}.txt_mod.lin.{lora_key}.bias" + ) - elif "txt_attn.qkv.lora_B.weight" in key: - context_q, context_k, context_v = torch.chunk( - original_state_dict.pop(f"double_blocks.{block_index}.txt_attn.qkv.lora_B.weight"), 3, dim=0 - ) - converted_state_dict[f"transformer_blocks.{block_index}.attn.add_q_proj.lora_B.weight"] = torch.cat( - [context_q] - ) - converted_state_dict[f"transformer_blocks.{block_index}.attn.add_k_proj.lora_B.weight"] = torch.cat( - [context_k] - ) - converted_state_dict[f"transformer_blocks.{block_index}.attn.add_v_proj.lora_B.weight"] = torch.cat( - [context_v] - ) + # Q, K, V + if lora_key == "lora_A": + sample_lora_weight = original_state_dict.pop(f"double_blocks.{i}.img_attn.qkv.{lora_key}.weight") + converted_state_dict[f"{block_prefix}attn.to_v.{lora_key}.weight"] = torch.cat([sample_lora_weight]) + converted_state_dict[f"{block_prefix}attn.to_q.{lora_key}.weight"] = torch.cat([sample_lora_weight]) + converted_state_dict[f"{block_prefix}attn.to_k.{lora_key}.weight"] = torch.cat([sample_lora_weight]) - elif "txt_attn.qkv.lora_B.bias" in key: - context_q_bias, context_k_bias, context_v_bias = torch.chunk( - original_state_dict.pop(f"double_blocks.{block_index}.txt_attn.qkv.lora_B.bias"), 3, dim=0 - ) - converted_state_dict[f"transformer_blocks.{block_index}.attn.add_q_proj.lora_B.bias"] = torch.cat( - [context_q_bias] - ) - converted_state_dict[f"transformer_blocks.{block_index}.attn.add_k_proj.lora_B.bias"] = torch.cat( - [context_k_bias] - ) - converted_state_dict[f"transformer_blocks.{block_index}.attn.add_v_proj.lora_B.bias"] = torch.cat( - [context_v_bias] - ) + context_lora_weight = original_state_dict.pop(f"double_blocks.{i}.txt_attn.qkv.{lora_key}.weight") + converted_state_dict[f"{block_prefix}attn.add_q_proj.{lora_key}.weight"] = torch.cat( + [context_lora_weight] + ) + converted_state_dict[f"{block_prefix}attn.add_k_proj.{lora_key}.weight"] = torch.cat( + [context_lora_weight] + ) + converted_state_dict[f"{block_prefix}attn.add_v_proj.{lora_key}.weight"] = torch.cat( + [context_lora_weight] + ) + else: + sample_q, sample_k, sample_v = torch.chunk( + original_state_dict.pop(f"double_blocks.{i}.img_attn.qkv.{lora_key}.weight"), 3, dim=0 + ) + converted_state_dict[f"{block_prefix}attn.to_q.{lora_key}.weight"] = torch.cat([sample_q]) + converted_state_dict[f"{block_prefix}attn.to_k.{lora_key}.weight"] = torch.cat([sample_k]) + converted_state_dict[f"{block_prefix}attn.to_v.{lora_key}.weight"] = torch.cat([sample_v]) - else: - new_key = key.replace("double_blocks", "transformer_blocks") - new_key = new_key.replace("img_mod.lin", "norm1.linear") - new_key = new_key.replace("txt_mod.lin", "norm1_context.linear") - new_key = new_key.replace("img_mlp.0", "ff.net.0.proj") - new_key = new_key.replace("img_mlp.2", "ff.net.2") - new_key = new_key.replace("txt_mlp.0", "ff_context.net.0.proj") - new_key = new_key.replace("txt_mlp.2", "ff_context.net.2") - new_key = new_key.replace("img_attn.proj", "attn.to_out.0") - new_key = new_key.replace("img_attn.norm.query_norm.scale", "attn.norm_q.weight") - new_key = new_key.replace("img_attn.norm.key_norm.scale", "attn.norm_k.weight") - new_key = new_key.replace("txt_attn.proj", "attn.to_add_out") - converted_state_dict[new_key] = original_state_dict.pop(key) + context_q, context_k, context_v = torch.chunk( + original_state_dict.pop(f"double_blocks.{i}.txt_attn.qkv.{lora_key}.weight"), 3, dim=0 + ) + converted_state_dict[f"{block_prefix}attn.add_q_proj.{lora_key}.weight"] = torch.cat([context_q]) + converted_state_dict[f"{block_prefix}attn.add_k_proj.{lora_key}.weight"] = torch.cat([context_k]) + converted_state_dict[f"{block_prefix}attn.add_v_proj.{lora_key}.weight"] = torch.cat([context_v]) - def remap_single_blocks(key, converted_state_dict, original_state_dict): - block_index = int(key.split(".")[1]) + if f"double_blocks.{i}.img_attn.qkv.{lora_key}.bias" in original_state_dict_keys: + sample_q_bias, sample_k_bias, sample_v_bias = torch.chunk( + original_state_dict.pop(f"double_blocks.{i}.img_attn.qkv.{lora_key}.bias"), 3, dim=0 + ) + converted_state_dict[f"{block_prefix}attn.to_q.{lora_key}.bias"] = torch.cat([sample_q_bias]) + converted_state_dict[f"{block_prefix}attn.to_k.{lora_key}.bias"] = torch.cat([sample_k_bias]) + converted_state_dict[f"{block_prefix}attn.to_v.{lora_key}.bias"] = torch.cat([sample_v_bias]) - # Hardcoded for now. Can try to infer from state dict in future - inner_dim = 3072 - mlp_ratio = 4.0 - mlp_hidden_dim = int(inner_dim * mlp_ratio) - split_size = (inner_dim, inner_dim, inner_dim, mlp_hidden_dim) + if f"double_blocks.{i}.txt_attn.qkv.{lora_key}.bias" in original_state_dict_keys: + context_q_bias, context_k_bias, context_v_bias = torch.chunk( + original_state_dict.pop(f"double_blocks.{i}.txt_attn.qkv.{lora_key}.bias"), 3, dim=0 + ) + converted_state_dict[f"{block_prefix}attn.add_q_proj.{lora_key}.bias"] = torch.cat([context_q_bias]) + converted_state_dict[f"{block_prefix}attn.add_k_proj.{lora_key}.bias"] = torch.cat([context_k_bias]) + converted_state_dict[f"{block_prefix}attn.add_v_proj.{lora_key}.bias"] = torch.cat([context_v_bias]) - if "linear1.lora_A.weight" in key: - lora_weight = original_state_dict.pop(f"single_blocks.{block_index}.linear1.lora_A.weight") - converted_state_dict[f"single_transformer_blocks.{block_index}.attn.to_q.lora_A.weight"] = torch.cat( - [lora_weight] - ) - converted_state_dict[f"single_transformer_blocks.{block_index}.attn.to_k.lora_A.weight"] = torch.cat( - [lora_weight] - ) - converted_state_dict[f"single_transformer_blocks.{block_index}.attn.to_v.lora_A.weight"] = torch.cat( - [lora_weight] - ) - converted_state_dict[f"single_transformer_blocks.{block_index}.proj_mlp.lora_A.weight"] = torch.cat( - [lora_weight] + # ff img_mlp + converted_state_dict[f"{block_prefix}ff.net.0.proj.{lora_key}.weight"] = original_state_dict.pop( + f"double_blocks.{i}.img_mlp.0.{lora_key}.weight" ) + if f"double_blocks.{i}.img_mlp.0.{lora_key}.bias" in original_state_dict_keys: + converted_state_dict[f"{block_prefix}ff.net.0.proj.{lora_key}.bias"] = original_state_dict.pop( + f"double_blocks.{i}.img_mlp.0.{lora_key}.bias" + ) - elif "linear1.lora_B.weight" in key: - q, k, v, mlp = torch.split( - original_state_dict.pop(f"single_blocks.{block_index}.linear1.lora_B.weight"), split_size, dim=0 + converted_state_dict[f"{block_prefix}ff.net.2.{lora_key}.weight"] = original_state_dict.pop( + f"double_blocks.{i}.img_mlp.2.{lora_key}.weight" ) - converted_state_dict[f"single_transformer_blocks.{block_index}.attn.to_q.lora_B.weight"] = torch.cat([q]) - converted_state_dict[f"single_transformer_blocks.{block_index}.attn.to_k.lora_B.weight"] = torch.cat([k]) - converted_state_dict[f"single_transformer_blocks.{block_index}.attn.to_v.lora_B.weight"] = torch.cat([v]) - converted_state_dict[f"single_transformer_blocks.{block_index}.proj_mlp.lora_B.weight"] = torch.cat([mlp]) + if f"double_blocks.{i}.img_mlp.2.{lora_key}.bias" in original_state_dict_keys: + converted_state_dict[f"{block_prefix}ff.net.2.{lora_key}.bias"] = original_state_dict.pop( + f"double_blocks.{i}.img_mlp.2.{lora_key}.bias" + ) - elif "linear1.lora_B.bias" in key: - q_bias, k_bias, v_bias, mlp_bias = torch.split( - original_state_dict.pop(f"single_blocks.{block_index}.linear1.lora_B.bias"), split_size, dim=0 - ) - converted_state_dict[f"single_transformer_blocks.{block_index}.attn.to_q.lora_B.bias"] = torch.cat( - [q_bias] + converted_state_dict[f"{block_prefix}ff_context.net.0.proj.{lora_key}.weight"] = original_state_dict.pop( + f"double_blocks.{i}.txt_mlp.0.{lora_key}.weight" ) - converted_state_dict[f"single_transformer_blocks.{block_index}.attn.to_k.lora_B.bias"] = torch.cat( - [k_bias] + if f"double_blocks.{i}.txt_mlp.0.{lora_key}.bias" in original_state_dict_keys: + converted_state_dict[f"{block_prefix}ff_context.net.0.proj.{lora_key}.bias"] = original_state_dict.pop( + f"double_blocks.{i}.txt_mlp.0.{lora_key}.bias" + ) + + converted_state_dict[f"{block_prefix}ff_context.net.2.{lora_key}.weight"] = original_state_dict.pop( + f"double_blocks.{i}.txt_mlp.2.{lora_key}.weight" ) - converted_state_dict[f"single_transformer_blocks.{block_index}.attn.to_v.lora_B.bias"] = torch.cat( - [v_bias] + if f"double_blocks.{i}.txt_mlp.2.{lora_key}.bias" in original_state_dict_keys: + converted_state_dict[f"{block_prefix}ff_context.net.2.{lora_key}.bias"] = original_state_dict.pop( + f"double_blocks.{i}.txt_mlp.2.{lora_key}.bias" + ) + + # output projections. + converted_state_dict[f"{block_prefix}attn.to_out.0.{lora_key}.weight"] = original_state_dict.pop( + f"double_blocks.{i}.img_attn.proj.{lora_key}.weight" ) - converted_state_dict[f"single_transformer_blocks.{block_index}.proj_mlp.lora_B.bias"] = torch.cat( - [mlp_bias] + if f"double_blocks.{i}.img_attn.proj.{lora_key}.bias" in original_state_dict_keys: + converted_state_dict[f"{block_prefix}attn.to_out.0.{lora_key}.bias"] = original_state_dict.pop( + f"double_blocks.{i}.img_attn.proj.{lora_key}.bias" + ) + converted_state_dict[f"{block_prefix}attn.to_add_out.{lora_key}.weight"] = original_state_dict.pop( + f"double_blocks.{i}.txt_attn.proj.{lora_key}.weight" ) + if f"double_blocks.{i}.txt_attn.proj.{lora_key}.bias" in original_state_dict_keys: + converted_state_dict[f"{block_prefix}attn.to_add_out.{lora_key}.bias"] = original_state_dict.pop( + f"double_blocks.{i}.txt_attn.proj.{lora_key}.bias" + ) - else: - new_key = key.replace("single_blocks", "single_transformer_blocks") - new_key = new_key.replace("modulation.lin", "norm.linear") - new_key = new_key.replace("linear2", "proj_out") - converted_state_dict[new_key] = original_state_dict.pop(key) - - def remap_final_layer(key, converted_state_dict, original_state_dict): - new_key = key.replace("final_layer.linear", "proj_out") - new_key = new_key.replace("final_layer.adaLN_modulation.1", "norm_out.linear") - converted_state_dict[new_key] = swap_scale_shift(original_state_dict.pop(key)) - - TRANSFORMER_KEYS_RENAME_DICT = { - "time_in.in_layer": "time_text_embed.timestep_embedder.linear_1", - "time_in.out_layer": "time_text_embed.timestep_embedder.linear_2", - "vector_in.in_layer": "time_text_embed.text_embedder.linear_1", - "vector_in.out_layer": "time_text_embed.text_embedder.linear_2", - "guidance_in.in_layer": "time_text_embed.guidance_embedder.linear_1", - "guidance_in.out_layer": "time_text_embed.guidance_embedder.linear_2", - "txt_in": "context_embedder", - "img_in": "x_embedder", - "final_layer.linear": "proj_out", - "final_layer.adaLN_modulation.1": "norm_out.linear", - } + # qk_norm + converted_state_dict[f"{block_prefix}attn.norm_q.weight"] = original_state_dict.pop( + f"double_blocks.{i}.img_attn.norm.query_norm.scale" + ) + converted_state_dict[f"{block_prefix}attn.norm_k.weight"] = original_state_dict.pop( + f"double_blocks.{i}.img_attn.norm.key_norm.scale" + ) + converted_state_dict[f"{block_prefix}attn.norm_added_q.weight"] = original_state_dict.pop( + f"double_blocks.{i}.txt_attn.norm.query_norm.scale" + ) + converted_state_dict[f"{block_prefix}attn.norm_added_k.weight"] = original_state_dict.pop( + f"double_blocks.{i}.txt_attn.norm.key_norm.scale" + ) + + # single transfomer blocks + for i in range(num_single_layers): + block_prefix = f"single_transformer_blocks.{i}." + + for lora_key in ["lora_A", "lora_B"]: + # norm.linear <- single_blocks.0.modulation.lin + converted_state_dict[f"{block_prefix}norm.linear.{lora_key}.weight"] = original_state_dict.pop( + f"single_blocks.{i}.modulation.lin.{lora_key}.weight" + ) + if f"single_blocks.{i}.modulation.lin.{lora_key}.bias" in original_state_dict_keys: + converted_state_dict[f"{block_prefix}norm.linear.{lora_key}.bias"] = original_state_dict.pop( + f"single_blocks.{i}.modulation.lin.{lora_key}.bias" + ) - TRANSFORMER_SPECIAL_KEYS_REMAP = { - "double_blocks": remap_double_blocks, - "single_blocks": remap_single_blocks, - "final_layer": remap_final_layer, - } + # Q, K, V, mlp + mlp_hidden_dim = int(inner_dim * mlp_ratio) + split_size = (inner_dim, inner_dim, inner_dim, mlp_hidden_dim) + + if lora_key == "lora_A": + lora_weight = original_state_dict.pop(f"single_blocks.{i}.linear1.{lora_key}.weight") + converted_state_dict[f"{block_prefix}attn.to_q.{lora_key}.weight"] = torch.cat([lora_weight]) + converted_state_dict[f"{block_prefix}attn.to_k.{lora_key}.weight"] = torch.cat([lora_weight]) + converted_state_dict[f"{block_prefix}attn.to_v.{lora_key}.weight"] = torch.cat([lora_weight]) + converted_state_dict[f"{block_prefix}proj_mlp.{lora_key}.weight"] = torch.cat([lora_weight]) + + if f"single_blocks.{i}.linear1.{lora_key}.bias" in original_state_dict_keys: + lora_bias = original_state_dict.pop(f"single_blocks.{i}.linear1.{lora_key}.bias") + converted_state_dict[f"{block_prefix}attn.to_q.{lora_key}.bias"] = torch.cat([lora_bias]) + converted_state_dict[f"{block_prefix}attn.to_k.{lora_key}.bias"] = torch.cat([lora_bias]) + converted_state_dict[f"{block_prefix}attn.to_v.{lora_key}.bias"] = torch.cat([lora_bias]) + converted_state_dict[f"{block_prefix}proj_mlp.{lora_key}.bias"] = torch.cat([lora_bias]) + else: + q, k, v, mlp = torch.split( + original_state_dict.pop(f"single_blocks.{i}.linear1.{lora_key}.weight"), split_size, dim=0 + ) + converted_state_dict[f"{block_prefix}attn.to_q.{lora_key}.weight"] = torch.cat([q]) + converted_state_dict[f"{block_prefix}attn.to_k.{lora_key}.weight"] = torch.cat([k]) + converted_state_dict[f"{block_prefix}attn.to_v.{lora_key}.weight"] = torch.cat([v]) + converted_state_dict[f"{block_prefix}proj_mlp.{lora_key}.weight"] = torch.cat([mlp]) + + if f"single_blocks.{i}.linear1.{lora_key}.bias" in original_state_dict_keys: + q_bias, k_bias, v_bias, mlp_bias = torch.split( + original_state_dict.pop(f"single_blocks.{i}.linear1.{lora_key}.bias"), split_size, dim=0 + ) + converted_state_dict[f"{block_prefix}attn.to_q.{lora_key}.bias"] = torch.cat([q_bias]) + converted_state_dict[f"{block_prefix}attn.to_k.{lora_key}.bias"] = torch.cat([k_bias]) + converted_state_dict[f"{block_prefix}attn.to_v.{lora_key}.bias"] = torch.cat([v_bias]) + converted_state_dict[f"{block_prefix}proj_mlp.{lora_key}.bias"] = torch.cat([mlp_bias]) + + # output projections. + converted_state_dict[f"{block_prefix}proj_out.{lora_key}.weight"] = original_state_dict.pop( + f"single_blocks.{i}.linear2.{lora_key}.weight" + ) + if f"single_blocks.{i}.linear2.{lora_key}.bias" in original_state_dict_keys: + converted_state_dict[f"{block_prefix}proj_out.{lora_key}.bias"] = original_state_dict.pop( + f"single_blocks.{i}.linear2.{lora_key}.bias" + ) - new_state_dict = {} + # qk norm + converted_state_dict[f"{block_prefix}attn.norm_q.weight"] = original_state_dict.pop( + f"single_blocks.{i}.norm.query_norm.scale" + ) + converted_state_dict[f"{block_prefix}attn.norm_k.weight"] = original_state_dict.pop( + f"single_blocks.{i}.norm.key_norm.scale" + ) - for key in list(old_state_dict.keys()): - new_key = key - for replace_key, rename_key in TRANSFORMER_KEYS_RENAME_DICT.items(): - new_key = new_key.replace(replace_key, rename_key) - if new_key != key: - new_state_dict[new_key] = old_state_dict.pop(key) + for lora_key in ["lora_A", "lora_B"]: + converted_state_dict[f"proj_out.{lora_key}.weight"] = original_state_dict.pop( + f"final_layer.linear.{lora_key}.weight" + ) + if f"final_layer.linear.{lora_key}.bias" in original_state_dict_keys: + converted_state_dict[f"proj_out.{lora_key}.bias"] = original_state_dict.pop( + f"final_layer.linear.{lora_key}.bias" + ) - for key in list(old_state_dict.keys()): - for special_key, handler_fn_inplace in TRANSFORMER_SPECIAL_KEYS_REMAP.items(): - if special_key not in key: - continue - handler_fn_inplace(key, new_state_dict, old_state_dict) + converted_state_dict[f"norm_out.linear.{lora_key}.weight"] = swap_scale_shift( + original_state_dict.pop(f"final_layer.adaLN_modulation.1.{lora_key}.weight") + ) + if f"final_layer.adaLN_modulation.1.{lora_key}.bias" in original_state_dict_keys: + converted_state_dict[f"norm_out.linear.{lora_key}.bias"] = swap_scale_shift( + original_state_dict.pop(f"final_layer.adaLN_modulation.1.{lora_key}.bias") + ) - # prefix everything with transformer - for key in list(new_state_dict.keys()): - new_state_dict[f"transformer.{key}"] = new_state_dict.pop(key) + if len(original_state_dict) > 0: + raise ValueError(f"`original_state_dict` should be empty at this point but has {original_state_dict.keys()=}.") - if len(old_state_dict) > 0: - raise ValueError(f"`original_state_dict` should be empty at this point but has {old_state_dict.keys()=}.") + for key in list(converted_state_dict.keys()): + converted_state_dict[f"transformer.{key}"] = converted_state_dict.pop(key) - return new_state_dict + return converted_state_dict diff --git a/src/diffusers/loaders/lora_pipeline.py b/src/diffusers/loaders/lora_pipeline.py index 4653afcf08cb..52e1203823e6 100644 --- a/src/diffusers/loaders/lora_pipeline.py +++ b/src/diffusers/loaders/lora_pipeline.py @@ -2308,6 +2308,7 @@ def _maybe_expand_transformer_param_shape_or_error_( lora_A_weight_name = f"{name}.lora_A.weight" lora_B_weight_name = f"{name}.lora_B.weight" + lora_B_bias_name = f"{name}.lora_B.bias" if lora_A_weight_name not in state_dict.keys(): continue @@ -2349,13 +2350,16 @@ def _maybe_expand_transformer_param_shape_or_error_( new_weight[slices] = module_weight expanded_module.weight.data.copy_(new_weight) - if bias: - new_bias = torch.zeros_like( - expanded_module.bias.data, device=module_bias.device, dtype=module_bias.dtype - ) - slices = tuple(slice(0, dim) for dim in module_bias.shape) - new_bias[slices] = module_bias - expanded_module.bias.data.copy_(new_bias) + bias_present_for_lora_B = lora_B_bias_name in state_dict + if bias_present_for_lora_B: + new_bias_shape = state_dict[lora_B_bias_name].shape + if bias and module_bias.shape < new_bias_shape: + new_bias = torch.zeros_like( + expanded_module.bias.data, device=module_bias.device, dtype=module_bias.dtype + ) + slices = tuple(slice(0, dim) for dim in module_bias.shape) + new_bias[slices] = module_bias + expanded_module.bias.data.copy_(new_bias) setattr(parent_module, current_module_name, expanded_module) diff --git a/src/diffusers/pipelines/flux/pipeline_flux_control.py b/src/diffusers/pipelines/flux/pipeline_flux_control.py index 492bc520753d..8427c4916477 100644 --- a/src/diffusers/pipelines/flux/pipeline_flux_control.py +++ b/src/diffusers/pipelines/flux/pipeline_flux_control.py @@ -573,9 +573,7 @@ def prepare_image( do_classifier_free_guidance=False, guess_mode=False, ): - if isinstance(image, torch.Tensor): - pass - else: + if not isinstance(image, torch.Tensor): image = self.image_processor.preprocess(image, height=height, width=width) image_batch_size = image.shape[0] From a8c50ba728f2b718e6bc50b587528145870dfcf5 Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Tue, 3 Dec 2024 13:01:36 +0530 Subject: [PATCH 38/58] fix-copies --- src/diffusers/pipelines/flux/pipeline_flux_control.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/src/diffusers/pipelines/flux/pipeline_flux_control.py b/src/diffusers/pipelines/flux/pipeline_flux_control.py index 8427c4916477..492bc520753d 100644 --- a/src/diffusers/pipelines/flux/pipeline_flux_control.py +++ b/src/diffusers/pipelines/flux/pipeline_flux_control.py @@ -573,7 +573,9 @@ def prepare_image( do_classifier_free_guidance=False, guess_mode=False, ): - if not isinstance(image, torch.Tensor): + if isinstance(image, torch.Tensor): + pass + else: image = self.image_processor.preprocess(image, height=height, width=width) image_batch_size = image.shape[0] From b12f797bf3ad6b8c1e3271fd7c628b299ec43552 Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Tue, 3 Dec 2024 14:33:22 +0530 Subject: [PATCH 39/58] updates for tests --- tests/lora/test_lora_layers_flux.py | 60 ++++++++--------------------- 1 file changed, 15 insertions(+), 45 deletions(-) diff --git a/tests/lora/test_lora_layers_flux.py b/tests/lora/test_lora_layers_flux.py index 0396898ebcf9..d23a0b9f9a70 100644 --- a/tests/lora/test_lora_layers_flux.py +++ b/tests/lora/test_lora_layers_flux.py @@ -111,30 +111,6 @@ def get_dummy_inputs(self, with_generator=True): return noise, input_ids, pipeline_inputs - def get_dummy_tensor_inputs(self, device=None): - batch_size = 1 - num_latent_channels = 4 - num_image_channels = 3 - height = width = 4 - sequence_length = 48 - embedding_dim = 32 - - hidden_states = torch.randn((batch_size, height * width, num_latent_channels)).to(torch_device) - encoder_hidden_states = torch.randn((batch_size, sequence_length, embedding_dim)).to(torch_device) - pooled_prompt_embeds = torch.randn((batch_size, embedding_dim)).to(torch_device) - text_ids = torch.randn((sequence_length, num_image_channels)).to(torch_device) - image_ids = torch.randn((height * width, num_image_channels)).to(torch_device) - timestep = torch.tensor([1.0]).to(torch_device).expand(batch_size) - - return { - "hidden_states": hidden_states, - "encoder_hidden_states": encoder_hidden_states, - "pooled_projections": pooled_prompt_embeds, - "txt_ids": text_ids, - "img_ids": image_ids, - "timestep": timestep, - } - def test_with_alpha_in_state_dict(self): components, _, denoiser_lora_config = self.get_dummy_components(FlowMatchEulerDiscreteScheduler) pipe = self.pipeline_class(**components) @@ -189,13 +165,12 @@ def test_with_norm_in_state_dict(self): pipe = pipe.to(torch_device) pipe.set_progress_bar_config(disable=None) - inputs = self.get_dummy_tensor_inputs(torch_device) + _, _, inputs = self.get_dummy_inputs(with_generator=False) logger = logging.get_logger("diffusers.loaders.lora_pipeline") logger.setLevel(logging.INFO) - with torch.no_grad(): - original_output = pipe.transformer(**inputs)[0] + original_output = pipe(**inputs, generator=torch.manual_seed(0))[0] for norm_layer in ["norm_q", "norm_k", "norm_added_q", "norm_added_k"]: norm_state_dict = {} @@ -206,18 +181,19 @@ def test_with_norm_in_state_dict(self): module.weight.shape, device=module.weight.device, dtype=module.weight.dtype ) - with torch.no_grad(): with CaptureLogger(logger) as cap_logger: pipe.load_lora_weights(norm_state_dict) - lora_load_output = pipe.transformer(**inputs)[0] + lora_load_output = pipe(**inputs, generator=torch.manual_seed(0))[0] + self.assertTrue( cap_logger.out.startswith( "The provided state dict contains normalization layers in addition to LoRA layers" ) ) + self.assertTrue(len(pipe.transformer._transformer_norm_layers) > 0) pipe.unload_lora_weights() - lora_unload_output = pipe.transformer(**inputs)[0] + lora_unload_output = pipe(**inputs, generator=torch.manual_seed(0))[0] self.assertTrue(pipe.transformer._transformer_norm_layers is None) self.assertFalse(np.allclose(original_output, lora_load_output, atol=1e-5, rtol=1e-5)) @@ -238,14 +214,11 @@ def test_lora_parameter_expanded_shapes(self): pipe = pipe.to(torch_device) pipe.set_progress_bar_config(disable=None) - inputs = self.get_dummy_tensor_inputs(torch_device) + _, _, inputs = self.get_dummy_inputs(with_generator=False) logger = logging.get_logger("diffusers.loaders.lora_pipeline") logger.setLevel(logging.DEBUG) - with torch.no_grad(): - original_output = pipe.transformer(**inputs)[0] - out_features, in_features = pipe.transformer.x_embedder.weight.shape rank = 4 @@ -257,12 +230,12 @@ def test_lora_parameter_expanded_shapes(self): } with CaptureLogger(logger) as cap_logger: pipe.load_lora_weights(lora_state_dict, "adapter-1") - inputs["hidden_states"] = torch.cat([inputs["hidden_states"]] * 2, dim=2) - with torch.no_grad(): - expanded_output = pipe.transformer(**inputs)[0] + + self.assertTrue(pipe.transformer.x_embedder.weight.data.shape[1] == 2 * in_features) + self.assertTrue(pipe.transformer.config.in_channels == 2 * in_features) + pipe.delete_adapters("adapter-1") self.assertTrue(cap_logger.out.startswith("Expanding the nn.Linear input/output features for module")) - self.assertFalse(np.allclose(original_output, expanded_output, atol=1e-3, rtol=1e-3)) components, _, _ = self.get_dummy_components(FlowMatchEulerDiscreteScheduler) pipe = self.pipeline_class(**components) @@ -286,24 +259,21 @@ def test_lora_B_bias(self): pipe = pipe.to(torch_device) pipe.set_progress_bar_config(disable=None) - inputs = self.get_dummy_tensor_inputs(torch_device) + _, _, inputs = self.get_dummy_inputs(with_generator=False) logger = logging.get_logger("diffusers.loaders.lora_pipeline") logger.setLevel(logging.INFO) - with torch.no_grad(): - original_output = pipe.transformer(**inputs)[0] + original_output = pipe(**inputs, generator=torch.manual_seed(0))[0] denoiser_lora_config.lora_bias = False pipe.transformer.add_adapter(denoiser_lora_config, "adapter-1") - with torch.no_grad(): - lora_bias_false_output = pipe.transformer(**inputs)[0] + lora_bias_false_output = pipe(**inputs, generator=torch.manual_seed(0))[0] pipe.delete_adapters("adapter-1") denoiser_lora_config.lora_bias = True pipe.transformer.add_adapter(denoiser_lora_config, "adapter-1") - with torch.no_grad(): - lora_bias_true_output = pipe.transformer(**inputs)[0] + lora_bias_true_output = pipe(**inputs)[0] self.assertFalse(np.allclose(original_output, lora_bias_false_output, atol=1e-3, rtol=1e-3)) self.assertFalse(np.allclose(original_output, lora_bias_true_output, atol=1e-3, rtol=1e-3)) From f9bd3eb6b85de3d28a1b2da28c1352b8b0dea4c0 Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Tue, 3 Dec 2024 15:13:26 +0530 Subject: [PATCH 40/58] fix --- tests/lora/test_lora_layers_flux.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/tests/lora/test_lora_layers_flux.py b/tests/lora/test_lora_layers_flux.py index d23a0b9f9a70..7ed9dd00f8f3 100644 --- a/tests/lora/test_lora_layers_flux.py +++ b/tests/lora/test_lora_layers_flux.py @@ -196,8 +196,10 @@ def test_with_norm_in_state_dict(self): lora_unload_output = pipe(**inputs, generator=torch.manual_seed(0))[0] self.assertTrue(pipe.transformer._transformer_norm_layers is None) - self.assertFalse(np.allclose(original_output, lora_load_output, atol=1e-5, rtol=1e-5)) self.assertTrue(np.allclose(original_output, lora_unload_output, atol=1e-5, rtol=1e-5)) + self.assertFalse( + np.allclose(original_output, lora_load_output, atol=1e-6, rtol=1e-6), f"{norm_layer} is tested" + ) with CaptureLogger(logger) as cap_logger: for key in list(norm_state_dict.keys()): From 84c168c2d931a708675cb754027726ee78e5ed72 Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Wed, 4 Dec 2024 11:22:37 +0530 Subject: [PATCH 41/58] alpha_pattern. --- src/diffusers/__init__.py | 2 +- src/diffusers/loaders/peft.py | 11 ++++++++++- 2 files changed, 11 insertions(+), 2 deletions(-) diff --git a/src/diffusers/__init__.py b/src/diffusers/__init__.py index 6f70a8191629..db46dc1d8801 100644 --- a/src/diffusers/__init__.py +++ b/src/diffusers/__init__.py @@ -338,8 +338,8 @@ "StableDiffusion3ControlNetPipeline", "StableDiffusion3Img2ImgPipeline", "StableDiffusion3InpaintPipeline", - "StableDiffusion3PAGPipeline", "StableDiffusion3PAGImg2ImgPipeline", + "StableDiffusion3PAGPipeline", "StableDiffusion3Pipeline", "StableDiffusionAdapterPipeline", "StableDiffusionAttendAndExcitePipeline", diff --git a/src/diffusers/loaders/peft.py b/src/diffusers/loaders/peft.py index fa2139143a82..a746b33f01ac 100644 --- a/src/diffusers/loaders/peft.py +++ b/src/diffusers/loaders/peft.py @@ -94,6 +94,16 @@ def _maybe_adjust_config(config): if mod != ambiguous_key and mod not in config["rank_pattern"]: config["rank_pattern"][mod] = original_r + # handle alphas to deal with cases like + # https://github.com/huggingface/diffusers/pull/9999#issuecomment-2516180777 + has_different_ranks = len(config["rank_pattern"]) > 1 and list(config["rank_pattern"])[0] != config["r"] + if has_different_ranks: + config["lora_alpha"] = config["r"] + alpha_pattern = {} + for module_name, rank in config["rank_pattern"].items(): + alpha_pattern[module_name] = rank + config["alpha_pattern"] = alpha_pattern + return config @@ -290,7 +300,6 @@ def load_lora_adapter(self, pretrained_model_name_or_path_or_dict, prefix="trans lora_config_kwargs.pop("lora_bias") lora_config = LoraConfig(**lora_config_kwargs) - # adapter_name if adapter_name is None: adapter_name = get_adapter_name(self) From be1d788ba64bb01afe89989b5fd43e3fa0857183 Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Wed, 4 Dec 2024 15:08:33 +0530 Subject: [PATCH 42/58] add a test for varied lora ranks and alphas. --- tests/lora/test_lora_layers_flux.py | 46 ++++++++++++++++++++++++++++- 1 file changed, 45 insertions(+), 1 deletion(-) diff --git a/tests/lora/test_lora_layers_flux.py b/tests/lora/test_lora_layers_flux.py index 7ed9dd00f8f3..4c403b985153 100644 --- a/tests/lora/test_lora_layers_flux.py +++ b/tests/lora/test_lora_layers_flux.py @@ -159,6 +159,7 @@ def test_with_alpha_in_state_dict(self): ) self.assertFalse(np.allclose(images_lora_with_alpha, images_lora, atol=1e-3, rtol=1e-3)) + # flux control lora specific def test_with_norm_in_state_dict(self): components, _, denoiser_lora_config = self.get_dummy_components(FlowMatchEulerDiscreteScheduler) pipe = self.pipeline_class(**components) @@ -210,6 +211,7 @@ def test_with_norm_in_state_dict(self): cap_logger.out.startswith("Unsupported keys found in state dict when trying to load normalization layers") ) + # flux control lora specific def test_lora_parameter_expanded_shapes(self): components, _, _ = self.get_dummy_components(FlowMatchEulerDiscreteScheduler) pipe = self.pipeline_class(**components) @@ -254,6 +256,7 @@ def test_lora_parameter_expanded_shapes(self): with self.assertRaises(NotImplementedError): pipe.load_lora_weights(lora_state_dict, "adapter-1") + # flux control lora specific @require_peft_version_greater("0.13.2") def test_lora_B_bias(self): components, _, denoiser_lora_config = self.get_dummy_components(FlowMatchEulerDiscreteScheduler) @@ -275,12 +278,53 @@ def test_lora_B_bias(self): denoiser_lora_config.lora_bias = True pipe.transformer.add_adapter(denoiser_lora_config, "adapter-1") - lora_bias_true_output = pipe(**inputs)[0] + lora_bias_true_output = pipe(**inputs, generator=torch.manual_seed(0))[0] self.assertFalse(np.allclose(original_output, lora_bias_false_output, atol=1e-3, rtol=1e-3)) self.assertFalse(np.allclose(original_output, lora_bias_true_output, atol=1e-3, rtol=1e-3)) self.assertFalse(np.allclose(lora_bias_false_output, lora_bias_true_output, atol=1e-3, rtol=1e-3)) + # for now this is flux control lora specific but can be generalized later and added to ./utils.py + def test_correct_lora_configs_with_different_ranks(self): + components, _, denoiser_lora_config = self.get_dummy_components(FlowMatchEulerDiscreteScheduler) + pipe = self.pipeline_class(**components) + pipe = pipe.to(torch_device) + pipe.set_progress_bar_config(disable=None) + _, _, inputs = self.get_dummy_inputs(with_generator=False) + + original_output = pipe(**inputs, generator=torch.manual_seed(0))[0] + + pipe.transformer.add_adapter(denoiser_lora_config, "adapter-1") + lora_output_same_rank = pipe(**inputs, generator=torch.manual_seed(0))[0] + pipe.transformer.delete_adapters("adapter-1") + + # change the rank_pattern + updated_rank = denoiser_lora_config.r * 2 + denoiser_lora_config.rank_pattern = {"single_transformer_blocks.0.attn.to_k": updated_rank} + pipe.transformer.add_adapter(denoiser_lora_config, "adapter-1") + assert pipe.transformer.peft_config["adapter-1"].rank_pattern == { + "single_transformer_blocks.0.attn.to_k": updated_rank + } + + lora_output_diff_rank = pipe(**inputs, generator=torch.manual_seed(0))[0] + + self.assertTrue(not np.allclose(original_output, lora_output_same_rank, atol=1e-3, rtol=1e-3)) + self.assertTrue(not np.allclose(lora_output_diff_rank, lora_output_same_rank, atol=1e-3, rtol=1e-3)) + pipe.transformer.delete_adapters("adapter-1") + + # similarly change the alpha_pattern + updated_alpha = denoiser_lora_config.lora_alpha * 2 + denoiser_lora_config.alpha_pattern = {"single_transformer_blocks.0.attn.to_k": updated_alpha} + pipe.transformer.add_adapter(denoiser_lora_config, "adapter-1") + assert pipe.transformer.peft_config["adapter-1"].alpha_pattern == { + "single_transformer_blocks.0.attn.to_k": updated_alpha + } + + lora_output_diff_alpha = pipe(**inputs, generator=torch.manual_seed(0))[0] + + self.assertTrue(not np.allclose(original_output, lora_output_diff_alpha, atol=1e-3, rtol=1e-3)) + self.assertTrue(not np.allclose(lora_output_diff_alpha, lora_output_same_rank, atol=1e-3, rtol=1e-3)) + @unittest.skip("Not supported in Flux.") def test_simple_inference_with_text_denoiser_block_scale_for_all_dict_options(self): pass From 5b1bcd89321b48552430e608655300f9ff0b9900 Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Wed, 4 Dec 2024 15:56:53 +0530 Subject: [PATCH 43/58] revert changes in num_channels_latents = self.transformer.config.in_channels // 8 --- src/diffusers/pipelines/flux/pipeline_flux_control.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/diffusers/pipelines/flux/pipeline_flux_control.py b/src/diffusers/pipelines/flux/pipeline_flux_control.py index ccdd68f000fe..cd5afc6e473a 100644 --- a/src/diffusers/pipelines/flux/pipeline_flux_control.py +++ b/src/diffusers/pipelines/flux/pipeline_flux_control.py @@ -763,6 +763,7 @@ def __call__( # 4. Prepare latent variables num_channels_latents = self.transformer.config.in_channels // 8 + control_image = self.prepare_image( image=control_image, width=width, From cde01e38cae874e5f3d75da90926b3c5c007c517 Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Wed, 4 Dec 2024 15:57:21 +0530 Subject: [PATCH 44/58] revert moe --- src/diffusers/pipelines/flux/pipeline_flux_control.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/diffusers/pipelines/flux/pipeline_flux_control.py b/src/diffusers/pipelines/flux/pipeline_flux_control.py index cd5afc6e473a..dc3ca8cf7b09 100644 --- a/src/diffusers/pipelines/flux/pipeline_flux_control.py +++ b/src/diffusers/pipelines/flux/pipeline_flux_control.py @@ -763,7 +763,7 @@ def __call__( # 4. Prepare latent variables num_channels_latents = self.transformer.config.in_channels // 8 - + control_image = self.prepare_image( image=control_image, width=width, From f688ecf9ae91232be34a3749dfbc3404b11af66f Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Thu, 5 Dec 2024 17:24:17 +0530 Subject: [PATCH 45/58] add a sanity check on unexpected keys when loading norm layers. --- src/diffusers/loaders/lora_pipeline.py | 18 ++++++++++++++---- 1 file changed, 14 insertions(+), 4 deletions(-) diff --git a/src/diffusers/loaders/lora_pipeline.py b/src/diffusers/loaders/lora_pipeline.py index 52e1203823e6..948710d886f9 100644 --- a/src/diffusers/loaders/lora_pipeline.py +++ b/src/diffusers/loaders/lora_pipeline.py @@ -1652,6 +1652,7 @@ class FluxLoraLoaderMixin(LoraBaseMixin): _lora_loadable_modules = ["transformer", "text_encoder"] transformer_name = TRANSFORMER_NAME text_encoder_name = TEXT_ENCODER_NAME + _control_lora_supported_norm_keys = ["norm_q", "norm_k", "norm_added_q", "norm_added_k"] @classmethod @validate_hf_hub_args @@ -1835,8 +1836,9 @@ def load_lora_weights( has_lora_keys = any("lora" in key for key in state_dict.keys()) # Flux Control LoRAs also have norm keys - supported_norm_keys = ["norm_q", "norm_k", "norm_added_q", "norm_added_k"] - has_norm_keys = any(norm_key in key for key in state_dict.keys() for norm_key in supported_norm_keys) + has_norm_keys = any( + norm_key in key for key in state_dict.keys() for norm_key in self._control_lora_supported_norm_keys + ) if not (has_lora_keys or has_norm_keys): raise ValueError("Invalid LoRA checkpoint.") @@ -1847,7 +1849,7 @@ def load_lora_weights( transformer_norm_state_dict = { k: state_dict.pop(k) for k in list(state_dict.keys()) - if "transformer." in k and any(norm_key in k for norm_key in supported_norm_keys) + if "transformer." in k and any(norm_key in k for norm_key in self._control_lora_supported_norm_keys) } transformer = getattr(self, self.transformer_name) if not hasattr(self, "transformer") else self.transformer @@ -1977,7 +1979,15 @@ def _load_norm_into_transformer( ) # We can't load with strict=True because the current state_dict does not contain all the transformer keys - transformer.load_state_dict(state_dict, strict=False) + incompatible_keys = transformer.load_state_dict(state_dict, strict=False) + unexpected_keys = getattr(incompatible_keys, "unexpected_keys", None) + + # We shouldn't expect to see the supported norm keys here being present in the unexpected keys. + if unexpected_keys: + if any(norm_key in k for k in unexpected_keys for norm_key in cls._control_lora_supported_norm_keys): + raise ValueError( + f"Found {unexpected_keys} as unexpected keys while trying to load norm layers into the transformer." + ) return overwritten_layers_state_dict From ecbc4cb036a84f8a8f080b9023d3e24e6c37f423 Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Fri, 6 Dec 2024 15:31:07 +0530 Subject: [PATCH 46/58] fixes --- src/diffusers/loaders/lora_conversion_utils.py | 6 +++++- src/diffusers/loaders/lora_pipeline.py | 17 ++++------------- 2 files changed, 9 insertions(+), 14 deletions(-) diff --git a/src/diffusers/loaders/lora_conversion_utils.py b/src/diffusers/loaders/lora_conversion_utils.py index 8e9bb079b43a..d1a380ff7ed1 100644 --- a/src/diffusers/loaders/lora_conversion_utils.py +++ b/src/diffusers/loaders/lora_conversion_utils.py @@ -673,6 +673,10 @@ def _convert_bfl_flux_control_lora_to_diffusers(original_state_dict): inner_dim = 3072 mlp_ratio = 4.0 + for k in original_state_dict: + if "bias" in k and "img_in" in k: + print(f"{k=}") + def swap_scale_shift(weight): shift, scale = weight.chunk(2, dim=0) new_weight = torch.cat([scale, shift], dim=0) @@ -750,7 +754,7 @@ def swap_scale_shift(weight): for i in range(num_layers): block_prefix = f"transformer_blocks.{i}." - for lora_key, lora_key in zip(["lora_A", "lora_B"], ["lora_A", "lora_B"]): + for lora_key in ["lora_A", "lora_B"]: # norms converted_state_dict[f"{block_prefix}norm1.linear.{lora_key}.weight"] = original_state_dict.pop( f"double_blocks.{i}.img_mod.lin.{lora_key}.weight" diff --git a/src/diffusers/loaders/lora_pipeline.py b/src/diffusers/loaders/lora_pipeline.py index 948710d886f9..9ce618c00ed7 100644 --- a/src/diffusers/loaders/lora_pipeline.py +++ b/src/diffusers/loaders/lora_pipeline.py @@ -2318,7 +2318,7 @@ def _maybe_expand_transformer_param_shape_or_error_( lora_A_weight_name = f"{name}.lora_A.weight" lora_B_weight_name = f"{name}.lora_B.weight" - lora_B_bias_name = f"{name}.lora_B.bias" + # lora_B_bias_name = f"{name}.lora_B.bias" if lora_A_weight_name not in state_dict.keys(): continue @@ -2352,24 +2352,15 @@ def _maybe_expand_transformer_param_shape_or_error_( expanded_module = torch.nn.Linear( in_features, out_features, bias=bias, device=module_weight.device, dtype=module_weight.dtype ) - + # Only weights are expanded and biases are not. new_weight = torch.zeros_like( expanded_module.weight.data, device=module_weight.device, dtype=module_weight.dtype ) slices = tuple(slice(0, dim) for dim in module_weight.shape) new_weight[slices] = module_weight expanded_module.weight.data.copy_(new_weight) - - bias_present_for_lora_B = lora_B_bias_name in state_dict - if bias_present_for_lora_B: - new_bias_shape = state_dict[lora_B_bias_name].shape - if bias and module_bias.shape < new_bias_shape: - new_bias = torch.zeros_like( - expanded_module.bias.data, device=module_bias.device, dtype=module_bias.dtype - ) - slices = tuple(slice(0, dim) for dim in module_bias.shape) - new_bias[slices] = module_bias - expanded_module.bias.data.copy_(new_bias) + if module_bias is not None: + expanded_module.bias.data.copy_(module_bias) setattr(parent_module, current_module_name, expanded_module) From 55058e2e38799be91c249496ca2f1e2f0c23b5c6 Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Fri, 6 Dec 2024 16:56:15 +0530 Subject: [PATCH 47/58] tests --- tests/lora/test_lora_layers_flux.py | 116 ++++++++++++++++++++++++++-- 1 file changed, 109 insertions(+), 7 deletions(-) diff --git a/tests/lora/test_lora_layers_flux.py b/tests/lora/test_lora_layers_flux.py index 4c403b985153..3119a84319e8 100644 --- a/tests/lora/test_lora_layers_flux.py +++ b/tests/lora/test_lora_layers_flux.py @@ -21,9 +21,10 @@ import numpy as np import safetensors.torch import torch +from PIL import Image from transformers import AutoTokenizer, CLIPTextModel, CLIPTokenizer, T5EncoderModel -from diffusers import FlowMatchEulerDiscreteScheduler, FluxPipeline, FluxTransformer2DModel +from diffusers import FlowMatchEulerDiscreteScheduler, FluxControlPipeline, FluxPipeline, FluxTransformer2DModel from diffusers.utils import logging from diffusers.utils.testing_utils import ( CaptureLogger, @@ -159,7 +160,80 @@ def test_with_alpha_in_state_dict(self): ) self.assertFalse(np.allclose(images_lora_with_alpha, images_lora, atol=1e-3, rtol=1e-3)) - # flux control lora specific + @unittest.skip("Not supported in Flux.") + def test_simple_inference_with_text_denoiser_block_scale_for_all_dict_options(self): + pass + + @unittest.skip("Not supported in Flux.") + def test_modify_padding_mode(self): + pass + + +class FluxControlLoRATests(unittest.TestCase, PeftLoraLoaderMixinTests): + pipeline_class = FluxControlPipeline + scheduler_cls = FlowMatchEulerDiscreteScheduler() + scheduler_kwargs = {} + scheduler_classes = [FlowMatchEulerDiscreteScheduler] + transformer_kwargs = { + "patch_size": 1, + "in_channels": 8, + "out_channels": 4, + "num_layers": 1, + "num_single_layers": 1, + "attention_head_dim": 16, + "num_attention_heads": 2, + "joint_attention_dim": 32, + "pooled_projection_dim": 32, + "axes_dims_rope": [4, 4, 8], + } + transformer_cls = FluxTransformer2DModel + vae_kwargs = { + "sample_size": 32, + "in_channels": 3, + "out_channels": 3, + "block_out_channels": (4,), + "layers_per_block": 1, + "latent_channels": 1, + "norm_num_groups": 1, + "use_quant_conv": False, + "use_post_quant_conv": False, + "shift_factor": 0.0609, + "scaling_factor": 1.5035, + } + has_two_text_encoders = True + tokenizer_cls, tokenizer_id = CLIPTokenizer, "peft-internal-testing/tiny-clip-text-2" + tokenizer_2_cls, tokenizer_2_id = AutoTokenizer, "hf-internal-testing/tiny-random-t5" + text_encoder_cls, text_encoder_id = CLIPTextModel, "peft-internal-testing/tiny-clip-text-2" + text_encoder_2_cls, text_encoder_2_id = T5EncoderModel, "hf-internal-testing/tiny-random-t5" + + @property + def output_shape(self): + return (1, 8, 8, 3) + + def get_dummy_inputs(self, with_generator=True): + batch_size = 1 + sequence_length = 10 + num_channels = 4 + sizes = (32, 32) + + generator = torch.manual_seed(0) + noise = floats_tensor((batch_size, num_channels) + sizes) + input_ids = torch.randint(1, sequence_length, size=(batch_size, sequence_length), generator=generator) + + pipeline_inputs = { + "prompt": "A painting of a squirrel eating a burger", + "control_image": Image.fromarray(np.random.randint(0, 255, size=(32, 32, 3), dtype="uint8")), + "num_inference_steps": 4, + "guidance_scale": 0.0, + "height": 8, + "width": 8, + "output_type": "np", + } + if with_generator: + pipeline_inputs.update({"generator": generator}) + + return noise, input_ids, pipeline_inputs + def test_with_norm_in_state_dict(self): components, _, denoiser_lora_config = self.get_dummy_components(FlowMatchEulerDiscreteScheduler) pipe = self.pipeline_class(**components) @@ -184,7 +258,7 @@ def test_with_norm_in_state_dict(self): with CaptureLogger(logger) as cap_logger: pipe.load_lora_weights(norm_state_dict) - lora_load_output = pipe(**inputs, generator=torch.manual_seed(0))[0] + lora_load_output = pipe(**inputs, generator=torch.manual_seed(0))[0] self.assertTrue( cap_logger.out.startswith( @@ -211,7 +285,6 @@ def test_with_norm_in_state_dict(self): cap_logger.out.startswith("Unsupported keys found in state dict when trying to load normalization layers") ) - # flux control lora specific def test_lora_parameter_expanded_shapes(self): components, _, _ = self.get_dummy_components(FlowMatchEulerDiscreteScheduler) pipe = self.pipeline_class(**components) @@ -219,10 +292,31 @@ def test_lora_parameter_expanded_shapes(self): pipe.set_progress_bar_config(disable=None) _, _, inputs = self.get_dummy_inputs(with_generator=False) + original_out = pipe(**inputs, generator=torch.manual_seed(0))[0] logger = logging.get_logger("diffusers.loaders.lora_pipeline") logger.setLevel(logging.DEBUG) + # Change the transformer config to mimic a real use case. + num_channels_without_control = 4 + transformer = FluxTransformer2DModel.from_config( + components["transformer"].config, in_channels=num_channels_without_control + ).to(torch_device) + self.assertTrue( + transformer.config.in_channels == num_channels_without_control, + f"Expected {num_channels_without_control} channels in the modified transformer but has {transformer.config.in_channels=}", + ) + + original_transformer_state_dict = pipe.transformer.state_dict() + x_embedder_weight = original_transformer_state_dict.pop("x_embedder.weight") + incompatible_keys = transformer.load_state_dict(original_transformer_state_dict, strict=False) + self.assertTrue( + "x_embedder.weight" in incompatible_keys.missing_keys, + "Could not find x_embedder.weight in the missing keys.", + ) + transformer.x_embedder.weight.data.copy_(x_embedder_weight[..., :num_channels_without_control]) + pipe.transformer = transformer + out_features, in_features = pipe.transformer.x_embedder.weight.shape rank = 4 @@ -234,11 +328,13 @@ def test_lora_parameter_expanded_shapes(self): } with CaptureLogger(logger) as cap_logger: pipe.load_lora_weights(lora_state_dict, "adapter-1") + self.assertTrue(check_if_lora_correctly_set(pipe.transformer), "Lora not correctly set in denoiser") + lora_out = pipe(**inputs, generator=torch.manual_seed(0))[0] + + self.assertFalse(np.allclose(original_out, lora_out, rtol=1e-4, atol=1e-4)) self.assertTrue(pipe.transformer.x_embedder.weight.data.shape[1] == 2 * in_features) self.assertTrue(pipe.transformer.config.in_channels == 2 * in_features) - - pipe.delete_adapters("adapter-1") self.assertTrue(cap_logger.out.startswith("Expanding the nn.Linear input/output features for module")) components, _, _ = self.get_dummy_components(FlowMatchEulerDiscreteScheduler) @@ -256,7 +352,6 @@ def test_lora_parameter_expanded_shapes(self): with self.assertRaises(NotImplementedError): pipe.load_lora_weights(lora_state_dict, "adapter-1") - # flux control lora specific @require_peft_version_greater("0.13.2") def test_lora_B_bias(self): components, _, denoiser_lora_config = self.get_dummy_components(FlowMatchEulerDiscreteScheduler) @@ -264,6 +359,13 @@ def test_lora_B_bias(self): pipe = pipe.to(torch_device) pipe.set_progress_bar_config(disable=None) + # keep track of the bias values of the base layers to perform checks later. + bias_values = {} + for name, module in pipe.transformer.named_modules(): + if any(k in name for k in ["to_q", "to_k", "to_v", "to_out.0"]): + if module.bias is not None: + bias_values[name] = module.bias.data.clone() + _, _, inputs = self.get_dummy_inputs(with_generator=False) logger = logging.get_logger("diffusers.loaders.lora_pipeline") From a8bd03b463863c19ef192894cfc0f41bcacd4fa5 Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Fri, 6 Dec 2024 16:59:41 +0530 Subject: [PATCH 48/58] reviewer feedback --- src/diffusers/loaders/lora_pipeline.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/diffusers/loaders/lora_pipeline.py b/src/diffusers/loaders/lora_pipeline.py index 9ce618c00ed7..a59eb825d7af 100644 --- a/src/diffusers/loaders/lora_pipeline.py +++ b/src/diffusers/loaders/lora_pipeline.py @@ -2318,14 +2318,13 @@ def _maybe_expand_transformer_param_shape_or_error_( lora_A_weight_name = f"{name}.lora_A.weight" lora_B_weight_name = f"{name}.lora_B.weight" - # lora_B_bias_name = f"{name}.lora_B.bias" - if lora_A_weight_name not in state_dict.keys(): continue in_features = state_dict[lora_A_weight_name].shape[1] out_features = state_dict[lora_B_weight_name].shape[0] + # This means there's no need for an expansion in the params, so we simply skip. if tuple(module_weight.shape) == (out_features, in_features): continue @@ -2349,6 +2348,7 @@ def _maybe_expand_transformer_param_shape_or_error_( parent_module_name, _, current_module_name = name.rpartition(".") parent_module = transformer.get_submodule(parent_module_name) + # TODO: consider initializing this under meta device for optims. expanded_module = torch.nn.Linear( in_features, out_features, bias=bias, device=module_weight.device, dtype=module_weight.dtype ) From 49c0242062daae7dc269e30ce4c704ade6e25eba Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Fri, 6 Dec 2024 17:03:04 +0530 Subject: [PATCH 49/58] fix --- src/diffusers/loaders/lora_pipeline.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/diffusers/loaders/lora_pipeline.py b/src/diffusers/loaders/lora_pipeline.py index a59eb825d7af..033d3cdb9a49 100644 --- a/src/diffusers/loaders/lora_pipeline.py +++ b/src/diffusers/loaders/lora_pipeline.py @@ -2246,7 +2246,7 @@ def fuse_lora( ): logger.info( "The provided state dict contains normalization layers in addition to LoRA layers. The normalization layers will be directly updated the state_dict of the transformer " - 'as opposed to the LoRA layers that will co-exist separately until the "fuse_lora()" method is called. That is to say, the normalization layers will always be directly ' + "as opposed to the LoRA layers that will co-exist separately until the 'fuse_lora()' method is called. That is to say, the normalization layers will always be directly " "fused into the transformer and can only be unfused if `discard_original_layers=True` is passed." ) From 8b050eabe414b3fc2ec9b54dcd7b39aff48fe9f9 Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Fri, 6 Dec 2024 17:17:32 +0530 Subject: [PATCH 50/58] proper peft version for lora_bias --- src/diffusers/loaders/lora_pipeline.py | 2 +- src/diffusers/loaders/peft.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/src/diffusers/loaders/lora_pipeline.py b/src/diffusers/loaders/lora_pipeline.py index 033d3cdb9a49..01de106bd37a 100644 --- a/src/diffusers/loaders/lora_pipeline.py +++ b/src/diffusers/loaders/lora_pipeline.py @@ -2542,7 +2542,7 @@ def load_lora_into_text_encoder( if lora_config_kwargs["lora_bias"]: if is_peft_version("<=", "0.13.2"): raise ValueError( - "You need `peft` 0.13.3 at least to use `bias` in LoRAs. Please upgrade your installation of `peft`." + "You need `peft` 0.14.0 at least to use `bias` in LoRAs. Please upgrade your installation of `peft`." ) else: if is_peft_version("<=", "0.13.2"): diff --git a/src/diffusers/loaders/peft.py b/src/diffusers/loaders/peft.py index a746b33f01ac..32df644b758d 100644 --- a/src/diffusers/loaders/peft.py +++ b/src/diffusers/loaders/peft.py @@ -293,7 +293,7 @@ def load_lora_adapter(self, pretrained_model_name_or_path_or_dict, prefix="trans if lora_config_kwargs["lora_bias"]: if is_peft_version("<=", "0.13.2"): raise ValueError( - "You need `peft` 0.13.3 at least to use `lora_bias` in LoRAs. Please upgrade your installation of `peft`." + "You need `peft` 0.14.0 at least to use `lora_bias` in LoRAs. Please upgrade your installation of `peft`." ) else: if is_peft_version("<=", "0.13.2"): From 3204627a9d9b2b537c05d5082943a6cfffd2b3b2 Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Fri, 6 Dec 2024 17:27:06 +0530 Subject: [PATCH 51/58] fix-copies --- src/diffusers/loaders/lora_pipeline.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/src/diffusers/loaders/lora_pipeline.py b/src/diffusers/loaders/lora_pipeline.py index 01de106bd37a..eb9b42c5fbb7 100644 --- a/src/diffusers/loaders/lora_pipeline.py +++ b/src/diffusers/loaders/lora_pipeline.py @@ -427,7 +427,7 @@ def load_lora_into_text_encoder( if lora_config_kwargs["lora_bias"]: if is_peft_version("<=", "0.13.2"): raise ValueError( - "You need `peft` 0.13.3 at least to use `bias` in LoRAs. Please upgrade your installation of `peft`." + "You need `peft` 0.14.0 at least to use `bias` in LoRAs. Please upgrade your installation of `peft`." ) else: if is_peft_version("<=", "0.13.2"): @@ -970,7 +970,7 @@ def load_lora_into_text_encoder( if lora_config_kwargs["lora_bias"]: if is_peft_version("<=", "0.13.2"): raise ValueError( - "You need `peft` 0.13.3 at least to use `bias` in LoRAs. Please upgrade your installation of `peft`." + "You need `peft` 0.14.0 at least to use `bias` in LoRAs. Please upgrade your installation of `peft`." ) else: if is_peft_version("<=", "0.13.2"): @@ -1479,7 +1479,7 @@ def load_lora_into_text_encoder( if lora_config_kwargs["lora_bias"]: if is_peft_version("<=", "0.13.2"): raise ValueError( - "You need `peft` 0.13.3 at least to use `bias` in LoRAs. Please upgrade your installation of `peft`." + "You need `peft` 0.14.0 at least to use `bias` in LoRAs. Please upgrade your installation of `peft`." ) else: if is_peft_version("<=", "0.13.2"): @@ -2108,7 +2108,7 @@ def load_lora_into_text_encoder( if lora_config_kwargs["lora_bias"]: if is_peft_version("<=", "0.13.2"): raise ValueError( - "You need `peft` 0.13.3 at least to use `bias` in LoRAs. Please upgrade your installation of `peft`." + "You need `peft` 0.14.0 at least to use `bias` in LoRAs. Please upgrade your installation of `peft`." ) else: if is_peft_version("<=", "0.13.2"): From 130e592eb82e0f53d1fca2069acc7563c4505f10 Mon Sep 17 00:00:00 2001 From: Aryan Date: Fri, 6 Dec 2024 21:56:19 +0100 Subject: [PATCH 52/58] remove debug code --- src/diffusers/loaders/lora_conversion_utils.py | 4 ---- 1 file changed, 4 deletions(-) diff --git a/src/diffusers/loaders/lora_conversion_utils.py b/src/diffusers/loaders/lora_conversion_utils.py index d1a380ff7ed1..aab87b8f4dba 100644 --- a/src/diffusers/loaders/lora_conversion_utils.py +++ b/src/diffusers/loaders/lora_conversion_utils.py @@ -673,10 +673,6 @@ def _convert_bfl_flux_control_lora_to_diffusers(original_state_dict): inner_dim = 3072 mlp_ratio = 4.0 - for k in original_state_dict: - if "bias" in k and "img_in" in k: - print(f"{k=}") - def swap_scale_shift(weight): shift, scale = weight.chunk(2, dim=0) new_weight = torch.cat([scale, shift], dim=0) From b20ec7d9bca5aa43b7fd69ddf545b27848f6bf11 Mon Sep 17 00:00:00 2001 From: Aryan Date: Fri, 6 Dec 2024 21:56:25 +0100 Subject: [PATCH 53/58] update docs --- docs/source/en/api/pipelines/flux.md | 59 ++++++++++++++++++++++++++++ 1 file changed, 59 insertions(+) diff --git a/docs/source/en/api/pipelines/flux.md b/docs/source/en/api/pipelines/flux.md index f776dc049ebd..af9c3639e047 100644 --- a/docs/source/en/api/pipelines/flux.md +++ b/docs/source/en/api/pipelines/flux.md @@ -143,6 +143,35 @@ image = pipe( image.save("output.png") ``` +Canny Control is also possible with a LoRA variant of this condition. The usage is as follows: + +```python +# !pip install -U controlnet-aux +import torch +from controlnet_aux import CannyDetector +from diffusers import FluxControlPipeline +from diffusers.utils import load_image + +pipe = FluxControlPipeline.from_pretrained("black-forest-labs/FLUX.1-dev", torch_dtype=torch.bfloat16).to("cuda") +pipe.load_lora_weights("black-forest-labs/FLUX.1-Canny-dev-lora") + +prompt = "A robot made of exotic candies and chocolates of different kinds. The background is filled with confetti and celebratory gifts." +control_image = load_image("https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/robot.png") + +processor = CannyDetector() +control_image = processor(control_image, low_threshold=50, high_threshold=200, detect_resolution=1024, image_resolution=1024) + +image = pipe( + prompt=prompt, + control_image=control_image, + height=1024, + width=1024, + num_inference_steps=50, + guidance_scale=30.0, +).images[0] +image.save("output.png") +``` + ### Depth Control **Note:** `black-forest-labs/Flux.1-Depth-dev` is _not_ a ControlNet model. [`ControlNetModel`] models are a separate component from the UNet/Transformer whose residuals are added to the actual underlying model. Depth Control is an alternate architecture that achieves effectively the same results as a ControlNet model would, by using channel-wise concatenation with input control condition and ensuring the transformer learns structure control by following the condition as closely as possible. @@ -174,6 +203,36 @@ image = pipe( image.save("output.png") ``` +Depth Control is also possible with a LoRA variant of this condition. The usage is as follows: + +```python +# !pip install git+https://github.com/huggingface/image_gen_aux +import torch +from diffusers import FluxControlPipeline, FluxTransformer2DModel +from diffusers.utils import load_image +from image_gen_aux import DepthPreprocessor + +pipe = FluxControlPipeline.from_pretrained("black-forest-labs/FLUX.1-dev", torch_dtype=torch.bfloat16).to("cuda") +pipe.load_lora_weights("black-forest-labs/FLUX.1-Depth-dev-lora") + +prompt = "A robot made of exotic candies and chocolates of different kinds. The background is filled with confetti and celebratory gifts." +control_image = load_image("https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/robot.png") + +processor = DepthPreprocessor.from_pretrained("LiheYoung/depth-anything-large-hf") +control_image = processor(control_image)[0].convert("RGB") + +image = pipe( + prompt=prompt, + control_image=control_image, + height=1024, + width=1024, + num_inference_steps=30, + guidance_scale=10.0, + generator=torch.Generator().manual_seed(42), +).images[0] +image.save("output.png") +``` + ### Redux * Flux Redux pipeline is an adapter for FLUX.1 base models. It can be used with both flux-dev and flux-schnell, for image-to-image generation. From d1715d3c81834f1d1609dd39db869ccfc11ef9ed Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Sat, 7 Dec 2024 08:35:30 +0530 Subject: [PATCH 54/58] integration tests --- tests/lora/test_lora_layers_flux.py | 62 ++++++++++++++++++++++++++++- 1 file changed, 61 insertions(+), 1 deletion(-) diff --git a/tests/lora/test_lora_layers_flux.py b/tests/lora/test_lora_layers_flux.py index 3119a84319e8..322c992b389a 100644 --- a/tests/lora/test_lora_layers_flux.py +++ b/tests/lora/test_lora_layers_flux.py @@ -19,19 +19,23 @@ import unittest import numpy as np +import pytest import safetensors.torch import torch +from parameterized import parameterized from PIL import Image from transformers import AutoTokenizer, CLIPTextModel, CLIPTokenizer, T5EncoderModel from diffusers import FlowMatchEulerDiscreteScheduler, FluxControlPipeline, FluxPipeline, FluxTransformer2DModel -from diffusers.utils import logging +from diffusers.utils import load_image, logging from diffusers.utils.testing_utils import ( CaptureLogger, floats_tensor, is_peft_available, nightly, numpy_cosine_similarity_distance, + print_tensor_test, + require_big_gpu_with_torch_cuda, require_peft_backend, require_peft_version_greater, require_torch_gpu, @@ -578,3 +582,59 @@ def test_flux_xlabs_load_lora_with_single_blocks(self): max_diff = numpy_cosine_similarity_distance(expected_slice.flatten(), out_slice) assert max_diff < 1e-3 + + +@nightly +@require_torch_gpu +@require_peft_backend +@require_big_gpu_with_torch_cuda +@pytest.mark.big_gpu_with_torch_cuda +class FluxControlLoRAIntegrationTests(unittest.TestCase): + num_inference_steps = 10 + seed = 0 + prompt = "A robot made of exotic candies and chocolates of different kinds." + + def setUp(self): + super().setUp() + + gc.collect() + torch.cuda.empty_cache() + + self.pipeline = FluxControlPipeline.from_pretrained( + "black-forest-labs/FLUX.1-dev", torch_dtype=torch.bfloat16 + ).to("cuda") + + def tearDown(self): + super().tearDown() + + gc.collect() + torch.cuda.empty_cache() + + @parameterized.expand(["black-forest-labs/FLUX.1-Canny-dev-lora", "black-forest-labs/FLUX.1-Depth-dev-lora"]) + def test_lora(self, lora_ckpt_id): + self.pipe.load_lora_weights(lora_ckpt_id) + + if "Canny" in lora_ckpt_id: + control_image = load_image( + "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/flux-control-lora/canny_condition_image.png" + ) + else: + control_image = load_image( + "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/flux-control-lora/depth_condition_image.png" + ) + + image = self.pipe( + prompt=self.prompt, + control_image=control_image, + height=1024, + width=1024, + num_inference_steps=50, + guidance_scale=30.0 if "Canny" in lora_ckpt_id else 10.0, + output_type="np", + generator=torch.manual_seed(self.seed), + ).images + + out_slice = image[0, -3:, -3:, -1].flatten() + print_tensor_test(out_slice) + + assert out_slice is None From cbad4b3c10a414bb62d007edc03804042c3f41e3 Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Sat, 7 Dec 2024 08:53:50 +0530 Subject: [PATCH 55/58] nis --- tests/lora/test_lora_layers_flux.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/lora/test_lora_layers_flux.py b/tests/lora/test_lora_layers_flux.py index 322c992b389a..1290f1fd1ef4 100644 --- a/tests/lora/test_lora_layers_flux.py +++ b/tests/lora/test_lora_layers_flux.py @@ -628,7 +628,7 @@ def test_lora(self, lora_ckpt_id): control_image=control_image, height=1024, width=1024, - num_inference_steps=50, + num_inference_steps=self.num_inference_steps, guidance_scale=30.0 if "Canny" in lora_ckpt_id else 10.0, output_type="np", generator=torch.manual_seed(self.seed), From cd7c15545be7b7322e3c3644a32bc546d4c0b603 Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Sat, 7 Dec 2024 09:03:23 +0530 Subject: [PATCH 56/58] fuse and unload. --- tests/lora/test_lora_layers_flux.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/tests/lora/test_lora_layers_flux.py b/tests/lora/test_lora_layers_flux.py index 1290f1fd1ef4..9941a99a8e1a 100644 --- a/tests/lora/test_lora_layers_flux.py +++ b/tests/lora/test_lora_layers_flux.py @@ -612,7 +612,9 @@ def tearDown(self): @parameterized.expand(["black-forest-labs/FLUX.1-Canny-dev-lora", "black-forest-labs/FLUX.1-Depth-dev-lora"]) def test_lora(self, lora_ckpt_id): - self.pipe.load_lora_weights(lora_ckpt_id) + self.pipeline.load_lora_weights(lora_ckpt_id) + self.pipeline.fuse_lora() + self.pipeline.unload_lora_weights() if "Canny" in lora_ckpt_id: control_image = load_image( From 25616e2e11f9e3937c50951ddaba057c84a59ae8 Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Sat, 7 Dec 2024 10:31:32 +0530 Subject: [PATCH 57/58] fix --- tests/lora/test_lora_layers_flux.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/lora/test_lora_layers_flux.py b/tests/lora/test_lora_layers_flux.py index 9941a99a8e1a..a0e7e79a2972 100644 --- a/tests/lora/test_lora_layers_flux.py +++ b/tests/lora/test_lora_layers_flux.py @@ -625,7 +625,7 @@ def test_lora(self, lora_ckpt_id): "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/flux-control-lora/depth_condition_image.png" ) - image = self.pipe( + image = self.pipeline( prompt=self.prompt, control_image=control_image, height=1024, From 0b83debefbedacf1db0d00989786001d0002dbae Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Sat, 7 Dec 2024 10:43:11 +0530 Subject: [PATCH 58/58] add slices. --- tests/lora/test_lora_layers_flux.py | 10 +++++++--- 1 file changed, 7 insertions(+), 3 deletions(-) diff --git a/tests/lora/test_lora_layers_flux.py b/tests/lora/test_lora_layers_flux.py index a0e7e79a2972..8142085f981c 100644 --- a/tests/lora/test_lora_layers_flux.py +++ b/tests/lora/test_lora_layers_flux.py @@ -34,7 +34,6 @@ is_peft_available, nightly, numpy_cosine_similarity_distance, - print_tensor_test, require_big_gpu_with_torch_cuda, require_peft_backend, require_peft_version_greater, @@ -637,6 +636,11 @@ def test_lora(self, lora_ckpt_id): ).images out_slice = image[0, -3:, -3:, -1].flatten() - print_tensor_test(out_slice) + if "Canny" in lora_ckpt_id: + expected_slice = np.array([0.8438, 0.8438, 0.8438, 0.8438, 0.8438, 0.8398, 0.8438, 0.8438, 0.8516]) + else: + expected_slice = np.array([0.8203, 0.8320, 0.8359, 0.8203, 0.8281, 0.8281, 0.8203, 0.8242, 0.8359]) - assert out_slice is None + max_diff = numpy_cosine_similarity_distance(expected_slice.flatten(), out_slice) + + assert max_diff < 1e-3