diff --git a/docs/source/en/api/pipelines/flux.md b/docs/source/en/api/pipelines/flux.md index f776dc049ebd..af9c3639e047 100644 --- a/docs/source/en/api/pipelines/flux.md +++ b/docs/source/en/api/pipelines/flux.md @@ -143,6 +143,35 @@ image = pipe( image.save("output.png") ``` +Canny Control is also possible with a LoRA variant of this condition. The usage is as follows: + +```python +# !pip install -U controlnet-aux +import torch +from controlnet_aux import CannyDetector +from diffusers import FluxControlPipeline +from diffusers.utils import load_image + +pipe = FluxControlPipeline.from_pretrained("black-forest-labs/FLUX.1-dev", torch_dtype=torch.bfloat16).to("cuda") +pipe.load_lora_weights("black-forest-labs/FLUX.1-Canny-dev-lora") + +prompt = "A robot made of exotic candies and chocolates of different kinds. The background is filled with confetti and celebratory gifts." +control_image = load_image("https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/robot.png") + +processor = CannyDetector() +control_image = processor(control_image, low_threshold=50, high_threshold=200, detect_resolution=1024, image_resolution=1024) + +image = pipe( + prompt=prompt, + control_image=control_image, + height=1024, + width=1024, + num_inference_steps=50, + guidance_scale=30.0, +).images[0] +image.save("output.png") +``` + ### Depth Control **Note:** `black-forest-labs/Flux.1-Depth-dev` is _not_ a ControlNet model. [`ControlNetModel`] models are a separate component from the UNet/Transformer whose residuals are added to the actual underlying model. Depth Control is an alternate architecture that achieves effectively the same results as a ControlNet model would, by using channel-wise concatenation with input control condition and ensuring the transformer learns structure control by following the condition as closely as possible. @@ -174,6 +203,36 @@ image = pipe( image.save("output.png") ``` +Depth Control is also possible with a LoRA variant of this condition. The usage is as follows: + +```python +# !pip install git+https://github.com/huggingface/image_gen_aux +import torch +from diffusers import FluxControlPipeline, FluxTransformer2DModel +from diffusers.utils import load_image +from image_gen_aux import DepthPreprocessor + +pipe = FluxControlPipeline.from_pretrained("black-forest-labs/FLUX.1-dev", torch_dtype=torch.bfloat16).to("cuda") +pipe.load_lora_weights("black-forest-labs/FLUX.1-Depth-dev-lora") + +prompt = "A robot made of exotic candies and chocolates of different kinds. The background is filled with confetti and celebratory gifts." +control_image = load_image("https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/robot.png") + +processor = DepthPreprocessor.from_pretrained("LiheYoung/depth-anything-large-hf") +control_image = processor(control_image)[0].convert("RGB") + +image = pipe( + prompt=prompt, + control_image=control_image, + height=1024, + width=1024, + num_inference_steps=30, + guidance_scale=10.0, + generator=torch.Generator().manual_seed(42), +).images[0] +image.save("output.png") +``` + ### Redux * Flux Redux pipeline is an adapter for FLUX.1 base models. It can be used with both flux-dev and flux-schnell, for image-to-image generation. diff --git a/src/diffusers/loaders/lora_conversion_utils.py b/src/diffusers/loaders/lora_conversion_utils.py index 51a406b2f6a3..aab87b8f4dba 100644 --- a/src/diffusers/loaders/lora_conversion_utils.py +++ b/src/diffusers/loaders/lora_conversion_utils.py @@ -663,3 +663,309 @@ def handle_qkv(sds_sd, ait_sd, sds_key, ait_keys, dims=None): raise ValueError(f"`old_state_dict` should be at this point but has: {list(old_state_dict.keys())}.") return new_state_dict + + +def _convert_bfl_flux_control_lora_to_diffusers(original_state_dict): + converted_state_dict = {} + original_state_dict_keys = list(original_state_dict.keys()) + num_layers = 19 + num_single_layers = 38 + inner_dim = 3072 + mlp_ratio = 4.0 + + def swap_scale_shift(weight): + shift, scale = weight.chunk(2, dim=0) + new_weight = torch.cat([scale, shift], dim=0) + return new_weight + + for lora_key in ["lora_A", "lora_B"]: + ## time_text_embed.timestep_embedder <- time_in + converted_state_dict[ + f"time_text_embed.timestep_embedder.linear_1.{lora_key}.weight" + ] = original_state_dict.pop(f"time_in.in_layer.{lora_key}.weight") + if f"time_in.in_layer.{lora_key}.bias" in original_state_dict_keys: + converted_state_dict[ + f"time_text_embed.timestep_embedder.linear_1.{lora_key}.bias" + ] = original_state_dict.pop(f"time_in.in_layer.{lora_key}.bias") + + converted_state_dict[ + f"time_text_embed.timestep_embedder.linear_2.{lora_key}.weight" + ] = original_state_dict.pop(f"time_in.out_layer.{lora_key}.weight") + if f"time_in.out_layer.{lora_key}.bias" in original_state_dict_keys: + converted_state_dict[ + f"time_text_embed.timestep_embedder.linear_2.{lora_key}.bias" + ] = original_state_dict.pop(f"time_in.out_layer.{lora_key}.bias") + + ## time_text_embed.text_embedder <- vector_in + converted_state_dict[f"time_text_embed.text_embedder.linear_1.{lora_key}.weight"] = original_state_dict.pop( + f"vector_in.in_layer.{lora_key}.weight" + ) + if f"vector_in.in_layer.{lora_key}.bias" in original_state_dict_keys: + converted_state_dict[f"time_text_embed.text_embedder.linear_1.{lora_key}.bias"] = original_state_dict.pop( + f"vector_in.in_layer.{lora_key}.bias" + ) + + converted_state_dict[f"time_text_embed.text_embedder.linear_2.{lora_key}.weight"] = original_state_dict.pop( + f"vector_in.out_layer.{lora_key}.weight" + ) + if f"vector_in.out_layer.{lora_key}.bias" in original_state_dict_keys: + converted_state_dict[f"time_text_embed.text_embedder.linear_2.{lora_key}.bias"] = original_state_dict.pop( + f"vector_in.out_layer.{lora_key}.bias" + ) + + # guidance + has_guidance = any("guidance" in k for k in original_state_dict) + if has_guidance: + converted_state_dict[ + f"time_text_embed.guidance_embedder.linear_1.{lora_key}.weight" + ] = original_state_dict.pop(f"guidance_in.in_layer.{lora_key}.weight") + if f"guidance_in.in_layer.{lora_key}.bias" in original_state_dict_keys: + converted_state_dict[ + f"time_text_embed.guidance_embedder.linear_1.{lora_key}.bias" + ] = original_state_dict.pop(f"guidance_in.in_layer.{lora_key}.bias") + + converted_state_dict[ + f"time_text_embed.guidance_embedder.linear_2.{lora_key}.weight" + ] = original_state_dict.pop(f"guidance_in.out_layer.{lora_key}.weight") + if f"guidance_in.out_layer.{lora_key}.bias" in original_state_dict_keys: + converted_state_dict[ + f"time_text_embed.guidance_embedder.linear_2.{lora_key}.bias" + ] = original_state_dict.pop(f"guidance_in.out_layer.{lora_key}.bias") + + # context_embedder + converted_state_dict[f"context_embedder.{lora_key}.weight"] = original_state_dict.pop( + f"txt_in.{lora_key}.weight" + ) + if f"txt_in.{lora_key}.bias" in original_state_dict_keys: + converted_state_dict[f"context_embedder.{lora_key}.bias"] = original_state_dict.pop( + f"txt_in.{lora_key}.bias" + ) + + # x_embedder + converted_state_dict[f"x_embedder.{lora_key}.weight"] = original_state_dict.pop(f"img_in.{lora_key}.weight") + if f"img_in.{lora_key}.bias" in original_state_dict_keys: + converted_state_dict[f"x_embedder.{lora_key}.bias"] = original_state_dict.pop(f"img_in.{lora_key}.bias") + + # double transformer blocks + for i in range(num_layers): + block_prefix = f"transformer_blocks.{i}." + + for lora_key in ["lora_A", "lora_B"]: + # norms + converted_state_dict[f"{block_prefix}norm1.linear.{lora_key}.weight"] = original_state_dict.pop( + f"double_blocks.{i}.img_mod.lin.{lora_key}.weight" + ) + if f"double_blocks.{i}.img_mod.lin.{lora_key}.bias" in original_state_dict_keys: + converted_state_dict[f"{block_prefix}norm1.linear.{lora_key}.bias"] = original_state_dict.pop( + f"double_blocks.{i}.img_mod.lin.{lora_key}.bias" + ) + + converted_state_dict[f"{block_prefix}norm1_context.linear.{lora_key}.weight"] = original_state_dict.pop( + f"double_blocks.{i}.txt_mod.lin.{lora_key}.weight" + ) + if f"double_blocks.{i}.txt_mod.lin.{lora_key}.bias" in original_state_dict_keys: + converted_state_dict[f"{block_prefix}norm1_context.linear.{lora_key}.bias"] = original_state_dict.pop( + f"double_blocks.{i}.txt_mod.lin.{lora_key}.bias" + ) + + # Q, K, V + if lora_key == "lora_A": + sample_lora_weight = original_state_dict.pop(f"double_blocks.{i}.img_attn.qkv.{lora_key}.weight") + converted_state_dict[f"{block_prefix}attn.to_v.{lora_key}.weight"] = torch.cat([sample_lora_weight]) + converted_state_dict[f"{block_prefix}attn.to_q.{lora_key}.weight"] = torch.cat([sample_lora_weight]) + converted_state_dict[f"{block_prefix}attn.to_k.{lora_key}.weight"] = torch.cat([sample_lora_weight]) + + context_lora_weight = original_state_dict.pop(f"double_blocks.{i}.txt_attn.qkv.{lora_key}.weight") + converted_state_dict[f"{block_prefix}attn.add_q_proj.{lora_key}.weight"] = torch.cat( + [context_lora_weight] + ) + converted_state_dict[f"{block_prefix}attn.add_k_proj.{lora_key}.weight"] = torch.cat( + [context_lora_weight] + ) + converted_state_dict[f"{block_prefix}attn.add_v_proj.{lora_key}.weight"] = torch.cat( + [context_lora_weight] + ) + else: + sample_q, sample_k, sample_v = torch.chunk( + original_state_dict.pop(f"double_blocks.{i}.img_attn.qkv.{lora_key}.weight"), 3, dim=0 + ) + converted_state_dict[f"{block_prefix}attn.to_q.{lora_key}.weight"] = torch.cat([sample_q]) + converted_state_dict[f"{block_prefix}attn.to_k.{lora_key}.weight"] = torch.cat([sample_k]) + converted_state_dict[f"{block_prefix}attn.to_v.{lora_key}.weight"] = torch.cat([sample_v]) + + context_q, context_k, context_v = torch.chunk( + original_state_dict.pop(f"double_blocks.{i}.txt_attn.qkv.{lora_key}.weight"), 3, dim=0 + ) + converted_state_dict[f"{block_prefix}attn.add_q_proj.{lora_key}.weight"] = torch.cat([context_q]) + converted_state_dict[f"{block_prefix}attn.add_k_proj.{lora_key}.weight"] = torch.cat([context_k]) + converted_state_dict[f"{block_prefix}attn.add_v_proj.{lora_key}.weight"] = torch.cat([context_v]) + + if f"double_blocks.{i}.img_attn.qkv.{lora_key}.bias" in original_state_dict_keys: + sample_q_bias, sample_k_bias, sample_v_bias = torch.chunk( + original_state_dict.pop(f"double_blocks.{i}.img_attn.qkv.{lora_key}.bias"), 3, dim=0 + ) + converted_state_dict[f"{block_prefix}attn.to_q.{lora_key}.bias"] = torch.cat([sample_q_bias]) + converted_state_dict[f"{block_prefix}attn.to_k.{lora_key}.bias"] = torch.cat([sample_k_bias]) + converted_state_dict[f"{block_prefix}attn.to_v.{lora_key}.bias"] = torch.cat([sample_v_bias]) + + if f"double_blocks.{i}.txt_attn.qkv.{lora_key}.bias" in original_state_dict_keys: + context_q_bias, context_k_bias, context_v_bias = torch.chunk( + original_state_dict.pop(f"double_blocks.{i}.txt_attn.qkv.{lora_key}.bias"), 3, dim=0 + ) + converted_state_dict[f"{block_prefix}attn.add_q_proj.{lora_key}.bias"] = torch.cat([context_q_bias]) + converted_state_dict[f"{block_prefix}attn.add_k_proj.{lora_key}.bias"] = torch.cat([context_k_bias]) + converted_state_dict[f"{block_prefix}attn.add_v_proj.{lora_key}.bias"] = torch.cat([context_v_bias]) + + # ff img_mlp + converted_state_dict[f"{block_prefix}ff.net.0.proj.{lora_key}.weight"] = original_state_dict.pop( + f"double_blocks.{i}.img_mlp.0.{lora_key}.weight" + ) + if f"double_blocks.{i}.img_mlp.0.{lora_key}.bias" in original_state_dict_keys: + converted_state_dict[f"{block_prefix}ff.net.0.proj.{lora_key}.bias"] = original_state_dict.pop( + f"double_blocks.{i}.img_mlp.0.{lora_key}.bias" + ) + + converted_state_dict[f"{block_prefix}ff.net.2.{lora_key}.weight"] = original_state_dict.pop( + f"double_blocks.{i}.img_mlp.2.{lora_key}.weight" + ) + if f"double_blocks.{i}.img_mlp.2.{lora_key}.bias" in original_state_dict_keys: + converted_state_dict[f"{block_prefix}ff.net.2.{lora_key}.bias"] = original_state_dict.pop( + f"double_blocks.{i}.img_mlp.2.{lora_key}.bias" + ) + + converted_state_dict[f"{block_prefix}ff_context.net.0.proj.{lora_key}.weight"] = original_state_dict.pop( + f"double_blocks.{i}.txt_mlp.0.{lora_key}.weight" + ) + if f"double_blocks.{i}.txt_mlp.0.{lora_key}.bias" in original_state_dict_keys: + converted_state_dict[f"{block_prefix}ff_context.net.0.proj.{lora_key}.bias"] = original_state_dict.pop( + f"double_blocks.{i}.txt_mlp.0.{lora_key}.bias" + ) + + converted_state_dict[f"{block_prefix}ff_context.net.2.{lora_key}.weight"] = original_state_dict.pop( + f"double_blocks.{i}.txt_mlp.2.{lora_key}.weight" + ) + if f"double_blocks.{i}.txt_mlp.2.{lora_key}.bias" in original_state_dict_keys: + converted_state_dict[f"{block_prefix}ff_context.net.2.{lora_key}.bias"] = original_state_dict.pop( + f"double_blocks.{i}.txt_mlp.2.{lora_key}.bias" + ) + + # output projections. + converted_state_dict[f"{block_prefix}attn.to_out.0.{lora_key}.weight"] = original_state_dict.pop( + f"double_blocks.{i}.img_attn.proj.{lora_key}.weight" + ) + if f"double_blocks.{i}.img_attn.proj.{lora_key}.bias" in original_state_dict_keys: + converted_state_dict[f"{block_prefix}attn.to_out.0.{lora_key}.bias"] = original_state_dict.pop( + f"double_blocks.{i}.img_attn.proj.{lora_key}.bias" + ) + converted_state_dict[f"{block_prefix}attn.to_add_out.{lora_key}.weight"] = original_state_dict.pop( + f"double_blocks.{i}.txt_attn.proj.{lora_key}.weight" + ) + if f"double_blocks.{i}.txt_attn.proj.{lora_key}.bias" in original_state_dict_keys: + converted_state_dict[f"{block_prefix}attn.to_add_out.{lora_key}.bias"] = original_state_dict.pop( + f"double_blocks.{i}.txt_attn.proj.{lora_key}.bias" + ) + + # qk_norm + converted_state_dict[f"{block_prefix}attn.norm_q.weight"] = original_state_dict.pop( + f"double_blocks.{i}.img_attn.norm.query_norm.scale" + ) + converted_state_dict[f"{block_prefix}attn.norm_k.weight"] = original_state_dict.pop( + f"double_blocks.{i}.img_attn.norm.key_norm.scale" + ) + converted_state_dict[f"{block_prefix}attn.norm_added_q.weight"] = original_state_dict.pop( + f"double_blocks.{i}.txt_attn.norm.query_norm.scale" + ) + converted_state_dict[f"{block_prefix}attn.norm_added_k.weight"] = original_state_dict.pop( + f"double_blocks.{i}.txt_attn.norm.key_norm.scale" + ) + + # single transfomer blocks + for i in range(num_single_layers): + block_prefix = f"single_transformer_blocks.{i}." + + for lora_key in ["lora_A", "lora_B"]: + # norm.linear <- single_blocks.0.modulation.lin + converted_state_dict[f"{block_prefix}norm.linear.{lora_key}.weight"] = original_state_dict.pop( + f"single_blocks.{i}.modulation.lin.{lora_key}.weight" + ) + if f"single_blocks.{i}.modulation.lin.{lora_key}.bias" in original_state_dict_keys: + converted_state_dict[f"{block_prefix}norm.linear.{lora_key}.bias"] = original_state_dict.pop( + f"single_blocks.{i}.modulation.lin.{lora_key}.bias" + ) + + # Q, K, V, mlp + mlp_hidden_dim = int(inner_dim * mlp_ratio) + split_size = (inner_dim, inner_dim, inner_dim, mlp_hidden_dim) + + if lora_key == "lora_A": + lora_weight = original_state_dict.pop(f"single_blocks.{i}.linear1.{lora_key}.weight") + converted_state_dict[f"{block_prefix}attn.to_q.{lora_key}.weight"] = torch.cat([lora_weight]) + converted_state_dict[f"{block_prefix}attn.to_k.{lora_key}.weight"] = torch.cat([lora_weight]) + converted_state_dict[f"{block_prefix}attn.to_v.{lora_key}.weight"] = torch.cat([lora_weight]) + converted_state_dict[f"{block_prefix}proj_mlp.{lora_key}.weight"] = torch.cat([lora_weight]) + + if f"single_blocks.{i}.linear1.{lora_key}.bias" in original_state_dict_keys: + lora_bias = original_state_dict.pop(f"single_blocks.{i}.linear1.{lora_key}.bias") + converted_state_dict[f"{block_prefix}attn.to_q.{lora_key}.bias"] = torch.cat([lora_bias]) + converted_state_dict[f"{block_prefix}attn.to_k.{lora_key}.bias"] = torch.cat([lora_bias]) + converted_state_dict[f"{block_prefix}attn.to_v.{lora_key}.bias"] = torch.cat([lora_bias]) + converted_state_dict[f"{block_prefix}proj_mlp.{lora_key}.bias"] = torch.cat([lora_bias]) + else: + q, k, v, mlp = torch.split( + original_state_dict.pop(f"single_blocks.{i}.linear1.{lora_key}.weight"), split_size, dim=0 + ) + converted_state_dict[f"{block_prefix}attn.to_q.{lora_key}.weight"] = torch.cat([q]) + converted_state_dict[f"{block_prefix}attn.to_k.{lora_key}.weight"] = torch.cat([k]) + converted_state_dict[f"{block_prefix}attn.to_v.{lora_key}.weight"] = torch.cat([v]) + converted_state_dict[f"{block_prefix}proj_mlp.{lora_key}.weight"] = torch.cat([mlp]) + + if f"single_blocks.{i}.linear1.{lora_key}.bias" in original_state_dict_keys: + q_bias, k_bias, v_bias, mlp_bias = torch.split( + original_state_dict.pop(f"single_blocks.{i}.linear1.{lora_key}.bias"), split_size, dim=0 + ) + converted_state_dict[f"{block_prefix}attn.to_q.{lora_key}.bias"] = torch.cat([q_bias]) + converted_state_dict[f"{block_prefix}attn.to_k.{lora_key}.bias"] = torch.cat([k_bias]) + converted_state_dict[f"{block_prefix}attn.to_v.{lora_key}.bias"] = torch.cat([v_bias]) + converted_state_dict[f"{block_prefix}proj_mlp.{lora_key}.bias"] = torch.cat([mlp_bias]) + + # output projections. + converted_state_dict[f"{block_prefix}proj_out.{lora_key}.weight"] = original_state_dict.pop( + f"single_blocks.{i}.linear2.{lora_key}.weight" + ) + if f"single_blocks.{i}.linear2.{lora_key}.bias" in original_state_dict_keys: + converted_state_dict[f"{block_prefix}proj_out.{lora_key}.bias"] = original_state_dict.pop( + f"single_blocks.{i}.linear2.{lora_key}.bias" + ) + + # qk norm + converted_state_dict[f"{block_prefix}attn.norm_q.weight"] = original_state_dict.pop( + f"single_blocks.{i}.norm.query_norm.scale" + ) + converted_state_dict[f"{block_prefix}attn.norm_k.weight"] = original_state_dict.pop( + f"single_blocks.{i}.norm.key_norm.scale" + ) + + for lora_key in ["lora_A", "lora_B"]: + converted_state_dict[f"proj_out.{lora_key}.weight"] = original_state_dict.pop( + f"final_layer.linear.{lora_key}.weight" + ) + if f"final_layer.linear.{lora_key}.bias" in original_state_dict_keys: + converted_state_dict[f"proj_out.{lora_key}.bias"] = original_state_dict.pop( + f"final_layer.linear.{lora_key}.bias" + ) + + converted_state_dict[f"norm_out.linear.{lora_key}.weight"] = swap_scale_shift( + original_state_dict.pop(f"final_layer.adaLN_modulation.1.{lora_key}.weight") + ) + if f"final_layer.adaLN_modulation.1.{lora_key}.bias" in original_state_dict_keys: + converted_state_dict[f"norm_out.linear.{lora_key}.bias"] = swap_scale_shift( + original_state_dict.pop(f"final_layer.adaLN_modulation.1.{lora_key}.bias") + ) + + if len(original_state_dict) > 0: + raise ValueError(f"`original_state_dict` should be empty at this point but has {original_state_dict.keys()=}.") + + for key in list(converted_state_dict.keys()): + converted_state_dict[f"transformer.{key}"] = converted_state_dict.pop(key) + + return converted_state_dict diff --git a/src/diffusers/loaders/lora_pipeline.py b/src/diffusers/loaders/lora_pipeline.py index 109592c69c3e..eb9b42c5fbb7 100644 --- a/src/diffusers/loaders/lora_pipeline.py +++ b/src/diffusers/loaders/lora_pipeline.py @@ -11,6 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. + import os from typing import Callable, Dict, List, Optional, Union @@ -34,6 +35,7 @@ ) from .lora_base import LORA_WEIGHT_NAME, LORA_WEIGHT_NAME_SAFE, LoraBaseMixin, _fetch_state_dict # noqa from .lora_conversion_utils import ( + _convert_bfl_flux_control_lora_to_diffusers, _convert_kohya_flux_lora_to_diffusers, _convert_non_diffusers_lora_to_diffusers, _convert_xlabs_flux_lora_to_diffusers, @@ -61,6 +63,8 @@ UNET_NAME = "unet" TRANSFORMER_NAME = "transformer" +_MODULE_NAME_TO_ATTRIBUTE_MAP_FLUX = {"x_embedder": "in_channels"} + class StableDiffusionLoraLoaderMixin(LoraBaseMixin): r""" @@ -408,6 +412,7 @@ def load_lora_into_text_encoder( } lora_config_kwargs = get_peft_kwargs(rank, network_alphas, text_encoder_lora_state_dict, is_unet=False) + if "use_dora" in lora_config_kwargs: if lora_config_kwargs["use_dora"]: if is_peft_version("<", "0.9.0"): @@ -417,6 +422,17 @@ def load_lora_into_text_encoder( else: if is_peft_version("<", "0.9.0"): lora_config_kwargs.pop("use_dora") + + if "lora_bias" in lora_config_kwargs: + if lora_config_kwargs["lora_bias"]: + if is_peft_version("<=", "0.13.2"): + raise ValueError( + "You need `peft` 0.14.0 at least to use `bias` in LoRAs. Please upgrade your installation of `peft`." + ) + else: + if is_peft_version("<=", "0.13.2"): + lora_config_kwargs.pop("lora_bias") + lora_config = LoraConfig(**lora_config_kwargs) # adapter_name @@ -939,6 +955,7 @@ def load_lora_into_text_encoder( } lora_config_kwargs = get_peft_kwargs(rank, network_alphas, text_encoder_lora_state_dict, is_unet=False) + if "use_dora" in lora_config_kwargs: if lora_config_kwargs["use_dora"]: if is_peft_version("<", "0.9.0"): @@ -948,6 +965,17 @@ def load_lora_into_text_encoder( else: if is_peft_version("<", "0.9.0"): lora_config_kwargs.pop("use_dora") + + if "lora_bias" in lora_config_kwargs: + if lora_config_kwargs["lora_bias"]: + if is_peft_version("<=", "0.13.2"): + raise ValueError( + "You need `peft` 0.14.0 at least to use `bias` in LoRAs. Please upgrade your installation of `peft`." + ) + else: + if is_peft_version("<=", "0.13.2"): + lora_config_kwargs.pop("lora_bias") + lora_config = LoraConfig(**lora_config_kwargs) # adapter_name @@ -1436,6 +1464,7 @@ def load_lora_into_text_encoder( } lora_config_kwargs = get_peft_kwargs(rank, network_alphas, text_encoder_lora_state_dict, is_unet=False) + if "use_dora" in lora_config_kwargs: if lora_config_kwargs["use_dora"]: if is_peft_version("<", "0.9.0"): @@ -1445,6 +1474,17 @@ def load_lora_into_text_encoder( else: if is_peft_version("<", "0.9.0"): lora_config_kwargs.pop("use_dora") + + if "lora_bias" in lora_config_kwargs: + if lora_config_kwargs["lora_bias"]: + if is_peft_version("<=", "0.13.2"): + raise ValueError( + "You need `peft` 0.14.0 at least to use `bias` in LoRAs. Please upgrade your installation of `peft`." + ) + else: + if is_peft_version("<=", "0.13.2"): + lora_config_kwargs.pop("lora_bias") + lora_config = LoraConfig(**lora_config_kwargs) # adapter_name @@ -1612,6 +1652,7 @@ class FluxLoraLoaderMixin(LoraBaseMixin): _lora_loadable_modules = ["transformer", "text_encoder"] transformer_name = TRANSFORMER_NAME text_encoder_name = TEXT_ENCODER_NAME + _control_lora_supported_norm_keys = ["norm_q", "norm_k", "norm_added_q", "norm_added_k"] @classmethod @validate_hf_hub_args @@ -1721,6 +1762,11 @@ def lora_state_dict( # xlabs doesn't use `alpha`. return (state_dict, None) if return_alphas else state_dict + is_bfl_control = any("query_norm.scale" in k for k in state_dict) + if is_bfl_control: + state_dict = _convert_bfl_flux_control_lora_to_diffusers(state_dict) + return (state_dict, None) if return_alphas else state_dict + # For state dicts like # https://huggingface.co/TheLastBen/Jon_Snow_Flux_LoRA keys = list(state_dict.keys()) @@ -1787,23 +1833,54 @@ def load_lora_weights( pretrained_model_name_or_path_or_dict, return_alphas=True, **kwargs ) - is_correct_format = all("lora" in key for key in state_dict.keys()) - if not is_correct_format: + has_lora_keys = any("lora" in key for key in state_dict.keys()) + + # Flux Control LoRAs also have norm keys + has_norm_keys = any( + norm_key in key for key in state_dict.keys() for norm_key in self._control_lora_supported_norm_keys + ) + + if not (has_lora_keys or has_norm_keys): raise ValueError("Invalid LoRA checkpoint.") - transformer_state_dict = {k: v for k, v in state_dict.items() if "transformer." in k} - if len(transformer_state_dict) > 0: + transformer_lora_state_dict = { + k: state_dict.pop(k) for k in list(state_dict.keys()) if "transformer." in k and "lora" in k + } + transformer_norm_state_dict = { + k: state_dict.pop(k) + for k in list(state_dict.keys()) + if "transformer." in k and any(norm_key in k for norm_key in self._control_lora_supported_norm_keys) + } + + transformer = getattr(self, self.transformer_name) if not hasattr(self, "transformer") else self.transformer + has_param_with_expanded_shape = self._maybe_expand_transformer_param_shape_or_error_( + transformer, transformer_lora_state_dict, transformer_norm_state_dict + ) + + if has_param_with_expanded_shape: + logger.info( + "The LoRA weights contain parameters that have different shapes that expected by the transformer. " + "As a result, the state_dict of the transformer has been expanded to match the LoRA parameter shapes. " + "To get a comprehensive list of parameter names that were modified, enable debug logging." + ) + + if len(transformer_lora_state_dict) > 0: self.load_lora_into_transformer( - state_dict, + transformer_lora_state_dict, network_alphas=network_alphas, - transformer=getattr(self, self.transformer_name) - if not hasattr(self, "transformer") - else self.transformer, + transformer=transformer, adapter_name=adapter_name, _pipeline=self, low_cpu_mem_usage=low_cpu_mem_usage, ) + if len(transformer_norm_state_dict) > 0: + transformer._transformer_norm_layers = self._load_norm_into_transformer( + transformer_norm_state_dict, + transformer=transformer, + discard_original_layers=False, + ) + text_encoder_state_dict = {k: v for k, v in state_dict.items() if "text_encoder." in k} if len(text_encoder_state_dict) > 0: self.load_lora_into_text_encoder( @@ -1860,6 +1937,60 @@ def load_lora_into_transformer( low_cpu_mem_usage=low_cpu_mem_usage, ) + @classmethod + def _load_norm_into_transformer( + cls, + state_dict, + transformer, + prefix=None, + discard_original_layers=False, + ) -> Dict[str, torch.Tensor]: + # Remove prefix if present + prefix = prefix or cls.transformer_name + for key in list(state_dict.keys()): + if key.split(".")[0] == prefix: + state_dict[key[len(f"{prefix}.") :]] = state_dict.pop(key) + + # Find invalid keys + transformer_state_dict = transformer.state_dict() + transformer_keys = set(transformer_state_dict.keys()) + state_dict_keys = set(state_dict.keys()) + extra_keys = list(state_dict_keys - transformer_keys) + + if extra_keys: + logger.warning( + f"Unsupported keys found in state dict when trying to load normalization layers into the transformer. The following keys will be ignored:\n{extra_keys}." + ) + + for key in extra_keys: + state_dict.pop(key) + + # Save the layers that are going to be overwritten so that unload_lora_weights can work as expected + overwritten_layers_state_dict = {} + if not discard_original_layers: + for key in state_dict.keys(): + overwritten_layers_state_dict[key] = transformer_state_dict[key].clone() + + logger.info( + "The provided state dict contains normalization layers in addition to LoRA layers. The normalization layers will directly update the state_dict of the transformer " + 'as opposed to the LoRA layers that will co-exist separately until the "fuse_lora()" method is called. That is to say, the normalization layers will always be directly ' + "fused into the transformer and can only be unfused if `discard_original_layers=True` is passed. This might also have implications when dealing with multiple LoRAs. " + "If you notice something unexpected, please open an issue: https://github.com/huggingface/diffusers/issues." + ) + + # We can't load with strict=True because the current state_dict does not contain all the transformer keys + incompatible_keys = transformer.load_state_dict(state_dict, strict=False) + unexpected_keys = getattr(incompatible_keys, "unexpected_keys", None) + + # We shouldn't expect to see the supported norm keys here being present in the unexpected keys. + if unexpected_keys: + if any(norm_key in k for k in unexpected_keys for norm_key in cls._control_lora_supported_norm_keys): + raise ValueError( + f"Found {unexpected_keys} as unexpected keys while trying to load norm layers into the transformer." + ) + + return overwritten_layers_state_dict + @classmethod # Copied from diffusers.loaders.lora_pipeline.StableDiffusionLoraLoaderMixin.load_lora_into_text_encoder def load_lora_into_text_encoder( @@ -1962,6 +2093,7 @@ def load_lora_into_text_encoder( } lora_config_kwargs = get_peft_kwargs(rank, network_alphas, text_encoder_lora_state_dict, is_unet=False) + if "use_dora" in lora_config_kwargs: if lora_config_kwargs["use_dora"]: if is_peft_version("<", "0.9.0"): @@ -1971,6 +2103,17 @@ def load_lora_into_text_encoder( else: if is_peft_version("<", "0.9.0"): lora_config_kwargs.pop("use_dora") + + if "lora_bias" in lora_config_kwargs: + if lora_config_kwargs["lora_bias"]: + if is_peft_version("<=", "0.13.2"): + raise ValueError( + "You need `peft` 0.14.0 at least to use `bias` in LoRAs. Please upgrade your installation of `peft`." + ) + else: + if is_peft_version("<=", "0.13.2"): + lora_config_kwargs.pop("lora_bias") + lora_config = LoraConfig(**lora_config_kwargs) # adapter_name @@ -2055,7 +2198,6 @@ def save_lora_weights( safe_serialization=safe_serialization, ) - # Copied from diffusers.loaders.lora_pipeline.StableDiffusionLoraLoaderMixin.fuse_lora with unet->transformer def fuse_lora( self, components: List[str] = ["transformer", "text_encoder"], @@ -2095,6 +2237,19 @@ def fuse_lora( pipeline.fuse_lora(lora_scale=0.7) ``` """ + + transformer = getattr(self, self.transformer_name) if not hasattr(self, "transformer") else self.transformer + if ( + hasattr(transformer, "_transformer_norm_layers") + and isinstance(transformer._transformer_norm_layers, dict) + and len(transformer._transformer_norm_layers.keys()) > 0 + ): + logger.info( + "The provided state dict contains normalization layers in addition to LoRA layers. The normalization layers will be directly updated the state_dict of the transformer " + "as opposed to the LoRA layers that will co-exist separately until the 'fuse_lora()' method is called. That is to say, the normalization layers will always be directly " + "fused into the transformer and can only be unfused if `discard_original_layers=True` is passed." + ) + super().fuse_lora( components=components, lora_scale=lora_scale, safe_fusing=safe_fusing, adapter_names=adapter_names ) @@ -2113,8 +2268,111 @@ def unfuse_lora(self, components: List[str] = ["transformer", "text_encoder"], * Args: components (`List[str]`): List of LoRA-injectable components to unfuse LoRA from. """ + transformer = getattr(self, self.transformer_name) if not hasattr(self, "transformer") else self.transformer + if hasattr(transformer, "_transformer_norm_layers") and transformer._transformer_norm_layers: + transformer.load_state_dict(transformer._transformer_norm_layers, strict=False) + super().unfuse_lora(components=components) + # We override this here account for `_transformer_norm_layers`. + def unload_lora_weights(self): + super().unload_lora_weights() + + transformer = getattr(self, self.transformer_name) if not hasattr(self, "transformer") else self.transformer + if hasattr(transformer, "_transformer_norm_layers") and transformer._transformer_norm_layers: + transformer.load_state_dict(transformer._transformer_norm_layers, strict=False) + transformer._transformer_norm_layers = None + + @classmethod + def _maybe_expand_transformer_param_shape_or_error_( + cls, + transformer: torch.nn.Module, + lora_state_dict=None, + norm_state_dict=None, + prefix=None, + ) -> bool: + """ + Control LoRA expands the shape of the input layer from (3072, 64) to (3072, 128). This method handles that and + generalizes things a bit so that any parameter that needs expansion receives appropriate treatement. + """ + state_dict = {} + if lora_state_dict is not None: + state_dict.update(lora_state_dict) + if norm_state_dict is not None: + state_dict.update(norm_state_dict) + + # Remove prefix if present + prefix = prefix or cls.transformer_name + for key in list(state_dict.keys()): + if key.split(".")[0] == prefix: + state_dict[key[len(f"{prefix}.") :]] = state_dict.pop(key) + + # Expand transformer parameter shapes if they don't match lora + has_param_with_shape_update = False + + for name, module in transformer.named_modules(): + if isinstance(module, torch.nn.Linear): + module_weight = module.weight.data + module_bias = module.bias.data if hasattr(module, "bias") else None + bias = module_bias is not None + + lora_A_weight_name = f"{name}.lora_A.weight" + lora_B_weight_name = f"{name}.lora_B.weight" + if lora_A_weight_name not in state_dict.keys(): + continue + + in_features = state_dict[lora_A_weight_name].shape[1] + out_features = state_dict[lora_B_weight_name].shape[0] + + # This means there's no need for an expansion in the params, so we simply skip. + if tuple(module_weight.shape) == (out_features, in_features): + continue + + module_out_features, module_in_features = module_weight.shape + if out_features < module_out_features or in_features < module_in_features: + raise NotImplementedError( + f"Only LoRAs with input/output features higher than the current module's input/output features " + f"are currently supported. The provided LoRA contains {in_features=} and {out_features=}, which " + f"are lower than {module_in_features=} and {module_out_features=}. If you require support for " + f"this please open an issue at https://github.com/huggingface/diffusers/issues." + ) + + logger.debug( + f'Expanding the nn.Linear input/output features for module="{name}" because the provided LoRA ' + f"checkpoint contains higher number of features than expected. The number of input_features will be " + f"expanded from {module_in_features} to {in_features}, and the number of output features will be " + f"expanded from {module_out_features} to {out_features}." + ) + + has_param_with_shape_update = True + parent_module_name, _, current_module_name = name.rpartition(".") + parent_module = transformer.get_submodule(parent_module_name) + + # TODO: consider initializing this under meta device for optims. + expanded_module = torch.nn.Linear( + in_features, out_features, bias=bias, device=module_weight.device, dtype=module_weight.dtype + ) + # Only weights are expanded and biases are not. + new_weight = torch.zeros_like( + expanded_module.weight.data, device=module_weight.device, dtype=module_weight.dtype + ) + slices = tuple(slice(0, dim) for dim in module_weight.shape) + new_weight[slices] = module_weight + expanded_module.weight.data.copy_(new_weight) + if module_bias is not None: + expanded_module.bias.data.copy_(module_bias) + + setattr(parent_module, current_module_name, expanded_module) + + if current_module_name in _MODULE_NAME_TO_ATTRIBUTE_MAP_FLUX: + attribute_name = _MODULE_NAME_TO_ATTRIBUTE_MAP_FLUX[current_module_name] + new_value = int(expanded_module.weight.data.shape[1]) + old_value = getattr(transformer.config, attribute_name) + setattr(transformer.config, attribute_name, new_value) + logger.info(f"Set the {attribute_name} attribute of the model to {new_value} from {old_value}.") + + return has_param_with_shape_update + # The reason why we subclass from `StableDiffusionLoraLoaderMixin` here is because Amused initially # relied on `StableDiffusionLoraLoaderMixin` for its LoRA support. @@ -2269,6 +2527,7 @@ def load_lora_into_text_encoder( } lora_config_kwargs = get_peft_kwargs(rank, network_alphas, text_encoder_lora_state_dict, is_unet=False) + if "use_dora" in lora_config_kwargs: if lora_config_kwargs["use_dora"]: if is_peft_version("<", "0.9.0"): @@ -2278,6 +2537,17 @@ def load_lora_into_text_encoder( else: if is_peft_version("<", "0.9.0"): lora_config_kwargs.pop("use_dora") + + if "lora_bias" in lora_config_kwargs: + if lora_config_kwargs["lora_bias"]: + if is_peft_version("<=", "0.13.2"): + raise ValueError( + "You need `peft` 0.14.0 at least to use `bias` in LoRAs. Please upgrade your installation of `peft`." + ) + else: + if is_peft_version("<=", "0.13.2"): + lora_config_kwargs.pop("lora_bias") + lora_config = LoraConfig(**lora_config_kwargs) # adapter_name diff --git a/src/diffusers/loaders/peft.py b/src/diffusers/loaders/peft.py index bf118c88b2de..32df644b758d 100644 --- a/src/diffusers/loaders/peft.py +++ b/src/diffusers/loaders/peft.py @@ -56,6 +56,57 @@ } +def _maybe_adjust_config(config): + """ + We may run into some ambiguous configuration values when a model has module names, sharing a common prefix + (`proj_out.weight` and `blocks.transformer.proj_out.weight`, for example) and they have different LoRA ranks. This + method removes the ambiguity by following what is described here: + https://github.com/huggingface/diffusers/pull/9985#issuecomment-2493840028. + """ + rank_pattern = config["rank_pattern"].copy() + target_modules = config["target_modules"] + original_r = config["r"] + + for key in list(rank_pattern.keys()): + key_rank = rank_pattern[key] + + # try to detect ambiguity + # `target_modules` can also be a str, in which case this loop would loop + # over the chars of the str. The technically correct way to match LoRA keys + # in PEFT is to use LoraModel._check_target_module_exists (lora_config, key). + # But this cuts it for now. + exact_matches = [mod for mod in target_modules if mod == key] + substring_matches = [mod for mod in target_modules if key in mod and mod != key] + ambiguous_key = key + + if exact_matches and substring_matches: + # if ambiguous we update the rank associated with the ambiguous key (`proj_out`, for example) + config["r"] = key_rank + # remove the ambiguous key from `rank_pattern` and update its rank to `r`, instead + del config["rank_pattern"][key] + for mod in substring_matches: + # avoid overwriting if the module already has a specific rank + if mod not in config["rank_pattern"]: + config["rank_pattern"][mod] = original_r + + # update the rest of the keys with the `original_r` + for mod in target_modules: + if mod != ambiguous_key and mod not in config["rank_pattern"]: + config["rank_pattern"][mod] = original_r + + # handle alphas to deal with cases like + # https://github.com/huggingface/diffusers/pull/9999#issuecomment-2516180777 + has_different_ranks = len(config["rank_pattern"]) > 1 and list(config["rank_pattern"])[0] != config["r"] + if has_different_ranks: + config["lora_alpha"] = config["r"] + alpha_pattern = {} + for module_name, rank in config["rank_pattern"].items(): + alpha_pattern[module_name] = rank + config["alpha_pattern"] = alpha_pattern + + return config + + class PeftAdapterMixin: """ A class containing all functions for loading and using adapters weights that are supported in PEFT library. For @@ -216,7 +267,9 @@ def load_lora_adapter(self, pretrained_model_name_or_path_or_dict, prefix="trans rank = {} for key, val in state_dict.items(): - if "lora_B" in key: + # Cannot figure out rank from lora layers that don't have atleast 2 dimensions. + # Bias layers in LoRA only have a single dimension + if "lora_B" in key and val.ndim > 1: rank[key] = val.shape[1] if network_alphas is not None and len(network_alphas) >= 1: @@ -224,6 +277,8 @@ def load_lora_adapter(self, pretrained_model_name_or_path_or_dict, prefix="trans network_alphas = {k.replace(f"{prefix}.", ""): v for k, v in network_alphas.items() if k in alpha_keys} lora_config_kwargs = get_peft_kwargs(rank, network_alpha_dict=network_alphas, peft_state_dict=state_dict) + lora_config_kwargs = _maybe_adjust_config(lora_config_kwargs) + if "use_dora" in lora_config_kwargs: if lora_config_kwargs["use_dora"]: if is_peft_version("<", "0.9.0"): @@ -233,8 +288,18 @@ def load_lora_adapter(self, pretrained_model_name_or_path_or_dict, prefix="trans else: if is_peft_version("<", "0.9.0"): lora_config_kwargs.pop("use_dora") - lora_config = LoraConfig(**lora_config_kwargs) + if "lora_bias" in lora_config_kwargs: + if lora_config_kwargs["lora_bias"]: + if is_peft_version("<=", "0.13.2"): + raise ValueError( + "You need `peft` 0.14.0 at least to use `lora_bias` in LoRAs. Please upgrade your installation of `peft`." + ) + else: + if is_peft_version("<=", "0.13.2"): + lora_config_kwargs.pop("lora_bias") + + lora_config = LoraConfig(**lora_config_kwargs) # adapter_name if adapter_name is None: adapter_name = get_adapter_name(self) diff --git a/src/diffusers/utils/peft_utils.py b/src/diffusers/utils/peft_utils.py index dcc78a547a13..a518596f4756 100644 --- a/src/diffusers/utils/peft_utils.py +++ b/src/diffusers/utils/peft_utils.py @@ -180,6 +180,8 @@ def get_peft_kwargs(rank_dict, network_alpha_dict, peft_state_dict, is_unet=True # layer names without the Diffusers specific target_modules = list({name.split(".lora")[0] for name in peft_state_dict.keys()}) use_dora = any("lora_magnitude_vector" in k for k in peft_state_dict) + # for now we know that the "bias" keys are only associated with `lora_B`. + lora_bias = any("lora_B" in k and k.endswith(".bias") for k in peft_state_dict) lora_config_kwargs = { "r": r, @@ -188,6 +190,7 @@ def get_peft_kwargs(rank_dict, network_alpha_dict, peft_state_dict, is_unet=True "alpha_pattern": alpha_pattern, "target_modules": target_modules, "use_dora": use_dora, + "lora_bias": lora_bias, } return lora_config_kwargs diff --git a/tests/lora/test_lora_layers_flux.py b/tests/lora/test_lora_layers_flux.py index e6e87c7ba939..8142085f981c 100644 --- a/tests/lora/test_lora_layers_flux.py +++ b/tests/lora/test_lora_layers_flux.py @@ -19,17 +19,24 @@ import unittest import numpy as np +import pytest import safetensors.torch import torch +from parameterized import parameterized +from PIL import Image from transformers import AutoTokenizer, CLIPTextModel, CLIPTokenizer, T5EncoderModel -from diffusers import FlowMatchEulerDiscreteScheduler, FluxPipeline, FluxTransformer2DModel +from diffusers import FlowMatchEulerDiscreteScheduler, FluxControlPipeline, FluxPipeline, FluxTransformer2DModel +from diffusers.utils import load_image, logging from diffusers.utils.testing_utils import ( + CaptureLogger, floats_tensor, is_peft_available, nightly, numpy_cosine_similarity_distance, + require_big_gpu_with_torch_cuda, require_peft_backend, + require_peft_version_greater, require_torch_gpu, slow, torch_device, @@ -165,6 +172,273 @@ def test_modify_padding_mode(self): pass +class FluxControlLoRATests(unittest.TestCase, PeftLoraLoaderMixinTests): + pipeline_class = FluxControlPipeline + scheduler_cls = FlowMatchEulerDiscreteScheduler() + scheduler_kwargs = {} + scheduler_classes = [FlowMatchEulerDiscreteScheduler] + transformer_kwargs = { + "patch_size": 1, + "in_channels": 8, + "out_channels": 4, + "num_layers": 1, + "num_single_layers": 1, + "attention_head_dim": 16, + "num_attention_heads": 2, + "joint_attention_dim": 32, + "pooled_projection_dim": 32, + "axes_dims_rope": [4, 4, 8], + } + transformer_cls = FluxTransformer2DModel + vae_kwargs = { + "sample_size": 32, + "in_channels": 3, + "out_channels": 3, + "block_out_channels": (4,), + "layers_per_block": 1, + "latent_channels": 1, + "norm_num_groups": 1, + "use_quant_conv": False, + "use_post_quant_conv": False, + "shift_factor": 0.0609, + "scaling_factor": 1.5035, + } + has_two_text_encoders = True + tokenizer_cls, tokenizer_id = CLIPTokenizer, "peft-internal-testing/tiny-clip-text-2" + tokenizer_2_cls, tokenizer_2_id = AutoTokenizer, "hf-internal-testing/tiny-random-t5" + text_encoder_cls, text_encoder_id = CLIPTextModel, "peft-internal-testing/tiny-clip-text-2" + text_encoder_2_cls, text_encoder_2_id = T5EncoderModel, "hf-internal-testing/tiny-random-t5" + + @property + def output_shape(self): + return (1, 8, 8, 3) + + def get_dummy_inputs(self, with_generator=True): + batch_size = 1 + sequence_length = 10 + num_channels = 4 + sizes = (32, 32) + + generator = torch.manual_seed(0) + noise = floats_tensor((batch_size, num_channels) + sizes) + input_ids = torch.randint(1, sequence_length, size=(batch_size, sequence_length), generator=generator) + + pipeline_inputs = { + "prompt": "A painting of a squirrel eating a burger", + "control_image": Image.fromarray(np.random.randint(0, 255, size=(32, 32, 3), dtype="uint8")), + "num_inference_steps": 4, + "guidance_scale": 0.0, + "height": 8, + "width": 8, + "output_type": "np", + } + if with_generator: + pipeline_inputs.update({"generator": generator}) + + return noise, input_ids, pipeline_inputs + + def test_with_norm_in_state_dict(self): + components, _, denoiser_lora_config = self.get_dummy_components(FlowMatchEulerDiscreteScheduler) + pipe = self.pipeline_class(**components) + pipe = pipe.to(torch_device) + pipe.set_progress_bar_config(disable=None) + + _, _, inputs = self.get_dummy_inputs(with_generator=False) + + logger = logging.get_logger("diffusers.loaders.lora_pipeline") + logger.setLevel(logging.INFO) + + original_output = pipe(**inputs, generator=torch.manual_seed(0))[0] + + for norm_layer in ["norm_q", "norm_k", "norm_added_q", "norm_added_k"]: + norm_state_dict = {} + for name, module in pipe.transformer.named_modules(): + if norm_layer not in name or not hasattr(module, "weight") or module.weight is None: + continue + norm_state_dict[f"transformer.{name}.weight"] = torch.randn( + module.weight.shape, device=module.weight.device, dtype=module.weight.dtype + ) + + with CaptureLogger(logger) as cap_logger: + pipe.load_lora_weights(norm_state_dict) + lora_load_output = pipe(**inputs, generator=torch.manual_seed(0))[0] + + self.assertTrue( + cap_logger.out.startswith( + "The provided state dict contains normalization layers in addition to LoRA layers" + ) + ) + self.assertTrue(len(pipe.transformer._transformer_norm_layers) > 0) + + pipe.unload_lora_weights() + lora_unload_output = pipe(**inputs, generator=torch.manual_seed(0))[0] + + self.assertTrue(pipe.transformer._transformer_norm_layers is None) + self.assertTrue(np.allclose(original_output, lora_unload_output, atol=1e-5, rtol=1e-5)) + self.assertFalse( + np.allclose(original_output, lora_load_output, atol=1e-6, rtol=1e-6), f"{norm_layer} is tested" + ) + + with CaptureLogger(logger) as cap_logger: + for key in list(norm_state_dict.keys()): + norm_state_dict[key.replace("norm", "norm_k_something_random")] = norm_state_dict.pop(key) + pipe.load_lora_weights(norm_state_dict) + + self.assertTrue( + cap_logger.out.startswith("Unsupported keys found in state dict when trying to load normalization layers") + ) + + def test_lora_parameter_expanded_shapes(self): + components, _, _ = self.get_dummy_components(FlowMatchEulerDiscreteScheduler) + pipe = self.pipeline_class(**components) + pipe = pipe.to(torch_device) + pipe.set_progress_bar_config(disable=None) + + _, _, inputs = self.get_dummy_inputs(with_generator=False) + original_out = pipe(**inputs, generator=torch.manual_seed(0))[0] + + logger = logging.get_logger("diffusers.loaders.lora_pipeline") + logger.setLevel(logging.DEBUG) + + # Change the transformer config to mimic a real use case. + num_channels_without_control = 4 + transformer = FluxTransformer2DModel.from_config( + components["transformer"].config, in_channels=num_channels_without_control + ).to(torch_device) + self.assertTrue( + transformer.config.in_channels == num_channels_without_control, + f"Expected {num_channels_without_control} channels in the modified transformer but has {transformer.config.in_channels=}", + ) + + original_transformer_state_dict = pipe.transformer.state_dict() + x_embedder_weight = original_transformer_state_dict.pop("x_embedder.weight") + incompatible_keys = transformer.load_state_dict(original_transformer_state_dict, strict=False) + self.assertTrue( + "x_embedder.weight" in incompatible_keys.missing_keys, + "Could not find x_embedder.weight in the missing keys.", + ) + transformer.x_embedder.weight.data.copy_(x_embedder_weight[..., :num_channels_without_control]) + pipe.transformer = transformer + + out_features, in_features = pipe.transformer.x_embedder.weight.shape + rank = 4 + + dummy_lora_A = torch.nn.Linear(2 * in_features, rank, bias=False) + dummy_lora_B = torch.nn.Linear(rank, out_features, bias=False) + lora_state_dict = { + "transformer.x_embedder.lora_A.weight": dummy_lora_A.weight, + "transformer.x_embedder.lora_B.weight": dummy_lora_B.weight, + } + with CaptureLogger(logger) as cap_logger: + pipe.load_lora_weights(lora_state_dict, "adapter-1") + self.assertTrue(check_if_lora_correctly_set(pipe.transformer), "Lora not correctly set in denoiser") + + lora_out = pipe(**inputs, generator=torch.manual_seed(0))[0] + + self.assertFalse(np.allclose(original_out, lora_out, rtol=1e-4, atol=1e-4)) + self.assertTrue(pipe.transformer.x_embedder.weight.data.shape[1] == 2 * in_features) + self.assertTrue(pipe.transformer.config.in_channels == 2 * in_features) + self.assertTrue(cap_logger.out.startswith("Expanding the nn.Linear input/output features for module")) + + components, _, _ = self.get_dummy_components(FlowMatchEulerDiscreteScheduler) + pipe = self.pipeline_class(**components) + pipe = pipe.to(torch_device) + pipe.set_progress_bar_config(disable=None) + dummy_lora_A = torch.nn.Linear(1, rank, bias=False) + dummy_lora_B = torch.nn.Linear(rank, out_features, bias=False) + lora_state_dict = { + "transformer.x_embedder.lora_A.weight": dummy_lora_A.weight, + "transformer.x_embedder.lora_B.weight": dummy_lora_B.weight, + } + # We should error out because lora input features is less than original. We only + # support expanding the module, not shrinking it + with self.assertRaises(NotImplementedError): + pipe.load_lora_weights(lora_state_dict, "adapter-1") + + @require_peft_version_greater("0.13.2") + def test_lora_B_bias(self): + components, _, denoiser_lora_config = self.get_dummy_components(FlowMatchEulerDiscreteScheduler) + pipe = self.pipeline_class(**components) + pipe = pipe.to(torch_device) + pipe.set_progress_bar_config(disable=None) + + # keep track of the bias values of the base layers to perform checks later. + bias_values = {} + for name, module in pipe.transformer.named_modules(): + if any(k in name for k in ["to_q", "to_k", "to_v", "to_out.0"]): + if module.bias is not None: + bias_values[name] = module.bias.data.clone() + + _, _, inputs = self.get_dummy_inputs(with_generator=False) + + logger = logging.get_logger("diffusers.loaders.lora_pipeline") + logger.setLevel(logging.INFO) + + original_output = pipe(**inputs, generator=torch.manual_seed(0))[0] + + denoiser_lora_config.lora_bias = False + pipe.transformer.add_adapter(denoiser_lora_config, "adapter-1") + lora_bias_false_output = pipe(**inputs, generator=torch.manual_seed(0))[0] + pipe.delete_adapters("adapter-1") + + denoiser_lora_config.lora_bias = True + pipe.transformer.add_adapter(denoiser_lora_config, "adapter-1") + lora_bias_true_output = pipe(**inputs, generator=torch.manual_seed(0))[0] + + self.assertFalse(np.allclose(original_output, lora_bias_false_output, atol=1e-3, rtol=1e-3)) + self.assertFalse(np.allclose(original_output, lora_bias_true_output, atol=1e-3, rtol=1e-3)) + self.assertFalse(np.allclose(lora_bias_false_output, lora_bias_true_output, atol=1e-3, rtol=1e-3)) + + # for now this is flux control lora specific but can be generalized later and added to ./utils.py + def test_correct_lora_configs_with_different_ranks(self): + components, _, denoiser_lora_config = self.get_dummy_components(FlowMatchEulerDiscreteScheduler) + pipe = self.pipeline_class(**components) + pipe = pipe.to(torch_device) + pipe.set_progress_bar_config(disable=None) + _, _, inputs = self.get_dummy_inputs(with_generator=False) + + original_output = pipe(**inputs, generator=torch.manual_seed(0))[0] + + pipe.transformer.add_adapter(denoiser_lora_config, "adapter-1") + lora_output_same_rank = pipe(**inputs, generator=torch.manual_seed(0))[0] + pipe.transformer.delete_adapters("adapter-1") + + # change the rank_pattern + updated_rank = denoiser_lora_config.r * 2 + denoiser_lora_config.rank_pattern = {"single_transformer_blocks.0.attn.to_k": updated_rank} + pipe.transformer.add_adapter(denoiser_lora_config, "adapter-1") + assert pipe.transformer.peft_config["adapter-1"].rank_pattern == { + "single_transformer_blocks.0.attn.to_k": updated_rank + } + + lora_output_diff_rank = pipe(**inputs, generator=torch.manual_seed(0))[0] + + self.assertTrue(not np.allclose(original_output, lora_output_same_rank, atol=1e-3, rtol=1e-3)) + self.assertTrue(not np.allclose(lora_output_diff_rank, lora_output_same_rank, atol=1e-3, rtol=1e-3)) + pipe.transformer.delete_adapters("adapter-1") + + # similarly change the alpha_pattern + updated_alpha = denoiser_lora_config.lora_alpha * 2 + denoiser_lora_config.alpha_pattern = {"single_transformer_blocks.0.attn.to_k": updated_alpha} + pipe.transformer.add_adapter(denoiser_lora_config, "adapter-1") + assert pipe.transformer.peft_config["adapter-1"].alpha_pattern == { + "single_transformer_blocks.0.attn.to_k": updated_alpha + } + + lora_output_diff_alpha = pipe(**inputs, generator=torch.manual_seed(0))[0] + + self.assertTrue(not np.allclose(original_output, lora_output_diff_alpha, atol=1e-3, rtol=1e-3)) + self.assertTrue(not np.allclose(lora_output_diff_alpha, lora_output_same_rank, atol=1e-3, rtol=1e-3)) + + @unittest.skip("Not supported in Flux.") + def test_simple_inference_with_text_denoiser_block_scale_for_all_dict_options(self): + pass + + @unittest.skip("Not supported in Flux.") + def test_modify_padding_mode(self): + pass + + @slow @nightly @require_torch_gpu @@ -307,3 +581,66 @@ def test_flux_xlabs_load_lora_with_single_blocks(self): max_diff = numpy_cosine_similarity_distance(expected_slice.flatten(), out_slice) assert max_diff < 1e-3 + + +@nightly +@require_torch_gpu +@require_peft_backend +@require_big_gpu_with_torch_cuda +@pytest.mark.big_gpu_with_torch_cuda +class FluxControlLoRAIntegrationTests(unittest.TestCase): + num_inference_steps = 10 + seed = 0 + prompt = "A robot made of exotic candies and chocolates of different kinds." + + def setUp(self): + super().setUp() + + gc.collect() + torch.cuda.empty_cache() + + self.pipeline = FluxControlPipeline.from_pretrained( + "black-forest-labs/FLUX.1-dev", torch_dtype=torch.bfloat16 + ).to("cuda") + + def tearDown(self): + super().tearDown() + + gc.collect() + torch.cuda.empty_cache() + + @parameterized.expand(["black-forest-labs/FLUX.1-Canny-dev-lora", "black-forest-labs/FLUX.1-Depth-dev-lora"]) + def test_lora(self, lora_ckpt_id): + self.pipeline.load_lora_weights(lora_ckpt_id) + self.pipeline.fuse_lora() + self.pipeline.unload_lora_weights() + + if "Canny" in lora_ckpt_id: + control_image = load_image( + "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/flux-control-lora/canny_condition_image.png" + ) + else: + control_image = load_image( + "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/flux-control-lora/depth_condition_image.png" + ) + + image = self.pipeline( + prompt=self.prompt, + control_image=control_image, + height=1024, + width=1024, + num_inference_steps=self.num_inference_steps, + guidance_scale=30.0 if "Canny" in lora_ckpt_id else 10.0, + output_type="np", + generator=torch.manual_seed(self.seed), + ).images + + out_slice = image[0, -3:, -3:, -1].flatten() + if "Canny" in lora_ckpt_id: + expected_slice = np.array([0.8438, 0.8438, 0.8438, 0.8438, 0.8438, 0.8398, 0.8438, 0.8438, 0.8516]) + else: + expected_slice = np.array([0.8203, 0.8320, 0.8359, 0.8203, 0.8281, 0.8281, 0.8203, 0.8242, 0.8359]) + + max_diff = numpy_cosine_similarity_distance(expected_slice.flatten(), out_slice) + + assert max_diff < 1e-3