From 6e8088f61e0296f9308599c6d9fc4762a878b9de Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tolga=20Cang=C3=B6z?= Date: Sun, 28 Jul 2024 21:19:36 +0300 Subject: [PATCH 01/87] Add initial template --- examples/research_projects/anytext/README.md | 40 + .../anytext/auxiliary_latent_module.py | 3 + .../convert_original_anytext_to_diffusers.py | 2118 +++++++++++++++++ .../anytext/pipeline_anytext.py | 1353 +++++++++++ .../anytext/text_controlnet.py | 387 +++ .../anytext/text_embedding_module.py | 3 + 6 files changed, 3904 insertions(+) create mode 100644 examples/research_projects/anytext/README.md create mode 100644 examples/research_projects/anytext/auxiliary_latent_module.py create mode 100644 examples/research_projects/anytext/convert_original_anytext_to_diffusers.py create mode 100644 examples/research_projects/anytext/pipeline_anytext.py create mode 100644 examples/research_projects/anytext/text_controlnet.py create mode 100644 examples/research_projects/anytext/text_embedding_module.py diff --git a/examples/research_projects/anytext/README.md b/examples/research_projects/anytext/README.md new file mode 100644 index 000000000000..e81142990290 --- /dev/null +++ b/examples/research_projects/anytext/README.md @@ -0,0 +1,40 @@ +# AnyTextPipeline Pipeline + +From the project [page](https://zhendong-wang.github.io/prompt-diffusion.github.io/) + +"With a prompt consisting of a task-specific example pair of images and text guidance, and a new query image, Prompt Diffusion can comprehend the desired task and generate the corresponding output image on both seen (trained) and unseen (new) task types." + +For any usage questions, please refer to the [paper](https://arxiv.org/abs/2305.01115). + +Prepare models by converting them from the [checkpoint](https://huggingface.co/zhendongw/prompt-diffusion) + +To convert the controlnet, use cldm_v15.yaml from the [repository](https://github.com/Zhendong-Wang/Prompt-Diffusion/tree/main/models/): + +```sh +python convert_original_anytext_to_diffusers.py --checkpoint_path path-to-network-step04999.ckpt --original_config_file path-to-cldm_v15.yaml --dump_path path-to-output-directory +``` + +To learn about how to convert the fine-tuned stable diffusion model, see the [Load different Stable Diffusion formats guide](https://huggingface.co/docs/diffusers/main/en/using-diffusers/other-formats). + + +```py +import torch +from pipeline_anytext import AnyTextPipeline +from text_controlnet import TextControlNetModel +from diffusers import DDIMScheduler + + +controlnet = TextControlNetModel.from_pretrained("a/b", subfolder="controlnet", torch_dtype=torch.float16) +model_id = "path-to-model" +pipe = AnyTextPipeline.from_pretrained("a/b", subfolder="base", controlnet=controlnet, torch_dtype=torch.float16, variant="fp16") + +# speed up diffusion process with faster scheduler and memory optimization +pipe.scheduler = DDIMScheduler.from_config(pipe.scheduler.config) +# uncomment following line if torch<2.0 +#pipe.enable_xformers_memory_efficient_attention() +pipe.enable_model_cpu_offload() +# generate image +generator = torch.Generator("cpu").manual_seed(0) +image = pipe("photo of caramel macchiato coffee on the table, top-down perspective, with "Any" "Text" written on it using cream", num_inference_steps=20, generator=generator).images[0] +image +``` diff --git a/examples/research_projects/anytext/auxiliary_latent_module.py b/examples/research_projects/anytext/auxiliary_latent_module.py new file mode 100644 index 000000000000..41f7665cd437 --- /dev/null +++ b/examples/research_projects/anytext/auxiliary_latent_module.py @@ -0,0 +1,3 @@ +# text -> glyph render -> glyph l_g -> glyph block -> + # +> fuse layer +# position l_p -> position block -> diff --git a/examples/research_projects/anytext/convert_original_anytext_to_diffusers.py b/examples/research_projects/anytext/convert_original_anytext_to_diffusers.py new file mode 100644 index 000000000000..b74cdd783435 --- /dev/null +++ b/examples/research_projects/anytext/convert_original_anytext_to_diffusers.py @@ -0,0 +1,2118 @@ +# coding=utf-8 +# Copyright 2023 The HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Conversion script for stable diffusion checkpoints which _only_ contain a controlnet.""" + +import argparse +import re +from contextlib import nullcontext +from io import BytesIO +from typing import Dict, Optional, Union + +import requests +import torch +import yaml +from text_controlnet import TextControlNetModel +from transformers import ( + AutoFeatureExtractor, + BertTokenizerFast, + CLIPImageProcessor, + CLIPTextConfig, + CLIPTextModel, + CLIPTextModelWithProjection, + CLIPTokenizer, + CLIPVisionConfig, + CLIPVisionModelWithProjection, +) + +from diffusers.models import ( + AutoencoderKL, + ControlNetModel, + PriorTransformer, + UNet2DConditionModel, +) +from diffusers.pipelines.latent_diffusion.pipeline_latent_diffusion import LDMBertConfig, LDMBertModel +from diffusers.pipelines.paint_by_example import PaintByExampleImageEncoder +from diffusers.pipelines.pipeline_utils import DiffusionPipeline +from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker +from diffusers.pipelines.stable_diffusion.stable_unclip_image_normalizer import StableUnCLIPImageNormalizer +from diffusers.schedulers import ( + DDIMScheduler, + DDPMScheduler, + DPMSolverMultistepScheduler, + EulerAncestralDiscreteScheduler, + EulerDiscreteScheduler, + HeunDiscreteScheduler, + LMSDiscreteScheduler, + PNDMScheduler, + UnCLIPScheduler, +) +from diffusers.utils import is_accelerate_available, logging + + +if is_accelerate_available(): + from accelerate import init_empty_weights + from accelerate.utils import set_module_tensor_to_device + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +def shave_segments(path, n_shave_prefix_segments=1): + """ + Removes segments. Positive values shave the first segments, negative shave the last segments. + """ + if n_shave_prefix_segments >= 0: + return ".".join(path.split(".")[n_shave_prefix_segments:]) + else: + return ".".join(path.split(".")[:n_shave_prefix_segments]) + + +def renew_resnet_paths(old_list, n_shave_prefix_segments=0): + """ + Updates paths inside resnets to the new naming scheme (local renaming) + """ + mapping = [] + for old_item in old_list: + new_item = old_item.replace("in_layers.0", "norm1") + new_item = new_item.replace("in_layers.2", "conv1") + + new_item = new_item.replace("out_layers.0", "norm2") + new_item = new_item.replace("out_layers.3", "conv2") + + new_item = new_item.replace("emb_layers.1", "time_emb_proj") + new_item = new_item.replace("skip_connection", "conv_shortcut") + + new_item = shave_segments(new_item, n_shave_prefix_segments=n_shave_prefix_segments) + + mapping.append({"old": old_item, "new": new_item}) + + return mapping + + +def renew_vae_resnet_paths(old_list, n_shave_prefix_segments=0): + """ + Updates paths inside resnets to the new naming scheme (local renaming) + """ + mapping = [] + for old_item in old_list: + new_item = old_item + + new_item = new_item.replace("nin_shortcut", "conv_shortcut") + new_item = shave_segments(new_item, n_shave_prefix_segments=n_shave_prefix_segments) + + mapping.append({"old": old_item, "new": new_item}) + + return mapping + + +def renew_attention_paths(old_list, n_shave_prefix_segments=0): + """ + Updates paths inside attentions to the new naming scheme (local renaming) + """ + mapping = [] + for old_item in old_list: + new_item = old_item + + # new_item = new_item.replace('norm.weight', 'group_norm.weight') + # new_item = new_item.replace('norm.bias', 'group_norm.bias') + + # new_item = new_item.replace('proj_out.weight', 'proj_attn.weight') + # new_item = new_item.replace('proj_out.bias', 'proj_attn.bias') + + # new_item = shave_segments(new_item, n_shave_prefix_segments=n_shave_prefix_segments) + + mapping.append({"old": old_item, "new": new_item}) + + return mapping + + +def renew_vae_attention_paths(old_list, n_shave_prefix_segments=0): + """ + Updates paths inside attentions to the new naming scheme (local renaming) + """ + mapping = [] + for old_item in old_list: + new_item = old_item + + new_item = new_item.replace("norm.weight", "group_norm.weight") + new_item = new_item.replace("norm.bias", "group_norm.bias") + + new_item = new_item.replace("q.weight", "to_q.weight") + new_item = new_item.replace("q.bias", "to_q.bias") + + new_item = new_item.replace("k.weight", "to_k.weight") + new_item = new_item.replace("k.bias", "to_k.bias") + + new_item = new_item.replace("v.weight", "to_v.weight") + new_item = new_item.replace("v.bias", "to_v.bias") + + new_item = new_item.replace("proj_out.weight", "to_out.0.weight") + new_item = new_item.replace("proj_out.bias", "to_out.0.bias") + + new_item = shave_segments(new_item, n_shave_prefix_segments=n_shave_prefix_segments) + + mapping.append({"old": old_item, "new": new_item}) + + return mapping + + +def assign_to_checkpoint( + paths, checkpoint, old_checkpoint, attention_paths_to_split=None, additional_replacements=None, config=None +): + """ + This does the final conversion step: take locally converted weights and apply a global renaming to them. It splits + attention layers, and takes into account additional replacements that may arise. + + Assigns the weights to the new checkpoint. + """ + assert isinstance(paths, list), "Paths should be a list of dicts containing 'old' and 'new' keys." + + # Splits the attention layers into three variables. + if attention_paths_to_split is not None: + for path, path_map in attention_paths_to_split.items(): + old_tensor = old_checkpoint[path] + channels = old_tensor.shape[0] // 3 + + target_shape = (-1, channels) if len(old_tensor.shape) == 3 else (-1) + + num_heads = old_tensor.shape[0] // config["num_head_channels"] // 3 + + old_tensor = old_tensor.reshape((num_heads, 3 * channels // num_heads) + old_tensor.shape[1:]) + query, key, value = old_tensor.split(channels // num_heads, dim=1) + + checkpoint[path_map["query"]] = query.reshape(target_shape) + checkpoint[path_map["key"]] = key.reshape(target_shape) + checkpoint[path_map["value"]] = value.reshape(target_shape) + + for path in paths: + new_path = path["new"] + + # These have already been assigned + if attention_paths_to_split is not None and new_path in attention_paths_to_split: + continue + + # Global renaming happens here + new_path = new_path.replace("middle_block.0", "mid_block.resnets.0") + new_path = new_path.replace("middle_block.1", "mid_block.attentions.0") + new_path = new_path.replace("middle_block.2", "mid_block.resnets.1") + + if additional_replacements is not None: + for replacement in additional_replacements: + new_path = new_path.replace(replacement["old"], replacement["new"]) + + # proj_attn.weight has to be converted from conv 1D to linear + is_attn_weight = "proj_attn.weight" in new_path or ("attentions" in new_path and "to_" in new_path) + shape = old_checkpoint[path["old"]].shape + if is_attn_weight and len(shape) == 3: + checkpoint[new_path] = old_checkpoint[path["old"]][:, :, 0] + elif is_attn_weight and len(shape) == 4: + checkpoint[new_path] = old_checkpoint[path["old"]][:, :, 0, 0] + else: + checkpoint[new_path] = old_checkpoint[path["old"]] + + +def conv_attn_to_linear(checkpoint): + keys = list(checkpoint.keys()) + attn_keys = ["query.weight", "key.weight", "value.weight"] + for key in keys: + if ".".join(key.split(".")[-2:]) in attn_keys: + if checkpoint[key].ndim > 2: + checkpoint[key] = checkpoint[key][:, :, 0, 0] + elif "proj_attn.weight" in key: + if checkpoint[key].ndim > 2: + checkpoint[key] = checkpoint[key][:, :, 0] + + +def create_unet_diffusers_config(original_config, image_size: int, controlnet=False): + """ + Creates a config for the diffusers based on the config of the LDM model. + """ + if controlnet: + unet_params = original_config["model"]["params"]["control_stage_config"]["params"] + else: + if ( + "unet_config" in original_config["model"]["params"] + and original_config["model"]["params"]["unet_config"] is not None + ): + unet_params = original_config["model"]["params"]["unet_config"]["params"] + else: + unet_params = original_config["model"]["params"]["network_config"]["params"] + + vae_params = original_config["model"]["params"]["first_stage_config"]["params"]["ddconfig"] + + block_out_channels = [unet_params["model_channels"] * mult for mult in unet_params["channel_mult"]] + + down_block_types = [] + resolution = 1 + for i in range(len(block_out_channels)): + block_type = "CrossAttnDownBlock2D" if resolution in unet_params["attention_resolutions"] else "DownBlock2D" + down_block_types.append(block_type) + if i != len(block_out_channels) - 1: + resolution *= 2 + + up_block_types = [] + for i in range(len(block_out_channels)): + block_type = "CrossAttnUpBlock2D" if resolution in unet_params["attention_resolutions"] else "UpBlock2D" + up_block_types.append(block_type) + resolution //= 2 + + if unet_params["transformer_depth"] is not None: + transformer_layers_per_block = ( + unet_params["transformer_depth"] + if isinstance(unet_params["transformer_depth"], int) + else list(unet_params["transformer_depth"]) + ) + else: + transformer_layers_per_block = 1 + + vae_scale_factor = 2 ** (len(vae_params["ch_mult"]) - 1) + + head_dim = unet_params["num_heads"] if "num_heads" in unet_params else None + use_linear_projection = ( + unet_params["use_linear_in_transformer"] if "use_linear_in_transformer" in unet_params else False + ) + if use_linear_projection: + # stable diffusion 2-base-512 and 2-768 + if head_dim is None: + head_dim_mult = unet_params["model_channels"] // unet_params["num_head_channels"] + head_dim = [head_dim_mult * c for c in list(unet_params["channel_mult"])] + + class_embed_type = None + addition_embed_type = None + addition_time_embed_dim = None + projection_class_embeddings_input_dim = None + context_dim = None + + if unet_params["context_dim"] is not None: + context_dim = ( + unet_params["context_dim"] + if isinstance(unet_params["context_dim"], int) + else unet_params["context_dim"][0] + ) + + if "num_classes" in unet_params: + if unet_params["num_classes"] == "sequential": + if context_dim in [2048, 1280]: + # SDXL + addition_embed_type = "text_time" + addition_time_embed_dim = 256 + else: + class_embed_type = "projection" + assert "adm_in_channels" in unet_params + projection_class_embeddings_input_dim = unet_params["adm_in_channels"] + + config = { + "sample_size": image_size // vae_scale_factor, + "in_channels": unet_params["in_channels"], + "down_block_types": tuple(down_block_types), + "block_out_channels": tuple(block_out_channels), + "layers_per_block": unet_params["num_res_blocks"], + "cross_attention_dim": context_dim, + "attention_head_dim": head_dim, + "use_linear_projection": use_linear_projection, + "class_embed_type": class_embed_type, + "addition_embed_type": addition_embed_type, + "addition_time_embed_dim": addition_time_embed_dim, + "projection_class_embeddings_input_dim": projection_class_embeddings_input_dim, + "transformer_layers_per_block": transformer_layers_per_block, + } + + if "disable_self_attentions" in unet_params: + config["only_cross_attention"] = unet_params["disable_self_attentions"] + + if "num_classes" in unet_params and isinstance(unet_params["num_classes"], int): + config["num_class_embeds"] = unet_params["num_classes"] + + if controlnet: + config["conditioning_channels"] = unet_params["hint_channels"] + else: + config["out_channels"] = unet_params["out_channels"] + config["up_block_types"] = tuple(up_block_types) + + return config + + +def create_vae_diffusers_config(original_config, image_size: int): + """ + Creates a config for the diffusers based on the config of the LDM model. + """ + vae_params = original_config["model"]["params"]["first_stage_config"]["params"]["ddconfig"] + _ = original_config["model"]["params"]["first_stage_config"]["params"]["embed_dim"] + + block_out_channels = [vae_params["ch"] * mult for mult in vae_params["ch_mult"]] + down_block_types = ["DownEncoderBlock2D"] * len(block_out_channels) + up_block_types = ["UpDecoderBlock2D"] * len(block_out_channels) + + config = { + "sample_size": image_size, + "in_channels": vae_params["in_channels"], + "out_channels": vae_params["out_ch"], + "down_block_types": tuple(down_block_types), + "up_block_types": tuple(up_block_types), + "block_out_channels": tuple(block_out_channels), + "latent_channels": vae_params["z_channels"], + "layers_per_block": vae_params["num_res_blocks"], + } + return config + + +def create_diffusers_schedular(original_config): + schedular = DDIMScheduler( + num_train_timesteps=original_config["model"]["params"]["timesteps"], + beta_start=original_config["model"]["params"]["linear_start"], + beta_end=original_config["model"]["params"]["linear_end"], + beta_schedule="scaled_linear", + ) + return schedular + + +def create_ldm_bert_config(original_config): + bert_params = original_config["model"]["params"]["cond_stage_config"]["params"] + config = LDMBertConfig( + d_model=bert_params.n_embed, + encoder_layers=bert_params.n_layer, + encoder_ffn_dim=bert_params.n_embed * 4, + ) + return config + + +def convert_ldm_unet_checkpoint( + checkpoint, + config, + path=None, + extract_ema=False, + controlnet=False, + skip_extract_state_dict=False, + promptdiffusion=False, +): + """ + Takes a state dict and a config, and returns a converted checkpoint. + """ + + if skip_extract_state_dict: + unet_state_dict = checkpoint + else: + # extract state_dict for UNet + unet_state_dict = {} + keys = list(checkpoint.keys()) + + if controlnet: + unet_key = "control_model." + else: + unet_key = "model.diffusion_model." + + # at least a 100 parameters have to start with `model_ema` in order for the checkpoint to be EMA + if sum(k.startswith("model_ema") for k in keys) > 100 and extract_ema: + logger.warning(f"Checkpoint {path} has both EMA and non-EMA weights.") + logger.warning( + "In this conversion only the EMA weights are extracted. If you want to instead extract the non-EMA" + " weights (useful to continue fine-tuning), please make sure to remove the `--extract_ema` flag." + ) + for key in keys: + if key.startswith("model.diffusion_model"): + flat_ema_key = "model_ema." + "".join(key.split(".")[1:]) + unet_state_dict[key.replace(unet_key, "")] = checkpoint.pop(flat_ema_key) + else: + if sum(k.startswith("model_ema") for k in keys) > 100: + logger.warning( + "In this conversion only the non-EMA weights are extracted. If you want to instead extract the EMA" + " weights (usually better for inference), please make sure to add the `--extract_ema` flag." + ) + + for key in keys: + if key.startswith(unet_key): + unet_state_dict[key.replace(unet_key, "")] = checkpoint.pop(key) + + new_checkpoint = {} + + new_checkpoint["time_embedding.linear_1.weight"] = unet_state_dict["time_embed.0.weight"] + new_checkpoint["time_embedding.linear_1.bias"] = unet_state_dict["time_embed.0.bias"] + new_checkpoint["time_embedding.linear_2.weight"] = unet_state_dict["time_embed.2.weight"] + new_checkpoint["time_embedding.linear_2.bias"] = unet_state_dict["time_embed.2.bias"] + + if config["class_embed_type"] is None: + # No parameters to port + ... + elif config["class_embed_type"] == "timestep" or config["class_embed_type"] == "projection": + new_checkpoint["class_embedding.linear_1.weight"] = unet_state_dict["label_emb.0.0.weight"] + new_checkpoint["class_embedding.linear_1.bias"] = unet_state_dict["label_emb.0.0.bias"] + new_checkpoint["class_embedding.linear_2.weight"] = unet_state_dict["label_emb.0.2.weight"] + new_checkpoint["class_embedding.linear_2.bias"] = unet_state_dict["label_emb.0.2.bias"] + else: + raise NotImplementedError(f"Not implemented `class_embed_type`: {config['class_embed_type']}") + + if config["addition_embed_type"] == "text_time": + new_checkpoint["add_embedding.linear_1.weight"] = unet_state_dict["label_emb.0.0.weight"] + new_checkpoint["add_embedding.linear_1.bias"] = unet_state_dict["label_emb.0.0.bias"] + new_checkpoint["add_embedding.linear_2.weight"] = unet_state_dict["label_emb.0.2.weight"] + new_checkpoint["add_embedding.linear_2.bias"] = unet_state_dict["label_emb.0.2.bias"] + + # Relevant to StableDiffusionUpscalePipeline + if "num_class_embeds" in config: + if (config["num_class_embeds"] is not None) and ("label_emb.weight" in unet_state_dict): + new_checkpoint["class_embedding.weight"] = unet_state_dict["label_emb.weight"] + + new_checkpoint["conv_in.weight"] = unet_state_dict["input_blocks.0.0.weight"] + new_checkpoint["conv_in.bias"] = unet_state_dict["input_blocks.0.0.bias"] + + if not controlnet: + new_checkpoint["conv_norm_out.weight"] = unet_state_dict["out.0.weight"] + new_checkpoint["conv_norm_out.bias"] = unet_state_dict["out.0.bias"] + new_checkpoint["conv_out.weight"] = unet_state_dict["out.2.weight"] + new_checkpoint["conv_out.bias"] = unet_state_dict["out.2.bias"] + + # Retrieves the keys for the input blocks only + num_input_blocks = len({".".join(layer.split(".")[:2]) for layer in unet_state_dict if "input_blocks" in layer}) + input_blocks = { + layer_id: [key for key in unet_state_dict if f"input_blocks.{layer_id}" in key] + for layer_id in range(num_input_blocks) + } + + # Retrieves the keys for the middle blocks only + num_middle_blocks = len({".".join(layer.split(".")[:2]) for layer in unet_state_dict if "middle_block" in layer}) + middle_blocks = { + layer_id: [key for key in unet_state_dict if f"middle_block.{layer_id}" in key] + for layer_id in range(num_middle_blocks) + } + + # Retrieves the keys for the output blocks only + num_output_blocks = len({".".join(layer.split(".")[:2]) for layer in unet_state_dict if "output_blocks" in layer}) + output_blocks = { + layer_id: [key for key in unet_state_dict if f"output_blocks.{layer_id}" in key] + for layer_id in range(num_output_blocks) + } + + for i in range(1, num_input_blocks): + block_id = (i - 1) // (config["layers_per_block"] + 1) + layer_in_block_id = (i - 1) % (config["layers_per_block"] + 1) + + resnets = [ + key for key in input_blocks[i] if f"input_blocks.{i}.0" in key and f"input_blocks.{i}.0.op" not in key + ] + attentions = [key for key in input_blocks[i] if f"input_blocks.{i}.1" in key] + + if f"input_blocks.{i}.0.op.weight" in unet_state_dict: + new_checkpoint[f"down_blocks.{block_id}.downsamplers.0.conv.weight"] = unet_state_dict.pop( + f"input_blocks.{i}.0.op.weight" + ) + new_checkpoint[f"down_blocks.{block_id}.downsamplers.0.conv.bias"] = unet_state_dict.pop( + f"input_blocks.{i}.0.op.bias" + ) + + paths = renew_resnet_paths(resnets) + meta_path = {"old": f"input_blocks.{i}.0", "new": f"down_blocks.{block_id}.resnets.{layer_in_block_id}"} + assign_to_checkpoint( + paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config + ) + + if len(attentions): + paths = renew_attention_paths(attentions) + + meta_path = {"old": f"input_blocks.{i}.1", "new": f"down_blocks.{block_id}.attentions.{layer_in_block_id}"} + assign_to_checkpoint( + paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config + ) + + resnet_0 = middle_blocks[0] + attentions = middle_blocks[1] + resnet_1 = middle_blocks[2] + + resnet_0_paths = renew_resnet_paths(resnet_0) + assign_to_checkpoint(resnet_0_paths, new_checkpoint, unet_state_dict, config=config) + + resnet_1_paths = renew_resnet_paths(resnet_1) + assign_to_checkpoint(resnet_1_paths, new_checkpoint, unet_state_dict, config=config) + + attentions_paths = renew_attention_paths(attentions) + meta_path = {"old": "middle_block.1", "new": "mid_block.attentions.0"} + assign_to_checkpoint( + attentions_paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config + ) + + for i in range(num_output_blocks): + block_id = i // (config["layers_per_block"] + 1) + layer_in_block_id = i % (config["layers_per_block"] + 1) + output_block_layers = [shave_segments(name, 2) for name in output_blocks[i]] + output_block_list = {} + + for layer in output_block_layers: + layer_id, layer_name = layer.split(".")[0], shave_segments(layer, 1) + if layer_id in output_block_list: + output_block_list[layer_id].append(layer_name) + else: + output_block_list[layer_id] = [layer_name] + + if len(output_block_list) > 1: + resnets = [key for key in output_blocks[i] if f"output_blocks.{i}.0" in key] + attentions = [key for key in output_blocks[i] if f"output_blocks.{i}.1" in key] + + resnet_0_paths = renew_resnet_paths(resnets) + paths = renew_resnet_paths(resnets) + + meta_path = {"old": f"output_blocks.{i}.0", "new": f"up_blocks.{block_id}.resnets.{layer_in_block_id}"} + assign_to_checkpoint( + paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config + ) + + output_block_list = {k: sorted(v) for k, v in output_block_list.items()} + if ["conv.bias", "conv.weight"] in output_block_list.values(): + index = list(output_block_list.values()).index(["conv.bias", "conv.weight"]) + new_checkpoint[f"up_blocks.{block_id}.upsamplers.0.conv.weight"] = unet_state_dict[ + f"output_blocks.{i}.{index}.conv.weight" + ] + new_checkpoint[f"up_blocks.{block_id}.upsamplers.0.conv.bias"] = unet_state_dict[ + f"output_blocks.{i}.{index}.conv.bias" + ] + + # Clear attentions as they have been attributed above. + if len(attentions) == 2: + attentions = [] + + if len(attentions): + paths = renew_attention_paths(attentions) + meta_path = { + "old": f"output_blocks.{i}.1", + "new": f"up_blocks.{block_id}.attentions.{layer_in_block_id}", + } + assign_to_checkpoint( + paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config + ) + else: + resnet_0_paths = renew_resnet_paths(output_block_layers, n_shave_prefix_segments=1) + for path in resnet_0_paths: + old_path = ".".join(["output_blocks", str(i), path["old"]]) + new_path = ".".join(["up_blocks", str(block_id), "resnets", str(layer_in_block_id), path["new"]]) + + new_checkpoint[new_path] = unet_state_dict[old_path] + + if controlnet and not promptdiffusion: + # conditioning embedding + + orig_index = 0 + + new_checkpoint["controlnet_cond_embedding.conv_in.weight"] = unet_state_dict.pop( + f"input_hint_block.{orig_index}.weight" + ) + new_checkpoint["controlnet_cond_embedding.conv_in.bias"] = unet_state_dict.pop( + f"input_hint_block.{orig_index}.bias" + ) + + orig_index += 2 + + diffusers_index = 0 + + while diffusers_index < 6: + new_checkpoint[f"controlnet_cond_embedding.blocks.{diffusers_index}.weight"] = unet_state_dict.pop( + f"input_hint_block.{orig_index}.weight" + ) + new_checkpoint[f"controlnet_cond_embedding.blocks.{diffusers_index}.bias"] = unet_state_dict.pop( + f"input_hint_block.{orig_index}.bias" + ) + diffusers_index += 1 + orig_index += 2 + + new_checkpoint["controlnet_cond_embedding.conv_out.weight"] = unet_state_dict.pop( + f"input_hint_block.{orig_index}.weight" + ) + new_checkpoint["controlnet_cond_embedding.conv_out.bias"] = unet_state_dict.pop( + f"input_hint_block.{orig_index}.bias" + ) + + # down blocks + for i in range(num_input_blocks): + new_checkpoint[f"controlnet_down_blocks.{i}.weight"] = unet_state_dict.pop(f"zero_convs.{i}.0.weight") + new_checkpoint[f"controlnet_down_blocks.{i}.bias"] = unet_state_dict.pop(f"zero_convs.{i}.0.bias") + + # mid block + new_checkpoint["controlnet_mid_block.weight"] = unet_state_dict.pop("middle_block_out.0.weight") + new_checkpoint["controlnet_mid_block.bias"] = unet_state_dict.pop("middle_block_out.0.bias") + + if promptdiffusion: + # conditioning embedding + + orig_index = 0 + + new_checkpoint["controlnet_cond_embedding.conv_in.weight"] = unet_state_dict.pop( + f"input_hint_block.{orig_index}.weight" + ) + new_checkpoint["controlnet_cond_embedding.conv_in.bias"] = unet_state_dict.pop( + f"input_hint_block.{orig_index}.bias" + ) + + new_checkpoint["controlnet_query_cond_embedding.conv_in.weight"] = unet_state_dict.pop( + f"input_cond_block.{orig_index}.weight" + ) + new_checkpoint["controlnet_query_cond_embedding.conv_in.bias"] = unet_state_dict.pop( + f"input_cond_block.{orig_index}.bias" + ) + orig_index += 2 + + diffusers_index = 0 + + while diffusers_index < 6: + new_checkpoint[f"controlnet_cond_embedding.blocks.{diffusers_index}.weight"] = unet_state_dict.pop( + f"input_hint_block.{orig_index}.weight" + ) + new_checkpoint[f"controlnet_cond_embedding.blocks.{diffusers_index}.bias"] = unet_state_dict.pop( + f"input_hint_block.{orig_index}.bias" + ) + new_checkpoint[f"controlnet_query_cond_embedding.blocks.{diffusers_index}.weight"] = unet_state_dict.pop( + f"input_cond_block.{orig_index}.weight" + ) + new_checkpoint[f"controlnet_query_cond_embedding.blocks.{diffusers_index}.bias"] = unet_state_dict.pop( + f"input_cond_block.{orig_index}.bias" + ) + diffusers_index += 1 + orig_index += 2 + + new_checkpoint["controlnet_cond_embedding.conv_out.weight"] = unet_state_dict.pop( + f"input_hint_block.{orig_index}.weight" + ) + new_checkpoint["controlnet_cond_embedding.conv_out.bias"] = unet_state_dict.pop( + f"input_hint_block.{orig_index}.bias" + ) + + new_checkpoint["controlnet_query_cond_embedding.conv_out.weight"] = unet_state_dict.pop( + f"input_cond_block.{orig_index}.weight" + ) + new_checkpoint["controlnet_query_cond_embedding.conv_out.bias"] = unet_state_dict.pop( + f"input_cond_block.{orig_index}.bias" + ) + # down blocks + for i in range(num_input_blocks): + new_checkpoint[f"controlnet_down_blocks.{i}.weight"] = unet_state_dict.pop(f"zero_convs.{i}.0.weight") + new_checkpoint[f"controlnet_down_blocks.{i}.bias"] = unet_state_dict.pop(f"zero_convs.{i}.0.bias") + + # mid block + new_checkpoint["controlnet_mid_block.weight"] = unet_state_dict.pop("middle_block_out.0.weight") + new_checkpoint["controlnet_mid_block.bias"] = unet_state_dict.pop("middle_block_out.0.bias") + + return new_checkpoint + + +def convert_ldm_vae_checkpoint(checkpoint, config): + # extract state dict for VAE + vae_state_dict = {} + keys = list(checkpoint.keys()) + vae_key = "first_stage_model." if any(k.startswith("first_stage_model.") for k in keys) else "" + for key in keys: + if key.startswith(vae_key): + vae_state_dict[key.replace(vae_key, "")] = checkpoint.get(key) + + new_checkpoint = {} + + new_checkpoint["encoder.conv_in.weight"] = vae_state_dict["encoder.conv_in.weight"] + new_checkpoint["encoder.conv_in.bias"] = vae_state_dict["encoder.conv_in.bias"] + new_checkpoint["encoder.conv_out.weight"] = vae_state_dict["encoder.conv_out.weight"] + new_checkpoint["encoder.conv_out.bias"] = vae_state_dict["encoder.conv_out.bias"] + new_checkpoint["encoder.conv_norm_out.weight"] = vae_state_dict["encoder.norm_out.weight"] + new_checkpoint["encoder.conv_norm_out.bias"] = vae_state_dict["encoder.norm_out.bias"] + + new_checkpoint["decoder.conv_in.weight"] = vae_state_dict["decoder.conv_in.weight"] + new_checkpoint["decoder.conv_in.bias"] = vae_state_dict["decoder.conv_in.bias"] + new_checkpoint["decoder.conv_out.weight"] = vae_state_dict["decoder.conv_out.weight"] + new_checkpoint["decoder.conv_out.bias"] = vae_state_dict["decoder.conv_out.bias"] + new_checkpoint["decoder.conv_norm_out.weight"] = vae_state_dict["decoder.norm_out.weight"] + new_checkpoint["decoder.conv_norm_out.bias"] = vae_state_dict["decoder.norm_out.bias"] + + new_checkpoint["quant_conv.weight"] = vae_state_dict["quant_conv.weight"] + new_checkpoint["quant_conv.bias"] = vae_state_dict["quant_conv.bias"] + new_checkpoint["post_quant_conv.weight"] = vae_state_dict["post_quant_conv.weight"] + new_checkpoint["post_quant_conv.bias"] = vae_state_dict["post_quant_conv.bias"] + + # Retrieves the keys for the encoder down blocks only + num_down_blocks = len({".".join(layer.split(".")[:3]) for layer in vae_state_dict if "encoder.down" in layer}) + down_blocks = { + layer_id: [key for key in vae_state_dict if f"down.{layer_id}" in key] for layer_id in range(num_down_blocks) + } + + # Retrieves the keys for the decoder up blocks only + num_up_blocks = len({".".join(layer.split(".")[:3]) for layer in vae_state_dict if "decoder.up" in layer}) + up_blocks = { + layer_id: [key for key in vae_state_dict if f"up.{layer_id}" in key] for layer_id in range(num_up_blocks) + } + + for i in range(num_down_blocks): + resnets = [key for key in down_blocks[i] if f"down.{i}" in key and f"down.{i}.downsample" not in key] + + if f"encoder.down.{i}.downsample.conv.weight" in vae_state_dict: + new_checkpoint[f"encoder.down_blocks.{i}.downsamplers.0.conv.weight"] = vae_state_dict.pop( + f"encoder.down.{i}.downsample.conv.weight" + ) + new_checkpoint[f"encoder.down_blocks.{i}.downsamplers.0.conv.bias"] = vae_state_dict.pop( + f"encoder.down.{i}.downsample.conv.bias" + ) + + paths = renew_vae_resnet_paths(resnets) + meta_path = {"old": f"down.{i}.block", "new": f"down_blocks.{i}.resnets"} + assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config) + + mid_resnets = [key for key in vae_state_dict if "encoder.mid.block" in key] + num_mid_res_blocks = 2 + for i in range(1, num_mid_res_blocks + 1): + resnets = [key for key in mid_resnets if f"encoder.mid.block_{i}" in key] + + paths = renew_vae_resnet_paths(resnets) + meta_path = {"old": f"mid.block_{i}", "new": f"mid_block.resnets.{i - 1}"} + assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config) + + mid_attentions = [key for key in vae_state_dict if "encoder.mid.attn" in key] + paths = renew_vae_attention_paths(mid_attentions) + meta_path = {"old": "mid.attn_1", "new": "mid_block.attentions.0"} + assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config) + conv_attn_to_linear(new_checkpoint) + + for i in range(num_up_blocks): + block_id = num_up_blocks - 1 - i + resnets = [ + key for key in up_blocks[block_id] if f"up.{block_id}" in key and f"up.{block_id}.upsample" not in key + ] + + if f"decoder.up.{block_id}.upsample.conv.weight" in vae_state_dict: + new_checkpoint[f"decoder.up_blocks.{i}.upsamplers.0.conv.weight"] = vae_state_dict[ + f"decoder.up.{block_id}.upsample.conv.weight" + ] + new_checkpoint[f"decoder.up_blocks.{i}.upsamplers.0.conv.bias"] = vae_state_dict[ + f"decoder.up.{block_id}.upsample.conv.bias" + ] + + paths = renew_vae_resnet_paths(resnets) + meta_path = {"old": f"up.{block_id}.block", "new": f"up_blocks.{i}.resnets"} + assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config) + + mid_resnets = [key for key in vae_state_dict if "decoder.mid.block" in key] + num_mid_res_blocks = 2 + for i in range(1, num_mid_res_blocks + 1): + resnets = [key for key in mid_resnets if f"decoder.mid.block_{i}" in key] + + paths = renew_vae_resnet_paths(resnets) + meta_path = {"old": f"mid.block_{i}", "new": f"mid_block.resnets.{i - 1}"} + assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config) + + mid_attentions = [key for key in vae_state_dict if "decoder.mid.attn" in key] + paths = renew_vae_attention_paths(mid_attentions) + meta_path = {"old": "mid.attn_1", "new": "mid_block.attentions.0"} + assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config) + conv_attn_to_linear(new_checkpoint) + return new_checkpoint + + +def convert_ldm_bert_checkpoint(checkpoint, config): + def _copy_attn_layer(hf_attn_layer, pt_attn_layer): + hf_attn_layer.q_proj.weight.data = pt_attn_layer.to_q.weight + hf_attn_layer.k_proj.weight.data = pt_attn_layer.to_k.weight + hf_attn_layer.v_proj.weight.data = pt_attn_layer.to_v.weight + + hf_attn_layer.out_proj.weight = pt_attn_layer.to_out.weight + hf_attn_layer.out_proj.bias = pt_attn_layer.to_out.bias + + def _copy_linear(hf_linear, pt_linear): + hf_linear.weight = pt_linear.weight + hf_linear.bias = pt_linear.bias + + def _copy_layer(hf_layer, pt_layer): + # copy layer norms + _copy_linear(hf_layer.self_attn_layer_norm, pt_layer[0][0]) + _copy_linear(hf_layer.final_layer_norm, pt_layer[1][0]) + + # copy attn + _copy_attn_layer(hf_layer.self_attn, pt_layer[0][1]) + + # copy MLP + pt_mlp = pt_layer[1][1] + _copy_linear(hf_layer.fc1, pt_mlp.net[0][0]) + _copy_linear(hf_layer.fc2, pt_mlp.net[2]) + + def _copy_layers(hf_layers, pt_layers): + for i, hf_layer in enumerate(hf_layers): + if i != 0: + i += i + pt_layer = pt_layers[i : i + 2] + _copy_layer(hf_layer, pt_layer) + + hf_model = LDMBertModel(config).eval() + + # copy embeds + hf_model.model.embed_tokens.weight = checkpoint.transformer.token_emb.weight + hf_model.model.embed_positions.weight.data = checkpoint.transformer.pos_emb.emb.weight + + # copy layer norm + _copy_linear(hf_model.model.layer_norm, checkpoint.transformer.norm) + + # copy hidden layers + _copy_layers(hf_model.model.layers, checkpoint.transformer.attn_layers.layers) + + _copy_linear(hf_model.to_logits, checkpoint.transformer.to_logits) + + return hf_model + + +def convert_ldm_clip_checkpoint(checkpoint, local_files_only=False, text_encoder=None): + if text_encoder is None: + config_name = "openai/clip-vit-large-patch14" + try: + config = CLIPTextConfig.from_pretrained(config_name, local_files_only=local_files_only) + except Exception: + raise ValueError( + f"With local_files_only set to {local_files_only}, you must first locally save the configuration in the following path: 'openai/clip-vit-large-patch14'." + ) + + ctx = init_empty_weights if is_accelerate_available() else nullcontext + with ctx(): + text_model = CLIPTextModel(config) + else: + text_model = text_encoder + + keys = list(checkpoint.keys()) + + text_model_dict = {} + + remove_prefixes = ["cond_stage_model.transformer", "conditioner.embedders.0.transformer"] + + for key in keys: + for prefix in remove_prefixes: + if key.startswith(prefix): + text_model_dict[key[len(prefix + ".") :]] = checkpoint[key] + + if is_accelerate_available(): + for param_name, param in text_model_dict.items(): + set_module_tensor_to_device(text_model, param_name, "cpu", value=param) + else: + if not (hasattr(text_model, "embeddings") and hasattr(text_model.embeddings.position_ids)): + text_model_dict.pop("text_model.embeddings.position_ids", None) + + text_model.load_state_dict(text_model_dict) + + return text_model + + +textenc_conversion_lst = [ + ("positional_embedding", "text_model.embeddings.position_embedding.weight"), + ("token_embedding.weight", "text_model.embeddings.token_embedding.weight"), + ("ln_final.weight", "text_model.final_layer_norm.weight"), + ("ln_final.bias", "text_model.final_layer_norm.bias"), + ("text_projection", "text_projection.weight"), +] +textenc_conversion_map = {x[0]: x[1] for x in textenc_conversion_lst} + +textenc_transformer_conversion_lst = [ + # (stable-diffusion, HF Diffusers) + ("resblocks.", "text_model.encoder.layers."), + ("ln_1", "layer_norm1"), + ("ln_2", "layer_norm2"), + (".c_fc.", ".fc1."), + (".c_proj.", ".fc2."), + (".attn", ".self_attn"), + ("ln_final.", "transformer.text_model.final_layer_norm."), + ("token_embedding.weight", "transformer.text_model.embeddings.token_embedding.weight"), + ("positional_embedding", "transformer.text_model.embeddings.position_embedding.weight"), +] +protected = {re.escape(x[0]): x[1] for x in textenc_transformer_conversion_lst} +textenc_pattern = re.compile("|".join(protected.keys())) + + +def convert_paint_by_example_checkpoint(checkpoint, local_files_only=False): + config = CLIPVisionConfig.from_pretrained("openai/clip-vit-large-patch14", local_files_only=local_files_only) + model = PaintByExampleImageEncoder(config) + + keys = list(checkpoint.keys()) + + text_model_dict = {} + + for key in keys: + if key.startswith("cond_stage_model.transformer"): + text_model_dict[key[len("cond_stage_model.transformer.") :]] = checkpoint[key] + + # load clip vision + model.model.load_state_dict(text_model_dict) + + # load mapper + keys_mapper = { + k[len("cond_stage_model.mapper.res") :]: v + for k, v in checkpoint.items() + if k.startswith("cond_stage_model.mapper") + } + + MAPPING = { + "attn.c_qkv": ["attn1.to_q", "attn1.to_k", "attn1.to_v"], + "attn.c_proj": ["attn1.to_out.0"], + "ln_1": ["norm1"], + "ln_2": ["norm3"], + "mlp.c_fc": ["ff.net.0.proj"], + "mlp.c_proj": ["ff.net.2"], + } + + mapped_weights = {} + for key, value in keys_mapper.items(): + prefix = key[: len("blocks.i")] + suffix = key.split(prefix)[-1].split(".")[-1] + name = key.split(prefix)[-1].split(suffix)[0][1:-1] + mapped_names = MAPPING[name] + + num_splits = len(mapped_names) + for i, mapped_name in enumerate(mapped_names): + new_name = ".".join([prefix, mapped_name, suffix]) + shape = value.shape[0] // num_splits + mapped_weights[new_name] = value[i * shape : (i + 1) * shape] + + model.mapper.load_state_dict(mapped_weights) + + # load final layer norm + model.final_layer_norm.load_state_dict( + { + "bias": checkpoint["cond_stage_model.final_ln.bias"], + "weight": checkpoint["cond_stage_model.final_ln.weight"], + } + ) + + # load final proj + model.proj_out.load_state_dict( + { + "bias": checkpoint["proj_out.bias"], + "weight": checkpoint["proj_out.weight"], + } + ) + + # load uncond vector + model.uncond_vector.data = torch.nn.Parameter(checkpoint["learnable_vector"]) + return model + + +def convert_open_clip_checkpoint( + checkpoint, + config_name, + prefix="cond_stage_model.model.", + has_projection=False, + local_files_only=False, + **config_kwargs, +): + # text_model = CLIPTextModel.from_pretrained("stabilityai/stable-diffusion-2", subfolder="text_encoder") + # text_model = CLIPTextModelWithProjection.from_pretrained( + # "laion/CLIP-ViT-bigG-14-laion2B-39B-b160k", projection_dim=1280 + # ) + try: + config = CLIPTextConfig.from_pretrained(config_name, **config_kwargs, local_files_only=local_files_only) + except Exception: + raise ValueError( + f"With local_files_only set to {local_files_only}, you must first locally save the configuration in the following path: '{config_name}'." + ) + + ctx = init_empty_weights if is_accelerate_available() else nullcontext + with ctx(): + text_model = CLIPTextModelWithProjection(config) if has_projection else CLIPTextModel(config) + + keys = list(checkpoint.keys()) + + keys_to_ignore = [] + if config_name == "stabilityai/stable-diffusion-2" and config.num_hidden_layers == 23: + # make sure to remove all keys > 22 + keys_to_ignore += [k for k in keys if k.startswith("cond_stage_model.model.transformer.resblocks.23")] + keys_to_ignore += ["cond_stage_model.model.text_projection"] + + text_model_dict = {} + + if prefix + "text_projection" in checkpoint: + d_model = int(checkpoint[prefix + "text_projection"].shape[0]) + else: + d_model = 1024 + + text_model_dict["text_model.embeddings.position_ids"] = text_model.text_model.embeddings.get_buffer("position_ids") + + for key in keys: + if key in keys_to_ignore: + continue + if key[len(prefix) :] in textenc_conversion_map: + if key.endswith("text_projection"): + value = checkpoint[key].T.contiguous() + else: + value = checkpoint[key] + + text_model_dict[textenc_conversion_map[key[len(prefix) :]]] = value + + if key.startswith(prefix + "transformer."): + new_key = key[len(prefix + "transformer.") :] + if new_key.endswith(".in_proj_weight"): + new_key = new_key[: -len(".in_proj_weight")] + new_key = textenc_pattern.sub(lambda m: protected[re.escape(m.group(0))], new_key) + text_model_dict[new_key + ".q_proj.weight"] = checkpoint[key][:d_model, :] + text_model_dict[new_key + ".k_proj.weight"] = checkpoint[key][d_model : d_model * 2, :] + text_model_dict[new_key + ".v_proj.weight"] = checkpoint[key][d_model * 2 :, :] + elif new_key.endswith(".in_proj_bias"): + new_key = new_key[: -len(".in_proj_bias")] + new_key = textenc_pattern.sub(lambda m: protected[re.escape(m.group(0))], new_key) + text_model_dict[new_key + ".q_proj.bias"] = checkpoint[key][:d_model] + text_model_dict[new_key + ".k_proj.bias"] = checkpoint[key][d_model : d_model * 2] + text_model_dict[new_key + ".v_proj.bias"] = checkpoint[key][d_model * 2 :] + else: + new_key = textenc_pattern.sub(lambda m: protected[re.escape(m.group(0))], new_key) + + text_model_dict[new_key] = checkpoint[key] + + if is_accelerate_available(): + for param_name, param in text_model_dict.items(): + set_module_tensor_to_device(text_model, param_name, "cpu", value=param) + else: + if not (hasattr(text_model, "embeddings") and hasattr(text_model.embeddings.position_ids)): + text_model_dict.pop("text_model.embeddings.position_ids", None) + + text_model.load_state_dict(text_model_dict) + + return text_model + + +def stable_unclip_image_encoder(original_config, local_files_only=False): + """ + Returns the image processor and clip image encoder for the img2img unclip pipeline. + + We currently know of two types of stable unclip models which separately use the clip and the openclip image + encoders. + """ + + image_embedder_config = original_config["model"]["params"]["embedder_config"] + + sd_clip_image_embedder_class = image_embedder_config["target"] + sd_clip_image_embedder_class = sd_clip_image_embedder_class.split(".")[-1] + + if sd_clip_image_embedder_class == "ClipImageEmbedder": + clip_model_name = image_embedder_config.params.model + + if clip_model_name == "ViT-L/14": + feature_extractor = CLIPImageProcessor() + image_encoder = CLIPVisionModelWithProjection.from_pretrained( + "openai/clip-vit-large-patch14", local_files_only=local_files_only + ) + else: + raise NotImplementedError(f"Unknown CLIP checkpoint name in stable diffusion checkpoint {clip_model_name}") + + elif sd_clip_image_embedder_class == "FrozenOpenCLIPImageEmbedder": + feature_extractor = CLIPImageProcessor() + image_encoder = CLIPVisionModelWithProjection.from_pretrained( + "laion/CLIP-ViT-H-14-laion2B-s32B-b79K", local_files_only=local_files_only + ) + else: + raise NotImplementedError( + f"Unknown CLIP image embedder class in stable diffusion checkpoint {sd_clip_image_embedder_class}" + ) + + return feature_extractor, image_encoder + + +def stable_unclip_image_noising_components( + original_config, clip_stats_path: Optional[str] = None, device: Optional[str] = None +): + """ + Returns the noising components for the img2img and txt2img unclip pipelines. + + Converts the stability noise augmentor into + 1. a `StableUnCLIPImageNormalizer` for holding the CLIP stats + 2. a `DDPMScheduler` for holding the noise schedule + + If the noise augmentor config specifies a clip stats path, the `clip_stats_path` must be provided. + """ + noise_aug_config = original_config["model"]["params"]["noise_aug_config"] + noise_aug_class = noise_aug_config["target"] + noise_aug_class = noise_aug_class.split(".")[-1] + + if noise_aug_class == "CLIPEmbeddingNoiseAugmentation": + noise_aug_config = noise_aug_config.params + embedding_dim = noise_aug_config.timestep_dim + max_noise_level = noise_aug_config.noise_schedule_config.timesteps + beta_schedule = noise_aug_config.noise_schedule_config.beta_schedule + + image_normalizer = StableUnCLIPImageNormalizer(embedding_dim=embedding_dim) + image_noising_scheduler = DDPMScheduler(num_train_timesteps=max_noise_level, beta_schedule=beta_schedule) + + if "clip_stats_path" in noise_aug_config: + if clip_stats_path is None: + raise ValueError("This stable unclip config requires a `clip_stats_path`") + + clip_mean, clip_std = torch.load(clip_stats_path, map_location=device) + clip_mean = clip_mean[None, :] + clip_std = clip_std[None, :] + + clip_stats_state_dict = { + "mean": clip_mean, + "std": clip_std, + } + + image_normalizer.load_state_dict(clip_stats_state_dict) + else: + raise NotImplementedError(f"Unknown noise augmentor class: {noise_aug_class}") + + return image_normalizer, image_noising_scheduler + + +def convert_controlnet_checkpoint( + checkpoint, + original_config, + checkpoint_path, + image_size, + upcast_attention, + extract_ema, + use_linear_projection=None, + cross_attention_dim=None, +): + ctrlnet_config = create_unet_diffusers_config(original_config, image_size=image_size, controlnet=True) + ctrlnet_config["upcast_attention"] = upcast_attention + + ctrlnet_config.pop("sample_size") + + if use_linear_projection is not None: + ctrlnet_config["use_linear_projection"] = use_linear_projection + + if cross_attention_dim is not None: + ctrlnet_config["cross_attention_dim"] = cross_attention_dim + + ctx = init_empty_weights if is_accelerate_available() else nullcontext + with ctx(): + controlnet = ControlNetModel(**ctrlnet_config) + + # Some controlnet ckpt files are distributed independently from the rest of the + # model components i.e. https://huggingface.co/thibaud/controlnet-sd21/ + if "time_embed.0.weight" in checkpoint: + skip_extract_state_dict = True + else: + skip_extract_state_dict = False + + converted_ctrl_checkpoint = convert_ldm_unet_checkpoint( + checkpoint, + ctrlnet_config, + path=checkpoint_path, + extract_ema=extract_ema, + controlnet=True, + skip_extract_state_dict=skip_extract_state_dict, + ) + + if is_accelerate_available(): + for param_name, param in converted_ctrl_checkpoint.items(): + set_module_tensor_to_device(controlnet, param_name, "cpu", value=param) + else: + controlnet.load_state_dict(converted_ctrl_checkpoint) + + return controlnet + + +def convert_promptdiffusion_checkpoint( + checkpoint, + original_config, + checkpoint_path, + image_size, + upcast_attention, + extract_ema, + use_linear_projection=None, + cross_attention_dim=None, +): + ctrlnet_config = create_unet_diffusers_config(original_config, image_size=image_size, controlnet=True) + ctrlnet_config["upcast_attention"] = upcast_attention + + ctrlnet_config.pop("sample_size") + + if use_linear_projection is not None: + ctrlnet_config["use_linear_projection"] = use_linear_projection + + if cross_attention_dim is not None: + ctrlnet_config["cross_attention_dim"] = cross_attention_dim + + ctx = init_empty_weights if is_accelerate_available() else nullcontext + with ctx(): + controlnet = TextControlNetModel(**ctrlnet_config) + + # Some controlnet ckpt files are distributed independently from the rest of the + # model components i.e. https://huggingface.co/thibaud/controlnet-sd21/ + if "time_embed.0.weight" in checkpoint: + skip_extract_state_dict = True + else: + skip_extract_state_dict = False + + converted_ctrl_checkpoint = convert_ldm_unet_checkpoint( + checkpoint, + ctrlnet_config, + path=checkpoint_path, + extract_ema=extract_ema, + promptdiffusion=True, + controlnet=True, + skip_extract_state_dict=skip_extract_state_dict, + ) + + if is_accelerate_available(): + for param_name, param in converted_ctrl_checkpoint.items(): + set_module_tensor_to_device(controlnet, param_name, "cpu", value=param) + else: + controlnet.load_state_dict(converted_ctrl_checkpoint) + + return controlnet + + +def download_from_original_stable_diffusion_ckpt( + checkpoint_path_or_dict: Union[str, Dict[str, torch.Tensor]], + original_config_file: str = None, + image_size: Optional[int] = None, + prediction_type: str = None, + model_type: str = None, + extract_ema: bool = False, + scheduler_type: str = "pndm", + num_in_channels: Optional[int] = None, + upcast_attention: Optional[bool] = None, + device: str = None, + from_safetensors: bool = False, + stable_unclip: Optional[str] = None, + stable_unclip_prior: Optional[str] = None, + clip_stats_path: Optional[str] = None, + controlnet: Optional[bool] = None, + adapter: Optional[bool] = None, + load_safety_checker: bool = True, + pipeline_class: DiffusionPipeline = None, + local_files_only=False, + vae_path=None, + vae=None, + text_encoder=None, + text_encoder_2=None, + tokenizer=None, + tokenizer_2=None, + config_files=None, +) -> DiffusionPipeline: + """ + Load a Stable Diffusion pipeline object from a CompVis-style `.ckpt`/`.safetensors` file and (ideally) a `.yaml` + config file. + + Although many of the arguments can be automatically inferred, some of these rely on brittle checks against the + global step count, which will likely fail for models that have undergone further fine-tuning. Therefore, it is + recommended that you override the default values and/or supply an `original_config_file` wherever possible. + + Args: + checkpoint_path_or_dict (`str` or `dict`): Path to `.ckpt` file, or the state dict. + original_config_file (`str`): + Path to `.yaml` config file corresponding to the original architecture. If `None`, will be automatically + inferred by looking for a key that only exists in SD2.0 models. + image_size (`int`, *optional*, defaults to 512): + The image size that the model was trained on. Use 512 for Stable Diffusion v1.X and Stable Diffusion v2 + Base. Use 768 for Stable Diffusion v2. + prediction_type (`str`, *optional*): + The prediction type that the model was trained on. Use `'epsilon'` for Stable Diffusion v1.X and Stable + Diffusion v2 Base. Use `'v_prediction'` for Stable Diffusion v2. + num_in_channels (`int`, *optional*, defaults to None): + The number of input channels. If `None`, it will be automatically inferred. + scheduler_type (`str`, *optional*, defaults to 'pndm'): + Type of scheduler to use. Should be one of `["pndm", "lms", "heun", "euler", "euler-ancestral", "dpm", + "ddim"]`. + model_type (`str`, *optional*, defaults to `None`): + The pipeline type. `None` to automatically infer, or one of `["FrozenOpenCLIPEmbedder", + "FrozenCLIPEmbedder", "PaintByExample"]`. + is_img2img (`bool`, *optional*, defaults to `False`): + Whether the model should be loaded as an img2img pipeline. + extract_ema (`bool`, *optional*, defaults to `False`): Only relevant for + checkpoints that have both EMA and non-EMA weights. Whether to extract the EMA weights or not. Defaults to + `False`. Pass `True` to extract the EMA weights. EMA weights usually yield higher quality images for + inference. Non-EMA weights are usually better to continue fine-tuning. + upcast_attention (`bool`, *optional*, defaults to `None`): + Whether the attention computation should always be upcasted. This is necessary when running stable + diffusion 2.1. + device (`str`, *optional*, defaults to `None`): + The device to use. Pass `None` to determine automatically. + from_safetensors (`str`, *optional*, defaults to `False`): + If `checkpoint_path` is in `safetensors` format, load checkpoint with safetensors instead of PyTorch. + load_safety_checker (`bool`, *optional*, defaults to `True`): + Whether to load the safety checker or not. Defaults to `True`. + pipeline_class (`str`, *optional*, defaults to `None`): + The pipeline class to use. Pass `None` to determine automatically. + local_files_only (`bool`, *optional*, defaults to `False`): + Whether or not to only look at local files (i.e., do not try to download the model). + vae (`AutoencoderKL`, *optional*, defaults to `None`): + Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations. If + this parameter is `None`, the function will load a new instance of [CLIP] by itself, if needed. + text_encoder (`CLIPTextModel`, *optional*, defaults to `None`): + An instance of [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel) + to use, specifically the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) + variant. If this parameter is `None`, the function will load a new instance of [CLIP] by itself, if needed. + tokenizer (`CLIPTokenizer`, *optional*, defaults to `None`): + An instance of + [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer) + to use. If this parameter is `None`, the function will load a new instance of [CLIPTokenizer] by itself, if + needed. + config_files (`Dict[str, str]`, *optional*, defaults to `None`): + A dictionary mapping from config file names to their contents. If this parameter is `None`, the function + will load the config files by itself, if needed. Valid keys are: + - `v1`: Config file for Stable Diffusion v1 + - `v2`: Config file for Stable Diffusion v2 + - `xl`: Config file for Stable Diffusion XL + - `xl_refiner`: Config file for Stable Diffusion XL Refiner + return: A StableDiffusionPipeline object representing the passed-in `.ckpt`/`.safetensors` file. + """ + + # import pipelines here to avoid circular import error when using from_single_file method + from diffusers import ( + LDMTextToImagePipeline, + PaintByExamplePipeline, + StableDiffusionControlNetPipeline, + StableDiffusionInpaintPipeline, + StableDiffusionPipeline, + StableDiffusionUpscalePipeline, + StableDiffusionXLControlNetInpaintPipeline, + StableDiffusionXLImg2ImgPipeline, + StableDiffusionXLInpaintPipeline, + StableDiffusionXLPipeline, + StableUnCLIPImg2ImgPipeline, + StableUnCLIPPipeline, + ) + + if prediction_type == "v-prediction": + prediction_type = "v_prediction" + + if isinstance(checkpoint_path_or_dict, str): + if from_safetensors: + from safetensors.torch import load_file as safe_load + + checkpoint = safe_load(checkpoint_path_or_dict, device="cpu") + else: + if device is None: + device = "cuda" if torch.cuda.is_available() else "cpu" + checkpoint = torch.load(checkpoint_path_or_dict, map_location=device) + else: + checkpoint = torch.load(checkpoint_path_or_dict, map_location=device) + elif isinstance(checkpoint_path_or_dict, dict): + checkpoint = checkpoint_path_or_dict + + # Sometimes models don't have the global_step item + if "global_step" in checkpoint: + global_step = checkpoint["global_step"] + else: + logger.debug("global_step key not found in model") + global_step = None + + # NOTE: this while loop isn't great but this controlnet checkpoint has one additional + # "state_dict" key https://huggingface.co/thibaud/controlnet-canny-sd21 + while "state_dict" in checkpoint: + checkpoint = checkpoint["state_dict"] + + if original_config_file is None: + key_name_v2_1 = "model.diffusion_model.input_blocks.2.1.transformer_blocks.0.attn2.to_k.weight" + key_name_sd_xl_base = "conditioner.embedders.1.model.transformer.resblocks.9.mlp.c_proj.bias" + key_name_sd_xl_refiner = "conditioner.embedders.0.model.transformer.resblocks.9.mlp.c_proj.bias" + is_upscale = pipeline_class == StableDiffusionUpscalePipeline + + config_url = None + + # model_type = "v1" + if config_files is not None and "v1" in config_files: + original_config_file = config_files["v1"] + else: + config_url = "https://raw.githubusercontent.com/CompVis/stable-diffusion/main/configs/stable-diffusion/v1-inference.yaml" + + if key_name_v2_1 in checkpoint and checkpoint[key_name_v2_1].shape[-1] == 1024: + # model_type = "v2" + if config_files is not None and "v2" in config_files: + original_config_file = config_files["v2"] + else: + config_url = "https://raw.githubusercontent.com/Stability-AI/stablediffusion/main/configs/stable-diffusion/v2-inference-v.yaml" + if global_step == 110000: + # v2.1 needs to upcast attention + upcast_attention = True + elif key_name_sd_xl_base in checkpoint: + # only base xl has two text embedders + if config_files is not None and "xl" in config_files: + original_config_file = config_files["xl"] + else: + config_url = "https://raw.githubusercontent.com/Stability-AI/generative-models/main/configs/inference/sd_xl_base.yaml" + elif key_name_sd_xl_refiner in checkpoint: + # only refiner xl has embedder and one text embedders + if config_files is not None and "xl_refiner" in config_files: + original_config_file = config_files["xl_refiner"] + else: + config_url = "https://raw.githubusercontent.com/Stability-AI/generative-models/main/configs/inference/sd_xl_refiner.yaml" + + if is_upscale: + config_url = "https://raw.githubusercontent.com/Stability-AI/stablediffusion/main/configs/stable-diffusion/x4-upscaling.yaml" + + if config_url is not None: + original_config_file = BytesIO(requests.get(config_url).content) + else: + with open(original_config_file, "r") as f: + original_config_file = f.read() + + original_config = yaml.safe_load(original_config_file) + + # Convert the text model. + if ( + model_type is None + and "cond_stage_config" in original_config["model"]["params"] + and original_config["model"]["params"]["cond_stage_config"] is not None + ): + model_type = original_config["model"]["params"]["cond_stage_config"]["target"].split(".")[-1] + logger.debug(f"no `model_type` given, `model_type` inferred as: {model_type}") + elif model_type is None and original_config["model"]["params"]["network_config"] is not None: + if original_config["model"]["params"]["network_config"]["params"]["context_dim"] == 2048: + model_type = "SDXL" + else: + model_type = "SDXL-Refiner" + if image_size is None: + image_size = 1024 + + if pipeline_class is None: + # Check if we have a SDXL or SD model and initialize default pipeline + if model_type not in ["SDXL", "SDXL-Refiner"]: + pipeline_class = StableDiffusionPipeline if not controlnet else StableDiffusionControlNetPipeline + else: + pipeline_class = StableDiffusionXLPipeline if model_type == "SDXL" else StableDiffusionXLImg2ImgPipeline + + if num_in_channels is None and pipeline_class in [ + StableDiffusionInpaintPipeline, + StableDiffusionXLInpaintPipeline, + StableDiffusionXLControlNetInpaintPipeline, + ]: + num_in_channels = 9 + if num_in_channels is None and pipeline_class == StableDiffusionUpscalePipeline: + num_in_channels = 7 + elif num_in_channels is None: + num_in_channels = 4 + + if "unet_config" in original_config["model"]["params"]: + original_config["model"]["params"]["unet_config"]["params"]["in_channels"] = num_in_channels + + if ( + "parameterization" in original_config["model"]["params"] + and original_config["model"]["params"]["parameterization"] == "v" + ): + if prediction_type is None: + # NOTE: For stable diffusion 2 base it is recommended to pass `prediction_type=="epsilon"` + # as it relies on a brittle global step parameter here + prediction_type = "epsilon" if global_step == 875000 else "v_prediction" + if image_size is None: + # NOTE: For stable diffusion 2 base one has to pass `image_size==512` + # as it relies on a brittle global step parameter here + image_size = 512 if global_step == 875000 else 768 + else: + if prediction_type is None: + prediction_type = "epsilon" + if image_size is None: + image_size = 512 + + if controlnet is None and "control_stage_config" in original_config["model"]["params"]: + path = checkpoint_path_or_dict if isinstance(checkpoint_path_or_dict, str) else "" + controlnet = convert_controlnet_checkpoint( + checkpoint, original_config, path, image_size, upcast_attention, extract_ema + ) + + if "timesteps" in original_config["model"]["params"]: + num_train_timesteps = original_config["model"]["params"]["timesteps"] + else: + num_train_timesteps = 1000 + + if model_type in ["SDXL", "SDXL-Refiner"]: + scheduler_dict = { + "beta_schedule": "scaled_linear", + "beta_start": 0.00085, + "beta_end": 0.012, + "interpolation_type": "linear", + "num_train_timesteps": num_train_timesteps, + "prediction_type": "epsilon", + "sample_max_value": 1.0, + "set_alpha_to_one": False, + "skip_prk_steps": True, + "steps_offset": 1, + "timestep_spacing": "leading", + } + scheduler = EulerDiscreteScheduler.from_config(scheduler_dict) + scheduler_type = "euler" + else: + if "linear_start" in original_config["model"]["params"]: + beta_start = original_config["model"]["params"]["linear_start"] + else: + beta_start = 0.02 + + if "linear_end" in original_config["model"]["params"]: + beta_end = original_config["model"]["params"]["linear_end"] + else: + beta_end = 0.085 + scheduler = DDIMScheduler( + beta_end=beta_end, + beta_schedule="scaled_linear", + beta_start=beta_start, + num_train_timesteps=num_train_timesteps, + steps_offset=1, + clip_sample=False, + set_alpha_to_one=False, + prediction_type=prediction_type, + ) + # make sure scheduler works correctly with DDIM + scheduler.register_to_config(clip_sample=False) + + if scheduler_type == "pndm": + config = dict(scheduler.config) + config["skip_prk_steps"] = True + scheduler = PNDMScheduler.from_config(config) + elif scheduler_type == "lms": + scheduler = LMSDiscreteScheduler.from_config(scheduler.config) + elif scheduler_type == "heun": + scheduler = HeunDiscreteScheduler.from_config(scheduler.config) + elif scheduler_type == "euler": + scheduler = EulerDiscreteScheduler.from_config(scheduler.config) + elif scheduler_type == "euler-ancestral": + scheduler = EulerAncestralDiscreteScheduler.from_config(scheduler.config) + elif scheduler_type == "dpm": + scheduler = DPMSolverMultistepScheduler.from_config(scheduler.config) + elif scheduler_type == "ddim": + scheduler = scheduler + else: + raise ValueError(f"Scheduler of type {scheduler_type} doesn't exist!") + + if pipeline_class == StableDiffusionUpscalePipeline: + image_size = original_config["model"]["params"]["unet_config"]["params"]["image_size"] + + # Convert the UNet2DConditionModel model. + unet_config = create_unet_diffusers_config(original_config, image_size=image_size) + unet_config["upcast_attention"] = upcast_attention + + path = checkpoint_path_or_dict if isinstance(checkpoint_path_or_dict, str) else "" + converted_unet_checkpoint = convert_ldm_unet_checkpoint( + checkpoint, unet_config, path=path, extract_ema=extract_ema + ) + + ctx = init_empty_weights if is_accelerate_available() else nullcontext + with ctx(): + unet = UNet2DConditionModel(**unet_config) + + if is_accelerate_available(): + if model_type not in ["SDXL", "SDXL-Refiner"]: # SBM Delay this. + for param_name, param in converted_unet_checkpoint.items(): + set_module_tensor_to_device(unet, param_name, "cpu", value=param) + else: + unet.load_state_dict(converted_unet_checkpoint) + + # Convert the VAE model. + if vae_path is None and vae is None: + vae_config = create_vae_diffusers_config(original_config, image_size=image_size) + converted_vae_checkpoint = convert_ldm_vae_checkpoint(checkpoint, vae_config) + + if ( + "model" in original_config + and "params" in original_config["model"] + and "scale_factor" in original_config["model"]["params"] + ): + vae_scaling_factor = original_config["model"]["params"]["scale_factor"] + else: + vae_scaling_factor = 0.18215 # default SD scaling factor + + vae_config["scaling_factor"] = vae_scaling_factor + + ctx = init_empty_weights if is_accelerate_available() else nullcontext + with ctx(): + vae = AutoencoderKL(**vae_config) + + if is_accelerate_available(): + for param_name, param in converted_vae_checkpoint.items(): + set_module_tensor_to_device(vae, param_name, "cpu", value=param) + else: + vae.load_state_dict(converted_vae_checkpoint) + elif vae is None: + vae = AutoencoderKL.from_pretrained(vae_path, local_files_only=local_files_only) + + if model_type == "FrozenOpenCLIPEmbedder": + config_name = "stabilityai/stable-diffusion-2" + config_kwargs = {"subfolder": "text_encoder"} + + if text_encoder is None: + text_model = convert_open_clip_checkpoint( + checkpoint, config_name, local_files_only=local_files_only, **config_kwargs + ) + else: + text_model = text_encoder + + try: + tokenizer = CLIPTokenizer.from_pretrained( + "stabilityai/stable-diffusion-2", subfolder="tokenizer", local_files_only=local_files_only + ) + except Exception: + raise ValueError( + f"With local_files_only set to {local_files_only}, you must first locally save the tokenizer in the following path: 'stabilityai/stable-diffusion-2'." + ) + + if stable_unclip is None: + if controlnet: + pipe = pipeline_class( + vae=vae, + text_encoder=text_model, + tokenizer=tokenizer, + unet=unet, + scheduler=scheduler, + controlnet=controlnet, + safety_checker=None, + feature_extractor=None, + ) + if hasattr(pipe, "requires_safety_checker"): + pipe.requires_safety_checker = False + + elif pipeline_class == StableDiffusionUpscalePipeline: + scheduler = DDIMScheduler.from_pretrained( + "stabilityai/stable-diffusion-x4-upscaler", subfolder="scheduler" + ) + low_res_scheduler = DDPMScheduler.from_pretrained( + "stabilityai/stable-diffusion-x4-upscaler", subfolder="low_res_scheduler" + ) + + pipe = pipeline_class( + vae=vae, + text_encoder=text_model, + tokenizer=tokenizer, + unet=unet, + scheduler=scheduler, + low_res_scheduler=low_res_scheduler, + safety_checker=None, + feature_extractor=None, + ) + + else: + pipe = pipeline_class( + vae=vae, + text_encoder=text_model, + tokenizer=tokenizer, + unet=unet, + scheduler=scheduler, + safety_checker=None, + feature_extractor=None, + ) + if hasattr(pipe, "requires_safety_checker"): + pipe.requires_safety_checker = False + + else: + image_normalizer, image_noising_scheduler = stable_unclip_image_noising_components( + original_config, clip_stats_path=clip_stats_path, device=device + ) + + if stable_unclip == "img2img": + feature_extractor, image_encoder = stable_unclip_image_encoder(original_config) + + pipe = StableUnCLIPImg2ImgPipeline( + # image encoding components + feature_extractor=feature_extractor, + image_encoder=image_encoder, + # image noising components + image_normalizer=image_normalizer, + image_noising_scheduler=image_noising_scheduler, + # regular denoising components + tokenizer=tokenizer, + text_encoder=text_model, + unet=unet, + scheduler=scheduler, + # vae + vae=vae, + ) + elif stable_unclip == "txt2img": + if stable_unclip_prior is None or stable_unclip_prior == "karlo": + karlo_model = "kakaobrain/karlo-v1-alpha" + prior = PriorTransformer.from_pretrained( + karlo_model, subfolder="prior", local_files_only=local_files_only + ) + + try: + prior_tokenizer = CLIPTokenizer.from_pretrained( + "openai/clip-vit-large-patch14", local_files_only=local_files_only + ) + except Exception: + raise ValueError( + f"With local_files_only set to {local_files_only}, you must first locally save the tokenizer in the following path: 'openai/clip-vit-large-patch14'." + ) + prior_text_model = CLIPTextModelWithProjection.from_pretrained( + "openai/clip-vit-large-patch14", local_files_only=local_files_only + ) + + prior_scheduler = UnCLIPScheduler.from_pretrained( + karlo_model, subfolder="prior_scheduler", local_files_only=local_files_only + ) + prior_scheduler = DDPMScheduler.from_config(prior_scheduler.config) + else: + raise NotImplementedError(f"unknown prior for stable unclip model: {stable_unclip_prior}") + + pipe = StableUnCLIPPipeline( + # prior components + prior_tokenizer=prior_tokenizer, + prior_text_encoder=prior_text_model, + prior=prior, + prior_scheduler=prior_scheduler, + # image noising components + image_normalizer=image_normalizer, + image_noising_scheduler=image_noising_scheduler, + # regular denoising components + tokenizer=tokenizer, + text_encoder=text_model, + unet=unet, + scheduler=scheduler, + # vae + vae=vae, + ) + else: + raise NotImplementedError(f"unknown `stable_unclip` type: {stable_unclip}") + elif model_type == "PaintByExample": + vision_model = convert_paint_by_example_checkpoint(checkpoint) + try: + tokenizer = CLIPTokenizer.from_pretrained( + "openai/clip-vit-large-patch14", local_files_only=local_files_only + ) + except Exception: + raise ValueError( + f"With local_files_only set to {local_files_only}, you must first locally save the tokenizer in the following path: 'openai/clip-vit-large-patch14'." + ) + try: + feature_extractor = AutoFeatureExtractor.from_pretrained( + "CompVis/stable-diffusion-safety-checker", local_files_only=local_files_only + ) + except Exception: + raise ValueError( + f"With local_files_only set to {local_files_only}, you must first locally save the feature_extractor in the following path: 'CompVis/stable-diffusion-safety-checker'." + ) + pipe = PaintByExamplePipeline( + vae=vae, + image_encoder=vision_model, + unet=unet, + scheduler=scheduler, + safety_checker=None, + feature_extractor=feature_extractor, + ) + elif model_type == "FrozenCLIPEmbedder": + text_model = convert_ldm_clip_checkpoint( + checkpoint, local_files_only=local_files_only, text_encoder=text_encoder + ) + try: + tokenizer = ( + CLIPTokenizer.from_pretrained("openai/clip-vit-large-patch14", local_files_only=local_files_only) + if tokenizer is None + else tokenizer + ) + except Exception: + raise ValueError( + f"With local_files_only set to {local_files_only}, you must first locally save the tokenizer in the following path: 'openai/clip-vit-large-patch14'." + ) + + if load_safety_checker: + safety_checker = StableDiffusionSafetyChecker.from_pretrained( + "CompVis/stable-diffusion-safety-checker", local_files_only=local_files_only + ) + feature_extractor = AutoFeatureExtractor.from_pretrained( + "CompVis/stable-diffusion-safety-checker", local_files_only=local_files_only + ) + else: + safety_checker = None + feature_extractor = None + + if controlnet: + pipe = pipeline_class( + vae=vae, + text_encoder=text_model, + tokenizer=tokenizer, + unet=unet, + controlnet=controlnet, + scheduler=scheduler, + safety_checker=safety_checker, + feature_extractor=feature_extractor, + ) + else: + pipe = pipeline_class( + vae=vae, + text_encoder=text_model, + tokenizer=tokenizer, + unet=unet, + scheduler=scheduler, + safety_checker=safety_checker, + feature_extractor=feature_extractor, + ) + elif model_type in ["SDXL", "SDXL-Refiner"]: + is_refiner = model_type == "SDXL-Refiner" + + if (is_refiner is False) and (tokenizer is None): + try: + tokenizer = CLIPTokenizer.from_pretrained( + "openai/clip-vit-large-patch14", local_files_only=local_files_only + ) + except Exception: + raise ValueError( + f"With local_files_only set to {local_files_only}, you must first locally save the tokenizer in the following path: 'openai/clip-vit-large-patch14'." + ) + + if (is_refiner is False) and (text_encoder is None): + text_encoder = convert_ldm_clip_checkpoint(checkpoint, local_files_only=local_files_only) + + if tokenizer_2 is None: + try: + tokenizer_2 = CLIPTokenizer.from_pretrained( + "laion/CLIP-ViT-bigG-14-laion2B-39B-b160k", pad_token="!", local_files_only=local_files_only + ) + except Exception: + raise ValueError( + f"With local_files_only set to {local_files_only}, you must first locally save the tokenizer in the following path: 'laion/CLIP-ViT-bigG-14-laion2B-39B-b160k' with `pad_token` set to '!'." + ) + + if text_encoder_2 is None: + config_name = "laion/CLIP-ViT-bigG-14-laion2B-39B-b160k" + config_kwargs = {"projection_dim": 1280} + prefix = "conditioner.embedders.0.model." if is_refiner else "conditioner.embedders.1.model." + + text_encoder_2 = convert_open_clip_checkpoint( + checkpoint, + config_name, + prefix=prefix, + has_projection=True, + local_files_only=local_files_only, + **config_kwargs, + ) + + if is_accelerate_available(): # SBM Now move model to cpu. + for param_name, param in converted_unet_checkpoint.items(): + set_module_tensor_to_device(unet, param_name, "cpu", value=param) + + if controlnet: + pipe = pipeline_class( + vae=vae, + text_encoder=text_encoder, + tokenizer=tokenizer, + text_encoder_2=text_encoder_2, + tokenizer_2=tokenizer_2, + unet=unet, + controlnet=controlnet, + scheduler=scheduler, + force_zeros_for_empty_prompt=True, + ) + elif adapter: + pipe = pipeline_class( + vae=vae, + text_encoder=text_encoder, + tokenizer=tokenizer, + text_encoder_2=text_encoder_2, + tokenizer_2=tokenizer_2, + unet=unet, + adapter=adapter, + scheduler=scheduler, + force_zeros_for_empty_prompt=True, + ) + + else: + pipeline_kwargs = { + "vae": vae, + "text_encoder": text_encoder, + "tokenizer": tokenizer, + "text_encoder_2": text_encoder_2, + "tokenizer_2": tokenizer_2, + "unet": unet, + "scheduler": scheduler, + } + + if (pipeline_class == StableDiffusionXLImg2ImgPipeline) or ( + pipeline_class == StableDiffusionXLInpaintPipeline + ): + pipeline_kwargs.update({"requires_aesthetics_score": is_refiner}) + + if is_refiner: + pipeline_kwargs.update({"force_zeros_for_empty_prompt": False}) + + pipe = pipeline_class(**pipeline_kwargs) + else: + text_config = create_ldm_bert_config(original_config) + text_model = convert_ldm_bert_checkpoint(checkpoint, text_config) + tokenizer = BertTokenizerFast.from_pretrained("bert-base-uncased", local_files_only=local_files_only) + pipe = LDMTextToImagePipeline(vqvae=vae, bert=text_model, tokenizer=tokenizer, unet=unet, scheduler=scheduler) + + return pipe + + +def download_controlnet_from_original_ckpt( + checkpoint_path: str, + original_config_file: str, + image_size: int = 512, + extract_ema: bool = False, + num_in_channels: Optional[int] = None, + upcast_attention: Optional[bool] = None, + device: str = None, + from_safetensors: bool = False, + use_linear_projection: Optional[bool] = None, + cross_attention_dim: Optional[bool] = None, +) -> DiffusionPipeline: + if from_safetensors: + from safetensors import safe_open + + checkpoint = {} + with safe_open(checkpoint_path, framework="pt", device="cpu") as f: + for key in f.keys(): + checkpoint[key] = f.get_tensor(key) + else: + if device is None: + device = "cuda" if torch.cuda.is_available() else "cpu" + checkpoint = torch.load(checkpoint_path, map_location=device) + else: + checkpoint = torch.load(checkpoint_path, map_location=device) + + # NOTE: this while loop isn't great but this controlnet checkpoint has one additional + # "state_dict" key https://huggingface.co/thibaud/controlnet-canny-sd21 + while "state_dict" in checkpoint: + checkpoint = checkpoint["state_dict"] + + original_config = yaml.safe_load(original_config_file) + + if num_in_channels is not None: + original_config["model"]["params"]["unet_config"]["params"]["in_channels"] = num_in_channels + + if "control_stage_config" not in original_config["model"]["params"]: + raise ValueError("`control_stage_config` not present in original config") + + controlnet = convert_controlnet_checkpoint( + checkpoint, + original_config, + checkpoint_path, + image_size, + upcast_attention, + extract_ema, + use_linear_projection=use_linear_projection, + cross_attention_dim=cross_attention_dim, + ) + + return controlnet + + +def download_promptdiffusion_from_original_ckpt( + checkpoint_path: str, + original_config_file: str, + image_size: int = 512, + extract_ema: bool = False, + num_in_channels: Optional[int] = None, + upcast_attention: Optional[bool] = None, + device: str = None, + from_safetensors: bool = False, + use_linear_projection: Optional[bool] = None, + cross_attention_dim: Optional[bool] = None, +) -> DiffusionPipeline: + if from_safetensors: + from safetensors import safe_open + + checkpoint = {} + with safe_open(checkpoint_path, framework="pt", device="cpu") as f: + for key in f.keys(): + checkpoint[key] = f.get_tensor(key) + else: + if device is None: + device = "cuda" if torch.cuda.is_available() else "cpu" + checkpoint = torch.load(checkpoint_path, map_location=device) + else: + checkpoint = torch.load(checkpoint_path, map_location=device) + + # NOTE: this while loop isn't great but this controlnet checkpoint has one additional + # "state_dict" key https://huggingface.co/thibaud/controlnet-canny-sd21 + while "state_dict" in checkpoint: + checkpoint = checkpoint["state_dict"] + + original_config = yaml.safe_load(open(original_config_file)) + + if num_in_channels is not None: + original_config["model"]["params"]["unet_config"]["params"]["in_channels"] = num_in_channels + if "control_stage_config" not in original_config["model"]["params"]: + raise ValueError("`control_stage_config` not present in original config") + + controlnet = convert_promptdiffusion_checkpoint( + checkpoint, + original_config, + checkpoint_path, + image_size, + upcast_attention, + extract_ema, + use_linear_projection=use_linear_projection, + cross_attention_dim=cross_attention_dim, + ) + + return controlnet + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + + parser.add_argument( + "--checkpoint_path", default=None, type=str, required=True, help="Path to the checkpoint to convert." + ) + parser.add_argument( + "--original_config_file", + type=str, + required=True, + help="The YAML config file corresponding to the original architecture.", + ) + parser.add_argument( + "--num_in_channels", + default=None, + type=int, + help="The number of input channels. If `None` number of input channels will be automatically inferred.", + ) + parser.add_argument( + "--image_size", + default=512, + type=int, + help=( + "The image size that the model was trained on. Use 512 for Stable Diffusion v1.X and Stable Siffusion v2" + " Base. Use 768 for Stable Diffusion v2." + ), + ) + parser.add_argument( + "--extract_ema", + action="store_true", + help=( + "Only relevant for checkpoints that have both EMA and non-EMA weights. Whether to extract the EMA weights" + " or not. Defaults to `False`. Add `--extract_ema` to extract the EMA weights. EMA weights usually yield" + " higher quality images for inference. Non-EMA weights are usually better to continue fine-tuning." + ), + ) + parser.add_argument( + "--upcast_attention", + action="store_true", + help=( + "Whether the attention computation should always be upcasted. This is necessary when running stable" + " diffusion 2.1." + ), + ) + parser.add_argument( + "--from_safetensors", + action="store_true", + help="If `--checkpoint_path` is in `safetensors` format, load checkpoint with safetensors instead of PyTorch.", + ) + parser.add_argument( + "--to_safetensors", + action="store_true", + help="Whether to store pipeline in safetensors format or not.", + ) + parser.add_argument("--dump_path", default=None, type=str, required=True, help="Path to the output model.") + parser.add_argument("--device", type=str, help="Device to use (e.g. cpu, cuda:0, cuda:1, etc.)") + + # small workaround to get argparser to parse a boolean input as either true _or_ false + def parse_bool(string): + if string == "True": + return True + elif string == "False": + return False + else: + raise ValueError(f"could not parse string as bool {string}") + + parser.add_argument( + "--use_linear_projection", help="Override for use linear projection", required=False, type=parse_bool + ) + + parser.add_argument("--cross_attention_dim", help="Override for cross attention_dim", required=False, type=int) + + args = parser.parse_args() + + controlnet = download_promptdiffusion_from_original_ckpt( + checkpoint_path=args.checkpoint_path, + original_config_file=args.original_config_file, + image_size=args.image_size, + extract_ema=args.extract_ema, + num_in_channels=args.num_in_channels, + upcast_attention=args.upcast_attention, + from_safetensors=args.from_safetensors, + device=args.device, + use_linear_projection=args.use_linear_projection, + cross_attention_dim=args.cross_attention_dim, + ) + + controlnet.save_pretrained(args.dump_path, safe_serialization=args.to_safetensors) diff --git a/examples/research_projects/anytext/pipeline_anytext.py b/examples/research_projects/anytext/pipeline_anytext.py new file mode 100644 index 000000000000..eb8e3b1145cf --- /dev/null +++ b/examples/research_projects/anytext/pipeline_anytext.py @@ -0,0 +1,1353 @@ +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import inspect +from typing import Any, Callable, Dict, List, Optional, Tuple, Union + +import numpy as np +import PIL.Image +import torch +import torch.nn.functional as F +from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer, CLIPVisionModelWithProjection + +from diffusers.callbacks import MultiPipelineCallbacks, PipelineCallback +from diffusers.image_processor import PipelineImageInput, VaeImageProcessor +from diffusers.loaders import ( + FromSingleFileMixin, + IPAdapterMixin, + StableDiffusionLoraLoaderMixin, + TextualInversionLoaderMixin, +) +from diffusers.models import AutoencoderKL, ControlNetModel, ImageProjection, UNet2DConditionModel +from diffusers.models.lora import adjust_lora_scale_text_encoder +from diffusers.pipelines.controlnet.multicontrolnet import MultiControlNetModel +from diffusers.pipelines.pipeline_utils import DiffusionPipeline, StableDiffusionMixin +from diffusers.pipelines.stable_diffusion.pipeline_output import StableDiffusionPipelineOutput +from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker +from diffusers.schedulers import KarrasDiffusionSchedulers +from diffusers.utils import ( + USE_PEFT_BACKEND, + deprecate, + logging, + replace_example_docstring, + scale_lora_layers, + unscale_lora_layers, +) +from diffusers.utils.torch_utils import is_compiled_module, is_torch_version, randn_tensor + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +EXAMPLE_DOC_STRING = """ + Examples: + ```py + >>> # !pip install opencv-python transformers accelerate + >>> from diffusers import StableDiffusionControlNetPipeline, ControlNetModel, UniPCMultistepScheduler + >>> from diffusers.utils import load_image + >>> import numpy as np + >>> import torch + + >>> import cv2 + >>> from PIL import Image + + >>> # download an image + >>> image = load_image( + ... "https://hf.co/datasets/huggingface/documentation-images/resolve/main/diffusers/input_image_vermeer.png" + ... ) + >>> image = np.array(image) + + >>> # get canny image + >>> image = cv2.Canny(image, 100, 200) + >>> image = image[:, :, None] + >>> image = np.concatenate([image, image, image], axis=2) + >>> canny_image = Image.fromarray(image) + + >>> # load control net and stable diffusion v1-5 + >>> controlnet = ControlNetModel.from_pretrained("lllyasviel/sd-controlnet-canny", torch_dtype=torch.float16) + >>> pipe = StableDiffusionControlNetPipeline.from_pretrained( + ... "runwayml/stable-diffusion-v1-5", controlnet=controlnet, torch_dtype=torch.float16 + ... ) + + >>> # speed up diffusion process with faster scheduler and memory optimization + >>> pipe.scheduler = UniPCMultistepScheduler.from_config(pipe.scheduler.config) + >>> # remove following line if xformers is not installed + >>> pipe.enable_xformers_memory_efficient_attention() + + >>> pipe.enable_model_cpu_offload() + + >>> # generate image + >>> generator = torch.manual_seed(0) + >>> image = pipe( + ... "futuristic-looking woman", num_inference_steps=20, generator=generator, image=canny_image + ... ).images[0] + ``` +""" + + +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps +def retrieve_timesteps( + scheduler, + num_inference_steps: Optional[int] = None, + device: Optional[Union[str, torch.device]] = None, + timesteps: Optional[List[int]] = None, + sigmas: Optional[List[float]] = None, + **kwargs, +): + """ + Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles + custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`. + + Args: + scheduler (`SchedulerMixin`): + The scheduler to get timesteps from. + num_inference_steps (`int`): + The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps` + must be `None`. + device (`str` or `torch.device`, *optional*): + The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. + timesteps (`List[int]`, *optional*): + Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed, + `num_inference_steps` and `sigmas` must be `None`. + sigmas (`List[float]`, *optional*): + Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed, + `num_inference_steps` and `timesteps` must be `None`. + + Returns: + `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the + second element is the number of inference steps. + """ + if timesteps is not None and sigmas is not None: + raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values") + if timesteps is not None: + accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accepts_timesteps: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" timestep schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + elif sigmas is not None: + accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accept_sigmas: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" sigmas schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + else: + scheduler.set_timesteps(num_inference_steps, device=device, **kwargs) + timesteps = scheduler.timesteps + return timesteps, num_inference_steps + + +class AnyTextPipeline( + DiffusionPipeline, + StableDiffusionMixin, + TextualInversionLoaderMixin, + StableDiffusionLoraLoaderMixin, + IPAdapterMixin, + FromSingleFileMixin, +): + r""" + Pipeline for text-to-image generation using Stable Diffusion with ControlNet guidance. + + This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods + implemented for all pipelines (downloading, saving, running on a particular device, etc.). + + The pipeline also inherits the following loading methods: + - [`~loaders.TextualInversionLoaderMixin.load_textual_inversion`] for loading textual inversion embeddings + - [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_weights`] for loading LoRA weights + - [`~loaders.StableDiffusionLoraLoaderMixin.save_lora_weights`] for saving LoRA weights + - [`~loaders.FromSingleFileMixin.from_single_file`] for loading `.ckpt` files + - [`~loaders.IPAdapterMixin.load_ip_adapter`] for loading IP Adapters + + Args: + vae ([`AutoencoderKL`]): + Variational Auto-Encoder (VAE) model to encode and decode images to and from latent representations. + text_encoder ([`~transformers.CLIPTextModel`]): + Frozen text-encoder ([clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14)). + tokenizer ([`~transformers.CLIPTokenizer`]): + A `CLIPTokenizer` to tokenize text. + unet ([`UNet2DConditionModel`]): + A `UNet2DConditionModel` to denoise the encoded image latents. + controlnet ([`ControlNetModel`] or `List[ControlNetModel]`): + Provides additional conditioning to the `unet` during the denoising process. If you set multiple + ControlNets as a list, the outputs from each ControlNet are added together to create one combined + additional conditioning. + scheduler ([`SchedulerMixin`]): + A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of + [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`]. + safety_checker ([`StableDiffusionSafetyChecker`]): + Classification module that estimates whether generated images could be considered offensive or harmful. + Please refer to the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) for more details + about a model's potential harms. + feature_extractor ([`~transformers.CLIPImageProcessor`]): + A `CLIPImageProcessor` to extract features from generated images; used as inputs to the `safety_checker`. + """ + + model_cpu_offload_seq = "text_encoder->image_encoder->unet->vae" + _optional_components = ["safety_checker", "feature_extractor", "image_encoder"] + _exclude_from_cpu_offload = ["safety_checker"] + _callback_tensor_inputs = ["latents", "prompt_embeds", "negative_prompt_embeds"] + + def __init__( + self, + vae: AutoencoderKL, + text_encoder: CLIPTextModel, + tokenizer: CLIPTokenizer, + unet: UNet2DConditionModel, + controlnet: Union[ControlNetModel, List[ControlNetModel], Tuple[ControlNetModel], MultiControlNetModel], + scheduler: KarrasDiffusionSchedulers, + safety_checker: StableDiffusionSafetyChecker, + feature_extractor: CLIPImageProcessor, + image_encoder: CLIPVisionModelWithProjection = None, + requires_safety_checker: bool = True, + ): + super().__init__() + + if safety_checker is None and requires_safety_checker: + logger.warning( + f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure" + " that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered" + " results in services or applications open to the public. Both the diffusers team and Hugging Face" + " strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling" + " it only for use-cases that involve analyzing network behavior or auditing its results. For more" + " information, please have a look at https://github.com/huggingface/diffusers/pull/254 ." + ) + + if safety_checker is not None and feature_extractor is None: + raise ValueError( + "Make sure to define a feature extractor when loading {self.__class__} if you want to use the safety" + " checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead." + ) + + if isinstance(controlnet, (list, tuple)): + controlnet = MultiControlNetModel(controlnet) + + self.register_modules( + vae=vae, + text_encoder=text_encoder, + tokenizer=tokenizer, + unet=unet, + controlnet=controlnet, + scheduler=scheduler, + safety_checker=safety_checker, + feature_extractor=feature_extractor, + image_encoder=image_encoder, + ) + self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) + self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor, do_convert_rgb=True) + self.control_image_processor = VaeImageProcessor( + vae_scale_factor=self.vae_scale_factor, do_convert_rgb=True, do_normalize=False + ) + self.register_to_config(requires_safety_checker=requires_safety_checker) + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline._encode_prompt + def _encode_prompt( + self, + prompt, + device, + num_images_per_prompt, + do_classifier_free_guidance, + negative_prompt=None, + prompt_embeds: Optional[torch.Tensor] = None, + negative_prompt_embeds: Optional[torch.Tensor] = None, + lora_scale: Optional[float] = None, + **kwargs, + ): + deprecation_message = "`_encode_prompt()` is deprecated and it will be removed in a future version. Use `encode_prompt()` instead. Also, be aware that the output format changed from a concatenated tensor to a tuple." + deprecate("_encode_prompt()", "1.0.0", deprecation_message, standard_warn=False) + + prompt_embeds_tuple = self.encode_prompt( + prompt=prompt, + device=device, + num_images_per_prompt=num_images_per_prompt, + do_classifier_free_guidance=do_classifier_free_guidance, + negative_prompt=negative_prompt, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + lora_scale=lora_scale, + **kwargs, + ) + + # concatenate for backwards comp + prompt_embeds = torch.cat([prompt_embeds_tuple[1], prompt_embeds_tuple[0]]) + + return prompt_embeds + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.encode_prompt + def encode_prompt( + self, + prompt, + device, + num_images_per_prompt, + do_classifier_free_guidance, + negative_prompt=None, + prompt_embeds: Optional[torch.Tensor] = None, + negative_prompt_embeds: Optional[torch.Tensor] = None, + lora_scale: Optional[float] = None, + clip_skip: Optional[int] = None, + ): + r""" + Encodes the prompt into text encoder hidden states. + + Args: + prompt (`str` or `List[str]`, *optional*): + prompt to be encoded + device: (`torch.device`): + torch device + num_images_per_prompt (`int`): + number of images that should be generated per prompt + do_classifier_free_guidance (`bool`): + whether to use classifier free guidance or not + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is + less than `1`). + prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + lora_scale (`float`, *optional*): + A LoRA scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded. + clip_skip (`int`, *optional*): + Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that + the output of the pre-final layer will be used for computing the prompt embeddings. + """ + # set lora scale so that monkey patched LoRA + # function of text encoder can correctly access it + if lora_scale is not None and isinstance(self, StableDiffusionLoraLoaderMixin): + self._lora_scale = lora_scale + + # dynamically adjust the LoRA scale + if not USE_PEFT_BACKEND: + adjust_lora_scale_text_encoder(self.text_encoder, lora_scale) + else: + scale_lora_layers(self.text_encoder, lora_scale) + + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + if prompt_embeds is None: + # textual inversion: process multi-vector tokens if necessary + if isinstance(self, TextualInversionLoaderMixin): + prompt = self.maybe_convert_prompt(prompt, self.tokenizer) + + text_inputs = self.tokenizer( + prompt, + padding="max_length", + max_length=self.tokenizer.model_max_length, + truncation=True, + return_tensors="pt", + ) + text_input_ids = text_inputs.input_ids + untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids + + if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal( + text_input_ids, untruncated_ids + ): + removed_text = self.tokenizer.batch_decode( + untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1] + ) + logger.warning( + "The following part of your input was truncated because CLIP can only handle sequences up to" + f" {self.tokenizer.model_max_length} tokens: {removed_text}" + ) + + if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask: + attention_mask = text_inputs.attention_mask.to(device) + else: + attention_mask = None + + if clip_skip is None: + prompt_embeds = self.text_encoder(text_input_ids.to(device), attention_mask=attention_mask) + prompt_embeds = prompt_embeds[0] + else: + prompt_embeds = self.text_encoder( + text_input_ids.to(device), attention_mask=attention_mask, output_hidden_states=True + ) + # Access the `hidden_states` first, that contains a tuple of + # all the hidden states from the encoder layers. Then index into + # the tuple to access the hidden states from the desired layer. + prompt_embeds = prompt_embeds[-1][-(clip_skip + 1)] + # We also need to apply the final LayerNorm here to not mess with the + # representations. The `last_hidden_states` that we typically use for + # obtaining the final prompt representations passes through the LayerNorm + # layer. + prompt_embeds = self.text_encoder.text_model.final_layer_norm(prompt_embeds) + + if self.text_encoder is not None: + prompt_embeds_dtype = self.text_encoder.dtype + elif self.unet is not None: + prompt_embeds_dtype = self.unet.dtype + else: + prompt_embeds_dtype = prompt_embeds.dtype + + prompt_embeds = prompt_embeds.to(dtype=prompt_embeds_dtype, device=device) + + bs_embed, seq_len, _ = prompt_embeds.shape + # duplicate text embeddings for each generation per prompt, using mps friendly method + prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) + prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1) + + # get unconditional embeddings for classifier free guidance + if do_classifier_free_guidance and negative_prompt_embeds is None: + uncond_tokens: List[str] + if negative_prompt is None: + uncond_tokens = [""] * batch_size + elif prompt is not None and type(prompt) is not type(negative_prompt): + raise TypeError( + f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" + f" {type(prompt)}." + ) + elif isinstance(negative_prompt, str): + uncond_tokens = [negative_prompt] + elif batch_size != len(negative_prompt): + raise ValueError( + f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" + f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" + " the batch size of `prompt`." + ) + else: + uncond_tokens = negative_prompt + + # textual inversion: process multi-vector tokens if necessary + if isinstance(self, TextualInversionLoaderMixin): + uncond_tokens = self.maybe_convert_prompt(uncond_tokens, self.tokenizer) + + max_length = prompt_embeds.shape[1] + uncond_input = self.tokenizer( + uncond_tokens, + padding="max_length", + max_length=max_length, + truncation=True, + return_tensors="pt", + ) + + if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask: + attention_mask = uncond_input.attention_mask.to(device) + else: + attention_mask = None + + negative_prompt_embeds = self.text_encoder( + uncond_input.input_ids.to(device), + attention_mask=attention_mask, + ) + negative_prompt_embeds = negative_prompt_embeds[0] + + if do_classifier_free_guidance: + # duplicate unconditional embeddings for each generation per prompt, using mps friendly method + seq_len = negative_prompt_embeds.shape[1] + + negative_prompt_embeds = negative_prompt_embeds.to(dtype=prompt_embeds_dtype, device=device) + + negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1) + negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) + + if self.text_encoder is not None: + if isinstance(self, StableDiffusionLoraLoaderMixin) and USE_PEFT_BACKEND: + # Retrieve the original scale by scaling back the LoRA layers + unscale_lora_layers(self.text_encoder, lora_scale) + + return prompt_embeds, negative_prompt_embeds + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.encode_image + def encode_image(self, image, device, num_images_per_prompt, output_hidden_states=None): + dtype = next(self.image_encoder.parameters()).dtype + + if not isinstance(image, torch.Tensor): + image = self.feature_extractor(image, return_tensors="pt").pixel_values + + image = image.to(device=device, dtype=dtype) + if output_hidden_states: + image_enc_hidden_states = self.image_encoder(image, output_hidden_states=True).hidden_states[-2] + image_enc_hidden_states = image_enc_hidden_states.repeat_interleave(num_images_per_prompt, dim=0) + uncond_image_enc_hidden_states = self.image_encoder( + torch.zeros_like(image), output_hidden_states=True + ).hidden_states[-2] + uncond_image_enc_hidden_states = uncond_image_enc_hidden_states.repeat_interleave( + num_images_per_prompt, dim=0 + ) + return image_enc_hidden_states, uncond_image_enc_hidden_states + else: + image_embeds = self.image_encoder(image).image_embeds + image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0) + uncond_image_embeds = torch.zeros_like(image_embeds) + + return image_embeds, uncond_image_embeds + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_ip_adapter_image_embeds + def prepare_ip_adapter_image_embeds( + self, ip_adapter_image, ip_adapter_image_embeds, device, num_images_per_prompt, do_classifier_free_guidance + ): + image_embeds = [] + if do_classifier_free_guidance: + negative_image_embeds = [] + if ip_adapter_image_embeds is None: + if not isinstance(ip_adapter_image, list): + ip_adapter_image = [ip_adapter_image] + + if len(ip_adapter_image) != len(self.unet.encoder_hid_proj.image_projection_layers): + raise ValueError( + f"`ip_adapter_image` must have same length as the number of IP Adapters. Got {len(ip_adapter_image)} images and {len(self.unet.encoder_hid_proj.image_projection_layers)} IP Adapters." + ) + + for single_ip_adapter_image, image_proj_layer in zip( + ip_adapter_image, self.unet.encoder_hid_proj.image_projection_layers + ): + output_hidden_state = not isinstance(image_proj_layer, ImageProjection) + single_image_embeds, single_negative_image_embeds = self.encode_image( + single_ip_adapter_image, device, 1, output_hidden_state + ) + + image_embeds.append(single_image_embeds[None, :]) + if do_classifier_free_guidance: + negative_image_embeds.append(single_negative_image_embeds[None, :]) + else: + for single_image_embeds in ip_adapter_image_embeds: + if do_classifier_free_guidance: + single_negative_image_embeds, single_image_embeds = single_image_embeds.chunk(2) + negative_image_embeds.append(single_negative_image_embeds) + image_embeds.append(single_image_embeds) + + ip_adapter_image_embeds = [] + for i, single_image_embeds in enumerate(image_embeds): + single_image_embeds = torch.cat([single_image_embeds] * num_images_per_prompt, dim=0) + if do_classifier_free_guidance: + single_negative_image_embeds = torch.cat([negative_image_embeds[i]] * num_images_per_prompt, dim=0) + single_image_embeds = torch.cat([single_negative_image_embeds, single_image_embeds], dim=0) + + single_image_embeds = single_image_embeds.to(device=device) + ip_adapter_image_embeds.append(single_image_embeds) + + return ip_adapter_image_embeds + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.run_safety_checker + def run_safety_checker(self, image, device, dtype): + if self.safety_checker is None: + has_nsfw_concept = None + else: + if torch.is_tensor(image): + feature_extractor_input = self.image_processor.postprocess(image, output_type="pil") + else: + feature_extractor_input = self.image_processor.numpy_to_pil(image) + safety_checker_input = self.feature_extractor(feature_extractor_input, return_tensors="pt").to(device) + image, has_nsfw_concept = self.safety_checker( + images=image, clip_input=safety_checker_input.pixel_values.to(dtype) + ) + return image, has_nsfw_concept + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.decode_latents + def decode_latents(self, latents): + deprecation_message = "The decode_latents method is deprecated and will be removed in 1.0.0. Please use VaeImageProcessor.postprocess(...) instead" + deprecate("decode_latents", "1.0.0", deprecation_message, standard_warn=False) + + latents = 1 / self.vae.config.scaling_factor * latents + image = self.vae.decode(latents, return_dict=False)[0] + image = (image / 2 + 0.5).clamp(0, 1) + # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16 + image = image.cpu().permute(0, 2, 3, 1).float().numpy() + return image + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs + def prepare_extra_step_kwargs(self, generator, eta): + # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature + # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. + # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502 + # and should be between [0, 1] + + accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) + extra_step_kwargs = {} + if accepts_eta: + extra_step_kwargs["eta"] = eta + + # check if the scheduler accepts generator + accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys()) + if accepts_generator: + extra_step_kwargs["generator"] = generator + return extra_step_kwargs + + def check_inputs( + self, + prompt, + image, + callback_steps, + negative_prompt=None, + prompt_embeds=None, + negative_prompt_embeds=None, + ip_adapter_image=None, + ip_adapter_image_embeds=None, + controlnet_conditioning_scale=1.0, + control_guidance_start=0.0, + control_guidance_end=1.0, + callback_on_step_end_tensor_inputs=None, + ): + if callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0): + raise ValueError( + f"`callback_steps` has to be a positive integer but is {callback_steps} of type" + f" {type(callback_steps)}." + ) + + if callback_on_step_end_tensor_inputs is not None and not all( + k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs + ): + raise ValueError( + f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}" + ) + + if prompt is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt is None and prompt_embeds is None: + raise ValueError( + "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." + ) + elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + + if negative_prompt is not None and negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:" + f" {negative_prompt_embeds}. Please make sure to only forward one of the two." + ) + + if prompt_embeds is not None and negative_prompt_embeds is not None: + if prompt_embeds.shape != negative_prompt_embeds.shape: + raise ValueError( + "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but" + f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`" + f" {negative_prompt_embeds.shape}." + ) + + # Check `image` + is_compiled = hasattr(F, "scaled_dot_product_attention") and isinstance( + self.controlnet, torch._dynamo.eval_frame.OptimizedModule + ) + if ( + isinstance(self.controlnet, ControlNetModel) + or is_compiled + and isinstance(self.controlnet._orig_mod, ControlNetModel) + ): + self.check_image(image, prompt, prompt_embeds) + elif ( + isinstance(self.controlnet, MultiControlNetModel) + or is_compiled + and isinstance(self.controlnet._orig_mod, MultiControlNetModel) + ): + if not isinstance(image, list): + raise TypeError("For multiple controlnets: `image` must be type `list`") + + # When `image` is a nested list: + # (e.g. [[canny_image_1, pose_image_1], [canny_image_2, pose_image_2]]) + elif any(isinstance(i, list) for i in image): + transposed_image = [list(t) for t in zip(*image)] + if len(transposed_image) != len(self.controlnet.nets): + raise ValueError( + f"For multiple controlnets: if you pass`image` as a list of list, each sublist must have the same length as the number of controlnets, but the sublists in `image` got {len(transposed_image)} images and {len(self.controlnet.nets)} ControlNets." + ) + for image_ in transposed_image: + self.check_image(image_, prompt, prompt_embeds) + elif len(image) != len(self.controlnet.nets): + raise ValueError( + f"For multiple controlnets: `image` must have the same length as the number of controlnets, but got {len(image)} images and {len(self.controlnet.nets)} ControlNets." + ) + else: + for image_ in image: + self.check_image(image_, prompt, prompt_embeds) + else: + assert False + + # Check `controlnet_conditioning_scale` + if ( + isinstance(self.controlnet, ControlNetModel) + or is_compiled + and isinstance(self.controlnet._orig_mod, ControlNetModel) + ): + if not isinstance(controlnet_conditioning_scale, float): + raise TypeError("For single controlnet: `controlnet_conditioning_scale` must be type `float`.") + elif ( + isinstance(self.controlnet, MultiControlNetModel) + or is_compiled + and isinstance(self.controlnet._orig_mod, MultiControlNetModel) + ): + if isinstance(controlnet_conditioning_scale, list): + if any(isinstance(i, list) for i in controlnet_conditioning_scale): + raise ValueError( + "A single batch of varying conditioning scale settings (e.g. [[1.0, 0.5], [0.2, 0.8]]) is not supported at the moment. " + "The conditioning scale must be fixed across the batch." + ) + elif isinstance(controlnet_conditioning_scale, list) and len(controlnet_conditioning_scale) != len( + self.controlnet.nets + ): + raise ValueError( + "For multiple controlnets: When `controlnet_conditioning_scale` is specified as `list`, it must have" + " the same length as the number of controlnets" + ) + else: + assert False + + if not isinstance(control_guidance_start, (tuple, list)): + control_guidance_start = [control_guidance_start] + + if not isinstance(control_guidance_end, (tuple, list)): + control_guidance_end = [control_guidance_end] + + if len(control_guidance_start) != len(control_guidance_end): + raise ValueError( + f"`control_guidance_start` has {len(control_guidance_start)} elements, but `control_guidance_end` has {len(control_guidance_end)} elements. Make sure to provide the same number of elements to each list." + ) + + if isinstance(self.controlnet, MultiControlNetModel): + if len(control_guidance_start) != len(self.controlnet.nets): + raise ValueError( + f"`control_guidance_start`: {control_guidance_start} has {len(control_guidance_start)} elements but there are {len(self.controlnet.nets)} controlnets available. Make sure to provide {len(self.controlnet.nets)}." + ) + + for start, end in zip(control_guidance_start, control_guidance_end): + if start >= end: + raise ValueError( + f"control guidance start: {start} cannot be larger or equal to control guidance end: {end}." + ) + if start < 0.0: + raise ValueError(f"control guidance start: {start} can't be smaller than 0.") + if end > 1.0: + raise ValueError(f"control guidance end: {end} can't be larger than 1.0.") + + if ip_adapter_image is not None and ip_adapter_image_embeds is not None: + raise ValueError( + "Provide either `ip_adapter_image` or `ip_adapter_image_embeds`. Cannot leave both `ip_adapter_image` and `ip_adapter_image_embeds` defined." + ) + + if ip_adapter_image_embeds is not None: + if not isinstance(ip_adapter_image_embeds, list): + raise ValueError( + f"`ip_adapter_image_embeds` has to be of type `list` but is {type(ip_adapter_image_embeds)}" + ) + elif ip_adapter_image_embeds[0].ndim not in [3, 4]: + raise ValueError( + f"`ip_adapter_image_embeds` has to be a list of 3D or 4D tensors but is {ip_adapter_image_embeds[0].ndim}D" + ) + + def check_image(self, image, prompt, prompt_embeds): + image_is_pil = isinstance(image, PIL.Image.Image) + image_is_tensor = isinstance(image, torch.Tensor) + image_is_np = isinstance(image, np.ndarray) + image_is_pil_list = isinstance(image, list) and isinstance(image[0], PIL.Image.Image) + image_is_tensor_list = isinstance(image, list) and isinstance(image[0], torch.Tensor) + image_is_np_list = isinstance(image, list) and isinstance(image[0], np.ndarray) + + if ( + not image_is_pil + and not image_is_tensor + and not image_is_np + and not image_is_pil_list + and not image_is_tensor_list + and not image_is_np_list + ): + raise TypeError( + f"image must be passed and be one of PIL image, numpy array, torch tensor, list of PIL images, list of numpy arrays or list of torch tensors, but is {type(image)}" + ) + + if image_is_pil: + image_batch_size = 1 + else: + image_batch_size = len(image) + + if prompt is not None and isinstance(prompt, str): + prompt_batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + prompt_batch_size = len(prompt) + elif prompt_embeds is not None: + prompt_batch_size = prompt_embeds.shape[0] + + if image_batch_size != 1 and image_batch_size != prompt_batch_size: + raise ValueError( + f"If image batch size is not 1, image batch size must be same as prompt batch size. image batch size: {image_batch_size}, prompt batch size: {prompt_batch_size}" + ) + + def prepare_image( + self, + image, + width, + height, + batch_size, + num_images_per_prompt, + device, + dtype, + do_classifier_free_guidance=False, + guess_mode=False, + ): + image = self.control_image_processor.preprocess(image, height=height, width=width).to(dtype=torch.float32) + image_batch_size = image.shape[0] + + if image_batch_size == 1: + repeat_by = batch_size + else: + # image batch size is the same as prompt batch size + repeat_by = num_images_per_prompt + + image = image.repeat_interleave(repeat_by, dim=0) + + image = image.to(device=device, dtype=dtype) + + if do_classifier_free_guidance and not guess_mode: + image = torch.cat([image] * 2) + + return image + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents + def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None): + shape = ( + batch_size, + num_channels_latents, + int(height) // self.vae_scale_factor, + int(width) // self.vae_scale_factor, + ) + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + + if latents is None: + latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + else: + latents = latents.to(device) + + # scale the initial noise by the standard deviation required by the scheduler + latents = latents * self.scheduler.init_noise_sigma + return latents + + # Copied from diffusers.pipelines.latent_consistency_models.pipeline_latent_consistency_text2img.LatentConsistencyModelPipeline.get_guidance_scale_embedding + def get_guidance_scale_embedding( + self, w: torch.Tensor, embedding_dim: int = 512, dtype: torch.dtype = torch.float32 + ) -> torch.Tensor: + """ + See https://github.com/google-research/vdm/blob/dc27b98a554f65cdc654b800da5aa1846545d41b/model_vdm.py#L298 + + Args: + w (`torch.Tensor`): + Generate embedding vectors with a specified guidance scale to subsequently enrich timestep embeddings. + embedding_dim (`int`, *optional*, defaults to 512): + Dimension of the embeddings to generate. + dtype (`torch.dtype`, *optional*, defaults to `torch.float32`): + Data type of the generated embeddings. + + Returns: + `torch.Tensor`: Embedding vectors with shape `(len(w), embedding_dim)`. + """ + assert len(w.shape) == 1 + w = w * 1000.0 + + half_dim = embedding_dim // 2 + emb = torch.log(torch.tensor(10000.0)) / (half_dim - 1) + emb = torch.exp(torch.arange(half_dim, dtype=dtype) * -emb) + emb = w.to(dtype)[:, None] * emb[None, :] + emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1) + if embedding_dim % 2 == 1: # zero pad + emb = torch.nn.functional.pad(emb, (0, 1)) + assert emb.shape == (w.shape[0], embedding_dim) + return emb + + @property + def guidance_scale(self): + return self._guidance_scale + + @property + def clip_skip(self): + return self._clip_skip + + # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) + # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` + # corresponds to doing no classifier free guidance. + @property + def do_classifier_free_guidance(self): + return self._guidance_scale > 1 and self.unet.config.time_cond_proj_dim is None + + @property + def cross_attention_kwargs(self): + return self._cross_attention_kwargs + + @property + def num_timesteps(self): + return self._num_timesteps + + @torch.no_grad() + @replace_example_docstring(EXAMPLE_DOC_STRING) + def __call__( + self, + prompt: Union[str, List[str]] = None, + image: PipelineImageInput = None, + height: Optional[int] = None, + width: Optional[int] = None, + num_inference_steps: int = 50, + timesteps: List[int] = None, + sigmas: List[float] = None, + guidance_scale: float = 7.5, + negative_prompt: Optional[Union[str, List[str]]] = None, + num_images_per_prompt: Optional[int] = 1, + eta: float = 0.0, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + latents: Optional[torch.Tensor] = None, + prompt_embeds: Optional[torch.Tensor] = None, + negative_prompt_embeds: Optional[torch.Tensor] = None, + ip_adapter_image: Optional[PipelineImageInput] = None, + ip_adapter_image_embeds: Optional[List[torch.Tensor]] = None, + output_type: Optional[str] = "pil", + return_dict: bool = True, + cross_attention_kwargs: Optional[Dict[str, Any]] = None, + controlnet_conditioning_scale: Union[float, List[float]] = 1.0, + guess_mode: bool = False, + control_guidance_start: Union[float, List[float]] = 0.0, + control_guidance_end: Union[float, List[float]] = 1.0, + clip_skip: Optional[int] = None, + callback_on_step_end: Optional[ + Union[Callable[[int, int, Dict], None], PipelineCallback, MultiPipelineCallbacks] + ] = None, + callback_on_step_end_tensor_inputs: List[str] = ["latents"], + **kwargs, + ): + r""" + The call function to the pipeline for generation. + + Args: + prompt (`str` or `List[str]`, *optional*): + The prompt or prompts to guide image generation. If not defined, you need to pass `prompt_embeds`. + image (`torch.Tensor`, `PIL.Image.Image`, `np.ndarray`, `List[torch.Tensor]`, `List[PIL.Image.Image]`, `List[np.ndarray]`,: + `List[List[torch.Tensor]]`, `List[List[np.ndarray]]` or `List[List[PIL.Image.Image]]`): + The ControlNet input condition to provide guidance to the `unet` for generation. If the type is + specified as `torch.Tensor`, it is passed to ControlNet as is. `PIL.Image.Image` can also be accepted + as an image. The dimensions of the output image defaults to `image`'s dimensions. If height and/or + width are passed, `image` is resized accordingly. If multiple ControlNets are specified in `init`, + images must be passed as a list such that each element of the list can be correctly batched for input + to a single ControlNet. When `prompt` is a list, and if a list of images is passed for a single + ControlNet, each will be paired with each prompt in the `prompt` list. This also applies to multiple + ControlNets, where a list of image lists can be passed to batch for each prompt and each ControlNet. + height (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`): + The height in pixels of the generated image. + width (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`): + The width in pixels of the generated image. + num_inference_steps (`int`, *optional*, defaults to 50): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + timesteps (`List[int]`, *optional*): + Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument + in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is + passed will be used. Must be in descending order. + sigmas (`List[float]`, *optional*): + Custom sigmas to use for the denoising process with schedulers which support a `sigmas` argument in + their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed + will be used. + guidance_scale (`float`, *optional*, defaults to 7.5): + A higher guidance scale value encourages the model to generate images closely linked to the text + `prompt` at the expense of lower image quality. Guidance scale is enabled when `guidance_scale > 1`. + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts to guide what to not include in image generation. If not defined, you need to + pass `negative_prompt_embeds` instead. Ignored when not using guidance (`guidance_scale < 1`). + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + eta (`float`, *optional*, defaults to 0.0): + Corresponds to parameter eta (η) from the [DDIM](https://arxiv.org/abs/2010.02502) paper. Only applies + to the [`~schedulers.DDIMScheduler`], and is ignored in other schedulers. + generator (`torch.Generator` or `List[torch.Generator]`, *optional*): + A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make + generation deterministic. + latents (`torch.Tensor`, *optional*): + Pre-generated noisy latents sampled from a Gaussian distribution, to be used as inputs for image + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor is generated by sampling using the supplied random `generator`. + prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not + provided, text embeddings are generated from the `prompt` input argument. + negative_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs (prompt weighting). If + not provided, `negative_prompt_embeds` are generated from the `negative_prompt` input argument. + ip_adapter_image: (`PipelineImageInput`, *optional*): Optional image input to work with IP Adapters. + ip_adapter_image_embeds (`List[torch.Tensor]`, *optional*): + Pre-generated image embeddings for IP-Adapter. It should be a list of length same as number of + IP-adapters. Each element should be a tensor of shape `(batch_size, num_images, emb_dim)`. It should + contain the negative image embedding if `do_classifier_free_guidance` is set to `True`. If not + provided, embeddings are computed from the `ip_adapter_image` input argument. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generated image. Choose between `PIL.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a + plain tuple. + callback (`Callable`, *optional*): + A function that calls every `callback_steps` steps during inference. The function is called with the + following arguments: `callback(step: int, timestep: int, latents: torch.Tensor)`. + callback_steps (`int`, *optional*, defaults to 1): + The frequency at which the `callback` function is called. If not specified, the callback is called at + every step. + cross_attention_kwargs (`dict`, *optional*): + A kwargs dictionary that if specified is passed along to the [`AttentionProcessor`] as defined in + [`self.processor`](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). + controlnet_conditioning_scale (`float` or `List[float]`, *optional*, defaults to 1.0): + The outputs of the ControlNet are multiplied by `controlnet_conditioning_scale` before they are added + to the residual in the original `unet`. If multiple ControlNets are specified in `init`, you can set + the corresponding scale as a list. + guess_mode (`bool`, *optional*, defaults to `False`): + The ControlNet encoder tries to recognize the content of the input image even if you remove all + prompts. A `guidance_scale` value between 3.0 and 5.0 is recommended. + control_guidance_start (`float` or `List[float]`, *optional*, defaults to 0.0): + The percentage of total steps at which the ControlNet starts applying. + control_guidance_end (`float` or `List[float]`, *optional*, defaults to 1.0): + The percentage of total steps at which the ControlNet stops applying. + clip_skip (`int`, *optional*): + Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that + the output of the pre-final layer will be used for computing the prompt embeddings. + callback_on_step_end (`Callable`, `PipelineCallback`, `MultiPipelineCallbacks`, *optional*): + A function or a subclass of `PipelineCallback` or `MultiPipelineCallbacks` that is called at the end of + each denoising step during the inference. with the following arguments: `callback_on_step_end(self: + DiffusionPipeline, step: int, timestep: int, callback_kwargs: Dict)`. `callback_kwargs` will include a + list of all tensors as specified by `callback_on_step_end_tensor_inputs`. + callback_on_step_end_tensor_inputs (`List`, *optional*): + The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list + will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the + `._callback_tensor_inputs` attribute of your pipeline class. + + Examples: + + Returns: + [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`: + If `return_dict` is `True`, [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] is returned, + otherwise a `tuple` is returned where the first element is a list with the generated images and the + second element is a list of `bool`s indicating whether the corresponding generated image contains + "not-safe-for-work" (nsfw) content. + """ + + callback = kwargs.pop("callback", None) + callback_steps = kwargs.pop("callback_steps", None) + + if callback is not None: + deprecate( + "callback", + "1.0.0", + "Passing `callback` as an input argument to `__call__` is deprecated, consider using `callback_on_step_end`", + ) + if callback_steps is not None: + deprecate( + "callback_steps", + "1.0.0", + "Passing `callback_steps` as an input argument to `__call__` is deprecated, consider using `callback_on_step_end`", + ) + + if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)): + callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs + + controlnet = self.controlnet._orig_mod if is_compiled_module(self.controlnet) else self.controlnet + + # align format for control guidance + if not isinstance(control_guidance_start, list) and isinstance(control_guidance_end, list): + control_guidance_start = len(control_guidance_end) * [control_guidance_start] + elif not isinstance(control_guidance_end, list) and isinstance(control_guidance_start, list): + control_guidance_end = len(control_guidance_start) * [control_guidance_end] + elif not isinstance(control_guidance_start, list) and not isinstance(control_guidance_end, list): + mult = len(controlnet.nets) if isinstance(controlnet, MultiControlNetModel) else 1 + control_guidance_start, control_guidance_end = ( + mult * [control_guidance_start], + mult * [control_guidance_end], + ) + + # 1. Check inputs. Raise error if not correct + self.check_inputs( + prompt, + image, + callback_steps, + negative_prompt, + prompt_embeds, + negative_prompt_embeds, + ip_adapter_image, + ip_adapter_image_embeds, + controlnet_conditioning_scale, + control_guidance_start, + control_guidance_end, + callback_on_step_end_tensor_inputs, + ) + + self._guidance_scale = guidance_scale + self._clip_skip = clip_skip + self._cross_attention_kwargs = cross_attention_kwargs + + # 2. Define call parameters + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + device = self._execution_device + + if isinstance(controlnet, MultiControlNetModel) and isinstance(controlnet_conditioning_scale, float): + controlnet_conditioning_scale = [controlnet_conditioning_scale] * len(controlnet.nets) + + global_pool_conditions = ( + controlnet.config.global_pool_conditions + if isinstance(controlnet, ControlNetModel) + else controlnet.nets[0].config.global_pool_conditions + ) + guess_mode = guess_mode or global_pool_conditions + + # 3. Encode input prompt + text_encoder_lora_scale = ( + self.cross_attention_kwargs.get("scale", None) if self.cross_attention_kwargs is not None else None + ) + prompt_embeds, negative_prompt_embeds = self.encode_prompt( + prompt, + device, + num_images_per_prompt, + self.do_classifier_free_guidance, + negative_prompt, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + lora_scale=text_encoder_lora_scale, + clip_skip=self.clip_skip, + ) + # For classifier free guidance, we need to do two forward passes. + # Here we concatenate the unconditional and text embeddings into a single batch + # to avoid doing two forward passes + if self.do_classifier_free_guidance: + prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds]) + + if ip_adapter_image is not None or ip_adapter_image_embeds is not None: + image_embeds = self.prepare_ip_adapter_image_embeds( + ip_adapter_image, + ip_adapter_image_embeds, + device, + batch_size * num_images_per_prompt, + self.do_classifier_free_guidance, + ) + + # 4. Prepare image + if isinstance(controlnet, ControlNetModel): + image = self.prepare_image( + image=image, + width=width, + height=height, + batch_size=batch_size * num_images_per_prompt, + num_images_per_prompt=num_images_per_prompt, + device=device, + dtype=controlnet.dtype, + do_classifier_free_guidance=self.do_classifier_free_guidance, + guess_mode=guess_mode, + ) + height, width = image.shape[-2:] + elif isinstance(controlnet, MultiControlNetModel): + images = [] + + # Nested lists as ControlNet condition + if isinstance(image[0], list): + # Transpose the nested image list + image = [list(t) for t in zip(*image)] + + for image_ in image: + image_ = self.prepare_image( + image=image_, + width=width, + height=height, + batch_size=batch_size * num_images_per_prompt, + num_images_per_prompt=num_images_per_prompt, + device=device, + dtype=controlnet.dtype, + do_classifier_free_guidance=self.do_classifier_free_guidance, + guess_mode=guess_mode, + ) + + images.append(image_) + + image = images + height, width = image[0].shape[-2:] + else: + assert False + + # 5. Prepare timesteps + timesteps, num_inference_steps = retrieve_timesteps( + self.scheduler, num_inference_steps, device, timesteps, sigmas + ) + self._num_timesteps = len(timesteps) + + # 6. Prepare latent variables + num_channels_latents = self.unet.config.in_channels + latents = self.prepare_latents( + batch_size * num_images_per_prompt, + num_channels_latents, + height, + width, + prompt_embeds.dtype, + device, + generator, + latents, + ) + + # 6.5 Optionally get Guidance Scale Embedding + timestep_cond = None + if self.unet.config.time_cond_proj_dim is not None: + guidance_scale_tensor = torch.tensor(self.guidance_scale - 1).repeat(batch_size * num_images_per_prompt) + timestep_cond = self.get_guidance_scale_embedding( + guidance_scale_tensor, embedding_dim=self.unet.config.time_cond_proj_dim + ).to(device=device, dtype=latents.dtype) + + # 7. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline + extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) + + # 7.1 Add image embeds for IP-Adapter + added_cond_kwargs = ( + {"image_embeds": image_embeds} + if ip_adapter_image is not None or ip_adapter_image_embeds is not None + else None + ) + + # 7.2 Create tensor stating which controlnets to keep + controlnet_keep = [] + for i in range(len(timesteps)): + keeps = [ + 1.0 - float(i / len(timesteps) < s or (i + 1) / len(timesteps) > e) + for s, e in zip(control_guidance_start, control_guidance_end) + ] + controlnet_keep.append(keeps[0] if isinstance(controlnet, ControlNetModel) else keeps) + + # 8. Denoising loop + num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order + is_unet_compiled = is_compiled_module(self.unet) + is_controlnet_compiled = is_compiled_module(self.controlnet) + is_torch_higher_equal_2_1 = is_torch_version(">=", "2.1") + with self.progress_bar(total=num_inference_steps) as progress_bar: + for i, t in enumerate(timesteps): + # Relevant thread: + # https://dev-discuss.pytorch.org/t/cudagraphs-in-pytorch-2-0/1428 + if (is_unet_compiled and is_controlnet_compiled) and is_torch_higher_equal_2_1: + torch._inductor.cudagraph_mark_step_begin() + # expand the latents if we are doing classifier free guidance + latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents + latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) + + # controlnet(s) inference + if guess_mode and self.do_classifier_free_guidance: + # Infer ControlNet only for the conditional batch. + control_model_input = latents + control_model_input = self.scheduler.scale_model_input(control_model_input, t) + controlnet_prompt_embeds = prompt_embeds.chunk(2)[1] + else: + control_model_input = latent_model_input + controlnet_prompt_embeds = prompt_embeds + + if isinstance(controlnet_keep[i], list): + cond_scale = [c * s for c, s in zip(controlnet_conditioning_scale, controlnet_keep[i])] + else: + controlnet_cond_scale = controlnet_conditioning_scale + if isinstance(controlnet_cond_scale, list): + controlnet_cond_scale = controlnet_cond_scale[0] + cond_scale = controlnet_cond_scale * controlnet_keep[i] + + down_block_res_samples, mid_block_res_sample = self.controlnet( + control_model_input, + t, + encoder_hidden_states=controlnet_prompt_embeds, + controlnet_cond=image, + conditioning_scale=cond_scale, + guess_mode=guess_mode, + return_dict=False, + ) + + if guess_mode and self.do_classifier_free_guidance: + # Infered ControlNet only for the conditional batch. + # To apply the output of ControlNet to both the unconditional and conditional batches, + # add 0 to the unconditional batch to keep it unchanged. + down_block_res_samples = [torch.cat([torch.zeros_like(d), d]) for d in down_block_res_samples] + mid_block_res_sample = torch.cat([torch.zeros_like(mid_block_res_sample), mid_block_res_sample]) + + # predict the noise residual + noise_pred = self.unet( + latent_model_input, + t, + encoder_hidden_states=prompt_embeds, + timestep_cond=timestep_cond, + cross_attention_kwargs=self.cross_attention_kwargs, + down_block_additional_residuals=down_block_res_samples, + mid_block_additional_residual=mid_block_res_sample, + added_cond_kwargs=added_cond_kwargs, + return_dict=False, + )[0] + + # perform guidance + if self.do_classifier_free_guidance: + noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) + noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond) + + # compute the previous noisy sample x_t -> x_t-1 + latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0] + + if callback_on_step_end is not None: + callback_kwargs = {} + for k in callback_on_step_end_tensor_inputs: + callback_kwargs[k] = locals()[k] + callback_outputs = callback_on_step_end(self, i, t, callback_kwargs) + + latents = callback_outputs.pop("latents", latents) + prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds) + negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds) + + # call the callback, if provided + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + progress_bar.update() + if callback is not None and i % callback_steps == 0: + step_idx = i // getattr(self.scheduler, "order", 1) + callback(step_idx, t, latents) + + # If we do sequential model offloading, let's offload unet and controlnet + # manually for max memory savings + if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None: + self.unet.to("cpu") + self.controlnet.to("cpu") + torch.cuda.empty_cache() + + if not output_type == "latent": + image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False, generator=generator)[ + 0 + ] + image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype) + else: + image = latents + has_nsfw_concept = None + + if has_nsfw_concept is None: + do_denormalize = [True] * image.shape[0] + else: + do_denormalize = [not has_nsfw for has_nsfw in has_nsfw_concept] + + image = self.image_processor.postprocess(image, output_type=output_type, do_denormalize=do_denormalize) + + # Offload all models + self.maybe_free_model_hooks() + + if not return_dict: + return (image, has_nsfw_concept) + + return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept) diff --git a/examples/research_projects/anytext/text_controlnet.py b/examples/research_projects/anytext/text_controlnet.py new file mode 100644 index 000000000000..3587ec7358b8 --- /dev/null +++ b/examples/research_projects/anytext/text_controlnet.py @@ -0,0 +1,387 @@ +# Copyright 2023 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# TODO: Figure out `hint_channels` enigma(?) +from typing import Any, Dict, Optional, Tuple, Union + +import torch + +from diffusers.configuration_utils import register_to_config +from diffusers.models.controlnet import ( + ControlNetConditioningEmbedding, + ControlNetModel, + ControlNetOutput, +) +from diffusers.utils import logging + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +class TextControlNetModel(ControlNetModel): + """ + A PromptDiffusionControlNet model. + + Args: + in_channels (`int`, defaults to 4): + The number of channels in the input sample. + flip_sin_to_cos (`bool`, defaults to `True`): + Whether to flip the sin to cos in the time embedding. + freq_shift (`int`, defaults to 0): + The frequency shift to apply to the time embedding. + down_block_types (`tuple[str]`, defaults to `("CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "DownBlock2D")`): + The tuple of downsample blocks to use. + only_cross_attention (`Union[bool, Tuple[bool]]`, defaults to `False`): + block_out_channels (`tuple[int]`, defaults to `(320, 640, 1280, 1280)`): + The tuple of output channels for each block. + layers_per_block (`int`, defaults to 2): + The number of layers per block. + downsample_padding (`int`, defaults to 1): + The padding to use for the downsampling convolution. + mid_block_scale_factor (`float`, defaults to 1): + The scale factor to use for the mid block. + act_fn (`str`, defaults to "silu"): + The activation function to use. + norm_num_groups (`int`, *optional*, defaults to 32): + The number of groups to use for the normalization. If None, normalization and activation layers is skipped + in post-processing. + norm_eps (`float`, defaults to 1e-5): + The epsilon to use for the normalization. + cross_attention_dim (`int`, defaults to 1280): + The dimension of the cross attention features. + transformer_layers_per_block (`int` or `Tuple[int]`, *optional*, defaults to 1): + The number of transformer blocks of type [`~models.attention.BasicTransformerBlock`]. Only relevant for + [`~models.unet_2d_blocks.CrossAttnDownBlock2D`], [`~models.unet_2d_blocks.CrossAttnUpBlock2D`], + [`~models.unet_2d_blocks.UNetMidBlock2DCrossAttn`]. + encoder_hid_dim (`int`, *optional*, defaults to None): + If `encoder_hid_dim_type` is defined, `encoder_hidden_states` will be projected from `encoder_hid_dim` + dimension to `cross_attention_dim`. + encoder_hid_dim_type (`str`, *optional*, defaults to `None`): + If given, the `encoder_hidden_states` and potentially other embeddings are down-projected to text + embeddings of dimension `cross_attention` according to `encoder_hid_dim_type`. + attention_head_dim (`Union[int, Tuple[int]]`, defaults to 8): + The dimension of the attention heads. + use_linear_projection (`bool`, defaults to `False`): + class_embed_type (`str`, *optional*, defaults to `None`): + The type of class embedding to use which is ultimately summed with the time embeddings. Choose from None, + `"timestep"`, `"identity"`, `"projection"`, or `"simple_projection"`. + addition_embed_type (`str`, *optional*, defaults to `None`): + Configures an optional embedding which will be summed with the time embeddings. Choose from `None` or + "text". "text" will use the `TextTimeEmbedding` layer. + num_class_embeds (`int`, *optional*, defaults to 0): + Input dimension of the learnable embedding matrix to be projected to `time_embed_dim`, when performing + class conditioning with `class_embed_type` equal to `None`. + upcast_attention (`bool`, defaults to `False`): + resnet_time_scale_shift (`str`, defaults to `"default"`): + Time scale shift config for ResNet blocks (see `ResnetBlock2D`). Choose from `default` or `scale_shift`. + projection_class_embeddings_input_dim (`int`, *optional*, defaults to `None`): + The dimension of the `class_labels` input when `class_embed_type="projection"`. Required when + `class_embed_type="projection"`. + controlnet_conditioning_channel_order (`str`, defaults to `"rgb"`): + The channel order of conditional image. Will convert to `rgb` if it's `bgr`. + conditioning_embedding_out_channels (`tuple[int]`, *optional*, defaults to `(16, 32, 96, 256)`): + The tuple of output channel for each block in the `conditioning_embedding` layer. + global_pool_conditions (`bool`, defaults to `False`): + TODO(Patrick) - unused parameter. + addition_embed_type_num_heads (`int`, defaults to 64): + The number of heads to use for the `TextTimeEmbedding` layer. + """ + + _supports_gradient_checkpointing = True + + @register_to_config + def __init__( + self, + in_channels: int = 4, + conditioning_channels: int = 3, + flip_sin_to_cos: bool = True, + freq_shift: int = 0, + down_block_types: Tuple[str, ...] = ( + "CrossAttnDownBlock2D", + "CrossAttnDownBlock2D", + "CrossAttnDownBlock2D", + "DownBlock2D", + ), + mid_block_type: Optional[str] = "UNetMidBlock2DCrossAttn", + only_cross_attention: Union[bool, Tuple[bool]] = False, + block_out_channels: Tuple[int, ...] = (320, 640, 1280, 1280), + layers_per_block: int = 2, + downsample_padding: int = 1, + mid_block_scale_factor: float = 1, + act_fn: str = "silu", + norm_num_groups: Optional[int] = 32, + norm_eps: float = 1e-5, + cross_attention_dim: int = 1280, + transformer_layers_per_block: Union[int, Tuple[int, ...]] = 1, + encoder_hid_dim: Optional[int] = None, + encoder_hid_dim_type: Optional[str] = None, + attention_head_dim: Union[int, Tuple[int, ...]] = 8, + num_attention_heads: Optional[Union[int, Tuple[int, ...]]] = None, + use_linear_projection: bool = False, + class_embed_type: Optional[str] = None, + addition_embed_type: Optional[str] = None, + addition_time_embed_dim: Optional[int] = None, + num_class_embeds: Optional[int] = None, + upcast_attention: bool = False, + resnet_time_scale_shift: str = "default", + projection_class_embeddings_input_dim: Optional[int] = None, + controlnet_conditioning_channel_order: str = "rgb", + conditioning_embedding_out_channels: Optional[Tuple[int, ...]] = (16, 32, 96, 256), + global_pool_conditions: bool = False, + addition_embed_type_num_heads: int = 64, + ): + super().__init__( + in_channels, + conditioning_channels, + flip_sin_to_cos, + freq_shift, + down_block_types, + mid_block_type, + only_cross_attention, + block_out_channels, + layers_per_block, + downsample_padding, + mid_block_scale_factor, + act_fn, + norm_num_groups, + norm_eps, + cross_attention_dim, + transformer_layers_per_block, + encoder_hid_dim, + encoder_hid_dim_type, + attention_head_dim, + num_attention_heads, + use_linear_projection, + class_embed_type, + addition_embed_type, + addition_time_embed_dim, + num_class_embeds, + upcast_attention, + resnet_time_scale_shift, + projection_class_embeddings_input_dim, + controlnet_conditioning_channel_order, + conditioning_embedding_out_channels, + global_pool_conditions, + addition_embed_type_num_heads, + ) + self.controlnet_query_cond_embedding = ControlNetConditioningEmbedding( + conditioning_embedding_channels=block_out_channels[0], + block_out_channels=conditioning_embedding_out_channels, + conditioning_channels=3, + ) + + def forward( + self, + sample: torch.Tensor, + timestep: Union[torch.Tensor, float, int], + encoder_hidden_states: torch.Tensor, + controlnet_cond: torch.Tensor, + controlnet_query_cond: torch.Tensor, + conditioning_scale: float = 1.0, + class_labels: Optional[torch.Tensor] = None, + timestep_cond: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + added_cond_kwargs: Optional[Dict[str, torch.Tensor]] = None, + cross_attention_kwargs: Optional[Dict[str, Any]] = None, + guess_mode: bool = False, + return_dict: bool = True, + ) -> Union[ControlNetOutput, Tuple[Tuple[torch.Tensor, ...], torch.Tensor]]: + """ + The [`~PromptDiffusionControlNetModel`] forward method. + + Args: + sample (`torch.Tensor`): + The noisy input tensor. + timestep (`Union[torch.Tensor, float, int]`): + The number of timesteps to denoise an input. + encoder_hidden_states (`torch.Tensor`): + The encoder hidden states. + controlnet_cond (`torch.Tensor`): + The conditional input tensor of shape `(batch_size, sequence_length, hidden_size)`. + controlnet_query_cond (`torch.Tensor`): + The conditional input tensor of shape `(batch_size, sequence_length, hidden_size)`. + conditioning_scale (`float`, defaults to `1.0`): + The scale factor for ControlNet outputs. + class_labels (`torch.Tensor`, *optional*, defaults to `None`): + Optional class labels for conditioning. Their embeddings will be summed with the timestep embeddings. + timestep_cond (`torch.Tensor`, *optional*, defaults to `None`): + Additional conditional embeddings for timestep. If provided, the embeddings will be summed with the + timestep_embedding passed through the `self.time_embedding` layer to obtain the final timestep + embeddings. + attention_mask (`torch.Tensor`, *optional*, defaults to `None`): + An attention mask of shape `(batch, key_tokens)` is applied to `encoder_hidden_states`. If `1` the mask + is kept, otherwise if `0` it is discarded. Mask will be converted into a bias, which adds large + negative values to the attention scores corresponding to "discard" tokens. + added_cond_kwargs (`dict`): + Additional conditions for the Stable Diffusion XL UNet. + cross_attention_kwargs (`dict[str]`, *optional*, defaults to `None`): + A kwargs dictionary that if specified is passed along to the `AttnProcessor`. + guess_mode (`bool`, defaults to `False`): + In this mode, the ControlNet encoder tries its best to recognize the input content of the input even if + you remove all prompts. A `guidance_scale` between 3.0 and 5.0 is recommended. + return_dict (`bool`, defaults to `True`): + Whether or not to return a [`~models.controlnet.ControlNetOutput`] instead of a plain tuple. + + Returns: + [`~models.controlnet.ControlNetOutput`] **or** `tuple`: + If `return_dict` is `True`, a [`~models.controlnet.ControlNetOutput`] is returned, otherwise a tuple is + returned where the first element is the sample tensor. + """ + # check channel order + channel_order = self.config.controlnet_conditioning_channel_order + + if channel_order == "rgb": + # in rgb order by default + ... + elif channel_order == "bgr": + controlnet_cond = torch.flip(controlnet_cond, dims=[1]) + else: + raise ValueError(f"unknown `controlnet_conditioning_channel_order`: {channel_order}") + + # prepare attention_mask + if attention_mask is not None: + attention_mask = (1 - attention_mask.to(sample.dtype)) * -10000.0 + attention_mask = attention_mask.unsqueeze(1) + + # 1. time + timesteps = timestep + if not torch.is_tensor(timesteps): + # TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can + # This would be a good case for the `match` statement (Python 3.10+) + is_mps = sample.device.type == "mps" + if isinstance(timestep, float): + dtype = torch.float32 if is_mps else torch.float64 + else: + dtype = torch.int32 if is_mps else torch.int64 + timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device) + elif len(timesteps.shape) == 0: + timesteps = timesteps[None].to(sample.device) + + # broadcast to batch dimension in a way that's compatible with ONNX/Core ML + timesteps = timesteps.expand(sample.shape[0]) + + t_emb = self.time_proj(timesteps) + + # timesteps does not contain any weights and will always return f32 tensors + # but time_embedding might actually be running in fp16. so we need to cast here. + # there might be better ways to encapsulate this. + t_emb = t_emb.to(dtype=sample.dtype) + + emb = self.time_embedding(t_emb, timestep_cond) + aug_emb = None + + if self.class_embedding is not None: + if class_labels is None: + raise ValueError("class_labels should be provided when num_class_embeds > 0") + + if self.config.class_embed_type == "timestep": + class_labels = self.time_proj(class_labels) + + class_emb = self.class_embedding(class_labels).to(dtype=self.dtype) + emb = emb + class_emb + + if self.config.addition_embed_type is not None: + if self.config.addition_embed_type == "text": + aug_emb = self.add_embedding(encoder_hidden_states) + + elif self.config.addition_embed_type == "text_time": + if "text_embeds" not in added_cond_kwargs: + raise ValueError( + f"{self.__class__} has the config param `addition_embed_type` set to 'text_time' which requires the keyword argument `text_embeds` to be passed in `added_cond_kwargs`" + ) + text_embeds = added_cond_kwargs.get("text_embeds") + if "time_ids" not in added_cond_kwargs: + raise ValueError( + f"{self.__class__} has the config param `addition_embed_type` set to 'text_time' which requires the keyword argument `time_ids` to be passed in `added_cond_kwargs`" + ) + time_ids = added_cond_kwargs.get("time_ids") + time_embeds = self.add_time_proj(time_ids.flatten()) + time_embeds = time_embeds.reshape((text_embeds.shape[0], -1)) + + add_embeds = torch.concat([text_embeds, time_embeds], dim=-1) + add_embeds = add_embeds.to(emb.dtype) + aug_emb = self.add_embedding(add_embeds) + + emb = emb + aug_emb if aug_emb is not None else emb + + # 2. pre-process + sample = self.conv_in(sample) + + controlnet_cond = self.controlnet_cond_embedding(controlnet_cond) + controlnet_query_cond = self.controlnet_query_cond_embedding(controlnet_query_cond) + sample = sample + controlnet_cond + controlnet_query_cond + + # 3. down + down_block_res_samples = (sample,) + for downsample_block in self.down_blocks: + if hasattr(downsample_block, "has_cross_attention") and downsample_block.has_cross_attention: + sample, res_samples = downsample_block( + hidden_states=sample, + temb=emb, + encoder_hidden_states=encoder_hidden_states, + attention_mask=attention_mask, + cross_attention_kwargs=cross_attention_kwargs, + ) + else: + sample, res_samples = downsample_block(hidden_states=sample, temb=emb) + + down_block_res_samples += res_samples + + # 4. mid + if self.mid_block is not None: + if hasattr(self.mid_block, "has_cross_attention") and self.mid_block.has_cross_attention: + sample = self.mid_block( + sample, + emb, + encoder_hidden_states=encoder_hidden_states, + attention_mask=attention_mask, + cross_attention_kwargs=cross_attention_kwargs, + ) + else: + sample = self.mid_block(sample, emb) + + # 5. Control net blocks + + controlnet_down_block_res_samples = () + + for down_block_res_sample, controlnet_block in zip(down_block_res_samples, self.controlnet_down_blocks): + down_block_res_sample = controlnet_block(down_block_res_sample) + controlnet_down_block_res_samples = controlnet_down_block_res_samples + (down_block_res_sample,) + + down_block_res_samples = controlnet_down_block_res_samples + + mid_block_res_sample = self.controlnet_mid_block(sample) + + # 6. scaling + if guess_mode and not self.config.global_pool_conditions: + scales = torch.logspace(-1, 0, len(down_block_res_samples) + 1, device=sample.device) # 0.1 to 1.0 + scales = scales * conditioning_scale + down_block_res_samples = [sample * scale for sample, scale in zip(down_block_res_samples, scales)] + mid_block_res_sample = mid_block_res_sample * scales[-1] # last one + else: + down_block_res_samples = [sample * conditioning_scale for sample in down_block_res_samples] + mid_block_res_sample = mid_block_res_sample * conditioning_scale + + if self.config.global_pool_conditions: + down_block_res_samples = [ + torch.mean(sample, dim=(2, 3), keepdim=True) for sample in down_block_res_samples + ] + mid_block_res_sample = torch.mean(mid_block_res_sample, dim=(2, 3), keepdim=True) + + if not return_dict: + return (down_block_res_samples, mid_block_res_sample) + + return ControlNetOutput( + down_block_res_samples=down_block_res_samples, mid_block_res_sample=mid_block_res_sample + ) diff --git a/examples/research_projects/anytext/text_embedding_module.py b/examples/research_projects/anytext/text_embedding_module.py new file mode 100644 index 000000000000..38d84982751c --- /dev/null +++ b/examples/research_projects/anytext/text_embedding_module.py @@ -0,0 +1,3 @@ +# text -> glyph render -> glyph lines -> OCR -> linear -> + # +> Token Replacement -> FrozenCLIPEmbedderT3 +# text -> tokenizer -> From 98c2d6ed99b8866871e938b5715ba8b1d7b58f98 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tolga=20Cang=C3=B6z?= Date: Mon, 29 Jul 2024 21:50:40 +0300 Subject: [PATCH 02/87] Second template --- .../anytext/auxiliary_latent_module.py | 2 +- .../anytext/frozen_clip_embedder_t3.py | 206 +++++++++++++++++ .../anytext/pipeline_anytext.py | 5 + .../anytext/text_embedding_module.py | 215 +++++++++++++++++- 4 files changed, 426 insertions(+), 2 deletions(-) create mode 100644 examples/research_projects/anytext/frozen_clip_embedder_t3.py diff --git a/examples/research_projects/anytext/auxiliary_latent_module.py b/examples/research_projects/anytext/auxiliary_latent_module.py index 41f7665cd437..5564af8c1dc0 100644 --- a/examples/research_projects/anytext/auxiliary_latent_module.py +++ b/examples/research_projects/anytext/auxiliary_latent_module.py @@ -1,3 +1,3 @@ # text -> glyph render -> glyph l_g -> glyph block -> - # +> fuse layer +# +> fuse layer # position l_p -> position block -> diff --git a/examples/research_projects/anytext/frozen_clip_embedder_t3.py b/examples/research_projects/anytext/frozen_clip_embedder_t3.py new file mode 100644 index 000000000000..7de2b8aed492 --- /dev/null +++ b/examples/research_projects/anytext/frozen_clip_embedder_t3.py @@ -0,0 +1,206 @@ +import torch +from torch import nn +from transformers import AutoProcessor, CLIPTextModel, CLIPTokenizer, CLIPVisionModelWithProjection +from transformers.models.clip.modeling_clip import _build_causal_attention_mask, _expand_mask + + +class AbstractEncoder(nn.Module): + def __init__(self): + super().__init__() + + def encode(self, *args, **kwargs): + raise NotImplementedError + + +class FrozenCLIPEmbedderT3(AbstractEncoder): + """Uses the CLIP transformer encoder for text (from Hugging Face)""" + + def __init__( + self, version="openai/clip-vit-large-patch14", device="cuda", max_length=77, freeze=True, use_vision=False + ): + super().__init__() + self.tokenizer = CLIPTokenizer.from_pretrained(version) + self.transformer = CLIPTextModel.from_pretrained(version) + if use_vision: + self.vit = CLIPVisionModelWithProjection.from_pretrained(version) + self.processor = AutoProcessor.from_pretrained(version) + self.device = device + self.max_length = max_length + if freeze: + self.freeze() + + def embedding_forward( + self, + input_ids=None, + position_ids=None, + inputs_embeds=None, + embedding_manager=None, + ): + seq_length = input_ids.shape[-1] if input_ids is not None else inputs_embeds.shape[-2] + if position_ids is None: + position_ids = self.position_ids[:, :seq_length] + if inputs_embeds is None: + inputs_embeds = self.token_embedding(input_ids) + if embedding_manager is not None: + inputs_embeds = embedding_manager(input_ids, inputs_embeds) + position_embeddings = self.position_embedding(position_ids) + embeddings = inputs_embeds + position_embeddings + return embeddings + + self.transformer.text_model.embeddings.forward = embedding_forward.__get__( + self.transformer.text_model.embeddings + ) + + def encoder_forward( + self, + inputs_embeds, + attention_mask=None, + causal_attention_mask=None, + output_attentions=None, + output_hidden_states=None, + return_dict=None, + ): + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + encoder_states = () if output_hidden_states else None + all_attentions = () if output_attentions else None + hidden_states = inputs_embeds + for idx, encoder_layer in enumerate(self.layers): + if output_hidden_states: + encoder_states = encoder_states + (hidden_states,) + layer_outputs = encoder_layer( + hidden_states, + attention_mask, + causal_attention_mask, + output_attentions=output_attentions, + ) + hidden_states = layer_outputs[0] + if output_attentions: + all_attentions = all_attentions + (layer_outputs[1],) + if output_hidden_states: + encoder_states = encoder_states + (hidden_states,) + return hidden_states + + self.transformer.text_model.encoder.forward = encoder_forward.__get__(self.transformer.text_model.encoder) + + def text_encoder_forward( + self, + input_ids=None, + attention_mask=None, + position_ids=None, + output_attentions=None, + output_hidden_states=None, + return_dict=None, + embedding_manager=None, + ): + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + if input_ids is None: + raise ValueError("You have to specify either input_ids") + input_shape = input_ids.size() + input_ids = input_ids.view(-1, input_shape[-1]) + hidden_states = self.embeddings( + input_ids=input_ids, position_ids=position_ids, embedding_manager=embedding_manager + ) + bsz, seq_len = input_shape + # CLIP's text model uses causal mask, prepare it here. + # https://github.com/openai/CLIP/blob/cfcffb90e69f37bf2ff1e988237a0fbe41f33c04/clip/model.py#L324 + causal_attention_mask = _build_causal_attention_mask(bsz, seq_len, hidden_states.dtype).to( + hidden_states.device + ) + # expand attention_mask + if attention_mask is not None: + # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] + attention_mask = _expand_mask(attention_mask, hidden_states.dtype) + last_hidden_state = self.encoder( + inputs_embeds=hidden_states, + attention_mask=attention_mask, + causal_attention_mask=causal_attention_mask, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + last_hidden_state = self.final_layer_norm(last_hidden_state) + return last_hidden_state + + self.transformer.text_model.forward = text_encoder_forward.__get__(self.transformer.text_model) + + def transformer_forward( + self, + input_ids=None, + attention_mask=None, + position_ids=None, + output_attentions=None, + output_hidden_states=None, + return_dict=None, + embedding_manager=None, + ): + return self.text_model( + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + embedding_manager=embedding_manager, + ) + + self.transformer.forward = transformer_forward.__get__(self.transformer) + + def freeze(self): + self.transformer = self.transformer.eval() + for param in self.parameters(): + param.requires_grad = False + + def forward(self, text, **kwargs): + batch_encoding = self.tokenizer( + text, + truncation=False, + max_length=self.max_length, + return_length=True, + return_overflowing_tokens=False, + padding="longest", + return_tensors="pt", + ) + input_ids = batch_encoding["input_ids"] + tokens_list = self.split_chunks(input_ids) + z_list = [] + for tokens in tokens_list: + tokens = tokens.to(self.device) + _z = self.transformer(input_ids=tokens, **kwargs) + z_list += [_z] + return torch.cat(z_list, dim=1) + + def encode(self, text, **kwargs): + return self(text, **kwargs) + + def split_chunks(self, input_ids, chunk_size=75): + tokens_list = [] + bs, n = input_ids.shape + id_start = input_ids[:, 0].unsqueeze(1) # dim --> [bs, 1] + id_end = input_ids[:, -1].unsqueeze(1) + if n == 2: # empty caption + tokens_list.append(torch.cat((id_start,) + (id_end,) * (chunk_size + 1), dim=1)) + + trimmed_encoding = input_ids[:, 1:-1] + num_full_groups = (n - 2) // chunk_size + + for i in range(num_full_groups): + group = trimmed_encoding[:, i * chunk_size : (i + 1) * chunk_size] + group_pad = torch.cat((id_start, group, id_end), dim=1) + tokens_list.append(group_pad) + + remaining_columns = (n - 2) % chunk_size + if remaining_columns > 0: + remaining_group = trimmed_encoding[:, -remaining_columns:] + padding_columns = chunk_size - remaining_group.shape[1] + padding = id_end.expand(bs, padding_columns) + remaining_group_pad = torch.cat((id_start, remaining_group, padding, id_end), dim=1) + tokens_list.append(remaining_group_pad) + return tokens_list diff --git a/examples/research_projects/anytext/pipeline_anytext.py b/examples/research_projects/anytext/pipeline_anytext.py index eb8e3b1145cf..37e965b2b77e 100644 --- a/examples/research_projects/anytext/pipeline_anytext.py +++ b/examples/research_projects/anytext/pipeline_anytext.py @@ -259,6 +259,9 @@ def __init__( ) self.register_to_config(requires_safety_checker=requires_safety_checker) + def modify_prompt(self, prompt: str) -> str: + return prompt + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline._encode_prompt def _encode_prompt( self, @@ -1119,6 +1122,8 @@ def __call__( text_encoder_lora_scale = ( self.cross_attention_kwargs.get("scale", None) if self.cross_attention_kwargs is not None else None ) + prompt, texts = self.modify_prompt(prompt) + prompt_embeds, negative_prompt_embeds = self.encode_prompt( prompt, device, diff --git a/examples/research_projects/anytext/text_embedding_module.py b/examples/research_projects/anytext/text_embedding_module.py index 38d84982751c..44705c19558e 100644 --- a/examples/research_projects/anytext/text_embedding_module.py +++ b/examples/research_projects/anytext/text_embedding_module.py @@ -1,3 +1,216 @@ # text -> glyph render -> glyph lines -> OCR -> linear -> - # +> Token Replacement -> FrozenCLIPEmbedderT3 +# +> Token Replacement -> FrozenCLIPEmbedderT3 # text -> tokenizer -> + +from typing import List, Optional + +import torch +from torch import nn + +from diffusers import StableDiffusionLoraLoaderMixin, TextualInversionLoaderMixin +from diffusers.models.lora import adjust_lora_scale_text_encoder +from diffusers.utils import USE_PEFT_BACKEND, logging, scale_lora_layers, unscale_lora_layers + +from .frozen_clip_embedder_t3 import FrozenCLIPEmbedderT3 + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +class TextEmbeddingModule(nn.Module): + def __init__(self): + super().__init__() + self.frozen_CLIP_embedder_t3 = FrozenCLIPEmbedderT3() + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.encode_prompt + def encode_prompt( + self, + prompt, + device, + num_images_per_prompt, + do_classifier_free_guidance, + negative_prompt=None, + prompt_embeds: Optional[torch.Tensor] = None, + negative_prompt_embeds: Optional[torch.Tensor] = None, + lora_scale: Optional[float] = None, + clip_skip: Optional[int] = None, + ): + r""" + Encodes the prompt into text encoder hidden states. + + Args: + prompt (`str` or `List[str]`, *optional*): + prompt to be encoded + device: (`torch.device`): + torch device + num_images_per_prompt (`int`): + number of images that should be generated per prompt + do_classifier_free_guidance (`bool`): + whether to use classifier free guidance or not + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is + less than `1`). + prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + lora_scale (`float`, *optional*): + A LoRA scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded. + clip_skip (`int`, *optional*): + Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that + the output of the pre-final layer will be used for computing the prompt embeddings. + """ + # set lora scale so that monkey patched LoRA + # function of text encoder can correctly access it + if lora_scale is not None and isinstance(self, StableDiffusionLoraLoaderMixin): + self._lora_scale = lora_scale + + # dynamically adjust the LoRA scale + if not USE_PEFT_BACKEND: + adjust_lora_scale_text_encoder(self.frozen_CLIP_embedder_t3.text_encoder, lora_scale) + else: + scale_lora_layers(self.frozen_CLIP_embedder_t3.text_encoder, lora_scale) + + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + if prompt_embeds is None: + # textual inversion: process multi-vector tokens if necessary + if isinstance(self, TextualInversionLoaderMixin): + prompt = self.maybe_convert_prompt(prompt, self.frozen_CLIP_embedder_t3.tokenizer) + + text_inputs = self.frozen_CLIP_embedder_t3.tokenizer( + prompt, + padding="max_length", + max_length=self.frozen_CLIP_embedder_t3.tokenizer.model_max_length, + truncation=True, + return_tensors="pt", + ) + text_input_ids = text_inputs.input_ids + untruncated_ids = self.frozen_CLIP_embedder_t3.tokenizer( + prompt, padding="longest", return_tensors="pt" + ).input_ids + + if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal( + text_input_ids, untruncated_ids + ): + removed_text = self.frozen_CLIP_embedder_t3.tokenizer.batch_decode( + untruncated_ids[:, self.frozen_CLIP_embedder_t3.tokenizer.model_max_length - 1 : -1] + ) + logger.warning( + "The following part of your input was truncated because CLIP can only handle sequences up to" + f" {self.frozen_CLIP_embedder_t3.tokenizer.model_max_length} tokens: {removed_text}" + ) + + if ( + hasattr(self.frozen_CLIP_embedder_t3.text_encoder.config, "use_attention_mask") + and self.frozen_CLIP_embedder_t3.text_encoder.config.use_attention_mask + ): + attention_mask = text_inputs.attention_mask.to(device) + else: + attention_mask = None + + if clip_skip is None: + prompt_embeds = self.frozen_CLIP_embedder_t3.text_encoder( + text_input_ids.to(device), attention_mask=attention_mask + ) + prompt_embeds = prompt_embeds[0] + else: + prompt_embeds = self.frozen_CLIP_embedder_t3.text_encoder( + text_input_ids.to(device), attention_mask=attention_mask, output_hidden_states=True + ) + # Access the `hidden_states` first, that contains a tuple of + # all the hidden states from the encoder layers. Then index into + # the tuple to access the hidden states from the desired layer. + prompt_embeds = prompt_embeds[-1][-(clip_skip + 1)] + # We also need to apply the final LayerNorm here to not mess with the + # representations. The `last_hidden_states` that we typically use for + # obtaining the final prompt representations passes through the LayerNorm + # layer. + prompt_embeds = self.frozen_CLIP_embedder_t3.text_encoder.text_model.final_layer_norm(prompt_embeds) + + if self.text_encoder is not None: + prompt_embeds_dtype = self.text_encoder.dtype + elif self.unet is not None: + prompt_embeds_dtype = self.unet.dtype + else: + prompt_embeds_dtype = prompt_embeds.dtype + + prompt_embeds = prompt_embeds.to(dtype=prompt_embeds_dtype, device=device) + + bs_embed, seq_len, _ = prompt_embeds.shape + # duplicate text embeddings for each generation per prompt, using mps friendly method + prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) + prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1) + + # get unconditional embeddings for classifier free guidance + if do_classifier_free_guidance and negative_prompt_embeds is None: + uncond_tokens: List[str] + if negative_prompt is None: + uncond_tokens = [""] * batch_size + elif prompt is not None and type(prompt) is not type(negative_prompt): + raise TypeError( + f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" + f" {type(prompt)}." + ) + elif isinstance(negative_prompt, str): + uncond_tokens = [negative_prompt] + elif batch_size != len(negative_prompt): + raise ValueError( + f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" + f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" + " the batch size of `prompt`." + ) + else: + uncond_tokens = negative_prompt + + # textual inversion: process multi-vector tokens if necessary + if isinstance(self, TextualInversionLoaderMixin): + uncond_tokens = self.maybe_convert_prompt(uncond_tokens, self.frozen_CLIP_embedder_t3.tokenizer) + + max_length = prompt_embeds.shape[1] + uncond_input = self.frozen_CLIP_embedder_t3.tokenizer( + uncond_tokens, + padding="max_length", + max_length=max_length, + truncation=True, + return_tensors="pt", + ) + + if ( + hasattr(self.frozen_CLIP_embedder_t3.text_encoder.config, "use_attention_mask") + and self.frozen_CLIP_embedder_t3.text_encoder.config.use_attention_mask + ): + attention_mask = uncond_input.attention_mask.to(device) + else: + attention_mask = None + + negative_prompt_embeds = self.frozen_CLIP_embedder_t3.text_encoder( + uncond_input.input_ids.to(device), + attention_mask=attention_mask, + ) + negative_prompt_embeds = negative_prompt_embeds[0] + + if do_classifier_free_guidance: + # duplicate unconditional embeddings for each generation per prompt, using mps friendly method + seq_len = negative_prompt_embeds.shape[1] + + negative_prompt_embeds = negative_prompt_embeds.to(dtype=prompt_embeds_dtype, device=device) + + negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1) + negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) + + if self.frozen_CLIP_embedder_t3.text_encoder is not None: + if isinstance(self, StableDiffusionLoraLoaderMixin) and USE_PEFT_BACKEND: + # Retrieve the original scale by scaling back the LoRA layers + unscale_lora_layers(self.frozen_CLIP_embedder_t3.text_encoder, lora_scale) + + return prompt_embeds, negative_prompt_embeds From 867bbbfdb0232fc730e97d1a71d69fbadecfd5df Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tolga=20Cang=C3=B6z?= Date: Tue, 30 Jul 2024 17:18:50 +0300 Subject: [PATCH 03/87] feat: Add TextEmbeddingModule to AnyTextPipeline --- examples/research_projects/anytext/pipeline_anytext.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/examples/research_projects/anytext/pipeline_anytext.py b/examples/research_projects/anytext/pipeline_anytext.py index 37e965b2b77e..90b709ecfe0b 100644 --- a/examples/research_projects/anytext/pipeline_anytext.py +++ b/examples/research_projects/anytext/pipeline_anytext.py @@ -20,6 +20,7 @@ import PIL.Image import torch import torch.nn.functional as F +from text_embedding_module import TextEmbeddingModule from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer, CLIPVisionModelWithProjection from diffusers.callbacks import MultiPipelineCallbacks, PipelineCallback @@ -221,6 +222,7 @@ def __init__( requires_safety_checker: bool = True, ): super().__init__() + self.text_embedding_module = TextEmbeddingModule() if safety_checker is None and requires_safety_checker: logger.warning( @@ -1124,7 +1126,7 @@ def __call__( ) prompt, texts = self.modify_prompt(prompt) - prompt_embeds, negative_prompt_embeds = self.encode_prompt( + prompt_embeds, negative_prompt_embeds = self.text_embedding_module.encode_prompt( prompt, device, num_images_per_prompt, From 8818372dac4e42a429f9c299893fec04a5c42019 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tolga=20Cang=C3=B6z?= Date: Tue, 30 Jul 2024 17:19:22 +0300 Subject: [PATCH 04/87] feat: Add AuxiliaryLatentModule template to AnyTextPipeline --- .../anytext/auxiliary_latent_module.py | 78 +++++++++++++++++++ 1 file changed, 78 insertions(+) diff --git a/examples/research_projects/anytext/auxiliary_latent_module.py b/examples/research_projects/anytext/auxiliary_latent_module.py index 5564af8c1dc0..fc5580f445a8 100644 --- a/examples/research_projects/anytext/auxiliary_latent_module.py +++ b/examples/research_projects/anytext/auxiliary_latent_module.py @@ -1,3 +1,81 @@ # text -> glyph render -> glyph l_g -> glyph block -> # +> fuse layer # position l_p -> position block -> + +import torch +from torch import nn + +from diffusers.xxx import TimestepEmbedSequential + + +# Taken from AnyText.ldm.modules.diffusionmodules.util.conv_nd +def conv_nd(dims, *args, **kwargs): + """ + Create a 1D, 2D, or 3D convolution module. + """ + if dims == 1: + return nn.Conv1d(*args, **kwargs) + elif dims == 2: + return nn.Conv2d(*args, **kwargs) + elif dims == 3: + return nn.Conv3d(*args, **kwargs) + raise ValueError(f"unsupported dimensions: {dims}") + + +# Copied from diffusers.models.controlnet.zero_module +def zero_module(module: nn.Module) -> nn.Module: + for p in module.parameters(): + nn.init.zeros_(p) + return module + + +class AuxiliaryLatentModule(nn.Module): + def __init__(self, dims, model_channels, glyph_channels, position_channels): + self.glyph_block = TimestepEmbedSequential( + conv_nd(dims, glyph_channels, 8, 3, padding=1), + nn.SiLU(), + conv_nd(dims, 8, 8, 3, padding=1), + nn.SiLU(), + conv_nd(dims, 8, 16, 3, padding=1, stride=2), + nn.SiLU(), + conv_nd(dims, 16, 16, 3, padding=1), + nn.SiLU(), + conv_nd(dims, 16, 32, 3, padding=1, stride=2), + nn.SiLU(), + conv_nd(dims, 32, 32, 3, padding=1), + nn.SiLU(), + conv_nd(dims, 32, 96, 3, padding=1, stride=2), + nn.SiLU(), + conv_nd(dims, 96, 96, 3, padding=1), + nn.SiLU(), + conv_nd(dims, 96, 256, 3, padding=1, stride=2), + nn.SiLU(), + ) + + self.position_block = TimestepEmbedSequential( + conv_nd(dims, position_channels, 8, 3, padding=1), + nn.SiLU(), + conv_nd(dims, 8, 8, 3, padding=1), + nn.SiLU(), + conv_nd(dims, 8, 16, 3, padding=1, stride=2), + nn.SiLU(), + conv_nd(dims, 16, 16, 3, padding=1), + nn.SiLU(), + conv_nd(dims, 16, 32, 3, padding=1, stride=2), + nn.SiLU(), + conv_nd(dims, 32, 32, 3, padding=1), + nn.SiLU(), + conv_nd(dims, 32, 64, 3, padding=1, stride=2), + nn.SiLU(), + ) + + self.fuse_block = zero_module(conv_nd(dims, 256 + 64 + 4, model_channels, 3, padding=1)) + + def forward(self, text_info, emb, context): + glyphs = torch.cat(text_info["glyphs"], dim=1).sum(dim=1, keepdim=True) + positions = torch.cat(text_info["positions"], dim=1).sum(dim=1, keepdim=True) + enc_glyph = self.glyph_block(glyphs, emb, context) + enc_pos = self.position_block(positions, emb, context) + guided_hint = self.fuse_block(torch.cat([enc_glyph, enc_pos, text_info["masked_x"]], dim=1)) + + return guided_hint From 64c63ebd9d437e792255eb3ece9a588a0fe8d6fa Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tolga=20Cang=C3=B6z?= Date: Tue, 30 Jul 2024 22:49:43 +0300 Subject: [PATCH 05/87] Add bert tokenizer from the anytext repo for now --- .../anytext/bert_tokenizer.py | 428 ++++++++++++++++++ 1 file changed, 428 insertions(+) create mode 100644 examples/research_projects/anytext/bert_tokenizer.py diff --git a/examples/research_projects/anytext/bert_tokenizer.py b/examples/research_projects/anytext/bert_tokenizer.py new file mode 100644 index 000000000000..b6d2a2e81232 --- /dev/null +++ b/examples/research_projects/anytext/bert_tokenizer.py @@ -0,0 +1,428 @@ +# Copyright 2018 The Google AI Language Team Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Tokenization classes.""" + +from __future__ import absolute_import, division, print_function + +import collections +import re +import unicodedata + +import six + + +def validate_case_matches_checkpoint(do_lower_case, init_checkpoint): + """Checks whether the casing config is consistent with the checkpoint name.""" + + # The casing has to be passed in by the user and there is no explicit check + # as to whether it matches the checkpoint. The casing information probably + # should have been stored in the bert_config.json file, but it's not, so + # we have to heuristically detect it to validate. + + if not init_checkpoint: + return + + m = re.match("^.*?([A-Za-z0-9_-]+)/bert_model.ckpt", init_checkpoint) + if m is None: + return + + model_name = m.group(1) + + lower_models = [ + "uncased_L-24_H-1024_A-16", + "uncased_L-12_H-768_A-12", + "multilingual_L-12_H-768_A-12", + "chinese_L-12_H-768_A-12", + ] + + cased_models = ["cased_L-12_H-768_A-12", "cased_L-24_H-1024_A-16", "multi_cased_L-12_H-768_A-12"] + + is_bad_config = False + if model_name in lower_models and not do_lower_case: + is_bad_config = True + actual_flag = "False" + case_name = "lowercased" + opposite_flag = "True" + + if model_name in cased_models and do_lower_case: + is_bad_config = True + actual_flag = "True" + case_name = "cased" + opposite_flag = "False" + + if is_bad_config: + raise ValueError( + "You passed in `--do_lower_case=%s` with `--init_checkpoint=%s`. " + "However, `%s` seems to be a %s model, so you " + "should pass in `--do_lower_case=%s` so that the fine-tuning matches " + "how the model was pre-training. If this error is wrong, please " + "just comment out this check." % (actual_flag, init_checkpoint, model_name, case_name, opposite_flag) + ) + + +def convert_to_unicode(text): + """Converts `text` to Unicode (if it's not already), assuming utf-8 input.""" + if six.PY3: + if isinstance(text, str): + return text + elif isinstance(text, bytes): + return text.decode("utf-8", "ignore") + else: + raise ValueError("Unsupported string type: %s" % (type(text))) + elif six.PY2: + if isinstance(text, str): + return text.decode("utf-8", "ignore") + elif isinstance(text, unicode): # type: ignore # noqa: F821 + return text + else: + raise ValueError("Unsupported string type: %s" % (type(text))) + else: + raise ValueError("Not running on Python2 or Python 3?") + + +def printable_text(text): + """Returns text encoded in a way suitable for print or `tf.logging`.""" + + # These functions want `str` for both Python2 and Python3, but in one case + # it's a Unicode string and in the other it's a byte string. + if six.PY3: + if isinstance(text, str): + return text + elif isinstance(text, bytes): + return text.decode("utf-8", "ignore") + else: + raise ValueError("Unsupported string type: %s" % (type(text))) + elif six.PY2: + if isinstance(text, str): + return text + elif isinstance(text, unicode): # type: ignore # noqa: F821 + return text.encode("utf-8") + else: + raise ValueError("Unsupported string type: %s" % (type(text))) + else: + raise ValueError("Not running on Python2 or Python 3?") + + +def load_vocab(vocab_file): + """Loads a vocabulary file into a dictionary.""" + vocab = collections.OrderedDict() + index = 0 + with open(vocab_file, "r", encoding="utf-8") as reader: + while True: + token = convert_to_unicode(reader.readline()) + if not token: + break + token = token.strip() + vocab[token] = index + index += 1 + return vocab + + +def convert_by_vocab(vocab, items): + """Converts a sequence of [tokens|ids] using the vocab.""" + output = [] + for item in items: + output.append(vocab[item]) + return output + + +def convert_tokens_to_ids(vocab, tokens): + return convert_by_vocab(vocab, tokens) + + +def convert_ids_to_tokens(inv_vocab, ids): + return convert_by_vocab(inv_vocab, ids) + + +def whitespace_tokenize(text): + """Runs basic whitespace cleaning and splitting on a piece of text.""" + text = text.strip() + if not text: + return [] + tokens = text.split() + return tokens + + +class FullTokenizer(object): + """Runs end-to-end tokenization.""" + + def __init__(self, vocab_file, do_lower_case=True): + self.vocab = load_vocab(vocab_file) + self.inv_vocab = {v: k for k, v in self.vocab.items()} + self.basic_tokenizer = BasicTokenizer(do_lower_case=do_lower_case) + self.wordpiece_tokenizer = WordpieceTokenizer(vocab=self.vocab) + + def tokenize(self, text): + split_tokens = [] + for token in self.basic_tokenizer.tokenize(text): + for sub_token in self.wordpiece_tokenizer.tokenize(token): + split_tokens.append(sub_token) + + return split_tokens + + def convert_tokens_to_ids(self, tokens): + return convert_by_vocab(self.vocab, tokens) + + def convert_ids_to_tokens(self, ids): + return convert_by_vocab(self.inv_vocab, ids) + + @staticmethod + def convert_tokens_to_string(tokens, clean_up_tokenization_spaces=True): + """Converts a sequence of tokens (string) in a single string.""" + + def clean_up_tokenization(out_string): + """Clean up a list of simple English tokenization artifacts + like spaces before punctuations and abbreviated forms. + """ + out_string = ( + out_string.replace(" .", ".") + .replace(" ?", "?") + .replace(" !", "!") + .replace(" ,", ",") + .replace(" ' ", "'") + .replace(" n't", "n't") + .replace(" 'm", "'m") + .replace(" 's", "'s") + .replace(" 've", "'ve") + .replace(" 're", "'re") + ) + return out_string + + text = " ".join(tokens).replace(" ##", "").strip() + if clean_up_tokenization_spaces: + clean_text = clean_up_tokenization(text) + return clean_text + else: + return text + + def vocab_size(self): + return len(self.vocab) + + +class BasicTokenizer(object): + """Runs basic tokenization (punctuation splitting, lower casing, etc.).""" + + def __init__(self, do_lower_case=True): + """Constructs a BasicTokenizer. + + Args: + do_lower_case: Whether to lower case the input. + """ + self.do_lower_case = do_lower_case + + def tokenize(self, text): + """Tokenizes a piece of text.""" + text = convert_to_unicode(text) + text = self._clean_text(text) + + # This was added on November 1st, 2018 for the multilingual and Chinese + # models. This is also applied to the English models now, but it doesn't + # matter since the English models were not trained on any Chinese data + # and generally don't have any Chinese data in them (there are Chinese + # characters in the vocabulary because Wikipedia does have some Chinese + # words in the English Wikipedia.). + text = self._tokenize_chinese_chars(text) + + orig_tokens = whitespace_tokenize(text) + split_tokens = [] + for token in orig_tokens: + if self.do_lower_case: + token = token.lower() + token = self._run_strip_accents(token) + split_tokens.extend(self._run_split_on_punc(token)) + + output_tokens = whitespace_tokenize(" ".join(split_tokens)) + return output_tokens + + def _run_strip_accents(self, text): + """Strips accents from a piece of text.""" + text = unicodedata.normalize("NFD", text) + output = [] + for char in text: + cat = unicodedata.category(char) + if cat == "Mn": + continue + output.append(char) + return "".join(output) + + def _run_split_on_punc(self, text): + """Splits punctuation on a piece of text.""" + chars = list(text) + i = 0 + start_new_word = True + output = [] + while i < len(chars): + char = chars[i] + if _is_punctuation(char): + output.append([char]) + start_new_word = True + else: + if start_new_word: + output.append([]) + start_new_word = False + output[-1].append(char) + i += 1 + + return ["".join(x) for x in output] + + def _tokenize_chinese_chars(self, text): + """Adds whitespace around any CJK character.""" + output = [] + for char in text: + cp = ord(char) + if self._is_chinese_char(cp): + output.append(" ") + output.append(char) + output.append(" ") + else: + output.append(char) + return "".join(output) + + def _is_chinese_char(self, cp): + """Checks whether CP is the codepoint of a CJK character.""" + # This defines a "chinese character" as anything in the CJK Unicode block: + # https://en.wikipedia.org/wiki/CJK_Unified_Ideographs_(Unicode_block) + # + # Note that the CJK Unicode block is NOT all Japanese and Korean characters, + # despite its name. The modern Korean Hangul alphabet is a different block, + # as is Japanese Hiragana and Katakana. Those alphabets are used to write + # space-separated words, so they are not treated specially and handled + # like the all of the other languages. + if ( + (cp >= 0x4E00 and cp <= 0x9FFF) + or (cp >= 0x3400 and cp <= 0x4DBF) + or (cp >= 0x20000 and cp <= 0x2A6DF) + or (cp >= 0x2A700 and cp <= 0x2B73F) + or (cp >= 0x2B740 and cp <= 0x2B81F) + or (cp >= 0x2B820 and cp <= 0x2CEAF) + or (cp >= 0xF900 and cp <= 0xFAFF) + or (cp >= 0x2F800 and cp <= 0x2FA1F) + ): + return True + + return False + + def _clean_text(self, text): + """Performs invalid character removal and whitespace cleanup on text.""" + output = [] + for char in text: + cp = ord(char) + if cp == 0 or cp == 0xFFFD or _is_control(char): + continue + if _is_whitespace(char): + output.append(" ") + else: + output.append(char) + return "".join(output) + + +class WordpieceTokenizer(object): + """Runs WordPiece tokenization.""" + + def __init__(self, vocab, unk_token="[UNK]", max_input_chars_per_word=200): + self.vocab = vocab + self.unk_token = unk_token + self.max_input_chars_per_word = max_input_chars_per_word + + def tokenize(self, text): + """Tokenizes a piece of text into its word pieces. + + This uses a greedy longest-match-first algorithm to perform tokenization + using the given vocabulary. + + For example: + input = "unaffable" + output = ["un", "##aff", "##able"] + + Args: + text: A single token or whitespace separated tokens. This should have + already been passed through `BasicTokenizer. + + Returns: + A list of wordpiece tokens. + """ + + text = convert_to_unicode(text) + + output_tokens = [] + for token in whitespace_tokenize(text): + chars = list(token) + if len(chars) > self.max_input_chars_per_word: + output_tokens.append(self.unk_token) + continue + + is_bad = False + start = 0 + sub_tokens = [] + while start < len(chars): + end = len(chars) + cur_substr = None + while start < end: + substr = "".join(chars[start:end]) + if start > 0: + substr = "##" + substr + if substr in self.vocab: + cur_substr = substr + break + end -= 1 + if cur_substr is None: + is_bad = True + break + sub_tokens.append(cur_substr) + start = end + + if is_bad: + output_tokens.append(self.unk_token) + else: + output_tokens.extend(sub_tokens) + return output_tokens + + +def _is_whitespace(char): + """Checks whether `chars` is a whitespace character.""" + # \t, \n, and \r are technically control characters but we treat them + # as whitespace since they are generally considered as such. + if char == " " or char == "\t" or char == "\n" or char == "\r": + return True + cat = unicodedata.category(char) + if cat == "Zs": + return True + return False + + +def _is_control(char): + """Checks whether `chars` is a control character.""" + # These are technically control characters but we count them as whitespace + # characters. + if char == "\t" or char == "\n" or char == "\r": + return False + cat = unicodedata.category(char) + if cat in ("Cc", "Cf"): + return True + return False + + +def _is_punctuation(char): + """Checks whether `chars` is a punctuation character.""" + cp = ord(char) + # We treat all non-letter/number ASCII as punctuation. + # Characters such as "^", "$", and "`" are not in the Unicode + # Punctuation class but we treat them as punctuation anyways, for + # consistency. + if (cp >= 33 and cp <= 47) or (cp >= 58 and cp <= 64) or (cp >= 91 and cp <= 96) or (cp >= 123 and cp <= 126): + return True + cat = unicodedata.category(char) + if cat.startswith("P"): + return True + return False From 92f8b79283c005afaad0726f5bf68cc865732ce2 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tolga=20Cang=C3=B6z?= Date: Tue, 30 Jul 2024 22:51:22 +0300 Subject: [PATCH 06/87] feat: Update AnyTextPipeline's modify_prompt method This commit adds improvements to the modify_prompt method in the AnyTextPipeline class. The method now handles special characters and replaces selected string prompts with a placeholder. Additionally, it includes a check for Chinese text and translation using the trans_pipe. --- .../anytext/pipeline_anytext.py | 35 +++++++++++++++++-- 1 file changed, 32 insertions(+), 3 deletions(-) diff --git a/examples/research_projects/anytext/pipeline_anytext.py b/examples/research_projects/anytext/pipeline_anytext.py index 90b709ecfe0b..52511243ed76 100644 --- a/examples/research_projects/anytext/pipeline_anytext.py +++ b/examples/research_projects/anytext/pipeline_anytext.py @@ -14,12 +14,14 @@ import inspect +import re from typing import Any, Callable, Dict, List, Optional, Tuple, Union import numpy as np import PIL.Image import torch import torch.nn.functional as F +from bert_tokenizer import BasicTokenizer from text_embedding_module import TextEmbeddingModule from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer, CLIPVisionModelWithProjection @@ -49,6 +51,10 @@ from diffusers.utils.torch_utils import is_compiled_module, is_torch_version, randn_tensor +checker = BasicTokenizer() + + +PLACE_HOLDER = "*" logger = logging.get_logger(__name__) # pylint: disable=invalid-name @@ -261,8 +267,31 @@ def __init__( ) self.register_to_config(requires_safety_checker=requires_safety_checker) - def modify_prompt(self, prompt: str) -> str: - return prompt + def modify_prompt(self, prompt): + prompt = prompt.replace("“", '"') + prompt = prompt.replace("”", '"') + p = '"(.*?)"' + strs = re.findall(p, prompt) + if len(strs) == 0: + strs = [" "] + else: + for s in strs: + prompt = prompt.replace(f'"{s}"', f" {PLACE_HOLDER} ", 1) + if self.is_chinese(prompt): + if self.trans_pipe is None: + return None, None + old_prompt = prompt + prompt = self.trans_pipe(input=prompt + " .")["translation"][:-1] + print(f"Translate: {old_prompt} --> {prompt}") + return prompt, strs + + def is_chinese(self, text): + text = checker._clean_text(text) + for char in text: + cp = ord(char) + if checker._is_chinese_char(cp): + return True + return False # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline._encode_prompt def _encode_prompt( @@ -1284,7 +1313,7 @@ def __call__( ) if guess_mode and self.do_classifier_free_guidance: - # Infered ControlNet only for the conditional batch. + # Inferred ControlNet only for the conditional batch. # To apply the output of ControlNet to both the unconditional and conditional batches, # add 0 to the unconditional batch to keep it unchanged. down_block_res_samples = [torch.cat([torch.zeros_like(d), d]) for d in down_block_res_samples] From e9c688c013fed668b0d4e6c0434baaed047c1ce4 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tolga=20Cang=C3=B6z?= Date: Tue, 30 Jul 2024 22:54:13 +0300 Subject: [PATCH 07/87] Fill in the `forward` pass of `AuxiliaryLatentModule` --- .../anytext/auxiliary_latent_module.py | 290 +++++++++++++++++- 1 file changed, 285 insertions(+), 5 deletions(-) diff --git a/examples/research_projects/anytext/auxiliary_latent_module.py b/examples/research_projects/anytext/auxiliary_latent_module.py index fc5580f445a8..c7e65fd75ace 100644 --- a/examples/research_projects/anytext/auxiliary_latent_module.py +++ b/examples/research_projects/anytext/auxiliary_latent_module.py @@ -2,10 +2,143 @@ # +> fuse layer # position l_p -> position block -> +import cv2 +from diffusers.utils import logging +import numpy as np import torch +from PIL import Image, ImageDraw, ImageFont from torch import nn -from diffusers.xxx import TimestepEmbedSequential + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +def check_channels(image): + channels = image.shape[2] if len(image.shape) == 3 else 1 + if channels == 1: + image = cv2.cvtColor(image, cv2.COLOR_GRAY2BGR) + elif channels > 3: + image = image[:, :, :3] + return image + + +def resize_image(img, max_length=768): + height, width = img.shape[:2] + max_dimension = max(height, width) + + if max_dimension > max_length: + scale_factor = max_length / max_dimension + new_width = int(round(width * scale_factor)) + new_height = int(round(height * scale_factor)) + new_size = (new_width, new_height) + img = cv2.resize(img, new_size) + height, width = img.shape[:2] + img = cv2.resize(img, (width - (width % 64), height - (height % 64))) + return img + + +def insert_spaces(string, nSpace): + if nSpace == 0: + return string + new_string = "" + for char in string: + new_string += char + " " * nSpace + return new_string[:-nSpace] + + +def draw_glyph(font, text): + g_size = 50 + W, H = (512, 80) + new_font = font.font_variant(size=g_size) + img = Image.new(mode="1", size=(W, H), color=0) + draw = ImageDraw.Draw(img) + left, top, right, bottom = new_font.getbbox(text) + text_width = max(right - left, 5) + text_height = max(bottom - top, 5) + ratio = min(W * 0.9 / text_width, H * 0.9 / text_height) + new_font = font.font_variant(size=int(g_size * ratio)) + + text_width, text_height = new_font.getsize(text) + offset_x, offset_y = new_font.getoffset(text) + x = (img.width - text_width) // 2 + y = (img.height - text_height) // 2 - offset_y // 2 + draw.text((x, y), text, font=new_font, fill="white") + img = np.expand_dims(np.array(img), axis=2).astype(np.float64) + return img + + +def draw_glyph2(font, text, polygon, vertAng=10, scale=1, width=512, height=512, add_space=True): + enlarge_polygon = polygon * scale + rect = cv2.minAreaRect(enlarge_polygon) + box = cv2.boxPoints(rect) + box = np.int0(box) + w, h = rect[1] + angle = rect[2] + if angle < -45: + angle += 90 + angle = -angle + if w < h: + angle += 90 + + vert = False + if abs(angle) % 90 < vertAng or abs(90 - abs(angle) % 90) % 90 < vertAng: + _w = max(box[:, 0]) - min(box[:, 0]) + _h = max(box[:, 1]) - min(box[:, 1]) + if _h >= _w: + vert = True + angle = 0 + + img = np.zeros((height * scale, width * scale, 3), np.uint8) + img = Image.fromarray(img) + + # infer font size + image4ratio = Image.new("RGB", img.size, "white") + draw = ImageDraw.Draw(image4ratio) + _, _, _tw, _th = draw.textbbox(xy=(0, 0), text=text, font=font) + text_w = min(w, h) * (_tw / _th) + if text_w <= max(w, h): + # add space + if len(text) > 1 and not vert and add_space: + for i in range(1, 100): + text_space = insert_spaces(text, i) + _, _, _tw2, _th2 = draw.textbbox(xy=(0, 0), text=text_space, font=font) + if min(w, h) * (_tw2 / _th2) > max(w, h): + break + text = insert_spaces(text, i - 1) + font_size = min(w, h) * 0.80 + else: + shrink = 0.75 if vert else 0.85 + font_size = min(w, h) / (text_w / max(w, h)) * shrink + new_font = font.font_variant(size=int(font_size)) + + left, top, right, bottom = new_font.getbbox(text) + text_width = right - left + text_height = bottom - top + + layer = Image.new("RGBA", img.size, (0, 0, 0, 0)) + draw = ImageDraw.Draw(layer) + if not vert: + draw.text( + (rect[0][0] - text_width // 2, rect[0][1] - text_height // 2 - top), + text, + font=new_font, + fill=(255, 255, 255, 255), + ) + else: + x_s = min(box[:, 0]) + _w // 2 - text_height // 2 + y_s = min(box[:, 1]) + for c in text: + draw.text((x_s, y_s), c, font=new_font, fill=(255, 255, 255, 255)) + _, _t, _, _b = new_font.getbbox(c) + y_s += _b + + rotated_layer = layer.rotate(angle, expand=1, center=(rect[0][0], rect[0][1])) + + x_offset = int((img.width - rotated_layer.width) / 2) + y_offset = int((img.height - rotated_layer.height) / 2) + img.paste(rotated_layer, (x_offset, y_offset), rotated_layer) + img = np.expand_dims(np.array(img.convert("1")), axis=2).astype(np.float64) + return img # Taken from AnyText.ldm.modules.diffusionmodules.util.conv_nd @@ -30,8 +163,10 @@ def zero_module(module: nn.Module) -> nn.Module: class AuxiliaryLatentModule(nn.Module): - def __init__(self, dims, model_channels, glyph_channels, position_channels): - self.glyph_block = TimestepEmbedSequential( + def __init__(self, dims, model_channels, glyph_channels, position_channels, font_path): + super().__init__() + self.font = ImageFont.truetype(font_path, 60) + self.glyph_block = nn.Sequential( conv_nd(dims, glyph_channels, 8, 3, padding=1), nn.SiLU(), conv_nd(dims, 8, 8, 3, padding=1), @@ -52,7 +187,7 @@ def __init__(self, dims, model_channels, glyph_channels, position_channels): nn.SiLU(), ) - self.position_block = TimestepEmbedSequential( + self.position_block = nn.Sequential( conv_nd(dims, position_channels, 8, 3, padding=1), nn.SiLU(), conv_nd(dims, 8, 8, 3, padding=1), @@ -71,7 +206,152 @@ def __init__(self, dims, model_channels, glyph_channels, position_channels): self.fuse_block = zero_module(conv_nd(dims, 256 + 64 + 4, model_channels, 3, padding=1)) - def forward(self, text_info, emb, context): + def forward( + self, + text_info, + emb, + context, + mode, + texts, + prompt, + draw_pos, + ori_image, + img_count, + max_chars=77, + revise_pos=False, + sort_priority=False, + h=512, + w=512, + ): + if prompt is None and texts is None: + return None, -1, "You have input Chinese prompt but the translator is not loaded!", "" + n_lines = len(texts) + if mode in ["text-generation", "gen"]: + edit_image = np.ones((h, w, 3)) * 127.5 # empty mask image + elif mode in ["text-editing", "edit"]: + if draw_pos is None or ori_image is None: + return None, -1, "Reference image and position image are needed for text editing!", "" + if isinstance(ori_image, str): + ori_image = cv2.imread(ori_image)[..., ::-1] + assert ori_image is not None, f"Can't read ori_image image from{ori_image}!" + elif isinstance(ori_image, torch.Tensor): + ori_image = ori_image.cpu().numpy() + else: + assert isinstance(ori_image, np.ndarray), f"Unknown format of ori_image: {type(ori_image)}" + edit_image = ori_image.clip(1, 255) # for mask reason + edit_image = check_channels(edit_image) + edit_image = resize_image( + edit_image, max_length=768 + ) # make w h multiple of 64, resize if w or h > max_length + h, w = edit_image.shape[:2] # change h, w by input ref_img + # preprocess pos_imgs(if numpy, make sure it's white pos in black bg) + if draw_pos is None: + pos_imgs = np.zeros((w, h, 1)) + if isinstance(draw_pos, str): + draw_pos = cv2.imread(draw_pos)[..., ::-1] + assert draw_pos is not None, f"Can't read draw_pos image from{draw_pos}!" + pos_imgs = 255 - draw_pos + elif isinstance(draw_pos, torch.Tensor): + pos_imgs = draw_pos.cpu().numpy() + else: + assert isinstance(draw_pos, np.ndarray), f"Unknown format of draw_pos: {type(draw_pos)}" + if mode in ["text-editing", "edit"]: + pos_imgs = cv2.resize(pos_imgs, (w, h)) + pos_imgs = pos_imgs[..., 0:1] + pos_imgs = cv2.convertScaleAbs(pos_imgs) + _, pos_imgs = cv2.threshold(pos_imgs, 254, 255, cv2.THRESH_BINARY) + # separate pos_imgs + pos_imgs = self.separate_pos_imgs(pos_imgs, sort_priority) + if len(pos_imgs) == 0: + pos_imgs = [np.zeros((h, w, 1))] + if len(pos_imgs) < n_lines: + if n_lines == 1 and texts[0] == " ": + pass # text-to-image without text + else: + return ( + None, + -1, + f"Found {len(pos_imgs)} positions that < needed {n_lines} from prompt, check and try again!", + "", + ) + elif len(pos_imgs) > n_lines: + str_warning = f"Warning: found {len(pos_imgs)} positions that > needed {n_lines} from prompt." + logger.warning(str_warning) + # get pre_pos, poly_list, hint that needed for anytext + pre_pos = [] + poly_list = [] + for input_pos in pos_imgs: + if input_pos.mean() != 0: + input_pos = input_pos[..., np.newaxis] if len(input_pos.shape) == 2 else input_pos + poly, pos_img = self.find_polygon(input_pos) + pre_pos += [pos_img / 255.0] + poly_list += [poly] + else: + pre_pos += [np.zeros((h, w, 1))] + poly_list += [None] + np_hint = np.sum(pre_pos, axis=0).clip(0, 1) + # prepare info dict + info = {} + info["glyphs"] = [] + info["gly_line"] = [] + info["positions"] = [] + info["n_lines"] = [len(texts)] * img_count + gly_pos_imgs = [] + for i in range(len(texts)): + text = texts[i] + if len(text) > max_chars: + str_warning = f'"{text}" length > max_chars: {max_chars}, will be cut off...' + logger.warning(str_warning) + text = text[:max_chars] + gly_scale = 2 + if pre_pos[i].mean() != 0: + gly_line = draw_glyph(self.font, text) + glyphs = draw_glyph2( + self.font, text, poly_list[i], scale=gly_scale, width=w, height=h, add_space=False + ) + gly_pos_img = cv2.drawContours(glyphs * 255, [poly_list[i] * gly_scale], 0, (255, 255, 255), 1) + if revise_pos: + resize_gly = cv2.resize(glyphs, (pre_pos[i].shape[1], pre_pos[i].shape[0])) + new_pos = cv2.morphologyEx( + (resize_gly * 255).astype(np.uint8), + cv2.MORPH_CLOSE, + kernel=np.ones((resize_gly.shape[0] // 10, resize_gly.shape[1] // 10), dtype=np.uint8), + iterations=1, + ) + new_pos = new_pos[..., np.newaxis] if len(new_pos.shape) == 2 else new_pos + contours, _ = cv2.findContours(new_pos, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_NONE) + if len(contours) != 1: + str_warning = f"Fail to revise position {i} to bounding rect, remain position unchanged..." + raise ValueError(str_warning) + else: + rect = cv2.minAreaRect(contours[0]) + poly = np.int0(cv2.boxPoints(rect)) + pre_pos[i] = cv2.drawContours(new_pos, [poly], -1, 255, -1) / 255.0 + gly_pos_img = cv2.drawContours(glyphs * 255, [poly * gly_scale], 0, (255, 255, 255), 1) + gly_pos_imgs += [gly_pos_img] # for show + else: + glyphs = np.zeros((h * gly_scale, w * gly_scale, 1)) + gly_line = np.zeros((80, 512, 1)) + gly_pos_imgs += [np.zeros((h * gly_scale, w * gly_scale, 1))] # for show + pos = pre_pos[i] + info["glyphs"] += [self.arr2tensor(glyphs, img_count)] + info["gly_line"] += [self.arr2tensor(gly_line, img_count)] + info["positions"] += [self.arr2tensor(pos, img_count)] + + # get masked_x + masked_img = ((edit_image.astype(np.float32) / 127.5) - 1.0) * (1 - np_hint) + masked_img = np.transpose(masked_img, (2, 0, 1)) + masked_img = torch.from_numpy(masked_img.copy()).float().cpu() + if self.use_fp16: + masked_img = masked_img.half() + encoder_posterior = self.model.encode_first_stage(masked_img[None, ...]) + masked_x = self.model.get_first_stage_encoding(encoder_posterior).detach() + if self.use_fp16: + masked_x = masked_x.half() + info["masked_x"] = torch.cat([masked_x for _ in range(img_count)], dim=0) + + # hint = self.arr2tensor(np_hint, img_count) + glyphs = torch.cat(text_info["glyphs"], dim=1).sum(dim=1, keepdim=True) positions = torch.cat(text_info["positions"], dim=1).sum(dim=1, keepdim=True) enc_glyph = self.glyph_block(glyphs, emb, context) From 42a41d01bf3373da72ebf2a84ba9795f41daeb31 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tolga=20Cang=C3=B6z?= Date: Tue, 30 Jul 2024 22:55:01 +0300 Subject: [PATCH 08/87] `make style && make quality` --- examples/research_projects/anytext/auxiliary_latent_module.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/examples/research_projects/anytext/auxiliary_latent_module.py b/examples/research_projects/anytext/auxiliary_latent_module.py index c7e65fd75ace..2df944168de5 100644 --- a/examples/research_projects/anytext/auxiliary_latent_module.py +++ b/examples/research_projects/anytext/auxiliary_latent_module.py @@ -3,12 +3,13 @@ # position l_p -> position block -> import cv2 -from diffusers.utils import logging import numpy as np import torch from PIL import Image, ImageDraw, ImageFont from torch import nn +from diffusers.utils import logging + logger = logging.get_logger(__name__) # pylint: disable=invalid-name From 9d50f80862232a2239f6abe69ea8d68e6b19371c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tolga=20Cang=C3=B6z?= Date: Tue, 30 Jul 2024 23:12:23 +0300 Subject: [PATCH 09/87] `chore: Update bert_tokenizer.py with a TODO comment suggesting the use of the transformers library` --- examples/research_projects/anytext/bert_tokenizer.py | 1 + 1 file changed, 1 insertion(+) diff --git a/examples/research_projects/anytext/bert_tokenizer.py b/examples/research_projects/anytext/bert_tokenizer.py index b6d2a2e81232..fbd3f37f3f1f 100644 --- a/examples/research_projects/anytext/bert_tokenizer.py +++ b/examples/research_projects/anytext/bert_tokenizer.py @@ -11,6 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +# TODO: Try to use the `transformers` library instead of this custom implementation if possible. """Tokenization classes.""" from __future__ import absolute_import, division, print_function From 5e1e515a75736b861d445cd15829b0bbeebf4cc6 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tolga=20Cang=C3=B6z?= Date: Wed, 31 Jul 2024 01:19:02 +0300 Subject: [PATCH 10/87] Update error handling to raise and logging --- .../anytext/auxiliary_latent_module.py | 31 +++++++++++-------- 1 file changed, 18 insertions(+), 13 deletions(-) diff --git a/examples/research_projects/anytext/auxiliary_latent_module.py b/examples/research_projects/anytext/auxiliary_latent_module.py index 2df944168de5..0f93119d5e48 100644 --- a/examples/research_projects/anytext/auxiliary_latent_module.py +++ b/examples/research_projects/anytext/auxiliary_latent_module.py @@ -225,20 +225,26 @@ def forward( w=512, ): if prompt is None and texts is None: - return None, -1, "You have input Chinese prompt but the translator is not loaded!", "" + raise ValueError("Prompt or texts must be provided!") + #return None, -1, "You have input Chinese prompt but the translator is not loaded!", "" n_lines = len(texts) if mode in ["text-generation", "gen"]: edit_image = np.ones((h, w, 3)) * 127.5 # empty mask image elif mode in ["text-editing", "edit"]: if draw_pos is None or ori_image is None: - return None, -1, "Reference image and position image are needed for text editing!", "" + raise ValueError("Reference image and position image are needed for text editing!") + #return None, -1, "Reference image and position image are needed for text editing!", "" if isinstance(ori_image, str): ori_image = cv2.imread(ori_image)[..., ::-1] - assert ori_image is not None, f"Can't read ori_image image from{ori_image}!" + if ori_image is None: + raise ValueError(f"Can't read ori_image image from {ori_image}!") + #assert ori_image is not None, f"Can't read ori_image image from{ori_image}!" elif isinstance(ori_image, torch.Tensor): ori_image = ori_image.cpu().numpy() else: - assert isinstance(ori_image, np.ndarray), f"Unknown format of ori_image: {type(ori_image)}" + if not isinstance(ori_image, np.ndarray): + raise ValueError(f"Unknown format of ori_image: {type(ori_image)}") + #assert isinstance(ori_image, np.ndarray), f"Unknown format of ori_image: {type(ori_image)}" edit_image = ori_image.clip(1, 255) # for mask reason edit_image = check_channels(edit_image) edit_image = resize_image( @@ -250,12 +256,16 @@ def forward( pos_imgs = np.zeros((w, h, 1)) if isinstance(draw_pos, str): draw_pos = cv2.imread(draw_pos)[..., ::-1] - assert draw_pos is not None, f"Can't read draw_pos image from{draw_pos}!" + if draw_pos is None: + raise ValueError(f"Can't read draw_pos image from {draw_pos}!") + #assert draw_pos is not None, f"Can't read draw_pos image from{draw_pos}!" pos_imgs = 255 - draw_pos elif isinstance(draw_pos, torch.Tensor): pos_imgs = draw_pos.cpu().numpy() else: - assert isinstance(draw_pos, np.ndarray), f"Unknown format of draw_pos: {type(draw_pos)}" + if not isinstance(draw_pos, np.ndarray): + raise ValueError(f"Unknown format of draw_pos: {type(draw_pos)}") + #assert isinstance(draw_pos, np.ndarray), f"Unknown format of draw_pos: {type(draw_pos)}" if mode in ["text-editing", "edit"]: pos_imgs = cv2.resize(pos_imgs, (w, h)) pos_imgs = pos_imgs[..., 0:1] @@ -269,12 +279,7 @@ def forward( if n_lines == 1 and texts[0] == " ": pass # text-to-image without text else: - return ( - None, - -1, - f"Found {len(pos_imgs)} positions that < needed {n_lines} from prompt, check and try again!", - "", - ) + raise ValueError(f"Found {len(pos_imgs)} positions that < needed {n_lines} from prompt, check and try again!") elif len(pos_imgs) > n_lines: str_warning = f"Warning: found {len(pos_imgs)} positions that > needed {n_lines} from prompt." logger.warning(str_warning) @@ -323,7 +328,7 @@ def forward( contours, _ = cv2.findContours(new_pos, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_NONE) if len(contours) != 1: str_warning = f"Fail to revise position {i} to bounding rect, remain position unchanged..." - raise ValueError(str_warning) + logger.warning(str_warning) else: rect = cv2.minAreaRect(contours[0]) poly = np.int0(cv2.boxPoints(rect)) From 2d10f0c279aa23268633875e8e05e7e774faa753 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tolga=20Cang=C3=B6z?= Date: Wed, 31 Jul 2024 01:47:34 +0300 Subject: [PATCH 11/87] Add `create_glyph_lines` function into `TextEmbeddingModule` --- .../anytext/text_embedding_module.py | 267 +++++++++++++++++- 1 file changed, 265 insertions(+), 2 deletions(-) diff --git a/examples/research_projects/anytext/text_embedding_module.py b/examples/research_projects/anytext/text_embedding_module.py index 44705c19558e..59d33a5ea43c 100644 --- a/examples/research_projects/anytext/text_embedding_module.py +++ b/examples/research_projects/anytext/text_embedding_module.py @@ -4,10 +4,13 @@ from typing import List, Optional +import cv2 +import numpy as np import torch +from PIL import Image, ImageDraw from torch import nn -from diffusers import StableDiffusionLoraLoaderMixin, TextualInversionLoaderMixin +from diffusers.loaders import StableDiffusionLoraLoaderMixin, TextualInversionLoaderMixin from diffusers.models.lora import adjust_lora_scale_text_encoder from diffusers.utils import USE_PEFT_BACKEND, logging, scale_lora_layers, unscale_lora_layers @@ -17,12 +20,272 @@ logger = logging.get_logger(__name__) # pylint: disable=invalid-name +def check_channels(image): + channels = image.shape[2] if len(image.shape) == 3 else 1 + if channels == 1: + image = cv2.cvtColor(image, cv2.COLOR_GRAY2BGR) + elif channels > 3: + image = image[:, :, :3] + return image + + +def resize_image(img, max_length=768): + height, width = img.shape[:2] + max_dimension = max(height, width) + + if max_dimension > max_length: + scale_factor = max_length / max_dimension + new_width = int(round(width * scale_factor)) + new_height = int(round(height * scale_factor)) + new_size = (new_width, new_height) + img = cv2.resize(img, new_size) + height, width = img.shape[:2] + img = cv2.resize(img, (width - (width % 64), height - (height % 64))) + return img + + +def insert_spaces(string, nSpace): + if nSpace == 0: + return string + new_string = "" + for char in string: + new_string += char + " " * nSpace + return new_string[:-nSpace] + + +def draw_glyph(font, text): + g_size = 50 + W, H = (512, 80) + new_font = font.font_variant(size=g_size) + img = Image.new(mode="1", size=(W, H), color=0) + draw = ImageDraw.Draw(img) + left, top, right, bottom = new_font.getbbox(text) + text_width = max(right - left, 5) + text_height = max(bottom - top, 5) + ratio = min(W * 0.9 / text_width, H * 0.9 / text_height) + new_font = font.font_variant(size=int(g_size * ratio)) + + text_width, text_height = new_font.getsize(text) + offset_x, offset_y = new_font.getoffset(text) + x = (img.width - text_width) // 2 + y = (img.height - text_height) // 2 - offset_y // 2 + draw.text((x, y), text, font=new_font, fill="white") + img = np.expand_dims(np.array(img), axis=2).astype(np.float64) + return img + + +def draw_glyph2(font, text, polygon, vertAng=10, scale=1, width=512, height=512, add_space=True): + enlarge_polygon = polygon * scale + rect = cv2.minAreaRect(enlarge_polygon) + box = cv2.boxPoints(rect) + box = np.int0(box) + w, h = rect[1] + angle = rect[2] + if angle < -45: + angle += 90 + angle = -angle + if w < h: + angle += 90 + + vert = False + if abs(angle) % 90 < vertAng or abs(90 - abs(angle) % 90) % 90 < vertAng: + _w = max(box[:, 0]) - min(box[:, 0]) + _h = max(box[:, 1]) - min(box[:, 1]) + if _h >= _w: + vert = True + angle = 0 + + img = np.zeros((height * scale, width * scale, 3), np.uint8) + img = Image.fromarray(img) + + # infer font size + image4ratio = Image.new("RGB", img.size, "white") + draw = ImageDraw.Draw(image4ratio) + _, _, _tw, _th = draw.textbbox(xy=(0, 0), text=text, font=font) + text_w = min(w, h) * (_tw / _th) + if text_w <= max(w, h): + # add space + if len(text) > 1 and not vert and add_space: + for i in range(1, 100): + text_space = insert_spaces(text, i) + _, _, _tw2, _th2 = draw.textbbox(xy=(0, 0), text=text_space, font=font) + if min(w, h) * (_tw2 / _th2) > max(w, h): + break + text = insert_spaces(text, i - 1) + font_size = min(w, h) * 0.80 + else: + shrink = 0.75 if vert else 0.85 + font_size = min(w, h) / (text_w / max(w, h)) * shrink + new_font = font.font_variant(size=int(font_size)) + + left, top, right, bottom = new_font.getbbox(text) + text_width = right - left + text_height = bottom - top + + layer = Image.new("RGBA", img.size, (0, 0, 0, 0)) + draw = ImageDraw.Draw(layer) + if not vert: + draw.text( + (rect[0][0] - text_width // 2, rect[0][1] - text_height // 2 - top), + text, + font=new_font, + fill=(255, 255, 255, 255), + ) + else: + x_s = min(box[:, 0]) + _w // 2 - text_height // 2 + y_s = min(box[:, 1]) + for c in text: + draw.text((x_s, y_s), c, font=new_font, fill=(255, 255, 255, 255)) + _, _t, _, _b = new_font.getbbox(c) + y_s += _b + + rotated_layer = layer.rotate(angle, expand=1, center=(rect[0][0], rect[0][1])) + + x_offset = int((img.width - rotated_layer.width) / 2) + y_offset = int((img.height - rotated_layer.height) / 2) + img.paste(rotated_layer, (x_offset, y_offset), rotated_layer) + img = np.expand_dims(np.array(img.convert("1")), axis=2).astype(np.float64) + return img + + class TextEmbeddingModule(nn.Module): def __init__(self): super().__init__() self.frozen_CLIP_embedder_t3 = FrozenCLIPEmbedderT3() - # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.encode_prompt + def forward(self, text): + pass + + def create_glyph_lines( + self, + prompt, + texts, + mode="text-generation", + img_count=1, + max_chars=77, + revise_pos=False, + draw_pos=None, + ori_image=None, + sort_priority=False, + h=512, + w=512, + ): + if prompt is None and texts is None: + raise ValueError("Prompt or texts should be provided!") + # return None, -1, "You have input Chinese prompt but the translator is not loaded!", "" + n_lines = len(texts) + if mode in ["text-generation", "gen"]: + edit_image = np.ones((h, w, 3)) * 127.5 # empty mask image + elif mode in ["text-editing", "edit"]: + if draw_pos is None or ori_image is None: + raise ValueError("Reference image and position image are needed for text editing!") + # return None, -1, "Reference image and position image are needed for text editing!", "" + if isinstance(ori_image, str): + ori_image = cv2.imread(ori_image)[..., ::-1] + if ori_image is None: + raise ValueError(f"Can't read ori_image image from{ori_image}!") + # assert ori_image is not None, f"Can't read ori_image image from{ori_image}!" + elif isinstance(ori_image, torch.Tensor): + ori_image = ori_image.cpu().numpy() + else: + if not isinstance(ori_image, np.ndarray): + raise ValueError(f"Unknown format of ori_image: {type(ori_image)}") + # assert isinstance(ori_image, np.ndarray), f'Unknown format of ori_image: {type(ori_image)}' + edit_image = ori_image.clip(1, 255) # for mask reason + edit_image = check_channels(edit_image) + edit_image = resize_image( + edit_image, max_length=768 + ) # make w h multiple of 64, resize if w or h > max_length + h, w = edit_image.shape[:2] # change h, w by input ref_img + # preprocess pos_imgs(if numpy, make sure it's white pos in black bg) + if draw_pos is None: + pos_imgs = np.zeros((w, h, 1)) + if isinstance(draw_pos, str): + draw_pos = cv2.imread(draw_pos)[..., ::-1] + if draw_pos is None: + raise ValueError(f"Can't read draw_pos image from{draw_pos}!") + # assert draw_pos is not None, f"Can't read draw_pos image from{draw_pos}!" + pos_imgs = 255 - draw_pos + elif isinstance(draw_pos, torch.Tensor): + pos_imgs = draw_pos.cpu().numpy() + else: + if not isinstance(draw_pos, np.ndarray): + raise ValueError(f"Unknown format of draw_pos: {type(draw_pos)}") + # assert isinstance(draw_pos, np.ndarray), f'Unknown format of draw_pos: {type(draw_pos)}' + if mode in ["text-editing", "edit"]: + pos_imgs = cv2.resize(pos_imgs, (w, h)) + pos_imgs = pos_imgs[..., 0:1] + pos_imgs = cv2.convertScaleAbs(pos_imgs) + _, pos_imgs = cv2.threshold(pos_imgs, 254, 255, cv2.THRESH_BINARY) + # separate pos_imgs + pos_imgs = self.separate_pos_imgs(pos_imgs, sort_priority) + if len(pos_imgs) == 0: + pos_imgs = [np.zeros((h, w, 1))] + if len(pos_imgs) < n_lines: + if n_lines == 1 and texts[0] == " ": + pass # text-to-image without text + else: + raise ValueError( + f"Found {len(pos_imgs)} positions that < needed {n_lines} from prompt, check and try again!" + ) + # return None, -1, f'Found {len(pos_imgs)} positions that < needed {n_lines} from prompt, check and try again!', '' + elif len(pos_imgs) > n_lines: + str_warning = f"Warning: found {len(pos_imgs)} positions that > needed {n_lines} from prompt." + logger.warning(str_warning) + # get pre_pos, poly_list, hint that needed for anytext + pre_pos = [] + poly_list = [] + for input_pos in pos_imgs: + if input_pos.mean() != 0: + input_pos = input_pos[..., np.newaxis] if len(input_pos.shape) == 2 else input_pos + poly, pos_img = self.find_polygon(input_pos) + pre_pos += [pos_img / 255.0] + poly_list += [poly] + else: + pre_pos += [np.zeros((h, w, 1))] + poly_list += [None] + # prepare info dict + info = {} + info["gly_line"] = [] + gly_pos_imgs = [] + for i in range(len(texts)): + text = texts[i] + if len(text) > max_chars: + str_warning = f'"{text}" length > max_chars: {max_chars}, will be cut off...' + logger.warning(str_warning) + text = text[:max_chars] + gly_scale = 2 + if pre_pos[i].mean() != 0: + gly_line = draw_glyph(self.font, text) + glyphs = draw_glyph2( + self.font, text, poly_list[i], scale=gly_scale, width=w, height=h, add_space=False + ) + gly_pos_img = cv2.drawContours(glyphs * 255, [poly_list[i] * gly_scale], 0, (255, 255, 255), 1) + if revise_pos: + resize_gly = cv2.resize(glyphs, (pre_pos[i].shape[1], pre_pos[i].shape[0])) + new_pos = cv2.morphologyEx( + (resize_gly * 255).astype(np.uint8), + cv2.MORPH_CLOSE, + kernel=np.ones((resize_gly.shape[0] // 10, resize_gly.shape[1] // 10), dtype=np.uint8), + iterations=1, + ) + new_pos = new_pos[..., np.newaxis] if len(new_pos.shape) == 2 else new_pos + contours, _ = cv2.findContours(new_pos, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_NONE) + if len(contours) != 1: + str_warning = f"Fail to revise position {i} to bounding rect, remain position unchanged..." + logger.warning(str_warning) + else: + rect = cv2.minAreaRect(contours[0]) + poly = np.int0(cv2.boxPoints(rect)) + gly_pos_img = cv2.drawContours(glyphs * 255, [poly * gly_scale], 0, (255, 255, 255), 1) + gly_pos_imgs += [gly_pos_img] # for show + else: + gly_line = np.zeros((80, 512, 1)) + gly_pos_imgs += [np.zeros((h * gly_scale, w * gly_scale, 1))] # for show + info["gly_line"] += [self.arr2tensor(gly_line, img_count)] + + return info["gly_line"] + def encode_prompt( self, prompt, From bc197a936b29d9824ece31169785d4b7052abaa7 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tolga=20Cang=C3=B6z?= Date: Wed, 31 Jul 2024 01:48:15 +0300 Subject: [PATCH 12/87] make style --- .../anytext/auxiliary_latent_module.py | 16 +++++++++------- 1 file changed, 9 insertions(+), 7 deletions(-) diff --git a/examples/research_projects/anytext/auxiliary_latent_module.py b/examples/research_projects/anytext/auxiliary_latent_module.py index 0f93119d5e48..cdaa13454a69 100644 --- a/examples/research_projects/anytext/auxiliary_latent_module.py +++ b/examples/research_projects/anytext/auxiliary_latent_module.py @@ -226,25 +226,25 @@ def forward( ): if prompt is None and texts is None: raise ValueError("Prompt or texts must be provided!") - #return None, -1, "You have input Chinese prompt but the translator is not loaded!", "" + # return None, -1, "You have input Chinese prompt but the translator is not loaded!", "" n_lines = len(texts) if mode in ["text-generation", "gen"]: edit_image = np.ones((h, w, 3)) * 127.5 # empty mask image elif mode in ["text-editing", "edit"]: if draw_pos is None or ori_image is None: raise ValueError("Reference image and position image are needed for text editing!") - #return None, -1, "Reference image and position image are needed for text editing!", "" + # return None, -1, "Reference image and position image are needed for text editing!", "" if isinstance(ori_image, str): ori_image = cv2.imread(ori_image)[..., ::-1] if ori_image is None: raise ValueError(f"Can't read ori_image image from {ori_image}!") - #assert ori_image is not None, f"Can't read ori_image image from{ori_image}!" + # assert ori_image is not None, f"Can't read ori_image image from{ori_image}!" elif isinstance(ori_image, torch.Tensor): ori_image = ori_image.cpu().numpy() else: if not isinstance(ori_image, np.ndarray): raise ValueError(f"Unknown format of ori_image: {type(ori_image)}") - #assert isinstance(ori_image, np.ndarray), f"Unknown format of ori_image: {type(ori_image)}" + # assert isinstance(ori_image, np.ndarray), f"Unknown format of ori_image: {type(ori_image)}" edit_image = ori_image.clip(1, 255) # for mask reason edit_image = check_channels(edit_image) edit_image = resize_image( @@ -258,14 +258,14 @@ def forward( draw_pos = cv2.imread(draw_pos)[..., ::-1] if draw_pos is None: raise ValueError(f"Can't read draw_pos image from {draw_pos}!") - #assert draw_pos is not None, f"Can't read draw_pos image from{draw_pos}!" + # assert draw_pos is not None, f"Can't read draw_pos image from{draw_pos}!" pos_imgs = 255 - draw_pos elif isinstance(draw_pos, torch.Tensor): pos_imgs = draw_pos.cpu().numpy() else: if not isinstance(draw_pos, np.ndarray): raise ValueError(f"Unknown format of draw_pos: {type(draw_pos)}") - #assert isinstance(draw_pos, np.ndarray), f"Unknown format of draw_pos: {type(draw_pos)}" + # assert isinstance(draw_pos, np.ndarray), f"Unknown format of draw_pos: {type(draw_pos)}" if mode in ["text-editing", "edit"]: pos_imgs = cv2.resize(pos_imgs, (w, h)) pos_imgs = pos_imgs[..., 0:1] @@ -279,7 +279,9 @@ def forward( if n_lines == 1 and texts[0] == " ": pass # text-to-image without text else: - raise ValueError(f"Found {len(pos_imgs)} positions that < needed {n_lines} from prompt, check and try again!") + raise ValueError( + f"Found {len(pos_imgs)} positions that < needed {n_lines} from prompt, check and try again!" + ) elif len(pos_imgs) > n_lines: str_warning = f"Warning: found {len(pos_imgs)} positions that > needed {n_lines} from prompt." logger.warning(str_warning) From e52d8ccec12fc06ff62c3e9018c4c7ff549b3075 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tolga=20Cang=C3=B6z?= Date: Wed, 31 Jul 2024 14:59:55 +0300 Subject: [PATCH 13/87] Up --- .../anytext/auxiliary_latent_module.py | 287 +++++++++--------- 1 file changed, 136 insertions(+), 151 deletions(-) diff --git a/examples/research_projects/anytext/auxiliary_latent_module.py b/examples/research_projects/anytext/auxiliary_latent_module.py index cdaa13454a69..fe010bf031f7 100644 --- a/examples/research_projects/anytext/auxiliary_latent_module.py +++ b/examples/research_projects/anytext/auxiliary_latent_module.py @@ -14,134 +14,6 @@ logger = logging.get_logger(__name__) # pylint: disable=invalid-name -def check_channels(image): - channels = image.shape[2] if len(image.shape) == 3 else 1 - if channels == 1: - image = cv2.cvtColor(image, cv2.COLOR_GRAY2BGR) - elif channels > 3: - image = image[:, :, :3] - return image - - -def resize_image(img, max_length=768): - height, width = img.shape[:2] - max_dimension = max(height, width) - - if max_dimension > max_length: - scale_factor = max_length / max_dimension - new_width = int(round(width * scale_factor)) - new_height = int(round(height * scale_factor)) - new_size = (new_width, new_height) - img = cv2.resize(img, new_size) - height, width = img.shape[:2] - img = cv2.resize(img, (width - (width % 64), height - (height % 64))) - return img - - -def insert_spaces(string, nSpace): - if nSpace == 0: - return string - new_string = "" - for char in string: - new_string += char + " " * nSpace - return new_string[:-nSpace] - - -def draw_glyph(font, text): - g_size = 50 - W, H = (512, 80) - new_font = font.font_variant(size=g_size) - img = Image.new(mode="1", size=(W, H), color=0) - draw = ImageDraw.Draw(img) - left, top, right, bottom = new_font.getbbox(text) - text_width = max(right - left, 5) - text_height = max(bottom - top, 5) - ratio = min(W * 0.9 / text_width, H * 0.9 / text_height) - new_font = font.font_variant(size=int(g_size * ratio)) - - text_width, text_height = new_font.getsize(text) - offset_x, offset_y = new_font.getoffset(text) - x = (img.width - text_width) // 2 - y = (img.height - text_height) // 2 - offset_y // 2 - draw.text((x, y), text, font=new_font, fill="white") - img = np.expand_dims(np.array(img), axis=2).astype(np.float64) - return img - - -def draw_glyph2(font, text, polygon, vertAng=10, scale=1, width=512, height=512, add_space=True): - enlarge_polygon = polygon * scale - rect = cv2.minAreaRect(enlarge_polygon) - box = cv2.boxPoints(rect) - box = np.int0(box) - w, h = rect[1] - angle = rect[2] - if angle < -45: - angle += 90 - angle = -angle - if w < h: - angle += 90 - - vert = False - if abs(angle) % 90 < vertAng or abs(90 - abs(angle) % 90) % 90 < vertAng: - _w = max(box[:, 0]) - min(box[:, 0]) - _h = max(box[:, 1]) - min(box[:, 1]) - if _h >= _w: - vert = True - angle = 0 - - img = np.zeros((height * scale, width * scale, 3), np.uint8) - img = Image.fromarray(img) - - # infer font size - image4ratio = Image.new("RGB", img.size, "white") - draw = ImageDraw.Draw(image4ratio) - _, _, _tw, _th = draw.textbbox(xy=(0, 0), text=text, font=font) - text_w = min(w, h) * (_tw / _th) - if text_w <= max(w, h): - # add space - if len(text) > 1 and not vert and add_space: - for i in range(1, 100): - text_space = insert_spaces(text, i) - _, _, _tw2, _th2 = draw.textbbox(xy=(0, 0), text=text_space, font=font) - if min(w, h) * (_tw2 / _th2) > max(w, h): - break - text = insert_spaces(text, i - 1) - font_size = min(w, h) * 0.80 - else: - shrink = 0.75 if vert else 0.85 - font_size = min(w, h) / (text_w / max(w, h)) * shrink - new_font = font.font_variant(size=int(font_size)) - - left, top, right, bottom = new_font.getbbox(text) - text_width = right - left - text_height = bottom - top - - layer = Image.new("RGBA", img.size, (0, 0, 0, 0)) - draw = ImageDraw.Draw(layer) - if not vert: - draw.text( - (rect[0][0] - text_width // 2, rect[0][1] - text_height // 2 - top), - text, - font=new_font, - fill=(255, 255, 255, 255), - ) - else: - x_s = min(box[:, 0]) + _w // 2 - text_height // 2 - y_s = min(box[:, 1]) - for c in text: - draw.text((x_s, y_s), c, font=new_font, fill=(255, 255, 255, 255)) - _, _t, _, _b = new_font.getbbox(c) - y_s += _b - - rotated_layer = layer.rotate(angle, expand=1, center=(rect[0][0], rect[0][1])) - - x_offset = int((img.width - rotated_layer.width) / 2) - y_offset = int((img.height - rotated_layer.height) / 2) - img.paste(rotated_layer, (x_offset, y_offset), rotated_layer) - img = np.expand_dims(np.array(img.convert("1")), axis=2).astype(np.float64) - return img - - # Taken from AnyText.ldm.modules.diffusionmodules.util.conv_nd def conv_nd(dims, *args, **kwargs): """ @@ -164,9 +36,10 @@ def zero_module(module: nn.Module) -> nn.Module: class AuxiliaryLatentModule(nn.Module): - def __init__(self, dims, model_channels, glyph_channels, position_channels, font_path): + def __init__(self, dims, model_channels, glyph_channels, position_channels, font_path, **kwargs): super().__init__() self.font = ImageFont.truetype(font_path, 60) + self.use_fp16 = kwargs.get("use_fp16", False) self.glyph_block = nn.Sequential( conv_nd(dims, glyph_channels, 8, 3, padding=1), nn.SiLU(), @@ -209,7 +82,6 @@ def __init__(self, dims, model_channels, glyph_channels, position_channels, font def forward( self, - text_info, emb, context, mode, @@ -246,8 +118,8 @@ def forward( raise ValueError(f"Unknown format of ori_image: {type(ori_image)}") # assert isinstance(ori_image, np.ndarray), f"Unknown format of ori_image: {type(ori_image)}" edit_image = ori_image.clip(1, 255) # for mask reason - edit_image = check_channels(edit_image) - edit_image = resize_image( + edit_image = self.check_channels(edit_image) + edit_image = self.resize_image( edit_image, max_length=768 ) # make w h multiple of 64, resize if w or h > max_length h, w = edit_image.shape[:2] # change h, w by input ref_img @@ -299,11 +171,9 @@ def forward( poly_list += [None] np_hint = np.sum(pre_pos, axis=0).clip(0, 1) # prepare info dict - info = {} - info["glyphs"] = [] - info["gly_line"] = [] - info["positions"] = [] - info["n_lines"] = [len(texts)] * img_count + glyphs_list = [] + positions = [] + n_lines = [len(texts)] * img_count gly_pos_imgs = [] for i in range(len(texts)): text = texts[i] @@ -313,8 +183,7 @@ def forward( text = text[:max_chars] gly_scale = 2 if pre_pos[i].mean() != 0: - gly_line = draw_glyph(self.font, text) - glyphs = draw_glyph2( + glyphs = self.draw_glyph2( self.font, text, poly_list[i], scale=gly_scale, width=w, height=h, add_space=False ) gly_pos_img = cv2.drawContours(glyphs * 255, [poly_list[i] * gly_scale], 0, (255, 255, 255), 1) @@ -339,12 +208,10 @@ def forward( gly_pos_imgs += [gly_pos_img] # for show else: glyphs = np.zeros((h * gly_scale, w * gly_scale, 1)) - gly_line = np.zeros((80, 512, 1)) gly_pos_imgs += [np.zeros((h * gly_scale, w * gly_scale, 1))] # for show pos = pre_pos[i] - info["glyphs"] += [self.arr2tensor(glyphs, img_count)] - info["gly_line"] += [self.arr2tensor(gly_line, img_count)] - info["positions"] += [self.arr2tensor(pos, img_count)] + glyphs_list += [self.arr2tensor(glyphs, img_count)] + positions += [self.arr2tensor(pos, img_count)] # get masked_x masked_img = ((edit_image.astype(np.float32) / 127.5) - 1.0) * (1 - np_hint) @@ -352,18 +219,136 @@ def forward( masked_img = torch.from_numpy(masked_img.copy()).float().cpu() if self.use_fp16: masked_img = masked_img.half() - encoder_posterior = self.model.encode_first_stage(masked_img[None, ...]) - masked_x = self.model.get_first_stage_encoding(encoder_posterior).detach() + encoder_posterior = self.encode_first_stage(masked_img[None, ...]) + masked_x = self.get_first_stage_encoding(encoder_posterior).detach() if self.use_fp16: masked_x = masked_x.half() - info["masked_x"] = torch.cat([masked_x for _ in range(img_count)], dim=0) - - # hint = self.arr2tensor(np_hint, img_count) + masked_x = torch.cat([masked_x for _ in range(img_count)], dim=0) - glyphs = torch.cat(text_info["glyphs"], dim=1).sum(dim=1, keepdim=True) - positions = torch.cat(text_info["positions"], dim=1).sum(dim=1, keepdim=True) + glyphs = torch.cat(glyphs_list, dim=1).sum(dim=1, keepdim=True) + positions = torch.cat(positions, dim=1).sum(dim=1, keepdim=True) enc_glyph = self.glyph_block(glyphs, emb, context) enc_pos = self.position_block(positions, emb, context) - guided_hint = self.fuse_block(torch.cat([enc_glyph, enc_pos, text_info["masked_x"]], dim=1)) + guided_hint = self.fuse_block(torch.cat([enc_glyph, enc_pos, masked_x], dim=1)) return guided_hint + + + def encode_first_stage(self, masked_img): + pass + + def get_first_stage_encoding(self, encoder_posterior): + pass + + def arr2tensor(self, arr, bs): + arr = np.transpose(arr, (2, 0, 1)) + _arr = torch.from_numpy(arr.copy()).float().cpu() + if self.use_fp16: + _arr = _arr.half() + _arr = torch.stack([_arr for _ in range(bs)], dim=0) + return _arr + + def check_channels(self, image): + channels = image.shape[2] if len(image.shape) == 3 else 1 + if channels == 1: + image = cv2.cvtColor(image, cv2.COLOR_GRAY2BGR) + elif channels > 3: + image = image[:, :, :3] + return image + + + def resize_image(self, img, max_length=768): + height, width = img.shape[:2] + max_dimension = max(height, width) + + if max_dimension > max_length: + scale_factor = max_length / max_dimension + new_width = int(round(width * scale_factor)) + new_height = int(round(height * scale_factor)) + new_size = (new_width, new_height) + img = cv2.resize(img, new_size) + height, width = img.shape[:2] + img = cv2.resize(img, (width - (width % 64), height - (height % 64))) + return img + + + def insert_spaces(self, string, nSpace): + if nSpace == 0: + return string + new_string = "" + for char in string: + new_string += char + " " * nSpace + return new_string[:-nSpace] + + def draw_glyph2(self, font, text, polygon, vertAng=10, scale=1, width=512, height=512, add_space=True): + enlarge_polygon = polygon * scale + rect = cv2.minAreaRect(enlarge_polygon) + box = cv2.boxPoints(rect) + box = np.int0(box) + w, h = rect[1] + angle = rect[2] + if angle < -45: + angle += 90 + angle = -angle + if w < h: + angle += 90 + + vert = False + if abs(angle) % 90 < vertAng or abs(90 - abs(angle) % 90) % 90 < vertAng: + _w = max(box[:, 0]) - min(box[:, 0]) + _h = max(box[:, 1]) - min(box[:, 1]) + if _h >= _w: + vert = True + angle = 0 + + img = np.zeros((height * scale, width * scale, 3), np.uint8) + img = Image.fromarray(img) + + # infer font size + image4ratio = Image.new("RGB", img.size, "white") + draw = ImageDraw.Draw(image4ratio) + _, _, _tw, _th = draw.textbbox(xy=(0, 0), text=text, font=font) + text_w = min(w, h) * (_tw / _th) + if text_w <= max(w, h): + # add space + if len(text) > 1 and not vert and add_space: + for i in range(1, 100): + text_space = self.insert_spaces(text, i) + _, _, _tw2, _th2 = draw.textbbox(xy=(0, 0), text=text_space, font=font) + if min(w, h) * (_tw2 / _th2) > max(w, h): + break + text = self.insert_spaces(text, i - 1) + font_size = min(w, h) * 0.80 + else: + shrink = 0.75 if vert else 0.85 + font_size = min(w, h) / (text_w / max(w, h)) * shrink + new_font = font.font_variant(size=int(font_size)) + + left, top, right, bottom = new_font.getbbox(text) + text_width = right - left + text_height = bottom - top + + layer = Image.new("RGBA", img.size, (0, 0, 0, 0)) + draw = ImageDraw.Draw(layer) + if not vert: + draw.text( + (rect[0][0] - text_width // 2, rect[0][1] - text_height // 2 - top), + text, + font=new_font, + fill=(255, 255, 255, 255), + ) + else: + x_s = min(box[:, 0]) + _w // 2 - text_height // 2 + y_s = min(box[:, 1]) + for c in text: + draw.text((x_s, y_s), c, font=new_font, fill=(255, 255, 255, 255)) + _, _t, _, _b = new_font.getbbox(c) + y_s += _b + + rotated_layer = layer.rotate(angle, expand=1, center=(rect[0][0], rect[0][1])) + + x_offset = int((img.width - rotated_layer.width) / 2) + y_offset = int((img.height - rotated_layer.height) / 2) + img.paste(rotated_layer, (x_offset, y_offset), rotated_layer) + img = np.expand_dims(np.array(img.convert("1")), axis=2).astype(np.float64) + return img From 8c69d83c68d7014cf16bc1b3e25cacbd8ad86ffb Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tolga=20Cang=C3=B6z?= Date: Wed, 31 Jul 2024 15:06:05 +0300 Subject: [PATCH 14/87] Up --- .../anytext/auxiliary_latent_module.py | 3 - .../anytext/text_embedding_module.py | 272 +++++++++--------- 2 files changed, 135 insertions(+), 140 deletions(-) diff --git a/examples/research_projects/anytext/auxiliary_latent_module.py b/examples/research_projects/anytext/auxiliary_latent_module.py index fe010bf031f7..35486eef8f80 100644 --- a/examples/research_projects/anytext/auxiliary_latent_module.py +++ b/examples/research_projects/anytext/auxiliary_latent_module.py @@ -233,7 +233,6 @@ def forward( return guided_hint - def encode_first_stage(self, masked_img): pass @@ -256,7 +255,6 @@ def check_channels(self, image): image = image[:, :, :3] return image - def resize_image(self, img, max_length=768): height, width = img.shape[:2] max_dimension = max(height, width) @@ -271,7 +269,6 @@ def resize_image(self, img, max_length=768): img = cv2.resize(img, (width - (width % 64), height - (height % 64))) return img - def insert_spaces(self, string, nSpace): if nSpace == 0: return string diff --git a/examples/research_projects/anytext/text_embedding_module.py b/examples/research_projects/anytext/text_embedding_module.py index 59d33a5ea43c..65006a167484 100644 --- a/examples/research_projects/anytext/text_embedding_module.py +++ b/examples/research_projects/anytext/text_embedding_module.py @@ -20,142 +20,18 @@ logger = logging.get_logger(__name__) # pylint: disable=invalid-name -def check_channels(image): - channels = image.shape[2] if len(image.shape) == 3 else 1 - if channels == 1: - image = cv2.cvtColor(image, cv2.COLOR_GRAY2BGR) - elif channels > 3: - image = image[:, :, :3] - return image - - -def resize_image(img, max_length=768): - height, width = img.shape[:2] - max_dimension = max(height, width) - - if max_dimension > max_length: - scale_factor = max_length / max_dimension - new_width = int(round(width * scale_factor)) - new_height = int(round(height * scale_factor)) - new_size = (new_width, new_height) - img = cv2.resize(img, new_size) - height, width = img.shape[:2] - img = cv2.resize(img, (width - (width % 64), height - (height % 64))) - return img - - -def insert_spaces(string, nSpace): - if nSpace == 0: - return string - new_string = "" - for char in string: - new_string += char + " " * nSpace - return new_string[:-nSpace] - - -def draw_glyph(font, text): - g_size = 50 - W, H = (512, 80) - new_font = font.font_variant(size=g_size) - img = Image.new(mode="1", size=(W, H), color=0) - draw = ImageDraw.Draw(img) - left, top, right, bottom = new_font.getbbox(text) - text_width = max(right - left, 5) - text_height = max(bottom - top, 5) - ratio = min(W * 0.9 / text_width, H * 0.9 / text_height) - new_font = font.font_variant(size=int(g_size * ratio)) - - text_width, text_height = new_font.getsize(text) - offset_x, offset_y = new_font.getoffset(text) - x = (img.width - text_width) // 2 - y = (img.height - text_height) // 2 - offset_y // 2 - draw.text((x, y), text, font=new_font, fill="white") - img = np.expand_dims(np.array(img), axis=2).astype(np.float64) - return img - - -def draw_glyph2(font, text, polygon, vertAng=10, scale=1, width=512, height=512, add_space=True): - enlarge_polygon = polygon * scale - rect = cv2.minAreaRect(enlarge_polygon) - box = cv2.boxPoints(rect) - box = np.int0(box) - w, h = rect[1] - angle = rect[2] - if angle < -45: - angle += 90 - angle = -angle - if w < h: - angle += 90 - - vert = False - if abs(angle) % 90 < vertAng or abs(90 - abs(angle) % 90) % 90 < vertAng: - _w = max(box[:, 0]) - min(box[:, 0]) - _h = max(box[:, 1]) - min(box[:, 1]) - if _h >= _w: - vert = True - angle = 0 - - img = np.zeros((height * scale, width * scale, 3), np.uint8) - img = Image.fromarray(img) - - # infer font size - image4ratio = Image.new("RGB", img.size, "white") - draw = ImageDraw.Draw(image4ratio) - _, _, _tw, _th = draw.textbbox(xy=(0, 0), text=text, font=font) - text_w = min(w, h) * (_tw / _th) - if text_w <= max(w, h): - # add space - if len(text) > 1 and not vert and add_space: - for i in range(1, 100): - text_space = insert_spaces(text, i) - _, _, _tw2, _th2 = draw.textbbox(xy=(0, 0), text=text_space, font=font) - if min(w, h) * (_tw2 / _th2) > max(w, h): - break - text = insert_spaces(text, i - 1) - font_size = min(w, h) * 0.80 - else: - shrink = 0.75 if vert else 0.85 - font_size = min(w, h) / (text_w / max(w, h)) * shrink - new_font = font.font_variant(size=int(font_size)) - - left, top, right, bottom = new_font.getbbox(text) - text_width = right - left - text_height = bottom - top - - layer = Image.new("RGBA", img.size, (0, 0, 0, 0)) - draw = ImageDraw.Draw(layer) - if not vert: - draw.text( - (rect[0][0] - text_width // 2, rect[0][1] - text_height // 2 - top), - text, - font=new_font, - fill=(255, 255, 255, 255), - ) - else: - x_s = min(box[:, 0]) + _w // 2 - text_height // 2 - y_s = min(box[:, 1]) - for c in text: - draw.text((x_s, y_s), c, font=new_font, fill=(255, 255, 255, 255)) - _, _t, _, _b = new_font.getbbox(c) - y_s += _b - - rotated_layer = layer.rotate(angle, expand=1, center=(rect[0][0], rect[0][1])) - - x_offset = int((img.width - rotated_layer.width) / 2) - y_offset = int((img.height - rotated_layer.height) / 2) - img.paste(rotated_layer, (x_offset, y_offset), rotated_layer) - img = np.expand_dims(np.array(img.convert("1")), axis=2).astype(np.float64) - return img - - class TextEmbeddingModule(nn.Module): def __init__(self): super().__init__() self.frozen_CLIP_embedder_t3 = FrozenCLIPEmbedderT3() + self.linear = nn.Linear() def forward(self, text): pass + def ocr(self, glyph_lines): + pass + def create_glyph_lines( self, prompt, @@ -192,8 +68,8 @@ def create_glyph_lines( raise ValueError(f"Unknown format of ori_image: {type(ori_image)}") # assert isinstance(ori_image, np.ndarray), f'Unknown format of ori_image: {type(ori_image)}' edit_image = ori_image.clip(1, 255) # for mask reason - edit_image = check_channels(edit_image) - edit_image = resize_image( + edit_image = self.check_channels(edit_image) + edit_image = self.resize_image( edit_image, max_length=768 ) # make w h multiple of 64, resize if w or h > max_length h, w = edit_image.shape[:2] # change h, w by input ref_img @@ -245,8 +121,7 @@ def create_glyph_lines( pre_pos += [np.zeros((h, w, 1))] poly_list += [None] # prepare info dict - info = {} - info["gly_line"] = [] + gly_lines = [] gly_pos_imgs = [] for i in range(len(texts)): text = texts[i] @@ -256,8 +131,8 @@ def create_glyph_lines( text = text[:max_chars] gly_scale = 2 if pre_pos[i].mean() != 0: - gly_line = draw_glyph(self.font, text) - glyphs = draw_glyph2( + gly_line = self.draw_glyph(self.font, text) + glyphs = self.draw_glyph2( self.font, text, poly_list[i], scale=gly_scale, width=w, height=h, add_space=False ) gly_pos_img = cv2.drawContours(glyphs * 255, [poly_list[i] * gly_scale], 0, (255, 255, 255), 1) @@ -282,9 +157,132 @@ def create_glyph_lines( else: gly_line = np.zeros((80, 512, 1)) gly_pos_imgs += [np.zeros((h * gly_scale, w * gly_scale, 1))] # for show - info["gly_line"] += [self.arr2tensor(gly_line, img_count)] - - return info["gly_line"] + gly_lines += [self.arr2tensor(gly_line, img_count)] + + return gly_lines + + def check_channels(self, image): + channels = image.shape[2] if len(image.shape) == 3 else 1 + if channels == 1: + image = cv2.cvtColor(image, cv2.COLOR_GRAY2BGR) + elif channels > 3: + image = image[:, :, :3] + return image + + def resize_image(self, img, max_length=768): + height, width = img.shape[:2] + max_dimension = max(height, width) + + if max_dimension > max_length: + scale_factor = max_length / max_dimension + new_width = int(round(width * scale_factor)) + new_height = int(round(height * scale_factor)) + new_size = (new_width, new_height) + img = cv2.resize(img, new_size) + height, width = img.shape[:2] + img = cv2.resize(img, (width - (width % 64), height - (height % 64))) + return img + + def insert_spaces(self, string, nSpace): + if nSpace == 0: + return string + new_string = "" + for char in string: + new_string += char + " " * nSpace + return new_string[:-nSpace] + + def draw_glyph(self, font, text): + g_size = 50 + W, H = (512, 80) + new_font = font.font_variant(size=g_size) + img = Image.new(mode="1", size=(W, H), color=0) + draw = ImageDraw.Draw(img) + left, top, right, bottom = new_font.getbbox(text) + text_width = max(right - left, 5) + text_height = max(bottom - top, 5) + ratio = min(W * 0.9 / text_width, H * 0.9 / text_height) + new_font = font.font_variant(size=int(g_size * ratio)) + + text_width, text_height = new_font.getsize(text) + offset_x, offset_y = new_font.getoffset(text) + x = (img.width - text_width) // 2 + y = (img.height - text_height) // 2 - offset_y // 2 + draw.text((x, y), text, font=new_font, fill="white") + img = np.expand_dims(np.array(img), axis=2).astype(np.float64) + return img + + def draw_glyph2(self, font, text, polygon, vertAng=10, scale=1, width=512, height=512, add_space=True): + enlarge_polygon = polygon * scale + rect = cv2.minAreaRect(enlarge_polygon) + box = cv2.boxPoints(rect) + box = np.int0(box) + w, h = rect[1] + angle = rect[2] + if angle < -45: + angle += 90 + angle = -angle + if w < h: + angle += 90 + + vert = False + if abs(angle) % 90 < vertAng or abs(90 - abs(angle) % 90) % 90 < vertAng: + _w = max(box[:, 0]) - min(box[:, 0]) + _h = max(box[:, 1]) - min(box[:, 1]) + if _h >= _w: + vert = True + angle = 0 + + img = np.zeros((height * scale, width * scale, 3), np.uint8) + img = Image.fromarray(img) + + # infer font size + image4ratio = Image.new("RGB", img.size, "white") + draw = ImageDraw.Draw(image4ratio) + _, _, _tw, _th = draw.textbbox(xy=(0, 0), text=text, font=font) + text_w = min(w, h) * (_tw / _th) + if text_w <= max(w, h): + # add space + if len(text) > 1 and not vert and add_space: + for i in range(1, 100): + text_space = self.insert_spaces(text, i) + _, _, _tw2, _th2 = draw.textbbox(xy=(0, 0), text=text_space, font=font) + if min(w, h) * (_tw2 / _th2) > max(w, h): + break + text = self.insert_spaces(text, i - 1) + font_size = min(w, h) * 0.80 + else: + shrink = 0.75 if vert else 0.85 + font_size = min(w, h) / (text_w / max(w, h)) * shrink + new_font = font.font_variant(size=int(font_size)) + + left, top, right, bottom = new_font.getbbox(text) + text_width = right - left + text_height = bottom - top + + layer = Image.new("RGBA", img.size, (0, 0, 0, 0)) + draw = ImageDraw.Draw(layer) + if not vert: + draw.text( + (rect[0][0] - text_width // 2, rect[0][1] - text_height // 2 - top), + text, + font=new_font, + fill=(255, 255, 255, 255), + ) + else: + x_s = min(box[:, 0]) + _w // 2 - text_height // 2 + y_s = min(box[:, 1]) + for c in text: + draw.text((x_s, y_s), c, font=new_font, fill=(255, 255, 255, 255)) + _, _t, _, _b = new_font.getbbox(c) + y_s += _b + + rotated_layer = layer.rotate(angle, expand=1, center=(rect[0][0], rect[0][1])) + + x_offset = int((img.width - rotated_layer.width) / 2) + y_offset = int((img.height - rotated_layer.height) / 2) + img.paste(rotated_layer, (x_offset, y_offset), rotated_layer) + img = np.expand_dims(np.array(img.convert("1")), axis=2).astype(np.float64) + return img def encode_prompt( self, From 4a413aab2703e3badd9fddf9d81dc57e299677c9 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tolga=20Cang=C3=B6z?= Date: Wed, 31 Jul 2024 17:40:53 +0300 Subject: [PATCH 15/87] Up --- .../anytext/pipeline_anytext.py | 2 +- .../anytext/text_embedding_module.py | 176 +++--------------- 2 files changed, 26 insertions(+), 152 deletions(-) diff --git a/examples/research_projects/anytext/pipeline_anytext.py b/examples/research_projects/anytext/pipeline_anytext.py index 52511243ed76..3de2af217994 100644 --- a/examples/research_projects/anytext/pipeline_anytext.py +++ b/examples/research_projects/anytext/pipeline_anytext.py @@ -1155,7 +1155,7 @@ def __call__( ) prompt, texts = self.modify_prompt(prompt) - prompt_embeds, negative_prompt_embeds = self.text_embedding_module.encode_prompt( + prompt_embeds, negative_prompt_embeds = self.text_embedding_module( prompt, device, num_images_per_prompt, diff --git a/examples/research_projects/anytext/text_embedding_module.py b/examples/research_projects/anytext/text_embedding_module.py index 65006a167484..837433752bec 100644 --- a/examples/research_projects/anytext/text_embedding_module.py +++ b/examples/research_projects/anytext/text_embedding_module.py @@ -7,7 +7,7 @@ import cv2 import numpy as np import torch -from PIL import Image, ImageDraw +from PIL import Image, ImageDraw, ImageFont from torch import nn from diffusers.loaders import StableDiffusionLoraLoaderMixin, TextualInversionLoaderMixin @@ -21,52 +21,55 @@ class TextEmbeddingModule(nn.Module): - def __init__(self): + def __init__(self, font_path): super().__init__() - self.frozen_CLIP_embedder_t3 = FrozenCLIPEmbedderT3() + self.font = ImageFont.truetype(font_path, 60) + self.ocr_model = ... self.linear = nn.Linear() + self.frozen_CLIP_embedder_t3 = FrozenCLIPEmbedderT3() - def forward(self, text): - pass + def forward(self, texts, prompt, device, num_images_per_prompt, do_classifier_free_guidance): + glyph_lines = self.create_glyph_lines(texts) + ocr_output = self.ocr(glyph_lines) + _ = self.linear(ocr_output) + # Token Replacement + + # FrozenCLIPEmbedderT3 + prompt_embeds, negative_prompt_embeds = self.encode_prompt( + prompt, + device, + num_images_per_prompt, + do_classifier_free_guidance, + negative_prompt=None, + prompt_embeds=None, + negative_prompt_embeds=None, + lora_scale=None, + clip_skip=None, + ) + return prompt_embeds, negative_prompt_embeds def ocr(self, glyph_lines): pass def create_glyph_lines( self, - prompt, texts, mode="text-generation", img_count=1, max_chars=77, - revise_pos=False, draw_pos=None, ori_image=None, sort_priority=False, h=512, w=512, ): - if prompt is None and texts is None: - raise ValueError("Prompt or texts should be provided!") - # return None, -1, "You have input Chinese prompt but the translator is not loaded!", "" - n_lines = len(texts) if mode in ["text-generation", "gen"]: edit_image = np.ones((h, w, 3)) * 127.5 # empty mask image elif mode in ["text-editing", "edit"]: - if draw_pos is None or ori_image is None: - raise ValueError("Reference image and position image are needed for text editing!") - # return None, -1, "Reference image and position image are needed for text editing!", "" if isinstance(ori_image, str): ori_image = cv2.imread(ori_image)[..., ::-1] - if ori_image is None: - raise ValueError(f"Can't read ori_image image from{ori_image}!") - # assert ori_image is not None, f"Can't read ori_image image from{ori_image}!" elif isinstance(ori_image, torch.Tensor): ori_image = ori_image.cpu().numpy() - else: - if not isinstance(ori_image, np.ndarray): - raise ValueError(f"Unknown format of ori_image: {type(ori_image)}") - # assert isinstance(ori_image, np.ndarray), f'Unknown format of ori_image: {type(ori_image)}' edit_image = ori_image.clip(1, 255) # for mask reason edit_image = self.check_channels(edit_image) edit_image = self.resize_image( @@ -78,16 +81,9 @@ def create_glyph_lines( pos_imgs = np.zeros((w, h, 1)) if isinstance(draw_pos, str): draw_pos = cv2.imread(draw_pos)[..., ::-1] - if draw_pos is None: - raise ValueError(f"Can't read draw_pos image from{draw_pos}!") - # assert draw_pos is not None, f"Can't read draw_pos image from{draw_pos}!" pos_imgs = 255 - draw_pos elif isinstance(draw_pos, torch.Tensor): pos_imgs = draw_pos.cpu().numpy() - else: - if not isinstance(draw_pos, np.ndarray): - raise ValueError(f"Unknown format of draw_pos: {type(draw_pos)}") - # assert isinstance(draw_pos, np.ndarray), f'Unknown format of draw_pos: {type(draw_pos)}' if mode in ["text-editing", "edit"]: pos_imgs = cv2.resize(pos_imgs, (w, h)) pos_imgs = pos_imgs[..., 0:1] @@ -97,66 +93,25 @@ def create_glyph_lines( pos_imgs = self.separate_pos_imgs(pos_imgs, sort_priority) if len(pos_imgs) == 0: pos_imgs = [np.zeros((h, w, 1))] - if len(pos_imgs) < n_lines: - if n_lines == 1 and texts[0] == " ": - pass # text-to-image without text - else: - raise ValueError( - f"Found {len(pos_imgs)} positions that < needed {n_lines} from prompt, check and try again!" - ) - # return None, -1, f'Found {len(pos_imgs)} positions that < needed {n_lines} from prompt, check and try again!', '' - elif len(pos_imgs) > n_lines: - str_warning = f"Warning: found {len(pos_imgs)} positions that > needed {n_lines} from prompt." - logger.warning(str_warning) - # get pre_pos, poly_list, hint that needed for anytext + # get pre_pos that needed for anytext pre_pos = [] - poly_list = [] for input_pos in pos_imgs: if input_pos.mean() != 0: input_pos = input_pos[..., np.newaxis] if len(input_pos.shape) == 2 else input_pos poly, pos_img = self.find_polygon(input_pos) pre_pos += [pos_img / 255.0] - poly_list += [poly] else: pre_pos += [np.zeros((h, w, 1))] - poly_list += [None] # prepare info dict gly_lines = [] - gly_pos_imgs = [] for i in range(len(texts)): text = texts[i] if len(text) > max_chars: - str_warning = f'"{text}" length > max_chars: {max_chars}, will be cut off...' - logger.warning(str_warning) text = text[:max_chars] - gly_scale = 2 if pre_pos[i].mean() != 0: gly_line = self.draw_glyph(self.font, text) - glyphs = self.draw_glyph2( - self.font, text, poly_list[i], scale=gly_scale, width=w, height=h, add_space=False - ) - gly_pos_img = cv2.drawContours(glyphs * 255, [poly_list[i] * gly_scale], 0, (255, 255, 255), 1) - if revise_pos: - resize_gly = cv2.resize(glyphs, (pre_pos[i].shape[1], pre_pos[i].shape[0])) - new_pos = cv2.morphologyEx( - (resize_gly * 255).astype(np.uint8), - cv2.MORPH_CLOSE, - kernel=np.ones((resize_gly.shape[0] // 10, resize_gly.shape[1] // 10), dtype=np.uint8), - iterations=1, - ) - new_pos = new_pos[..., np.newaxis] if len(new_pos.shape) == 2 else new_pos - contours, _ = cv2.findContours(new_pos, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_NONE) - if len(contours) != 1: - str_warning = f"Fail to revise position {i} to bounding rect, remain position unchanged..." - logger.warning(str_warning) - else: - rect = cv2.minAreaRect(contours[0]) - poly = np.int0(cv2.boxPoints(rect)) - gly_pos_img = cv2.drawContours(glyphs * 255, [poly * gly_scale], 0, (255, 255, 255), 1) - gly_pos_imgs += [gly_pos_img] # for show else: gly_line = np.zeros((80, 512, 1)) - gly_pos_imgs += [np.zeros((h * gly_scale, w * gly_scale, 1))] # for show gly_lines += [self.arr2tensor(gly_line, img_count)] return gly_lines @@ -183,14 +138,6 @@ def resize_image(self, img, max_length=768): img = cv2.resize(img, (width - (width % 64), height - (height % 64))) return img - def insert_spaces(self, string, nSpace): - if nSpace == 0: - return string - new_string = "" - for char in string: - new_string += char + " " * nSpace - return new_string[:-nSpace] - def draw_glyph(self, font, text): g_size = 50 W, H = (512, 80) @@ -211,79 +158,6 @@ def draw_glyph(self, font, text): img = np.expand_dims(np.array(img), axis=2).astype(np.float64) return img - def draw_glyph2(self, font, text, polygon, vertAng=10, scale=1, width=512, height=512, add_space=True): - enlarge_polygon = polygon * scale - rect = cv2.minAreaRect(enlarge_polygon) - box = cv2.boxPoints(rect) - box = np.int0(box) - w, h = rect[1] - angle = rect[2] - if angle < -45: - angle += 90 - angle = -angle - if w < h: - angle += 90 - - vert = False - if abs(angle) % 90 < vertAng or abs(90 - abs(angle) % 90) % 90 < vertAng: - _w = max(box[:, 0]) - min(box[:, 0]) - _h = max(box[:, 1]) - min(box[:, 1]) - if _h >= _w: - vert = True - angle = 0 - - img = np.zeros((height * scale, width * scale, 3), np.uint8) - img = Image.fromarray(img) - - # infer font size - image4ratio = Image.new("RGB", img.size, "white") - draw = ImageDraw.Draw(image4ratio) - _, _, _tw, _th = draw.textbbox(xy=(0, 0), text=text, font=font) - text_w = min(w, h) * (_tw / _th) - if text_w <= max(w, h): - # add space - if len(text) > 1 and not vert and add_space: - for i in range(1, 100): - text_space = self.insert_spaces(text, i) - _, _, _tw2, _th2 = draw.textbbox(xy=(0, 0), text=text_space, font=font) - if min(w, h) * (_tw2 / _th2) > max(w, h): - break - text = self.insert_spaces(text, i - 1) - font_size = min(w, h) * 0.80 - else: - shrink = 0.75 if vert else 0.85 - font_size = min(w, h) / (text_w / max(w, h)) * shrink - new_font = font.font_variant(size=int(font_size)) - - left, top, right, bottom = new_font.getbbox(text) - text_width = right - left - text_height = bottom - top - - layer = Image.new("RGBA", img.size, (0, 0, 0, 0)) - draw = ImageDraw.Draw(layer) - if not vert: - draw.text( - (rect[0][0] - text_width // 2, rect[0][1] - text_height // 2 - top), - text, - font=new_font, - fill=(255, 255, 255, 255), - ) - else: - x_s = min(box[:, 0]) + _w // 2 - text_height // 2 - y_s = min(box[:, 1]) - for c in text: - draw.text((x_s, y_s), c, font=new_font, fill=(255, 255, 255, 255)) - _, _t, _, _b = new_font.getbbox(c) - y_s += _b - - rotated_layer = layer.rotate(angle, expand=1, center=(rect[0][0], rect[0][1])) - - x_offset = int((img.width - rotated_layer.width) / 2) - y_offset = int((img.height - rotated_layer.height) / 2) - img.paste(rotated_layer, (x_offset, y_offset), rotated_layer) - img = np.expand_dims(np.array(img.convert("1")), axis=2).astype(np.float64) - return img - def encode_prompt( self, prompt, From 571608be1204e373afda4b0c6aa033af5f699bb9 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tolga=20Cang=C3=B6z?= Date: Thu, 1 Aug 2024 09:50:59 +0300 Subject: [PATCH 16/87] Up --- .../anytext/auxiliary_latent_module.py | 13 ++++++++++++- .../research_projects/anytext/pipeline_anytext.py | 5 ++++- .../anytext/text_embedding_module.py | 1 + 3 files changed, 17 insertions(+), 2 deletions(-) diff --git a/examples/research_projects/anytext/auxiliary_latent_module.py b/examples/research_projects/anytext/auxiliary_latent_module.py index 35486eef8f80..efbd729a7207 100644 --- a/examples/research_projects/anytext/auxiliary_latent_module.py +++ b/examples/research_projects/anytext/auxiliary_latent_module.py @@ -36,10 +36,13 @@ def zero_module(module: nn.Module) -> nn.Module: class AuxiliaryLatentModule(nn.Module): - def __init__(self, dims, model_channels, glyph_channels, position_channels, font_path, **kwargs): + def __init__(self, font_path, dims=2, glyph_channels=256, position_channels=64, model_channels=256, **kwargs): super().__init__() + if font_path is None: + raise ValueError("font_path must be provided!") self.font = ImageFont.truetype(font_path, 60) self.use_fp16 = kwargs.get("use_fp16", False) + self.device = kwargs.get("device", "cpu") self.glyph_block = nn.Sequential( conv_nd(dims, glyph_channels, 8, 3, padding=1), nn.SiLU(), @@ -80,6 +83,7 @@ def __init__(self, dims, model_channels, glyph_channels, position_channels, font self.fuse_block = zero_module(conv_nd(dims, 256 + 64 + 4, model_channels, 3, padding=1)) + @torch.no_grad() def forward( self, emb, @@ -349,3 +353,10 @@ def draw_glyph2(self, font, text, polygon, vertAng=10, scale=1, width=512, heigh img.paste(rotated_layer, (x_offset, y_offset), rotated_layer) img = np.expand_dims(np.array(img.convert("1")), axis=2).astype(np.float64) return img + + def to(self, device): + self.device = device + self.glyph_block = self.glyph_block.to(device) + self.position_block = self.position_block.to(device) + self.fuse_block = self.fuse_block.to(device) + return self diff --git a/examples/research_projects/anytext/pipeline_anytext.py b/examples/research_projects/anytext/pipeline_anytext.py index 3de2af217994..d8bc42c6b1fb 100644 --- a/examples/research_projects/anytext/pipeline_anytext.py +++ b/examples/research_projects/anytext/pipeline_anytext.py @@ -21,6 +21,7 @@ import PIL.Image import torch import torch.nn.functional as F +from auxiliary_latent_module import AuxiliaryLatentModule from bert_tokenizer import BasicTokenizer from text_embedding_module import TextEmbeddingModule from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer, CLIPVisionModelWithProjection @@ -226,9 +227,11 @@ def __init__( feature_extractor: CLIPImageProcessor, image_encoder: CLIPVisionModelWithProjection = None, requires_safety_checker: bool = True, + font_path: str = None, ): super().__init__() - self.text_embedding_module = TextEmbeddingModule() + self.text_embedding_module = TextEmbeddingModule(text_encoder, tokenizer) + self.auxiliary_latent_module = AuxiliaryLatentModule(font_path) if safety_checker is None and requires_safety_checker: logger.warning( diff --git a/examples/research_projects/anytext/text_embedding_module.py b/examples/research_projects/anytext/text_embedding_module.py index 837433752bec..f46f56908aed 100644 --- a/examples/research_projects/anytext/text_embedding_module.py +++ b/examples/research_projects/anytext/text_embedding_module.py @@ -28,6 +28,7 @@ def __init__(self, font_path): self.linear = nn.Linear() self.frozen_CLIP_embedder_t3 = FrozenCLIPEmbedderT3() + @torch.no_grad() def forward(self, texts, prompt, device, num_images_per_prompt, do_classifier_free_guidance): glyph_lines = self.create_glyph_lines(texts) ocr_output = self.ocr(glyph_lines) From a7d025f7500954a91c42f73737eb1a8eb81dfd85 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tolga=20Cang=C3=B6z?= Date: Thu, 1 Aug 2024 16:46:35 +0300 Subject: [PATCH 17/87] Remove several comments --- .../research_projects/anytext/auxiliary_latent_module.py | 7 ------- 1 file changed, 7 deletions(-) diff --git a/examples/research_projects/anytext/auxiliary_latent_module.py b/examples/research_projects/anytext/auxiliary_latent_module.py index efbd729a7207..ab03d54c1ad5 100644 --- a/examples/research_projects/anytext/auxiliary_latent_module.py +++ b/examples/research_projects/anytext/auxiliary_latent_module.py @@ -14,7 +14,6 @@ logger = logging.get_logger(__name__) # pylint: disable=invalid-name -# Taken from AnyText.ldm.modules.diffusionmodules.util.conv_nd def conv_nd(dims, *args, **kwargs): """ Create a 1D, 2D, or 3D convolution module. @@ -102,25 +101,21 @@ def forward( ): if prompt is None and texts is None: raise ValueError("Prompt or texts must be provided!") - # return None, -1, "You have input Chinese prompt but the translator is not loaded!", "" n_lines = len(texts) if mode in ["text-generation", "gen"]: edit_image = np.ones((h, w, 3)) * 127.5 # empty mask image elif mode in ["text-editing", "edit"]: if draw_pos is None or ori_image is None: raise ValueError("Reference image and position image are needed for text editing!") - # return None, -1, "Reference image and position image are needed for text editing!", "" if isinstance(ori_image, str): ori_image = cv2.imread(ori_image)[..., ::-1] if ori_image is None: raise ValueError(f"Can't read ori_image image from {ori_image}!") - # assert ori_image is not None, f"Can't read ori_image image from{ori_image}!" elif isinstance(ori_image, torch.Tensor): ori_image = ori_image.cpu().numpy() else: if not isinstance(ori_image, np.ndarray): raise ValueError(f"Unknown format of ori_image: {type(ori_image)}") - # assert isinstance(ori_image, np.ndarray), f"Unknown format of ori_image: {type(ori_image)}" edit_image = ori_image.clip(1, 255) # for mask reason edit_image = self.check_channels(edit_image) edit_image = self.resize_image( @@ -134,14 +129,12 @@ def forward( draw_pos = cv2.imread(draw_pos)[..., ::-1] if draw_pos is None: raise ValueError(f"Can't read draw_pos image from {draw_pos}!") - # assert draw_pos is not None, f"Can't read draw_pos image from{draw_pos}!" pos_imgs = 255 - draw_pos elif isinstance(draw_pos, torch.Tensor): pos_imgs = draw_pos.cpu().numpy() else: if not isinstance(draw_pos, np.ndarray): raise ValueError(f"Unknown format of draw_pos: {type(draw_pos)}") - # assert isinstance(draw_pos, np.ndarray), f"Unknown format of draw_pos: {type(draw_pos)}" if mode in ["text-editing", "edit"]: pos_imgs = cv2.resize(pos_imgs, (w, h)) pos_imgs = pos_imgs[..., 0:1] From d2c5a65ed0d09c1cb9c28a989ae8fe3adf64fe35 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tolga=20Cang=C3=B6z?= Date: Thu, 1 Aug 2024 16:48:50 +0300 Subject: [PATCH 18/87] refactor: Remove ControlNetConditioningEmbedding and update code accordingly --- .../anytext/text_controlnet.py | 26 ++++++------------- 1 file changed, 8 insertions(+), 18 deletions(-) diff --git a/examples/research_projects/anytext/text_controlnet.py b/examples/research_projects/anytext/text_controlnet.py index 3587ec7358b8..8300f57b3f75 100644 --- a/examples/research_projects/anytext/text_controlnet.py +++ b/examples/research_projects/anytext/text_controlnet.py @@ -19,7 +19,6 @@ from diffusers.configuration_utils import register_to_config from diffusers.models.controlnet import ( - ControlNetConditioningEmbedding, ControlNetModel, ControlNetOutput, ) @@ -175,19 +174,14 @@ def __init__( global_pool_conditions, addition_embed_type_num_heads, ) - self.controlnet_query_cond_embedding = ControlNetConditioningEmbedding( - conditioning_embedding_channels=block_out_channels[0], - block_out_channels=conditioning_embedding_out_channels, - conditioning_channels=3, - ) + self.controlnet_cond_embedding = None # This part is computed inside AuxiliaryLatentModel def forward( self, sample: torch.Tensor, timestep: Union[torch.Tensor, float, int], encoder_hidden_states: torch.Tensor, - controlnet_cond: torch.Tensor, - controlnet_query_cond: torch.Tensor, + guided_hint: torch.Tensor, conditioning_scale: float = 1.0, class_labels: Optional[torch.Tensor] = None, timestep_cond: Optional[torch.Tensor] = None, @@ -207,10 +201,8 @@ def forward( The number of timesteps to denoise an input. encoder_hidden_states (`torch.Tensor`): The encoder hidden states. - controlnet_cond (`torch.Tensor`): - The conditional input tensor of shape `(batch_size, sequence_length, hidden_size)`. - controlnet_query_cond (`torch.Tensor`): - The conditional input tensor of shape `(batch_size, sequence_length, hidden_size)`. + #controlnet_cond (`torch.Tensor`): + # The conditional input tensor of shape `(batch_size, sequence_length, hidden_size)`. conditioning_scale (`float`, defaults to `1.0`): The scale factor for ControlNet outputs. class_labels (`torch.Tensor`, *optional*, defaults to `None`): @@ -244,8 +236,8 @@ def forward( if channel_order == "rgb": # in rgb order by default ... - elif channel_order == "bgr": - controlnet_cond = torch.flip(controlnet_cond, dims=[1]) + # elif channel_order == "bgr": + # controlnet_cond = torch.flip(controlnet_cond, dims=[1]) else: raise ValueError(f"unknown `controlnet_conditioning_channel_order`: {channel_order}") @@ -318,9 +310,8 @@ def forward( # 2. pre-process sample = self.conv_in(sample) - controlnet_cond = self.controlnet_cond_embedding(controlnet_cond) - controlnet_query_cond = self.controlnet_query_cond_embedding(controlnet_query_cond) - sample = sample + controlnet_cond + controlnet_query_cond + # controlnet_cond = self.controlnet_cond_embedding(controlnet_cond) + sample = sample + guided_hint # 3. down down_block_res_samples = (sample,) @@ -352,7 +343,6 @@ def forward( sample = self.mid_block(sample, emb) # 5. Control net blocks - controlnet_down_block_res_samples = () for down_block_res_sample, controlnet_block in zip(down_block_res_samples, self.controlnet_down_blocks): From 2607b6b88c03bda907b9578607b659811f8247e6 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tolga=20Cang=C3=B6z?= Date: Thu, 1 Aug 2024 17:16:09 +0300 Subject: [PATCH 19/87] Up --- examples/research_projects/anytext/text_controlnet.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/examples/research_projects/anytext/text_controlnet.py b/examples/research_projects/anytext/text_controlnet.py index 8300f57b3f75..c8b085e6a04e 100644 --- a/examples/research_projects/anytext/text_controlnet.py +++ b/examples/research_projects/anytext/text_controlnet.py @@ -11,8 +11,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - -# TODO: Figure out `hint_channels` enigma(?) from typing import Any, Dict, Optional, Tuple, Union import torch From a9fe4a057632a7413a516f10d8dd6833f8e5078c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tolga=20Cang=C3=B6z?= Date: Thu, 1 Aug 2024 20:59:46 +0300 Subject: [PATCH 20/87] Up --- .../research_projects/anytext/auxiliary_latent_module.py | 6 +++--- examples/research_projects/anytext/text_controlnet.py | 2 +- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/examples/research_projects/anytext/auxiliary_latent_module.py b/examples/research_projects/anytext/auxiliary_latent_module.py index ab03d54c1ad5..f5b52f3bc8c1 100644 --- a/examples/research_projects/anytext/auxiliary_latent_module.py +++ b/examples/research_projects/anytext/auxiliary_latent_module.py @@ -102,9 +102,9 @@ def forward( if prompt is None and texts is None: raise ValueError("Prompt or texts must be provided!") n_lines = len(texts) - if mode in ["text-generation", "gen"]: + if mode == "generate": edit_image = np.ones((h, w, 3)) * 127.5 # empty mask image - elif mode in ["text-editing", "edit"]: + elif mode == "edit": if draw_pos is None or ori_image is None: raise ValueError("Reference image and position image are needed for text editing!") if isinstance(ori_image, str): @@ -135,7 +135,7 @@ def forward( else: if not isinstance(draw_pos, np.ndarray): raise ValueError(f"Unknown format of draw_pos: {type(draw_pos)}") - if mode in ["text-editing", "edit"]: + if mode == "edit": pos_imgs = cv2.resize(pos_imgs, (w, h)) pos_imgs = pos_imgs[..., 0:1] pos_imgs = cv2.convertScaleAbs(pos_imgs) diff --git a/examples/research_projects/anytext/text_controlnet.py b/examples/research_projects/anytext/text_controlnet.py index c8b085e6a04e..c48c7081924c 100644 --- a/examples/research_projects/anytext/text_controlnet.py +++ b/examples/research_projects/anytext/text_controlnet.py @@ -26,7 +26,7 @@ logger = logging.get_logger(__name__) # pylint: disable=invalid-name -class TextControlNetModel(ControlNetModel): +class AnyTextControlNetModel(ControlNetModel): """ A PromptDiffusionControlNet model. From 567f553c2eef64f408b223290299fefd46cf6722 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tolga=20Cang=C3=B6z?= Date: Thu, 1 Aug 2024 21:11:49 +0300 Subject: [PATCH 21/87] up --- examples/research_projects/anytext/auxiliary_latent_module.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/research_projects/anytext/auxiliary_latent_module.py b/examples/research_projects/anytext/auxiliary_latent_module.py index f5b52f3bc8c1..879ec8d90851 100644 --- a/examples/research_projects/anytext/auxiliary_latent_module.py +++ b/examples/research_projects/anytext/auxiliary_latent_module.py @@ -102,7 +102,7 @@ def forward( if prompt is None and texts is None: raise ValueError("Prompt or texts must be provided!") n_lines = len(texts) - if mode == "generate": + if mode == "generate": edit_image = np.ones((h, w, 3)) * 127.5 # empty mask image elif mode == "edit": if draw_pos is None or ori_image is None: From a9991d05152f1ef223c64d88c69e3dd70dd2222d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tolga=20Cang=C3=B6z?= Date: Thu, 1 Aug 2024 21:12:59 +0300 Subject: [PATCH 22/87] refactor: Update AnyTextPipeline to include new optional parameters --- .../anytext/pipeline_anytext.py | 106 ++++++++++-------- 1 file changed, 59 insertions(+), 47 deletions(-) diff --git a/examples/research_projects/anytext/pipeline_anytext.py b/examples/research_projects/anytext/pipeline_anytext.py index d8bc42c6b1fb..66bc491b37f4 100644 --- a/examples/research_projects/anytext/pipeline_anytext.py +++ b/examples/research_projects/anytext/pipeline_anytext.py @@ -940,10 +940,12 @@ def num_timesteps(self): def __call__( self, prompt: Union[str, List[str]] = None, - image: PipelineImageInput = None, height: Optional[int] = None, width: Optional[int] = None, num_inference_steps: int = 50, + mode: Optional[str] = "generate", + draw_pos: Optional[Union[str, torch.Tensor]] = None, + ori_image: Optional[Union[str, torch.Tensor]] = None, timesteps: List[int] = None, sigmas: List[float] = None, guidance_scale: float = 7.5, @@ -1115,7 +1117,7 @@ def __call__( # 1. Check inputs. Raise error if not correct self.check_inputs( prompt, - image, + # image, callback_steps, negative_prompt, prompt_embeds, @@ -1184,45 +1186,63 @@ def __call__( self.do_classifier_free_guidance, ) + # 3.5 Optionally get Guidance Scale Embedding + timestep_cond = None + if self.unet.config.time_cond_proj_dim is not None: + guidance_scale_tensor = torch.tensor(self.guidance_scale - 1).repeat(batch_size * num_images_per_prompt) + timestep_cond = self.get_guidance_scale_embedding( + guidance_scale_tensor, embedding_dim=self.unet.config.time_cond_proj_dim + ).to(device=device, dtype=latents.dtype) + # 4. Prepare image if isinstance(controlnet, ControlNetModel): - image = self.prepare_image( - image=image, - width=width, - height=height, - batch_size=batch_size * num_images_per_prompt, - num_images_per_prompt=num_images_per_prompt, - device=device, - dtype=controlnet.dtype, - do_classifier_free_guidance=self.do_classifier_free_guidance, - guess_mode=guess_mode, + # image = self.prepare_image( + # image=image, + # width=width, + # height=height, + # batch_size=batch_size * num_images_per_prompt, + # num_images_per_prompt=num_images_per_prompt, + # device=device, + # dtype=controlnet.dtype, + # do_classifier_free_guidance=self.do_classifier_free_guidance, + # guess_mode=guess_mode, + # ) + # height, width = image.shape[-2:] + guided_hint = self.auxiliary_latent_module( + emb=timestep_cond, + context=prompt_embeds, + mode=mode, + texts=texts, + prompt=prompt, + draw_pos=draw_pos, + ori_image=ori_image, + img_count=len(prompt), ) - height, width = image.shape[-2:] - elif isinstance(controlnet, MultiControlNetModel): - images = [] - - # Nested lists as ControlNet condition - if isinstance(image[0], list): - # Transpose the nested image list - image = [list(t) for t in zip(*image)] - - for image_ in image: - image_ = self.prepare_image( - image=image_, - width=width, - height=height, - batch_size=batch_size * num_images_per_prompt, - num_images_per_prompt=num_images_per_prompt, - device=device, - dtype=controlnet.dtype, - do_classifier_free_guidance=self.do_classifier_free_guidance, - guess_mode=guess_mode, - ) - - images.append(image_) - - image = images - height, width = image[0].shape[-2:] + # elif isinstance(controlnet, MultiControlNetModel): + # images = [] + + # # Nested lists as ControlNet condition + # if isinstance(image[0], list): + # # Transpose the nested image list + # image = [list(t) for t in zip(*image)] + + # for image_ in image: + # image_ = self.prepare_image( + # image=image_, + # width=width, + # height=height, + # batch_size=batch_size * num_images_per_prompt, + # num_images_per_prompt=num_images_per_prompt, + # device=device, + # dtype=controlnet.dtype, + # do_classifier_free_guidance=self.do_classifier_free_guidance, + # guess_mode=guess_mode, + # ) + + # images.append(image_) + + # image = images + # height, width = image[0].shape[-2:] else: assert False @@ -1245,14 +1265,6 @@ def __call__( latents, ) - # 6.5 Optionally get Guidance Scale Embedding - timestep_cond = None - if self.unet.config.time_cond_proj_dim is not None: - guidance_scale_tensor = torch.tensor(self.guidance_scale - 1).repeat(batch_size * num_images_per_prompt) - timestep_cond = self.get_guidance_scale_embedding( - guidance_scale_tensor, embedding_dim=self.unet.config.time_cond_proj_dim - ).to(device=device, dtype=latents.dtype) - # 7. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) @@ -1309,7 +1321,7 @@ def __call__( control_model_input, t, encoder_hidden_states=controlnet_prompt_embeds, - controlnet_cond=image, + guided_hint=guided_hint, conditioning_scale=cond_scale, guess_mode=guess_mode, return_dict=False, From 91252e03e1372976224913b1b73e186515bf63cc Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tolga=20Cang=C3=B6z?= Date: Thu, 1 Aug 2024 22:34:55 +0300 Subject: [PATCH 23/87] up --- .../anytext/bert_tokenizer.py | 1 + .../convert_original_anytext_to_diffusers.py | 2119 +---------------- 2 files changed, 2 insertions(+), 2118 deletions(-) diff --git a/examples/research_projects/anytext/bert_tokenizer.py b/examples/research_projects/anytext/bert_tokenizer.py index fbd3f37f3f1f..fd1e0ce32c47 100644 --- a/examples/research_projects/anytext/bert_tokenizer.py +++ b/examples/research_projects/anytext/bert_tokenizer.py @@ -11,6 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. + # TODO: Try to use the `transformers` library instead of this custom implementation if possible. """Tokenization classes.""" diff --git a/examples/research_projects/anytext/convert_original_anytext_to_diffusers.py b/examples/research_projects/anytext/convert_original_anytext_to_diffusers.py index b74cdd783435..69a3d155ff33 100644 --- a/examples/research_projects/anytext/convert_original_anytext_to_diffusers.py +++ b/examples/research_projects/anytext/convert_original_anytext_to_diffusers.py @@ -1,2118 +1 @@ -# coding=utf-8 -# Copyright 2023 The HuggingFace Inc. team. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -"""Conversion script for stable diffusion checkpoints which _only_ contain a controlnet.""" - -import argparse -import re -from contextlib import nullcontext -from io import BytesIO -from typing import Dict, Optional, Union - -import requests -import torch -import yaml -from text_controlnet import TextControlNetModel -from transformers import ( - AutoFeatureExtractor, - BertTokenizerFast, - CLIPImageProcessor, - CLIPTextConfig, - CLIPTextModel, - CLIPTextModelWithProjection, - CLIPTokenizer, - CLIPVisionConfig, - CLIPVisionModelWithProjection, -) - -from diffusers.models import ( - AutoencoderKL, - ControlNetModel, - PriorTransformer, - UNet2DConditionModel, -) -from diffusers.pipelines.latent_diffusion.pipeline_latent_diffusion import LDMBertConfig, LDMBertModel -from diffusers.pipelines.paint_by_example import PaintByExampleImageEncoder -from diffusers.pipelines.pipeline_utils import DiffusionPipeline -from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker -from diffusers.pipelines.stable_diffusion.stable_unclip_image_normalizer import StableUnCLIPImageNormalizer -from diffusers.schedulers import ( - DDIMScheduler, - DDPMScheduler, - DPMSolverMultistepScheduler, - EulerAncestralDiscreteScheduler, - EulerDiscreteScheduler, - HeunDiscreteScheduler, - LMSDiscreteScheduler, - PNDMScheduler, - UnCLIPScheduler, -) -from diffusers.utils import is_accelerate_available, logging - - -if is_accelerate_available(): - from accelerate import init_empty_weights - from accelerate.utils import set_module_tensor_to_device - -logger = logging.get_logger(__name__) # pylint: disable=invalid-name - - -def shave_segments(path, n_shave_prefix_segments=1): - """ - Removes segments. Positive values shave the first segments, negative shave the last segments. - """ - if n_shave_prefix_segments >= 0: - return ".".join(path.split(".")[n_shave_prefix_segments:]) - else: - return ".".join(path.split(".")[:n_shave_prefix_segments]) - - -def renew_resnet_paths(old_list, n_shave_prefix_segments=0): - """ - Updates paths inside resnets to the new naming scheme (local renaming) - """ - mapping = [] - for old_item in old_list: - new_item = old_item.replace("in_layers.0", "norm1") - new_item = new_item.replace("in_layers.2", "conv1") - - new_item = new_item.replace("out_layers.0", "norm2") - new_item = new_item.replace("out_layers.3", "conv2") - - new_item = new_item.replace("emb_layers.1", "time_emb_proj") - new_item = new_item.replace("skip_connection", "conv_shortcut") - - new_item = shave_segments(new_item, n_shave_prefix_segments=n_shave_prefix_segments) - - mapping.append({"old": old_item, "new": new_item}) - - return mapping - - -def renew_vae_resnet_paths(old_list, n_shave_prefix_segments=0): - """ - Updates paths inside resnets to the new naming scheme (local renaming) - """ - mapping = [] - for old_item in old_list: - new_item = old_item - - new_item = new_item.replace("nin_shortcut", "conv_shortcut") - new_item = shave_segments(new_item, n_shave_prefix_segments=n_shave_prefix_segments) - - mapping.append({"old": old_item, "new": new_item}) - - return mapping - - -def renew_attention_paths(old_list, n_shave_prefix_segments=0): - """ - Updates paths inside attentions to the new naming scheme (local renaming) - """ - mapping = [] - for old_item in old_list: - new_item = old_item - - # new_item = new_item.replace('norm.weight', 'group_norm.weight') - # new_item = new_item.replace('norm.bias', 'group_norm.bias') - - # new_item = new_item.replace('proj_out.weight', 'proj_attn.weight') - # new_item = new_item.replace('proj_out.bias', 'proj_attn.bias') - - # new_item = shave_segments(new_item, n_shave_prefix_segments=n_shave_prefix_segments) - - mapping.append({"old": old_item, "new": new_item}) - - return mapping - - -def renew_vae_attention_paths(old_list, n_shave_prefix_segments=0): - """ - Updates paths inside attentions to the new naming scheme (local renaming) - """ - mapping = [] - for old_item in old_list: - new_item = old_item - - new_item = new_item.replace("norm.weight", "group_norm.weight") - new_item = new_item.replace("norm.bias", "group_norm.bias") - - new_item = new_item.replace("q.weight", "to_q.weight") - new_item = new_item.replace("q.bias", "to_q.bias") - - new_item = new_item.replace("k.weight", "to_k.weight") - new_item = new_item.replace("k.bias", "to_k.bias") - - new_item = new_item.replace("v.weight", "to_v.weight") - new_item = new_item.replace("v.bias", "to_v.bias") - - new_item = new_item.replace("proj_out.weight", "to_out.0.weight") - new_item = new_item.replace("proj_out.bias", "to_out.0.bias") - - new_item = shave_segments(new_item, n_shave_prefix_segments=n_shave_prefix_segments) - - mapping.append({"old": old_item, "new": new_item}) - - return mapping - - -def assign_to_checkpoint( - paths, checkpoint, old_checkpoint, attention_paths_to_split=None, additional_replacements=None, config=None -): - """ - This does the final conversion step: take locally converted weights and apply a global renaming to them. It splits - attention layers, and takes into account additional replacements that may arise. - - Assigns the weights to the new checkpoint. - """ - assert isinstance(paths, list), "Paths should be a list of dicts containing 'old' and 'new' keys." - - # Splits the attention layers into three variables. - if attention_paths_to_split is not None: - for path, path_map in attention_paths_to_split.items(): - old_tensor = old_checkpoint[path] - channels = old_tensor.shape[0] // 3 - - target_shape = (-1, channels) if len(old_tensor.shape) == 3 else (-1) - - num_heads = old_tensor.shape[0] // config["num_head_channels"] // 3 - - old_tensor = old_tensor.reshape((num_heads, 3 * channels // num_heads) + old_tensor.shape[1:]) - query, key, value = old_tensor.split(channels // num_heads, dim=1) - - checkpoint[path_map["query"]] = query.reshape(target_shape) - checkpoint[path_map["key"]] = key.reshape(target_shape) - checkpoint[path_map["value"]] = value.reshape(target_shape) - - for path in paths: - new_path = path["new"] - - # These have already been assigned - if attention_paths_to_split is not None and new_path in attention_paths_to_split: - continue - - # Global renaming happens here - new_path = new_path.replace("middle_block.0", "mid_block.resnets.0") - new_path = new_path.replace("middle_block.1", "mid_block.attentions.0") - new_path = new_path.replace("middle_block.2", "mid_block.resnets.1") - - if additional_replacements is not None: - for replacement in additional_replacements: - new_path = new_path.replace(replacement["old"], replacement["new"]) - - # proj_attn.weight has to be converted from conv 1D to linear - is_attn_weight = "proj_attn.weight" in new_path or ("attentions" in new_path and "to_" in new_path) - shape = old_checkpoint[path["old"]].shape - if is_attn_weight and len(shape) == 3: - checkpoint[new_path] = old_checkpoint[path["old"]][:, :, 0] - elif is_attn_weight and len(shape) == 4: - checkpoint[new_path] = old_checkpoint[path["old"]][:, :, 0, 0] - else: - checkpoint[new_path] = old_checkpoint[path["old"]] - - -def conv_attn_to_linear(checkpoint): - keys = list(checkpoint.keys()) - attn_keys = ["query.weight", "key.weight", "value.weight"] - for key in keys: - if ".".join(key.split(".")[-2:]) in attn_keys: - if checkpoint[key].ndim > 2: - checkpoint[key] = checkpoint[key][:, :, 0, 0] - elif "proj_attn.weight" in key: - if checkpoint[key].ndim > 2: - checkpoint[key] = checkpoint[key][:, :, 0] - - -def create_unet_diffusers_config(original_config, image_size: int, controlnet=False): - """ - Creates a config for the diffusers based on the config of the LDM model. - """ - if controlnet: - unet_params = original_config["model"]["params"]["control_stage_config"]["params"] - else: - if ( - "unet_config" in original_config["model"]["params"] - and original_config["model"]["params"]["unet_config"] is not None - ): - unet_params = original_config["model"]["params"]["unet_config"]["params"] - else: - unet_params = original_config["model"]["params"]["network_config"]["params"] - - vae_params = original_config["model"]["params"]["first_stage_config"]["params"]["ddconfig"] - - block_out_channels = [unet_params["model_channels"] * mult for mult in unet_params["channel_mult"]] - - down_block_types = [] - resolution = 1 - for i in range(len(block_out_channels)): - block_type = "CrossAttnDownBlock2D" if resolution in unet_params["attention_resolutions"] else "DownBlock2D" - down_block_types.append(block_type) - if i != len(block_out_channels) - 1: - resolution *= 2 - - up_block_types = [] - for i in range(len(block_out_channels)): - block_type = "CrossAttnUpBlock2D" if resolution in unet_params["attention_resolutions"] else "UpBlock2D" - up_block_types.append(block_type) - resolution //= 2 - - if unet_params["transformer_depth"] is not None: - transformer_layers_per_block = ( - unet_params["transformer_depth"] - if isinstance(unet_params["transformer_depth"], int) - else list(unet_params["transformer_depth"]) - ) - else: - transformer_layers_per_block = 1 - - vae_scale_factor = 2 ** (len(vae_params["ch_mult"]) - 1) - - head_dim = unet_params["num_heads"] if "num_heads" in unet_params else None - use_linear_projection = ( - unet_params["use_linear_in_transformer"] if "use_linear_in_transformer" in unet_params else False - ) - if use_linear_projection: - # stable diffusion 2-base-512 and 2-768 - if head_dim is None: - head_dim_mult = unet_params["model_channels"] // unet_params["num_head_channels"] - head_dim = [head_dim_mult * c for c in list(unet_params["channel_mult"])] - - class_embed_type = None - addition_embed_type = None - addition_time_embed_dim = None - projection_class_embeddings_input_dim = None - context_dim = None - - if unet_params["context_dim"] is not None: - context_dim = ( - unet_params["context_dim"] - if isinstance(unet_params["context_dim"], int) - else unet_params["context_dim"][0] - ) - - if "num_classes" in unet_params: - if unet_params["num_classes"] == "sequential": - if context_dim in [2048, 1280]: - # SDXL - addition_embed_type = "text_time" - addition_time_embed_dim = 256 - else: - class_embed_type = "projection" - assert "adm_in_channels" in unet_params - projection_class_embeddings_input_dim = unet_params["adm_in_channels"] - - config = { - "sample_size": image_size // vae_scale_factor, - "in_channels": unet_params["in_channels"], - "down_block_types": tuple(down_block_types), - "block_out_channels": tuple(block_out_channels), - "layers_per_block": unet_params["num_res_blocks"], - "cross_attention_dim": context_dim, - "attention_head_dim": head_dim, - "use_linear_projection": use_linear_projection, - "class_embed_type": class_embed_type, - "addition_embed_type": addition_embed_type, - "addition_time_embed_dim": addition_time_embed_dim, - "projection_class_embeddings_input_dim": projection_class_embeddings_input_dim, - "transformer_layers_per_block": transformer_layers_per_block, - } - - if "disable_self_attentions" in unet_params: - config["only_cross_attention"] = unet_params["disable_self_attentions"] - - if "num_classes" in unet_params and isinstance(unet_params["num_classes"], int): - config["num_class_embeds"] = unet_params["num_classes"] - - if controlnet: - config["conditioning_channels"] = unet_params["hint_channels"] - else: - config["out_channels"] = unet_params["out_channels"] - config["up_block_types"] = tuple(up_block_types) - - return config - - -def create_vae_diffusers_config(original_config, image_size: int): - """ - Creates a config for the diffusers based on the config of the LDM model. - """ - vae_params = original_config["model"]["params"]["first_stage_config"]["params"]["ddconfig"] - _ = original_config["model"]["params"]["first_stage_config"]["params"]["embed_dim"] - - block_out_channels = [vae_params["ch"] * mult for mult in vae_params["ch_mult"]] - down_block_types = ["DownEncoderBlock2D"] * len(block_out_channels) - up_block_types = ["UpDecoderBlock2D"] * len(block_out_channels) - - config = { - "sample_size": image_size, - "in_channels": vae_params["in_channels"], - "out_channels": vae_params["out_ch"], - "down_block_types": tuple(down_block_types), - "up_block_types": tuple(up_block_types), - "block_out_channels": tuple(block_out_channels), - "latent_channels": vae_params["z_channels"], - "layers_per_block": vae_params["num_res_blocks"], - } - return config - - -def create_diffusers_schedular(original_config): - schedular = DDIMScheduler( - num_train_timesteps=original_config["model"]["params"]["timesteps"], - beta_start=original_config["model"]["params"]["linear_start"], - beta_end=original_config["model"]["params"]["linear_end"], - beta_schedule="scaled_linear", - ) - return schedular - - -def create_ldm_bert_config(original_config): - bert_params = original_config["model"]["params"]["cond_stage_config"]["params"] - config = LDMBertConfig( - d_model=bert_params.n_embed, - encoder_layers=bert_params.n_layer, - encoder_ffn_dim=bert_params.n_embed * 4, - ) - return config - - -def convert_ldm_unet_checkpoint( - checkpoint, - config, - path=None, - extract_ema=False, - controlnet=False, - skip_extract_state_dict=False, - promptdiffusion=False, -): - """ - Takes a state dict and a config, and returns a converted checkpoint. - """ - - if skip_extract_state_dict: - unet_state_dict = checkpoint - else: - # extract state_dict for UNet - unet_state_dict = {} - keys = list(checkpoint.keys()) - - if controlnet: - unet_key = "control_model." - else: - unet_key = "model.diffusion_model." - - # at least a 100 parameters have to start with `model_ema` in order for the checkpoint to be EMA - if sum(k.startswith("model_ema") for k in keys) > 100 and extract_ema: - logger.warning(f"Checkpoint {path} has both EMA and non-EMA weights.") - logger.warning( - "In this conversion only the EMA weights are extracted. If you want to instead extract the non-EMA" - " weights (useful to continue fine-tuning), please make sure to remove the `--extract_ema` flag." - ) - for key in keys: - if key.startswith("model.diffusion_model"): - flat_ema_key = "model_ema." + "".join(key.split(".")[1:]) - unet_state_dict[key.replace(unet_key, "")] = checkpoint.pop(flat_ema_key) - else: - if sum(k.startswith("model_ema") for k in keys) > 100: - logger.warning( - "In this conversion only the non-EMA weights are extracted. If you want to instead extract the EMA" - " weights (usually better for inference), please make sure to add the `--extract_ema` flag." - ) - - for key in keys: - if key.startswith(unet_key): - unet_state_dict[key.replace(unet_key, "")] = checkpoint.pop(key) - - new_checkpoint = {} - - new_checkpoint["time_embedding.linear_1.weight"] = unet_state_dict["time_embed.0.weight"] - new_checkpoint["time_embedding.linear_1.bias"] = unet_state_dict["time_embed.0.bias"] - new_checkpoint["time_embedding.linear_2.weight"] = unet_state_dict["time_embed.2.weight"] - new_checkpoint["time_embedding.linear_2.bias"] = unet_state_dict["time_embed.2.bias"] - - if config["class_embed_type"] is None: - # No parameters to port - ... - elif config["class_embed_type"] == "timestep" or config["class_embed_type"] == "projection": - new_checkpoint["class_embedding.linear_1.weight"] = unet_state_dict["label_emb.0.0.weight"] - new_checkpoint["class_embedding.linear_1.bias"] = unet_state_dict["label_emb.0.0.bias"] - new_checkpoint["class_embedding.linear_2.weight"] = unet_state_dict["label_emb.0.2.weight"] - new_checkpoint["class_embedding.linear_2.bias"] = unet_state_dict["label_emb.0.2.bias"] - else: - raise NotImplementedError(f"Not implemented `class_embed_type`: {config['class_embed_type']}") - - if config["addition_embed_type"] == "text_time": - new_checkpoint["add_embedding.linear_1.weight"] = unet_state_dict["label_emb.0.0.weight"] - new_checkpoint["add_embedding.linear_1.bias"] = unet_state_dict["label_emb.0.0.bias"] - new_checkpoint["add_embedding.linear_2.weight"] = unet_state_dict["label_emb.0.2.weight"] - new_checkpoint["add_embedding.linear_2.bias"] = unet_state_dict["label_emb.0.2.bias"] - - # Relevant to StableDiffusionUpscalePipeline - if "num_class_embeds" in config: - if (config["num_class_embeds"] is not None) and ("label_emb.weight" in unet_state_dict): - new_checkpoint["class_embedding.weight"] = unet_state_dict["label_emb.weight"] - - new_checkpoint["conv_in.weight"] = unet_state_dict["input_blocks.0.0.weight"] - new_checkpoint["conv_in.bias"] = unet_state_dict["input_blocks.0.0.bias"] - - if not controlnet: - new_checkpoint["conv_norm_out.weight"] = unet_state_dict["out.0.weight"] - new_checkpoint["conv_norm_out.bias"] = unet_state_dict["out.0.bias"] - new_checkpoint["conv_out.weight"] = unet_state_dict["out.2.weight"] - new_checkpoint["conv_out.bias"] = unet_state_dict["out.2.bias"] - - # Retrieves the keys for the input blocks only - num_input_blocks = len({".".join(layer.split(".")[:2]) for layer in unet_state_dict if "input_blocks" in layer}) - input_blocks = { - layer_id: [key for key in unet_state_dict if f"input_blocks.{layer_id}" in key] - for layer_id in range(num_input_blocks) - } - - # Retrieves the keys for the middle blocks only - num_middle_blocks = len({".".join(layer.split(".")[:2]) for layer in unet_state_dict if "middle_block" in layer}) - middle_blocks = { - layer_id: [key for key in unet_state_dict if f"middle_block.{layer_id}" in key] - for layer_id in range(num_middle_blocks) - } - - # Retrieves the keys for the output blocks only - num_output_blocks = len({".".join(layer.split(".")[:2]) for layer in unet_state_dict if "output_blocks" in layer}) - output_blocks = { - layer_id: [key for key in unet_state_dict if f"output_blocks.{layer_id}" in key] - for layer_id in range(num_output_blocks) - } - - for i in range(1, num_input_blocks): - block_id = (i - 1) // (config["layers_per_block"] + 1) - layer_in_block_id = (i - 1) % (config["layers_per_block"] + 1) - - resnets = [ - key for key in input_blocks[i] if f"input_blocks.{i}.0" in key and f"input_blocks.{i}.0.op" not in key - ] - attentions = [key for key in input_blocks[i] if f"input_blocks.{i}.1" in key] - - if f"input_blocks.{i}.0.op.weight" in unet_state_dict: - new_checkpoint[f"down_blocks.{block_id}.downsamplers.0.conv.weight"] = unet_state_dict.pop( - f"input_blocks.{i}.0.op.weight" - ) - new_checkpoint[f"down_blocks.{block_id}.downsamplers.0.conv.bias"] = unet_state_dict.pop( - f"input_blocks.{i}.0.op.bias" - ) - - paths = renew_resnet_paths(resnets) - meta_path = {"old": f"input_blocks.{i}.0", "new": f"down_blocks.{block_id}.resnets.{layer_in_block_id}"} - assign_to_checkpoint( - paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config - ) - - if len(attentions): - paths = renew_attention_paths(attentions) - - meta_path = {"old": f"input_blocks.{i}.1", "new": f"down_blocks.{block_id}.attentions.{layer_in_block_id}"} - assign_to_checkpoint( - paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config - ) - - resnet_0 = middle_blocks[0] - attentions = middle_blocks[1] - resnet_1 = middle_blocks[2] - - resnet_0_paths = renew_resnet_paths(resnet_0) - assign_to_checkpoint(resnet_0_paths, new_checkpoint, unet_state_dict, config=config) - - resnet_1_paths = renew_resnet_paths(resnet_1) - assign_to_checkpoint(resnet_1_paths, new_checkpoint, unet_state_dict, config=config) - - attentions_paths = renew_attention_paths(attentions) - meta_path = {"old": "middle_block.1", "new": "mid_block.attentions.0"} - assign_to_checkpoint( - attentions_paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config - ) - - for i in range(num_output_blocks): - block_id = i // (config["layers_per_block"] + 1) - layer_in_block_id = i % (config["layers_per_block"] + 1) - output_block_layers = [shave_segments(name, 2) for name in output_blocks[i]] - output_block_list = {} - - for layer in output_block_layers: - layer_id, layer_name = layer.split(".")[0], shave_segments(layer, 1) - if layer_id in output_block_list: - output_block_list[layer_id].append(layer_name) - else: - output_block_list[layer_id] = [layer_name] - - if len(output_block_list) > 1: - resnets = [key for key in output_blocks[i] if f"output_blocks.{i}.0" in key] - attentions = [key for key in output_blocks[i] if f"output_blocks.{i}.1" in key] - - resnet_0_paths = renew_resnet_paths(resnets) - paths = renew_resnet_paths(resnets) - - meta_path = {"old": f"output_blocks.{i}.0", "new": f"up_blocks.{block_id}.resnets.{layer_in_block_id}"} - assign_to_checkpoint( - paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config - ) - - output_block_list = {k: sorted(v) for k, v in output_block_list.items()} - if ["conv.bias", "conv.weight"] in output_block_list.values(): - index = list(output_block_list.values()).index(["conv.bias", "conv.weight"]) - new_checkpoint[f"up_blocks.{block_id}.upsamplers.0.conv.weight"] = unet_state_dict[ - f"output_blocks.{i}.{index}.conv.weight" - ] - new_checkpoint[f"up_blocks.{block_id}.upsamplers.0.conv.bias"] = unet_state_dict[ - f"output_blocks.{i}.{index}.conv.bias" - ] - - # Clear attentions as they have been attributed above. - if len(attentions) == 2: - attentions = [] - - if len(attentions): - paths = renew_attention_paths(attentions) - meta_path = { - "old": f"output_blocks.{i}.1", - "new": f"up_blocks.{block_id}.attentions.{layer_in_block_id}", - } - assign_to_checkpoint( - paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config - ) - else: - resnet_0_paths = renew_resnet_paths(output_block_layers, n_shave_prefix_segments=1) - for path in resnet_0_paths: - old_path = ".".join(["output_blocks", str(i), path["old"]]) - new_path = ".".join(["up_blocks", str(block_id), "resnets", str(layer_in_block_id), path["new"]]) - - new_checkpoint[new_path] = unet_state_dict[old_path] - - if controlnet and not promptdiffusion: - # conditioning embedding - - orig_index = 0 - - new_checkpoint["controlnet_cond_embedding.conv_in.weight"] = unet_state_dict.pop( - f"input_hint_block.{orig_index}.weight" - ) - new_checkpoint["controlnet_cond_embedding.conv_in.bias"] = unet_state_dict.pop( - f"input_hint_block.{orig_index}.bias" - ) - - orig_index += 2 - - diffusers_index = 0 - - while diffusers_index < 6: - new_checkpoint[f"controlnet_cond_embedding.blocks.{diffusers_index}.weight"] = unet_state_dict.pop( - f"input_hint_block.{orig_index}.weight" - ) - new_checkpoint[f"controlnet_cond_embedding.blocks.{diffusers_index}.bias"] = unet_state_dict.pop( - f"input_hint_block.{orig_index}.bias" - ) - diffusers_index += 1 - orig_index += 2 - - new_checkpoint["controlnet_cond_embedding.conv_out.weight"] = unet_state_dict.pop( - f"input_hint_block.{orig_index}.weight" - ) - new_checkpoint["controlnet_cond_embedding.conv_out.bias"] = unet_state_dict.pop( - f"input_hint_block.{orig_index}.bias" - ) - - # down blocks - for i in range(num_input_blocks): - new_checkpoint[f"controlnet_down_blocks.{i}.weight"] = unet_state_dict.pop(f"zero_convs.{i}.0.weight") - new_checkpoint[f"controlnet_down_blocks.{i}.bias"] = unet_state_dict.pop(f"zero_convs.{i}.0.bias") - - # mid block - new_checkpoint["controlnet_mid_block.weight"] = unet_state_dict.pop("middle_block_out.0.weight") - new_checkpoint["controlnet_mid_block.bias"] = unet_state_dict.pop("middle_block_out.0.bias") - - if promptdiffusion: - # conditioning embedding - - orig_index = 0 - - new_checkpoint["controlnet_cond_embedding.conv_in.weight"] = unet_state_dict.pop( - f"input_hint_block.{orig_index}.weight" - ) - new_checkpoint["controlnet_cond_embedding.conv_in.bias"] = unet_state_dict.pop( - f"input_hint_block.{orig_index}.bias" - ) - - new_checkpoint["controlnet_query_cond_embedding.conv_in.weight"] = unet_state_dict.pop( - f"input_cond_block.{orig_index}.weight" - ) - new_checkpoint["controlnet_query_cond_embedding.conv_in.bias"] = unet_state_dict.pop( - f"input_cond_block.{orig_index}.bias" - ) - orig_index += 2 - - diffusers_index = 0 - - while diffusers_index < 6: - new_checkpoint[f"controlnet_cond_embedding.blocks.{diffusers_index}.weight"] = unet_state_dict.pop( - f"input_hint_block.{orig_index}.weight" - ) - new_checkpoint[f"controlnet_cond_embedding.blocks.{diffusers_index}.bias"] = unet_state_dict.pop( - f"input_hint_block.{orig_index}.bias" - ) - new_checkpoint[f"controlnet_query_cond_embedding.blocks.{diffusers_index}.weight"] = unet_state_dict.pop( - f"input_cond_block.{orig_index}.weight" - ) - new_checkpoint[f"controlnet_query_cond_embedding.blocks.{diffusers_index}.bias"] = unet_state_dict.pop( - f"input_cond_block.{orig_index}.bias" - ) - diffusers_index += 1 - orig_index += 2 - - new_checkpoint["controlnet_cond_embedding.conv_out.weight"] = unet_state_dict.pop( - f"input_hint_block.{orig_index}.weight" - ) - new_checkpoint["controlnet_cond_embedding.conv_out.bias"] = unet_state_dict.pop( - f"input_hint_block.{orig_index}.bias" - ) - - new_checkpoint["controlnet_query_cond_embedding.conv_out.weight"] = unet_state_dict.pop( - f"input_cond_block.{orig_index}.weight" - ) - new_checkpoint["controlnet_query_cond_embedding.conv_out.bias"] = unet_state_dict.pop( - f"input_cond_block.{orig_index}.bias" - ) - # down blocks - for i in range(num_input_blocks): - new_checkpoint[f"controlnet_down_blocks.{i}.weight"] = unet_state_dict.pop(f"zero_convs.{i}.0.weight") - new_checkpoint[f"controlnet_down_blocks.{i}.bias"] = unet_state_dict.pop(f"zero_convs.{i}.0.bias") - - # mid block - new_checkpoint["controlnet_mid_block.weight"] = unet_state_dict.pop("middle_block_out.0.weight") - new_checkpoint["controlnet_mid_block.bias"] = unet_state_dict.pop("middle_block_out.0.bias") - - return new_checkpoint - - -def convert_ldm_vae_checkpoint(checkpoint, config): - # extract state dict for VAE - vae_state_dict = {} - keys = list(checkpoint.keys()) - vae_key = "first_stage_model." if any(k.startswith("first_stage_model.") for k in keys) else "" - for key in keys: - if key.startswith(vae_key): - vae_state_dict[key.replace(vae_key, "")] = checkpoint.get(key) - - new_checkpoint = {} - - new_checkpoint["encoder.conv_in.weight"] = vae_state_dict["encoder.conv_in.weight"] - new_checkpoint["encoder.conv_in.bias"] = vae_state_dict["encoder.conv_in.bias"] - new_checkpoint["encoder.conv_out.weight"] = vae_state_dict["encoder.conv_out.weight"] - new_checkpoint["encoder.conv_out.bias"] = vae_state_dict["encoder.conv_out.bias"] - new_checkpoint["encoder.conv_norm_out.weight"] = vae_state_dict["encoder.norm_out.weight"] - new_checkpoint["encoder.conv_norm_out.bias"] = vae_state_dict["encoder.norm_out.bias"] - - new_checkpoint["decoder.conv_in.weight"] = vae_state_dict["decoder.conv_in.weight"] - new_checkpoint["decoder.conv_in.bias"] = vae_state_dict["decoder.conv_in.bias"] - new_checkpoint["decoder.conv_out.weight"] = vae_state_dict["decoder.conv_out.weight"] - new_checkpoint["decoder.conv_out.bias"] = vae_state_dict["decoder.conv_out.bias"] - new_checkpoint["decoder.conv_norm_out.weight"] = vae_state_dict["decoder.norm_out.weight"] - new_checkpoint["decoder.conv_norm_out.bias"] = vae_state_dict["decoder.norm_out.bias"] - - new_checkpoint["quant_conv.weight"] = vae_state_dict["quant_conv.weight"] - new_checkpoint["quant_conv.bias"] = vae_state_dict["quant_conv.bias"] - new_checkpoint["post_quant_conv.weight"] = vae_state_dict["post_quant_conv.weight"] - new_checkpoint["post_quant_conv.bias"] = vae_state_dict["post_quant_conv.bias"] - - # Retrieves the keys for the encoder down blocks only - num_down_blocks = len({".".join(layer.split(".")[:3]) for layer in vae_state_dict if "encoder.down" in layer}) - down_blocks = { - layer_id: [key for key in vae_state_dict if f"down.{layer_id}" in key] for layer_id in range(num_down_blocks) - } - - # Retrieves the keys for the decoder up blocks only - num_up_blocks = len({".".join(layer.split(".")[:3]) for layer in vae_state_dict if "decoder.up" in layer}) - up_blocks = { - layer_id: [key for key in vae_state_dict if f"up.{layer_id}" in key] for layer_id in range(num_up_blocks) - } - - for i in range(num_down_blocks): - resnets = [key for key in down_blocks[i] if f"down.{i}" in key and f"down.{i}.downsample" not in key] - - if f"encoder.down.{i}.downsample.conv.weight" in vae_state_dict: - new_checkpoint[f"encoder.down_blocks.{i}.downsamplers.0.conv.weight"] = vae_state_dict.pop( - f"encoder.down.{i}.downsample.conv.weight" - ) - new_checkpoint[f"encoder.down_blocks.{i}.downsamplers.0.conv.bias"] = vae_state_dict.pop( - f"encoder.down.{i}.downsample.conv.bias" - ) - - paths = renew_vae_resnet_paths(resnets) - meta_path = {"old": f"down.{i}.block", "new": f"down_blocks.{i}.resnets"} - assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config) - - mid_resnets = [key for key in vae_state_dict if "encoder.mid.block" in key] - num_mid_res_blocks = 2 - for i in range(1, num_mid_res_blocks + 1): - resnets = [key for key in mid_resnets if f"encoder.mid.block_{i}" in key] - - paths = renew_vae_resnet_paths(resnets) - meta_path = {"old": f"mid.block_{i}", "new": f"mid_block.resnets.{i - 1}"} - assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config) - - mid_attentions = [key for key in vae_state_dict if "encoder.mid.attn" in key] - paths = renew_vae_attention_paths(mid_attentions) - meta_path = {"old": "mid.attn_1", "new": "mid_block.attentions.0"} - assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config) - conv_attn_to_linear(new_checkpoint) - - for i in range(num_up_blocks): - block_id = num_up_blocks - 1 - i - resnets = [ - key for key in up_blocks[block_id] if f"up.{block_id}" in key and f"up.{block_id}.upsample" not in key - ] - - if f"decoder.up.{block_id}.upsample.conv.weight" in vae_state_dict: - new_checkpoint[f"decoder.up_blocks.{i}.upsamplers.0.conv.weight"] = vae_state_dict[ - f"decoder.up.{block_id}.upsample.conv.weight" - ] - new_checkpoint[f"decoder.up_blocks.{i}.upsamplers.0.conv.bias"] = vae_state_dict[ - f"decoder.up.{block_id}.upsample.conv.bias" - ] - - paths = renew_vae_resnet_paths(resnets) - meta_path = {"old": f"up.{block_id}.block", "new": f"up_blocks.{i}.resnets"} - assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config) - - mid_resnets = [key for key in vae_state_dict if "decoder.mid.block" in key] - num_mid_res_blocks = 2 - for i in range(1, num_mid_res_blocks + 1): - resnets = [key for key in mid_resnets if f"decoder.mid.block_{i}" in key] - - paths = renew_vae_resnet_paths(resnets) - meta_path = {"old": f"mid.block_{i}", "new": f"mid_block.resnets.{i - 1}"} - assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config) - - mid_attentions = [key for key in vae_state_dict if "decoder.mid.attn" in key] - paths = renew_vae_attention_paths(mid_attentions) - meta_path = {"old": "mid.attn_1", "new": "mid_block.attentions.0"} - assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config) - conv_attn_to_linear(new_checkpoint) - return new_checkpoint - - -def convert_ldm_bert_checkpoint(checkpoint, config): - def _copy_attn_layer(hf_attn_layer, pt_attn_layer): - hf_attn_layer.q_proj.weight.data = pt_attn_layer.to_q.weight - hf_attn_layer.k_proj.weight.data = pt_attn_layer.to_k.weight - hf_attn_layer.v_proj.weight.data = pt_attn_layer.to_v.weight - - hf_attn_layer.out_proj.weight = pt_attn_layer.to_out.weight - hf_attn_layer.out_proj.bias = pt_attn_layer.to_out.bias - - def _copy_linear(hf_linear, pt_linear): - hf_linear.weight = pt_linear.weight - hf_linear.bias = pt_linear.bias - - def _copy_layer(hf_layer, pt_layer): - # copy layer norms - _copy_linear(hf_layer.self_attn_layer_norm, pt_layer[0][0]) - _copy_linear(hf_layer.final_layer_norm, pt_layer[1][0]) - - # copy attn - _copy_attn_layer(hf_layer.self_attn, pt_layer[0][1]) - - # copy MLP - pt_mlp = pt_layer[1][1] - _copy_linear(hf_layer.fc1, pt_mlp.net[0][0]) - _copy_linear(hf_layer.fc2, pt_mlp.net[2]) - - def _copy_layers(hf_layers, pt_layers): - for i, hf_layer in enumerate(hf_layers): - if i != 0: - i += i - pt_layer = pt_layers[i : i + 2] - _copy_layer(hf_layer, pt_layer) - - hf_model = LDMBertModel(config).eval() - - # copy embeds - hf_model.model.embed_tokens.weight = checkpoint.transformer.token_emb.weight - hf_model.model.embed_positions.weight.data = checkpoint.transformer.pos_emb.emb.weight - - # copy layer norm - _copy_linear(hf_model.model.layer_norm, checkpoint.transformer.norm) - - # copy hidden layers - _copy_layers(hf_model.model.layers, checkpoint.transformer.attn_layers.layers) - - _copy_linear(hf_model.to_logits, checkpoint.transformer.to_logits) - - return hf_model - - -def convert_ldm_clip_checkpoint(checkpoint, local_files_only=False, text_encoder=None): - if text_encoder is None: - config_name = "openai/clip-vit-large-patch14" - try: - config = CLIPTextConfig.from_pretrained(config_name, local_files_only=local_files_only) - except Exception: - raise ValueError( - f"With local_files_only set to {local_files_only}, you must first locally save the configuration in the following path: 'openai/clip-vit-large-patch14'." - ) - - ctx = init_empty_weights if is_accelerate_available() else nullcontext - with ctx(): - text_model = CLIPTextModel(config) - else: - text_model = text_encoder - - keys = list(checkpoint.keys()) - - text_model_dict = {} - - remove_prefixes = ["cond_stage_model.transformer", "conditioner.embedders.0.transformer"] - - for key in keys: - for prefix in remove_prefixes: - if key.startswith(prefix): - text_model_dict[key[len(prefix + ".") :]] = checkpoint[key] - - if is_accelerate_available(): - for param_name, param in text_model_dict.items(): - set_module_tensor_to_device(text_model, param_name, "cpu", value=param) - else: - if not (hasattr(text_model, "embeddings") and hasattr(text_model.embeddings.position_ids)): - text_model_dict.pop("text_model.embeddings.position_ids", None) - - text_model.load_state_dict(text_model_dict) - - return text_model - - -textenc_conversion_lst = [ - ("positional_embedding", "text_model.embeddings.position_embedding.weight"), - ("token_embedding.weight", "text_model.embeddings.token_embedding.weight"), - ("ln_final.weight", "text_model.final_layer_norm.weight"), - ("ln_final.bias", "text_model.final_layer_norm.bias"), - ("text_projection", "text_projection.weight"), -] -textenc_conversion_map = {x[0]: x[1] for x in textenc_conversion_lst} - -textenc_transformer_conversion_lst = [ - # (stable-diffusion, HF Diffusers) - ("resblocks.", "text_model.encoder.layers."), - ("ln_1", "layer_norm1"), - ("ln_2", "layer_norm2"), - (".c_fc.", ".fc1."), - (".c_proj.", ".fc2."), - (".attn", ".self_attn"), - ("ln_final.", "transformer.text_model.final_layer_norm."), - ("token_embedding.weight", "transformer.text_model.embeddings.token_embedding.weight"), - ("positional_embedding", "transformer.text_model.embeddings.position_embedding.weight"), -] -protected = {re.escape(x[0]): x[1] for x in textenc_transformer_conversion_lst} -textenc_pattern = re.compile("|".join(protected.keys())) - - -def convert_paint_by_example_checkpoint(checkpoint, local_files_only=False): - config = CLIPVisionConfig.from_pretrained("openai/clip-vit-large-patch14", local_files_only=local_files_only) - model = PaintByExampleImageEncoder(config) - - keys = list(checkpoint.keys()) - - text_model_dict = {} - - for key in keys: - if key.startswith("cond_stage_model.transformer"): - text_model_dict[key[len("cond_stage_model.transformer.") :]] = checkpoint[key] - - # load clip vision - model.model.load_state_dict(text_model_dict) - - # load mapper - keys_mapper = { - k[len("cond_stage_model.mapper.res") :]: v - for k, v in checkpoint.items() - if k.startswith("cond_stage_model.mapper") - } - - MAPPING = { - "attn.c_qkv": ["attn1.to_q", "attn1.to_k", "attn1.to_v"], - "attn.c_proj": ["attn1.to_out.0"], - "ln_1": ["norm1"], - "ln_2": ["norm3"], - "mlp.c_fc": ["ff.net.0.proj"], - "mlp.c_proj": ["ff.net.2"], - } - - mapped_weights = {} - for key, value in keys_mapper.items(): - prefix = key[: len("blocks.i")] - suffix = key.split(prefix)[-1].split(".")[-1] - name = key.split(prefix)[-1].split(suffix)[0][1:-1] - mapped_names = MAPPING[name] - - num_splits = len(mapped_names) - for i, mapped_name in enumerate(mapped_names): - new_name = ".".join([prefix, mapped_name, suffix]) - shape = value.shape[0] // num_splits - mapped_weights[new_name] = value[i * shape : (i + 1) * shape] - - model.mapper.load_state_dict(mapped_weights) - - # load final layer norm - model.final_layer_norm.load_state_dict( - { - "bias": checkpoint["cond_stage_model.final_ln.bias"], - "weight": checkpoint["cond_stage_model.final_ln.weight"], - } - ) - - # load final proj - model.proj_out.load_state_dict( - { - "bias": checkpoint["proj_out.bias"], - "weight": checkpoint["proj_out.weight"], - } - ) - - # load uncond vector - model.uncond_vector.data = torch.nn.Parameter(checkpoint["learnable_vector"]) - return model - - -def convert_open_clip_checkpoint( - checkpoint, - config_name, - prefix="cond_stage_model.model.", - has_projection=False, - local_files_only=False, - **config_kwargs, -): - # text_model = CLIPTextModel.from_pretrained("stabilityai/stable-diffusion-2", subfolder="text_encoder") - # text_model = CLIPTextModelWithProjection.from_pretrained( - # "laion/CLIP-ViT-bigG-14-laion2B-39B-b160k", projection_dim=1280 - # ) - try: - config = CLIPTextConfig.from_pretrained(config_name, **config_kwargs, local_files_only=local_files_only) - except Exception: - raise ValueError( - f"With local_files_only set to {local_files_only}, you must first locally save the configuration in the following path: '{config_name}'." - ) - - ctx = init_empty_weights if is_accelerate_available() else nullcontext - with ctx(): - text_model = CLIPTextModelWithProjection(config) if has_projection else CLIPTextModel(config) - - keys = list(checkpoint.keys()) - - keys_to_ignore = [] - if config_name == "stabilityai/stable-diffusion-2" and config.num_hidden_layers == 23: - # make sure to remove all keys > 22 - keys_to_ignore += [k for k in keys if k.startswith("cond_stage_model.model.transformer.resblocks.23")] - keys_to_ignore += ["cond_stage_model.model.text_projection"] - - text_model_dict = {} - - if prefix + "text_projection" in checkpoint: - d_model = int(checkpoint[prefix + "text_projection"].shape[0]) - else: - d_model = 1024 - - text_model_dict["text_model.embeddings.position_ids"] = text_model.text_model.embeddings.get_buffer("position_ids") - - for key in keys: - if key in keys_to_ignore: - continue - if key[len(prefix) :] in textenc_conversion_map: - if key.endswith("text_projection"): - value = checkpoint[key].T.contiguous() - else: - value = checkpoint[key] - - text_model_dict[textenc_conversion_map[key[len(prefix) :]]] = value - - if key.startswith(prefix + "transformer."): - new_key = key[len(prefix + "transformer.") :] - if new_key.endswith(".in_proj_weight"): - new_key = new_key[: -len(".in_proj_weight")] - new_key = textenc_pattern.sub(lambda m: protected[re.escape(m.group(0))], new_key) - text_model_dict[new_key + ".q_proj.weight"] = checkpoint[key][:d_model, :] - text_model_dict[new_key + ".k_proj.weight"] = checkpoint[key][d_model : d_model * 2, :] - text_model_dict[new_key + ".v_proj.weight"] = checkpoint[key][d_model * 2 :, :] - elif new_key.endswith(".in_proj_bias"): - new_key = new_key[: -len(".in_proj_bias")] - new_key = textenc_pattern.sub(lambda m: protected[re.escape(m.group(0))], new_key) - text_model_dict[new_key + ".q_proj.bias"] = checkpoint[key][:d_model] - text_model_dict[new_key + ".k_proj.bias"] = checkpoint[key][d_model : d_model * 2] - text_model_dict[new_key + ".v_proj.bias"] = checkpoint[key][d_model * 2 :] - else: - new_key = textenc_pattern.sub(lambda m: protected[re.escape(m.group(0))], new_key) - - text_model_dict[new_key] = checkpoint[key] - - if is_accelerate_available(): - for param_name, param in text_model_dict.items(): - set_module_tensor_to_device(text_model, param_name, "cpu", value=param) - else: - if not (hasattr(text_model, "embeddings") and hasattr(text_model.embeddings.position_ids)): - text_model_dict.pop("text_model.embeddings.position_ids", None) - - text_model.load_state_dict(text_model_dict) - - return text_model - - -def stable_unclip_image_encoder(original_config, local_files_only=False): - """ - Returns the image processor and clip image encoder for the img2img unclip pipeline. - - We currently know of two types of stable unclip models which separately use the clip and the openclip image - encoders. - """ - - image_embedder_config = original_config["model"]["params"]["embedder_config"] - - sd_clip_image_embedder_class = image_embedder_config["target"] - sd_clip_image_embedder_class = sd_clip_image_embedder_class.split(".")[-1] - - if sd_clip_image_embedder_class == "ClipImageEmbedder": - clip_model_name = image_embedder_config.params.model - - if clip_model_name == "ViT-L/14": - feature_extractor = CLIPImageProcessor() - image_encoder = CLIPVisionModelWithProjection.from_pretrained( - "openai/clip-vit-large-patch14", local_files_only=local_files_only - ) - else: - raise NotImplementedError(f"Unknown CLIP checkpoint name in stable diffusion checkpoint {clip_model_name}") - - elif sd_clip_image_embedder_class == "FrozenOpenCLIPImageEmbedder": - feature_extractor = CLIPImageProcessor() - image_encoder = CLIPVisionModelWithProjection.from_pretrained( - "laion/CLIP-ViT-H-14-laion2B-s32B-b79K", local_files_only=local_files_only - ) - else: - raise NotImplementedError( - f"Unknown CLIP image embedder class in stable diffusion checkpoint {sd_clip_image_embedder_class}" - ) - - return feature_extractor, image_encoder - - -def stable_unclip_image_noising_components( - original_config, clip_stats_path: Optional[str] = None, device: Optional[str] = None -): - """ - Returns the noising components for the img2img and txt2img unclip pipelines. - - Converts the stability noise augmentor into - 1. a `StableUnCLIPImageNormalizer` for holding the CLIP stats - 2. a `DDPMScheduler` for holding the noise schedule - - If the noise augmentor config specifies a clip stats path, the `clip_stats_path` must be provided. - """ - noise_aug_config = original_config["model"]["params"]["noise_aug_config"] - noise_aug_class = noise_aug_config["target"] - noise_aug_class = noise_aug_class.split(".")[-1] - - if noise_aug_class == "CLIPEmbeddingNoiseAugmentation": - noise_aug_config = noise_aug_config.params - embedding_dim = noise_aug_config.timestep_dim - max_noise_level = noise_aug_config.noise_schedule_config.timesteps - beta_schedule = noise_aug_config.noise_schedule_config.beta_schedule - - image_normalizer = StableUnCLIPImageNormalizer(embedding_dim=embedding_dim) - image_noising_scheduler = DDPMScheduler(num_train_timesteps=max_noise_level, beta_schedule=beta_schedule) - - if "clip_stats_path" in noise_aug_config: - if clip_stats_path is None: - raise ValueError("This stable unclip config requires a `clip_stats_path`") - - clip_mean, clip_std = torch.load(clip_stats_path, map_location=device) - clip_mean = clip_mean[None, :] - clip_std = clip_std[None, :] - - clip_stats_state_dict = { - "mean": clip_mean, - "std": clip_std, - } - - image_normalizer.load_state_dict(clip_stats_state_dict) - else: - raise NotImplementedError(f"Unknown noise augmentor class: {noise_aug_class}") - - return image_normalizer, image_noising_scheduler - - -def convert_controlnet_checkpoint( - checkpoint, - original_config, - checkpoint_path, - image_size, - upcast_attention, - extract_ema, - use_linear_projection=None, - cross_attention_dim=None, -): - ctrlnet_config = create_unet_diffusers_config(original_config, image_size=image_size, controlnet=True) - ctrlnet_config["upcast_attention"] = upcast_attention - - ctrlnet_config.pop("sample_size") - - if use_linear_projection is not None: - ctrlnet_config["use_linear_projection"] = use_linear_projection - - if cross_attention_dim is not None: - ctrlnet_config["cross_attention_dim"] = cross_attention_dim - - ctx = init_empty_weights if is_accelerate_available() else nullcontext - with ctx(): - controlnet = ControlNetModel(**ctrlnet_config) - - # Some controlnet ckpt files are distributed independently from the rest of the - # model components i.e. https://huggingface.co/thibaud/controlnet-sd21/ - if "time_embed.0.weight" in checkpoint: - skip_extract_state_dict = True - else: - skip_extract_state_dict = False - - converted_ctrl_checkpoint = convert_ldm_unet_checkpoint( - checkpoint, - ctrlnet_config, - path=checkpoint_path, - extract_ema=extract_ema, - controlnet=True, - skip_extract_state_dict=skip_extract_state_dict, - ) - - if is_accelerate_available(): - for param_name, param in converted_ctrl_checkpoint.items(): - set_module_tensor_to_device(controlnet, param_name, "cpu", value=param) - else: - controlnet.load_state_dict(converted_ctrl_checkpoint) - - return controlnet - - -def convert_promptdiffusion_checkpoint( - checkpoint, - original_config, - checkpoint_path, - image_size, - upcast_attention, - extract_ema, - use_linear_projection=None, - cross_attention_dim=None, -): - ctrlnet_config = create_unet_diffusers_config(original_config, image_size=image_size, controlnet=True) - ctrlnet_config["upcast_attention"] = upcast_attention - - ctrlnet_config.pop("sample_size") - - if use_linear_projection is not None: - ctrlnet_config["use_linear_projection"] = use_linear_projection - - if cross_attention_dim is not None: - ctrlnet_config["cross_attention_dim"] = cross_attention_dim - - ctx = init_empty_weights if is_accelerate_available() else nullcontext - with ctx(): - controlnet = TextControlNetModel(**ctrlnet_config) - - # Some controlnet ckpt files are distributed independently from the rest of the - # model components i.e. https://huggingface.co/thibaud/controlnet-sd21/ - if "time_embed.0.weight" in checkpoint: - skip_extract_state_dict = True - else: - skip_extract_state_dict = False - - converted_ctrl_checkpoint = convert_ldm_unet_checkpoint( - checkpoint, - ctrlnet_config, - path=checkpoint_path, - extract_ema=extract_ema, - promptdiffusion=True, - controlnet=True, - skip_extract_state_dict=skip_extract_state_dict, - ) - - if is_accelerate_available(): - for param_name, param in converted_ctrl_checkpoint.items(): - set_module_tensor_to_device(controlnet, param_name, "cpu", value=param) - else: - controlnet.load_state_dict(converted_ctrl_checkpoint) - - return controlnet - - -def download_from_original_stable_diffusion_ckpt( - checkpoint_path_or_dict: Union[str, Dict[str, torch.Tensor]], - original_config_file: str = None, - image_size: Optional[int] = None, - prediction_type: str = None, - model_type: str = None, - extract_ema: bool = False, - scheduler_type: str = "pndm", - num_in_channels: Optional[int] = None, - upcast_attention: Optional[bool] = None, - device: str = None, - from_safetensors: bool = False, - stable_unclip: Optional[str] = None, - stable_unclip_prior: Optional[str] = None, - clip_stats_path: Optional[str] = None, - controlnet: Optional[bool] = None, - adapter: Optional[bool] = None, - load_safety_checker: bool = True, - pipeline_class: DiffusionPipeline = None, - local_files_only=False, - vae_path=None, - vae=None, - text_encoder=None, - text_encoder_2=None, - tokenizer=None, - tokenizer_2=None, - config_files=None, -) -> DiffusionPipeline: - """ - Load a Stable Diffusion pipeline object from a CompVis-style `.ckpt`/`.safetensors` file and (ideally) a `.yaml` - config file. - - Although many of the arguments can be automatically inferred, some of these rely on brittle checks against the - global step count, which will likely fail for models that have undergone further fine-tuning. Therefore, it is - recommended that you override the default values and/or supply an `original_config_file` wherever possible. - - Args: - checkpoint_path_or_dict (`str` or `dict`): Path to `.ckpt` file, or the state dict. - original_config_file (`str`): - Path to `.yaml` config file corresponding to the original architecture. If `None`, will be automatically - inferred by looking for a key that only exists in SD2.0 models. - image_size (`int`, *optional*, defaults to 512): - The image size that the model was trained on. Use 512 for Stable Diffusion v1.X and Stable Diffusion v2 - Base. Use 768 for Stable Diffusion v2. - prediction_type (`str`, *optional*): - The prediction type that the model was trained on. Use `'epsilon'` for Stable Diffusion v1.X and Stable - Diffusion v2 Base. Use `'v_prediction'` for Stable Diffusion v2. - num_in_channels (`int`, *optional*, defaults to None): - The number of input channels. If `None`, it will be automatically inferred. - scheduler_type (`str`, *optional*, defaults to 'pndm'): - Type of scheduler to use. Should be one of `["pndm", "lms", "heun", "euler", "euler-ancestral", "dpm", - "ddim"]`. - model_type (`str`, *optional*, defaults to `None`): - The pipeline type. `None` to automatically infer, or one of `["FrozenOpenCLIPEmbedder", - "FrozenCLIPEmbedder", "PaintByExample"]`. - is_img2img (`bool`, *optional*, defaults to `False`): - Whether the model should be loaded as an img2img pipeline. - extract_ema (`bool`, *optional*, defaults to `False`): Only relevant for - checkpoints that have both EMA and non-EMA weights. Whether to extract the EMA weights or not. Defaults to - `False`. Pass `True` to extract the EMA weights. EMA weights usually yield higher quality images for - inference. Non-EMA weights are usually better to continue fine-tuning. - upcast_attention (`bool`, *optional*, defaults to `None`): - Whether the attention computation should always be upcasted. This is necessary when running stable - diffusion 2.1. - device (`str`, *optional*, defaults to `None`): - The device to use. Pass `None` to determine automatically. - from_safetensors (`str`, *optional*, defaults to `False`): - If `checkpoint_path` is in `safetensors` format, load checkpoint with safetensors instead of PyTorch. - load_safety_checker (`bool`, *optional*, defaults to `True`): - Whether to load the safety checker or not. Defaults to `True`. - pipeline_class (`str`, *optional*, defaults to `None`): - The pipeline class to use. Pass `None` to determine automatically. - local_files_only (`bool`, *optional*, defaults to `False`): - Whether or not to only look at local files (i.e., do not try to download the model). - vae (`AutoencoderKL`, *optional*, defaults to `None`): - Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations. If - this parameter is `None`, the function will load a new instance of [CLIP] by itself, if needed. - text_encoder (`CLIPTextModel`, *optional*, defaults to `None`): - An instance of [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel) - to use, specifically the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) - variant. If this parameter is `None`, the function will load a new instance of [CLIP] by itself, if needed. - tokenizer (`CLIPTokenizer`, *optional*, defaults to `None`): - An instance of - [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer) - to use. If this parameter is `None`, the function will load a new instance of [CLIPTokenizer] by itself, if - needed. - config_files (`Dict[str, str]`, *optional*, defaults to `None`): - A dictionary mapping from config file names to their contents. If this parameter is `None`, the function - will load the config files by itself, if needed. Valid keys are: - - `v1`: Config file for Stable Diffusion v1 - - `v2`: Config file for Stable Diffusion v2 - - `xl`: Config file for Stable Diffusion XL - - `xl_refiner`: Config file for Stable Diffusion XL Refiner - return: A StableDiffusionPipeline object representing the passed-in `.ckpt`/`.safetensors` file. - """ - - # import pipelines here to avoid circular import error when using from_single_file method - from diffusers import ( - LDMTextToImagePipeline, - PaintByExamplePipeline, - StableDiffusionControlNetPipeline, - StableDiffusionInpaintPipeline, - StableDiffusionPipeline, - StableDiffusionUpscalePipeline, - StableDiffusionXLControlNetInpaintPipeline, - StableDiffusionXLImg2ImgPipeline, - StableDiffusionXLInpaintPipeline, - StableDiffusionXLPipeline, - StableUnCLIPImg2ImgPipeline, - StableUnCLIPPipeline, - ) - - if prediction_type == "v-prediction": - prediction_type = "v_prediction" - - if isinstance(checkpoint_path_or_dict, str): - if from_safetensors: - from safetensors.torch import load_file as safe_load - - checkpoint = safe_load(checkpoint_path_or_dict, device="cpu") - else: - if device is None: - device = "cuda" if torch.cuda.is_available() else "cpu" - checkpoint = torch.load(checkpoint_path_or_dict, map_location=device) - else: - checkpoint = torch.load(checkpoint_path_or_dict, map_location=device) - elif isinstance(checkpoint_path_or_dict, dict): - checkpoint = checkpoint_path_or_dict - - # Sometimes models don't have the global_step item - if "global_step" in checkpoint: - global_step = checkpoint["global_step"] - else: - logger.debug("global_step key not found in model") - global_step = None - - # NOTE: this while loop isn't great but this controlnet checkpoint has one additional - # "state_dict" key https://huggingface.co/thibaud/controlnet-canny-sd21 - while "state_dict" in checkpoint: - checkpoint = checkpoint["state_dict"] - - if original_config_file is None: - key_name_v2_1 = "model.diffusion_model.input_blocks.2.1.transformer_blocks.0.attn2.to_k.weight" - key_name_sd_xl_base = "conditioner.embedders.1.model.transformer.resblocks.9.mlp.c_proj.bias" - key_name_sd_xl_refiner = "conditioner.embedders.0.model.transformer.resblocks.9.mlp.c_proj.bias" - is_upscale = pipeline_class == StableDiffusionUpscalePipeline - - config_url = None - - # model_type = "v1" - if config_files is not None and "v1" in config_files: - original_config_file = config_files["v1"] - else: - config_url = "https://raw.githubusercontent.com/CompVis/stable-diffusion/main/configs/stable-diffusion/v1-inference.yaml" - - if key_name_v2_1 in checkpoint and checkpoint[key_name_v2_1].shape[-1] == 1024: - # model_type = "v2" - if config_files is not None and "v2" in config_files: - original_config_file = config_files["v2"] - else: - config_url = "https://raw.githubusercontent.com/Stability-AI/stablediffusion/main/configs/stable-diffusion/v2-inference-v.yaml" - if global_step == 110000: - # v2.1 needs to upcast attention - upcast_attention = True - elif key_name_sd_xl_base in checkpoint: - # only base xl has two text embedders - if config_files is not None and "xl" in config_files: - original_config_file = config_files["xl"] - else: - config_url = "https://raw.githubusercontent.com/Stability-AI/generative-models/main/configs/inference/sd_xl_base.yaml" - elif key_name_sd_xl_refiner in checkpoint: - # only refiner xl has embedder and one text embedders - if config_files is not None and "xl_refiner" in config_files: - original_config_file = config_files["xl_refiner"] - else: - config_url = "https://raw.githubusercontent.com/Stability-AI/generative-models/main/configs/inference/sd_xl_refiner.yaml" - - if is_upscale: - config_url = "https://raw.githubusercontent.com/Stability-AI/stablediffusion/main/configs/stable-diffusion/x4-upscaling.yaml" - - if config_url is not None: - original_config_file = BytesIO(requests.get(config_url).content) - else: - with open(original_config_file, "r") as f: - original_config_file = f.read() - - original_config = yaml.safe_load(original_config_file) - - # Convert the text model. - if ( - model_type is None - and "cond_stage_config" in original_config["model"]["params"] - and original_config["model"]["params"]["cond_stage_config"] is not None - ): - model_type = original_config["model"]["params"]["cond_stage_config"]["target"].split(".")[-1] - logger.debug(f"no `model_type` given, `model_type` inferred as: {model_type}") - elif model_type is None and original_config["model"]["params"]["network_config"] is not None: - if original_config["model"]["params"]["network_config"]["params"]["context_dim"] == 2048: - model_type = "SDXL" - else: - model_type = "SDXL-Refiner" - if image_size is None: - image_size = 1024 - - if pipeline_class is None: - # Check if we have a SDXL or SD model and initialize default pipeline - if model_type not in ["SDXL", "SDXL-Refiner"]: - pipeline_class = StableDiffusionPipeline if not controlnet else StableDiffusionControlNetPipeline - else: - pipeline_class = StableDiffusionXLPipeline if model_type == "SDXL" else StableDiffusionXLImg2ImgPipeline - - if num_in_channels is None and pipeline_class in [ - StableDiffusionInpaintPipeline, - StableDiffusionXLInpaintPipeline, - StableDiffusionXLControlNetInpaintPipeline, - ]: - num_in_channels = 9 - if num_in_channels is None and pipeline_class == StableDiffusionUpscalePipeline: - num_in_channels = 7 - elif num_in_channels is None: - num_in_channels = 4 - - if "unet_config" in original_config["model"]["params"]: - original_config["model"]["params"]["unet_config"]["params"]["in_channels"] = num_in_channels - - if ( - "parameterization" in original_config["model"]["params"] - and original_config["model"]["params"]["parameterization"] == "v" - ): - if prediction_type is None: - # NOTE: For stable diffusion 2 base it is recommended to pass `prediction_type=="epsilon"` - # as it relies on a brittle global step parameter here - prediction_type = "epsilon" if global_step == 875000 else "v_prediction" - if image_size is None: - # NOTE: For stable diffusion 2 base one has to pass `image_size==512` - # as it relies on a brittle global step parameter here - image_size = 512 if global_step == 875000 else 768 - else: - if prediction_type is None: - prediction_type = "epsilon" - if image_size is None: - image_size = 512 - - if controlnet is None and "control_stage_config" in original_config["model"]["params"]: - path = checkpoint_path_or_dict if isinstance(checkpoint_path_or_dict, str) else "" - controlnet = convert_controlnet_checkpoint( - checkpoint, original_config, path, image_size, upcast_attention, extract_ema - ) - - if "timesteps" in original_config["model"]["params"]: - num_train_timesteps = original_config["model"]["params"]["timesteps"] - else: - num_train_timesteps = 1000 - - if model_type in ["SDXL", "SDXL-Refiner"]: - scheduler_dict = { - "beta_schedule": "scaled_linear", - "beta_start": 0.00085, - "beta_end": 0.012, - "interpolation_type": "linear", - "num_train_timesteps": num_train_timesteps, - "prediction_type": "epsilon", - "sample_max_value": 1.0, - "set_alpha_to_one": False, - "skip_prk_steps": True, - "steps_offset": 1, - "timestep_spacing": "leading", - } - scheduler = EulerDiscreteScheduler.from_config(scheduler_dict) - scheduler_type = "euler" - else: - if "linear_start" in original_config["model"]["params"]: - beta_start = original_config["model"]["params"]["linear_start"] - else: - beta_start = 0.02 - - if "linear_end" in original_config["model"]["params"]: - beta_end = original_config["model"]["params"]["linear_end"] - else: - beta_end = 0.085 - scheduler = DDIMScheduler( - beta_end=beta_end, - beta_schedule="scaled_linear", - beta_start=beta_start, - num_train_timesteps=num_train_timesteps, - steps_offset=1, - clip_sample=False, - set_alpha_to_one=False, - prediction_type=prediction_type, - ) - # make sure scheduler works correctly with DDIM - scheduler.register_to_config(clip_sample=False) - - if scheduler_type == "pndm": - config = dict(scheduler.config) - config["skip_prk_steps"] = True - scheduler = PNDMScheduler.from_config(config) - elif scheduler_type == "lms": - scheduler = LMSDiscreteScheduler.from_config(scheduler.config) - elif scheduler_type == "heun": - scheduler = HeunDiscreteScheduler.from_config(scheduler.config) - elif scheduler_type == "euler": - scheduler = EulerDiscreteScheduler.from_config(scheduler.config) - elif scheduler_type == "euler-ancestral": - scheduler = EulerAncestralDiscreteScheduler.from_config(scheduler.config) - elif scheduler_type == "dpm": - scheduler = DPMSolverMultistepScheduler.from_config(scheduler.config) - elif scheduler_type == "ddim": - scheduler = scheduler - else: - raise ValueError(f"Scheduler of type {scheduler_type} doesn't exist!") - - if pipeline_class == StableDiffusionUpscalePipeline: - image_size = original_config["model"]["params"]["unet_config"]["params"]["image_size"] - - # Convert the UNet2DConditionModel model. - unet_config = create_unet_diffusers_config(original_config, image_size=image_size) - unet_config["upcast_attention"] = upcast_attention - - path = checkpoint_path_or_dict if isinstance(checkpoint_path_or_dict, str) else "" - converted_unet_checkpoint = convert_ldm_unet_checkpoint( - checkpoint, unet_config, path=path, extract_ema=extract_ema - ) - - ctx = init_empty_weights if is_accelerate_available() else nullcontext - with ctx(): - unet = UNet2DConditionModel(**unet_config) - - if is_accelerate_available(): - if model_type not in ["SDXL", "SDXL-Refiner"]: # SBM Delay this. - for param_name, param in converted_unet_checkpoint.items(): - set_module_tensor_to_device(unet, param_name, "cpu", value=param) - else: - unet.load_state_dict(converted_unet_checkpoint) - - # Convert the VAE model. - if vae_path is None and vae is None: - vae_config = create_vae_diffusers_config(original_config, image_size=image_size) - converted_vae_checkpoint = convert_ldm_vae_checkpoint(checkpoint, vae_config) - - if ( - "model" in original_config - and "params" in original_config["model"] - and "scale_factor" in original_config["model"]["params"] - ): - vae_scaling_factor = original_config["model"]["params"]["scale_factor"] - else: - vae_scaling_factor = 0.18215 # default SD scaling factor - - vae_config["scaling_factor"] = vae_scaling_factor - - ctx = init_empty_weights if is_accelerate_available() else nullcontext - with ctx(): - vae = AutoencoderKL(**vae_config) - - if is_accelerate_available(): - for param_name, param in converted_vae_checkpoint.items(): - set_module_tensor_to_device(vae, param_name, "cpu", value=param) - else: - vae.load_state_dict(converted_vae_checkpoint) - elif vae is None: - vae = AutoencoderKL.from_pretrained(vae_path, local_files_only=local_files_only) - - if model_type == "FrozenOpenCLIPEmbedder": - config_name = "stabilityai/stable-diffusion-2" - config_kwargs = {"subfolder": "text_encoder"} - - if text_encoder is None: - text_model = convert_open_clip_checkpoint( - checkpoint, config_name, local_files_only=local_files_only, **config_kwargs - ) - else: - text_model = text_encoder - - try: - tokenizer = CLIPTokenizer.from_pretrained( - "stabilityai/stable-diffusion-2", subfolder="tokenizer", local_files_only=local_files_only - ) - except Exception: - raise ValueError( - f"With local_files_only set to {local_files_only}, you must first locally save the tokenizer in the following path: 'stabilityai/stable-diffusion-2'." - ) - - if stable_unclip is None: - if controlnet: - pipe = pipeline_class( - vae=vae, - text_encoder=text_model, - tokenizer=tokenizer, - unet=unet, - scheduler=scheduler, - controlnet=controlnet, - safety_checker=None, - feature_extractor=None, - ) - if hasattr(pipe, "requires_safety_checker"): - pipe.requires_safety_checker = False - - elif pipeline_class == StableDiffusionUpscalePipeline: - scheduler = DDIMScheduler.from_pretrained( - "stabilityai/stable-diffusion-x4-upscaler", subfolder="scheduler" - ) - low_res_scheduler = DDPMScheduler.from_pretrained( - "stabilityai/stable-diffusion-x4-upscaler", subfolder="low_res_scheduler" - ) - - pipe = pipeline_class( - vae=vae, - text_encoder=text_model, - tokenizer=tokenizer, - unet=unet, - scheduler=scheduler, - low_res_scheduler=low_res_scheduler, - safety_checker=None, - feature_extractor=None, - ) - - else: - pipe = pipeline_class( - vae=vae, - text_encoder=text_model, - tokenizer=tokenizer, - unet=unet, - scheduler=scheduler, - safety_checker=None, - feature_extractor=None, - ) - if hasattr(pipe, "requires_safety_checker"): - pipe.requires_safety_checker = False - - else: - image_normalizer, image_noising_scheduler = stable_unclip_image_noising_components( - original_config, clip_stats_path=clip_stats_path, device=device - ) - - if stable_unclip == "img2img": - feature_extractor, image_encoder = stable_unclip_image_encoder(original_config) - - pipe = StableUnCLIPImg2ImgPipeline( - # image encoding components - feature_extractor=feature_extractor, - image_encoder=image_encoder, - # image noising components - image_normalizer=image_normalizer, - image_noising_scheduler=image_noising_scheduler, - # regular denoising components - tokenizer=tokenizer, - text_encoder=text_model, - unet=unet, - scheduler=scheduler, - # vae - vae=vae, - ) - elif stable_unclip == "txt2img": - if stable_unclip_prior is None or stable_unclip_prior == "karlo": - karlo_model = "kakaobrain/karlo-v1-alpha" - prior = PriorTransformer.from_pretrained( - karlo_model, subfolder="prior", local_files_only=local_files_only - ) - - try: - prior_tokenizer = CLIPTokenizer.from_pretrained( - "openai/clip-vit-large-patch14", local_files_only=local_files_only - ) - except Exception: - raise ValueError( - f"With local_files_only set to {local_files_only}, you must first locally save the tokenizer in the following path: 'openai/clip-vit-large-patch14'." - ) - prior_text_model = CLIPTextModelWithProjection.from_pretrained( - "openai/clip-vit-large-patch14", local_files_only=local_files_only - ) - - prior_scheduler = UnCLIPScheduler.from_pretrained( - karlo_model, subfolder="prior_scheduler", local_files_only=local_files_only - ) - prior_scheduler = DDPMScheduler.from_config(prior_scheduler.config) - else: - raise NotImplementedError(f"unknown prior for stable unclip model: {stable_unclip_prior}") - - pipe = StableUnCLIPPipeline( - # prior components - prior_tokenizer=prior_tokenizer, - prior_text_encoder=prior_text_model, - prior=prior, - prior_scheduler=prior_scheduler, - # image noising components - image_normalizer=image_normalizer, - image_noising_scheduler=image_noising_scheduler, - # regular denoising components - tokenizer=tokenizer, - text_encoder=text_model, - unet=unet, - scheduler=scheduler, - # vae - vae=vae, - ) - else: - raise NotImplementedError(f"unknown `stable_unclip` type: {stable_unclip}") - elif model_type == "PaintByExample": - vision_model = convert_paint_by_example_checkpoint(checkpoint) - try: - tokenizer = CLIPTokenizer.from_pretrained( - "openai/clip-vit-large-patch14", local_files_only=local_files_only - ) - except Exception: - raise ValueError( - f"With local_files_only set to {local_files_only}, you must first locally save the tokenizer in the following path: 'openai/clip-vit-large-patch14'." - ) - try: - feature_extractor = AutoFeatureExtractor.from_pretrained( - "CompVis/stable-diffusion-safety-checker", local_files_only=local_files_only - ) - except Exception: - raise ValueError( - f"With local_files_only set to {local_files_only}, you must first locally save the feature_extractor in the following path: 'CompVis/stable-diffusion-safety-checker'." - ) - pipe = PaintByExamplePipeline( - vae=vae, - image_encoder=vision_model, - unet=unet, - scheduler=scheduler, - safety_checker=None, - feature_extractor=feature_extractor, - ) - elif model_type == "FrozenCLIPEmbedder": - text_model = convert_ldm_clip_checkpoint( - checkpoint, local_files_only=local_files_only, text_encoder=text_encoder - ) - try: - tokenizer = ( - CLIPTokenizer.from_pretrained("openai/clip-vit-large-patch14", local_files_only=local_files_only) - if tokenizer is None - else tokenizer - ) - except Exception: - raise ValueError( - f"With local_files_only set to {local_files_only}, you must first locally save the tokenizer in the following path: 'openai/clip-vit-large-patch14'." - ) - - if load_safety_checker: - safety_checker = StableDiffusionSafetyChecker.from_pretrained( - "CompVis/stable-diffusion-safety-checker", local_files_only=local_files_only - ) - feature_extractor = AutoFeatureExtractor.from_pretrained( - "CompVis/stable-diffusion-safety-checker", local_files_only=local_files_only - ) - else: - safety_checker = None - feature_extractor = None - - if controlnet: - pipe = pipeline_class( - vae=vae, - text_encoder=text_model, - tokenizer=tokenizer, - unet=unet, - controlnet=controlnet, - scheduler=scheduler, - safety_checker=safety_checker, - feature_extractor=feature_extractor, - ) - else: - pipe = pipeline_class( - vae=vae, - text_encoder=text_model, - tokenizer=tokenizer, - unet=unet, - scheduler=scheduler, - safety_checker=safety_checker, - feature_extractor=feature_extractor, - ) - elif model_type in ["SDXL", "SDXL-Refiner"]: - is_refiner = model_type == "SDXL-Refiner" - - if (is_refiner is False) and (tokenizer is None): - try: - tokenizer = CLIPTokenizer.from_pretrained( - "openai/clip-vit-large-patch14", local_files_only=local_files_only - ) - except Exception: - raise ValueError( - f"With local_files_only set to {local_files_only}, you must first locally save the tokenizer in the following path: 'openai/clip-vit-large-patch14'." - ) - - if (is_refiner is False) and (text_encoder is None): - text_encoder = convert_ldm_clip_checkpoint(checkpoint, local_files_only=local_files_only) - - if tokenizer_2 is None: - try: - tokenizer_2 = CLIPTokenizer.from_pretrained( - "laion/CLIP-ViT-bigG-14-laion2B-39B-b160k", pad_token="!", local_files_only=local_files_only - ) - except Exception: - raise ValueError( - f"With local_files_only set to {local_files_only}, you must first locally save the tokenizer in the following path: 'laion/CLIP-ViT-bigG-14-laion2B-39B-b160k' with `pad_token` set to '!'." - ) - - if text_encoder_2 is None: - config_name = "laion/CLIP-ViT-bigG-14-laion2B-39B-b160k" - config_kwargs = {"projection_dim": 1280} - prefix = "conditioner.embedders.0.model." if is_refiner else "conditioner.embedders.1.model." - - text_encoder_2 = convert_open_clip_checkpoint( - checkpoint, - config_name, - prefix=prefix, - has_projection=True, - local_files_only=local_files_only, - **config_kwargs, - ) - - if is_accelerate_available(): # SBM Now move model to cpu. - for param_name, param in converted_unet_checkpoint.items(): - set_module_tensor_to_device(unet, param_name, "cpu", value=param) - - if controlnet: - pipe = pipeline_class( - vae=vae, - text_encoder=text_encoder, - tokenizer=tokenizer, - text_encoder_2=text_encoder_2, - tokenizer_2=tokenizer_2, - unet=unet, - controlnet=controlnet, - scheduler=scheduler, - force_zeros_for_empty_prompt=True, - ) - elif adapter: - pipe = pipeline_class( - vae=vae, - text_encoder=text_encoder, - tokenizer=tokenizer, - text_encoder_2=text_encoder_2, - tokenizer_2=tokenizer_2, - unet=unet, - adapter=adapter, - scheduler=scheduler, - force_zeros_for_empty_prompt=True, - ) - - else: - pipeline_kwargs = { - "vae": vae, - "text_encoder": text_encoder, - "tokenizer": tokenizer, - "text_encoder_2": text_encoder_2, - "tokenizer_2": tokenizer_2, - "unet": unet, - "scheduler": scheduler, - } - - if (pipeline_class == StableDiffusionXLImg2ImgPipeline) or ( - pipeline_class == StableDiffusionXLInpaintPipeline - ): - pipeline_kwargs.update({"requires_aesthetics_score": is_refiner}) - - if is_refiner: - pipeline_kwargs.update({"force_zeros_for_empty_prompt": False}) - - pipe = pipeline_class(**pipeline_kwargs) - else: - text_config = create_ldm_bert_config(original_config) - text_model = convert_ldm_bert_checkpoint(checkpoint, text_config) - tokenizer = BertTokenizerFast.from_pretrained("bert-base-uncased", local_files_only=local_files_only) - pipe = LDMTextToImagePipeline(vqvae=vae, bert=text_model, tokenizer=tokenizer, unet=unet, scheduler=scheduler) - - return pipe - - -def download_controlnet_from_original_ckpt( - checkpoint_path: str, - original_config_file: str, - image_size: int = 512, - extract_ema: bool = False, - num_in_channels: Optional[int] = None, - upcast_attention: Optional[bool] = None, - device: str = None, - from_safetensors: bool = False, - use_linear_projection: Optional[bool] = None, - cross_attention_dim: Optional[bool] = None, -) -> DiffusionPipeline: - if from_safetensors: - from safetensors import safe_open - - checkpoint = {} - with safe_open(checkpoint_path, framework="pt", device="cpu") as f: - for key in f.keys(): - checkpoint[key] = f.get_tensor(key) - else: - if device is None: - device = "cuda" if torch.cuda.is_available() else "cpu" - checkpoint = torch.load(checkpoint_path, map_location=device) - else: - checkpoint = torch.load(checkpoint_path, map_location=device) - - # NOTE: this while loop isn't great but this controlnet checkpoint has one additional - # "state_dict" key https://huggingface.co/thibaud/controlnet-canny-sd21 - while "state_dict" in checkpoint: - checkpoint = checkpoint["state_dict"] - - original_config = yaml.safe_load(original_config_file) - - if num_in_channels is not None: - original_config["model"]["params"]["unet_config"]["params"]["in_channels"] = num_in_channels - - if "control_stage_config" not in original_config["model"]["params"]: - raise ValueError("`control_stage_config` not present in original config") - - controlnet = convert_controlnet_checkpoint( - checkpoint, - original_config, - checkpoint_path, - image_size, - upcast_attention, - extract_ema, - use_linear_projection=use_linear_projection, - cross_attention_dim=cross_attention_dim, - ) - - return controlnet - - -def download_promptdiffusion_from_original_ckpt( - checkpoint_path: str, - original_config_file: str, - image_size: int = 512, - extract_ema: bool = False, - num_in_channels: Optional[int] = None, - upcast_attention: Optional[bool] = None, - device: str = None, - from_safetensors: bool = False, - use_linear_projection: Optional[bool] = None, - cross_attention_dim: Optional[bool] = None, -) -> DiffusionPipeline: - if from_safetensors: - from safetensors import safe_open - - checkpoint = {} - with safe_open(checkpoint_path, framework="pt", device="cpu") as f: - for key in f.keys(): - checkpoint[key] = f.get_tensor(key) - else: - if device is None: - device = "cuda" if torch.cuda.is_available() else "cpu" - checkpoint = torch.load(checkpoint_path, map_location=device) - else: - checkpoint = torch.load(checkpoint_path, map_location=device) - - # NOTE: this while loop isn't great but this controlnet checkpoint has one additional - # "state_dict" key https://huggingface.co/thibaud/controlnet-canny-sd21 - while "state_dict" in checkpoint: - checkpoint = checkpoint["state_dict"] - - original_config = yaml.safe_load(open(original_config_file)) - - if num_in_channels is not None: - original_config["model"]["params"]["unet_config"]["params"]["in_channels"] = num_in_channels - if "control_stage_config" not in original_config["model"]["params"]: - raise ValueError("`control_stage_config` not present in original config") - - controlnet = convert_promptdiffusion_checkpoint( - checkpoint, - original_config, - checkpoint_path, - image_size, - upcast_attention, - extract_ema, - use_linear_projection=use_linear_projection, - cross_attention_dim=cross_attention_dim, - ) - - return controlnet - - -if __name__ == "__main__": - parser = argparse.ArgumentParser() - - parser.add_argument( - "--checkpoint_path", default=None, type=str, required=True, help="Path to the checkpoint to convert." - ) - parser.add_argument( - "--original_config_file", - type=str, - required=True, - help="The YAML config file corresponding to the original architecture.", - ) - parser.add_argument( - "--num_in_channels", - default=None, - type=int, - help="The number of input channels. If `None` number of input channels will be automatically inferred.", - ) - parser.add_argument( - "--image_size", - default=512, - type=int, - help=( - "The image size that the model was trained on. Use 512 for Stable Diffusion v1.X and Stable Siffusion v2" - " Base. Use 768 for Stable Diffusion v2." - ), - ) - parser.add_argument( - "--extract_ema", - action="store_true", - help=( - "Only relevant for checkpoints that have both EMA and non-EMA weights. Whether to extract the EMA weights" - " or not. Defaults to `False`. Add `--extract_ema` to extract the EMA weights. EMA weights usually yield" - " higher quality images for inference. Non-EMA weights are usually better to continue fine-tuning." - ), - ) - parser.add_argument( - "--upcast_attention", - action="store_true", - help=( - "Whether the attention computation should always be upcasted. This is necessary when running stable" - " diffusion 2.1." - ), - ) - parser.add_argument( - "--from_safetensors", - action="store_true", - help="If `--checkpoint_path` is in `safetensors` format, load checkpoint with safetensors instead of PyTorch.", - ) - parser.add_argument( - "--to_safetensors", - action="store_true", - help="Whether to store pipeline in safetensors format or not.", - ) - parser.add_argument("--dump_path", default=None, type=str, required=True, help="Path to the output model.") - parser.add_argument("--device", type=str, help="Device to use (e.g. cpu, cuda:0, cuda:1, etc.)") - - # small workaround to get argparser to parse a boolean input as either true _or_ false - def parse_bool(string): - if string == "True": - return True - elif string == "False": - return False - else: - raise ValueError(f"could not parse string as bool {string}") - - parser.add_argument( - "--use_linear_projection", help="Override for use linear projection", required=False, type=parse_bool - ) - - parser.add_argument("--cross_attention_dim", help="Override for cross attention_dim", required=False, type=int) - - args = parser.parse_args() - - controlnet = download_promptdiffusion_from_original_ckpt( - checkpoint_path=args.checkpoint_path, - original_config_file=args.original_config_file, - image_size=args.image_size, - extract_ema=args.extract_ema, - num_in_channels=args.num_in_channels, - upcast_attention=args.upcast_attention, - from_safetensors=args.from_safetensors, - device=args.device, - use_linear_projection=args.use_linear_projection, - cross_attention_dim=args.cross_attention_dim, - ) - - controlnet.save_pretrained(args.dump_path, safe_serialization=args.to_safetensors) +# In construction... \ No newline at end of file From b9164e3e9e8433029b0fb2458ca71c13f65a02c7 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tolga=20Cang=C3=B6z?= Date: Fri, 2 Aug 2024 18:57:31 +0300 Subject: [PATCH 24/87] feat: Add OCR model and its components --- .../anytext/ocr_recog/RNN.py | 210 +++++++ .../anytext/ocr_recog/RecCTCHead.py | 48 ++ .../anytext/ocr_recog/RecModel.py | 45 ++ .../anytext/ocr_recog/RecMv1_enhance.py | 233 +++++++ .../anytext/ocr_recog/RecSVTR.py | 591 ++++++++++++++++++ .../anytext/ocr_recog/common.py | 74 +++ .../anytext/ocr_recog/en_dict.txt | 95 +++ .../research_projects/anytext/recognizer.py | 315 ++++++++++ 8 files changed, 1611 insertions(+) create mode 100755 examples/research_projects/anytext/ocr_recog/RNN.py create mode 100755 examples/research_projects/anytext/ocr_recog/RecCTCHead.py create mode 100755 examples/research_projects/anytext/ocr_recog/RecModel.py create mode 100644 examples/research_projects/anytext/ocr_recog/RecMv1_enhance.py create mode 100644 examples/research_projects/anytext/ocr_recog/RecSVTR.py create mode 100644 examples/research_projects/anytext/ocr_recog/common.py create mode 100644 examples/research_projects/anytext/ocr_recog/en_dict.txt create mode 100755 examples/research_projects/anytext/recognizer.py diff --git a/examples/research_projects/anytext/ocr_recog/RNN.py b/examples/research_projects/anytext/ocr_recog/RNN.py new file mode 100755 index 000000000000..cf16855b3711 --- /dev/null +++ b/examples/research_projects/anytext/ocr_recog/RNN.py @@ -0,0 +1,210 @@ +from torch import nn +import torch +from .RecSVTR import Block + +class Swish(nn.Module): + def __int__(self): + super(Swish, self).__int__() + + def forward(self,x): + return x*torch.sigmoid(x) + +class Im2Im(nn.Module): + def __init__(self, in_channels, **kwargs): + super().__init__() + self.out_channels = in_channels + + def forward(self, x): + return x + +class Im2Seq(nn.Module): + def __init__(self, in_channels, **kwargs): + super().__init__() + self.out_channels = in_channels + + def forward(self, x): + B, C, H, W = x.shape + # assert H == 1 + x = x.reshape(B, C, H * W) + x = x.permute((0, 2, 1)) + return x + +class EncoderWithRNN(nn.Module): + def __init__(self, in_channels,**kwargs): + super(EncoderWithRNN, self).__init__() + hidden_size = kwargs.get('hidden_size', 256) + self.out_channels = hidden_size * 2 + self.lstm = nn.LSTM(in_channels, hidden_size, bidirectional=True, num_layers=2,batch_first=True) + + def forward(self, x): + self.lstm.flatten_parameters() + x, _ = self.lstm(x) + return x + +class SequenceEncoder(nn.Module): + def __init__(self, in_channels, encoder_type='rnn', **kwargs): + super(SequenceEncoder, self).__init__() + self.encoder_reshape = Im2Seq(in_channels) + self.out_channels = self.encoder_reshape.out_channels + self.encoder_type = encoder_type + if encoder_type == 'reshape': + self.only_reshape = True + else: + support_encoder_dict = { + 'reshape': Im2Seq, + 'rnn': EncoderWithRNN, + 'svtr': EncoderWithSVTR + } + assert encoder_type in support_encoder_dict, '{} must in {}'.format( + encoder_type, support_encoder_dict.keys()) + + self.encoder = support_encoder_dict[encoder_type]( + self.encoder_reshape.out_channels,**kwargs) + self.out_channels = self.encoder.out_channels + self.only_reshape = False + + def forward(self, x): + if self.encoder_type != 'svtr': + x = self.encoder_reshape(x) + if not self.only_reshape: + x = self.encoder(x) + return x + else: + x = self.encoder(x) + x = self.encoder_reshape(x) + return x + +class ConvBNLayer(nn.Module): + def __init__(self, + in_channels, + out_channels, + kernel_size=3, + stride=1, + padding=0, + bias_attr=False, + groups=1, + act=nn.GELU): + super().__init__() + self.conv = nn.Conv2d( + in_channels=in_channels, + out_channels=out_channels, + kernel_size=kernel_size, + stride=stride, + padding=padding, + groups=groups, + # weight_attr=paddle.ParamAttr(initializer=nn.initializer.KaimingUniform()), + bias=bias_attr) + self.norm = nn.BatchNorm2d(out_channels) + self.act = Swish() + + def forward(self, inputs): + out = self.conv(inputs) + out = self.norm(out) + out = self.act(out) + return out + + +class EncoderWithSVTR(nn.Module): + def __init__( + self, + in_channels, + dims=64, # XS + depth=2, + hidden_dims=120, + use_guide=False, + num_heads=8, + qkv_bias=True, + mlp_ratio=2.0, + drop_rate=0.1, + attn_drop_rate=0.1, + drop_path=0., + qk_scale=None): + super(EncoderWithSVTR, self).__init__() + self.depth = depth + self.use_guide = use_guide + self.conv1 = ConvBNLayer( + in_channels, in_channels // 8, padding=1, act='swish') + self.conv2 = ConvBNLayer( + in_channels // 8, hidden_dims, kernel_size=1, act='swish') + + self.svtr_block = nn.ModuleList([ + Block( + dim=hidden_dims, + num_heads=num_heads, + mixer='Global', + HW=None, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, + qk_scale=qk_scale, + drop=drop_rate, + act_layer='swish', + attn_drop=attn_drop_rate, + drop_path=drop_path, + norm_layer='nn.LayerNorm', + epsilon=1e-05, + prenorm=False) for i in range(depth) + ]) + self.norm = nn.LayerNorm(hidden_dims, eps=1e-6) + self.conv3 = ConvBNLayer( + hidden_dims, in_channels, kernel_size=1, act='swish') + # last conv-nxn, the input is concat of input tensor and conv3 output tensor + self.conv4 = ConvBNLayer( + 2 * in_channels, in_channels // 8, padding=1, act='swish') + + self.conv1x1 = ConvBNLayer( + in_channels // 8, dims, kernel_size=1, act='swish') + self.out_channels = dims + self.apply(self._init_weights) + + def _init_weights(self, m): + # weight initialization + if isinstance(m, nn.Conv2d): + nn.init.kaiming_normal_(m.weight, mode='fan_out') + if m.bias is not None: + nn.init.zeros_(m.bias) + elif isinstance(m, nn.BatchNorm2d): + nn.init.ones_(m.weight) + nn.init.zeros_(m.bias) + elif isinstance(m, nn.Linear): + nn.init.normal_(m.weight, 0, 0.01) + if m.bias is not None: + nn.init.zeros_(m.bias) + elif isinstance(m, nn.ConvTranspose2d): + nn.init.kaiming_normal_(m.weight, mode='fan_out') + if m.bias is not None: + nn.init.zeros_(m.bias) + elif isinstance(m, nn.LayerNorm): + nn.init.ones_(m.weight) + nn.init.zeros_(m.bias) + + def forward(self, x): + # for use guide + if self.use_guide: + z = x.clone() + z.stop_gradient = True + else: + z = x + # for short cut + h = z + # reduce dim + z = self.conv1(z) + z = self.conv2(z) + # SVTR global block + B, C, H, W = z.shape + z = z.flatten(2).permute(0, 2, 1) + + for blk in self.svtr_block: + z = blk(z) + + z = self.norm(z) + # last stage + z = z.reshape([-1, H, W, C]).permute(0, 3, 1, 2) + z = self.conv3(z) + z = torch.cat((h, z), dim=1) + z = self.conv1x1(self.conv4(z)) + + return z + +if __name__=="__main__": + svtrRNN = EncoderWithSVTR(56) + print(svtrRNN) \ No newline at end of file diff --git a/examples/research_projects/anytext/ocr_recog/RecCTCHead.py b/examples/research_projects/anytext/ocr_recog/RecCTCHead.py new file mode 100755 index 000000000000..867ede9916b1 --- /dev/null +++ b/examples/research_projects/anytext/ocr_recog/RecCTCHead.py @@ -0,0 +1,48 @@ +from torch import nn + + +class CTCHead(nn.Module): + def __init__(self, + in_channels, + out_channels=6625, + fc_decay=0.0004, + mid_channels=None, + return_feats=False, + **kwargs): + super(CTCHead, self).__init__() + if mid_channels is None: + self.fc = nn.Linear( + in_channels, + out_channels, + bias=True,) + else: + self.fc1 = nn.Linear( + in_channels, + mid_channels, + bias=True, + ) + self.fc2 = nn.Linear( + mid_channels, + out_channels, + bias=True, + ) + + self.out_channels = out_channels + self.mid_channels = mid_channels + self.return_feats = return_feats + + def forward(self, x, labels=None): + if self.mid_channels is None: + predicts = self.fc(x) + else: + x = self.fc1(x) + predicts = self.fc2(x) + + if self.return_feats: + result = dict() + result['ctc'] = predicts + result['ctc_neck'] = x + else: + result = predicts + + return result diff --git a/examples/research_projects/anytext/ocr_recog/RecModel.py b/examples/research_projects/anytext/ocr_recog/RecModel.py new file mode 100755 index 000000000000..c2313bf02c95 --- /dev/null +++ b/examples/research_projects/anytext/ocr_recog/RecModel.py @@ -0,0 +1,45 @@ +from torch import nn +from .RNN import SequenceEncoder, Im2Seq, Im2Im +from .RecMv1_enhance import MobileNetV1Enhance + +from .RecCTCHead import CTCHead + +backbone_dict = {"MobileNetV1Enhance":MobileNetV1Enhance} +neck_dict = {'SequenceEncoder': SequenceEncoder, 'Im2Seq': Im2Seq,'None':Im2Im} +head_dict = {'CTCHead':CTCHead} + + +class RecModel(nn.Module): + def __init__(self, config): + super().__init__() + assert 'in_channels' in config, 'in_channels must in model config' + backbone_type = config.backbone.pop('type') + assert backbone_type in backbone_dict, f'backbone.type must in {backbone_dict}' + self.backbone = backbone_dict[backbone_type](config.in_channels, **config.backbone) + + neck_type = config.neck.pop('type') + assert neck_type in neck_dict, f'neck.type must in {neck_dict}' + self.neck = neck_dict[neck_type](self.backbone.out_channels, **config.neck) + + head_type = config.head.pop('type') + assert head_type in head_dict, f'head.type must in {head_dict}' + self.head = head_dict[head_type](self.neck.out_channels, **config.head) + + self.name = f'RecModel_{backbone_type}_{neck_type}_{head_type}' + + def load_3rd_state_dict(self, _3rd_name, _state): + self.backbone.load_3rd_state_dict(_3rd_name, _state) + self.neck.load_3rd_state_dict(_3rd_name, _state) + self.head.load_3rd_state_dict(_3rd_name, _state) + + def forward(self, x): + x = self.backbone(x) + x = self.neck(x) + x = self.head(x) + return x + + def encode(self, x): + x = self.backbone(x) + x = self.neck(x) + x = self.head.ctc_encoder(x) + return x diff --git a/examples/research_projects/anytext/ocr_recog/RecMv1_enhance.py b/examples/research_projects/anytext/ocr_recog/RecMv1_enhance.py new file mode 100644 index 000000000000..d5c848533dd3 --- /dev/null +++ b/examples/research_projects/anytext/ocr_recog/RecMv1_enhance.py @@ -0,0 +1,233 @@ +import os, sys +import torch +import torch.nn as nn +import torch.nn.functional as F +from .common import Activation + + +class ConvBNLayer(nn.Module): + def __init__(self, + num_channels, + filter_size, + num_filters, + stride, + padding, + channels=None, + num_groups=1, + act='hard_swish'): + super(ConvBNLayer, self).__init__() + self.act = act + self._conv = nn.Conv2d( + in_channels=num_channels, + out_channels=num_filters, + kernel_size=filter_size, + stride=stride, + padding=padding, + groups=num_groups, + bias=False) + + self._batch_norm = nn.BatchNorm2d( + num_filters, + ) + if self.act is not None: + self._act = Activation(act_type=act, inplace=True) + + def forward(self, inputs): + y = self._conv(inputs) + y = self._batch_norm(y) + if self.act is not None: + y = self._act(y) + return y + + +class DepthwiseSeparable(nn.Module): + def __init__(self, + num_channels, + num_filters1, + num_filters2, + num_groups, + stride, + scale, + dw_size=3, + padding=1, + use_se=False): + super(DepthwiseSeparable, self).__init__() + self.use_se = use_se + self._depthwise_conv = ConvBNLayer( + num_channels=num_channels, + num_filters=int(num_filters1 * scale), + filter_size=dw_size, + stride=stride, + padding=padding, + num_groups=int(num_groups * scale)) + if use_se: + self._se = SEModule(int(num_filters1 * scale)) + self._pointwise_conv = ConvBNLayer( + num_channels=int(num_filters1 * scale), + filter_size=1, + num_filters=int(num_filters2 * scale), + stride=1, + padding=0) + + def forward(self, inputs): + y = self._depthwise_conv(inputs) + if self.use_se: + y = self._se(y) + y = self._pointwise_conv(y) + return y + + +class MobileNetV1Enhance(nn.Module): + def __init__(self, + in_channels=3, + scale=0.5, + last_conv_stride=1, + last_pool_type='max', + **kwargs): + super().__init__() + self.scale = scale + self.block_list = [] + + self.conv1 = ConvBNLayer( + num_channels=in_channels, + filter_size=3, + channels=3, + num_filters=int(32 * scale), + stride=2, + padding=1) + + conv2_1 = DepthwiseSeparable( + num_channels=int(32 * scale), + num_filters1=32, + num_filters2=64, + num_groups=32, + stride=1, + scale=scale) + self.block_list.append(conv2_1) + + conv2_2 = DepthwiseSeparable( + num_channels=int(64 * scale), + num_filters1=64, + num_filters2=128, + num_groups=64, + stride=1, + scale=scale) + self.block_list.append(conv2_2) + + conv3_1 = DepthwiseSeparable( + num_channels=int(128 * scale), + num_filters1=128, + num_filters2=128, + num_groups=128, + stride=1, + scale=scale) + self.block_list.append(conv3_1) + + conv3_2 = DepthwiseSeparable( + num_channels=int(128 * scale), + num_filters1=128, + num_filters2=256, + num_groups=128, + stride=(2, 1), + scale=scale) + self.block_list.append(conv3_2) + + conv4_1 = DepthwiseSeparable( + num_channels=int(256 * scale), + num_filters1=256, + num_filters2=256, + num_groups=256, + stride=1, + scale=scale) + self.block_list.append(conv4_1) + + conv4_2 = DepthwiseSeparable( + num_channels=int(256 * scale), + num_filters1=256, + num_filters2=512, + num_groups=256, + stride=(2, 1), + scale=scale) + self.block_list.append(conv4_2) + + for _ in range(5): + conv5 = DepthwiseSeparable( + num_channels=int(512 * scale), + num_filters1=512, + num_filters2=512, + num_groups=512, + stride=1, + dw_size=5, + padding=2, + scale=scale, + use_se=False) + self.block_list.append(conv5) + + conv5_6 = DepthwiseSeparable( + num_channels=int(512 * scale), + num_filters1=512, + num_filters2=1024, + num_groups=512, + stride=(2, 1), + dw_size=5, + padding=2, + scale=scale, + use_se=True) + self.block_list.append(conv5_6) + + conv6 = DepthwiseSeparable( + num_channels=int(1024 * scale), + num_filters1=1024, + num_filters2=1024, + num_groups=1024, + stride=last_conv_stride, + dw_size=5, + padding=2, + use_se=True, + scale=scale) + self.block_list.append(conv6) + + self.block_list = nn.Sequential(*self.block_list) + if last_pool_type == 'avg': + self.pool = nn.AvgPool2d(kernel_size=2, stride=2, padding=0) + else: + self.pool = nn.MaxPool2d(kernel_size=2, stride=2, padding=0) + self.out_channels = int(1024 * scale) + + def forward(self, inputs): + y = self.conv1(inputs) + y = self.block_list(y) + y = self.pool(y) + return y + +def hardsigmoid(x): + return F.relu6(x + 3., inplace=True) / 6. + +class SEModule(nn.Module): + def __init__(self, channel, reduction=4): + super(SEModule, self).__init__() + self.avg_pool = nn.AdaptiveAvgPool2d(1) + self.conv1 = nn.Conv2d( + in_channels=channel, + out_channels=channel // reduction, + kernel_size=1, + stride=1, + padding=0, + bias=True) + self.conv2 = nn.Conv2d( + in_channels=channel // reduction, + out_channels=channel, + kernel_size=1, + stride=1, + padding=0, + bias=True) + + def forward(self, inputs): + outputs = self.avg_pool(inputs) + outputs = self.conv1(outputs) + outputs = F.relu(outputs) + outputs = self.conv2(outputs) + outputs = hardsigmoid(outputs) + x = torch.mul(inputs, outputs) + + return x diff --git a/examples/research_projects/anytext/ocr_recog/RecSVTR.py b/examples/research_projects/anytext/ocr_recog/RecSVTR.py new file mode 100644 index 000000000000..484b3df99125 --- /dev/null +++ b/examples/research_projects/anytext/ocr_recog/RecSVTR.py @@ -0,0 +1,591 @@ +import torch +import torch.nn as nn +import numpy as np +from torch.nn.init import trunc_normal_, zeros_, ones_ +from torch.nn import functional + + +def drop_path(x, drop_prob=0., training=False): + """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks). + the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper... + See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... + """ + if drop_prob == 0. or not training: + return x + keep_prob = torch.tensor(1 - drop_prob) + shape = (x.size()[0], ) + (1, ) * (x.ndim - 1) + random_tensor = keep_prob + torch.rand(shape, dtype=x.dtype) + random_tensor = torch.floor(random_tensor) # binarize + output = x.divide(keep_prob) * random_tensor + return output + + +class Swish(nn.Module): + def __int__(self): + super(Swish, self).__int__() + + def forward(self,x): + return x*torch.sigmoid(x) + + +class ConvBNLayer(nn.Module): + def __init__(self, + in_channels, + out_channels, + kernel_size=3, + stride=1, + padding=0, + bias_attr=False, + groups=1, + act=nn.GELU): + super().__init__() + self.conv = nn.Conv2d( + in_channels=in_channels, + out_channels=out_channels, + kernel_size=kernel_size, + stride=stride, + padding=padding, + groups=groups, + # weight_attr=paddle.ParamAttr(initializer=nn.initializer.KaimingUniform()), + bias=bias_attr) + self.norm = nn.BatchNorm2d(out_channels) + self.act = act() + + def forward(self, inputs): + out = self.conv(inputs) + out = self.norm(out) + out = self.act(out) + return out + + +class DropPath(nn.Module): + """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks). + """ + + def __init__(self, drop_prob=None): + super(DropPath, self).__init__() + self.drop_prob = drop_prob + + def forward(self, x): + return drop_path(x, self.drop_prob, self.training) + + +class Identity(nn.Module): + def __init__(self): + super(Identity, self).__init__() + + def forward(self, input): + return input + + +class Mlp(nn.Module): + def __init__(self, + in_features, + hidden_features=None, + out_features=None, + act_layer=nn.GELU, + drop=0.): + super().__init__() + out_features = out_features or in_features + hidden_features = hidden_features or in_features + self.fc1 = nn.Linear(in_features, hidden_features) + if isinstance(act_layer, str): + self.act = Swish() + else: + self.act = act_layer() + self.fc2 = nn.Linear(hidden_features, out_features) + self.drop = nn.Dropout(drop) + + def forward(self, x): + x = self.fc1(x) + x = self.act(x) + x = self.drop(x) + x = self.fc2(x) + x = self.drop(x) + return x + + +class ConvMixer(nn.Module): + def __init__( + self, + dim, + num_heads=8, + HW=(8, 25), + local_k=(3, 3), ): + super().__init__() + self.HW = HW + self.dim = dim + self.local_mixer = nn.Conv2d( + dim, + dim, + local_k, + 1, (local_k[0] // 2, local_k[1] // 2), + groups=num_heads, + # weight_attr=ParamAttr(initializer=KaimingNormal()) + ) + + def forward(self, x): + h = self.HW[0] + w = self.HW[1] + x = x.transpose([0, 2, 1]).reshape([0, self.dim, h, w]) + x = self.local_mixer(x) + x = x.flatten(2).transpose([0, 2, 1]) + return x + + +class Attention(nn.Module): + def __init__(self, + dim, + num_heads=8, + mixer='Global', + HW=(8, 25), + local_k=(7, 11), + qkv_bias=False, + qk_scale=None, + attn_drop=0., + proj_drop=0.): + super().__init__() + self.num_heads = num_heads + head_dim = dim // num_heads + self.scale = qk_scale or head_dim**-0.5 + + self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) + self.attn_drop = nn.Dropout(attn_drop) + self.proj = nn.Linear(dim, dim) + self.proj_drop = nn.Dropout(proj_drop) + self.HW = HW + if HW is not None: + H = HW[0] + W = HW[1] + self.N = H * W + self.C = dim + if mixer == 'Local' and HW is not None: + hk = local_k[0] + wk = local_k[1] + mask = torch.ones([H * W, H + hk - 1, W + wk - 1]) + for h in range(0, H): + for w in range(0, W): + mask[h * W + w, h:h + hk, w:w + wk] = 0. + mask_paddle = mask[:, hk // 2:H + hk // 2, wk // 2:W + wk // + 2].flatten(1) + mask_inf = torch.full([H * W, H * W],fill_value=float('-inf')) + mask = torch.where(mask_paddle < 1, mask_paddle, mask_inf) + self.mask = mask[None,None,:] + # self.mask = mask.unsqueeze([0, 1]) + self.mixer = mixer + + def forward(self, x): + if self.HW is not None: + N = self.N + C = self.C + else: + _, N, C = x.shape + qkv = self.qkv(x).reshape((-1, N, 3, self.num_heads, C //self.num_heads)).permute((2, 0, 3, 1, 4)) + q, k, v = qkv[0] * self.scale, qkv[1], qkv[2] + + attn = (q.matmul(k.permute((0, 1, 3, 2)))) + if self.mixer == 'Local': + attn += self.mask + attn = functional.softmax(attn, dim=-1) + attn = self.attn_drop(attn) + + x = (attn.matmul(v)).permute((0, 2, 1, 3)).reshape((-1, N, C)) + x = self.proj(x) + x = self.proj_drop(x) + return x + + +class Block(nn.Module): + def __init__(self, + dim, + num_heads, + mixer='Global', + local_mixer=(7, 11), + HW=(8, 25), + mlp_ratio=4., + qkv_bias=False, + qk_scale=None, + drop=0., + attn_drop=0., + drop_path=0., + act_layer=nn.GELU, + norm_layer='nn.LayerNorm', + epsilon=1e-6, + prenorm=True): + super().__init__() + if isinstance(norm_layer, str): + self.norm1 = eval(norm_layer)(dim, eps=epsilon) + else: + self.norm1 = norm_layer(dim) + if mixer == 'Global' or mixer == 'Local': + + self.mixer = Attention( + dim, + num_heads=num_heads, + mixer=mixer, + HW=HW, + local_k=local_mixer, + qkv_bias=qkv_bias, + qk_scale=qk_scale, + attn_drop=attn_drop, + proj_drop=drop) + elif mixer == 'Conv': + self.mixer = ConvMixer( + dim, num_heads=num_heads, HW=HW, local_k=local_mixer) + else: + raise TypeError("The mixer must be one of [Global, Local, Conv]") + + self.drop_path = DropPath(drop_path) if drop_path > 0. else Identity() + if isinstance(norm_layer, str): + self.norm2 = eval(norm_layer)(dim, eps=epsilon) + else: + self.norm2 = norm_layer(dim) + mlp_hidden_dim = int(dim * mlp_ratio) + self.mlp_ratio = mlp_ratio + self.mlp = Mlp(in_features=dim, + hidden_features=mlp_hidden_dim, + act_layer=act_layer, + drop=drop) + self.prenorm = prenorm + + def forward(self, x): + if self.prenorm: + x = self.norm1(x + self.drop_path(self.mixer(x))) + x = self.norm2(x + self.drop_path(self.mlp(x))) + else: + x = x + self.drop_path(self.mixer(self.norm1(x))) + x = x + self.drop_path(self.mlp(self.norm2(x))) + return x + + +class PatchEmbed(nn.Module): + """ Image to Patch Embedding + """ + + def __init__(self, + img_size=(32, 100), + in_channels=3, + embed_dim=768, + sub_num=2): + super().__init__() + num_patches = (img_size[1] // (2 ** sub_num)) * \ + (img_size[0] // (2 ** sub_num)) + self.img_size = img_size + self.num_patches = num_patches + self.embed_dim = embed_dim + self.norm = None + if sub_num == 2: + self.proj = nn.Sequential( + ConvBNLayer( + in_channels=in_channels, + out_channels=embed_dim // 2, + kernel_size=3, + stride=2, + padding=1, + act=nn.GELU, + bias_attr=False), + ConvBNLayer( + in_channels=embed_dim // 2, + out_channels=embed_dim, + kernel_size=3, + stride=2, + padding=1, + act=nn.GELU, + bias_attr=False)) + if sub_num == 3: + self.proj = nn.Sequential( + ConvBNLayer( + in_channels=in_channels, + out_channels=embed_dim // 4, + kernel_size=3, + stride=2, + padding=1, + act=nn.GELU, + bias_attr=False), + ConvBNLayer( + in_channels=embed_dim // 4, + out_channels=embed_dim // 2, + kernel_size=3, + stride=2, + padding=1, + act=nn.GELU, + bias_attr=False), + ConvBNLayer( + in_channels=embed_dim // 2, + out_channels=embed_dim, + kernel_size=3, + stride=2, + padding=1, + act=nn.GELU, + bias_attr=False)) + + def forward(self, x): + B, C, H, W = x.shape + assert H == self.img_size[0] and W == self.img_size[1], \ + f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})." + x = self.proj(x).flatten(2).permute(0, 2, 1) + return x + + +class SubSample(nn.Module): + def __init__(self, + in_channels, + out_channels, + types='Pool', + stride=(2, 1), + sub_norm='nn.LayerNorm', + act=None): + super().__init__() + self.types = types + if types == 'Pool': + self.avgpool = nn.AvgPool2d( + kernel_size=(3, 5), stride=stride, padding=(1, 2)) + self.maxpool = nn.MaxPool2d( + kernel_size=(3, 5), stride=stride, padding=(1, 2)) + self.proj = nn.Linear(in_channels, out_channels) + else: + self.conv = nn.Conv2d( + in_channels, + out_channels, + kernel_size=3, + stride=stride, + padding=1, + # weight_attr=ParamAttr(initializer=KaimingNormal()) + ) + self.norm = eval(sub_norm)(out_channels) + if act is not None: + self.act = act() + else: + self.act = None + + def forward(self, x): + + if self.types == 'Pool': + x1 = self.avgpool(x) + x2 = self.maxpool(x) + x = (x1 + x2) * 0.5 + out = self.proj(x.flatten(2).permute((0, 2, 1))) + else: + x = self.conv(x) + out = x.flatten(2).permute((0, 2, 1)) + out = self.norm(out) + if self.act is not None: + out = self.act(out) + + return out + + +class SVTRNet(nn.Module): + def __init__( + self, + img_size=[48, 100], + in_channels=3, + embed_dim=[64, 128, 256], + depth=[3, 6, 3], + num_heads=[2, 4, 8], + mixer=['Local'] * 6 + ['Global'] * + 6, # Local atten, Global atten, Conv + local_mixer=[[7, 11], [7, 11], [7, 11]], + patch_merging='Conv', # Conv, Pool, None + mlp_ratio=4, + qkv_bias=True, + qk_scale=None, + drop_rate=0., + last_drop=0.1, + attn_drop_rate=0., + drop_path_rate=0.1, + norm_layer='nn.LayerNorm', + sub_norm='nn.LayerNorm', + epsilon=1e-6, + out_channels=192, + out_char_num=25, + block_unit='Block', + act='nn.GELU', + last_stage=True, + sub_num=2, + prenorm=True, + use_lenhead=False, + **kwargs): + super().__init__() + self.img_size = img_size + self.embed_dim = embed_dim + self.out_channels = out_channels + self.prenorm = prenorm + patch_merging = None if patch_merging != 'Conv' and patch_merging != 'Pool' else patch_merging + self.patch_embed = PatchEmbed( + img_size=img_size, + in_channels=in_channels, + embed_dim=embed_dim[0], + sub_num=sub_num) + num_patches = self.patch_embed.num_patches + self.HW = [img_size[0] // (2**sub_num), img_size[1] // (2**sub_num)] + self.pos_embed = nn.Parameter(torch.zeros(1, num_patches, embed_dim[0])) + # self.pos_embed = self.create_parameter( + # shape=[1, num_patches, embed_dim[0]], default_initializer=zeros_) + + # self.add_parameter("pos_embed", self.pos_embed) + + self.pos_drop = nn.Dropout(p=drop_rate) + Block_unit = eval(block_unit) + + dpr = np.linspace(0, drop_path_rate, sum(depth)) + self.blocks1 = nn.ModuleList( + [ + Block_unit( + dim=embed_dim[0], + num_heads=num_heads[0], + mixer=mixer[0:depth[0]][i], + HW=self.HW, + local_mixer=local_mixer[0], + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, + qk_scale=qk_scale, + drop=drop_rate, + act_layer=eval(act), + attn_drop=attn_drop_rate, + drop_path=dpr[0:depth[0]][i], + norm_layer=norm_layer, + epsilon=epsilon, + prenorm=prenorm) for i in range(depth[0]) + ] + ) + if patch_merging is not None: + self.sub_sample1 = SubSample( + embed_dim[0], + embed_dim[1], + sub_norm=sub_norm, + stride=[2, 1], + types=patch_merging) + HW = [self.HW[0] // 2, self.HW[1]] + else: + HW = self.HW + self.patch_merging = patch_merging + self.blocks2 = nn.ModuleList([ + Block_unit( + dim=embed_dim[1], + num_heads=num_heads[1], + mixer=mixer[depth[0]:depth[0] + depth[1]][i], + HW=HW, + local_mixer=local_mixer[1], + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, + qk_scale=qk_scale, + drop=drop_rate, + act_layer=eval(act), + attn_drop=attn_drop_rate, + drop_path=dpr[depth[0]:depth[0] + depth[1]][i], + norm_layer=norm_layer, + epsilon=epsilon, + prenorm=prenorm) for i in range(depth[1]) + ]) + if patch_merging is not None: + self.sub_sample2 = SubSample( + embed_dim[1], + embed_dim[2], + sub_norm=sub_norm, + stride=[2, 1], + types=patch_merging) + HW = [self.HW[0] // 4, self.HW[1]] + else: + HW = self.HW + self.blocks3 = nn.ModuleList([ + Block_unit( + dim=embed_dim[2], + num_heads=num_heads[2], + mixer=mixer[depth[0] + depth[1]:][i], + HW=HW, + local_mixer=local_mixer[2], + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, + qk_scale=qk_scale, + drop=drop_rate, + act_layer=eval(act), + attn_drop=attn_drop_rate, + drop_path=dpr[depth[0] + depth[1]:][i], + norm_layer=norm_layer, + epsilon=epsilon, + prenorm=prenorm) for i in range(depth[2]) + ]) + self.last_stage = last_stage + if last_stage: + self.avg_pool = nn.AdaptiveAvgPool2d((1, out_char_num)) + self.last_conv = nn.Conv2d( + in_channels=embed_dim[2], + out_channels=self.out_channels, + kernel_size=1, + stride=1, + padding=0, + bias=False) + self.hardswish = nn.Hardswish() + self.dropout = nn.Dropout(p=last_drop) + if not prenorm: + self.norm = eval(norm_layer)(embed_dim[-1], epsilon=epsilon) + self.use_lenhead = use_lenhead + if use_lenhead: + self.len_conv = nn.Linear(embed_dim[2], self.out_channels) + self.hardswish_len = nn.Hardswish() + self.dropout_len = nn.Dropout( + p=last_drop) + + trunc_normal_(self.pos_embed,std=.02) + self.apply(self._init_weights) + + def _init_weights(self, m): + if isinstance(m, nn.Linear): + trunc_normal_(m.weight,std=.02) + if isinstance(m, nn.Linear) and m.bias is not None: + zeros_(m.bias) + elif isinstance(m, nn.LayerNorm): + zeros_(m.bias) + ones_(m.weight) + + def forward_features(self, x): + x = self.patch_embed(x) + x = x + self.pos_embed + x = self.pos_drop(x) + for blk in self.blocks1: + x = blk(x) + if self.patch_merging is not None: + x = self.sub_sample1( + x.permute([0, 2, 1]).reshape( + [-1, self.embed_dim[0], self.HW[0], self.HW[1]])) + for blk in self.blocks2: + x = blk(x) + if self.patch_merging is not None: + x = self.sub_sample2( + x.permute([0, 2, 1]).reshape( + [-1, self.embed_dim[1], self.HW[0] // 2, self.HW[1]])) + for blk in self.blocks3: + x = blk(x) + if not self.prenorm: + x = self.norm(x) + return x + + def forward(self, x): + x = self.forward_features(x) + if self.use_lenhead: + len_x = self.len_conv(x.mean(1)) + len_x = self.dropout_len(self.hardswish_len(len_x)) + if self.last_stage: + if self.patch_merging is not None: + h = self.HW[0] // 4 + else: + h = self.HW[0] + x = self.avg_pool( + x.permute([0, 2, 1]).reshape( + [-1, self.embed_dim[2], h, self.HW[1]])) + x = self.last_conv(x) + x = self.hardswish(x) + x = self.dropout(x) + if self.use_lenhead: + return x, len_x + return x + + +if __name__=="__main__": + a = torch.rand(1,3,48,100) + svtr = SVTRNet() + + out = svtr(a) + print(svtr) + print(out.size()) \ No newline at end of file diff --git a/examples/research_projects/anytext/ocr_recog/common.py b/examples/research_projects/anytext/ocr_recog/common.py new file mode 100644 index 000000000000..a328bb034a37 --- /dev/null +++ b/examples/research_projects/anytext/ocr_recog/common.py @@ -0,0 +1,74 @@ + + +import torch +import torch.nn as nn +import torch.nn.functional as F + + +class Hswish(nn.Module): + def __init__(self, inplace=True): + super(Hswish, self).__init__() + self.inplace = inplace + + def forward(self, x): + return x * F.relu6(x + 3., inplace=self.inplace) / 6. + +# out = max(0, min(1, slop*x+offset)) +# paddle.fluid.layers.hard_sigmoid(x, slope=0.2, offset=0.5, name=None) +class Hsigmoid(nn.Module): + def __init__(self, inplace=True): + super(Hsigmoid, self).__init__() + self.inplace = inplace + + def forward(self, x): + # torch: F.relu6(x + 3., inplace=self.inplace) / 6. + # paddle: F.relu6(1.2 * x + 3., inplace=self.inplace) / 6. + return F.relu6(1.2 * x + 3., inplace=self.inplace) / 6. + +class GELU(nn.Module): + def __init__(self, inplace=True): + super(GELU, self).__init__() + self.inplace = inplace + + def forward(self, x): + return torch.nn.functional.gelu(x) + + +class Swish(nn.Module): + def __init__(self, inplace=True): + super(Swish, self).__init__() + self.inplace = inplace + + def forward(self, x): + if self.inplace: + x.mul_(torch.sigmoid(x)) + return x + else: + return x*torch.sigmoid(x) + + +class Activation(nn.Module): + def __init__(self, act_type, inplace=True): + super(Activation, self).__init__() + act_type = act_type.lower() + if act_type == 'relu': + self.act = nn.ReLU(inplace=inplace) + elif act_type == 'relu6': + self.act = nn.ReLU6(inplace=inplace) + elif act_type == 'sigmoid': + raise NotImplementedError + elif act_type == 'hard_sigmoid': + self.act = Hsigmoid(inplace) + elif act_type == 'hard_swish': + self.act = Hswish(inplace=inplace) + elif act_type == 'leakyrelu': + self.act = nn.LeakyReLU(inplace=inplace) + elif act_type == 'gelu': + self.act = GELU(inplace=inplace) + elif act_type == 'swish': + self.act = Swish(inplace=inplace) + else: + raise NotImplementedError + + def forward(self, inputs): + return self.act(inputs) \ No newline at end of file diff --git a/examples/research_projects/anytext/ocr_recog/en_dict.txt b/examples/research_projects/anytext/ocr_recog/en_dict.txt new file mode 100644 index 000000000000..7677d31b9d3f --- /dev/null +++ b/examples/research_projects/anytext/ocr_recog/en_dict.txt @@ -0,0 +1,95 @@ +0 +1 +2 +3 +4 +5 +6 +7 +8 +9 +: +; +< += +> +? +@ +A +B +C +D +E +F +G +H +I +J +K +L +M +N +O +P +Q +R +S +T +U +V +W +X +Y +Z +[ +\ +] +^ +_ +` +a +b +c +d +e +f +g +h +i +j +k +l +m +n +o +p +q +r +s +t +u +v +w +x +y +z +{ +| +} +~ +! +" +# +$ +% +& +' +( +) +* ++ +, +- +. +/ + diff --git a/examples/research_projects/anytext/recognizer.py b/examples/research_projects/anytext/recognizer.py new file mode 100755 index 000000000000..a9fa3880906a --- /dev/null +++ b/examples/research_projects/anytext/recognizer.py @@ -0,0 +1,315 @@ +""" +Copyright (c) Alibaba, Inc. and its affiliates. +""" +import math +import os +import sys +import time +import traceback + +import cv2 +import numpy as np +import torch +import torch.nn.functional as F +from easydict import EasyDict as edict +from ocr_recog.RecModel import RecModel +from skimage.transform._geometric import _umeyama as get_sym_mat + + +sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))) + + +def min_bounding_rect(img): + ret, thresh = cv2.threshold(img, 127, 255, 0) + contours, hierarchy = cv2.findContours(thresh, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE) + if len(contours) == 0: + print("Bad contours, using fake bbox...") + return np.array([[0, 0], [100, 0], [100, 100], [0, 100]]) + max_contour = max(contours, key=cv2.contourArea) + rect = cv2.minAreaRect(max_contour) + box = cv2.boxPoints(rect) + box = np.int0(box) + # sort + x_sorted = sorted(box, key=lambda x: x[0]) + left = x_sorted[:2] + right = x_sorted[2:] + left = sorted(left, key=lambda x: x[1]) + (tl, bl) = left + right = sorted(right, key=lambda x: x[1]) + (tr, br) = right + if tl[1] > bl[1]: + (tl, bl) = (bl, tl) + if tr[1] > br[1]: + (tr, br) = (br, tr) + return np.array([tl, tr, br, bl]) + + +def adjust_image(box, img): + pts1 = np.float32([box[0], box[1], box[2], box[3]]) + width = max(np.linalg.norm(pts1[0] - pts1[1]), np.linalg.norm(pts1[2] - pts1[3])) + height = max(np.linalg.norm(pts1[0] - pts1[3]), np.linalg.norm(pts1[1] - pts1[2])) + pts2 = np.float32([[0, 0], [width, 0], [width, height], [0, height]]) + # get transform matrix + M = get_sym_mat(pts1, pts2, estimate_scale=True) + C, H, W = img.shape + T = np.array([[2 / W, 0, -1], [0, 2 / H, -1], [0, 0, 1]]) + theta = np.linalg.inv(T @ M @ np.linalg.inv(T)) + theta = torch.from_numpy(theta[:2, :]).unsqueeze(0).type(torch.float32).to(img.device) + grid = F.affine_grid(theta, torch.Size([1, C, H, W]), align_corners=True) + result = F.grid_sample(img.unsqueeze(0), grid, align_corners=True) + result = torch.clamp(result.squeeze(0), 0, 255) + # crop + result = result[:, : int(height), : int(width)] + return result + + +""" +mask: numpy.ndarray, mask of textual, HWC +src_img: torch.Tensor, source image, CHW +""" + + +def crop_image(src_img, mask): + box = min_bounding_rect(mask) + result = adjust_image(box, src_img) + if len(result.shape) == 2: + result = torch.stack([result] * 3, axis=-1) + return result + + +def create_predictor(model_dir=None, model_lang="ch", is_onnx=False): + model_file_path = model_dir + if model_file_path is not None and not os.path.exists(model_file_path): + raise ValueError("not find model file path {}".format(model_file_path)) + + if is_onnx: + import onnxruntime as ort + + sess = ort.InferenceSession( + model_file_path, providers=["CPUExecutionProvider"] + ) # 'TensorrtExecutionProvider', 'CUDAExecutionProvider', 'CPUExecutionProvider' + return sess + else: + if model_lang == "ch": + n_class = 6625 + elif model_lang == "en": + n_class = 97 + else: + raise ValueError(f"Unsupported OCR recog model_lang: {model_lang}") + rec_config = edict( + in_channels=3, + backbone=edict(type="MobileNetV1Enhance", scale=0.5, last_conv_stride=[1, 2], last_pool_type="avg"), + neck=edict(type="SequenceEncoder", encoder_type="svtr", dims=64, depth=2, hidden_dims=120, use_guide=True), + head=edict(type="CTCHead", fc_decay=0.00001, out_channels=n_class, return_feats=True), + ) + + rec_model = RecModel(rec_config) + if model_file_path is not None: + rec_model.load_state_dict(torch.load(model_file_path, map_location="cpu")) + rec_model.eval() + return rec_model.eval() + + +def _check_image_file(path): + img_end = ("tiff", "tif", "bmp", "rgb", "jpg", "png", "jpeg") + return path.lower().endswith(tuple(img_end)) + + +def get_image_file_list(img_file): + imgs_lists = [] + if img_file is None or not os.path.exists(img_file): + raise Exception("not found any img file in {}".format(img_file)) + if os.path.isfile(img_file) and _check_image_file(img_file): + imgs_lists.append(img_file) + elif os.path.isdir(img_file): + for single_file in os.listdir(img_file): + file_path = os.path.join(img_file, single_file) + if os.path.isfile(file_path) and _check_image_file(file_path): + imgs_lists.append(file_path) + if len(imgs_lists) == 0: + raise Exception("not found any img file in {}".format(img_file)) + imgs_lists = sorted(imgs_lists) + return imgs_lists + + +class TextRecognizer(object): + def __init__(self, args, predictor): + self.rec_image_shape = [int(v) for v in args.rec_image_shape.split(",")] + self.rec_batch_num = args.rec_batch_num + self.predictor = predictor + self.chars = self.get_char_dict(args.rec_char_dict_path) + self.char2id = {x: i for i, x in enumerate(self.chars)} + self.is_onnx = not isinstance(self.predictor, torch.nn.Module) + self.use_fp16 = args.use_fp16 + + # img: CHW + def resize_norm_img(self, img, max_wh_ratio): + imgC, imgH, imgW = self.rec_image_shape + assert imgC == img.shape[0] + imgW = int((imgH * max_wh_ratio)) + + h, w = img.shape[1:] + ratio = w / float(h) + if math.ceil(imgH * ratio) > imgW: + resized_w = imgW + else: + resized_w = int(math.ceil(imgH * ratio)) + resized_image = torch.nn.functional.interpolate( + img.unsqueeze(0), + size=(imgH, resized_w), + mode="bilinear", + align_corners=True, + ) + resized_image /= 255.0 + resized_image -= 0.5 + resized_image /= 0.5 + padding_im = torch.zeros((imgC, imgH, imgW), dtype=torch.float32).to(img.device) + padding_im[:, :, 0:resized_w] = resized_image[0] + return padding_im + + # img_list: list of tensors with shape chw 0-255 + def pred_imglist(self, img_list, show_debug=False): + img_num = len(img_list) + assert img_num > 0 + # Calculate the aspect ratio of all text bars + width_list = [] + for img in img_list: + width_list.append(img.shape[2] / float(img.shape[1])) + # Sorting can speed up the recognition process + indices = torch.from_numpy(np.argsort(np.array(width_list))) + batch_num = self.rec_batch_num + preds_all = [None] * img_num + preds_neck_all = [None] * img_num + for beg_img_no in range(0, img_num, batch_num): + end_img_no = min(img_num, beg_img_no + batch_num) + norm_img_batch = [] + + imgC, imgH, imgW = self.rec_image_shape[:3] + max_wh_ratio = imgW / imgH + for ino in range(beg_img_no, end_img_no): + h, w = img_list[indices[ino]].shape[1:] + if h > w * 1.2: + img = img_list[indices[ino]] + img = torch.transpose(img, 1, 2).flip(dims=[1]) + img_list[indices[ino]] = img + h, w = img.shape[1:] + # wh_ratio = w * 1.0 / h + # max_wh_ratio = max(max_wh_ratio, wh_ratio) # comment to not use different ratio + for ino in range(beg_img_no, end_img_no): + norm_img = self.resize_norm_img(img_list[indices[ino]], max_wh_ratio) + if self.use_fp16: + norm_img = norm_img.half() + norm_img = norm_img.unsqueeze(0) + norm_img_batch.append(norm_img) + norm_img_batch = torch.cat(norm_img_batch, dim=0) + if show_debug: + for i in range(len(norm_img_batch)): + _img = norm_img_batch[i].permute(1, 2, 0).detach().cpu().numpy() + _img = (_img + 0.5) * 255 + _img = _img[:, :, ::-1] + file_name = f"{indices[beg_img_no + i]}" + if os.path.exists(file_name + ".jpg"): + file_name += "_2" # ori image + cv2.imwrite(file_name + ".jpg", _img) + if self.is_onnx: + input_dict = {} + input_dict[self.predictor.get_inputs()[0].name] = norm_img_batch.detach().cpu().numpy() + outputs = self.predictor.run(None, input_dict) + preds = {} + preds["ctc"] = torch.from_numpy(outputs[0]) + preds["ctc_neck"] = [torch.zeros(1)] * img_num + else: + preds = self.predictor(norm_img_batch) + for rno in range(preds["ctc"].shape[0]): + preds_all[indices[beg_img_no + rno]] = preds["ctc"][rno] + preds_neck_all[indices[beg_img_no + rno]] = preds["ctc_neck"][rno] + + return torch.stack(preds_all, dim=0), torch.stack(preds_neck_all, dim=0) + + def get_char_dict(self, character_dict_path): + character_str = [] + with open(character_dict_path, "rb") as fin: + lines = fin.readlines() + for line in lines: + line = line.decode("utf-8").strip("\n").strip("\r\n") + character_str.append(line) + dict_character = list(character_str) + dict_character = ["sos"] + dict_character + [" "] # eos is space + return dict_character + + def get_text(self, order): + char_list = [self.chars[text_id] for text_id in order] + return "".join(char_list) + + def decode(self, mat): + text_index = mat.detach().cpu().numpy().argmax(axis=1) + ignored_tokens = [0] + selection = np.ones(len(text_index), dtype=bool) + selection[1:] = text_index[1:] != text_index[:-1] + for ignored_token in ignored_tokens: + selection &= text_index != ignored_token + return text_index[selection], np.where(selection)[0] + + def get_ctcloss(self, preds, gt_text, weight): + if not isinstance(weight, torch.Tensor): + weight = torch.tensor(weight).to(preds.device) + ctc_loss = torch.nn.CTCLoss(reduction="none") + log_probs = preds.log_softmax(dim=2).permute(1, 0, 2) # NTC-->TNC + targets = [] + target_lengths = [] + for t in gt_text: + targets += [self.char2id.get(i, len(self.chars) - 1) for i in t] + target_lengths += [len(t)] + targets = torch.tensor(targets).to(preds.device) + target_lengths = torch.tensor(target_lengths).to(preds.device) + input_lengths = torch.tensor([log_probs.shape[0]] * (log_probs.shape[1])).to(preds.device) + loss = ctc_loss(log_probs, targets, input_lengths, target_lengths) + loss = loss / input_lengths * weight + return loss + + +def main(): + rec_model_dir = "./ocr_weights/ppv3_rec.pth" + predictor = create_predictor(rec_model_dir) + args = edict() + args.rec_image_shape = "3, 48, 320" + args.rec_char_dict_path = "./ocr_weights/ppocr_keys_v1.txt" + args.rec_batch_num = 6 + text_recognizer = TextRecognizer(args, predictor) + image_dir = "./test_imgs_cn" + gt_text = ["韩国小馆"] * 14 + + image_file_list = get_image_file_list(image_dir) + valid_image_file_list = [] + img_list = [] + + for image_file in image_file_list: + img = cv2.imread(image_file) + if img is None: + print("error in loading image:{}".format(image_file)) + continue + valid_image_file_list.append(image_file) + img_list.append(torch.from_numpy(img).permute(2, 0, 1).float()) + try: + tic = time.time() + times = [] + for i in range(10): + preds, _ = text_recognizer.pred_imglist(img_list) # get text + preds_all = preds.softmax(dim=2) + times += [(time.time() - tic) * 1000.0] + tic = time.time() + print(times) + print(np.mean(times[1:]) / len(preds_all)) + weight = np.ones(len(gt_text)) + loss = text_recognizer.get_ctcloss(preds, gt_text, weight) + for i in range(len(valid_image_file_list)): + pred = preds_all[i] + order, idx = text_recognizer.decode(pred) + text = text_recognizer.get_text(order) + print(f'{valid_image_file_list[i]}: pred/gt="{text}"/"{gt_text[i]}", loss={loss[i]:.2f}') + except Exception as E: + print(traceback.format_exc(), E) + + +if __name__ == "__main__": + main() From cd4c9c2246738bf20f07eb6b0737aee0343a0928 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tolga=20Cang=C3=B6z?= Date: Fri, 2 Aug 2024 19:00:13 +0300 Subject: [PATCH 25/87] chore: Update `TextEmbeddingModule` to include OCR model components and dependencies --- .../anytext/embedding_manager.py | 196 ++++++++++++++++++ .../anytext/text_embedding_module.py | 31 ++- 2 files changed, 225 insertions(+), 2 deletions(-) create mode 100644 examples/research_projects/anytext/embedding_manager.py diff --git a/examples/research_projects/anytext/embedding_manager.py b/examples/research_projects/anytext/embedding_manager.py new file mode 100644 index 000000000000..cbaab5aab682 --- /dev/null +++ b/examples/research_projects/anytext/embedding_manager.py @@ -0,0 +1,196 @@ +""" +Copyright (c) Alibaba, Inc. and its affiliates. +""" +from functools import partial + +import torch +import torch.nn as nn +import torch.nn.functional as F + + +def conv_nd(dims, *args, **kwargs): + """ + Create a 1D, 2D, or 3D convolution module. + """ + if dims == 1: + return nn.Conv1d(*args, **kwargs) + elif dims == 2: + return nn.Conv2d(*args, **kwargs) + elif dims == 3: + return nn.Conv3d(*args, **kwargs) + raise ValueError(f"unsupported dimensions: {dims}") + + +# Copied from diffusers.models.controlnet.zero_module +def zero_module(module: nn.Module) -> nn.Module: + for p in module.parameters(): + nn.init.zeros_(p) + return module + + +def get_clip_token_for_string(tokenizer, string): + batch_encoding = tokenizer( + string, + truncation=True, + max_length=77, + return_length=True, + return_overflowing_tokens=False, + padding="max_length", + return_tensors="pt", + ) + tokens = batch_encoding["input_ids"] + assert ( + torch.count_nonzero(tokens - 49407) == 2 + ), f"String '{string}' maps to more than a single token. Please use another string" + return tokens[0, 1] + + +def get_bert_token_for_string(tokenizer, string): + token = tokenizer(string) + assert ( + torch.count_nonzero(token) == 3 + ), f"String '{string}' maps to more than a single token. Please use another string" + token = token[0, 1] + return token + + +def get_clip_vision_emb(encoder, processor, img): + _img = img.repeat(1, 3, 1, 1) * 255 + inputs = processor(images=_img, return_tensors="pt") + inputs["pixel_values"] = inputs["pixel_values"].to(img.device) + outputs = encoder(**inputs) + emb = outputs.image_embeds + return emb + + +def get_recog_emb(encoder, img_list): + _img_list = [(img.repeat(1, 3, 1, 1) * 255)[0] for img in img_list] + encoder.predictor.eval() + _, preds_neck = encoder.pred_imglist(_img_list, show_debug=False) + return preds_neck + + +def pad_H(x): + _, _, H, W = x.shape + p_top = (W - H) // 2 + p_bot = W - H - p_top + return F.pad(x, (0, 0, p_top, p_bot)) + + +class EncodeNet(nn.Module): + def __init__(self, in_channels, out_channels): + super(EncodeNet, self).__init__() + chan = 16 + n_layer = 4 # downsample + + self.conv1 = conv_nd(2, in_channels, chan, 3, padding=1) + self.conv_list = nn.ModuleList([]) + _c = chan + for i in range(n_layer): + self.conv_list.append(conv_nd(2, _c, _c * 2, 3, padding=1, stride=2)) + _c *= 2 + self.conv2 = conv_nd(2, _c, out_channels, 3, padding=1) + self.avgpool = nn.AdaptiveAvgPool2d(1) + self.act = nn.SiLU() + + def forward(self, x): + x = self.act(self.conv1(x)) + for layer in self.conv_list: + x = self.act(layer(x)) + x = self.act(self.conv2(x)) + x = self.avgpool(x) + x = x.view(x.size(0), -1) + return x + + +class EmbeddingManager(nn.Module): + def __init__( + self, + embedder, + valid=True, + glyph_channels=20, + position_channels=1, + placeholder_string="*", + add_pos=False, + emb_type="ocr", + **kwargs, + ): + super().__init__() + if hasattr(embedder, "tokenizer"): # using Stable Diffusion's CLIP encoder + get_token_for_string = partial(get_clip_token_for_string, embedder.tokenizer) + token_dim = 768 + if hasattr(embedder, "vit"): + assert emb_type == "vit" + self.get_vision_emb = partial(get_clip_vision_emb, embedder.vit, embedder.processor) + self.get_recog_emb = None + else: # using LDM's BERT encoder + get_token_for_string = partial(get_bert_token_for_string, embedder.tknz_fn) + token_dim = 1280 + self.token_dim = token_dim + self.emb_type = emb_type + + self.add_pos = add_pos + if add_pos: + self.position_encoder = EncodeNet(position_channels, token_dim) + if emb_type == "ocr": + self.proj = nn.Sequential(zero_module(nn.Linear(40 * 64, token_dim)), nn.LayerNorm(token_dim)) + if emb_type == "conv": + self.glyph_encoder = EncodeNet(glyph_channels, token_dim) + + self.placeholder_token = get_token_for_string(placeholder_string) + + def encode_text(self, text_info): + if self.get_recog_emb is None and self.emb_type == "ocr": + self.get_recog_emb = partial(get_recog_emb, self.recog) + + gline_list = [] + pos_list = [] + for i in range(len(text_info["n_lines"])): # sample index in a batch + n_lines = text_info["n_lines"][i] + for j in range(n_lines): # line + gline_list += [text_info["gly_line"][j][i : i + 1]] + if self.add_pos: + pos_list += [text_info["positions"][j][i : i + 1]] + + if len(gline_list) > 0: + if self.emb_type == "ocr": + recog_emb = self.get_recog_emb(gline_list) + enc_glyph = self.proj(recog_emb.reshape(recog_emb.shape[0], -1)) + elif self.emb_type == "vit": + enc_glyph = self.get_vision_emb(pad_H(torch.cat(gline_list, dim=0))) + elif self.emb_type == "conv": + enc_glyph = self.glyph_encoder(pad_H(torch.cat(gline_list, dim=0))) + if self.add_pos: + enc_pos = self.position_encoder(torch.cat(gline_list, dim=0)) + enc_glyph = enc_glyph + enc_pos + + self.text_embs_all = [] + n_idx = 0 + for i in range(len(text_info["n_lines"])): # sample index in a batch + n_lines = text_info["n_lines"][i] + text_embs = [] + for j in range(n_lines): # line + text_embs += [enc_glyph[n_idx : n_idx + 1]] + n_idx += 1 + self.text_embs_all += [text_embs] + + def forward( + self, + tokenized_text, + embedded_text, + ): + b, device = tokenized_text.shape[0], tokenized_text.device + for i in range(b): + idx = tokenized_text[i] == self.placeholder_token.to(device) + if sum(idx) > 0: + if i >= len(self.text_embs_all): + print("truncation for log images...") + break + text_emb = torch.cat(self.text_embs_all[i], dim=0) + if sum(idx) != len(text_emb): + print("truncation for long caption...") + embedded_text[i][idx] = text_emb[: sum(idx)] + return embedded_text + + def embedding_parameters(self): + return self.parameters() diff --git a/examples/research_projects/anytext/text_embedding_module.py b/examples/research_projects/anytext/text_embedding_module.py index f46f56908aed..fe3c1a4486ab 100644 --- a/examples/research_projects/anytext/text_embedding_module.py +++ b/examples/research_projects/anytext/text_embedding_module.py @@ -7,6 +7,7 @@ import cv2 import numpy as np import torch +from easydict import EasyDict as edict from PIL import Image, ImageDraw, ImageFont from torch import nn @@ -14,19 +15,45 @@ from diffusers.models.lora import adjust_lora_scale_text_encoder from diffusers.utils import USE_PEFT_BACKEND, logging, scale_lora_layers, unscale_lora_layers +from .embedding_manager import EmbeddingManager from .frozen_clip_embedder_t3 import FrozenCLIPEmbedderT3 +from .recognizer import TextRecognizer, create_predictor logger = logging.get_logger(__name__) # pylint: disable=invalid-name class TextEmbeddingModule(nn.Module): - def __init__(self, font_path): + def __init__(self, font_path, device): super().__init__() + self.device = device self.font = ImageFont.truetype(font_path, 60) self.ocr_model = ... self.linear = nn.Linear() - self.frozen_CLIP_embedder_t3 = FrozenCLIPEmbedderT3() + self.frozen_CLIP_embedder_t3 = FrozenCLIPEmbedderT3(device=self.device) + self.embedding_manager_config = { + "valid": True, + "emb_type": "ocr", + "glyph_channels": 1, + "position_channels": 1, + "add_pos": False, + "placeholder_string": "*", + } + self.embedding_manager = EmbeddingManager(self.frozen_CLIP_embedder_t3, **self.embedding_manager_config) + # TODO: Understand the reason of param.requires_grad = True + for param in self.embedding_manager.embedding_parameters(): + param.requires_grad = True + rec_model_dir = "./ocr_weights/ppv3_rec.pth" + self.text_predictor = create_predictor(rec_model_dir).eval() + args = edict() + args.rec_image_shape = "3, 48, 320" + args.rec_batch_num = 6 + args.rec_char_dict_path = "./ocr_recog/ppocr_keys_v1.txt" + args.use_fp16 = self.use_fp16 + self.cn_recognizer = TextRecognizer(args, self.text_predictor) + for param in self.text_predictor.parameters(): + param.requires_grad = False + self.embedding_manager.recog = self.cn_recognizer @torch.no_grad() def forward(self, texts, prompt, device, num_images_per_prompt, do_classifier_free_guidance): From 0918cbd8f328d4cd9e992ce12a93a0aa1d9b3625 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tolga=20Cang=C3=B6z?= Date: Fri, 2 Aug 2024 19:03:11 +0300 Subject: [PATCH 26/87] chore: Update `AuxiliaryLatentModule` to include VAE model and its dependencies for masked image in the editing task --- .../anytext/auxiliary_latent_module.py | 41 +++++++++++++++---- 1 file changed, 34 insertions(+), 7 deletions(-) diff --git a/examples/research_projects/anytext/auxiliary_latent_module.py b/examples/research_projects/anytext/auxiliary_latent_module.py index 879ec8d90851..d245288c3c72 100644 --- a/examples/research_projects/anytext/auxiliary_latent_module.py +++ b/examples/research_projects/anytext/auxiliary_latent_module.py @@ -2,12 +2,15 @@ # +> fuse layer # position l_p -> position block -> +from typing import Optional + import cv2 import numpy as np import torch from PIL import Image, ImageDraw, ImageFont from torch import nn +from diffusers.models.autoencoders import AutoencoderKL from diffusers.utils import logging @@ -34,6 +37,20 @@ def zero_module(module: nn.Module) -> nn.Module: return module +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents +def retrieve_latents( + encoder_output: torch.Tensor, generator: Optional[torch.Generator] = None, sample_mode: str = "sample" +): + if hasattr(encoder_output, "latent_dist") and sample_mode == "sample": + return encoder_output.latent_dist.sample(generator) + elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax": + return encoder_output.latent_dist.mode() + elif hasattr(encoder_output, "latents"): + return encoder_output.latents + else: + raise AttributeError("Could not access latents of provided encoder_output") + + class AuxiliaryLatentModule(nn.Module): def __init__(self, font_path, dims=2, glyph_channels=256, position_channels=64, model_channels=256, **kwargs): super().__init__() @@ -42,6 +59,7 @@ def __init__(self, font_path, dims=2, glyph_channels=256, position_channels=64, self.font = ImageFont.truetype(font_path, 60) self.use_fp16 = kwargs.get("use_fp16", False) self.device = kwargs.get("device", "cpu") + self.scale_factor = 0.18215 self.glyph_block = nn.Sequential( conv_nd(dims, glyph_channels, 8, 3, padding=1), nn.SiLU(), @@ -80,6 +98,16 @@ def __init__(self, font_path, dims=2, glyph_channels=256, position_channels=64, nn.SiLU(), ) + self.vae = AutoencoderKL.from_pretrained( + "runwayml/stable-diffusion-v1-5", + subfolder="vae", + torch_dtype=torch.float16 if self.use_fp16 else torch.float32, + variant="fp16" if self.use_fp16 else "fp32", + ) + self.vae.eval() + for param in self.vae.parameters(): + param.requires_grad = False + self.fuse_block = zero_module(conv_nd(dims, 256 + 64 + 4, model_channels, 3, padding=1)) @torch.no_grad() @@ -216,8 +244,7 @@ def forward( masked_img = torch.from_numpy(masked_img.copy()).float().cpu() if self.use_fp16: masked_img = masked_img.half() - encoder_posterior = self.encode_first_stage(masked_img[None, ...]) - masked_x = self.get_first_stage_encoding(encoder_posterior).detach() + masked_x = self.encode_first_stage(masked_img[None, ...]).detach() if self.use_fp16: masked_x = masked_x.half() masked_x = torch.cat([masked_x for _ in range(img_count)], dim=0) @@ -228,13 +255,12 @@ def forward( enc_pos = self.position_block(positions, emb, context) guided_hint = self.fuse_block(torch.cat([enc_glyph, enc_pos, masked_x], dim=1)) - return guided_hint + hint = self.arr2tensor(np_hint, img_count) - def encode_first_stage(self, masked_img): - pass + return guided_hint, hint # , gly_pos_imgs - def get_first_stage_encoding(self, encoder_posterior): - pass + def encode_first_stage(self, masked_img): + return retrieve_latents(self.vae.encode(masked_img)) * self.scale_factor def arr2tensor(self, arr, bs): arr = np.transpose(arr, (2, 0, 1)) @@ -351,5 +377,6 @@ def to(self, device): self.device = device self.glyph_block = self.glyph_block.to(device) self.position_block = self.position_block.to(device) + self.vae = self.vae.to(device) self.fuse_block = self.fuse_block.to(device) return self From 37ae99fa65bc1dcce9ba0208ece130814b9eef94 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tolga=20Cang=C3=B6z?= Date: Fri, 2 Aug 2024 19:06:16 +0300 Subject: [PATCH 27/87] `make style` --- .../convert_original_anytext_to_diffusers.py | 2 +- .../anytext/ocr_recog/RNN.py | 147 +++--- .../anytext/ocr_recog/RecCTCHead.py | 19 +- .../anytext/ocr_recog/RecModel.py | 27 +- .../anytext/ocr_recog/RecMv1_enhance.py | 120 ++--- .../anytext/ocr_recog/RecSVTR.py | 421 +++++++++--------- .../anytext/ocr_recog/common.py | 28 +- 7 files changed, 352 insertions(+), 412 deletions(-) diff --git a/examples/research_projects/anytext/convert_original_anytext_to_diffusers.py b/examples/research_projects/anytext/convert_original_anytext_to_diffusers.py index 69a3d155ff33..4f5fd7aa01a8 100644 --- a/examples/research_projects/anytext/convert_original_anytext_to_diffusers.py +++ b/examples/research_projects/anytext/convert_original_anytext_to_diffusers.py @@ -1 +1 @@ -# In construction... \ No newline at end of file +# In construction... diff --git a/examples/research_projects/anytext/ocr_recog/RNN.py b/examples/research_projects/anytext/ocr_recog/RNN.py index cf16855b3711..aec796d987c0 100755 --- a/examples/research_projects/anytext/ocr_recog/RNN.py +++ b/examples/research_projects/anytext/ocr_recog/RNN.py @@ -1,13 +1,16 @@ -from torch import nn import torch +from torch import nn + from .RecSVTR import Block + class Swish(nn.Module): def __int__(self): super(Swish, self).__int__() - def forward(self,x): - return x*torch.sigmoid(x) + def forward(self, x): + return x * torch.sigmoid(x) + class Im2Im(nn.Module): def __init__(self, in_channels, **kwargs): @@ -17,6 +20,7 @@ def __init__(self, in_channels, **kwargs): def forward(self, x): return x + class Im2Seq(nn.Module): def __init__(self, in_channels, **kwargs): super().__init__() @@ -29,42 +33,40 @@ def forward(self, x): x = x.permute((0, 2, 1)) return x + class EncoderWithRNN(nn.Module): - def __init__(self, in_channels,**kwargs): + def __init__(self, in_channels, **kwargs): super(EncoderWithRNN, self).__init__() - hidden_size = kwargs.get('hidden_size', 256) + hidden_size = kwargs.get("hidden_size", 256) self.out_channels = hidden_size * 2 - self.lstm = nn.LSTM(in_channels, hidden_size, bidirectional=True, num_layers=2,batch_first=True) + self.lstm = nn.LSTM(in_channels, hidden_size, bidirectional=True, num_layers=2, batch_first=True) def forward(self, x): self.lstm.flatten_parameters() x, _ = self.lstm(x) return x + class SequenceEncoder(nn.Module): - def __init__(self, in_channels, encoder_type='rnn', **kwargs): + def __init__(self, in_channels, encoder_type="rnn", **kwargs): super(SequenceEncoder, self).__init__() self.encoder_reshape = Im2Seq(in_channels) self.out_channels = self.encoder_reshape.out_channels self.encoder_type = encoder_type - if encoder_type == 'reshape': + if encoder_type == "reshape": self.only_reshape = True else: - support_encoder_dict = { - 'reshape': Im2Seq, - 'rnn': EncoderWithRNN, - 'svtr': EncoderWithSVTR - } - assert encoder_type in support_encoder_dict, '{} must in {}'.format( - encoder_type, support_encoder_dict.keys()) - - self.encoder = support_encoder_dict[encoder_type]( - self.encoder_reshape.out_channels,**kwargs) + support_encoder_dict = {"reshape": Im2Seq, "rnn": EncoderWithRNN, "svtr": EncoderWithSVTR} + assert encoder_type in support_encoder_dict, "{} must in {}".format( + encoder_type, support_encoder_dict.keys() + ) + + self.encoder = support_encoder_dict[encoder_type](self.encoder_reshape.out_channels, **kwargs) self.out_channels = self.encoder.out_channels self.only_reshape = False def forward(self, x): - if self.encoder_type != 'svtr': + if self.encoder_type != "svtr": x = self.encoder_reshape(x) if not self.only_reshape: x = self.encoder(x) @@ -74,16 +76,11 @@ def forward(self, x): x = self.encoder_reshape(x) return x + class ConvBNLayer(nn.Module): - def __init__(self, - in_channels, - out_channels, - kernel_size=3, - stride=1, - padding=0, - bias_attr=False, - groups=1, - act=nn.GELU): + def __init__( + self, in_channels, out_channels, kernel_size=3, stride=1, padding=0, bias_attr=False, groups=1, act=nn.GELU + ): super().__init__() self.conv = nn.Conv2d( in_channels=in_channels, @@ -93,7 +90,8 @@ def __init__(self, padding=padding, groups=groups, # weight_attr=paddle.ParamAttr(initializer=nn.initializer.KaimingUniform()), - bias=bias_attr) + bias=bias_attr, + ) self.norm = nn.BatchNorm2d(out_channels) self.act = Swish() @@ -106,60 +104,60 @@ def forward(self, inputs): class EncoderWithSVTR(nn.Module): def __init__( - self, - in_channels, - dims=64, # XS - depth=2, - hidden_dims=120, - use_guide=False, - num_heads=8, - qkv_bias=True, - mlp_ratio=2.0, - drop_rate=0.1, - attn_drop_rate=0.1, - drop_path=0., - qk_scale=None): + self, + in_channels, + dims=64, # XS + depth=2, + hidden_dims=120, + use_guide=False, + num_heads=8, + qkv_bias=True, + mlp_ratio=2.0, + drop_rate=0.1, + attn_drop_rate=0.1, + drop_path=0.0, + qk_scale=None, + ): super(EncoderWithSVTR, self).__init__() self.depth = depth self.use_guide = use_guide - self.conv1 = ConvBNLayer( - in_channels, in_channels // 8, padding=1, act='swish') - self.conv2 = ConvBNLayer( - in_channels // 8, hidden_dims, kernel_size=1, act='swish') - - self.svtr_block = nn.ModuleList([ - Block( - dim=hidden_dims, - num_heads=num_heads, - mixer='Global', - HW=None, - mlp_ratio=mlp_ratio, - qkv_bias=qkv_bias, - qk_scale=qk_scale, - drop=drop_rate, - act_layer='swish', - attn_drop=attn_drop_rate, - drop_path=drop_path, - norm_layer='nn.LayerNorm', - epsilon=1e-05, - prenorm=False) for i in range(depth) - ]) + self.conv1 = ConvBNLayer(in_channels, in_channels // 8, padding=1, act="swish") + self.conv2 = ConvBNLayer(in_channels // 8, hidden_dims, kernel_size=1, act="swish") + + self.svtr_block = nn.ModuleList( + [ + Block( + dim=hidden_dims, + num_heads=num_heads, + mixer="Global", + HW=None, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, + qk_scale=qk_scale, + drop=drop_rate, + act_layer="swish", + attn_drop=attn_drop_rate, + drop_path=drop_path, + norm_layer="nn.LayerNorm", + epsilon=1e-05, + prenorm=False, + ) + for i in range(depth) + ] + ) self.norm = nn.LayerNorm(hidden_dims, eps=1e-6) - self.conv3 = ConvBNLayer( - hidden_dims, in_channels, kernel_size=1, act='swish') + self.conv3 = ConvBNLayer(hidden_dims, in_channels, kernel_size=1, act="swish") # last conv-nxn, the input is concat of input tensor and conv3 output tensor - self.conv4 = ConvBNLayer( - 2 * in_channels, in_channels // 8, padding=1, act='swish') + self.conv4 = ConvBNLayer(2 * in_channels, in_channels // 8, padding=1, act="swish") - self.conv1x1 = ConvBNLayer( - in_channels // 8, dims, kernel_size=1, act='swish') + self.conv1x1 = ConvBNLayer(in_channels // 8, dims, kernel_size=1, act="swish") self.out_channels = dims self.apply(self._init_weights) def _init_weights(self, m): # weight initialization if isinstance(m, nn.Conv2d): - nn.init.kaiming_normal_(m.weight, mode='fan_out') + nn.init.kaiming_normal_(m.weight, mode="fan_out") if m.bias is not None: nn.init.zeros_(m.bias) elif isinstance(m, nn.BatchNorm2d): @@ -170,7 +168,7 @@ def _init_weights(self, m): if m.bias is not None: nn.init.zeros_(m.bias) elif isinstance(m, nn.ConvTranspose2d): - nn.init.kaiming_normal_(m.weight, mode='fan_out') + nn.init.kaiming_normal_(m.weight, mode="fan_out") if m.bias is not None: nn.init.zeros_(m.bias) elif isinstance(m, nn.LayerNorm): @@ -205,6 +203,7 @@ def forward(self, x): return z -if __name__=="__main__": + +if __name__ == "__main__": svtrRNN = EncoderWithSVTR(56) - print(svtrRNN) \ No newline at end of file + print(svtrRNN) diff --git a/examples/research_projects/anytext/ocr_recog/RecCTCHead.py b/examples/research_projects/anytext/ocr_recog/RecCTCHead.py index 867ede9916b1..c066c6202b19 100755 --- a/examples/research_projects/anytext/ocr_recog/RecCTCHead.py +++ b/examples/research_projects/anytext/ocr_recog/RecCTCHead.py @@ -2,19 +2,16 @@ class CTCHead(nn.Module): - def __init__(self, - in_channels, - out_channels=6625, - fc_decay=0.0004, - mid_channels=None, - return_feats=False, - **kwargs): + def __init__( + self, in_channels, out_channels=6625, fc_decay=0.0004, mid_channels=None, return_feats=False, **kwargs + ): super(CTCHead, self).__init__() if mid_channels is None: self.fc = nn.Linear( in_channels, out_channels, - bias=True,) + bias=True, + ) else: self.fc1 = nn.Linear( in_channels, @@ -39,9 +36,9 @@ def forward(self, x, labels=None): predicts = self.fc2(x) if self.return_feats: - result = dict() - result['ctc'] = predicts - result['ctc_neck'] = x + result = {} + result["ctc"] = predicts + result["ctc_neck"] = x else: result = predicts diff --git a/examples/research_projects/anytext/ocr_recog/RecModel.py b/examples/research_projects/anytext/ocr_recog/RecModel.py index c2313bf02c95..50b0cec967d5 100755 --- a/examples/research_projects/anytext/ocr_recog/RecModel.py +++ b/examples/research_projects/anytext/ocr_recog/RecModel.py @@ -1,31 +1,32 @@ from torch import nn -from .RNN import SequenceEncoder, Im2Seq, Im2Im -from .RecMv1_enhance import MobileNetV1Enhance from .RecCTCHead import CTCHead +from .RecMv1_enhance import MobileNetV1Enhance +from .RNN import Im2Im, Im2Seq, SequenceEncoder + -backbone_dict = {"MobileNetV1Enhance":MobileNetV1Enhance} -neck_dict = {'SequenceEncoder': SequenceEncoder, 'Im2Seq': Im2Seq,'None':Im2Im} -head_dict = {'CTCHead':CTCHead} +backbone_dict = {"MobileNetV1Enhance": MobileNetV1Enhance} +neck_dict = {"SequenceEncoder": SequenceEncoder, "Im2Seq": Im2Seq, "None": Im2Im} +head_dict = {"CTCHead": CTCHead} class RecModel(nn.Module): def __init__(self, config): super().__init__() - assert 'in_channels' in config, 'in_channels must in model config' - backbone_type = config.backbone.pop('type') - assert backbone_type in backbone_dict, f'backbone.type must in {backbone_dict}' + assert "in_channels" in config, "in_channels must in model config" + backbone_type = config.backbone.pop("type") + assert backbone_type in backbone_dict, f"backbone.type must in {backbone_dict}" self.backbone = backbone_dict[backbone_type](config.in_channels, **config.backbone) - neck_type = config.neck.pop('type') - assert neck_type in neck_dict, f'neck.type must in {neck_dict}' + neck_type = config.neck.pop("type") + assert neck_type in neck_dict, f"neck.type must in {neck_dict}" self.neck = neck_dict[neck_type](self.backbone.out_channels, **config.neck) - head_type = config.head.pop('type') - assert head_type in head_dict, f'head.type must in {head_dict}' + head_type = config.head.pop("type") + assert head_type in head_dict, f"head.type must in {head_dict}" self.head = head_dict[head_type](self.neck.out_channels, **config.head) - self.name = f'RecModel_{backbone_type}_{neck_type}_{head_type}' + self.name = f"RecModel_{backbone_type}_{neck_type}_{head_type}" def load_3rd_state_dict(self, _3rd_name, _state): self.backbone.load_3rd_state_dict(_3rd_name, _state) diff --git a/examples/research_projects/anytext/ocr_recog/RecMv1_enhance.py b/examples/research_projects/anytext/ocr_recog/RecMv1_enhance.py index d5c848533dd3..df41519b2713 100644 --- a/examples/research_projects/anytext/ocr_recog/RecMv1_enhance.py +++ b/examples/research_projects/anytext/ocr_recog/RecMv1_enhance.py @@ -1,20 +1,14 @@ -import os, sys import torch import torch.nn as nn import torch.nn.functional as F + from .common import Activation class ConvBNLayer(nn.Module): - def __init__(self, - num_channels, - filter_size, - num_filters, - stride, - padding, - channels=None, - num_groups=1, - act='hard_swish'): + def __init__( + self, num_channels, filter_size, num_filters, stride, padding, channels=None, num_groups=1, act="hard_swish" + ): super(ConvBNLayer, self).__init__() self.act = act self._conv = nn.Conv2d( @@ -24,7 +18,8 @@ def __init__(self, stride=stride, padding=padding, groups=num_groups, - bias=False) + bias=False, + ) self._batch_norm = nn.BatchNorm2d( num_filters, @@ -41,16 +36,9 @@ def forward(self, inputs): class DepthwiseSeparable(nn.Module): - def __init__(self, - num_channels, - num_filters1, - num_filters2, - num_groups, - stride, - scale, - dw_size=3, - padding=1, - use_se=False): + def __init__( + self, num_channels, num_filters1, num_filters2, num_groups, stride, scale, dw_size=3, padding=1, use_se=False + ): super(DepthwiseSeparable, self).__init__() self.use_se = use_se self._depthwise_conv = ConvBNLayer( @@ -59,7 +47,8 @@ def __init__(self, filter_size=dw_size, stride=stride, padding=padding, - num_groups=int(num_groups * scale)) + num_groups=int(num_groups * scale), + ) if use_se: self._se = SEModule(int(num_filters1 * scale)) self._pointwise_conv = ConvBNLayer( @@ -67,7 +56,8 @@ def __init__(self, filter_size=1, num_filters=int(num_filters2 * scale), stride=1, - padding=0) + padding=0, + ) def forward(self, inputs): y = self._depthwise_conv(inputs) @@ -78,49 +68,28 @@ def forward(self, inputs): class MobileNetV1Enhance(nn.Module): - def __init__(self, - in_channels=3, - scale=0.5, - last_conv_stride=1, - last_pool_type='max', - **kwargs): + def __init__(self, in_channels=3, scale=0.5, last_conv_stride=1, last_pool_type="max", **kwargs): super().__init__() self.scale = scale self.block_list = [] self.conv1 = ConvBNLayer( - num_channels=in_channels, - filter_size=3, - channels=3, - num_filters=int(32 * scale), - stride=2, - padding=1) + num_channels=in_channels, filter_size=3, channels=3, num_filters=int(32 * scale), stride=2, padding=1 + ) conv2_1 = DepthwiseSeparable( - num_channels=int(32 * scale), - num_filters1=32, - num_filters2=64, - num_groups=32, - stride=1, - scale=scale) + num_channels=int(32 * scale), num_filters1=32, num_filters2=64, num_groups=32, stride=1, scale=scale + ) self.block_list.append(conv2_1) conv2_2 = DepthwiseSeparable( - num_channels=int(64 * scale), - num_filters1=64, - num_filters2=128, - num_groups=64, - stride=1, - scale=scale) + num_channels=int(64 * scale), num_filters1=64, num_filters2=128, num_groups=64, stride=1, scale=scale + ) self.block_list.append(conv2_2) conv3_1 = DepthwiseSeparable( - num_channels=int(128 * scale), - num_filters1=128, - num_filters2=128, - num_groups=128, - stride=1, - scale=scale) + num_channels=int(128 * scale), num_filters1=128, num_filters2=128, num_groups=128, stride=1, scale=scale + ) self.block_list.append(conv3_1) conv3_2 = DepthwiseSeparable( @@ -129,16 +98,13 @@ def __init__(self, num_filters2=256, num_groups=128, stride=(2, 1), - scale=scale) + scale=scale, + ) self.block_list.append(conv3_2) conv4_1 = DepthwiseSeparable( - num_channels=int(256 * scale), - num_filters1=256, - num_filters2=256, - num_groups=256, - stride=1, - scale=scale) + num_channels=int(256 * scale), num_filters1=256, num_filters2=256, num_groups=256, stride=1, scale=scale + ) self.block_list.append(conv4_1) conv4_2 = DepthwiseSeparable( @@ -147,7 +113,8 @@ def __init__(self, num_filters2=512, num_groups=256, stride=(2, 1), - scale=scale) + scale=scale, + ) self.block_list.append(conv4_2) for _ in range(5): @@ -160,7 +127,8 @@ def __init__(self, dw_size=5, padding=2, scale=scale, - use_se=False) + use_se=False, + ) self.block_list.append(conv5) conv5_6 = DepthwiseSeparable( @@ -172,7 +140,8 @@ def __init__(self, dw_size=5, padding=2, scale=scale, - use_se=True) + use_se=True, + ) self.block_list.append(conv5_6) conv6 = DepthwiseSeparable( @@ -184,11 +153,12 @@ def __init__(self, dw_size=5, padding=2, use_se=True, - scale=scale) + scale=scale, + ) self.block_list.append(conv6) self.block_list = nn.Sequential(*self.block_list) - if last_pool_type == 'avg': + if last_pool_type == "avg": self.pool = nn.AvgPool2d(kernel_size=2, stride=2, padding=0) else: self.pool = nn.MaxPool2d(kernel_size=2, stride=2, padding=0) @@ -200,27 +170,21 @@ def forward(self, inputs): y = self.pool(y) return y + def hardsigmoid(x): - return F.relu6(x + 3., inplace=True) / 6. + return F.relu6(x + 3.0, inplace=True) / 6.0 + class SEModule(nn.Module): def __init__(self, channel, reduction=4): super(SEModule, self).__init__() self.avg_pool = nn.AdaptiveAvgPool2d(1) self.conv1 = nn.Conv2d( - in_channels=channel, - out_channels=channel // reduction, - kernel_size=1, - stride=1, - padding=0, - bias=True) + in_channels=channel, out_channels=channel // reduction, kernel_size=1, stride=1, padding=0, bias=True + ) self.conv2 = nn.Conv2d( - in_channels=channel // reduction, - out_channels=channel, - kernel_size=1, - stride=1, - padding=0, - bias=True) + in_channels=channel // reduction, out_channels=channel, kernel_size=1, stride=1, padding=0, bias=True + ) def forward(self, inputs): outputs = self.avg_pool(inputs) diff --git a/examples/research_projects/anytext/ocr_recog/RecSVTR.py b/examples/research_projects/anytext/ocr_recog/RecSVTR.py index 484b3df99125..590a96995b26 100644 --- a/examples/research_projects/anytext/ocr_recog/RecSVTR.py +++ b/examples/research_projects/anytext/ocr_recog/RecSVTR.py @@ -1,19 +1,19 @@ +import numpy as np import torch import torch.nn as nn -import numpy as np -from torch.nn.init import trunc_normal_, zeros_, ones_ from torch.nn import functional +from torch.nn.init import ones_, trunc_normal_, zeros_ -def drop_path(x, drop_prob=0., training=False): +def drop_path(x, drop_prob=0.0, training=False): """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks). the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper... See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... """ - if drop_prob == 0. or not training: + if drop_prob == 0.0 or not training: return x keep_prob = torch.tensor(1 - drop_prob) - shape = (x.size()[0], ) + (1, ) * (x.ndim - 1) + shape = (x.size()[0],) + (1,) * (x.ndim - 1) random_tensor = keep_prob + torch.rand(shape, dtype=x.dtype) random_tensor = torch.floor(random_tensor) # binarize output = x.divide(keep_prob) * random_tensor @@ -24,20 +24,14 @@ class Swish(nn.Module): def __int__(self): super(Swish, self).__int__() - def forward(self,x): - return x*torch.sigmoid(x) + def forward(self, x): + return x * torch.sigmoid(x) class ConvBNLayer(nn.Module): - def __init__(self, - in_channels, - out_channels, - kernel_size=3, - stride=1, - padding=0, - bias_attr=False, - groups=1, - act=nn.GELU): + def __init__( + self, in_channels, out_channels, kernel_size=3, stride=1, padding=0, bias_attr=False, groups=1, act=nn.GELU + ): super().__init__() self.conv = nn.Conv2d( in_channels=in_channels, @@ -47,7 +41,8 @@ def __init__(self, padding=padding, groups=groups, # weight_attr=paddle.ParamAttr(initializer=nn.initializer.KaimingUniform()), - bias=bias_attr) + bias=bias_attr, + ) self.norm = nn.BatchNorm2d(out_channels) self.act = act() @@ -59,8 +54,7 @@ def forward(self, inputs): class DropPath(nn.Module): - """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks). - """ + """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).""" def __init__(self, drop_prob=None): super(DropPath, self).__init__() @@ -79,12 +73,7 @@ def forward(self, input): class Mlp(nn.Module): - def __init__(self, - in_features, - hidden_features=None, - out_features=None, - act_layer=nn.GELU, - drop=0.): + def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.0): super().__init__() out_features = out_features or in_features hidden_features = hidden_features or in_features @@ -107,11 +96,12 @@ def forward(self, x): class ConvMixer(nn.Module): def __init__( - self, - dim, - num_heads=8, - HW=(8, 25), - local_k=(3, 3), ): + self, + dim, + num_heads=8, + HW=(8, 25), + local_k=(3, 3), + ): super().__init__() self.HW = HW self.dim = dim @@ -119,7 +109,8 @@ def __init__( dim, dim, local_k, - 1, (local_k[0] // 2, local_k[1] // 2), + 1, + (local_k[0] // 2, local_k[1] // 2), groups=num_heads, # weight_attr=ParamAttr(initializer=KaimingNormal()) ) @@ -134,16 +125,18 @@ def forward(self, x): class Attention(nn.Module): - def __init__(self, - dim, - num_heads=8, - mixer='Global', - HW=(8, 25), - local_k=(7, 11), - qkv_bias=False, - qk_scale=None, - attn_drop=0., - proj_drop=0.): + def __init__( + self, + dim, + num_heads=8, + mixer="Global", + HW=(8, 25), + local_k=(7, 11), + qkv_bias=False, + qk_scale=None, + attn_drop=0.0, + proj_drop=0.0, + ): super().__init__() self.num_heads = num_heads head_dim = dim // num_heads @@ -159,18 +152,17 @@ def __init__(self, W = HW[1] self.N = H * W self.C = dim - if mixer == 'Local' and HW is not None: + if mixer == "Local" and HW is not None: hk = local_k[0] wk = local_k[1] mask = torch.ones([H * W, H + hk - 1, W + wk - 1]) for h in range(0, H): for w in range(0, W): - mask[h * W + w, h:h + hk, w:w + wk] = 0. - mask_paddle = mask[:, hk // 2:H + hk // 2, wk // 2:W + wk // - 2].flatten(1) - mask_inf = torch.full([H * W, H * W],fill_value=float('-inf')) + mask[h * W + w, h : h + hk, w : w + wk] = 0.0 + mask_paddle = mask[:, hk // 2 : H + hk // 2, wk // 2 : W + wk // 2].flatten(1) + mask_inf = torch.full([H * W, H * W], fill_value=float("-inf")) mask = torch.where(mask_paddle < 1, mask_paddle, mask_inf) - self.mask = mask[None,None,:] + self.mask = mask[None, None, :] # self.mask = mask.unsqueeze([0, 1]) self.mixer = mixer @@ -180,11 +172,11 @@ def forward(self, x): C = self.C else: _, N, C = x.shape - qkv = self.qkv(x).reshape((-1, N, 3, self.num_heads, C //self.num_heads)).permute((2, 0, 3, 1, 4)) + qkv = self.qkv(x).reshape((-1, N, 3, self.num_heads, C // self.num_heads)).permute((2, 0, 3, 1, 4)) q, k, v = qkv[0] * self.scale, qkv[1], qkv[2] - attn = (q.matmul(k.permute((0, 1, 3, 2)))) - if self.mixer == 'Local': + attn = q.matmul(k.permute((0, 1, 3, 2))) + if self.mixer == "Local": attn += self.mask attn = functional.softmax(attn, dim=-1) attn = self.attn_drop(attn) @@ -196,29 +188,30 @@ def forward(self, x): class Block(nn.Module): - def __init__(self, - dim, - num_heads, - mixer='Global', - local_mixer=(7, 11), - HW=(8, 25), - mlp_ratio=4., - qkv_bias=False, - qk_scale=None, - drop=0., - attn_drop=0., - drop_path=0., - act_layer=nn.GELU, - norm_layer='nn.LayerNorm', - epsilon=1e-6, - prenorm=True): + def __init__( + self, + dim, + num_heads, + mixer="Global", + local_mixer=(7, 11), + HW=(8, 25), + mlp_ratio=4.0, + qkv_bias=False, + qk_scale=None, + drop=0.0, + attn_drop=0.0, + drop_path=0.0, + act_layer=nn.GELU, + norm_layer="nn.LayerNorm", + epsilon=1e-6, + prenorm=True, + ): super().__init__() if isinstance(norm_layer, str): self.norm1 = eval(norm_layer)(dim, eps=epsilon) else: self.norm1 = norm_layer(dim) - if mixer == 'Global' or mixer == 'Local': - + if mixer == "Global" or mixer == "Local": self.mixer = Attention( dim, num_heads=num_heads, @@ -228,24 +221,21 @@ def __init__(self, qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, - proj_drop=drop) - elif mixer == 'Conv': - self.mixer = ConvMixer( - dim, num_heads=num_heads, HW=HW, local_k=local_mixer) + proj_drop=drop, + ) + elif mixer == "Conv": + self.mixer = ConvMixer(dim, num_heads=num_heads, HW=HW, local_k=local_mixer) else: raise TypeError("The mixer must be one of [Global, Local, Conv]") - self.drop_path = DropPath(drop_path) if drop_path > 0. else Identity() + self.drop_path = DropPath(drop_path) if drop_path > 0.0 else Identity() if isinstance(norm_layer, str): self.norm2 = eval(norm_layer)(dim, eps=epsilon) else: self.norm2 = norm_layer(dim) mlp_hidden_dim = int(dim * mlp_ratio) self.mlp_ratio = mlp_ratio - self.mlp = Mlp(in_features=dim, - hidden_features=mlp_hidden_dim, - act_layer=act_layer, - drop=drop) + self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop) self.prenorm = prenorm def forward(self, x): @@ -259,17 +249,11 @@ def forward(self, x): class PatchEmbed(nn.Module): - """ Image to Patch Embedding - """ + """Image to Patch Embedding""" - def __init__(self, - img_size=(32, 100), - in_channels=3, - embed_dim=768, - sub_num=2): + def __init__(self, img_size=(32, 100), in_channels=3, embed_dim=768, sub_num=2): super().__init__() - num_patches = (img_size[1] // (2 ** sub_num)) * \ - (img_size[0] // (2 ** sub_num)) + num_patches = (img_size[1] // (2**sub_num)) * (img_size[0] // (2**sub_num)) self.img_size = img_size self.num_patches = num_patches self.embed_dim = embed_dim @@ -283,7 +267,8 @@ def __init__(self, stride=2, padding=1, act=nn.GELU, - bias_attr=False), + bias_attr=False, + ), ConvBNLayer( in_channels=embed_dim // 2, out_channels=embed_dim, @@ -291,7 +276,9 @@ def __init__(self, stride=2, padding=1, act=nn.GELU, - bias_attr=False)) + bias_attr=False, + ), + ) if sub_num == 3: self.proj = nn.Sequential( ConvBNLayer( @@ -301,7 +288,8 @@ def __init__(self, stride=2, padding=1, act=nn.GELU, - bias_attr=False), + bias_attr=False, + ), ConvBNLayer( in_channels=embed_dim // 4, out_channels=embed_dim // 2, @@ -309,7 +297,8 @@ def __init__(self, stride=2, padding=1, act=nn.GELU, - bias_attr=False), + bias_attr=False, + ), ConvBNLayer( in_channels=embed_dim // 2, out_channels=embed_dim, @@ -317,31 +306,26 @@ def __init__(self, stride=2, padding=1, act=nn.GELU, - bias_attr=False)) + bias_attr=False, + ), + ) def forward(self, x): B, C, H, W = x.shape - assert H == self.img_size[0] and W == self.img_size[1], \ - f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})." + assert ( + H == self.img_size[0] and W == self.img_size[1] + ), f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})." x = self.proj(x).flatten(2).permute(0, 2, 1) return x class SubSample(nn.Module): - def __init__(self, - in_channels, - out_channels, - types='Pool', - stride=(2, 1), - sub_norm='nn.LayerNorm', - act=None): + def __init__(self, in_channels, out_channels, types="Pool", stride=(2, 1), sub_norm="nn.LayerNorm", act=None): super().__init__() self.types = types - if types == 'Pool': - self.avgpool = nn.AvgPool2d( - kernel_size=(3, 5), stride=stride, padding=(1, 2)) - self.maxpool = nn.MaxPool2d( - kernel_size=(3, 5), stride=stride, padding=(1, 2)) + if types == "Pool": + self.avgpool = nn.AvgPool2d(kernel_size=(3, 5), stride=stride, padding=(1, 2)) + self.maxpool = nn.MaxPool2d(kernel_size=(3, 5), stride=stride, padding=(1, 2)) self.proj = nn.Linear(in_channels, out_channels) else: self.conv = nn.Conv2d( @@ -359,8 +343,7 @@ def __init__(self, self.act = None def forward(self, x): - - if self.types == 'Pool': + if self.types == "Pool": x1 = self.avgpool(x) x2 = self.maxpool(x) x = (x1 + x2) * 0.5 @@ -377,46 +360,44 @@ def forward(self, x): class SVTRNet(nn.Module): def __init__( - self, - img_size=[48, 100], - in_channels=3, - embed_dim=[64, 128, 256], - depth=[3, 6, 3], - num_heads=[2, 4, 8], - mixer=['Local'] * 6 + ['Global'] * - 6, # Local atten, Global atten, Conv - local_mixer=[[7, 11], [7, 11], [7, 11]], - patch_merging='Conv', # Conv, Pool, None - mlp_ratio=4, - qkv_bias=True, - qk_scale=None, - drop_rate=0., - last_drop=0.1, - attn_drop_rate=0., - drop_path_rate=0.1, - norm_layer='nn.LayerNorm', - sub_norm='nn.LayerNorm', - epsilon=1e-6, - out_channels=192, - out_char_num=25, - block_unit='Block', - act='nn.GELU', - last_stage=True, - sub_num=2, - prenorm=True, - use_lenhead=False, - **kwargs): + self, + img_size=[48, 100], + in_channels=3, + embed_dim=[64, 128, 256], + depth=[3, 6, 3], + num_heads=[2, 4, 8], + mixer=["Local"] * 6 + ["Global"] * 6, # Local atten, Global atten, Conv + local_mixer=[[7, 11], [7, 11], [7, 11]], + patch_merging="Conv", # Conv, Pool, None + mlp_ratio=4, + qkv_bias=True, + qk_scale=None, + drop_rate=0.0, + last_drop=0.1, + attn_drop_rate=0.0, + drop_path_rate=0.1, + norm_layer="nn.LayerNorm", + sub_norm="nn.LayerNorm", + epsilon=1e-6, + out_channels=192, + out_char_num=25, + block_unit="Block", + act="nn.GELU", + last_stage=True, + sub_num=2, + prenorm=True, + use_lenhead=False, + **kwargs, + ): super().__init__() self.img_size = img_size self.embed_dim = embed_dim self.out_channels = out_channels self.prenorm = prenorm - patch_merging = None if patch_merging != 'Conv' and patch_merging != 'Pool' else patch_merging + patch_merging = None if patch_merging != "Conv" and patch_merging != "Pool" else patch_merging self.patch_embed = PatchEmbed( - img_size=img_size, - in_channels=in_channels, - embed_dim=embed_dim[0], - sub_num=sub_num) + img_size=img_size, in_channels=in_channels, embed_dim=embed_dim[0], sub_num=sub_num + ) num_patches = self.patch_embed.num_patches self.HW = [img_size[0] // (2**sub_num), img_size[1] // (2**sub_num)] self.pos_embed = nn.Parameter(torch.zeros(1, num_patches, embed_dim[0])) @@ -431,81 +412,85 @@ def __init__( dpr = np.linspace(0, drop_path_rate, sum(depth)) self.blocks1 = nn.ModuleList( [ - Block_unit( - dim=embed_dim[0], - num_heads=num_heads[0], - mixer=mixer[0:depth[0]][i], - HW=self.HW, - local_mixer=local_mixer[0], - mlp_ratio=mlp_ratio, - qkv_bias=qkv_bias, - qk_scale=qk_scale, - drop=drop_rate, - act_layer=eval(act), - attn_drop=attn_drop_rate, - drop_path=dpr[0:depth[0]][i], - norm_layer=norm_layer, - epsilon=epsilon, - prenorm=prenorm) for i in range(depth[0]) - ] + Block_unit( + dim=embed_dim[0], + num_heads=num_heads[0], + mixer=mixer[0 : depth[0]][i], + HW=self.HW, + local_mixer=local_mixer[0], + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, + qk_scale=qk_scale, + drop=drop_rate, + act_layer=eval(act), + attn_drop=attn_drop_rate, + drop_path=dpr[0 : depth[0]][i], + norm_layer=norm_layer, + epsilon=epsilon, + prenorm=prenorm, + ) + for i in range(depth[0]) + ] ) if patch_merging is not None: self.sub_sample1 = SubSample( - embed_dim[0], - embed_dim[1], - sub_norm=sub_norm, - stride=[2, 1], - types=patch_merging) + embed_dim[0], embed_dim[1], sub_norm=sub_norm, stride=[2, 1], types=patch_merging + ) HW = [self.HW[0] // 2, self.HW[1]] else: HW = self.HW self.patch_merging = patch_merging - self.blocks2 = nn.ModuleList([ - Block_unit( - dim=embed_dim[1], - num_heads=num_heads[1], - mixer=mixer[depth[0]:depth[0] + depth[1]][i], - HW=HW, - local_mixer=local_mixer[1], - mlp_ratio=mlp_ratio, - qkv_bias=qkv_bias, - qk_scale=qk_scale, - drop=drop_rate, - act_layer=eval(act), - attn_drop=attn_drop_rate, - drop_path=dpr[depth[0]:depth[0] + depth[1]][i], - norm_layer=norm_layer, - epsilon=epsilon, - prenorm=prenorm) for i in range(depth[1]) - ]) + self.blocks2 = nn.ModuleList( + [ + Block_unit( + dim=embed_dim[1], + num_heads=num_heads[1], + mixer=mixer[depth[0] : depth[0] + depth[1]][i], + HW=HW, + local_mixer=local_mixer[1], + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, + qk_scale=qk_scale, + drop=drop_rate, + act_layer=eval(act), + attn_drop=attn_drop_rate, + drop_path=dpr[depth[0] : depth[0] + depth[1]][i], + norm_layer=norm_layer, + epsilon=epsilon, + prenorm=prenorm, + ) + for i in range(depth[1]) + ] + ) if patch_merging is not None: self.sub_sample2 = SubSample( - embed_dim[1], - embed_dim[2], - sub_norm=sub_norm, - stride=[2, 1], - types=patch_merging) + embed_dim[1], embed_dim[2], sub_norm=sub_norm, stride=[2, 1], types=patch_merging + ) HW = [self.HW[0] // 4, self.HW[1]] else: HW = self.HW - self.blocks3 = nn.ModuleList([ - Block_unit( - dim=embed_dim[2], - num_heads=num_heads[2], - mixer=mixer[depth[0] + depth[1]:][i], - HW=HW, - local_mixer=local_mixer[2], - mlp_ratio=mlp_ratio, - qkv_bias=qkv_bias, - qk_scale=qk_scale, - drop=drop_rate, - act_layer=eval(act), - attn_drop=attn_drop_rate, - drop_path=dpr[depth[0] + depth[1]:][i], - norm_layer=norm_layer, - epsilon=epsilon, - prenorm=prenorm) for i in range(depth[2]) - ]) + self.blocks3 = nn.ModuleList( + [ + Block_unit( + dim=embed_dim[2], + num_heads=num_heads[2], + mixer=mixer[depth[0] + depth[1] :][i], + HW=HW, + local_mixer=local_mixer[2], + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, + qk_scale=qk_scale, + drop=drop_rate, + act_layer=eval(act), + attn_drop=attn_drop_rate, + drop_path=dpr[depth[0] + depth[1] :][i], + norm_layer=norm_layer, + epsilon=epsilon, + prenorm=prenorm, + ) + for i in range(depth[2]) + ] + ) self.last_stage = last_stage if last_stage: self.avg_pool = nn.AdaptiveAvgPool2d((1, out_char_num)) @@ -515,7 +500,8 @@ def __init__( kernel_size=1, stride=1, padding=0, - bias=False) + bias=False, + ) self.hardswish = nn.Hardswish() self.dropout = nn.Dropout(p=last_drop) if not prenorm: @@ -524,15 +510,14 @@ def __init__( if use_lenhead: self.len_conv = nn.Linear(embed_dim[2], self.out_channels) self.hardswish_len = nn.Hardswish() - self.dropout_len = nn.Dropout( - p=last_drop) + self.dropout_len = nn.Dropout(p=last_drop) - trunc_normal_(self.pos_embed,std=.02) + trunc_normal_(self.pos_embed, std=0.02) self.apply(self._init_weights) def _init_weights(self, m): if isinstance(m, nn.Linear): - trunc_normal_(m.weight,std=.02) + trunc_normal_(m.weight, std=0.02) if isinstance(m, nn.Linear) and m.bias is not None: zeros_(m.bias) elif isinstance(m, nn.LayerNorm): @@ -546,15 +531,11 @@ def forward_features(self, x): for blk in self.blocks1: x = blk(x) if self.patch_merging is not None: - x = self.sub_sample1( - x.permute([0, 2, 1]).reshape( - [-1, self.embed_dim[0], self.HW[0], self.HW[1]])) + x = self.sub_sample1(x.permute([0, 2, 1]).reshape([-1, self.embed_dim[0], self.HW[0], self.HW[1]])) for blk in self.blocks2: x = blk(x) if self.patch_merging is not None: - x = self.sub_sample2( - x.permute([0, 2, 1]).reshape( - [-1, self.embed_dim[1], self.HW[0] // 2, self.HW[1]])) + x = self.sub_sample2(x.permute([0, 2, 1]).reshape([-1, self.embed_dim[1], self.HW[0] // 2, self.HW[1]])) for blk in self.blocks3: x = blk(x) if not self.prenorm: @@ -571,9 +552,7 @@ def forward(self, x): h = self.HW[0] // 4 else: h = self.HW[0] - x = self.avg_pool( - x.permute([0, 2, 1]).reshape( - [-1, self.embed_dim[2], h, self.HW[1]])) + x = self.avg_pool(x.permute([0, 2, 1]).reshape([-1, self.embed_dim[2], h, self.HW[1]])) x = self.last_conv(x) x = self.hardswish(x) x = self.dropout(x) @@ -582,10 +561,10 @@ def forward(self, x): return x -if __name__=="__main__": - a = torch.rand(1,3,48,100) +if __name__ == "__main__": + a = torch.rand(1, 3, 48, 100) svtr = SVTRNet() out = svtr(a) print(svtr) - print(out.size()) \ No newline at end of file + print(out.size()) diff --git a/examples/research_projects/anytext/ocr_recog/common.py b/examples/research_projects/anytext/ocr_recog/common.py index a328bb034a37..207a95b17d0e 100644 --- a/examples/research_projects/anytext/ocr_recog/common.py +++ b/examples/research_projects/anytext/ocr_recog/common.py @@ -1,5 +1,3 @@ - - import torch import torch.nn as nn import torch.nn.functional as F @@ -11,7 +9,8 @@ def __init__(self, inplace=True): self.inplace = inplace def forward(self, x): - return x * F.relu6(x + 3., inplace=self.inplace) / 6. + return x * F.relu6(x + 3.0, inplace=self.inplace) / 6.0 + # out = max(0, min(1, slop*x+offset)) # paddle.fluid.layers.hard_sigmoid(x, slope=0.2, offset=0.5, name=None) @@ -23,7 +22,8 @@ def __init__(self, inplace=True): def forward(self, x): # torch: F.relu6(x + 3., inplace=self.inplace) / 6. # paddle: F.relu6(1.2 * x + 3., inplace=self.inplace) / 6. - return F.relu6(1.2 * x + 3., inplace=self.inplace) / 6. + return F.relu6(1.2 * x + 3.0, inplace=self.inplace) / 6.0 + class GELU(nn.Module): def __init__(self, inplace=True): @@ -44,31 +44,31 @@ def forward(self, x): x.mul_(torch.sigmoid(x)) return x else: - return x*torch.sigmoid(x) + return x * torch.sigmoid(x) class Activation(nn.Module): def __init__(self, act_type, inplace=True): super(Activation, self).__init__() act_type = act_type.lower() - if act_type == 'relu': + if act_type == "relu": self.act = nn.ReLU(inplace=inplace) - elif act_type == 'relu6': + elif act_type == "relu6": self.act = nn.ReLU6(inplace=inplace) - elif act_type == 'sigmoid': + elif act_type == "sigmoid": raise NotImplementedError - elif act_type == 'hard_sigmoid': + elif act_type == "hard_sigmoid": self.act = Hsigmoid(inplace) - elif act_type == 'hard_swish': + elif act_type == "hard_swish": self.act = Hswish(inplace=inplace) - elif act_type == 'leakyrelu': + elif act_type == "leakyrelu": self.act = nn.LeakyReLU(inplace=inplace) - elif act_type == 'gelu': + elif act_type == "gelu": self.act = GELU(inplace=inplace) - elif act_type == 'swish': + elif act_type == "swish": self.act = Swish(inplace=inplace) else: raise NotImplementedError def forward(self, inputs): - return self.act(inputs) \ No newline at end of file + return self.act(inputs) From b475a3b6ea65a3e29bcd9fa726e2d145efa2e168 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tolga=20Cang=C3=B6z?= Date: Sun, 4 Aug 2024 13:14:05 +0300 Subject: [PATCH 28/87] refactor: Update `AnyTextPipeline`'s docstring --- .../anytext/pipeline_anytext.py | 49 ++++++++----------- 1 file changed, 20 insertions(+), 29 deletions(-) diff --git a/examples/research_projects/anytext/pipeline_anytext.py b/examples/research_projects/anytext/pipeline_anytext.py index 66bc491b37f4..03bf3b8a41fe 100644 --- a/examples/research_projects/anytext/pipeline_anytext.py +++ b/examples/research_projects/anytext/pipeline_anytext.py @@ -62,45 +62,36 @@ EXAMPLE_DOC_STRING = """ Examples: ```py - >>> # !pip install opencv-python transformers accelerate - >>> from diffusers import StableDiffusionControlNetPipeline, ControlNetModel, UniPCMultistepScheduler + >>> from pipeline_anytext import AnyTextPipeline + >>> from text_controlnet import TextControlNetModel + >>> from diffusers import DDIMScheduler >>> from diffusers.utils import load_image - >>> import numpy as np >>> import torch - >>> import cv2 - >>> from PIL import Image - - >>> # download an image - >>> image = load_image( - ... "https://hf.co/datasets/huggingface/documentation-images/resolve/main/diffusers/input_image_vermeer.png" - ... ) - >>> image = np.array(image) - - >>> # get canny image - >>> image = cv2.Canny(image, 100, 200) - >>> image = image[:, :, None] - >>> image = np.concatenate([image, image, image], axis=2) - >>> canny_image = Image.fromarray(image) - >>> # load control net and stable diffusion v1-5 - >>> controlnet = ControlNetModel.from_pretrained("lllyasviel/sd-controlnet-canny", torch_dtype=torch.float16) - >>> pipe = StableDiffusionControlNetPipeline.from_pretrained( - ... "runwayml/stable-diffusion-v1-5", controlnet=controlnet, torch_dtype=torch.float16 - ... ) + >>> text_controlnet = TextControlNetModel.from_pretrained("a/TextControlNet", torch_dtype=torch.float16) + >>> pipe = AnyTextPipeline.from_pretrained( + ... "a/AnyText", controlnet=text_controlnet, torch_dtype=torch.float16, + ... variant="fp16" + ... ).to("cuda") - >>> # speed up diffusion process with faster scheduler and memory optimization - >>> pipe.scheduler = UniPCMultistepScheduler.from_config(pipe.scheduler.config) - >>> # remove following line if xformers is not installed - >>> pipe.enable_xformers_memory_efficient_attention() + >>> pipe.scheduler = DDIMScheduler.from_config(pipe.scheduler.config) + >>> # uncomment following line if PyTorch>=2.0 is not installed for memory optimization + >>> #pipe.enable_xformers_memory_efficient_attention() - >>> pipe.enable_model_cpu_offload() + >>> # uncomment following line if you want to offload the model to CPU for memory optimization + >>> # also remove the `.to("cuda")` part + >>> #pipe.enable_model_cpu_offload() >>> # generate image - >>> generator = torch.manual_seed(0) + >>> generator = torch.Generator("cpu").manual_seed(66273235) + >>> prompt = 'photo of caramel macchiato coffee on the table, top-down perspective, with "Any" "Text" written on it using cream' + >>> draw_pos = load_image("www.huggingface.co/a/AnyText/tree/main/examples/gen9.png") >>> image = pipe( - ... "futuristic-looking woman", num_inference_steps=20, generator=generator, image=canny_image + ... prompt, num_inference_steps=20, generator=generator, mode="generate", + ... draw_pos=draw_pos ... ).images[0] + >>> image ``` """ From ea957f08e0193f454e0f3180f6e9df81911ecfda Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tolga=20Cang=C3=B6z?= Date: Sun, 4 Aug 2024 18:44:54 +0300 Subject: [PATCH 29/87] Update `AuxiliaryLatentModule` to include info dictionary so that text processing is done once --- .../anytext/auxiliary_latent_module.py | 54 ++++++++----------- .../anytext/pipeline_anytext.py | 22 ++++---- 2 files changed, 32 insertions(+), 44 deletions(-) diff --git a/examples/research_projects/anytext/auxiliary_latent_module.py b/examples/research_projects/anytext/auxiliary_latent_module.py index d245288c3c72..8196804f8d67 100644 --- a/examples/research_projects/anytext/auxiliary_latent_module.py +++ b/examples/research_projects/anytext/auxiliary_latent_module.py @@ -120,7 +120,6 @@ def forward( prompt, draw_pos, ori_image, - img_count, max_chars=77, revise_pos=False, sort_priority=False, @@ -196,10 +195,11 @@ def forward( poly_list += [None] np_hint = np.sum(pre_pos, axis=0).clip(0, 1) # prepare info dict - glyphs_list = [] - positions = [] - n_lines = [len(texts)] * img_count - gly_pos_imgs = [] + info = {} + info['glyphs'] = [] + info['gly_line'] = [] + info['positions'] = [] + info['n_lines'] = [len(texts)]*len(prompt) for i in range(len(texts)): text = texts[i] if len(text) > max_chars: @@ -208,56 +208,46 @@ def forward( text = text[:max_chars] gly_scale = 2 if pre_pos[i].mean() != 0: - glyphs = self.draw_glyph2( - self.font, text, poly_list[i], scale=gly_scale, width=w, height=h, add_space=False - ) - gly_pos_img = cv2.drawContours(glyphs * 255, [poly_list[i] * gly_scale], 0, (255, 255, 255), 1) + gly_line = self.draw_glyph(self.font, text) + glyphs = self.draw_glyph2(self.font, text, poly_list[i], scale=gly_scale, width=w, height=h, add_space=False) if revise_pos: resize_gly = cv2.resize(glyphs, (pre_pos[i].shape[1], pre_pos[i].shape[0])) - new_pos = cv2.morphologyEx( - (resize_gly * 255).astype(np.uint8), - cv2.MORPH_CLOSE, - kernel=np.ones((resize_gly.shape[0] // 10, resize_gly.shape[1] // 10), dtype=np.uint8), - iterations=1, - ) + new_pos = cv2.morphologyEx((resize_gly*255).astype(np.uint8), cv2.MORPH_CLOSE, kernel=np.ones((resize_gly.shape[0]//10, resize_gly.shape[1]//10), dtype=np.uint8), iterations=1) new_pos = new_pos[..., np.newaxis] if len(new_pos.shape) == 2 else new_pos contours, _ = cv2.findContours(new_pos, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_NONE) if len(contours) != 1: - str_warning = f"Fail to revise position {i} to bounding rect, remain position unchanged..." + str_warning = f'Fail to revise position {i} to bounding rect, remain position unchanged...' logger.warning(str_warning) else: rect = cv2.minAreaRect(contours[0]) poly = np.int0(cv2.boxPoints(rect)) - pre_pos[i] = cv2.drawContours(new_pos, [poly], -1, 255, -1) / 255.0 - gly_pos_img = cv2.drawContours(glyphs * 255, [poly * gly_scale], 0, (255, 255, 255), 1) - gly_pos_imgs += [gly_pos_img] # for show + pre_pos[i] = cv2.drawContours(new_pos, [poly], -1, 255, -1) / 255. else: - glyphs = np.zeros((h * gly_scale, w * gly_scale, 1)) - gly_pos_imgs += [np.zeros((h * gly_scale, w * gly_scale, 1))] # for show + glyphs = np.zeros((h*gly_scale, w*gly_scale, 1)) + gly_line = np.zeros((80, 512, 1)) pos = pre_pos[i] - glyphs_list += [self.arr2tensor(glyphs, img_count)] - positions += [self.arr2tensor(pos, img_count)] - + info['glyphs'] += [self.arr2tensor(glyphs, len(prompt))] + info['gly_line'] += [self.arr2tensor(gly_line, len(prompt))] + info['positions'] += [self.arr2tensor(pos, len(prompt))] # get masked_x - masked_img = ((edit_image.astype(np.float32) / 127.5) - 1.0) * (1 - np_hint) + masked_img = ((edit_image.astype(np.float32) / 127.5) - 1.0)*(1-np_hint) masked_img = np.transpose(masked_img, (2, 0, 1)) - masked_img = torch.from_numpy(masked_img.copy()).float().cpu() + masked_img = torch.from_numpy(masked_img.copy()).float().to(self.device) if self.use_fp16: masked_img = masked_img.half() masked_x = self.encode_first_stage(masked_img[None, ...]).detach() if self.use_fp16: masked_x = masked_x.half() - masked_x = torch.cat([masked_x for _ in range(img_count)], dim=0) + info['masked_x'] = torch.cat([masked_x for _ in range(len(prompt))], dim=0) + hint = self.arr2tensor(np_hint, len(prompt)) - glyphs = torch.cat(glyphs_list, dim=1).sum(dim=1, keepdim=True) - positions = torch.cat(positions, dim=1).sum(dim=1, keepdim=True) + glyphs = torch.cat(info['glyphs'], dim=1).sum(dim=1, keepdim=True) + positions = torch.cat(info['positions'], dim=1).sum(dim=1, keepdim=True) enc_glyph = self.glyph_block(glyphs, emb, context) enc_pos = self.position_block(positions, emb, context) guided_hint = self.fuse_block(torch.cat([enc_glyph, enc_pos, masked_x], dim=1)) - hint = self.arr2tensor(np_hint, img_count) - - return guided_hint, hint # , gly_pos_imgs + return guided_hint, hint, info def encode_first_stage(self, masked_img): return retrieve_latents(self.vae.encode(masked_img)) * self.scale_factor diff --git a/examples/research_projects/anytext/pipeline_anytext.py b/examples/research_projects/anytext/pipeline_anytext.py index 03bf3b8a41fe..fd903ff666d4 100644 --- a/examples/research_projects/anytext/pipeline_anytext.py +++ b/examples/research_projects/anytext/pipeline_anytext.py @@ -1151,17 +1151,6 @@ def __call__( ) prompt, texts = self.modify_prompt(prompt) - prompt_embeds, negative_prompt_embeds = self.text_embedding_module( - prompt, - device, - num_images_per_prompt, - self.do_classifier_free_guidance, - negative_prompt, - prompt_embeds=prompt_embeds, - negative_prompt_embeds=negative_prompt_embeds, - lora_scale=text_encoder_lora_scale, - clip_skip=self.clip_skip, - ) # For classifier free guidance, we need to do two forward passes. # Here we concatenate the unconditional and text embeddings into a single batch # to avoid doing two forward passes @@ -1199,7 +1188,7 @@ def __call__( # guess_mode=guess_mode, # ) # height, width = image.shape[-2:] - guided_hint = self.auxiliary_latent_module( + guided_hint, hint, text_info = self.auxiliary_latent_module( emb=timestep_cond, context=prompt_embeds, mode=mode, @@ -1237,6 +1226,15 @@ def __call__( else: assert False + prompt_embeds, negative_prompt_embeds = self.text_embedding_module( + prompt, + device, + num_images_per_prompt, + self.do_classifier_free_guidance, + hint, + negative_prompt, + text_info, + ) # 5. Prepare timesteps timesteps, num_inference_steps = retrieve_timesteps( self.scheduler, num_inference_steps, device, timesteps, sigmas From cc0c6e590150381e70c4596b8da3f31bbe745d6b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tolga=20Cang=C3=B6z?= Date: Sun, 4 Aug 2024 18:46:06 +0300 Subject: [PATCH 30/87] simplify --- .../anytext/text_embedding_module.py | 350 ++---------------- 1 file changed, 30 insertions(+), 320 deletions(-) diff --git a/examples/research_projects/anytext/text_embedding_module.py b/examples/research_projects/anytext/text_embedding_module.py index fe3c1a4486ab..4cf4149eec83 100644 --- a/examples/research_projects/anytext/text_embedding_module.py +++ b/examples/research_projects/anytext/text_embedding_module.py @@ -14,6 +14,7 @@ from diffusers.loaders import StableDiffusionLoraLoaderMixin, TextualInversionLoaderMixin from diffusers.models.lora import adjust_lora_scale_text_encoder from diffusers.utils import USE_PEFT_BACKEND, logging, scale_lora_layers, unscale_lora_layers +from diffusers.models.autoencoders.vae import DiagonalGaussianDistribution from .embedding_manager import EmbeddingManager from .frozen_clip_embedder_t3 import FrozenCLIPEmbedderT3 @@ -24,12 +25,11 @@ class TextEmbeddingModule(nn.Module): - def __init__(self, font_path, device): + def __init__(self, font_path, device, use_fp16): super().__init__() self.device = device + # TODO: Learn if the recommended font file is free to use self.font = ImageFont.truetype(font_path, 60) - self.ocr_model = ... - self.linear = nn.Linear() self.frozen_CLIP_embedder_t3 = FrozenCLIPEmbedderT3(device=self.device) self.embedding_manager_config = { "valid": True, @@ -45,335 +45,45 @@ def __init__(self, font_path, device): param.requires_grad = True rec_model_dir = "./ocr_weights/ppv3_rec.pth" self.text_predictor = create_predictor(rec_model_dir).eval() - args = edict() - args.rec_image_shape = "3, 48, 320" - args.rec_batch_num = 6 - args.rec_char_dict_path = "./ocr_recog/ppocr_keys_v1.txt" - args.use_fp16 = self.use_fp16 + args = {} + args['rec_image_shape'] = "3, 48, 320" + args['rec_batch_num'] = 6 + args['rec_char_dict_path'] = "./ocr_recog/ppocr_keys_v1.txt" + args['use_fp16'] = use_fp16 self.cn_recognizer = TextRecognizer(args, self.text_predictor) for param in self.text_predictor.parameters(): param.requires_grad = False self.embedding_manager.recog = self.cn_recognizer @torch.no_grad() - def forward(self, texts, prompt, device, num_images_per_prompt, do_classifier_free_guidance): - glyph_lines = self.create_glyph_lines(texts) - ocr_output = self.ocr(glyph_lines) - _ = self.linear(ocr_output) - # Token Replacement + def forward(self, prompt, device, num_images_per_prompt, do_classifier_free_guidance, hint, n_prompt, text_info): + prompt_embeds = self.get_learned_conditioning({"c_concat": [hint], "c_crossattn": [[prompt] * len(prompt)], "text_info": text_info}) + negative_prompt_embeds = self.get_learned_conditioning({"c_concat": [hint], "c_crossattn": [[n_prompt] * len(prompt)], "text_info": text_info}) - # FrozenCLIPEmbedderT3 - prompt_embeds, negative_prompt_embeds = self.encode_prompt( - prompt, - device, - num_images_per_prompt, - do_classifier_free_guidance, - negative_prompt=None, - prompt_embeds=None, - negative_prompt_embeds=None, - lora_scale=None, - clip_skip=None, - ) return prompt_embeds, negative_prompt_embeds - def ocr(self, glyph_lines): - pass - - def create_glyph_lines( - self, - texts, - mode="text-generation", - img_count=1, - max_chars=77, - draw_pos=None, - ori_image=None, - sort_priority=False, - h=512, - w=512, - ): - if mode in ["text-generation", "gen"]: - edit_image = np.ones((h, w, 3)) * 127.5 # empty mask image - elif mode in ["text-editing", "edit"]: - if isinstance(ori_image, str): - ori_image = cv2.imread(ori_image)[..., ::-1] - elif isinstance(ori_image, torch.Tensor): - ori_image = ori_image.cpu().numpy() - edit_image = ori_image.clip(1, 255) # for mask reason - edit_image = self.check_channels(edit_image) - edit_image = self.resize_image( - edit_image, max_length=768 - ) # make w h multiple of 64, resize if w or h > max_length - h, w = edit_image.shape[:2] # change h, w by input ref_img - # preprocess pos_imgs(if numpy, make sure it's white pos in black bg) - if draw_pos is None: - pos_imgs = np.zeros((w, h, 1)) - if isinstance(draw_pos, str): - draw_pos = cv2.imread(draw_pos)[..., ::-1] - pos_imgs = 255 - draw_pos - elif isinstance(draw_pos, torch.Tensor): - pos_imgs = draw_pos.cpu().numpy() - if mode in ["text-editing", "edit"]: - pos_imgs = cv2.resize(pos_imgs, (w, h)) - pos_imgs = pos_imgs[..., 0:1] - pos_imgs = cv2.convertScaleAbs(pos_imgs) - _, pos_imgs = cv2.threshold(pos_imgs, 254, 255, cv2.THRESH_BINARY) - # separate pos_imgs - pos_imgs = self.separate_pos_imgs(pos_imgs, sort_priority) - if len(pos_imgs) == 0: - pos_imgs = [np.zeros((h, w, 1))] - # get pre_pos that needed for anytext - pre_pos = [] - for input_pos in pos_imgs: - if input_pos.mean() != 0: - input_pos = input_pos[..., np.newaxis] if len(input_pos.shape) == 2 else input_pos - poly, pos_img = self.find_polygon(input_pos) - pre_pos += [pos_img / 255.0] + def get_learned_conditioning(self, c): + if hasattr(self.frozen_CLIP_embedder_t3, 'encode') and callable(self.frozen_CLIP_embedder_t3.encode): + if self.embedding_manager is not None and c['text_info'] is not None: + self.embedding_manager.encode_text(c['text_info']) + if isinstance(c, dict): + cond_txt = c['c_crossattn'][0] else: - pre_pos += [np.zeros((h, w, 1))] - # prepare info dict - gly_lines = [] - for i in range(len(texts)): - text = texts[i] - if len(text) > max_chars: - text = text[:max_chars] - if pre_pos[i].mean() != 0: - gly_line = self.draw_glyph(self.font, text) + cond_txt = c + if self.embedding_manager is not None: + cond_txt = self.frozen_CLIP_embedder_t3.encode(cond_txt, embedding_manager=self.embedding_manager) else: - gly_line = np.zeros((80, 512, 1)) - gly_lines += [self.arr2tensor(gly_line, img_count)] - - return gly_lines - - def check_channels(self, image): - channels = image.shape[2] if len(image.shape) == 3 else 1 - if channels == 1: - image = cv2.cvtColor(image, cv2.COLOR_GRAY2BGR) - elif channels > 3: - image = image[:, :, :3] - return image - - def resize_image(self, img, max_length=768): - height, width = img.shape[:2] - max_dimension = max(height, width) - - if max_dimension > max_length: - scale_factor = max_length / max_dimension - new_width = int(round(width * scale_factor)) - new_height = int(round(height * scale_factor)) - new_size = (new_width, new_height) - img = cv2.resize(img, new_size) - height, width = img.shape[:2] - img = cv2.resize(img, (width - (width % 64), height - (height % 64))) - return img - - def draw_glyph(self, font, text): - g_size = 50 - W, H = (512, 80) - new_font = font.font_variant(size=g_size) - img = Image.new(mode="1", size=(W, H), color=0) - draw = ImageDraw.Draw(img) - left, top, right, bottom = new_font.getbbox(text) - text_width = max(right - left, 5) - text_height = max(bottom - top, 5) - ratio = min(W * 0.9 / text_width, H * 0.9 / text_height) - new_font = font.font_variant(size=int(g_size * ratio)) - - text_width, text_height = new_font.getsize(text) - offset_x, offset_y = new_font.getoffset(text) - x = (img.width - text_width) // 2 - y = (img.height - text_height) // 2 - offset_y // 2 - draw.text((x, y), text, font=new_font, fill="white") - img = np.expand_dims(np.array(img), axis=2).astype(np.float64) - return img - - def encode_prompt( - self, - prompt, - device, - num_images_per_prompt, - do_classifier_free_guidance, - negative_prompt=None, - prompt_embeds: Optional[torch.Tensor] = None, - negative_prompt_embeds: Optional[torch.Tensor] = None, - lora_scale: Optional[float] = None, - clip_skip: Optional[int] = None, - ): - r""" - Encodes the prompt into text encoder hidden states. - - Args: - prompt (`str` or `List[str]`, *optional*): - prompt to be encoded - device: (`torch.device`): - torch device - num_images_per_prompt (`int`): - number of images that should be generated per prompt - do_classifier_free_guidance (`bool`): - whether to use classifier free guidance or not - negative_prompt (`str` or `List[str]`, *optional*): - The prompt or prompts not to guide the image generation. If not defined, one has to pass - `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is - less than `1`). - prompt_embeds (`torch.Tensor`, *optional*): - Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not - provided, text embeddings will be generated from `prompt` input argument. - negative_prompt_embeds (`torch.Tensor`, *optional*): - Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt - weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input - argument. - lora_scale (`float`, *optional*): - A LoRA scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded. - clip_skip (`int`, *optional*): - Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that - the output of the pre-final layer will be used for computing the prompt embeddings. - """ - # set lora scale so that monkey patched LoRA - # function of text encoder can correctly access it - if lora_scale is not None and isinstance(self, StableDiffusionLoraLoaderMixin): - self._lora_scale = lora_scale - - # dynamically adjust the LoRA scale - if not USE_PEFT_BACKEND: - adjust_lora_scale_text_encoder(self.frozen_CLIP_embedder_t3.text_encoder, lora_scale) + cond_txt = self.frozen_CLIP_embedder_t3.encode(cond_txt) + if isinstance(c, dict): + c['c_crossattn'][0] = cond_txt else: - scale_lora_layers(self.frozen_CLIP_embedder_t3.text_encoder, lora_scale) - - if prompt is not None and isinstance(prompt, str): - batch_size = 1 - elif prompt is not None and isinstance(prompt, list): - batch_size = len(prompt) + c = cond_txt + if isinstance(c, DiagonalGaussianDistribution): + c = c.mode() else: - batch_size = prompt_embeds.shape[0] - - if prompt_embeds is None: - # textual inversion: process multi-vector tokens if necessary - if isinstance(self, TextualInversionLoaderMixin): - prompt = self.maybe_convert_prompt(prompt, self.frozen_CLIP_embedder_t3.tokenizer) - - text_inputs = self.frozen_CLIP_embedder_t3.tokenizer( - prompt, - padding="max_length", - max_length=self.frozen_CLIP_embedder_t3.tokenizer.model_max_length, - truncation=True, - return_tensors="pt", - ) - text_input_ids = text_inputs.input_ids - untruncated_ids = self.frozen_CLIP_embedder_t3.tokenizer( - prompt, padding="longest", return_tensors="pt" - ).input_ids + c = self.frozen_CLIP_embedder_t3(c) - if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal( - text_input_ids, untruncated_ids - ): - removed_text = self.frozen_CLIP_embedder_t3.tokenizer.batch_decode( - untruncated_ids[:, self.frozen_CLIP_embedder_t3.tokenizer.model_max_length - 1 : -1] - ) - logger.warning( - "The following part of your input was truncated because CLIP can only handle sequences up to" - f" {self.frozen_CLIP_embedder_t3.tokenizer.model_max_length} tokens: {removed_text}" - ) - - if ( - hasattr(self.frozen_CLIP_embedder_t3.text_encoder.config, "use_attention_mask") - and self.frozen_CLIP_embedder_t3.text_encoder.config.use_attention_mask - ): - attention_mask = text_inputs.attention_mask.to(device) - else: - attention_mask = None - - if clip_skip is None: - prompt_embeds = self.frozen_CLIP_embedder_t3.text_encoder( - text_input_ids.to(device), attention_mask=attention_mask - ) - prompt_embeds = prompt_embeds[0] - else: - prompt_embeds = self.frozen_CLIP_embedder_t3.text_encoder( - text_input_ids.to(device), attention_mask=attention_mask, output_hidden_states=True - ) - # Access the `hidden_states` first, that contains a tuple of - # all the hidden states from the encoder layers. Then index into - # the tuple to access the hidden states from the desired layer. - prompt_embeds = prompt_embeds[-1][-(clip_skip + 1)] - # We also need to apply the final LayerNorm here to not mess with the - # representations. The `last_hidden_states` that we typically use for - # obtaining the final prompt representations passes through the LayerNorm - # layer. - prompt_embeds = self.frozen_CLIP_embedder_t3.text_encoder.text_model.final_layer_norm(prompt_embeds) + return c - if self.text_encoder is not None: - prompt_embeds_dtype = self.text_encoder.dtype - elif self.unet is not None: - prompt_embeds_dtype = self.unet.dtype - else: - prompt_embeds_dtype = prompt_embeds.dtype - - prompt_embeds = prompt_embeds.to(dtype=prompt_embeds_dtype, device=device) - - bs_embed, seq_len, _ = prompt_embeds.shape - # duplicate text embeddings for each generation per prompt, using mps friendly method - prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) - prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1) - - # get unconditional embeddings for classifier free guidance - if do_classifier_free_guidance and negative_prompt_embeds is None: - uncond_tokens: List[str] - if negative_prompt is None: - uncond_tokens = [""] * batch_size - elif prompt is not None and type(prompt) is not type(negative_prompt): - raise TypeError( - f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" - f" {type(prompt)}." - ) - elif isinstance(negative_prompt, str): - uncond_tokens = [negative_prompt] - elif batch_size != len(negative_prompt): - raise ValueError( - f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" - f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" - " the batch size of `prompt`." - ) - else: - uncond_tokens = negative_prompt - - # textual inversion: process multi-vector tokens if necessary - if isinstance(self, TextualInversionLoaderMixin): - uncond_tokens = self.maybe_convert_prompt(uncond_tokens, self.frozen_CLIP_embedder_t3.tokenizer) - - max_length = prompt_embeds.shape[1] - uncond_input = self.frozen_CLIP_embedder_t3.tokenizer( - uncond_tokens, - padding="max_length", - max_length=max_length, - truncation=True, - return_tensors="pt", - ) - - if ( - hasattr(self.frozen_CLIP_embedder_t3.text_encoder.config, "use_attention_mask") - and self.frozen_CLIP_embedder_t3.text_encoder.config.use_attention_mask - ): - attention_mask = uncond_input.attention_mask.to(device) - else: - attention_mask = None - - negative_prompt_embeds = self.frozen_CLIP_embedder_t3.text_encoder( - uncond_input.input_ids.to(device), - attention_mask=attention_mask, - ) - negative_prompt_embeds = negative_prompt_embeds[0] - - if do_classifier_free_guidance: - # duplicate unconditional embeddings for each generation per prompt, using mps friendly method - seq_len = negative_prompt_embeds.shape[1] - - negative_prompt_embeds = negative_prompt_embeds.to(dtype=prompt_embeds_dtype, device=device) - - negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1) - negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) - - if self.frozen_CLIP_embedder_t3.text_encoder is not None: - if isinstance(self, StableDiffusionLoraLoaderMixin) and USE_PEFT_BACKEND: - # Retrieve the original scale by scaling back the LoRA layers - unscale_lora_layers(self.frozen_CLIP_embedder_t3.text_encoder, lora_scale) - - return prompt_embeds, negative_prompt_embeds + def get_unconditional_conditioning(self, N): + return self.get_learned_conditioning({"c_crossattn": [[""] * N], "text_info": None}) From 52fb0b4de60c0936faa5099447c9874a83f64b4d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tolga=20Cang=C3=B6z?= Date: Sun, 4 Aug 2024 18:47:15 +0300 Subject: [PATCH 31/87] `make style` --- .../anytext/auxiliary_latent_module.py | 39 +++++++++++-------- .../anytext/text_embedding_module.py | 36 ++++++++--------- 2 files changed, 40 insertions(+), 35 deletions(-) diff --git a/examples/research_projects/anytext/auxiliary_latent_module.py b/examples/research_projects/anytext/auxiliary_latent_module.py index 8196804f8d67..6571d5bb5c9b 100644 --- a/examples/research_projects/anytext/auxiliary_latent_module.py +++ b/examples/research_projects/anytext/auxiliary_latent_module.py @@ -196,10 +196,10 @@ def forward( np_hint = np.sum(pre_pos, axis=0).clip(0, 1) # prepare info dict info = {} - info['glyphs'] = [] - info['gly_line'] = [] - info['positions'] = [] - info['n_lines'] = [len(texts)]*len(prompt) + info["glyphs"] = [] + info["gly_line"] = [] + info["positions"] = [] + info["n_lines"] = [len(texts)] * len(prompt) for i in range(len(texts)): text = texts[i] if len(text) > max_chars: @@ -209,28 +209,35 @@ def forward( gly_scale = 2 if pre_pos[i].mean() != 0: gly_line = self.draw_glyph(self.font, text) - glyphs = self.draw_glyph2(self.font, text, poly_list[i], scale=gly_scale, width=w, height=h, add_space=False) + glyphs = self.draw_glyph2( + self.font, text, poly_list[i], scale=gly_scale, width=w, height=h, add_space=False + ) if revise_pos: resize_gly = cv2.resize(glyphs, (pre_pos[i].shape[1], pre_pos[i].shape[0])) - new_pos = cv2.morphologyEx((resize_gly*255).astype(np.uint8), cv2.MORPH_CLOSE, kernel=np.ones((resize_gly.shape[0]//10, resize_gly.shape[1]//10), dtype=np.uint8), iterations=1) + new_pos = cv2.morphologyEx( + (resize_gly * 255).astype(np.uint8), + cv2.MORPH_CLOSE, + kernel=np.ones((resize_gly.shape[0] // 10, resize_gly.shape[1] // 10), dtype=np.uint8), + iterations=1, + ) new_pos = new_pos[..., np.newaxis] if len(new_pos.shape) == 2 else new_pos contours, _ = cv2.findContours(new_pos, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_NONE) if len(contours) != 1: - str_warning = f'Fail to revise position {i} to bounding rect, remain position unchanged...' + str_warning = f"Fail to revise position {i} to bounding rect, remain position unchanged..." logger.warning(str_warning) else: rect = cv2.minAreaRect(contours[0]) poly = np.int0(cv2.boxPoints(rect)) - pre_pos[i] = cv2.drawContours(new_pos, [poly], -1, 255, -1) / 255. + pre_pos[i] = cv2.drawContours(new_pos, [poly], -1, 255, -1) / 255.0 else: - glyphs = np.zeros((h*gly_scale, w*gly_scale, 1)) + glyphs = np.zeros((h * gly_scale, w * gly_scale, 1)) gly_line = np.zeros((80, 512, 1)) pos = pre_pos[i] - info['glyphs'] += [self.arr2tensor(glyphs, len(prompt))] - info['gly_line'] += [self.arr2tensor(gly_line, len(prompt))] - info['positions'] += [self.arr2tensor(pos, len(prompt))] + info["glyphs"] += [self.arr2tensor(glyphs, len(prompt))] + info["gly_line"] += [self.arr2tensor(gly_line, len(prompt))] + info["positions"] += [self.arr2tensor(pos, len(prompt))] # get masked_x - masked_img = ((edit_image.astype(np.float32) / 127.5) - 1.0)*(1-np_hint) + masked_img = ((edit_image.astype(np.float32) / 127.5) - 1.0) * (1 - np_hint) masked_img = np.transpose(masked_img, (2, 0, 1)) masked_img = torch.from_numpy(masked_img.copy()).float().to(self.device) if self.use_fp16: @@ -238,11 +245,11 @@ def forward( masked_x = self.encode_first_stage(masked_img[None, ...]).detach() if self.use_fp16: masked_x = masked_x.half() - info['masked_x'] = torch.cat([masked_x for _ in range(len(prompt))], dim=0) + info["masked_x"] = torch.cat([masked_x for _ in range(len(prompt))], dim=0) hint = self.arr2tensor(np_hint, len(prompt)) - glyphs = torch.cat(info['glyphs'], dim=1).sum(dim=1, keepdim=True) - positions = torch.cat(info['positions'], dim=1).sum(dim=1, keepdim=True) + glyphs = torch.cat(info["glyphs"], dim=1).sum(dim=1, keepdim=True) + positions = torch.cat(info["positions"], dim=1).sum(dim=1, keepdim=True) enc_glyph = self.glyph_block(glyphs, emb, context) enc_pos = self.position_block(positions, emb, context) guided_hint = self.fuse_block(torch.cat([enc_glyph, enc_pos, masked_x], dim=1)) diff --git a/examples/research_projects/anytext/text_embedding_module.py b/examples/research_projects/anytext/text_embedding_module.py index 4cf4149eec83..e7e5ef98db13 100644 --- a/examples/research_projects/anytext/text_embedding_module.py +++ b/examples/research_projects/anytext/text_embedding_module.py @@ -2,19 +2,13 @@ # +> Token Replacement -> FrozenCLIPEmbedderT3 # text -> tokenizer -> -from typing import List, Optional -import cv2 -import numpy as np import torch -from easydict import EasyDict as edict -from PIL import Image, ImageDraw, ImageFont +from PIL import ImageFont from torch import nn -from diffusers.loaders import StableDiffusionLoraLoaderMixin, TextualInversionLoaderMixin -from diffusers.models.lora import adjust_lora_scale_text_encoder -from diffusers.utils import USE_PEFT_BACKEND, logging, scale_lora_layers, unscale_lora_layers from diffusers.models.autoencoders.vae import DiagonalGaussianDistribution +from diffusers.utils import logging from .embedding_manager import EmbeddingManager from .frozen_clip_embedder_t3 import FrozenCLIPEmbedderT3 @@ -46,10 +40,10 @@ def __init__(self, font_path, device, use_fp16): rec_model_dir = "./ocr_weights/ppv3_rec.pth" self.text_predictor = create_predictor(rec_model_dir).eval() args = {} - args['rec_image_shape'] = "3, 48, 320" - args['rec_batch_num'] = 6 - args['rec_char_dict_path'] = "./ocr_recog/ppocr_keys_v1.txt" - args['use_fp16'] = use_fp16 + args["rec_image_shape"] = "3, 48, 320" + args["rec_batch_num"] = 6 + args["rec_char_dict_path"] = "./ocr_recog/ppocr_keys_v1.txt" + args["use_fp16"] = use_fp16 self.cn_recognizer = TextRecognizer(args, self.text_predictor) for param in self.text_predictor.parameters(): param.requires_grad = False @@ -57,17 +51,21 @@ def __init__(self, font_path, device, use_fp16): @torch.no_grad() def forward(self, prompt, device, num_images_per_prompt, do_classifier_free_guidance, hint, n_prompt, text_info): - prompt_embeds = self.get_learned_conditioning({"c_concat": [hint], "c_crossattn": [[prompt] * len(prompt)], "text_info": text_info}) - negative_prompt_embeds = self.get_learned_conditioning({"c_concat": [hint], "c_crossattn": [[n_prompt] * len(prompt)], "text_info": text_info}) + prompt_embeds = self.get_learned_conditioning( + {"c_concat": [hint], "c_crossattn": [[prompt] * len(prompt)], "text_info": text_info} + ) + negative_prompt_embeds = self.get_learned_conditioning( + {"c_concat": [hint], "c_crossattn": [[n_prompt] * len(prompt)], "text_info": text_info} + ) return prompt_embeds, negative_prompt_embeds def get_learned_conditioning(self, c): - if hasattr(self.frozen_CLIP_embedder_t3, 'encode') and callable(self.frozen_CLIP_embedder_t3.encode): - if self.embedding_manager is not None and c['text_info'] is not None: - self.embedding_manager.encode_text(c['text_info']) + if hasattr(self.frozen_CLIP_embedder_t3, "encode") and callable(self.frozen_CLIP_embedder_t3.encode): + if self.embedding_manager is not None and c["text_info"] is not None: + self.embedding_manager.encode_text(c["text_info"]) if isinstance(c, dict): - cond_txt = c['c_crossattn'][0] + cond_txt = c["c_crossattn"][0] else: cond_txt = c if self.embedding_manager is not None: @@ -75,7 +73,7 @@ def get_learned_conditioning(self, c): else: cond_txt = self.frozen_CLIP_embedder_t3.encode(cond_txt) if isinstance(c, dict): - c['c_crossattn'][0] = cond_txt + c["c_crossattn"][0] = cond_txt else: c = cond_txt if isinstance(c, DiagonalGaussianDistribution): From 9dd4ee90c3cede8ce33fdcc3997619296dc46f22 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tolga=20Cang=C3=B6z?= Date: Sun, 4 Aug 2024 20:18:45 +0300 Subject: [PATCH 32/87] Converting `TextEmbeddingModule` to ordinary `encode_prompt()` function --- .../anytext/pipeline_anytext.py | 14 +- .../anytext/text_embedding_module.py | 176 +++++++++++++++++- 2 files changed, 178 insertions(+), 12 deletions(-) diff --git a/examples/research_projects/anytext/pipeline_anytext.py b/examples/research_projects/anytext/pipeline_anytext.py index fd903ff666d4..1f392f903794 100644 --- a/examples/research_projects/anytext/pipeline_anytext.py +++ b/examples/research_projects/anytext/pipeline_anytext.py @@ -1145,10 +1145,6 @@ def __call__( ) guess_mode = guess_mode or global_pool_conditions - # 3. Encode input prompt - text_encoder_lora_scale = ( - self.cross_attention_kwargs.get("scale", None) if self.cross_attention_kwargs is not None else None - ) prompt, texts = self.modify_prompt(prompt) # For classifier free guidance, we need to do two forward passes. @@ -1226,14 +1222,22 @@ def __call__( else: assert False + # 3. Encode input prompt + text_encoder_lora_scale = ( + self.cross_attention_kwargs.get("scale", None) if self.cross_attention_kwargs is not None else None + ) prompt_embeds, negative_prompt_embeds = self.text_embedding_module( prompt, device, num_images_per_prompt, self.do_classifier_free_guidance, hint, - negative_prompt, text_info, + negative_prompt, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + lora_scale=text_encoder_lora_scale, + clip_skip=self.clip_skip, ) # 5. Prepare timesteps timesteps, num_inference_steps = retrieve_timesteps( diff --git a/examples/research_projects/anytext/text_embedding_module.py b/examples/research_projects/anytext/text_embedding_module.py index e7e5ef98db13..5f5ccb8e4112 100644 --- a/examples/research_projects/anytext/text_embedding_module.py +++ b/examples/research_projects/anytext/text_embedding_module.py @@ -3,12 +3,24 @@ # text -> tokenizer -> +from typing import List, Optional + import torch from PIL import ImageFont from torch import nn +from diffusers.loaders import ( + StableDiffusionLoraLoaderMixin, + TextualInversionLoaderMixin, +) from diffusers.models.autoencoders.vae import DiagonalGaussianDistribution -from diffusers.utils import logging +from diffusers.models.lora import adjust_lora_scale_text_encoder +from diffusers.utils import ( + USE_PEFT_BACKEND, + logging, + scale_lora_layers, + unscale_lora_layers, +) from .embedding_manager import EmbeddingManager from .frozen_clip_embedder_t3 import FrozenCLIPEmbedderT3 @@ -50,14 +62,167 @@ def __init__(self, font_path, device, use_fp16): self.embedding_manager.recog = self.cn_recognizer @torch.no_grad() - def forward(self, prompt, device, num_images_per_prompt, do_classifier_free_guidance, hint, n_prompt, text_info): + def forward( + self, + prompt, + device, + num_images_per_prompt, + do_classifier_free_guidance, + hint, + text_info, + negative_prompt=None, + prompt_embeds: Optional[torch.Tensor] = None, + negative_prompt_embeds: Optional[torch.Tensor] = None, + lora_scale: Optional[float] = None, + clip_skip: Optional[int] = None, + ): + # TODO: Convert `get_learned_conditioning` functions to `diffusers`' format prompt_embeds = self.get_learned_conditioning( - {"c_concat": [hint], "c_crossattn": [[prompt] * len(prompt)], "text_info": text_info} + {"c_concat": [hint], "c_crossattn": [[prompt] * num_images_per_prompt], "text_info": text_info} ) negative_prompt_embeds = self.get_learned_conditioning( - {"c_concat": [hint], "c_crossattn": [[n_prompt] * len(prompt)], "text_info": text_info} + {"c_concat": [hint], "c_crossattn": [[negative_prompt] * num_images_per_prompt], "text_info": text_info} ) + # set lora scale so that monkey patched LoRA + # function of text encoder can correctly access it + if lora_scale is not None and isinstance(self, StableDiffusionLoraLoaderMixin): + self._lora_scale = lora_scale + + # dynamically adjust the LoRA scale + if not USE_PEFT_BACKEND: + adjust_lora_scale_text_encoder(self.text_encoder, lora_scale) + else: + scale_lora_layers(self.text_encoder, lora_scale) + + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + if prompt_embeds is None: + # textual inversion: process multi-vector tokens if necessary + if isinstance(self, TextualInversionLoaderMixin): + prompt = self.maybe_convert_prompt(prompt, self.tokenizer) + + text_inputs = self.tokenizer( + prompt, + padding="max_length", + max_length=self.tokenizer.model_max_length, + truncation=True, + return_tensors="pt", + ) + text_input_ids = text_inputs.input_ids + untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids + + if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal( + text_input_ids, untruncated_ids + ): + removed_text = self.tokenizer.batch_decode( + untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1] + ) + logger.warning( + "The following part of your input was truncated because CLIP can only handle sequences up to" + f" {self.tokenizer.model_max_length} tokens: {removed_text}" + ) + + if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask: + attention_mask = text_inputs.attention_mask.to(device) + else: + attention_mask = None + + if clip_skip is None: + prompt_embeds = self.text_encoder(text_input_ids.to(device), attention_mask=attention_mask) + prompt_embeds = prompt_embeds[0] + else: + prompt_embeds = self.text_encoder( + text_input_ids.to(device), attention_mask=attention_mask, output_hidden_states=True + ) + # Access the `hidden_states` first, that contains a tuple of + # all the hidden states from the encoder layers. Then index into + # the tuple to access the hidden states from the desired layer. + prompt_embeds = prompt_embeds[-1][-(clip_skip + 1)] + # We also need to apply the final LayerNorm here to not mess with the + # representations. The `last_hidden_states` that we typically use for + # obtaining the final prompt representations passes through the LayerNorm + # layer. + prompt_embeds = self.text_encoder.text_model.final_layer_norm(prompt_embeds) + + if self.text_encoder is not None: + prompt_embeds_dtype = self.text_encoder.dtype + elif self.unet is not None: + prompt_embeds_dtype = self.unet.dtype + else: + prompt_embeds_dtype = prompt_embeds.dtype + + prompt_embeds = prompt_embeds.to(dtype=prompt_embeds_dtype, device=device) + + bs_embed, seq_len, _ = prompt_embeds.shape + # duplicate text embeddings for each generation per prompt, using mps friendly method + prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) + prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1) + + # get unconditional embeddings for classifier free guidance + if do_classifier_free_guidance and negative_prompt_embeds is None: + uncond_tokens: List[str] + if negative_prompt is None: + uncond_tokens = [""] * batch_size + elif prompt is not None and type(prompt) is not type(negative_prompt): + raise TypeError( + f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" + f" {type(prompt)}." + ) + elif isinstance(negative_prompt, str): + uncond_tokens = [negative_prompt] + elif batch_size != len(negative_prompt): + raise ValueError( + f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" + f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" + " the batch size of `prompt`." + ) + else: + uncond_tokens = negative_prompt + + # textual inversion: process multi-vector tokens if necessary + if isinstance(self, TextualInversionLoaderMixin): + uncond_tokens = self.maybe_convert_prompt(uncond_tokens, self.tokenizer) + + max_length = prompt_embeds.shape[1] + uncond_input = self.tokenizer( + uncond_tokens, + padding="max_length", + max_length=max_length, + truncation=True, + return_tensors="pt", + ) + + if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask: + attention_mask = uncond_input.attention_mask.to(device) + else: + attention_mask = None + + negative_prompt_embeds = self.text_encoder( + uncond_input.input_ids.to(device), + attention_mask=attention_mask, + ) + negative_prompt_embeds = negative_prompt_embeds[0] + + if do_classifier_free_guidance: + # duplicate unconditional embeddings for each generation per prompt, using mps friendly method + seq_len = negative_prompt_embeds.shape[1] + + negative_prompt_embeds = negative_prompt_embeds.to(dtype=prompt_embeds_dtype, device=device) + + negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1) + negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) + + if self.text_encoder is not None: + if isinstance(self, StableDiffusionLoraLoaderMixin) and USE_PEFT_BACKEND: + # Retrieve the original scale by scaling back the LoRA layers + unscale_lora_layers(self.text_encoder, lora_scale) + return prompt_embeds, negative_prompt_embeds def get_learned_conditioning(self, c): @@ -82,6 +247,3 @@ def get_learned_conditioning(self, c): c = self.frozen_CLIP_embedder_t3(c) return c - - def get_unconditional_conditioning(self, N): - return self.get_learned_conditioning({"c_crossattn": [[""] * N], "text_info": None}) From 7dbd4bc41561e8acdbf852b2b800a2c6a0af76ea Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tolga=20Cang=C3=B6z?= Date: Mon, 5 Aug 2024 14:31:32 +0300 Subject: [PATCH 33/87] Simplify for now --- .../anytext/auxiliary_latent_module.py | 18 +- .../anytext/embedding_manager.py | 4 - .../anytext/pipeline_anytext.py | 11 +- .../anytext/text_embedding_module.py | 188 +----------------- 4 files changed, 15 insertions(+), 206 deletions(-) diff --git a/examples/research_projects/anytext/auxiliary_latent_module.py b/examples/research_projects/anytext/auxiliary_latent_module.py index 6571d5bb5c9b..e62fe0f9f5a8 100644 --- a/examples/research_projects/anytext/auxiliary_latent_module.py +++ b/examples/research_projects/anytext/auxiliary_latent_module.py @@ -52,14 +52,11 @@ def retrieve_latents( class AuxiliaryLatentModule(nn.Module): - def __init__(self, font_path, dims=2, glyph_channels=256, position_channels=64, model_channels=256, **kwargs): + def __init__(self, dims, glyph_channels, position_channels, model_channels, **kwargs): super().__init__() - if font_path is None: - raise ValueError("font_path must be provided!") - self.font = ImageFont.truetype(font_path, 60) + self.font = ImageFont.truetype("./font/Arial_Unicode.ttf", 60) self.use_fp16 = kwargs.get("use_fp16", False) self.device = kwargs.get("device", "cpu") - self.scale_factor = 0.18215 self.glyph_block = nn.Sequential( conv_nd(dims, glyph_channels, 8, 3, padding=1), nn.SiLU(), @@ -98,15 +95,8 @@ def __init__(self, font_path, dims=2, glyph_channels=256, position_channels=64, nn.SiLU(), ) - self.vae = AutoencoderKL.from_pretrained( - "runwayml/stable-diffusion-v1-5", - subfolder="vae", - torch_dtype=torch.float16 if self.use_fp16 else torch.float32, - variant="fp16" if self.use_fp16 else "fp32", - ) + self.vae = kwargs.get("vae") self.vae.eval() - for param in self.vae.parameters(): - param.requires_grad = False self.fuse_block = zero_module(conv_nd(dims, 256 + 64 + 4, model_channels, 3, padding=1)) @@ -257,7 +247,7 @@ def forward( return guided_hint, hint, info def encode_first_stage(self, masked_img): - return retrieve_latents(self.vae.encode(masked_img)) * self.scale_factor + return retrieve_latents(self.vae.encode(masked_img)) * self.vae.scale_factor def arr2tensor(self, arr, bs): arr = np.transpose(arr, (2, 0, 1)) diff --git a/examples/research_projects/anytext/embedding_manager.py b/examples/research_projects/anytext/embedding_manager.py index cbaab5aab682..11dbb4d70c4d 100644 --- a/examples/research_projects/anytext/embedding_manager.py +++ b/examples/research_projects/anytext/embedding_manager.py @@ -156,10 +156,6 @@ def encode_text(self, text_info): if self.emb_type == "ocr": recog_emb = self.get_recog_emb(gline_list) enc_glyph = self.proj(recog_emb.reshape(recog_emb.shape[0], -1)) - elif self.emb_type == "vit": - enc_glyph = self.get_vision_emb(pad_H(torch.cat(gline_list, dim=0))) - elif self.emb_type == "conv": - enc_glyph = self.glyph_encoder(pad_H(torch.cat(gline_list, dim=0))) if self.add_pos: enc_pos = self.position_encoder(torch.cat(gline_list, dim=0)) enc_glyph = enc_glyph + enc_pos diff --git a/examples/research_projects/anytext/pipeline_anytext.py b/examples/research_projects/anytext/pipeline_anytext.py index 1f392f903794..f43952cd6600 100644 --- a/examples/research_projects/anytext/pipeline_anytext.py +++ b/examples/research_projects/anytext/pipeline_anytext.py @@ -218,11 +218,10 @@ def __init__( feature_extractor: CLIPImageProcessor, image_encoder: CLIPVisionModelWithProjection = None, requires_safety_checker: bool = True, - font_path: str = None, ): super().__init__() - self.text_embedding_module = TextEmbeddingModule(text_encoder, tokenizer) - self.auxiliary_latent_module = AuxiliaryLatentModule(font_path) + self.text_embedding_module = TextEmbeddingModule(use_fp16=unet.dtype == torch.float16) + self.auxiliary_latent_module = AuxiliaryLatentModule(vae=vae, use_fp16=unet.dtype == torch.float16) if safety_checker is None and requires_safety_checker: logger.warning( @@ -1228,16 +1227,10 @@ def __call__( ) prompt_embeds, negative_prompt_embeds = self.text_embedding_module( prompt, - device, - num_images_per_prompt, - self.do_classifier_free_guidance, - hint, text_info, negative_prompt, prompt_embeds=prompt_embeds, negative_prompt_embeds=negative_prompt_embeds, - lora_scale=text_encoder_lora_scale, - clip_skip=self.clip_skip, ) # 5. Prepare timesteps timesteps, num_inference_steps = retrieve_timesteps( diff --git a/examples/research_projects/anytext/text_embedding_module.py b/examples/research_projects/anytext/text_embedding_module.py index 5f5ccb8e4112..585d9e3f818a 100644 --- a/examples/research_projects/anytext/text_embedding_module.py +++ b/examples/research_projects/anytext/text_embedding_module.py @@ -31,11 +31,11 @@ class TextEmbeddingModule(nn.Module): - def __init__(self, font_path, device, use_fp16): + def __init__(self, use_fp16): super().__init__() - self.device = device + self.device = "cuda" if torch.cuda.is_available() else "cpu" # TODO: Learn if the recommended font file is free to use - self.font = ImageFont.truetype(font_path, 60) + self.font = ImageFont.truetype("./font/Arial_Unicode.ttf", 60) self.frozen_CLIP_embedder_t3 = FrozenCLIPEmbedderT3(device=self.device) self.embedding_manager_config = { "valid": True, @@ -49,12 +49,12 @@ def __init__(self, font_path, device, use_fp16): # TODO: Understand the reason of param.requires_grad = True for param in self.embedding_manager.embedding_parameters(): param.requires_grad = True - rec_model_dir = "./ocr_weights/ppv3_rec.pth" + rec_model_dir = "./ocr/ppv3_rec.pth" self.text_predictor = create_predictor(rec_model_dir).eval() args = {} args["rec_image_shape"] = "3, 48, 320" args["rec_batch_num"] = 6 - args["rec_char_dict_path"] = "./ocr_recog/ppocr_keys_v1.txt" + args["rec_char_dict_path"] = "./ocr/ppocr_keys_v1.txt" args["use_fp16"] = use_fp16 self.cn_recognizer = TextRecognizer(args, self.text_predictor) for param in self.text_predictor.parameters(): @@ -65,185 +65,15 @@ def __init__(self, font_path, device, use_fp16): def forward( self, prompt, - device, - num_images_per_prompt, - do_classifier_free_guidance, - hint, text_info, negative_prompt=None, prompt_embeds: Optional[torch.Tensor] = None, negative_prompt_embeds: Optional[torch.Tensor] = None, - lora_scale: Optional[float] = None, - clip_skip: Optional[int] = None, ): - # TODO: Convert `get_learned_conditioning` functions to `diffusers`' format - prompt_embeds = self.get_learned_conditioning( - {"c_concat": [hint], "c_crossattn": [[prompt] * num_images_per_prompt], "text_info": text_info} - ) - negative_prompt_embeds = self.get_learned_conditioning( - {"c_concat": [hint], "c_crossattn": [[negative_prompt] * num_images_per_prompt], "text_info": text_info} - ) + self.embedding_manager.encode_text(text_info) + prompt_embeds = self.frozen_CLIP_embedder_t3.encode([prompt], embedding_manager=self.embedding_manager) - # set lora scale so that monkey patched LoRA - # function of text encoder can correctly access it - if lora_scale is not None and isinstance(self, StableDiffusionLoraLoaderMixin): - self._lora_scale = lora_scale - - # dynamically adjust the LoRA scale - if not USE_PEFT_BACKEND: - adjust_lora_scale_text_encoder(self.text_encoder, lora_scale) - else: - scale_lora_layers(self.text_encoder, lora_scale) - - if prompt is not None and isinstance(prompt, str): - batch_size = 1 - elif prompt is not None and isinstance(prompt, list): - batch_size = len(prompt) - else: - batch_size = prompt_embeds.shape[0] - - if prompt_embeds is None: - # textual inversion: process multi-vector tokens if necessary - if isinstance(self, TextualInversionLoaderMixin): - prompt = self.maybe_convert_prompt(prompt, self.tokenizer) - - text_inputs = self.tokenizer( - prompt, - padding="max_length", - max_length=self.tokenizer.model_max_length, - truncation=True, - return_tensors="pt", - ) - text_input_ids = text_inputs.input_ids - untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids - - if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal( - text_input_ids, untruncated_ids - ): - removed_text = self.tokenizer.batch_decode( - untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1] - ) - logger.warning( - "The following part of your input was truncated because CLIP can only handle sequences up to" - f" {self.tokenizer.model_max_length} tokens: {removed_text}" - ) - - if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask: - attention_mask = text_inputs.attention_mask.to(device) - else: - attention_mask = None - - if clip_skip is None: - prompt_embeds = self.text_encoder(text_input_ids.to(device), attention_mask=attention_mask) - prompt_embeds = prompt_embeds[0] - else: - prompt_embeds = self.text_encoder( - text_input_ids.to(device), attention_mask=attention_mask, output_hidden_states=True - ) - # Access the `hidden_states` first, that contains a tuple of - # all the hidden states from the encoder layers. Then index into - # the tuple to access the hidden states from the desired layer. - prompt_embeds = prompt_embeds[-1][-(clip_skip + 1)] - # We also need to apply the final LayerNorm here to not mess with the - # representations. The `last_hidden_states` that we typically use for - # obtaining the final prompt representations passes through the LayerNorm - # layer. - prompt_embeds = self.text_encoder.text_model.final_layer_norm(prompt_embeds) - - if self.text_encoder is not None: - prompt_embeds_dtype = self.text_encoder.dtype - elif self.unet is not None: - prompt_embeds_dtype = self.unet.dtype - else: - prompt_embeds_dtype = prompt_embeds.dtype - - prompt_embeds = prompt_embeds.to(dtype=prompt_embeds_dtype, device=device) - - bs_embed, seq_len, _ = prompt_embeds.shape - # duplicate text embeddings for each generation per prompt, using mps friendly method - prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) - prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1) - - # get unconditional embeddings for classifier free guidance - if do_classifier_free_guidance and negative_prompt_embeds is None: - uncond_tokens: List[str] - if negative_prompt is None: - uncond_tokens = [""] * batch_size - elif prompt is not None and type(prompt) is not type(negative_prompt): - raise TypeError( - f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" - f" {type(prompt)}." - ) - elif isinstance(negative_prompt, str): - uncond_tokens = [negative_prompt] - elif batch_size != len(negative_prompt): - raise ValueError( - f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" - f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" - " the batch size of `prompt`." - ) - else: - uncond_tokens = negative_prompt - - # textual inversion: process multi-vector tokens if necessary - if isinstance(self, TextualInversionLoaderMixin): - uncond_tokens = self.maybe_convert_prompt(uncond_tokens, self.tokenizer) - - max_length = prompt_embeds.shape[1] - uncond_input = self.tokenizer( - uncond_tokens, - padding="max_length", - max_length=max_length, - truncation=True, - return_tensors="pt", - ) - - if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask: - attention_mask = uncond_input.attention_mask.to(device) - else: - attention_mask = None - - negative_prompt_embeds = self.text_encoder( - uncond_input.input_ids.to(device), - attention_mask=attention_mask, - ) - negative_prompt_embeds = negative_prompt_embeds[0] - - if do_classifier_free_guidance: - # duplicate unconditional embeddings for each generation per prompt, using mps friendly method - seq_len = negative_prompt_embeds.shape[1] - - negative_prompt_embeds = negative_prompt_embeds.to(dtype=prompt_embeds_dtype, device=device) - - negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1) - negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) - - if self.text_encoder is not None: - if isinstance(self, StableDiffusionLoraLoaderMixin) and USE_PEFT_BACKEND: - # Retrieve the original scale by scaling back the LoRA layers - unscale_lora_layers(self.text_encoder, lora_scale) + self.embedding_manager.encode_text(text_info) + negative_prompt_embeds = self.frozen_CLIP_embedder_t3.encode([negative_prompt], embedding_manager=self.embedding_manager) return prompt_embeds, negative_prompt_embeds - - def get_learned_conditioning(self, c): - if hasattr(self.frozen_CLIP_embedder_t3, "encode") and callable(self.frozen_CLIP_embedder_t3.encode): - if self.embedding_manager is not None and c["text_info"] is not None: - self.embedding_manager.encode_text(c["text_info"]) - if isinstance(c, dict): - cond_txt = c["c_crossattn"][0] - else: - cond_txt = c - if self.embedding_manager is not None: - cond_txt = self.frozen_CLIP_embedder_t3.encode(cond_txt, embedding_manager=self.embedding_manager) - else: - cond_txt = self.frozen_CLIP_embedder_t3.encode(cond_txt) - if isinstance(c, dict): - c["c_crossattn"][0] = cond_txt - else: - c = cond_txt - if isinstance(c, DiagonalGaussianDistribution): - c = c.mode() - else: - c = self.frozen_CLIP_embedder_t3(c) - - return c From f422423438bc39ea5cbd07aeed1f980bc52584bf Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tolga=20Cang=C3=B6z?= Date: Mon, 5 Aug 2024 14:32:12 +0300 Subject: [PATCH 34/87] `make style` --- .../anytext/auxiliary_latent_module.py | 1 - .../anytext/text_embedding_module.py | 15 ++++----------- 2 files changed, 4 insertions(+), 12 deletions(-) diff --git a/examples/research_projects/anytext/auxiliary_latent_module.py b/examples/research_projects/anytext/auxiliary_latent_module.py index e62fe0f9f5a8..5c5737fc0da5 100644 --- a/examples/research_projects/anytext/auxiliary_latent_module.py +++ b/examples/research_projects/anytext/auxiliary_latent_module.py @@ -10,7 +10,6 @@ from PIL import Image, ImageDraw, ImageFont from torch import nn -from diffusers.models.autoencoders import AutoencoderKL from diffusers.utils import logging diff --git a/examples/research_projects/anytext/text_embedding_module.py b/examples/research_projects/anytext/text_embedding_module.py index 585d9e3f818a..77ed5d7328a6 100644 --- a/examples/research_projects/anytext/text_embedding_module.py +++ b/examples/research_projects/anytext/text_embedding_module.py @@ -3,23 +3,14 @@ # text -> tokenizer -> -from typing import List, Optional +from typing import Optional import torch from PIL import ImageFont from torch import nn -from diffusers.loaders import ( - StableDiffusionLoraLoaderMixin, - TextualInversionLoaderMixin, -) -from diffusers.models.autoencoders.vae import DiagonalGaussianDistribution -from diffusers.models.lora import adjust_lora_scale_text_encoder from diffusers.utils import ( - USE_PEFT_BACKEND, logging, - scale_lora_layers, - unscale_lora_layers, ) from .embedding_manager import EmbeddingManager @@ -74,6 +65,8 @@ def forward( prompt_embeds = self.frozen_CLIP_embedder_t3.encode([prompt], embedding_manager=self.embedding_manager) self.embedding_manager.encode_text(text_info) - negative_prompt_embeds = self.frozen_CLIP_embedder_t3.encode([negative_prompt], embedding_manager=self.embedding_manager) + negative_prompt_embeds = self.frozen_CLIP_embedder_t3.encode( + [negative_prompt], embedding_manager=self.embedding_manager + ) return prompt_embeds, negative_prompt_embeds From 846600953d33b698e7aa0f08961db44fe32fabe9 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tolga=20Cang=C3=B6z?= Date: Mon, 5 Aug 2024 18:28:22 +0300 Subject: [PATCH 35/87] Up --- .../anytext/auxiliary_latent_module.py | 4 +-- .../anytext/frozen_clip_embedder_t3.py | 10 +++--- .../anytext/pipeline_anytext.py | 36 ++----------------- .../research_projects/anytext/recognizer.py | 8 ++--- .../anytext/text_embedding_module.py | 13 ++++--- 5 files changed, 18 insertions(+), 53 deletions(-) diff --git a/examples/research_projects/anytext/auxiliary_latent_module.py b/examples/research_projects/anytext/auxiliary_latent_module.py index 5c5737fc0da5..5223bbe98066 100644 --- a/examples/research_projects/anytext/auxiliary_latent_module.py +++ b/examples/research_projects/anytext/auxiliary_latent_module.py @@ -51,9 +51,9 @@ def retrieve_latents( class AuxiliaryLatentModule(nn.Module): - def __init__(self, dims, glyph_channels, position_channels, model_channels, **kwargs): + def __init__(self, dims=2, glyph_channels=1, position_channels=1, model_channels=320, **kwargs): super().__init__() - self.font = ImageFont.truetype("./font/Arial_Unicode.ttf", 60) + self.font = ImageFont.truetype("/home/x/Documents/gits/AnyText/font/Arial_Unicode.ttf", 60) self.use_fp16 = kwargs.get("use_fp16", False) self.device = kwargs.get("device", "cpu") self.glyph_block = nn.Sequential( diff --git a/examples/research_projects/anytext/frozen_clip_embedder_t3.py b/examples/research_projects/anytext/frozen_clip_embedder_t3.py index 7de2b8aed492..3ca9db0502fa 100644 --- a/examples/research_projects/anytext/frozen_clip_embedder_t3.py +++ b/examples/research_projects/anytext/frozen_clip_embedder_t3.py @@ -1,7 +1,7 @@ import torch from torch import nn from transformers import AutoProcessor, CLIPTextModel, CLIPTokenizer, CLIPVisionModelWithProjection -from transformers.models.clip.modeling_clip import _build_causal_attention_mask, _expand_mask +from transformers.modeling_attn_mask_utils import _create_4d_causal_attention_mask, _prepare_4d_attention_mask class AbstractEncoder(nn.Module): @@ -108,16 +108,14 @@ def text_encoder_forward( hidden_states = self.embeddings( input_ids=input_ids, position_ids=position_ids, embedding_manager=embedding_manager ) - bsz, seq_len = input_shape # CLIP's text model uses causal mask, prepare it here. # https://github.com/openai/CLIP/blob/cfcffb90e69f37bf2ff1e988237a0fbe41f33c04/clip/model.py#L324 - causal_attention_mask = _build_causal_attention_mask(bsz, seq_len, hidden_states.dtype).to( - hidden_states.device - ) + causal_attention_mask = _create_4d_causal_attention_mask(input_shape, hidden_states.dtype, + device=hidden_states.device) # expand attention_mask if attention_mask is not None: # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] - attention_mask = _expand_mask(attention_mask, hidden_states.dtype) + attention_mask = _prepare_4d_attention_mask(attention_mask, hidden_states.dtype) last_hidden_state = self.encoder( inputs_embeds=hidden_states, attention_mask=attention_mask, diff --git a/examples/research_projects/anytext/pipeline_anytext.py b/examples/research_projects/anytext/pipeline_anytext.py index f43952cd6600..b56b8dccb692 100644 --- a/examples/research_projects/anytext/pipeline_anytext.py +++ b/examples/research_projects/anytext/pipeline_anytext.py @@ -621,7 +621,7 @@ def prepare_extra_step_kwargs(self, generator, eta): def check_inputs( self, prompt, - image, + # image, callback_steps, negative_prompt=None, prompt_embeds=None, @@ -676,39 +676,6 @@ def check_inputs( is_compiled = hasattr(F, "scaled_dot_product_attention") and isinstance( self.controlnet, torch._dynamo.eval_frame.OptimizedModule ) - if ( - isinstance(self.controlnet, ControlNetModel) - or is_compiled - and isinstance(self.controlnet._orig_mod, ControlNetModel) - ): - self.check_image(image, prompt, prompt_embeds) - elif ( - isinstance(self.controlnet, MultiControlNetModel) - or is_compiled - and isinstance(self.controlnet._orig_mod, MultiControlNetModel) - ): - if not isinstance(image, list): - raise TypeError("For multiple controlnets: `image` must be type `list`") - - # When `image` is a nested list: - # (e.g. [[canny_image_1, pose_image_1], [canny_image_2, pose_image_2]]) - elif any(isinstance(i, list) for i in image): - transposed_image = [list(t) for t in zip(*image)] - if len(transposed_image) != len(self.controlnet.nets): - raise ValueError( - f"For multiple controlnets: if you pass`image` as a list of list, each sublist must have the same length as the number of controlnets, but the sublists in `image` got {len(transposed_image)} images and {len(self.controlnet.nets)} ControlNets." - ) - for image_ in transposed_image: - self.check_image(image_, prompt, prompt_embeds) - elif len(image) != len(self.controlnet.nets): - raise ValueError( - f"For multiple controlnets: `image` must have the same length as the number of controlnets, but got {len(image)} images and {len(self.controlnet.nets)} ControlNets." - ) - else: - for image_ in image: - self.check_image(image_, prompt, prompt_embeds) - else: - assert False # Check `controlnet_conditioning_scale` if ( @@ -717,6 +684,7 @@ def check_inputs( and isinstance(self.controlnet._orig_mod, ControlNetModel) ): if not isinstance(controlnet_conditioning_scale, float): + print(controlnet_conditioning_scale) raise TypeError("For single controlnet: `controlnet_conditioning_scale` must be type `float`.") elif ( isinstance(self.controlnet, MultiControlNetModel) diff --git a/examples/research_projects/anytext/recognizer.py b/examples/research_projects/anytext/recognizer.py index a9fa3880906a..33f138f81ae5 100755 --- a/examples/research_projects/anytext/recognizer.py +++ b/examples/research_projects/anytext/recognizer.py @@ -134,13 +134,13 @@ def get_image_file_list(img_file): class TextRecognizer(object): def __init__(self, args, predictor): - self.rec_image_shape = [int(v) for v in args.rec_image_shape.split(",")] - self.rec_batch_num = args.rec_batch_num + self.rec_image_shape = [int(v) for v in args['rec_image_shape'].split(",")] + self.rec_batch_num = args['rec_batch_num'] self.predictor = predictor - self.chars = self.get_char_dict(args.rec_char_dict_path) + self.chars = self.get_char_dict(args['rec_char_dict_path']) self.char2id = {x: i for i, x in enumerate(self.chars)} self.is_onnx = not isinstance(self.predictor, torch.nn.Module) - self.use_fp16 = args.use_fp16 + self.use_fp16 = args['use_fp16'] # img: CHW def resize_norm_img(self, img, max_wh_ratio): diff --git a/examples/research_projects/anytext/text_embedding_module.py b/examples/research_projects/anytext/text_embedding_module.py index 77ed5d7328a6..0efbed25d50b 100644 --- a/examples/research_projects/anytext/text_embedding_module.py +++ b/examples/research_projects/anytext/text_embedding_module.py @@ -6,17 +6,16 @@ from typing import Optional import torch +from embedding_manager import EmbeddingManager +from frozen_clip_embedder_t3 import FrozenCLIPEmbedderT3 from PIL import ImageFont +from recognizer import TextRecognizer, create_predictor from torch import nn from diffusers.utils import ( logging, ) -from .embedding_manager import EmbeddingManager -from .frozen_clip_embedder_t3 import FrozenCLIPEmbedderT3 -from .recognizer import TextRecognizer, create_predictor - logger = logging.get_logger(__name__) # pylint: disable=invalid-name @@ -26,7 +25,7 @@ def __init__(self, use_fp16): super().__init__() self.device = "cuda" if torch.cuda.is_available() else "cpu" # TODO: Learn if the recommended font file is free to use - self.font = ImageFont.truetype("./font/Arial_Unicode.ttf", 60) + self.font = ImageFont.truetype("/home/x/Documents/gits/AnyText/font/Arial_Unicode.ttf", 60) self.frozen_CLIP_embedder_t3 = FrozenCLIPEmbedderT3(device=self.device) self.embedding_manager_config = { "valid": True, @@ -40,12 +39,12 @@ def __init__(self, use_fp16): # TODO: Understand the reason of param.requires_grad = True for param in self.embedding_manager.embedding_parameters(): param.requires_grad = True - rec_model_dir = "./ocr/ppv3_rec.pth" + rec_model_dir = "/home/x/Documents/gits/AnyText/ocr_weights/ppv3_rec.pth" self.text_predictor = create_predictor(rec_model_dir).eval() args = {} args["rec_image_shape"] = "3, 48, 320" args["rec_batch_num"] = 6 - args["rec_char_dict_path"] = "./ocr/ppocr_keys_v1.txt" + args["rec_char_dict_path"] = "/home/x/Documents/gits/AnyText/ocr_weights/ppocr_keys_v1.txt" args["use_fp16"] = use_fp16 self.cn_recognizer = TextRecognizer(args, self.text_predictor) for param in self.text_predictor.parameters(): From 2b4be7a701245d8a3206c79b893fa9e0771ee057 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tolga=20Cang=C3=B6z?= Date: Mon, 5 Aug 2024 18:29:54 +0300 Subject: [PATCH 36/87] feat: Add scripts to convert AnyText controlnet to diffusers --- ...convert_anytext_controlnet_to_diffusers.py | 111 + .../anytext/convert_from_ckpt.py | 1872 +++++++++++++++++ 2 files changed, 1983 insertions(+) create mode 100644 examples/research_projects/anytext/convert_anytext_controlnet_to_diffusers.py create mode 100644 examples/research_projects/anytext/convert_from_ckpt.py diff --git a/examples/research_projects/anytext/convert_anytext_controlnet_to_diffusers.py b/examples/research_projects/anytext/convert_anytext_controlnet_to_diffusers.py new file mode 100644 index 000000000000..52c3b5281b41 --- /dev/null +++ b/examples/research_projects/anytext/convert_anytext_controlnet_to_diffusers.py @@ -0,0 +1,111 @@ +# coding=utf-8 +# Copyright 2024 The HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Conversion script for stable diffusion checkpoints which _only_ contain a controlnet.""" + +import argparse + +from convert_from_ckpt import download_controlnet_from_original_ckpt + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + + parser.add_argument( + "--checkpoint_path", default=None, type=str, required=False, help="Path to the checkpoint to convert." + ) + parser.add_argument( + "--original_config_file", + type=str, + required=False, + help="The YAML config file corresponding to the original architecture.", + ) + parser.add_argument( + "--num_in_channels", + default=None, + type=int, + help="The number of input channels. If `None` number of input channels will be automatically inferred.", + ) + parser.add_argument( + "--image_size", + default=512, + type=int, + help=( + "The image size that the model was trained on. Use 512 for Stable Diffusion v1.X and Stable Diffusion v2" + " Base. Use 768 for Stable Diffusion v2." + ), + ) + parser.add_argument( + "--extract_ema", + action="store_true", + help=( + "Only relevant for checkpoints that have both EMA and non-EMA weights. Whether to extract the EMA weights" + " or not. Defaults to `False`. Add `--extract_ema` to extract the EMA weights. EMA weights usually yield" + " higher quality images for inference. Non-EMA weights are usually better to continue fine-tuning." + ), + ) + parser.add_argument( + "--upcast_attention", + action="store_true", + help=( + "Whether the attention computation should always be upcasted. This is necessary when running stable" + " diffusion 2.1." + ), + ) + parser.add_argument( + "--from_safetensors", + action="store_true", + help="If `--checkpoint_path` is in `safetensors` format, load checkpoint with safetensors instead of PyTorch.", + ) + parser.add_argument( + "--to_safetensors", + action="store_true", + help="Whether to store pipeline in safetensors format or not.", + ) + parser.add_argument("--dump_path", default=None, type=str, required=False, help="Path to the output model.") + parser.add_argument("--device", type=str, help="Device to use (e.g. cpu, cuda:0, cuda:1, etc.)") + + # small workaround to get argparser to parse a boolean input as either true _or_ false + def parse_bool(string): + if string == "True": + return True + elif string == "False": + return False + else: + raise ValueError(f"could not parse string as bool {string}") + + parser.add_argument( + "--use_linear_projection", help="Override for use linear projection", required=False, type=parse_bool + ) + + parser.add_argument("--cross_attention_dim", help="Override for cross attention_dim", required=False, type=int) + + args = parser.parse_args() + + controlnet = download_controlnet_from_original_ckpt( + checkpoint_path="/home/x/Documents/gits/AnyText/anytext_v1.1.ckpt", + original_config_file="/home/x/Documents/gits/AnyText/models_yaml/anytext_sd15.yaml", + image_size=args.image_size, + extract_ema=args.extract_ema, + num_in_channels=args.num_in_channels, + upcast_attention=args.upcast_attention, + from_safetensors=args.from_safetensors, + device="cpu", + use_linear_projection=args.use_linear_projection, + cross_attention_dim=args.cross_attention_dim, + ) + + controlnet.save_pretrained( + "/home/x/Documents/gits/diffusers/examples/research_projects/anytext", safe_serialization=False + ) diff --git a/examples/research_projects/anytext/convert_from_ckpt.py b/examples/research_projects/anytext/convert_from_ckpt.py new file mode 100644 index 000000000000..9968462399ce --- /dev/null +++ b/examples/research_projects/anytext/convert_from_ckpt.py @@ -0,0 +1,1872 @@ +# coding=utf-8 +# Copyright 2024 The HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Conversion script for the Stable Diffusion checkpoints.""" + +import re +from contextlib import nullcontext +from io import BytesIO +from typing import Dict, Optional, Union + +import requests +import torch +import yaml +from transformers import ( + AutoFeatureExtractor, + BertTokenizerFast, + CLIPImageProcessor, + CLIPTextConfig, + CLIPTextModel, + CLIPTextModelWithProjection, + CLIPTokenizer, + CLIPVisionConfig, + CLIPVisionModelWithProjection, +) + +from ...models import ( + AutoencoderKL, + ControlNetModel, + PriorTransformer, + UNet2DConditionModel, +) +from ...schedulers import ( + DDIMScheduler, + DDPMScheduler, + DPMSolverMultistepScheduler, + EulerAncestralDiscreteScheduler, + EulerDiscreteScheduler, + HeunDiscreteScheduler, + LMSDiscreteScheduler, + PNDMScheduler, + UnCLIPScheduler, +) +from ...utils import is_accelerate_available, logging +from ..latent_diffusion.pipeline_latent_diffusion import LDMBertConfig, LDMBertModel +from ..paint_by_example import PaintByExampleImageEncoder +from ..pipeline_utils import DiffusionPipeline +from .safety_checker import StableDiffusionSafetyChecker +from .stable_unclip_image_normalizer import StableUnCLIPImageNormalizer + + +if is_accelerate_available(): + from accelerate import init_empty_weights + from accelerate.utils import set_module_tensor_to_device + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +def shave_segments(path, n_shave_prefix_segments=1): + """ + Removes segments. Positive values shave the first segments, negative shave the last segments. + """ + if n_shave_prefix_segments >= 0: + return ".".join(path.split(".")[n_shave_prefix_segments:]) + else: + return ".".join(path.split(".")[:n_shave_prefix_segments]) + + +def renew_resnet_paths(old_list, n_shave_prefix_segments=0): + """ + Updates paths inside resnets to the new naming scheme (local renaming) + """ + mapping = [] + for old_item in old_list: + new_item = old_item.replace("in_layers.0", "norm1") + new_item = new_item.replace("in_layers.2", "conv1") + + new_item = new_item.replace("out_layers.0", "norm2") + new_item = new_item.replace("out_layers.3", "conv2") + + new_item = new_item.replace("emb_layers.1", "time_emb_proj") + new_item = new_item.replace("skip_connection", "conv_shortcut") + + new_item = shave_segments(new_item, n_shave_prefix_segments=n_shave_prefix_segments) + + mapping.append({"old": old_item, "new": new_item}) + + return mapping + + +def renew_vae_resnet_paths(old_list, n_shave_prefix_segments=0): + """ + Updates paths inside resnets to the new naming scheme (local renaming) + """ + mapping = [] + for old_item in old_list: + new_item = old_item + + new_item = new_item.replace("nin_shortcut", "conv_shortcut") + new_item = shave_segments(new_item, n_shave_prefix_segments=n_shave_prefix_segments) + + mapping.append({"old": old_item, "new": new_item}) + + return mapping + + +def renew_attention_paths(old_list, n_shave_prefix_segments=0): + """ + Updates paths inside attentions to the new naming scheme (local renaming) + """ + mapping = [] + for old_item in old_list: + new_item = old_item + + # new_item = new_item.replace('norm.weight', 'group_norm.weight') + # new_item = new_item.replace('norm.bias', 'group_norm.bias') + + # new_item = new_item.replace('proj_out.weight', 'proj_attn.weight') + # new_item = new_item.replace('proj_out.bias', 'proj_attn.bias') + + # new_item = shave_segments(new_item, n_shave_prefix_segments=n_shave_prefix_segments) + + mapping.append({"old": old_item, "new": new_item}) + + return mapping + + +def renew_vae_attention_paths(old_list, n_shave_prefix_segments=0): + """ + Updates paths inside attentions to the new naming scheme (local renaming) + """ + mapping = [] + for old_item in old_list: + new_item = old_item + + new_item = new_item.replace("norm.weight", "group_norm.weight") + new_item = new_item.replace("norm.bias", "group_norm.bias") + + new_item = new_item.replace("q.weight", "to_q.weight") + new_item = new_item.replace("q.bias", "to_q.bias") + + new_item = new_item.replace("k.weight", "to_k.weight") + new_item = new_item.replace("k.bias", "to_k.bias") + + new_item = new_item.replace("v.weight", "to_v.weight") + new_item = new_item.replace("v.bias", "to_v.bias") + + new_item = new_item.replace("proj_out.weight", "to_out.0.weight") + new_item = new_item.replace("proj_out.bias", "to_out.0.bias") + + new_item = shave_segments(new_item, n_shave_prefix_segments=n_shave_prefix_segments) + + mapping.append({"old": old_item, "new": new_item}) + + return mapping + + +def assign_to_checkpoint( + paths, checkpoint, old_checkpoint, attention_paths_to_split=None, additional_replacements=None, config=None +): + """ + This does the final conversion step: take locally converted weights and apply a global renaming to them. It splits + attention layers, and takes into account additional replacements that may arise. + + Assigns the weights to the new checkpoint. + """ + assert isinstance(paths, list), "Paths should be a list of dicts containing 'old' and 'new' keys." + + # Splits the attention layers into three variables. + if attention_paths_to_split is not None: + for path, path_map in attention_paths_to_split.items(): + old_tensor = old_checkpoint[path] + channels = old_tensor.shape[0] // 3 + + target_shape = (-1, channels) if len(old_tensor.shape) == 3 else (-1) + + num_heads = old_tensor.shape[0] // config["num_head_channels"] // 3 + + old_tensor = old_tensor.reshape((num_heads, 3 * channels // num_heads) + old_tensor.shape[1:]) + query, key, value = old_tensor.split(channels // num_heads, dim=1) + + checkpoint[path_map["query"]] = query.reshape(target_shape) + checkpoint[path_map["key"]] = key.reshape(target_shape) + checkpoint[path_map["value"]] = value.reshape(target_shape) + + for path in paths: + new_path = path["new"] + + # These have already been assigned + if attention_paths_to_split is not None and new_path in attention_paths_to_split: + continue + + # Global renaming happens here + new_path = new_path.replace("middle_block.0", "mid_block.resnets.0") + new_path = new_path.replace("middle_block.1", "mid_block.attentions.0") + new_path = new_path.replace("middle_block.2", "mid_block.resnets.1") + + if additional_replacements is not None: + for replacement in additional_replacements: + new_path = new_path.replace(replacement["old"], replacement["new"]) + + # proj_attn.weight has to be converted from conv 1D to linear + is_attn_weight = "proj_attn.weight" in new_path or ("attentions" in new_path and "to_" in new_path) + shape = old_checkpoint[path["old"]].shape + if is_attn_weight and len(shape) == 3: + checkpoint[new_path] = old_checkpoint[path["old"]][:, :, 0] + elif is_attn_weight and len(shape) == 4: + checkpoint[new_path] = old_checkpoint[path["old"]][:, :, 0, 0] + else: + checkpoint[new_path] = old_checkpoint[path["old"]] + + +def conv_attn_to_linear(checkpoint): + keys = list(checkpoint.keys()) + attn_keys = ["query.weight", "key.weight", "value.weight"] + for key in keys: + if ".".join(key.split(".")[-2:]) in attn_keys: + if checkpoint[key].ndim > 2: + checkpoint[key] = checkpoint[key][:, :, 0, 0] + elif "proj_attn.weight" in key: + if checkpoint[key].ndim > 2: + checkpoint[key] = checkpoint[key][:, :, 0] + + +def create_unet_diffusers_config(original_config, image_size: int, controlnet=False): + """ + Creates a config for the diffusers based on the config of the LDM model. + """ + if controlnet: + unet_params = original_config["model"]["params"]["control_stage_config"]["params"] + else: + if ( + "unet_config" in original_config["model"]["params"] + and original_config["model"]["params"]["unet_config"] is not None + ): + unet_params = original_config["model"]["params"]["unet_config"]["params"] + else: + unet_params = original_config["model"]["params"]["network_config"]["params"] + + vae_params = original_config["model"]["params"]["first_stage_config"]["params"]["ddconfig"] + + block_out_channels = [unet_params["model_channels"] * mult for mult in unet_params["channel_mult"]] + + down_block_types = [] + resolution = 1 + for i in range(len(block_out_channels)): + block_type = "CrossAttnDownBlock2D" if resolution in unet_params["attention_resolutions"] else "DownBlock2D" + down_block_types.append(block_type) + if i != len(block_out_channels) - 1: + resolution *= 2 + + up_block_types = [] + for i in range(len(block_out_channels)): + block_type = "CrossAttnUpBlock2D" if resolution in unet_params["attention_resolutions"] else "UpBlock2D" + up_block_types.append(block_type) + resolution //= 2 + + if unet_params["transformer_depth"] is not None: + transformer_layers_per_block = ( + unet_params["transformer_depth"] + if isinstance(unet_params["transformer_depth"], int) + else list(unet_params["transformer_depth"]) + ) + else: + transformer_layers_per_block = 1 + + vae_scale_factor = 2 ** (len(vae_params["ch_mult"]) - 1) + + head_dim = unet_params["num_heads"] if "num_heads" in unet_params else None + use_linear_projection = ( + unet_params["use_linear_in_transformer"] if "use_linear_in_transformer" in unet_params else False + ) + if use_linear_projection: + # stable diffusion 2-base-512 and 2-768 + if head_dim is None: + head_dim_mult = unet_params["model_channels"] // unet_params["num_head_channels"] + head_dim = [head_dim_mult * c for c in list(unet_params["channel_mult"])] + + class_embed_type = None + addition_embed_type = None + addition_time_embed_dim = None + projection_class_embeddings_input_dim = None + context_dim = None + + if unet_params["context_dim"] is not None: + context_dim = ( + unet_params["context_dim"] + if isinstance(unet_params["context_dim"], int) + else unet_params["context_dim"][0] + ) + + if "num_classes" in unet_params: + if unet_params["num_classes"] == "sequential": + if context_dim in [2048, 1280]: + # SDXL + addition_embed_type = "text_time" + addition_time_embed_dim = 256 + else: + class_embed_type = "projection" + assert "adm_in_channels" in unet_params + projection_class_embeddings_input_dim = unet_params["adm_in_channels"] + + config = { + "sample_size": image_size // vae_scale_factor, + "in_channels": unet_params["in_channels"], + "down_block_types": tuple(down_block_types), + "block_out_channels": tuple(block_out_channels), + "layers_per_block": unet_params["num_res_blocks"], + "cross_attention_dim": context_dim, + "attention_head_dim": head_dim, + "use_linear_projection": use_linear_projection, + "class_embed_type": class_embed_type, + "addition_embed_type": addition_embed_type, + "addition_time_embed_dim": addition_time_embed_dim, + "projection_class_embeddings_input_dim": projection_class_embeddings_input_dim, + "transformer_layers_per_block": transformer_layers_per_block, + } + + if "disable_self_attentions" in unet_params: + config["only_cross_attention"] = unet_params["disable_self_attentions"] + + if "num_classes" in unet_params and isinstance(unet_params["num_classes"], int): + config["num_class_embeds"] = unet_params["num_classes"] + + if not controlnet: + config["out_channels"] = unet_params["out_channels"] + config["up_block_types"] = tuple(up_block_types) + + return config + + +def create_vae_diffusers_config(original_config, image_size: int): + """ + Creates a config for the diffusers based on the config of the LDM model. + """ + vae_params = original_config["model"]["params"]["first_stage_config"]["params"]["ddconfig"] + _ = original_config["model"]["params"]["first_stage_config"]["params"]["embed_dim"] + + block_out_channels = [vae_params["ch"] * mult for mult in vae_params["ch_mult"]] + down_block_types = ["DownEncoderBlock2D"] * len(block_out_channels) + up_block_types = ["UpDecoderBlock2D"] * len(block_out_channels) + + config = { + "sample_size": image_size, + "in_channels": vae_params["in_channels"], + "out_channels": vae_params["out_ch"], + "down_block_types": tuple(down_block_types), + "up_block_types": tuple(up_block_types), + "block_out_channels": tuple(block_out_channels), + "latent_channels": vae_params["z_channels"], + "layers_per_block": vae_params["num_res_blocks"], + } + return config + + +def create_diffusers_schedular(original_config): + schedular = DDIMScheduler( + num_train_timesteps=original_config["model"]["params"]["timesteps"], + beta_start=original_config["model"]["params"]["linear_start"], + beta_end=original_config["model"]["params"]["linear_end"], + beta_schedule="scaled_linear", + ) + return schedular + + +def create_ldm_bert_config(original_config): + bert_params = original_config["model"]["params"]["cond_stage_config"]["params"] + config = LDMBertConfig( + d_model=bert_params.n_embed, + encoder_layers=bert_params.n_layer, + encoder_ffn_dim=bert_params.n_embed * 4, + ) + return config + + +def convert_ldm_unet_checkpoint( + checkpoint, config, path=None, extract_ema=False, controlnet=False, skip_extract_state_dict=False +): + """ + Takes a state dict and a config, and returns a converted checkpoint. + """ + + if skip_extract_state_dict: + unet_state_dict = checkpoint + else: + # extract state_dict for UNet + unet_state_dict = {} + keys = list(checkpoint.keys()) + + if controlnet: + unet_key = "control_model." + else: + unet_key = "model.diffusion_model." + + # at least a 100 parameters have to start with `model_ema` in order for the checkpoint to be EMA + if sum(k.startswith("model_ema") for k in keys) > 100 and extract_ema: + logger.warning(f"Checkpoint {path} has both EMA and non-EMA weights.") + logger.warning( + "In this conversion only the EMA weights are extracted. If you want to instead extract the non-EMA" + " weights (useful to continue fine-tuning), please make sure to remove the `--extract_ema` flag." + ) + for key in keys: + if key.startswith("model.diffusion_model"): + flat_ema_key = "model_ema." + "".join(key.split(".")[1:]) + unet_state_dict[key.replace(unet_key, "")] = checkpoint.pop(flat_ema_key) + else: + if sum(k.startswith("model_ema") for k in keys) > 100: + logger.warning( + "In this conversion only the non-EMA weights are extracted. If you want to instead extract the EMA" + " weights (usually better for inference), please make sure to add the `--extract_ema` flag." + ) + + for key in keys: + if ( + key.startswith(unet_key) + and not key.startswith("control_model.glyph_block") + and not key.startswith("control_model.position_block") + and not key.startswith("control_model.fuse_block") + ): + unet_state_dict[key.replace(unet_key, "")] = checkpoint.pop(key) + + new_checkpoint = {} + + new_checkpoint["time_embedding.linear_1.weight"] = unet_state_dict["time_embed.0.weight"] + new_checkpoint["time_embedding.linear_1.bias"] = unet_state_dict["time_embed.0.bias"] + new_checkpoint["time_embedding.linear_2.weight"] = unet_state_dict["time_embed.2.weight"] + new_checkpoint["time_embedding.linear_2.bias"] = unet_state_dict["time_embed.2.bias"] + + if config["class_embed_type"] is None: + # No parameters to port + ... + elif config["class_embed_type"] == "timestep" or config["class_embed_type"] == "projection": + new_checkpoint["class_embedding.linear_1.weight"] = unet_state_dict["label_emb.0.0.weight"] + new_checkpoint["class_embedding.linear_1.bias"] = unet_state_dict["label_emb.0.0.bias"] + new_checkpoint["class_embedding.linear_2.weight"] = unet_state_dict["label_emb.0.2.weight"] + new_checkpoint["class_embedding.linear_2.bias"] = unet_state_dict["label_emb.0.2.bias"] + else: + raise NotImplementedError(f"Not implemented `class_embed_type`: {config['class_embed_type']}") + + if config["addition_embed_type"] == "text_time": + new_checkpoint["add_embedding.linear_1.weight"] = unet_state_dict["label_emb.0.0.weight"] + new_checkpoint["add_embedding.linear_1.bias"] = unet_state_dict["label_emb.0.0.bias"] + new_checkpoint["add_embedding.linear_2.weight"] = unet_state_dict["label_emb.0.2.weight"] + new_checkpoint["add_embedding.linear_2.bias"] = unet_state_dict["label_emb.0.2.bias"] + + # Relevant to StableDiffusionUpscalePipeline + if "num_class_embeds" in config: + if (config["num_class_embeds"] is not None) and ("label_emb.weight" in unet_state_dict): + new_checkpoint["class_embedding.weight"] = unet_state_dict["label_emb.weight"] + + new_checkpoint["conv_in.weight"] = unet_state_dict["input_blocks.0.0.weight"] + new_checkpoint["conv_in.bias"] = unet_state_dict["input_blocks.0.0.bias"] + + if not controlnet: + new_checkpoint["conv_norm_out.weight"] = unet_state_dict["out.0.weight"] + new_checkpoint["conv_norm_out.bias"] = unet_state_dict["out.0.bias"] + new_checkpoint["conv_out.weight"] = unet_state_dict["out.2.weight"] + new_checkpoint["conv_out.bias"] = unet_state_dict["out.2.bias"] + + # Retrieves the keys for the input blocks only + num_input_blocks = len({".".join(layer.split(".")[:2]) for layer in unet_state_dict if "input_blocks" in layer}) + input_blocks = { + layer_id: [key for key in unet_state_dict if f"input_blocks.{layer_id}" in key] + for layer_id in range(num_input_blocks) + } + + # Retrieves the keys for the middle blocks only + num_middle_blocks = len({".".join(layer.split(".")[:2]) for layer in unet_state_dict if "middle_block" in layer}) + middle_blocks = { + layer_id: [key for key in unet_state_dict if f"middle_block.{layer_id}" in key] + for layer_id in range(num_middle_blocks) + } + + # Retrieves the keys for the output blocks only + num_output_blocks = len({".".join(layer.split(".")[:2]) for layer in unet_state_dict if "output_blocks" in layer}) + output_blocks = { + layer_id: [key for key in unet_state_dict if f"output_blocks.{layer_id}" in key] + for layer_id in range(num_output_blocks) + } + + for i in range(1, num_input_blocks): + block_id = (i - 1) // (config["layers_per_block"] + 1) + layer_in_block_id = (i - 1) % (config["layers_per_block"] + 1) + + resnets = [ + key for key in input_blocks[i] if f"input_blocks.{i}.0" in key and f"input_blocks.{i}.0.op" not in key + ] + attentions = [key for key in input_blocks[i] if f"input_blocks.{i}.1" in key] + + if f"input_blocks.{i}.0.op.weight" in unet_state_dict: + new_checkpoint[f"down_blocks.{block_id}.downsamplers.0.conv.weight"] = unet_state_dict.pop( + f"input_blocks.{i}.0.op.weight" + ) + new_checkpoint[f"down_blocks.{block_id}.downsamplers.0.conv.bias"] = unet_state_dict.pop( + f"input_blocks.{i}.0.op.bias" + ) + + paths = renew_resnet_paths(resnets) + meta_path = {"old": f"input_blocks.{i}.0", "new": f"down_blocks.{block_id}.resnets.{layer_in_block_id}"} + assign_to_checkpoint( + paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config + ) + + if len(attentions): + paths = renew_attention_paths(attentions) + + meta_path = {"old": f"input_blocks.{i}.1", "new": f"down_blocks.{block_id}.attentions.{layer_in_block_id}"} + assign_to_checkpoint( + paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config + ) + + resnet_0 = middle_blocks[0] + attentions = middle_blocks[1] + resnet_1 = middle_blocks[2] + + resnet_0_paths = renew_resnet_paths(resnet_0) + assign_to_checkpoint(resnet_0_paths, new_checkpoint, unet_state_dict, config=config) + + resnet_1_paths = renew_resnet_paths(resnet_1) + assign_to_checkpoint(resnet_1_paths, new_checkpoint, unet_state_dict, config=config) + + attentions_paths = renew_attention_paths(attentions) + meta_path = {"old": "middle_block.1", "new": "mid_block.attentions.0"} + assign_to_checkpoint( + attentions_paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config + ) + + for i in range(num_output_blocks): + block_id = i // (config["layers_per_block"] + 1) + layer_in_block_id = i % (config["layers_per_block"] + 1) + output_block_layers = [shave_segments(name, 2) for name in output_blocks[i]] + output_block_list = {} + + for layer in output_block_layers: + layer_id, layer_name = layer.split(".")[0], shave_segments(layer, 1) + if layer_id in output_block_list: + output_block_list[layer_id].append(layer_name) + else: + output_block_list[layer_id] = [layer_name] + + if len(output_block_list) > 1: + resnets = [key for key in output_blocks[i] if f"output_blocks.{i}.0" in key] + attentions = [key for key in output_blocks[i] if f"output_blocks.{i}.1" in key] + + resnet_0_paths = renew_resnet_paths(resnets) + paths = renew_resnet_paths(resnets) + + meta_path = {"old": f"output_blocks.{i}.0", "new": f"up_blocks.{block_id}.resnets.{layer_in_block_id}"} + assign_to_checkpoint( + paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config + ) + + output_block_list = {k: sorted(v) for k, v in sorted(output_block_list.items())} + if ["conv.bias", "conv.weight"] in output_block_list.values(): + index = list(output_block_list.values()).index(["conv.bias", "conv.weight"]) + new_checkpoint[f"up_blocks.{block_id}.upsamplers.0.conv.weight"] = unet_state_dict[ + f"output_blocks.{i}.{index}.conv.weight" + ] + new_checkpoint[f"up_blocks.{block_id}.upsamplers.0.conv.bias"] = unet_state_dict[ + f"output_blocks.{i}.{index}.conv.bias" + ] + + # Clear attentions as they have been attributed above. + if len(attentions) == 2: + attentions = [] + + if len(attentions): + paths = renew_attention_paths(attentions) + meta_path = { + "old": f"output_blocks.{i}.1", + "new": f"up_blocks.{block_id}.attentions.{layer_in_block_id}", + } + assign_to_checkpoint( + paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config + ) + else: + resnet_0_paths = renew_resnet_paths(output_block_layers, n_shave_prefix_segments=1) + for path in resnet_0_paths: + old_path = ".".join(["output_blocks", str(i), path["old"]]) + new_path = ".".join(["up_blocks", str(block_id), "resnets", str(layer_in_block_id), path["new"]]) + + new_checkpoint[new_path] = unet_state_dict[old_path] + + if controlnet: + # # conditioning embedding + + # orig_index = 0 + + # new_checkpoint["controlnet_cond_embedding.conv_in.weight"] = unet_state_dict.pop( + # f"input_hint_block.{orig_index}.weight" + # ) + # new_checkpoint["controlnet_cond_embedding.conv_in.bias"] = unet_state_dict.pop( + # f"input_hint_block.{orig_index}.bias" + # ) + + # orig_index += 2 + + # diffusers_index = 0 + + # while diffusers_index < 6: + # new_checkpoint[f"controlnet_cond_embedding.blocks.{diffusers_index}.weight"] = unet_state_dict.pop( + # f"input_hint_block.{orig_index}.weight" + # ) + # new_checkpoint[f"controlnet_cond_embedding.blocks.{diffusers_index}.bias"] = unet_state_dict.pop( + # f"input_hint_block.{orig_index}.bias" + # ) + # diffusers_index += 1 + # orig_index += 2 + + # new_checkpoint["controlnet_cond_embedding.conv_out.weight"] = unet_state_dict.pop( + # f"input_hint_block.{orig_index}.weight" + # ) + # new_checkpoint["controlnet_cond_embedding.conv_out.bias"] = unet_state_dict.pop( + # f"input_hint_block.{orig_index}.bias" + # ) + + # down blocks + for i in range(num_input_blocks): + new_checkpoint[f"controlnet_down_blocks.{i}.weight"] = unet_state_dict.pop(f"zero_convs.{i}.0.weight") + new_checkpoint[f"controlnet_down_blocks.{i}.bias"] = unet_state_dict.pop(f"zero_convs.{i}.0.bias") + + # mid block + new_checkpoint["controlnet_mid_block.weight"] = unet_state_dict.pop("middle_block_out.0.weight") + new_checkpoint["controlnet_mid_block.bias"] = unet_state_dict.pop("middle_block_out.0.bias") + + return new_checkpoint + + +def convert_ldm_vae_checkpoint(checkpoint, config): + # extract state dict for VAE + vae_state_dict = {} + keys = list(checkpoint.keys()) + vae_key = "first_stage_model." if any(k.startswith("first_stage_model.") for k in keys) else "" + for key in keys: + if key.startswith(vae_key): + vae_state_dict[key.replace(vae_key, "")] = checkpoint.get(key) + + new_checkpoint = {} + + new_checkpoint["encoder.conv_in.weight"] = vae_state_dict["encoder.conv_in.weight"] + new_checkpoint["encoder.conv_in.bias"] = vae_state_dict["encoder.conv_in.bias"] + new_checkpoint["encoder.conv_out.weight"] = vae_state_dict["encoder.conv_out.weight"] + new_checkpoint["encoder.conv_out.bias"] = vae_state_dict["encoder.conv_out.bias"] + new_checkpoint["encoder.conv_norm_out.weight"] = vae_state_dict["encoder.norm_out.weight"] + new_checkpoint["encoder.conv_norm_out.bias"] = vae_state_dict["encoder.norm_out.bias"] + + new_checkpoint["decoder.conv_in.weight"] = vae_state_dict["decoder.conv_in.weight"] + new_checkpoint["decoder.conv_in.bias"] = vae_state_dict["decoder.conv_in.bias"] + new_checkpoint["decoder.conv_out.weight"] = vae_state_dict["decoder.conv_out.weight"] + new_checkpoint["decoder.conv_out.bias"] = vae_state_dict["decoder.conv_out.bias"] + new_checkpoint["decoder.conv_norm_out.weight"] = vae_state_dict["decoder.norm_out.weight"] + new_checkpoint["decoder.conv_norm_out.bias"] = vae_state_dict["decoder.norm_out.bias"] + + new_checkpoint["quant_conv.weight"] = vae_state_dict["quant_conv.weight"] + new_checkpoint["quant_conv.bias"] = vae_state_dict["quant_conv.bias"] + new_checkpoint["post_quant_conv.weight"] = vae_state_dict["post_quant_conv.weight"] + new_checkpoint["post_quant_conv.bias"] = vae_state_dict["post_quant_conv.bias"] + + # Retrieves the keys for the encoder down blocks only + num_down_blocks = len({".".join(layer.split(".")[:3]) for layer in vae_state_dict if "encoder.down" in layer}) + down_blocks = { + layer_id: [key for key in vae_state_dict if f"down.{layer_id}" in key] for layer_id in range(num_down_blocks) + } + + # Retrieves the keys for the decoder up blocks only + num_up_blocks = len({".".join(layer.split(".")[:3]) for layer in vae_state_dict if "decoder.up" in layer}) + up_blocks = { + layer_id: [key for key in vae_state_dict if f"up.{layer_id}" in key] for layer_id in range(num_up_blocks) + } + + for i in range(num_down_blocks): + resnets = [key for key in down_blocks[i] if f"down.{i}" in key and f"down.{i}.downsample" not in key] + + if f"encoder.down.{i}.downsample.conv.weight" in vae_state_dict: + new_checkpoint[f"encoder.down_blocks.{i}.downsamplers.0.conv.weight"] = vae_state_dict.pop( + f"encoder.down.{i}.downsample.conv.weight" + ) + new_checkpoint[f"encoder.down_blocks.{i}.downsamplers.0.conv.bias"] = vae_state_dict.pop( + f"encoder.down.{i}.downsample.conv.bias" + ) + + paths = renew_vae_resnet_paths(resnets) + meta_path = {"old": f"down.{i}.block", "new": f"down_blocks.{i}.resnets"} + assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config) + + mid_resnets = [key for key in vae_state_dict if "encoder.mid.block" in key] + num_mid_res_blocks = 2 + for i in range(1, num_mid_res_blocks + 1): + resnets = [key for key in mid_resnets if f"encoder.mid.block_{i}" in key] + + paths = renew_vae_resnet_paths(resnets) + meta_path = {"old": f"mid.block_{i}", "new": f"mid_block.resnets.{i - 1}"} + assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config) + + mid_attentions = [key for key in vae_state_dict if "encoder.mid.attn" in key] + paths = renew_vae_attention_paths(mid_attentions) + meta_path = {"old": "mid.attn_1", "new": "mid_block.attentions.0"} + assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config) + conv_attn_to_linear(new_checkpoint) + + for i in range(num_up_blocks): + block_id = num_up_blocks - 1 - i + resnets = [ + key for key in up_blocks[block_id] if f"up.{block_id}" in key and f"up.{block_id}.upsample" not in key + ] + + if f"decoder.up.{block_id}.upsample.conv.weight" in vae_state_dict: + new_checkpoint[f"decoder.up_blocks.{i}.upsamplers.0.conv.weight"] = vae_state_dict[ + f"decoder.up.{block_id}.upsample.conv.weight" + ] + new_checkpoint[f"decoder.up_blocks.{i}.upsamplers.0.conv.bias"] = vae_state_dict[ + f"decoder.up.{block_id}.upsample.conv.bias" + ] + + paths = renew_vae_resnet_paths(resnets) + meta_path = {"old": f"up.{block_id}.block", "new": f"up_blocks.{i}.resnets"} + assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config) + + mid_resnets = [key for key in vae_state_dict if "decoder.mid.block" in key] + num_mid_res_blocks = 2 + for i in range(1, num_mid_res_blocks + 1): + resnets = [key for key in mid_resnets if f"decoder.mid.block_{i}" in key] + + paths = renew_vae_resnet_paths(resnets) + meta_path = {"old": f"mid.block_{i}", "new": f"mid_block.resnets.{i - 1}"} + assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config) + + mid_attentions = [key for key in vae_state_dict if "decoder.mid.attn" in key] + paths = renew_vae_attention_paths(mid_attentions) + meta_path = {"old": "mid.attn_1", "new": "mid_block.attentions.0"} + assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config) + conv_attn_to_linear(new_checkpoint) + return new_checkpoint + + +def convert_ldm_bert_checkpoint(checkpoint, config): + def _copy_attn_layer(hf_attn_layer, pt_attn_layer): + hf_attn_layer.q_proj.weight.data = pt_attn_layer.to_q.weight + hf_attn_layer.k_proj.weight.data = pt_attn_layer.to_k.weight + hf_attn_layer.v_proj.weight.data = pt_attn_layer.to_v.weight + + hf_attn_layer.out_proj.weight = pt_attn_layer.to_out.weight + hf_attn_layer.out_proj.bias = pt_attn_layer.to_out.bias + + def _copy_linear(hf_linear, pt_linear): + hf_linear.weight = pt_linear.weight + hf_linear.bias = pt_linear.bias + + def _copy_layer(hf_layer, pt_layer): + # copy layer norms + _copy_linear(hf_layer.self_attn_layer_norm, pt_layer[0][0]) + _copy_linear(hf_layer.final_layer_norm, pt_layer[1][0]) + + # copy attn + _copy_attn_layer(hf_layer.self_attn, pt_layer[0][1]) + + # copy MLP + pt_mlp = pt_layer[1][1] + _copy_linear(hf_layer.fc1, pt_mlp.net[0][0]) + _copy_linear(hf_layer.fc2, pt_mlp.net[2]) + + def _copy_layers(hf_layers, pt_layers): + for i, hf_layer in enumerate(hf_layers): + if i != 0: + i += i + pt_layer = pt_layers[i : i + 2] + _copy_layer(hf_layer, pt_layer) + + hf_model = LDMBertModel(config).eval() + + # copy embeds + hf_model.model.embed_tokens.weight = checkpoint.transformer.token_emb.weight + hf_model.model.embed_positions.weight.data = checkpoint.transformer.pos_emb.emb.weight + + # copy layer norm + _copy_linear(hf_model.model.layer_norm, checkpoint.transformer.norm) + + # copy hidden layers + _copy_layers(hf_model.model.layers, checkpoint.transformer.attn_layers.layers) + + _copy_linear(hf_model.to_logits, checkpoint.transformer.to_logits) + + return hf_model + + +def convert_ldm_clip_checkpoint(checkpoint, local_files_only=False, text_encoder=None): + if text_encoder is None: + config_name = "openai/clip-vit-large-patch14" + try: + config = CLIPTextConfig.from_pretrained(config_name, local_files_only=local_files_only) + except Exception: + raise ValueError( + f"With local_files_only set to {local_files_only}, you must first locally save the configuration in the following path: 'openai/clip-vit-large-patch14'." + ) + + ctx = init_empty_weights if is_accelerate_available() else nullcontext + with ctx(): + text_model = CLIPTextModel(config) + else: + text_model = text_encoder + + keys = list(checkpoint.keys()) + + text_model_dict = {} + + remove_prefixes = ["cond_stage_model.transformer", "conditioner.embedders.0.transformer"] + + for key in keys: + for prefix in remove_prefixes: + if key.startswith(prefix): + text_model_dict[key[len(prefix + ".") :]] = checkpoint[key] + + if is_accelerate_available(): + for param_name, param in text_model_dict.items(): + set_module_tensor_to_device(text_model, param_name, "cpu", value=param) + else: + if not (hasattr(text_model, "embeddings") and hasattr(text_model.embeddings.position_ids)): + text_model_dict.pop("text_model.embeddings.position_ids", None) + + text_model.load_state_dict(text_model_dict) + + return text_model + + +textenc_conversion_lst = [ + ("positional_embedding", "text_model.embeddings.position_embedding.weight"), + ("token_embedding.weight", "text_model.embeddings.token_embedding.weight"), + ("ln_final.weight", "text_model.final_layer_norm.weight"), + ("ln_final.bias", "text_model.final_layer_norm.bias"), + ("text_projection", "text_projection.weight"), +] +textenc_conversion_map = {x[0]: x[1] for x in textenc_conversion_lst} + +textenc_transformer_conversion_lst = [ + # (stable-diffusion, HF Diffusers) + ("resblocks.", "text_model.encoder.layers."), + ("ln_1", "layer_norm1"), + ("ln_2", "layer_norm2"), + (".c_fc.", ".fc1."), + (".c_proj.", ".fc2."), + (".attn", ".self_attn"), + ("ln_final.", "transformer.text_model.final_layer_norm."), + ("token_embedding.weight", "transformer.text_model.embeddings.token_embedding.weight"), + ("positional_embedding", "transformer.text_model.embeddings.position_embedding.weight"), +] +protected = {re.escape(x[0]): x[1] for x in textenc_transformer_conversion_lst} +textenc_pattern = re.compile("|".join(protected.keys())) + + +def convert_paint_by_example_checkpoint(checkpoint, local_files_only=False): + config = CLIPVisionConfig.from_pretrained("openai/clip-vit-large-patch14", local_files_only=local_files_only) + model = PaintByExampleImageEncoder(config) + + keys = list(checkpoint.keys()) + + text_model_dict = {} + + for key in keys: + if key.startswith("cond_stage_model.transformer"): + text_model_dict[key[len("cond_stage_model.transformer.") :]] = checkpoint[key] + + # load clip vision + model.model.load_state_dict(text_model_dict) + + # load mapper + keys_mapper = { + k[len("cond_stage_model.mapper.res") :]: v + for k, v in checkpoint.items() + if k.startswith("cond_stage_model.mapper") + } + + MAPPING = { + "attn.c_qkv": ["attn1.to_q", "attn1.to_k", "attn1.to_v"], + "attn.c_proj": ["attn1.to_out.0"], + "ln_1": ["norm1"], + "ln_2": ["norm3"], + "mlp.c_fc": ["ff.net.0.proj"], + "mlp.c_proj": ["ff.net.2"], + } + + mapped_weights = {} + for key, value in keys_mapper.items(): + prefix = key[: len("blocks.i")] + suffix = key.split(prefix)[-1].split(".")[-1] + name = key.split(prefix)[-1].split(suffix)[0][1:-1] + mapped_names = MAPPING[name] + + num_splits = len(mapped_names) + for i, mapped_name in enumerate(mapped_names): + new_name = ".".join([prefix, mapped_name, suffix]) + shape = value.shape[0] // num_splits + mapped_weights[new_name] = value[i * shape : (i + 1) * shape] + + model.mapper.load_state_dict(mapped_weights) + + # load final layer norm + model.final_layer_norm.load_state_dict( + { + "bias": checkpoint["cond_stage_model.final_ln.bias"], + "weight": checkpoint["cond_stage_model.final_ln.weight"], + } + ) + + # load final proj + model.proj_out.load_state_dict( + { + "bias": checkpoint["proj_out.bias"], + "weight": checkpoint["proj_out.weight"], + } + ) + + # load uncond vector + model.uncond_vector.data = torch.nn.Parameter(checkpoint["learnable_vector"]) + return model + + +def convert_open_clip_checkpoint( + checkpoint, + config_name, + prefix="cond_stage_model.model.", + has_projection=False, + local_files_only=False, + **config_kwargs, +): + # text_model = CLIPTextModel.from_pretrained("stabilityai/stable-diffusion-2", subfolder="text_encoder") + # text_model = CLIPTextModelWithProjection.from_pretrained( + # "laion/CLIP-ViT-bigG-14-laion2B-39B-b160k", projection_dim=1280 + # ) + try: + config = CLIPTextConfig.from_pretrained(config_name, **config_kwargs, local_files_only=local_files_only) + except Exception: + raise ValueError( + f"With local_files_only set to {local_files_only}, you must first locally save the configuration in the following path: '{config_name}'." + ) + + ctx = init_empty_weights if is_accelerate_available() else nullcontext + with ctx(): + text_model = CLIPTextModelWithProjection(config) if has_projection else CLIPTextModel(config) + + keys = list(checkpoint.keys()) + + keys_to_ignore = [] + if config_name == "stabilityai/stable-diffusion-2" and config.num_hidden_layers == 23: + # make sure to remove all keys > 22 + keys_to_ignore += [k for k in keys if k.startswith("cond_stage_model.model.transformer.resblocks.23")] + keys_to_ignore += ["cond_stage_model.model.text_projection"] + + text_model_dict = {} + + if prefix + "text_projection" in checkpoint: + d_model = int(checkpoint[prefix + "text_projection"].shape[0]) + else: + d_model = 1024 + + text_model_dict["text_model.embeddings.position_ids"] = text_model.text_model.embeddings.get_buffer("position_ids") + + for key in keys: + if key in keys_to_ignore: + continue + if key[len(prefix) :] in textenc_conversion_map: + if key.endswith("text_projection"): + value = checkpoint[key].T.contiguous() + else: + value = checkpoint[key] + + text_model_dict[textenc_conversion_map[key[len(prefix) :]]] = value + + if key.startswith(prefix + "transformer."): + new_key = key[len(prefix + "transformer.") :] + if new_key.endswith(".in_proj_weight"): + new_key = new_key[: -len(".in_proj_weight")] + new_key = textenc_pattern.sub(lambda m: protected[re.escape(m.group(0))], new_key) + text_model_dict[new_key + ".q_proj.weight"] = checkpoint[key][:d_model, :] + text_model_dict[new_key + ".k_proj.weight"] = checkpoint[key][d_model : d_model * 2, :] + text_model_dict[new_key + ".v_proj.weight"] = checkpoint[key][d_model * 2 :, :] + elif new_key.endswith(".in_proj_bias"): + new_key = new_key[: -len(".in_proj_bias")] + new_key = textenc_pattern.sub(lambda m: protected[re.escape(m.group(0))], new_key) + text_model_dict[new_key + ".q_proj.bias"] = checkpoint[key][:d_model] + text_model_dict[new_key + ".k_proj.bias"] = checkpoint[key][d_model : d_model * 2] + text_model_dict[new_key + ".v_proj.bias"] = checkpoint[key][d_model * 2 :] + else: + new_key = textenc_pattern.sub(lambda m: protected[re.escape(m.group(0))], new_key) + + text_model_dict[new_key] = checkpoint[key] + + if is_accelerate_available(): + for param_name, param in text_model_dict.items(): + set_module_tensor_to_device(text_model, param_name, "cpu", value=param) + else: + if not (hasattr(text_model, "embeddings") and hasattr(text_model.embeddings.position_ids)): + text_model_dict.pop("text_model.embeddings.position_ids", None) + + text_model.load_state_dict(text_model_dict) + + return text_model + + +def stable_unclip_image_encoder(original_config, local_files_only=False): + """ + Returns the image processor and clip image encoder for the img2img unclip pipeline. + + We currently know of two types of stable unclip models which separately use the clip and the openclip image + encoders. + """ + + image_embedder_config = original_config["model"]["params"]["embedder_config"] + + sd_clip_image_embedder_class = image_embedder_config["target"] + sd_clip_image_embedder_class = sd_clip_image_embedder_class.split(".")[-1] + + if sd_clip_image_embedder_class == "ClipImageEmbedder": + clip_model_name = image_embedder_config.params.model + + if clip_model_name == "ViT-L/14": + feature_extractor = CLIPImageProcessor() + image_encoder = CLIPVisionModelWithProjection.from_pretrained( + "openai/clip-vit-large-patch14", local_files_only=local_files_only + ) + else: + raise NotImplementedError(f"Unknown CLIP checkpoint name in stable diffusion checkpoint {clip_model_name}") + + elif sd_clip_image_embedder_class == "FrozenOpenCLIPImageEmbedder": + feature_extractor = CLIPImageProcessor() + image_encoder = CLIPVisionModelWithProjection.from_pretrained( + "laion/CLIP-ViT-H-14-laion2B-s32B-b79K", local_files_only=local_files_only + ) + else: + raise NotImplementedError( + f"Unknown CLIP image embedder class in stable diffusion checkpoint {sd_clip_image_embedder_class}" + ) + + return feature_extractor, image_encoder + + +def stable_unclip_image_noising_components( + original_config, clip_stats_path: Optional[str] = None, device: Optional[str] = None +): + """ + Returns the noising components for the img2img and txt2img unclip pipelines. + + Converts the stability noise augmentor into + 1. a `StableUnCLIPImageNormalizer` for holding the CLIP stats + 2. a `DDPMScheduler` for holding the noise schedule + + If the noise augmentor config specifies a clip stats path, the `clip_stats_path` must be provided. + """ + noise_aug_config = original_config["model"]["params"]["noise_aug_config"] + noise_aug_class = noise_aug_config["target"] + noise_aug_class = noise_aug_class.split(".")[-1] + + if noise_aug_class == "CLIPEmbeddingNoiseAugmentation": + noise_aug_config = noise_aug_config.params + embedding_dim = noise_aug_config.timestep_dim + max_noise_level = noise_aug_config.noise_schedule_config.timesteps + beta_schedule = noise_aug_config.noise_schedule_config.beta_schedule + + image_normalizer = StableUnCLIPImageNormalizer(embedding_dim=embedding_dim) + image_noising_scheduler = DDPMScheduler(num_train_timesteps=max_noise_level, beta_schedule=beta_schedule) + + if "clip_stats_path" in noise_aug_config: + if clip_stats_path is None: + raise ValueError("This stable unclip config requires a `clip_stats_path`") + + clip_mean, clip_std = torch.load(clip_stats_path, map_location=device) + clip_mean = clip_mean[None, :] + clip_std = clip_std[None, :] + + clip_stats_state_dict = { + "mean": clip_mean, + "std": clip_std, + } + + image_normalizer.load_state_dict(clip_stats_state_dict) + else: + raise NotImplementedError(f"Unknown noise augmentor class: {noise_aug_class}") + + return image_normalizer, image_noising_scheduler + + +def convert_controlnet_checkpoint( + checkpoint, + original_config, + checkpoint_path, + image_size, + upcast_attention, + extract_ema, + use_linear_projection=None, + cross_attention_dim=None, +): + ctrlnet_config = create_unet_diffusers_config(original_config, image_size=image_size, controlnet=True) + ctrlnet_config["upcast_attention"] = upcast_attention + + ctrlnet_config.pop("sample_size") + + if use_linear_projection is not None: + ctrlnet_config["use_linear_projection"] = use_linear_projection + + if cross_attention_dim is not None: + ctrlnet_config["cross_attention_dim"] = cross_attention_dim + + ctx = init_empty_weights if is_accelerate_available() else nullcontext + with ctx(): + controlnet = ControlNetModel(**ctrlnet_config) + + # Some controlnet ckpt files are distributed independently from the rest of the + # model components i.e. https://huggingface.co/thibaud/controlnet-sd21/ + if "time_embed.0.weight" in checkpoint: + skip_extract_state_dict = True + else: + skip_extract_state_dict = False + + converted_ctrl_checkpoint = convert_ldm_unet_checkpoint( + checkpoint, + ctrlnet_config, + path=checkpoint_path, + extract_ema=extract_ema, + controlnet=True, + skip_extract_state_dict=skip_extract_state_dict, + ) + + if is_accelerate_available(): + for param_name, param in converted_ctrl_checkpoint.items(): + set_module_tensor_to_device(controlnet, param_name, "cpu", value=param) + else: + controlnet.load_state_dict(converted_ctrl_checkpoint) + + return controlnet + + +def download_from_original_stable_diffusion_ckpt( + checkpoint_path_or_dict: Union[str, Dict[str, torch.Tensor]], + original_config_file: str = None, + image_size: Optional[int] = None, + prediction_type: str = None, + model_type: str = None, + extract_ema: bool = False, + scheduler_type: str = "pndm", + num_in_channels: Optional[int] = None, + upcast_attention: Optional[bool] = None, + device: str = None, + from_safetensors: bool = False, + stable_unclip: Optional[str] = None, + stable_unclip_prior: Optional[str] = None, + clip_stats_path: Optional[str] = None, + controlnet: Optional[bool] = None, + adapter: Optional[bool] = None, + load_safety_checker: bool = True, + safety_checker: Optional[StableDiffusionSafetyChecker] = None, + feature_extractor: Optional[AutoFeatureExtractor] = None, + pipeline_class: DiffusionPipeline = None, + local_files_only=False, + vae_path=None, + vae=None, + text_encoder=None, + text_encoder_2=None, + tokenizer=None, + tokenizer_2=None, + config_files=None, +) -> DiffusionPipeline: + """ + Load a Stable Diffusion pipeline object from a CompVis-style `.ckpt`/`.safetensors` file and (ideally) a `.yaml` + config file. + + Although many of the arguments can be automatically inferred, some of these rely on brittle checks against the + global step count, which will likely fail for models that have undergone further fine-tuning. Therefore, it is + recommended that you override the default values and/or supply an `original_config_file` wherever possible. + + Args: + checkpoint_path_or_dict (`str` or `dict`): Path to `.ckpt` file, or the state dict. + original_config_file (`str`): + Path to `.yaml` config file corresponding to the original architecture. If `None`, will be automatically + inferred by looking for a key that only exists in SD2.0 models. + image_size (`int`, *optional*, defaults to 512): + The image size that the model was trained on. Use 512 for Stable Diffusion v1.X and Stable Diffusion v2 + Base. Use 768 for Stable Diffusion v2. + prediction_type (`str`, *optional*): + The prediction type that the model was trained on. Use `'epsilon'` for Stable Diffusion v1.X and Stable + Diffusion v2 Base. Use `'v_prediction'` for Stable Diffusion v2. + num_in_channels (`int`, *optional*, defaults to None): + The number of input channels. If `None`, it will be automatically inferred. + scheduler_type (`str`, *optional*, defaults to 'pndm'): + Type of scheduler to use. Should be one of `["pndm", "lms", "heun", "euler", "euler-ancestral", "dpm", + "ddim"]`. + model_type (`str`, *optional*, defaults to `None`): + The pipeline type. `None` to automatically infer, or one of `["FrozenOpenCLIPEmbedder", + "FrozenCLIPEmbedder", "PaintByExample"]`. + is_img2img (`bool`, *optional*, defaults to `False`): + Whether the model should be loaded as an img2img pipeline. + extract_ema (`bool`, *optional*, defaults to `False`): Only relevant for + checkpoints that have both EMA and non-EMA weights. Whether to extract the EMA weights or not. Defaults to + `False`. Pass `True` to extract the EMA weights. EMA weights usually yield higher quality images for + inference. Non-EMA weights are usually better to continue fine-tuning. + upcast_attention (`bool`, *optional*, defaults to `None`): + Whether the attention computation should always be upcasted. This is necessary when running stable + diffusion 2.1. + device (`str`, *optional*, defaults to `None`): + The device to use. Pass `None` to determine automatically. + from_safetensors (`str`, *optional*, defaults to `False`): + If `checkpoint_path` is in `safetensors` format, load checkpoint with safetensors instead of PyTorch. + load_safety_checker (`bool`, *optional*, defaults to `True`): + Whether to load the safety checker or not. Defaults to `True`. + safety_checker (`StableDiffusionSafetyChecker`, *optional*, defaults to `None`): + Safety checker to use. If this parameter is `None`, the function will load a new instance of + [StableDiffusionSafetyChecker] by itself, if needed. + feature_extractor (`AutoFeatureExtractor`, *optional*, defaults to `None`): + Feature extractor to use. If this parameter is `None`, the function will load a new instance of + [AutoFeatureExtractor] by itself, if needed. + pipeline_class (`str`, *optional*, defaults to `None`): + The pipeline class to use. Pass `None` to determine automatically. + local_files_only (`bool`, *optional*, defaults to `False`): + Whether or not to only look at local files (i.e., do not try to download the model). + vae (`AutoencoderKL`, *optional*, defaults to `None`): + Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations. If + this parameter is `None`, the function will load a new instance of [CLIP] by itself, if needed. + text_encoder (`CLIPTextModel`, *optional*, defaults to `None`): + An instance of [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel) + to use, specifically the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) + variant. If this parameter is `None`, the function will load a new instance of [CLIP] by itself, if needed. + tokenizer (`CLIPTokenizer`, *optional*, defaults to `None`): + An instance of + [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer) + to use. If this parameter is `None`, the function will load a new instance of [CLIPTokenizer] by itself, if + needed. + config_files (`Dict[str, str]`, *optional*, defaults to `None`): + A dictionary mapping from config file names to their contents. If this parameter is `None`, the function + will load the config files by itself, if needed. Valid keys are: + - `v1`: Config file for Stable Diffusion v1 + - `v2`: Config file for Stable Diffusion v2 + - `xl`: Config file for Stable Diffusion XL + - `xl_refiner`: Config file for Stable Diffusion XL Refiner + return: A StableDiffusionPipeline object representing the passed-in `.ckpt`/`.safetensors` file. + """ + + # import pipelines here to avoid circular import error when using from_single_file method + from diffusers import ( + LDMTextToImagePipeline, + PaintByExamplePipeline, + StableDiffusionControlNetPipeline, + StableDiffusionInpaintPipeline, + StableDiffusionPipeline, + StableDiffusionUpscalePipeline, + StableDiffusionXLControlNetInpaintPipeline, + StableDiffusionXLImg2ImgPipeline, + StableDiffusionXLInpaintPipeline, + StableDiffusionXLPipeline, + StableUnCLIPImg2ImgPipeline, + StableUnCLIPPipeline, + ) + + if prediction_type == "v-prediction": + prediction_type = "v_prediction" + + if isinstance(checkpoint_path_or_dict, str): + if from_safetensors: + from safetensors.torch import load_file as safe_load + + checkpoint = safe_load(checkpoint_path_or_dict, device="cpu") + else: + if device is None: + device = "cuda" if torch.cuda.is_available() else "cpu" + checkpoint = torch.load(checkpoint_path_or_dict, map_location=device) + else: + checkpoint = torch.load(checkpoint_path_or_dict, map_location=device) + elif isinstance(checkpoint_path_or_dict, dict): + checkpoint = checkpoint_path_or_dict + + # Sometimes models don't have the global_step item + if "global_step" in checkpoint: + global_step = checkpoint["global_step"] + else: + logger.debug("global_step key not found in model") + global_step = None + + # NOTE: this while loop isn't great but this controlnet checkpoint has one additional + # "state_dict" key https://huggingface.co/thibaud/controlnet-canny-sd21 + while "state_dict" in checkpoint: + checkpoint = checkpoint["state_dict"] + + if original_config_file is None: + key_name_v2_1 = "model.diffusion_model.input_blocks.2.1.transformer_blocks.0.attn2.to_k.weight" + key_name_sd_xl_base = "conditioner.embedders.1.model.transformer.resblocks.9.mlp.c_proj.bias" + key_name_sd_xl_refiner = "conditioner.embedders.0.model.transformer.resblocks.9.mlp.c_proj.bias" + is_upscale = pipeline_class == StableDiffusionUpscalePipeline + + config_url = None + + # model_type = "v1" + if config_files is not None and "v1" in config_files: + original_config_file = config_files["v1"] + else: + config_url = "https://raw.githubusercontent.com/CompVis/stable-diffusion/main/configs/stable-diffusion/v1-inference.yaml" + + if key_name_v2_1 in checkpoint and checkpoint[key_name_v2_1].shape[-1] == 1024: + # model_type = "v2" + if config_files is not None and "v2" in config_files: + original_config_file = config_files["v2"] + else: + config_url = "https://raw.githubusercontent.com/Stability-AI/stablediffusion/main/configs/stable-diffusion/v2-inference-v.yaml" + if global_step == 110000: + # v2.1 needs to upcast attention + upcast_attention = True + elif key_name_sd_xl_base in checkpoint: + # only base xl has two text embedders + if config_files is not None and "xl" in config_files: + original_config_file = config_files["xl"] + else: + config_url = "https://raw.githubusercontent.com/Stability-AI/generative-models/main/configs/inference/sd_xl_base.yaml" + elif key_name_sd_xl_refiner in checkpoint: + # only refiner xl has embedder and one text embedders + if config_files is not None and "xl_refiner" in config_files: + original_config_file = config_files["xl_refiner"] + else: + config_url = "https://raw.githubusercontent.com/Stability-AI/generative-models/main/configs/inference/sd_xl_refiner.yaml" + + if is_upscale: + config_url = "https://raw.githubusercontent.com/Stability-AI/stablediffusion/main/configs/stable-diffusion/x4-upscaling.yaml" + + if config_url is not None: + original_config_file = BytesIO(requests.get(config_url).content) + else: + with open(original_config_file, "r") as f: + original_config_file = f.read() + else: + with open(original_config_file, "r") as f: + original_config_file = f.read() + + original_config = yaml.safe_load(original_config_file) + + # Convert the text model. + if ( + model_type is None + and "cond_stage_config" in original_config["model"]["params"] + and original_config["model"]["params"]["cond_stage_config"] is not None + ): + model_type = original_config["model"]["params"]["cond_stage_config"]["target"].split(".")[-1] + logger.debug(f"no `model_type` given, `model_type` inferred as: {model_type}") + elif model_type is None and original_config["model"]["params"]["network_config"] is not None: + if original_config["model"]["params"]["network_config"]["params"]["context_dim"] == 2048: + model_type = "SDXL" + else: + model_type = "SDXL-Refiner" + if image_size is None: + image_size = 1024 + + if pipeline_class is None: + # Check if we have a SDXL or SD model and initialize default pipeline + if model_type not in ["SDXL", "SDXL-Refiner"]: + pipeline_class = StableDiffusionPipeline if not controlnet else StableDiffusionControlNetPipeline + else: + pipeline_class = StableDiffusionXLPipeline if model_type == "SDXL" else StableDiffusionXLImg2ImgPipeline + + if num_in_channels is None and pipeline_class in [ + StableDiffusionInpaintPipeline, + StableDiffusionXLInpaintPipeline, + StableDiffusionXLControlNetInpaintPipeline, + ]: + num_in_channels = 9 + if num_in_channels is None and pipeline_class == StableDiffusionUpscalePipeline: + num_in_channels = 7 + elif num_in_channels is None: + num_in_channels = 4 + + if "unet_config" in original_config["model"]["params"]: + original_config["model"]["params"]["unet_config"]["params"]["in_channels"] = num_in_channels + elif "network_config" in original_config["model"]["params"]: + original_config["model"]["params"]["network_config"]["params"]["in_channels"] = num_in_channels + + if ( + "parameterization" in original_config["model"]["params"] + and original_config["model"]["params"]["parameterization"] == "v" + ): + if prediction_type is None: + # NOTE: For stable diffusion 2 base it is recommended to pass `prediction_type=="epsilon"` + # as it relies on a brittle global step parameter here + prediction_type = "epsilon" if global_step == 875000 else "v_prediction" + if image_size is None: + # NOTE: For stable diffusion 2 base one has to pass `image_size==512` + # as it relies on a brittle global step parameter here + image_size = 512 if global_step == 875000 else 768 + else: + if prediction_type is None: + prediction_type = "epsilon" + if image_size is None: + image_size = 512 + + if controlnet is None and "control_stage_config" in original_config["model"]["params"]: + path = checkpoint_path_or_dict if isinstance(checkpoint_path_or_dict, str) else "" + controlnet = convert_controlnet_checkpoint( + checkpoint, original_config, path, image_size, upcast_attention, extract_ema + ) + + if "timesteps" in original_config["model"]["params"]: + num_train_timesteps = original_config["model"]["params"]["timesteps"] + else: + num_train_timesteps = 1000 + + if model_type in ["SDXL", "SDXL-Refiner"]: + scheduler_dict = { + "beta_schedule": "scaled_linear", + "beta_start": 0.00085, + "beta_end": 0.012, + "interpolation_type": "linear", + "num_train_timesteps": num_train_timesteps, + "prediction_type": "epsilon", + "sample_max_value": 1.0, + "set_alpha_to_one": False, + "skip_prk_steps": True, + "steps_offset": 1, + "timestep_spacing": "leading", + } + scheduler = EulerDiscreteScheduler.from_config(scheduler_dict) + scheduler_type = "euler" + else: + if "linear_start" in original_config["model"]["params"]: + beta_start = original_config["model"]["params"]["linear_start"] + else: + beta_start = 0.02 + + if "linear_end" in original_config["model"]["params"]: + beta_end = original_config["model"]["params"]["linear_end"] + else: + beta_end = 0.085 + scheduler = DDIMScheduler( + beta_end=beta_end, + beta_schedule="scaled_linear", + beta_start=beta_start, + num_train_timesteps=num_train_timesteps, + steps_offset=1, + clip_sample=False, + set_alpha_to_one=False, + prediction_type=prediction_type, + ) + # make sure scheduler works correctly with DDIM + scheduler.register_to_config(clip_sample=False) + + if scheduler_type == "pndm": + config = dict(scheduler.config) + config["skip_prk_steps"] = True + scheduler = PNDMScheduler.from_config(config) + elif scheduler_type == "lms": + scheduler = LMSDiscreteScheduler.from_config(scheduler.config) + elif scheduler_type == "heun": + scheduler = HeunDiscreteScheduler.from_config(scheduler.config) + elif scheduler_type == "euler": + scheduler = EulerDiscreteScheduler.from_config(scheduler.config) + elif scheduler_type == "euler-ancestral": + scheduler = EulerAncestralDiscreteScheduler.from_config(scheduler.config) + elif scheduler_type == "dpm": + scheduler = DPMSolverMultistepScheduler.from_config(scheduler.config) + elif scheduler_type == "ddim": + scheduler = scheduler + else: + raise ValueError(f"Scheduler of type {scheduler_type} doesn't exist!") + + if pipeline_class == StableDiffusionUpscalePipeline: + image_size = original_config["model"]["params"]["unet_config"]["params"]["image_size"] + + # Convert the UNet2DConditionModel model. + unet_config = create_unet_diffusers_config(original_config, image_size=image_size) + unet_config["upcast_attention"] = upcast_attention + + path = checkpoint_path_or_dict if isinstance(checkpoint_path_or_dict, str) else "" + converted_unet_checkpoint = convert_ldm_unet_checkpoint( + checkpoint, unet_config, path=path, extract_ema=extract_ema + ) + + ctx = init_empty_weights if is_accelerate_available() else nullcontext + with ctx(): + unet = UNet2DConditionModel(**unet_config) + + if is_accelerate_available(): + if model_type not in ["SDXL", "SDXL-Refiner"]: # SBM Delay this. + for param_name, param in converted_unet_checkpoint.items(): + set_module_tensor_to_device(unet, param_name, "cpu", value=param) + else: + unet.load_state_dict(converted_unet_checkpoint) + + # Convert the VAE model. + if vae_path is None and vae is None: + vae_config = create_vae_diffusers_config(original_config, image_size=image_size) + converted_vae_checkpoint = convert_ldm_vae_checkpoint(checkpoint, vae_config) + + if ( + "model" in original_config + and "params" in original_config["model"] + and "scale_factor" in original_config["model"]["params"] + ): + vae_scaling_factor = original_config["model"]["params"]["scale_factor"] + else: + vae_scaling_factor = 0.18215 # default SD scaling factor + + vae_config["scaling_factor"] = vae_scaling_factor + + ctx = init_empty_weights if is_accelerate_available() else nullcontext + with ctx(): + vae = AutoencoderKL(**vae_config) + + if is_accelerate_available(): + for param_name, param in converted_vae_checkpoint.items(): + set_module_tensor_to_device(vae, param_name, "cpu", value=param) + else: + vae.load_state_dict(converted_vae_checkpoint) + elif vae is None: + vae = AutoencoderKL.from_pretrained(vae_path, local_files_only=local_files_only) + + if model_type == "FrozenOpenCLIPEmbedder": + config_name = "stabilityai/stable-diffusion-2" + config_kwargs = {"subfolder": "text_encoder"} + + if text_encoder is None: + text_model = convert_open_clip_checkpoint( + checkpoint, config_name, local_files_only=local_files_only, **config_kwargs + ) + else: + text_model = text_encoder + + try: + tokenizer = CLIPTokenizer.from_pretrained( + "stabilityai/stable-diffusion-2", subfolder="tokenizer", local_files_only=local_files_only + ) + except Exception: + raise ValueError( + f"With local_files_only set to {local_files_only}, you must first locally save the tokenizer in the following path: 'stabilityai/stable-diffusion-2'." + ) + + if stable_unclip is None: + if controlnet: + pipe = pipeline_class( + vae=vae, + text_encoder=text_model, + tokenizer=tokenizer, + unet=unet, + scheduler=scheduler, + controlnet=controlnet, + safety_checker=safety_checker, + feature_extractor=feature_extractor, + ) + if hasattr(pipe, "requires_safety_checker"): + pipe.requires_safety_checker = False + + elif pipeline_class == StableDiffusionUpscalePipeline: + scheduler = DDIMScheduler.from_pretrained( + "stabilityai/stable-diffusion-x4-upscaler", subfolder="scheduler" + ) + low_res_scheduler = DDPMScheduler.from_pretrained( + "stabilityai/stable-diffusion-x4-upscaler", subfolder="low_res_scheduler" + ) + + pipe = pipeline_class( + vae=vae, + text_encoder=text_model, + tokenizer=tokenizer, + unet=unet, + scheduler=scheduler, + low_res_scheduler=low_res_scheduler, + safety_checker=safety_checker, + feature_extractor=feature_extractor, + ) + + else: + pipe = pipeline_class( + vae=vae, + text_encoder=text_model, + tokenizer=tokenizer, + unet=unet, + scheduler=scheduler, + safety_checker=safety_checker, + feature_extractor=feature_extractor, + ) + if hasattr(pipe, "requires_safety_checker"): + pipe.requires_safety_checker = False + + else: + image_normalizer, image_noising_scheduler = stable_unclip_image_noising_components( + original_config, clip_stats_path=clip_stats_path, device=device + ) + + if stable_unclip == "img2img": + feature_extractor, image_encoder = stable_unclip_image_encoder(original_config) + + pipe = StableUnCLIPImg2ImgPipeline( + # image encoding components + feature_extractor=feature_extractor, + image_encoder=image_encoder, + # image noising components + image_normalizer=image_normalizer, + image_noising_scheduler=image_noising_scheduler, + # regular denoising components + tokenizer=tokenizer, + text_encoder=text_model, + unet=unet, + scheduler=scheduler, + # vae + vae=vae, + ) + elif stable_unclip == "txt2img": + if stable_unclip_prior is None or stable_unclip_prior == "karlo": + karlo_model = "kakaobrain/karlo-v1-alpha" + prior = PriorTransformer.from_pretrained( + karlo_model, subfolder="prior", local_files_only=local_files_only + ) + + try: + prior_tokenizer = CLIPTokenizer.from_pretrained( + "openai/clip-vit-large-patch14", local_files_only=local_files_only + ) + except Exception: + raise ValueError( + f"With local_files_only set to {local_files_only}, you must first locally save the tokenizer in the following path: 'openai/clip-vit-large-patch14'." + ) + prior_text_model = CLIPTextModelWithProjection.from_pretrained( + "openai/clip-vit-large-patch14", local_files_only=local_files_only + ) + + prior_scheduler = UnCLIPScheduler.from_pretrained( + karlo_model, subfolder="prior_scheduler", local_files_only=local_files_only + ) + prior_scheduler = DDPMScheduler.from_config(prior_scheduler.config) + else: + raise NotImplementedError(f"unknown prior for stable unclip model: {stable_unclip_prior}") + + pipe = StableUnCLIPPipeline( + # prior components + prior_tokenizer=prior_tokenizer, + prior_text_encoder=prior_text_model, + prior=prior, + prior_scheduler=prior_scheduler, + # image noising components + image_normalizer=image_normalizer, + image_noising_scheduler=image_noising_scheduler, + # regular denoising components + tokenizer=tokenizer, + text_encoder=text_model, + unet=unet, + scheduler=scheduler, + # vae + vae=vae, + ) + else: + raise NotImplementedError(f"unknown `stable_unclip` type: {stable_unclip}") + elif model_type == "PaintByExample": + vision_model = convert_paint_by_example_checkpoint(checkpoint) + try: + tokenizer = CLIPTokenizer.from_pretrained( + "openai/clip-vit-large-patch14", local_files_only=local_files_only + ) + except Exception: + raise ValueError( + f"With local_files_only set to {local_files_only}, you must first locally save the tokenizer in the following path: 'openai/clip-vit-large-patch14'." + ) + try: + feature_extractor = AutoFeatureExtractor.from_pretrained( + "CompVis/stable-diffusion-safety-checker", local_files_only=local_files_only + ) + except Exception: + raise ValueError( + f"With local_files_only set to {local_files_only}, you must first locally save the feature_extractor in the following path: 'CompVis/stable-diffusion-safety-checker'." + ) + pipe = PaintByExamplePipeline( + vae=vae, + image_encoder=vision_model, + unet=unet, + scheduler=scheduler, + safety_checker=None, + feature_extractor=feature_extractor, + ) + elif model_type == "FrozenCLIPEmbedder": + text_model = convert_ldm_clip_checkpoint( + checkpoint, local_files_only=local_files_only, text_encoder=text_encoder + ) + try: + tokenizer = ( + CLIPTokenizer.from_pretrained("openai/clip-vit-large-patch14", local_files_only=local_files_only) + if tokenizer is None + else tokenizer + ) + except Exception: + raise ValueError( + f"With local_files_only set to {local_files_only}, you must first locally save the tokenizer in the following path: 'openai/clip-vit-large-patch14'." + ) + + if load_safety_checker: + safety_checker = StableDiffusionSafetyChecker.from_pretrained( + "CompVis/stable-diffusion-safety-checker", local_files_only=local_files_only + ) + feature_extractor = AutoFeatureExtractor.from_pretrained( + "CompVis/stable-diffusion-safety-checker", local_files_only=local_files_only + ) + + if controlnet: + pipe = pipeline_class( + vae=vae, + text_encoder=text_model, + tokenizer=tokenizer, + unet=unet, + controlnet=controlnet, + scheduler=scheduler, + safety_checker=safety_checker, + feature_extractor=feature_extractor, + ) + else: + pipe = pipeline_class( + vae=vae, + text_encoder=text_model, + tokenizer=tokenizer, + unet=unet, + scheduler=scheduler, + safety_checker=safety_checker, + feature_extractor=feature_extractor, + ) + elif model_type in ["SDXL", "SDXL-Refiner"]: + is_refiner = model_type == "SDXL-Refiner" + + if (is_refiner is False) and (tokenizer is None): + try: + tokenizer = CLIPTokenizer.from_pretrained( + "openai/clip-vit-large-patch14", local_files_only=local_files_only + ) + except Exception: + raise ValueError( + f"With local_files_only set to {local_files_only}, you must first locally save the tokenizer in the following path: 'openai/clip-vit-large-patch14'." + ) + + if (is_refiner is False) and (text_encoder is None): + text_encoder = convert_ldm_clip_checkpoint(checkpoint, local_files_only=local_files_only) + + if tokenizer_2 is None: + try: + tokenizer_2 = CLIPTokenizer.from_pretrained( + "laion/CLIP-ViT-bigG-14-laion2B-39B-b160k", pad_token="!", local_files_only=local_files_only + ) + except Exception: + raise ValueError( + f"With local_files_only set to {local_files_only}, you must first locally save the tokenizer in the following path: 'laion/CLIP-ViT-bigG-14-laion2B-39B-b160k' with `pad_token` set to '!'." + ) + + if text_encoder_2 is None: + config_name = "laion/CLIP-ViT-bigG-14-laion2B-39B-b160k" + config_kwargs = {"projection_dim": 1280} + prefix = "conditioner.embedders.0.model." if is_refiner else "conditioner.embedders.1.model." + + text_encoder_2 = convert_open_clip_checkpoint( + checkpoint, + config_name, + prefix=prefix, + has_projection=True, + local_files_only=local_files_only, + **config_kwargs, + ) + + if is_accelerate_available(): # SBM Now move model to cpu. + for param_name, param in converted_unet_checkpoint.items(): + set_module_tensor_to_device(unet, param_name, "cpu", value=param) + + if controlnet: + pipe = pipeline_class( + vae=vae, + text_encoder=text_encoder, + tokenizer=tokenizer, + text_encoder_2=text_encoder_2, + tokenizer_2=tokenizer_2, + unet=unet, + controlnet=controlnet, + scheduler=scheduler, + force_zeros_for_empty_prompt=True, + ) + elif adapter: + pipe = pipeline_class( + vae=vae, + text_encoder=text_encoder, + tokenizer=tokenizer, + text_encoder_2=text_encoder_2, + tokenizer_2=tokenizer_2, + unet=unet, + adapter=adapter, + scheduler=scheduler, + force_zeros_for_empty_prompt=True, + ) + + else: + pipeline_kwargs = { + "vae": vae, + "text_encoder": text_encoder, + "tokenizer": tokenizer, + "text_encoder_2": text_encoder_2, + "tokenizer_2": tokenizer_2, + "unet": unet, + "scheduler": scheduler, + } + + if (pipeline_class == StableDiffusionXLImg2ImgPipeline) or ( + pipeline_class == StableDiffusionXLInpaintPipeline + ): + pipeline_kwargs.update({"requires_aesthetics_score": is_refiner}) + + if is_refiner: + pipeline_kwargs.update({"force_zeros_for_empty_prompt": False}) + + pipe = pipeline_class(**pipeline_kwargs) + else: + text_config = create_ldm_bert_config(original_config) + text_model = convert_ldm_bert_checkpoint(checkpoint, text_config) + tokenizer = BertTokenizerFast.from_pretrained("bert-base-uncased", local_files_only=local_files_only) + pipe = LDMTextToImagePipeline(vqvae=vae, bert=text_model, tokenizer=tokenizer, unet=unet, scheduler=scheduler) + + return pipe + + +def download_controlnet_from_original_ckpt( + checkpoint_path: str, + original_config_file: str, + image_size: int = 512, + extract_ema: bool = False, + num_in_channels: Optional[int] = None, + upcast_attention: Optional[bool] = None, + device: str = None, + from_safetensors: bool = False, + use_linear_projection: Optional[bool] = None, + cross_attention_dim: Optional[bool] = None, +) -> DiffusionPipeline: + if from_safetensors: + from safetensors import safe_open + + checkpoint = {} + with safe_open(checkpoint_path, framework="pt", device="cpu") as f: + for key in f.keys(): + checkpoint[key] = f.get_tensor(key) + else: + if device is None: + device = "cuda" if torch.cuda.is_available() else "cpu" + checkpoint = torch.load(checkpoint_path, map_location=device) + else: + checkpoint = torch.load(checkpoint_path, map_location=device) + + # NOTE: this while loop isn't great but this controlnet checkpoint has one additional + # "state_dict" key https://huggingface.co/thibaud/controlnet-canny-sd21 + while "state_dict" in checkpoint: + checkpoint = checkpoint["state_dict"] + + with open(original_config_file, "r") as f: + original_config_file = f.read() + original_config = yaml.safe_load(original_config_file) + + if num_in_channels is not None: + original_config["model"]["params"]["unet_config"]["params"]["in_channels"] = num_in_channels + + if "control_stage_config" not in original_config["model"]["params"]: + raise ValueError("`control_stage_config` not present in original config") + + controlnet = convert_controlnet_checkpoint( + checkpoint, + original_config, + checkpoint_path, + image_size, + upcast_attention, + extract_ema, + use_linear_projection=use_linear_projection, + cross_attention_dim=cross_attention_dim, + ) + + return controlnet From 1cdbb55a3fe16fa8779f349cdf69e9815963a3a1 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tolga=20Cang=C3=B6z?= Date: Mon, 5 Aug 2024 18:31:04 +0300 Subject: [PATCH 37/87] `make style` --- .../research_projects/anytext/frozen_clip_embedder_t3.py | 5 +++-- examples/research_projects/anytext/recognizer.py | 8 ++++---- 2 files changed, 7 insertions(+), 6 deletions(-) diff --git a/examples/research_projects/anytext/frozen_clip_embedder_t3.py b/examples/research_projects/anytext/frozen_clip_embedder_t3.py index 3ca9db0502fa..0526964c9c5b 100644 --- a/examples/research_projects/anytext/frozen_clip_embedder_t3.py +++ b/examples/research_projects/anytext/frozen_clip_embedder_t3.py @@ -110,8 +110,9 @@ def text_encoder_forward( ) # CLIP's text model uses causal mask, prepare it here. # https://github.com/openai/CLIP/blob/cfcffb90e69f37bf2ff1e988237a0fbe41f33c04/clip/model.py#L324 - causal_attention_mask = _create_4d_causal_attention_mask(input_shape, hidden_states.dtype, - device=hidden_states.device) + causal_attention_mask = _create_4d_causal_attention_mask( + input_shape, hidden_states.dtype, device=hidden_states.device + ) # expand attention_mask if attention_mask is not None: # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] diff --git a/examples/research_projects/anytext/recognizer.py b/examples/research_projects/anytext/recognizer.py index 33f138f81ae5..9986562e5586 100755 --- a/examples/research_projects/anytext/recognizer.py +++ b/examples/research_projects/anytext/recognizer.py @@ -134,13 +134,13 @@ def get_image_file_list(img_file): class TextRecognizer(object): def __init__(self, args, predictor): - self.rec_image_shape = [int(v) for v in args['rec_image_shape'].split(",")] - self.rec_batch_num = args['rec_batch_num'] + self.rec_image_shape = [int(v) for v in args["rec_image_shape"].split(",")] + self.rec_batch_num = args["rec_batch_num"] self.predictor = predictor - self.chars = self.get_char_dict(args['rec_char_dict_path']) + self.chars = self.get_char_dict(args["rec_char_dict_path"]) self.char2id = {x: i for i, x in enumerate(self.chars)} self.is_onnx = not isinstance(self.predictor, torch.nn.Module) - self.use_fp16 = args['use_fp16'] + self.use_fp16 = args["use_fp16"] # img: CHW def resize_norm_img(self, img, max_wh_ratio): From da67ff7b59ce668d6a40917c751f55f8e2d94c55 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tolga=20Cang=C3=B6z?= Date: Tue, 6 Aug 2024 13:06:40 +0300 Subject: [PATCH 38/87] Fix: Move glyph rendering to `TextEmbeddingModule` from `AuxiliaryLatentModule` --- .../anytext/auxiliary_latent_module.py | 143 +---------------- .../anytext/pipeline_anytext.py | 34 ++-- .../anytext/text_embedding_module.py | 146 +++++++++++++++++- 3 files changed, 161 insertions(+), 162 deletions(-) diff --git a/examples/research_projects/anytext/auxiliary_latent_module.py b/examples/research_projects/anytext/auxiliary_latent_module.py index 5223bbe98066..56e2349aa6a8 100644 --- a/examples/research_projects/anytext/auxiliary_latent_module.py +++ b/examples/research_projects/anytext/auxiliary_latent_module.py @@ -53,7 +53,7 @@ def retrieve_latents( class AuxiliaryLatentModule(nn.Module): def __init__(self, dims=2, glyph_channels=1, position_channels=1, model_channels=320, **kwargs): super().__init__() - self.font = ImageFont.truetype("/home/x/Documents/gits/AnyText/font/Arial_Unicode.ttf", 60) + self.font = ImageFont.truetype("/home/cosmos/Documents/gits/AnyText/font/Arial_Unicode.ttf", 60) self.use_fp16 = kwargs.get("use_fp16", False) self.device = kwargs.get("device", "cpu") self.glyph_block = nn.Sequential( @@ -104,146 +104,15 @@ def forward( self, emb, context, - mode, - texts, - prompt, - draw_pos, - ori_image, - max_chars=77, - revise_pos=False, - sort_priority=False, - h=512, - w=512, + text_info, ): - if prompt is None and texts is None: - raise ValueError("Prompt or texts must be provided!") - n_lines = len(texts) - if mode == "generate": - edit_image = np.ones((h, w, 3)) * 127.5 # empty mask image - elif mode == "edit": - if draw_pos is None or ori_image is None: - raise ValueError("Reference image and position image are needed for text editing!") - if isinstance(ori_image, str): - ori_image = cv2.imread(ori_image)[..., ::-1] - if ori_image is None: - raise ValueError(f"Can't read ori_image image from {ori_image}!") - elif isinstance(ori_image, torch.Tensor): - ori_image = ori_image.cpu().numpy() - else: - if not isinstance(ori_image, np.ndarray): - raise ValueError(f"Unknown format of ori_image: {type(ori_image)}") - edit_image = ori_image.clip(1, 255) # for mask reason - edit_image = self.check_channels(edit_image) - edit_image = self.resize_image( - edit_image, max_length=768 - ) # make w h multiple of 64, resize if w or h > max_length - h, w = edit_image.shape[:2] # change h, w by input ref_img - # preprocess pos_imgs(if numpy, make sure it's white pos in black bg) - if draw_pos is None: - pos_imgs = np.zeros((w, h, 1)) - if isinstance(draw_pos, str): - draw_pos = cv2.imread(draw_pos)[..., ::-1] - if draw_pos is None: - raise ValueError(f"Can't read draw_pos image from {draw_pos}!") - pos_imgs = 255 - draw_pos - elif isinstance(draw_pos, torch.Tensor): - pos_imgs = draw_pos.cpu().numpy() - else: - if not isinstance(draw_pos, np.ndarray): - raise ValueError(f"Unknown format of draw_pos: {type(draw_pos)}") - if mode == "edit": - pos_imgs = cv2.resize(pos_imgs, (w, h)) - pos_imgs = pos_imgs[..., 0:1] - pos_imgs = cv2.convertScaleAbs(pos_imgs) - _, pos_imgs = cv2.threshold(pos_imgs, 254, 255, cv2.THRESH_BINARY) - # separate pos_imgs - pos_imgs = self.separate_pos_imgs(pos_imgs, sort_priority) - if len(pos_imgs) == 0: - pos_imgs = [np.zeros((h, w, 1))] - if len(pos_imgs) < n_lines: - if n_lines == 1 and texts[0] == " ": - pass # text-to-image without text - else: - raise ValueError( - f"Found {len(pos_imgs)} positions that < needed {n_lines} from prompt, check and try again!" - ) - elif len(pos_imgs) > n_lines: - str_warning = f"Warning: found {len(pos_imgs)} positions that > needed {n_lines} from prompt." - logger.warning(str_warning) - # get pre_pos, poly_list, hint that needed for anytext - pre_pos = [] - poly_list = [] - for input_pos in pos_imgs: - if input_pos.mean() != 0: - input_pos = input_pos[..., np.newaxis] if len(input_pos.shape) == 2 else input_pos - poly, pos_img = self.find_polygon(input_pos) - pre_pos += [pos_img / 255.0] - poly_list += [poly] - else: - pre_pos += [np.zeros((h, w, 1))] - poly_list += [None] - np_hint = np.sum(pre_pos, axis=0).clip(0, 1) - # prepare info dict - info = {} - info["glyphs"] = [] - info["gly_line"] = [] - info["positions"] = [] - info["n_lines"] = [len(texts)] * len(prompt) - for i in range(len(texts)): - text = texts[i] - if len(text) > max_chars: - str_warning = f'"{text}" length > max_chars: {max_chars}, will be cut off...' - logger.warning(str_warning) - text = text[:max_chars] - gly_scale = 2 - if pre_pos[i].mean() != 0: - gly_line = self.draw_glyph(self.font, text) - glyphs = self.draw_glyph2( - self.font, text, poly_list[i], scale=gly_scale, width=w, height=h, add_space=False - ) - if revise_pos: - resize_gly = cv2.resize(glyphs, (pre_pos[i].shape[1], pre_pos[i].shape[0])) - new_pos = cv2.morphologyEx( - (resize_gly * 255).astype(np.uint8), - cv2.MORPH_CLOSE, - kernel=np.ones((resize_gly.shape[0] // 10, resize_gly.shape[1] // 10), dtype=np.uint8), - iterations=1, - ) - new_pos = new_pos[..., np.newaxis] if len(new_pos.shape) == 2 else new_pos - contours, _ = cv2.findContours(new_pos, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_NONE) - if len(contours) != 1: - str_warning = f"Fail to revise position {i} to bounding rect, remain position unchanged..." - logger.warning(str_warning) - else: - rect = cv2.minAreaRect(contours[0]) - poly = np.int0(cv2.boxPoints(rect)) - pre_pos[i] = cv2.drawContours(new_pos, [poly], -1, 255, -1) / 255.0 - else: - glyphs = np.zeros((h * gly_scale, w * gly_scale, 1)) - gly_line = np.zeros((80, 512, 1)) - pos = pre_pos[i] - info["glyphs"] += [self.arr2tensor(glyphs, len(prompt))] - info["gly_line"] += [self.arr2tensor(gly_line, len(prompt))] - info["positions"] += [self.arr2tensor(pos, len(prompt))] - # get masked_x - masked_img = ((edit_image.astype(np.float32) / 127.5) - 1.0) * (1 - np_hint) - masked_img = np.transpose(masked_img, (2, 0, 1)) - masked_img = torch.from_numpy(masked_img.copy()).float().to(self.device) - if self.use_fp16: - masked_img = masked_img.half() - masked_x = self.encode_first_stage(masked_img[None, ...]).detach() - if self.use_fp16: - masked_x = masked_x.half() - info["masked_x"] = torch.cat([masked_x for _ in range(len(prompt))], dim=0) - hint = self.arr2tensor(np_hint, len(prompt)) - - glyphs = torch.cat(info["glyphs"], dim=1).sum(dim=1, keepdim=True) - positions = torch.cat(info["positions"], dim=1).sum(dim=1, keepdim=True) + glyphs = torch.cat(text_info["glyphs"], dim=1).sum(dim=1, keepdim=True) + positions = torch.cat(text_info["positions"], dim=1).sum(dim=1, keepdim=True) enc_glyph = self.glyph_block(glyphs, emb, context) enc_pos = self.position_block(positions, emb, context) - guided_hint = self.fuse_block(torch.cat([enc_glyph, enc_pos, masked_x], dim=1)) + guided_hint = self.fuse_block(torch.cat([enc_glyph, enc_pos, text_info["masked_x"]], dim=1)) - return guided_hint, hint, info + return guided_hint def encode_first_stage(self, masked_img): return retrieve_latents(self.vae.encode(masked_img)) * self.vae.scale_factor diff --git a/examples/research_projects/anytext/pipeline_anytext.py b/examples/research_projects/anytext/pipeline_anytext.py index b56b8dccb692..4aac7717fc3d 100644 --- a/examples/research_projects/anytext/pipeline_anytext.py +++ b/examples/research_projects/anytext/pipeline_anytext.py @@ -1114,6 +1114,20 @@ def __call__( prompt, texts = self.modify_prompt(prompt) + # 3. Encode input prompt + text_encoder_lora_scale = ( + self.cross_attention_kwargs.get("scale", None) if self.cross_attention_kwargs is not None else None + ) + prompt_embeds, negative_prompt_embeds, text_info = self.text_embedding_module( + prompt, + texts, + negative_prompt, + num_images_per_prompt, + mode, + draw_pos, + ori_image, + ) + # For classifier free guidance, we need to do two forward passes. # Here we concatenate the unconditional and text embeddings into a single batch # to avoid doing two forward passes @@ -1151,15 +1165,10 @@ def __call__( # guess_mode=guess_mode, # ) # height, width = image.shape[-2:] - guided_hint, hint, text_info = self.auxiliary_latent_module( + guided_hint = self.auxiliary_latent_module( emb=timestep_cond, context=prompt_embeds, - mode=mode, - texts=texts, - prompt=prompt, - draw_pos=draw_pos, - ori_image=ori_image, - img_count=len(prompt), + text_info=text_info, ) # elif isinstance(controlnet, MultiControlNetModel): # images = [] @@ -1189,17 +1198,6 @@ def __call__( else: assert False - # 3. Encode input prompt - text_encoder_lora_scale = ( - self.cross_attention_kwargs.get("scale", None) if self.cross_attention_kwargs is not None else None - ) - prompt_embeds, negative_prompt_embeds = self.text_embedding_module( - prompt, - text_info, - negative_prompt, - prompt_embeds=prompt_embeds, - negative_prompt_embeds=negative_prompt_embeds, - ) # 5. Prepare timesteps timesteps, num_inference_steps = retrieve_timesteps( self.scheduler, num_inference_steps, device, timesteps, sigmas diff --git a/examples/research_projects/anytext/text_embedding_module.py b/examples/research_projects/anytext/text_embedding_module.py index 0efbed25d50b..ad79849a86d7 100644 --- a/examples/research_projects/anytext/text_embedding_module.py +++ b/examples/research_projects/anytext/text_embedding_module.py @@ -11,6 +11,9 @@ from PIL import ImageFont from recognizer import TextRecognizer, create_predictor from torch import nn +from torch.nn import functional as F +import numpy as np +import cv2 from diffusers.utils import ( logging, @@ -25,7 +28,7 @@ def __init__(self, use_fp16): super().__init__() self.device = "cuda" if torch.cuda.is_available() else "cpu" # TODO: Learn if the recommended font file is free to use - self.font = ImageFont.truetype("/home/x/Documents/gits/AnyText/font/Arial_Unicode.ttf", 60) + self.font = ImageFont.truetype("/home/cosmos/Documents/gits/AnyText/font/Arial_Unicode.ttf", 60) self.frozen_CLIP_embedder_t3 = FrozenCLIPEmbedderT3(device=self.device) self.embedding_manager_config = { "valid": True, @@ -39,12 +42,12 @@ def __init__(self, use_fp16): # TODO: Understand the reason of param.requires_grad = True for param in self.embedding_manager.embedding_parameters(): param.requires_grad = True - rec_model_dir = "/home/x/Documents/gits/AnyText/ocr_weights/ppv3_rec.pth" + rec_model_dir = "/home/cosmos/Documents/gits/AnyText/ocr_weights/ppv3_rec.pth" self.text_predictor = create_predictor(rec_model_dir).eval() args = {} args["rec_image_shape"] = "3, 48, 320" args["rec_batch_num"] = 6 - args["rec_char_dict_path"] = "/home/x/Documents/gits/AnyText/ocr_weights/ppocr_keys_v1.txt" + args["rec_char_dict_path"] = "/home/cosmos/Documents/gits/AnyText/ocr_weights/ppocr_keys_v1.txt" args["use_fp16"] = use_fp16 self.cn_recognizer = TextRecognizer(args, self.text_predictor) for param in self.text_predictor.parameters(): @@ -55,11 +58,140 @@ def __init__(self, use_fp16): def forward( self, prompt, - text_info, - negative_prompt=None, - prompt_embeds: Optional[torch.Tensor] = None, - negative_prompt_embeds: Optional[torch.Tensor] = None, + texts, + negative_prompt, + num_images_per_prompt, + mode, + draw_pos, + ori_image, + max_chars=77, + revise_pos=False, + sort_priority=False, + h=512, + w=512, ): + if prompt is None and texts is None: + raise ValueError("Prompt or texts must be provided!") + n_lines = len(texts) + if mode == "generate": + edit_image = np.ones((h, w, 3)) * 127.5 # empty mask image + elif mode == "edit": + if draw_pos is None or ori_image is None: + raise ValueError("Reference image and position image are needed for text editing!") + if isinstance(ori_image, str): + ori_image = cv2.imread(ori_image)[..., ::-1] + if ori_image is None: + raise ValueError(f"Can't read ori_image image from {ori_image}!") + elif isinstance(ori_image, torch.Tensor): + ori_image = ori_image.cpu().numpy() + else: + if not isinstance(ori_image, np.ndarray): + raise ValueError(f"Unknown format of ori_image: {type(ori_image)}") + edit_image = ori_image.clip(1, 255) # for mask reason + edit_image = self.check_channels(edit_image) + edit_image = self.resize_image( + edit_image, max_length=768 + ) # make w h multiple of 64, resize if w or h > max_length + h, w = edit_image.shape[:2] # change h, w by input ref_img + # preprocess pos_imgs(if numpy, make sure it's white pos in black bg) + if draw_pos is None: + pos_imgs = np.zeros((w, h, 1)) + if isinstance(draw_pos, str): + draw_pos = cv2.imread(draw_pos)[..., ::-1] + if draw_pos is None: + raise ValueError(f"Can't read draw_pos image from {draw_pos}!") + pos_imgs = 255 - draw_pos + elif isinstance(draw_pos, torch.Tensor): + pos_imgs = draw_pos.cpu().numpy() + else: + if not isinstance(draw_pos, np.ndarray): + raise ValueError(f"Unknown format of draw_pos: {type(draw_pos)}") + if mode == "edit": + pos_imgs = cv2.resize(pos_imgs, (w, h)) + pos_imgs = pos_imgs[..., 0:1] + pos_imgs = cv2.convertScaleAbs(pos_imgs) + _, pos_imgs = cv2.threshold(pos_imgs, 254, 255, cv2.THRESH_BINARY) + # separate pos_imgs + pos_imgs = self.separate_pos_imgs(pos_imgs, sort_priority) + if len(pos_imgs) == 0: + pos_imgs = [np.zeros((h, w, 1))] + if len(pos_imgs) < n_lines: + if n_lines == 1 and texts[0] == " ": + pass # text-to-image without text + else: + raise ValueError( + f"Found {len(pos_imgs)} positions that < needed {n_lines} from prompt, check and try again!" + ) + elif len(pos_imgs) > n_lines: + str_warning = f"Warning: found {len(pos_imgs)} positions that > needed {n_lines} from prompt." + logger.warning(str_warning) + # get pre_pos, poly_list, hint that needed for anytext + pre_pos = [] + poly_list = [] + for input_pos in pos_imgs: + if input_pos.mean() != 0: + input_pos = input_pos[..., np.newaxis] if len(input_pos.shape) == 2 else input_pos + poly, pos_img = self.find_polygon(input_pos) + pre_pos += [pos_img / 255.0] + poly_list += [poly] + else: + pre_pos += [np.zeros((h, w, 1))] + poly_list += [None] + np_hint = np.sum(pre_pos, axis=0).clip(0, 1) + # prepare info dict + text_info = {} + text_info["glyphs"] = [] + text_info["gly_line"] = [] + text_info["positions"] = [] + text_info["n_lines"] = [len(texts)] * num_images_per_prompt + for i in range(len(texts)): + text = texts[i] + if len(text) > max_chars: + str_warning = f'"{text}" length > max_chars: {max_chars}, will be cut off...' + logger.warning(str_warning) + text = text[:max_chars] + gly_scale = 2 + if pre_pos[i].mean() != 0: + gly_line = self.draw_glyph(self.font, text) + glyphs = self.draw_glyph2( + self.font, text, poly_list[i], scale=gly_scale, width=w, height=h, add_space=False + ) + if revise_pos: + resize_gly = cv2.resize(glyphs, (pre_pos[i].shape[1], pre_pos[i].shape[0])) + new_pos = cv2.morphologyEx( + (resize_gly * 255).astype(np.uint8), + cv2.MORPH_CLOSE, + kernel=np.ones((resize_gly.shape[0] // 10, resize_gly.shape[1] // 10), dtype=np.uint8), + iterations=1, + ) + new_pos = new_pos[..., np.newaxis] if len(new_pos.shape) == 2 else new_pos + contours, _ = cv2.findContours(new_pos, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_NONE) + if len(contours) != 1: + str_warning = f"Fail to revise position {i} to bounding rect, remain position unchanged..." + logger.warning(str_warning) + else: + rect = cv2.minAreaRect(contours[0]) + poly = np.int0(cv2.boxPoints(rect)) + pre_pos[i] = cv2.drawContours(new_pos, [poly], -1, 255, -1) / 255.0 + else: + glyphs = np.zeros((h * gly_scale, w * gly_scale, 1)) + gly_line = np.zeros((80, 512, 1)) + pos = pre_pos[i] + text_info["glyphs"] += [self.arr2tensor(glyphs, len(prompt))] + text_info["gly_line"] += [self.arr2tensor(gly_line, len(prompt))] + text_info["positions"] += [self.arr2tensor(pos, len(prompt))] + # get masked_x + masked_img = ((edit_image.astype(np.float32) / 127.5) - 1.0) * (1 - np_hint) + masked_img = np.transpose(masked_img, (2, 0, 1)) + masked_img = torch.from_numpy(masked_img.copy()).float().to(self.device) + if self.use_fp16: + masked_img = masked_img.half() + masked_x = self.encode_first_stage(masked_img[None, ...]).detach() + if self.use_fp16: + masked_x = masked_x.half() + text_info["masked_x"] = torch.cat([masked_x for _ in range(len(prompt))], dim=0) + # hint = self.arr2tensor(np_hint, len(prompt)) + self.embedding_manager.encode_text(text_info) prompt_embeds = self.frozen_CLIP_embedder_t3.encode([prompt], embedding_manager=self.embedding_manager) From af30f0f6590dd50b305eb956070001404c3e61a1 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tolga=20Cang=C3=B6z?= Date: Tue, 6 Aug 2024 13:07:14 +0300 Subject: [PATCH 39/87] make style --- .../research_projects/anytext/text_embedding_module.py | 7 ++----- 1 file changed, 2 insertions(+), 5 deletions(-) diff --git a/examples/research_projects/anytext/text_embedding_module.py b/examples/research_projects/anytext/text_embedding_module.py index ad79849a86d7..4cac5f99595d 100644 --- a/examples/research_projects/anytext/text_embedding_module.py +++ b/examples/research_projects/anytext/text_embedding_module.py @@ -3,17 +3,14 @@ # text -> tokenizer -> -from typing import Optional - +import cv2 +import numpy as np import torch from embedding_manager import EmbeddingManager from frozen_clip_embedder_t3 import FrozenCLIPEmbedderT3 from PIL import ImageFont from recognizer import TextRecognizer, create_predictor from torch import nn -from torch.nn import functional as F -import numpy as np -import cv2 from diffusers.utils import ( logging, From a8dbbe2017f6273430fc3c7a510dd2bf11ec14ca Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tolga=20Cang=C3=B6z?= Date: Tue, 6 Aug 2024 18:54:06 +0300 Subject: [PATCH 40/87] Up --- .../anytext/auxiliary_latent_module.py | 218 ++++++++---------- .../anytext/embedding_manager.py | 1 + .../anytext/frozen_clip_embedder_t3.py | 2 +- .../anytext/pipeline_anytext.py | 11 +- .../anytext/text_embedding_module.py | 177 +++++++++++--- 5 files changed, 251 insertions(+), 158 deletions(-) diff --git a/examples/research_projects/anytext/auxiliary_latent_module.py b/examples/research_projects/anytext/auxiliary_latent_module.py index 56e2349aa6a8..aabc391e5540 100644 --- a/examples/research_projects/anytext/auxiliary_latent_module.py +++ b/examples/research_projects/anytext/auxiliary_latent_module.py @@ -2,12 +2,14 @@ # +> fuse layer # position l_p -> position block -> +import math from typing import Optional import cv2 import numpy as np import torch -from PIL import Image, ImageDraw, ImageFont +from einops import repeat +from PIL import ImageFont from torch import nn from diffusers.utils import logging @@ -16,19 +18,6 @@ logger = logging.get_logger(__name__) # pylint: disable=invalid-name -def conv_nd(dims, *args, **kwargs): - """ - Create a 1D, 2D, or 3D convolution module. - """ - if dims == 1: - return nn.Conv1d(*args, **kwargs) - elif dims == 2: - return nn.Conv2d(*args, **kwargs) - elif dims == 3: - return nn.Conv3d(*args, **kwargs) - raise ValueError(f"unsupported dimensions: {dims}") - - # Copied from diffusers.models.controlnet.zero_module def zero_module(module: nn.Module) -> nn.Module: for p in module.parameters(): @@ -56,74 +45,142 @@ def __init__(self, dims=2, glyph_channels=1, position_channels=1, model_channels self.font = ImageFont.truetype("/home/cosmos/Documents/gits/AnyText/font/Arial_Unicode.ttf", 60) self.use_fp16 = kwargs.get("use_fp16", False) self.device = kwargs.get("device", "cpu") + self.model_channels = model_channels + time_embed_dim = model_channels * 4 + self.time_embed = nn.Sequential( + nn.Linear(model_channels, time_embed_dim), + nn.SiLU(), + nn.Linear(time_embed_dim, time_embed_dim), + ) self.glyph_block = nn.Sequential( - conv_nd(dims, glyph_channels, 8, 3, padding=1), + nn.Conv2d(glyph_channels, 8, 3, padding=1), nn.SiLU(), - conv_nd(dims, 8, 8, 3, padding=1), + nn.Conv2d(8, 8, 3, padding=1), nn.SiLU(), - conv_nd(dims, 8, 16, 3, padding=1, stride=2), + nn.Conv2d(8, 16, 3, padding=1, stride=2), nn.SiLU(), - conv_nd(dims, 16, 16, 3, padding=1), + nn.Conv2d(16, 16, 3, padding=1), nn.SiLU(), - conv_nd(dims, 16, 32, 3, padding=1, stride=2), + nn.Conv2d(16, 32, 3, padding=1, stride=2), nn.SiLU(), - conv_nd(dims, 32, 32, 3, padding=1), + nn.Conv2d(32, 32, 3, padding=1), nn.SiLU(), - conv_nd(dims, 32, 96, 3, padding=1, stride=2), + nn.Conv2d(32, 96, 3, padding=1, stride=2), nn.SiLU(), - conv_nd(dims, 96, 96, 3, padding=1), + nn.Conv2d(96, 96, 3, padding=1), nn.SiLU(), - conv_nd(dims, 96, 256, 3, padding=1, stride=2), + nn.Conv2d(96, 256, 3, padding=1, stride=2), nn.SiLU(), ) self.position_block = nn.Sequential( - conv_nd(dims, position_channels, 8, 3, padding=1), + nn.Conv2d(position_channels, 8, 3, padding=1), nn.SiLU(), - conv_nd(dims, 8, 8, 3, padding=1), + nn.Conv2d(8, 8, 3, padding=1), nn.SiLU(), - conv_nd(dims, 8, 16, 3, padding=1, stride=2), + nn.Conv2d(8, 16, 3, padding=1, stride=2), nn.SiLU(), - conv_nd(dims, 16, 16, 3, padding=1), + nn.Conv2d(16, 16, 3, padding=1), nn.SiLU(), - conv_nd(dims, 16, 32, 3, padding=1, stride=2), + nn.Conv2d(16, 32, 3, padding=1, stride=2), nn.SiLU(), - conv_nd(dims, 32, 32, 3, padding=1), + nn.Conv2d(32, 32, 3, padding=1), nn.SiLU(), - conv_nd(dims, 32, 64, 3, padding=1, stride=2), + nn.Conv2d(32, 64, 3, padding=1, stride=2), nn.SiLU(), ) + self.time_embed = self.time_embed.to(device="cuda", dtype=torch.float16) + self.glyph_block = self.glyph_block.to(device="cuda", dtype=torch.float16) + self.position_block = self.position_block.to(device="cuda", dtype=torch.float16) self.vae = kwargs.get("vae") self.vae.eval() - self.fuse_block = zero_module(conv_nd(dims, 256 + 64 + 4, model_channels, 3, padding=1)) + self.fuse_block = zero_module(nn.Conv2d(256 + 64 + 4, model_channels, 3, padding=1)) + self.fuse_block = self.fuse_block.to(device="cuda", dtype=torch.float16) @torch.no_grad() def forward( self, - emb, context, text_info, + mode, + draw_pos, + ori_image, + num_images_per_prompt, + np_hint, + h=512, + w=512, ): + if mode == "generate": + edit_image = np.ones((h, w, 3)) * 127.5 # empty mask image + elif mode == "edit": + if draw_pos is None or ori_image is None: + raise ValueError("Reference image and position image are needed for text editing!") + if isinstance(ori_image, str): + ori_image = cv2.imread(ori_image)[..., ::-1] + if ori_image is None: + raise ValueError(f"Can't read ori_image image from {ori_image}!") + elif isinstance(ori_image, torch.Tensor): + ori_image = ori_image.cpu().numpy() + else: + if not isinstance(ori_image, np.ndarray): + raise ValueError(f"Unknown format of ori_image: {type(ori_image)}") + edit_image = ori_image.clip(1, 255) # for mask reason + edit_image = self.check_channels(edit_image) + edit_image = self.resize_image( + edit_image, max_length=768 + ) # make w h multiple of 64, resize if w or h > max_length + h, w = edit_image.shape[:2] # change h, w by input ref_img + + # get masked_x + masked_img = ((edit_image.astype(np.float32) / 127.5) - 1.0) * (1 - np_hint) + masked_img = np.transpose(masked_img, (2, 0, 1)) + masked_img = torch.from_numpy(masked_img.copy()).float().to(self.device) + if self.use_fp16: + masked_img = masked_img.half() + masked_x = self.encode_first_stage(masked_img[None, ...]).detach() + if self.use_fp16: + masked_x = masked_x.half() + text_info["masked_x"] = torch.cat([masked_x for _ in range(num_images_per_prompt)], dim=0) + glyphs = torch.cat(text_info["glyphs"], dim=1).sum(dim=1, keepdim=True) positions = torch.cat(text_info["positions"], dim=1).sum(dim=1, keepdim=True) - enc_glyph = self.glyph_block(glyphs, emb, context) - enc_pos = self.position_block(positions, emb, context) - guided_hint = self.fuse_block(torch.cat([enc_glyph, enc_pos, text_info["masked_x"]], dim=1)) + t_emb = self.timestep_embedding(torch.tensor([1000], device="cuda"), self.model_channels, repeat_only=False) + if self.use_fp16: + t_emb = t_emb.half() + emb = self.time_embed(t_emb) + print(glyphs.shape, emb.shape, positions.shape, context.shape) + enc_glyph = self.glyph_block(glyphs.cuda(), emb, context) + enc_pos = self.position_block(positions.cuda(), emb, context) + guided_hint = self.fuse_block(torch.cat([enc_glyph, enc_pos, text_info["masked_x"].cuda()], dim=1)) return guided_hint - def encode_first_stage(self, masked_img): - return retrieve_latents(self.vae.encode(masked_img)) * self.vae.scale_factor + def timestep_embedding(self, timesteps, dim, max_period=10000, repeat_only=False): + """ + Create sinusoidal timestep embeddings. + :param timesteps: a 1-D Tensor of N indices, one per batch element. + These may be fractional. + :param dim: the dimension of the output. + :param max_period: controls the minimum frequency of the embeddings. + :return: an [N x dim] Tensor of positional embeddings. + """ + if not repeat_only: + half = dim // 2 + freqs = torch.exp(-math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half).to( + device=timesteps.device + ) + args = timesteps[:, None].float() * freqs[None] + embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1) + if dim % 2: + embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1) + else: + embedding = repeat(timesteps, "b -> b d", d=dim) + return embedding - def arr2tensor(self, arr, bs): - arr = np.transpose(arr, (2, 0, 1)) - _arr = torch.from_numpy(arr.copy()).float().cpu() - if self.use_fp16: - _arr = _arr.half() - _arr = torch.stack([_arr for _ in range(bs)], dim=0) - return _arr + def encode_first_stage(self, masked_img): + return retrieve_latents(self.vae.encode(masked_img)) * self.vae.config.scaling_factor def check_channels(self, image): channels = image.shape[2] if len(image.shape) == 3 else 1 @@ -155,79 +212,6 @@ def insert_spaces(self, string, nSpace): new_string += char + " " * nSpace return new_string[:-nSpace] - def draw_glyph2(self, font, text, polygon, vertAng=10, scale=1, width=512, height=512, add_space=True): - enlarge_polygon = polygon * scale - rect = cv2.minAreaRect(enlarge_polygon) - box = cv2.boxPoints(rect) - box = np.int0(box) - w, h = rect[1] - angle = rect[2] - if angle < -45: - angle += 90 - angle = -angle - if w < h: - angle += 90 - - vert = False - if abs(angle) % 90 < vertAng or abs(90 - abs(angle) % 90) % 90 < vertAng: - _w = max(box[:, 0]) - min(box[:, 0]) - _h = max(box[:, 1]) - min(box[:, 1]) - if _h >= _w: - vert = True - angle = 0 - - img = np.zeros((height * scale, width * scale, 3), np.uint8) - img = Image.fromarray(img) - - # infer font size - image4ratio = Image.new("RGB", img.size, "white") - draw = ImageDraw.Draw(image4ratio) - _, _, _tw, _th = draw.textbbox(xy=(0, 0), text=text, font=font) - text_w = min(w, h) * (_tw / _th) - if text_w <= max(w, h): - # add space - if len(text) > 1 and not vert and add_space: - for i in range(1, 100): - text_space = self.insert_spaces(text, i) - _, _, _tw2, _th2 = draw.textbbox(xy=(0, 0), text=text_space, font=font) - if min(w, h) * (_tw2 / _th2) > max(w, h): - break - text = self.insert_spaces(text, i - 1) - font_size = min(w, h) * 0.80 - else: - shrink = 0.75 if vert else 0.85 - font_size = min(w, h) / (text_w / max(w, h)) * shrink - new_font = font.font_variant(size=int(font_size)) - - left, top, right, bottom = new_font.getbbox(text) - text_width = right - left - text_height = bottom - top - - layer = Image.new("RGBA", img.size, (0, 0, 0, 0)) - draw = ImageDraw.Draw(layer) - if not vert: - draw.text( - (rect[0][0] - text_width // 2, rect[0][1] - text_height // 2 - top), - text, - font=new_font, - fill=(255, 255, 255, 255), - ) - else: - x_s = min(box[:, 0]) + _w // 2 - text_height // 2 - y_s = min(box[:, 1]) - for c in text: - draw.text((x_s, y_s), c, font=new_font, fill=(255, 255, 255, 255)) - _, _t, _, _b = new_font.getbbox(c) - y_s += _b - - rotated_layer = layer.rotate(angle, expand=1, center=(rect[0][0], rect[0][1])) - - x_offset = int((img.width - rotated_layer.width) / 2) - y_offset = int((img.height - rotated_layer.height) / 2) - img.paste(rotated_layer, (x_offset, y_offset), rotated_layer) - img = np.expand_dims(np.array(img.convert("1")), axis=2).astype(np.float64) - return img - def to(self, device): self.device = device self.glyph_block = self.glyph_block.to(device) diff --git a/examples/research_projects/anytext/embedding_manager.py b/examples/research_projects/anytext/embedding_manager.py index 11dbb4d70c4d..9d011fac3cf1 100644 --- a/examples/research_projects/anytext/embedding_manager.py +++ b/examples/research_projects/anytext/embedding_manager.py @@ -185,6 +185,7 @@ def forward( text_emb = torch.cat(self.text_embs_all[i], dim=0) if sum(idx) != len(text_emb): print("truncation for long caption...") + text_emb = text_emb.to(embedded_text.device) embedded_text[i][idx] = text_emb[: sum(idx)] return embedded_text diff --git a/examples/research_projects/anytext/frozen_clip_embedder_t3.py b/examples/research_projects/anytext/frozen_clip_embedder_t3.py index 0526964c9c5b..e8be9497e876 100644 --- a/examples/research_projects/anytext/frozen_clip_embedder_t3.py +++ b/examples/research_projects/anytext/frozen_clip_embedder_t3.py @@ -20,7 +20,7 @@ def __init__( ): super().__init__() self.tokenizer = CLIPTokenizer.from_pretrained(version) - self.transformer = CLIPTextModel.from_pretrained(version) + self.transformer = CLIPTextModel.from_pretrained(version).to(device) if use_vision: self.vit = CLIPVisionModelWithProjection.from_pretrained(version) self.processor = AutoProcessor.from_pretrained(version) diff --git a/examples/research_projects/anytext/pipeline_anytext.py b/examples/research_projects/anytext/pipeline_anytext.py index 4aac7717fc3d..1b1b9de010e1 100644 --- a/examples/research_projects/anytext/pipeline_anytext.py +++ b/examples/research_projects/anytext/pipeline_anytext.py @@ -1118,14 +1118,13 @@ def __call__( text_encoder_lora_scale = ( self.cross_attention_kwargs.get("scale", None) if self.cross_attention_kwargs is not None else None ) - prompt_embeds, negative_prompt_embeds, text_info = self.text_embedding_module( + prompt_embeds, negative_prompt_embeds, text_info, np_hint = self.text_embedding_module( prompt, texts, negative_prompt, num_images_per_prompt, mode, draw_pos, - ori_image, ) # For classifier free guidance, we need to do two forward passes. @@ -1166,9 +1165,13 @@ def __call__( # ) # height, width = image.shape[-2:] guided_hint = self.auxiliary_latent_module( - emb=timestep_cond, - context=prompt_embeds, + context=prompt_embeds[1], text_info=text_info, + mode=mode, + draw_pos=draw_pos, + ori_image=ori_image, + num_images_per_prompt=num_images_per_prompt, + np_hint=np_hint, ) # elif isinstance(controlnet, MultiControlNetModel): # images = [] diff --git a/examples/research_projects/anytext/text_embedding_module.py b/examples/research_projects/anytext/text_embedding_module.py index 4cac5f99595d..53bb87b901d4 100644 --- a/examples/research_projects/anytext/text_embedding_module.py +++ b/examples/research_projects/anytext/text_embedding_module.py @@ -8,7 +8,7 @@ import torch from embedding_manager import EmbeddingManager from frozen_clip_embedder_t3 import FrozenCLIPEmbedderT3 -from PIL import ImageFont +from PIL import Image, ImageDraw, ImageFont from recognizer import TextRecognizer, create_predictor from torch import nn @@ -23,6 +23,7 @@ class TextEmbeddingModule(nn.Module): def __init__(self, use_fp16): super().__init__() + self.use_fp16 = use_fp16 self.device = "cuda" if torch.cuda.is_available() else "cpu" # TODO: Learn if the recommended font file is free to use self.font = ImageFont.truetype("/home/cosmos/Documents/gits/AnyText/font/Arial_Unicode.ttf", 60) @@ -45,7 +46,7 @@ def __init__(self, use_fp16): args["rec_image_shape"] = "3, 48, 320" args["rec_batch_num"] = 6 args["rec_char_dict_path"] = "/home/cosmos/Documents/gits/AnyText/ocr_weights/ppocr_keys_v1.txt" - args["use_fp16"] = use_fp16 + args["use_fp16"] = False self.cn_recognizer = TextRecognizer(args, self.text_predictor) for param in self.text_predictor.parameters(): param.requires_grad = False @@ -60,36 +61,14 @@ def forward( num_images_per_prompt, mode, draw_pos, - ori_image, + sort_priority="↕", max_chars=77, revise_pos=False, - sort_priority=False, h=512, w=512, ): if prompt is None and texts is None: raise ValueError("Prompt or texts must be provided!") - n_lines = len(texts) - if mode == "generate": - edit_image = np.ones((h, w, 3)) * 127.5 # empty mask image - elif mode == "edit": - if draw_pos is None or ori_image is None: - raise ValueError("Reference image and position image are needed for text editing!") - if isinstance(ori_image, str): - ori_image = cv2.imread(ori_image)[..., ::-1] - if ori_image is None: - raise ValueError(f"Can't read ori_image image from {ori_image}!") - elif isinstance(ori_image, torch.Tensor): - ori_image = ori_image.cpu().numpy() - else: - if not isinstance(ori_image, np.ndarray): - raise ValueError(f"Unknown format of ori_image: {type(ori_image)}") - edit_image = ori_image.clip(1, 255) # for mask reason - edit_image = self.check_channels(edit_image) - edit_image = self.resize_image( - edit_image, max_length=768 - ) # make w h multiple of 64, resize if w or h > max_length - h, w = edit_image.shape[:2] # change h, w by input ref_img # preprocess pos_imgs(if numpy, make sure it's white pos in black bg) if draw_pos is None: pos_imgs = np.zeros((w, h, 1)) @@ -112,6 +91,7 @@ def forward( pos_imgs = self.separate_pos_imgs(pos_imgs, sort_priority) if len(pos_imgs) == 0: pos_imgs = [np.zeros((h, w, 1))] + n_lines = len(texts) if len(pos_imgs) < n_lines: if n_lines == 1 and texts[0] == " ": pass # text-to-image without text @@ -177,16 +157,7 @@ def forward( text_info["glyphs"] += [self.arr2tensor(glyphs, len(prompt))] text_info["gly_line"] += [self.arr2tensor(gly_line, len(prompt))] text_info["positions"] += [self.arr2tensor(pos, len(prompt))] - # get masked_x - masked_img = ((edit_image.astype(np.float32) / 127.5) - 1.0) * (1 - np_hint) - masked_img = np.transpose(masked_img, (2, 0, 1)) - masked_img = torch.from_numpy(masked_img.copy()).float().to(self.device) - if self.use_fp16: - masked_img = masked_img.half() - masked_x = self.encode_first_stage(masked_img[None, ...]).detach() - if self.use_fp16: - masked_x = masked_x.half() - text_info["masked_x"] = torch.cat([masked_x for _ in range(len(prompt))], dim=0) + # hint = self.arr2tensor(np_hint, len(prompt)) self.embedding_manager.encode_text(text_info) @@ -197,4 +168,138 @@ def forward( [negative_prompt], embedding_manager=self.embedding_manager ) - return prompt_embeds, negative_prompt_embeds + return prompt_embeds, negative_prompt_embeds, text_info, np_hint + + def arr2tensor(self, arr, bs): + arr = np.transpose(arr, (2, 0, 1)) + _arr = torch.from_numpy(arr.copy()).float().cpu() + if self.use_fp16: + _arr = _arr.half() + _arr = torch.stack([_arr for _ in range(bs)], dim=0) + return _arr + + def separate_pos_imgs(self, img, sort_priority, gap=102): + num_labels, labels, stats, centroids = cv2.connectedComponentsWithStats(img) + components = [] + for label in range(1, num_labels): + component = np.zeros_like(img) + component[labels == label] = 255 + components.append((component, centroids[label])) + if sort_priority == "↕": + fir, sec = 1, 0 # top-down first + elif sort_priority == "↔": + fir, sec = 0, 1 # left-right first + else: + raise ValueError(f"Unknown sort_priority: {sort_priority}") + components.sort(key=lambda c: (c[1][fir] // gap, c[1][sec] // gap)) + sorted_components = [c[0] for c in components] + return sorted_components + + def find_polygon(self, image, min_rect=False): + contours, hierarchy = cv2.findContours(image, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_NONE) + max_contour = max(contours, key=cv2.contourArea) # get contour with max area + if min_rect: + # get minimum enclosing rectangle + rect = cv2.minAreaRect(max_contour) + poly = np.int0(cv2.boxPoints(rect)) + else: + # get approximate polygon + epsilon = 0.01 * cv2.arcLength(max_contour, True) + poly = cv2.approxPolyDP(max_contour, epsilon, True) + n, _, xy = poly.shape + poly = poly.reshape(n, xy) + cv2.drawContours(image, [poly], -1, 255, -1) + return poly, image + + def draw_glyph(self, font, text): + g_size = 50 + W, H = (512, 80) + new_font = font.font_variant(size=g_size) + img = Image.new(mode="1", size=(W, H), color=0) + draw = ImageDraw.Draw(img) + left, top, right, bottom = new_font.getbbox(text) + text_width = max(right - left, 5) + text_height = max(bottom - top, 5) + ratio = min(W * 0.9 / text_width, H * 0.9 / text_height) + new_font = font.font_variant(size=int(g_size * ratio)) + + text_width, text_height = new_font.getsize(text) + offset_x, offset_y = new_font.getoffset(text) + x = (img.width - text_width) // 2 + y = (img.height - text_height) // 2 - offset_y // 2 + draw.text((x, y), text, font=new_font, fill="white") + img = np.expand_dims(np.array(img), axis=2).astype(np.float64) + return img + + def draw_glyph2(self, font, text, polygon, vertAng=10, scale=1, width=512, height=512, add_space=True): + enlarge_polygon = polygon * scale + rect = cv2.minAreaRect(enlarge_polygon) + box = cv2.boxPoints(rect) + box = np.int0(box) + w, h = rect[1] + angle = rect[2] + if angle < -45: + angle += 90 + angle = -angle + if w < h: + angle += 90 + + vert = False + if abs(angle) % 90 < vertAng or abs(90 - abs(angle) % 90) % 90 < vertAng: + _w = max(box[:, 0]) - min(box[:, 0]) + _h = max(box[:, 1]) - min(box[:, 1]) + if _h >= _w: + vert = True + angle = 0 + + img = np.zeros((height * scale, width * scale, 3), np.uint8) + img = Image.fromarray(img) + + # infer font size + image4ratio = Image.new("RGB", img.size, "white") + draw = ImageDraw.Draw(image4ratio) + _, _, _tw, _th = draw.textbbox(xy=(0, 0), text=text, font=font) + text_w = min(w, h) * (_tw / _th) + if text_w <= max(w, h): + # add space + if len(text) > 1 and not vert and add_space: + for i in range(1, 100): + text_space = self.insert_spaces(text, i) + _, _, _tw2, _th2 = draw.textbbox(xy=(0, 0), text=text_space, font=font) + if min(w, h) * (_tw2 / _th2) > max(w, h): + break + text = self.insert_spaces(text, i - 1) + font_size = min(w, h) * 0.80 + else: + shrink = 0.75 if vert else 0.85 + font_size = min(w, h) / (text_w / max(w, h)) * shrink + new_font = font.font_variant(size=int(font_size)) + + left, top, right, bottom = new_font.getbbox(text) + text_width = right - left + text_height = bottom - top + + layer = Image.new("RGBA", img.size, (0, 0, 0, 0)) + draw = ImageDraw.Draw(layer) + if not vert: + draw.text( + (rect[0][0] - text_width // 2, rect[0][1] - text_height // 2 - top), + text, + font=new_font, + fill=(255, 255, 255, 255), + ) + else: + x_s = min(box[:, 0]) + _w // 2 - text_height // 2 + y_s = min(box[:, 1]) + for c in text: + draw.text((x_s, y_s), c, font=new_font, fill=(255, 255, 255, 255)) + _, _t, _, _b = new_font.getbbox(c) + y_s += _b + + rotated_layer = layer.rotate(angle, expand=1, center=(rect[0][0], rect[0][1])) + + x_offset = int((img.width - rotated_layer.width) / 2) + y_offset = int((img.height - rotated_layer.height) / 2) + img.paste(rotated_layer, (x_offset, y_offset), rotated_layer) + img = np.expand_dims(np.array(img.convert("1")), axis=2).astype(np.float64) + return img From 936c2ff62bb6045326fff4d2795188b1dffb9b04 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tolga=20Cang=C3=B6z?= Date: Tue, 6 Aug 2024 19:12:14 +0300 Subject: [PATCH 41/87] Simplify --- .../anytext/auxiliary_latent_module.py | 37 +------------------ .../anytext/text_controlnet.py | 2 +- .../anytext/text_embedding_module.py | 5 --- 3 files changed, 3 insertions(+), 41 deletions(-) diff --git a/examples/research_projects/anytext/auxiliary_latent_module.py b/examples/research_projects/anytext/auxiliary_latent_module.py index aabc391e5540..316f185b702a 100644 --- a/examples/research_projects/anytext/auxiliary_latent_module.py +++ b/examples/research_projects/anytext/auxiliary_latent_module.py @@ -1,14 +1,8 @@ -# text -> glyph render -> glyph l_g -> glyph block -> -# +> fuse layer -# position l_p -> position block -> - -import math from typing import Optional import cv2 import numpy as np import torch -from einops import repeat from PIL import ImageFont from torch import nn @@ -146,39 +140,12 @@ def forward( glyphs = torch.cat(text_info["glyphs"], dim=1).sum(dim=1, keepdim=True) positions = torch.cat(text_info["positions"], dim=1).sum(dim=1, keepdim=True) - t_emb = self.timestep_embedding(torch.tensor([1000], device="cuda"), self.model_channels, repeat_only=False) - if self.use_fp16: - t_emb = t_emb.half() - emb = self.time_embed(t_emb) - print(glyphs.shape, emb.shape, positions.shape, context.shape) - enc_glyph = self.glyph_block(glyphs.cuda(), emb, context) - enc_pos = self.position_block(positions.cuda(), emb, context) + enc_glyph = self.glyph_block(glyphs.cuda()) + enc_pos = self.position_block(positions.cuda()) guided_hint = self.fuse_block(torch.cat([enc_glyph, enc_pos, text_info["masked_x"].cuda()], dim=1)) return guided_hint - def timestep_embedding(self, timesteps, dim, max_period=10000, repeat_only=False): - """ - Create sinusoidal timestep embeddings. - :param timesteps: a 1-D Tensor of N indices, one per batch element. - These may be fractional. - :param dim: the dimension of the output. - :param max_period: controls the minimum frequency of the embeddings. - :return: an [N x dim] Tensor of positional embeddings. - """ - if not repeat_only: - half = dim // 2 - freqs = torch.exp(-math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half).to( - device=timesteps.device - ) - args = timesteps[:, None].float() * freqs[None] - embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1) - if dim % 2: - embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1) - else: - embedding = repeat(timesteps, "b -> b d", d=dim) - return embedding - def encode_first_stage(self, masked_img): return retrieve_latents(self.vae.encode(masked_img)) * self.vae.config.scaling_factor diff --git a/examples/research_projects/anytext/text_controlnet.py b/examples/research_projects/anytext/text_controlnet.py index c48c7081924c..cc6f2c59835d 100644 --- a/examples/research_projects/anytext/text_controlnet.py +++ b/examples/research_projects/anytext/text_controlnet.py @@ -28,7 +28,7 @@ class AnyTextControlNetModel(ControlNetModel): """ - A PromptDiffusionControlNet model. + A AnyTextControlNetModel model. Args: in_channels (`int`, defaults to 4): diff --git a/examples/research_projects/anytext/text_embedding_module.py b/examples/research_projects/anytext/text_embedding_module.py index 53bb87b901d4..1f09b63b0eaa 100644 --- a/examples/research_projects/anytext/text_embedding_module.py +++ b/examples/research_projects/anytext/text_embedding_module.py @@ -1,8 +1,3 @@ -# text -> glyph render -> glyph lines -> OCR -> linear -> -# +> Token Replacement -> FrozenCLIPEmbedderT3 -# text -> tokenizer -> - - import cv2 import numpy as np import torch From cffa03617f134e39090271888824ed44b8fc08a3 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tolga=20Cang=C3=B6z?= Date: Wed, 7 Aug 2024 15:23:52 +0300 Subject: [PATCH 42/87] Up --- .../anytext/auxiliary_latent_module.py | 18 +++++++----------- .../anytext/pipeline_anytext.py | 1 - .../anytext/text_embedding_module.py | 6 +++--- 3 files changed, 10 insertions(+), 15 deletions(-) diff --git a/examples/research_projects/anytext/auxiliary_latent_module.py b/examples/research_projects/anytext/auxiliary_latent_module.py index 316f185b702a..f686c39769dd 100644 --- a/examples/research_projects/anytext/auxiliary_latent_module.py +++ b/examples/research_projects/anytext/auxiliary_latent_module.py @@ -4,6 +4,7 @@ import numpy as np import torch from PIL import ImageFont +from safetensors.torch import load_file from torch import nn from diffusers.utils import logging @@ -34,18 +35,12 @@ def retrieve_latents( class AuxiliaryLatentModule(nn.Module): - def __init__(self, dims=2, glyph_channels=1, position_channels=1, model_channels=320, **kwargs): + def __init__(self, glyph_channels=1, position_channels=1, model_channels=320, **kwargs): super().__init__() - self.font = ImageFont.truetype("/home/cosmos/Documents/gits/AnyText/font/Arial_Unicode.ttf", 60) + self.font = ImageFont.truetype("Arial_Unicode.ttf", 60) self.use_fp16 = kwargs.get("use_fp16", False) self.device = kwargs.get("device", "cpu") - self.model_channels = model_channels - time_embed_dim = model_channels * 4 - self.time_embed = nn.Sequential( - nn.Linear(model_channels, time_embed_dim), - nn.SiLU(), - nn.Linear(time_embed_dim, time_embed_dim), - ) + self.glyph_block = nn.Sequential( nn.Conv2d(glyph_channels, 8, 3, padding=1), nn.SiLU(), @@ -83,7 +78,8 @@ def __init__(self, dims=2, glyph_channels=1, position_channels=1, model_channels nn.Conv2d(32, 64, 3, padding=1, stride=2), nn.SiLU(), ) - self.time_embed = self.time_embed.to(device="cuda", dtype=torch.float16) + self.glyph_block.load_state_dict(load_file("glyph_block.safetensors")) + self.position_block.load_state_dict(load_file("position_block.safetensors")) self.glyph_block = self.glyph_block.to(device="cuda", dtype=torch.float16) self.position_block = self.position_block.to(device="cuda", dtype=torch.float16) @@ -91,12 +87,12 @@ def __init__(self, dims=2, glyph_channels=1, position_channels=1, model_channels self.vae.eval() self.fuse_block = zero_module(nn.Conv2d(256 + 64 + 4, model_channels, 3, padding=1)) + self.fuse_block.load_state_dict(load_file("fuse_block.safetensors")) self.fuse_block = self.fuse_block.to(device="cuda", dtype=torch.float16) @torch.no_grad() def forward( self, - context, text_info, mode, draw_pos, diff --git a/examples/research_projects/anytext/pipeline_anytext.py b/examples/research_projects/anytext/pipeline_anytext.py index 1b1b9de010e1..408b7450e8e0 100644 --- a/examples/research_projects/anytext/pipeline_anytext.py +++ b/examples/research_projects/anytext/pipeline_anytext.py @@ -1165,7 +1165,6 @@ def __call__( # ) # height, width = image.shape[-2:] guided_hint = self.auxiliary_latent_module( - context=prompt_embeds[1], text_info=text_info, mode=mode, draw_pos=draw_pos, diff --git a/examples/research_projects/anytext/text_embedding_module.py b/examples/research_projects/anytext/text_embedding_module.py index 1f09b63b0eaa..bcc687e8da44 100644 --- a/examples/research_projects/anytext/text_embedding_module.py +++ b/examples/research_projects/anytext/text_embedding_module.py @@ -21,7 +21,7 @@ def __init__(self, use_fp16): self.use_fp16 = use_fp16 self.device = "cuda" if torch.cuda.is_available() else "cpu" # TODO: Learn if the recommended font file is free to use - self.font = ImageFont.truetype("/home/cosmos/Documents/gits/AnyText/font/Arial_Unicode.ttf", 60) + self.font = ImageFont.truetype("Arial_Unicode.ttf", 60) self.frozen_CLIP_embedder_t3 = FrozenCLIPEmbedderT3(device=self.device) self.embedding_manager_config = { "valid": True, @@ -35,12 +35,12 @@ def __init__(self, use_fp16): # TODO: Understand the reason of param.requires_grad = True for param in self.embedding_manager.embedding_parameters(): param.requires_grad = True - rec_model_dir = "/home/cosmos/Documents/gits/AnyText/ocr_weights/ppv3_rec.pth" + rec_model_dir = "ppv3_rec.pth" self.text_predictor = create_predictor(rec_model_dir).eval() args = {} args["rec_image_shape"] = "3, 48, 320" args["rec_batch_num"] = 6 - args["rec_char_dict_path"] = "/home/cosmos/Documents/gits/AnyText/ocr_weights/ppocr_keys_v1.txt" + args["rec_char_dict_path"] = "ppocr_keys_v1.txt" args["use_fp16"] = False self.cn_recognizer = TextRecognizer(args, self.text_predictor) for param in self.text_predictor.parameters(): From 8b43bc34613a16b7288ebf860327088d77be66c3 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tolga=20Cang=C3=B6z?= Date: Wed, 7 Aug 2024 15:30:40 +0300 Subject: [PATCH 43/87] feat: Add safetensors module for loading model file --- examples/research_projects/anytext/recognizer.py | 3 ++- examples/research_projects/anytext/text_embedding_module.py | 2 +- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/examples/research_projects/anytext/recognizer.py b/examples/research_projects/anytext/recognizer.py index 9986562e5586..842ee7842b19 100755 --- a/examples/research_projects/anytext/recognizer.py +++ b/examples/research_projects/anytext/recognizer.py @@ -13,6 +13,7 @@ import torch.nn.functional as F from easydict import EasyDict as edict from ocr_recog.RecModel import RecModel +from safetensors.torch import load_file from skimage.transform._geometric import _umeyama as get_sym_mat @@ -105,7 +106,7 @@ def create_predictor(model_dir=None, model_lang="ch", is_onnx=False): rec_model = RecModel(rec_config) if model_file_path is not None: - rec_model.load_state_dict(torch.load(model_file_path, map_location="cpu")) + rec_model.load_state_dict(load_file(model_file_path)) rec_model.eval() return rec_model.eval() diff --git a/examples/research_projects/anytext/text_embedding_module.py b/examples/research_projects/anytext/text_embedding_module.py index bcc687e8da44..43d1d4091b00 100644 --- a/examples/research_projects/anytext/text_embedding_module.py +++ b/examples/research_projects/anytext/text_embedding_module.py @@ -35,7 +35,7 @@ def __init__(self, use_fp16): # TODO: Understand the reason of param.requires_grad = True for param in self.embedding_manager.embedding_parameters(): param.requires_grad = True - rec_model_dir = "ppv3_rec.pth" + rec_model_dir = "ppv3_rec.safetensors" self.text_predictor = create_predictor(rec_model_dir).eval() args = {} args["rec_image_shape"] = "3, 48, 320" From f60a72bfca670e6f74c9037eb087a56bbcbf2dd9 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tolga=20Cang=C3=B6z?= Date: Wed, 7 Aug 2024 18:41:02 +0300 Subject: [PATCH 44/87] Fix device issues --- .../anytext/auxiliary_latent_module.py | 2 +- .../research_projects/anytext/embedding_manager.py | 1 + .../anytext/frozen_clip_embedder_t3.py | 2 +- .../research_projects/anytext/pipeline_anytext.py | 1 + .../anytext/text_embedding_module.py | 13 ++++++++----- 5 files changed, 12 insertions(+), 7 deletions(-) diff --git a/examples/research_projects/anytext/auxiliary_latent_module.py b/examples/research_projects/anytext/auxiliary_latent_module.py index f686c39769dd..cd3f8abeb6ad 100644 --- a/examples/research_projects/anytext/auxiliary_latent_module.py +++ b/examples/research_projects/anytext/auxiliary_latent_module.py @@ -39,7 +39,7 @@ def __init__(self, glyph_channels=1, position_channels=1, model_channels=320, ** super().__init__() self.font = ImageFont.truetype("Arial_Unicode.ttf", 60) self.use_fp16 = kwargs.get("use_fp16", False) - self.device = kwargs.get("device", "cpu") + self.device = kwargs.get("device", "cuda") self.glyph_block = nn.Sequential( nn.Conv2d(glyph_channels, 8, 3, padding=1), diff --git a/examples/research_projects/anytext/embedding_manager.py b/examples/research_projects/anytext/embedding_manager.py index 9d011fac3cf1..0357a1e29971 100644 --- a/examples/research_projects/anytext/embedding_manager.py +++ b/examples/research_projects/anytext/embedding_manager.py @@ -134,6 +134,7 @@ def __init__( self.position_encoder = EncodeNet(position_channels, token_dim) if emb_type == "ocr": self.proj = nn.Sequential(zero_module(nn.Linear(40 * 64, token_dim)), nn.LayerNorm(token_dim)) + self.proj = self.proj.to(dtype=torch.float16 if kwargs.get("use_fp16", False) else torch.float32) if emb_type == "conv": self.glyph_encoder = EncodeNet(glyph_channels, token_dim) diff --git a/examples/research_projects/anytext/frozen_clip_embedder_t3.py b/examples/research_projects/anytext/frozen_clip_embedder_t3.py index e8be9497e876..296e4432cca2 100644 --- a/examples/research_projects/anytext/frozen_clip_embedder_t3.py +++ b/examples/research_projects/anytext/frozen_clip_embedder_t3.py @@ -20,7 +20,7 @@ def __init__( ): super().__init__() self.tokenizer = CLIPTokenizer.from_pretrained(version) - self.transformer = CLIPTextModel.from_pretrained(version).to(device) + self.transformer = CLIPTextModel.from_pretrained(version, torch_dtype=torch.float16).to(device) if use_vision: self.vit = CLIPVisionModelWithProjection.from_pretrained(version) self.processor = AutoProcessor.from_pretrained(version) diff --git a/examples/research_projects/anytext/pipeline_anytext.py b/examples/research_projects/anytext/pipeline_anytext.py index 408b7450e8e0..747255b9b64e 100644 --- a/examples/research_projects/anytext/pipeline_anytext.py +++ b/examples/research_projects/anytext/pipeline_anytext.py @@ -1172,6 +1172,7 @@ def __call__( num_images_per_prompt=num_images_per_prompt, np_hint=np_hint, ) + height, width = 512, 512 # elif isinstance(controlnet, MultiControlNetModel): # images = [] diff --git a/examples/research_projects/anytext/text_embedding_module.py b/examples/research_projects/anytext/text_embedding_module.py index 43d1d4091b00..1f76b7d38fcd 100644 --- a/examples/research_projects/anytext/text_embedding_module.py +++ b/examples/research_projects/anytext/text_embedding_module.py @@ -30,6 +30,7 @@ def __init__(self, use_fp16): "position_channels": 1, "add_pos": False, "placeholder_string": "*", + "use_fp16": self.use_fp16, } self.embedding_manager = EmbeddingManager(self.frozen_CLIP_embedder_t3, **self.embedding_manager_config) # TODO: Understand the reason of param.requires_grad = True @@ -41,8 +42,10 @@ def __init__(self, use_fp16): args["rec_image_shape"] = "3, 48, 320" args["rec_batch_num"] = 6 args["rec_char_dict_path"] = "ppocr_keys_v1.txt" - args["use_fp16"] = False - self.cn_recognizer = TextRecognizer(args, self.text_predictor) + args["use_fp16"] = True + self.cn_recognizer = TextRecognizer( + args, self.text_predictor.to(dtype=torch.float16 if use_fp16 else torch.float32) + ) for param in self.text_predictor.parameters(): param.requires_grad = False self.embedding_manager.recog = self.cn_recognizer @@ -149,9 +152,9 @@ def forward( glyphs = np.zeros((h * gly_scale, w * gly_scale, 1)) gly_line = np.zeros((80, 512, 1)) pos = pre_pos[i] - text_info["glyphs"] += [self.arr2tensor(glyphs, len(prompt))] - text_info["gly_line"] += [self.arr2tensor(gly_line, len(prompt))] - text_info["positions"] += [self.arr2tensor(pos, len(prompt))] + text_info["glyphs"] += [self.arr2tensor(glyphs, num_images_per_prompt)] + text_info["gly_line"] += [self.arr2tensor(gly_line, num_images_per_prompt)] + text_info["positions"] += [self.arr2tensor(pos, num_images_per_prompt)] # hint = self.arr2tensor(np_hint, len(prompt)) From be4a319bbc23a03dfa98be19a9318f0cd7d981aa Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tolga=20Cang=C3=B6z?= Date: Thu, 8 Aug 2024 20:40:50 +0300 Subject: [PATCH 45/87] Up --- .../anytext/auxiliary_latent_module.py | 33 +++++++------- .../anytext/convert_from_ckpt.py | 16 +++---- .../anytext/embedding_manager.py | 22 +++------- .../anytext/frozen_clip_embedder_t3.py | 16 ++++--- .../anytext/pipeline_anytext.py | 9 +++- .../research_projects/anytext/recognizer.py | 43 ++++++++----------- .../anytext/text_embedding_module.py | 31 +++++-------- 7 files changed, 76 insertions(+), 94 deletions(-) diff --git a/examples/research_projects/anytext/auxiliary_latent_module.py b/examples/research_projects/anytext/auxiliary_latent_module.py index cd3f8abeb6ad..3898e7e9b090 100644 --- a/examples/research_projects/anytext/auxiliary_latent_module.py +++ b/examples/research_projects/anytext/auxiliary_latent_module.py @@ -37,9 +37,9 @@ def retrieve_latents( class AuxiliaryLatentModule(nn.Module): def __init__(self, glyph_channels=1, position_channels=1, model_channels=320, **kwargs): super().__init__() - self.font = ImageFont.truetype("Arial_Unicode.ttf", 60) + self.font = ImageFont.truetype("font/Arial_Unicode.ttf", 60) self.use_fp16 = kwargs.get("use_fp16", False) - self.device = kwargs.get("device", "cuda") + self.device = kwargs.get("device", "cpu") self.glyph_block = nn.Sequential( nn.Conv2d(glyph_channels, 8, 3, padding=1), @@ -78,17 +78,22 @@ def __init__(self, glyph_channels=1, position_channels=1, model_channels=320, ** nn.Conv2d(32, 64, 3, padding=1, stride=2), nn.SiLU(), ) - self.glyph_block.load_state_dict(load_file("glyph_block.safetensors")) - self.position_block.load_state_dict(load_file("position_block.safetensors")) - self.glyph_block = self.glyph_block.to(device="cuda", dtype=torch.float16) - self.position_block = self.position_block.to(device="cuda", dtype=torch.float16) self.vae = kwargs.get("vae") self.vae.eval() self.fuse_block = zero_module(nn.Conv2d(256 + 64 + 4, model_channels, 3, padding=1)) - self.fuse_block.load_state_dict(load_file("fuse_block.safetensors")) - self.fuse_block = self.fuse_block.to(device="cuda", dtype=torch.float16) + + self.glyph_block.load_state_dict( + load_file("AuxiliaryLatentModule/glyph_block.safetensors", device=self.device) + ) + self.glyph_block = self.glyph_block.to(dtype=torch.float16 if self.use_fp16 else torch.float32) + self.position_block.load_state_dict( + load_file("AuxiliaryLatentModule/position_block.safetensors", device=self.device) + ) + self.position_block = self.position_block.to(dtype=torch.float16 if self.use_fp16 else torch.float32) + self.fuse_block.load_state_dict(load_file("AuxiliaryLatentModule/fuse_block.safetensors", device=self.device)) + self.fuse_block = self.fuse_block.to(dtype=torch.float16 if self.use_fp16 else torch.float32) @torch.no_grad() def forward( @@ -121,7 +126,6 @@ def forward( edit_image = self.resize_image( edit_image, max_length=768 ) # make w h multiple of 64, resize if w or h > max_length - h, w = edit_image.shape[:2] # change h, w by input ref_img # get masked_x masked_img = ((edit_image.astype(np.float32) / 127.5) - 1.0) * (1 - np_hint) @@ -129,22 +133,19 @@ def forward( masked_img = torch.from_numpy(masked_img.copy()).float().to(self.device) if self.use_fp16: masked_img = masked_img.half() - masked_x = self.encode_first_stage(masked_img[None, ...]).detach() + masked_x = (retrieve_latents(self.vae.encode(masked_img[None, ...])) * self.vae.config.scaling_factor).detach() if self.use_fp16: masked_x = masked_x.half() text_info["masked_x"] = torch.cat([masked_x for _ in range(num_images_per_prompt)], dim=0) glyphs = torch.cat(text_info["glyphs"], dim=1).sum(dim=1, keepdim=True) positions = torch.cat(text_info["positions"], dim=1).sum(dim=1, keepdim=True) - enc_glyph = self.glyph_block(glyphs.cuda()) - enc_pos = self.position_block(positions.cuda()) - guided_hint = self.fuse_block(torch.cat([enc_glyph, enc_pos, text_info["masked_x"].cuda()], dim=1)) + enc_glyph = self.glyph_block(glyphs) + enc_pos = self.position_block(positions) + guided_hint = self.fuse_block(torch.cat([enc_glyph, enc_pos, text_info["masked_x"]], dim=1)) return guided_hint - def encode_first_stage(self, masked_img): - return retrieve_latents(self.vae.encode(masked_img)) * self.vae.config.scaling_factor - def check_channels(self, image): channels = image.shape[2] if len(image.shape) == 3 else 1 if channels == 1: diff --git a/examples/research_projects/anytext/convert_from_ckpt.py b/examples/research_projects/anytext/convert_from_ckpt.py index 9968462399ce..172afb30a4f5 100644 --- a/examples/research_projects/anytext/convert_from_ckpt.py +++ b/examples/research_projects/anytext/convert_from_ckpt.py @@ -34,13 +34,18 @@ CLIPVisionModelWithProjection, ) -from ...models import ( +from diffusers.models import ( AutoencoderKL, ControlNetModel, PriorTransformer, UNet2DConditionModel, ) -from ...schedulers import ( +from diffusers.pipelines.latent_diffusion.pipeline_latent_diffusion import LDMBertConfig, LDMBertModel +from diffusers.pipelines.paint_by_example import PaintByExampleImageEncoder +from diffusers.pipelines.pipeline_utils import DiffusionPipeline +from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker +from diffusers.pipelines.stable_diffusion.stable_unclip_image_normalizer import StableUnCLIPImageNormalizer +from diffusers.schedulers import ( DDIMScheduler, DDPMScheduler, DPMSolverMultistepScheduler, @@ -51,12 +56,7 @@ PNDMScheduler, UnCLIPScheduler, ) -from ...utils import is_accelerate_available, logging -from ..latent_diffusion.pipeline_latent_diffusion import LDMBertConfig, LDMBertModel -from ..paint_by_example import PaintByExampleImageEncoder -from ..pipeline_utils import DiffusionPipeline -from .safety_checker import StableDiffusionSafetyChecker -from .stable_unclip_image_normalizer import StableUnCLIPImageNormalizer +from diffusers.utils import is_accelerate_available, logging if is_accelerate_available(): diff --git a/examples/research_projects/anytext/embedding_manager.py b/examples/research_projects/anytext/embedding_manager.py index 0357a1e29971..7ebdb389cdac 100644 --- a/examples/research_projects/anytext/embedding_manager.py +++ b/examples/research_projects/anytext/embedding_manager.py @@ -107,25 +107,17 @@ class EmbeddingManager(nn.Module): def __init__( self, embedder, - valid=True, - glyph_channels=20, position_channels=1, placeholder_string="*", add_pos=False, emb_type="ocr", - **kwargs, + use_fp16=False, ): super().__init__() - if hasattr(embedder, "tokenizer"): # using Stable Diffusion's CLIP encoder - get_token_for_string = partial(get_clip_token_for_string, embedder.tokenizer) - token_dim = 768 - if hasattr(embedder, "vit"): - assert emb_type == "vit" - self.get_vision_emb = partial(get_clip_vision_emb, embedder.vit, embedder.processor) - self.get_recog_emb = None - else: # using LDM's BERT encoder - get_token_for_string = partial(get_bert_token_for_string, embedder.tknz_fn) - token_dim = 1280 + get_token_for_string = partial(get_clip_token_for_string, embedder.tokenizer) + token_dim = 768 + self.get_recog_emb = None + token_dim = 1280 self.token_dim = token_dim self.emb_type = emb_type @@ -134,9 +126,7 @@ def __init__( self.position_encoder = EncodeNet(position_channels, token_dim) if emb_type == "ocr": self.proj = nn.Sequential(zero_module(nn.Linear(40 * 64, token_dim)), nn.LayerNorm(token_dim)) - self.proj = self.proj.to(dtype=torch.float16 if kwargs.get("use_fp16", False) else torch.float32) - if emb_type == "conv": - self.glyph_encoder = EncodeNet(glyph_channels, token_dim) + self.proj = self.proj.to(dtype=torch.float16 if use_fp16 else torch.float32) self.placeholder_token = get_token_for_string(placeholder_string) diff --git a/examples/research_projects/anytext/frozen_clip_embedder_t3.py b/examples/research_projects/anytext/frozen_clip_embedder_t3.py index 296e4432cca2..f0446ca52621 100644 --- a/examples/research_projects/anytext/frozen_clip_embedder_t3.py +++ b/examples/research_projects/anytext/frozen_clip_embedder_t3.py @@ -1,6 +1,6 @@ import torch from torch import nn -from transformers import AutoProcessor, CLIPTextModel, CLIPTokenizer, CLIPVisionModelWithProjection +from transformers import CLIPTextModel, CLIPTokenizer from transformers.modeling_attn_mask_utils import _create_4d_causal_attention_mask, _prepare_4d_attention_mask @@ -16,14 +16,18 @@ class FrozenCLIPEmbedderT3(AbstractEncoder): """Uses the CLIP transformer encoder for text (from Hugging Face)""" def __init__( - self, version="openai/clip-vit-large-patch14", device="cuda", max_length=77, freeze=True, use_vision=False + self, + version="openai/clip-vit-large-patch14", + device="cpu", + max_length=77, + freeze=True, + use_fp16=False, ): super().__init__() self.tokenizer = CLIPTokenizer.from_pretrained(version) - self.transformer = CLIPTextModel.from_pretrained(version, torch_dtype=torch.float16).to(device) - if use_vision: - self.vit = CLIPVisionModelWithProjection.from_pretrained(version) - self.processor = AutoProcessor.from_pretrained(version) + self.transformer = CLIPTextModel.from_pretrained( + version, use_safetensors=True, torch_dtype=torch.float16 if use_fp16 else torch.float32 + ).to(device) self.device = device self.max_length = max_length if freeze: diff --git a/examples/research_projects/anytext/pipeline_anytext.py b/examples/research_projects/anytext/pipeline_anytext.py index 747255b9b64e..32cb07e1a425 100644 --- a/examples/research_projects/anytext/pipeline_anytext.py +++ b/examples/research_projects/anytext/pipeline_anytext.py @@ -218,10 +218,15 @@ def __init__( feature_extractor: CLIPImageProcessor, image_encoder: CLIPVisionModelWithProjection = None, requires_safety_checker: bool = True, + font_path: str = "font/Arial_Unicode.ttf", ): super().__init__() - self.text_embedding_module = TextEmbeddingModule(use_fp16=unet.dtype == torch.float16) - self.auxiliary_latent_module = AuxiliaryLatentModule(vae=vae, use_fp16=unet.dtype == torch.float16) + self.text_embedding_module = TextEmbeddingModule( + use_fp16=unet.dtype == torch.float16, device=unet.device, font_path=font_path + ) + self.auxiliary_latent_module = AuxiliaryLatentModule( + vae=vae, use_fp16=unet.dtype == torch.float16, device=unet.device, font_path=font_path + ) if safety_checker is None and requires_safety_checker: logger.warning( diff --git a/examples/research_projects/anytext/recognizer.py b/examples/research_projects/anytext/recognizer.py index 842ee7842b19..7359cee0b43d 100755 --- a/examples/research_projects/anytext/recognizer.py +++ b/examples/research_projects/anytext/recognizer.py @@ -78,37 +78,30 @@ def crop_image(src_img, mask): return result -def create_predictor(model_dir=None, model_lang="ch", is_onnx=False): +def create_predictor(model_dir=None, model_lang="ch", device="cpu", use_fp16=False): model_file_path = model_dir if model_file_path is not None and not os.path.exists(model_file_path): raise ValueError("not find model file path {}".format(model_file_path)) - if is_onnx: - import onnxruntime as ort - - sess = ort.InferenceSession( - model_file_path, providers=["CPUExecutionProvider"] - ) # 'TensorrtExecutionProvider', 'CUDAExecutionProvider', 'CPUExecutionProvider' - return sess + if model_lang == "ch": + n_class = 6625 + elif model_lang == "en": + n_class = 97 else: - if model_lang == "ch": - n_class = 6625 - elif model_lang == "en": - n_class = 97 - else: - raise ValueError(f"Unsupported OCR recog model_lang: {model_lang}") - rec_config = edict( - in_channels=3, - backbone=edict(type="MobileNetV1Enhance", scale=0.5, last_conv_stride=[1, 2], last_pool_type="avg"), - neck=edict(type="SequenceEncoder", encoder_type="svtr", dims=64, depth=2, hidden_dims=120, use_guide=True), - head=edict(type="CTCHead", fc_decay=0.00001, out_channels=n_class, return_feats=True), + raise ValueError(f"Unsupported OCR recog model_lang: {model_lang}") + rec_config = edict( + in_channels=3, + backbone=edict(type="MobileNetV1Enhance", scale=0.5, last_conv_stride=[1, 2], last_pool_type="avg"), + neck=edict(type="SequenceEncoder", encoder_type="svtr", dims=64, depth=2, hidden_dims=120, use_guide=True), + head=edict(type="CTCHead", fc_decay=0.00001, out_channels=n_class, return_feats=True), + ) + + rec_model = RecModel(rec_config) + if model_file_path is not None: + rec_model.load_state_dict(load_file(model_file_path, device=device)).to( + dtype=torch.float16 if use_fp16 else torch.float32 ) - - rec_model = RecModel(rec_config) - if model_file_path is not None: - rec_model.load_state_dict(load_file(model_file_path)) - rec_model.eval() - return rec_model.eval() + return rec_model def _check_image_file(path): diff --git a/examples/research_projects/anytext/text_embedding_module.py b/examples/research_projects/anytext/text_embedding_module.py index 1f76b7d38fcd..607317891f24 100644 --- a/examples/research_projects/anytext/text_embedding_module.py +++ b/examples/research_projects/anytext/text_embedding_module.py @@ -16,36 +16,25 @@ class TextEmbeddingModule(nn.Module): - def __init__(self, use_fp16): + def __init__(self, font_path, use_fp16=False, device="cpu"): super().__init__() self.use_fp16 = use_fp16 - self.device = "cuda" if torch.cuda.is_available() else "cpu" + self.device = device # TODO: Learn if the recommended font file is free to use - self.font = ImageFont.truetype("Arial_Unicode.ttf", 60) - self.frozen_CLIP_embedder_t3 = FrozenCLIPEmbedderT3(device=self.device) - self.embedding_manager_config = { - "valid": True, - "emb_type": "ocr", - "glyph_channels": 1, - "position_channels": 1, - "add_pos": False, - "placeholder_string": "*", - "use_fp16": self.use_fp16, - } - self.embedding_manager = EmbeddingManager(self.frozen_CLIP_embedder_t3, **self.embedding_manager_config) + self.font = ImageFont.truetype(font_path, 60) + self.frozen_CLIP_embedder_t3 = FrozenCLIPEmbedderT3(device=self.device, use_fp16=self.use_fp16) + self.embedding_manager = EmbeddingManager(self.frozen_CLIP_embedder_t3, use_fp16=self.use_fp16) # TODO: Understand the reason of param.requires_grad = True for param in self.embedding_manager.embedding_parameters(): param.requires_grad = True - rec_model_dir = "ppv3_rec.safetensors" - self.text_predictor = create_predictor(rec_model_dir).eval() + rec_model_dir = "OCR/ppv3_rec.safetensors" + self.text_predictor = create_predictor(rec_model_dir, device=self.device, use_fp16=self.use_fp16) args = {} args["rec_image_shape"] = "3, 48, 320" args["rec_batch_num"] = 6 - args["rec_char_dict_path"] = "ppocr_keys_v1.txt" - args["use_fp16"] = True - self.cn_recognizer = TextRecognizer( - args, self.text_predictor.to(dtype=torch.float16 if use_fp16 else torch.float32) - ) + args["rec_char_dict_path"] = "OCR/ppocr_keys_v1.txt" + args["use_fp16"] = self.use_fp16 + self.cn_recognizer = TextRecognizer(args, self.text_predictor, device=self.device, use_fp16=self.use_fp16) for param in self.text_predictor.parameters(): param.requires_grad = False self.embedding_manager.recog = self.cn_recognizer From f7131713b656aae407a84d50413a80e2d9000e9e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tolga=20Cang=C3=B6z?= Date: Thu, 8 Aug 2024 23:11:12 +0300 Subject: [PATCH 46/87] Up --- .../anytext/auxiliary_latent_module.py | 11 ++++++----- .../anytext/text_embedding_module.py | 8 ++++---- 2 files changed, 10 insertions(+), 9 deletions(-) diff --git a/examples/research_projects/anytext/auxiliary_latent_module.py b/examples/research_projects/anytext/auxiliary_latent_module.py index 3898e7e9b090..f09066bf29e0 100644 --- a/examples/research_projects/anytext/auxiliary_latent_module.py +++ b/examples/research_projects/anytext/auxiliary_latent_module.py @@ -35,11 +35,13 @@ def retrieve_latents( class AuxiliaryLatentModule(nn.Module): - def __init__(self, glyph_channels=1, position_channels=1, model_channels=320, **kwargs): + def __init__( + self, glyph_channels=1, position_channels=1, model_channels=320, vae=None, device="cpu", use_fp16=False + ): super().__init__() self.font = ImageFont.truetype("font/Arial_Unicode.ttf", 60) - self.use_fp16 = kwargs.get("use_fp16", False) - self.device = kwargs.get("device", "cpu") + self.use_fp16 = use_fp16 + self.device = device self.glyph_block = nn.Sequential( nn.Conv2d(glyph_channels, 8, 3, padding=1), @@ -79,8 +81,7 @@ def __init__(self, glyph_channels=1, position_channels=1, model_channels=320, ** nn.SiLU(), ) - self.vae = kwargs.get("vae") - self.vae.eval() + self.vae = vae.eval() self.fuse_block = zero_module(nn.Conv2d(256 + 64 + 4, model_channels, 3, padding=1)) diff --git a/examples/research_projects/anytext/text_embedding_module.py b/examples/research_projects/anytext/text_embedding_module.py index 607317891f24..32a7a8700469 100644 --- a/examples/research_projects/anytext/text_embedding_module.py +++ b/examples/research_projects/anytext/text_embedding_module.py @@ -28,15 +28,15 @@ def __init__(self, font_path, use_fp16=False, device="cpu"): for param in self.embedding_manager.embedding_parameters(): param.requires_grad = True rec_model_dir = "OCR/ppv3_rec.safetensors" - self.text_predictor = create_predictor(rec_model_dir, device=self.device, use_fp16=self.use_fp16) + self.text_predictor = create_predictor(rec_model_dir, device=self.device, use_fp16=self.use_fp16).eval() + for param in self.text_predictor.parameters(): + param.requires_grad = False args = {} args["rec_image_shape"] = "3, 48, 320" args["rec_batch_num"] = 6 args["rec_char_dict_path"] = "OCR/ppocr_keys_v1.txt" args["use_fp16"] = self.use_fp16 - self.cn_recognizer = TextRecognizer(args, self.text_predictor, device=self.device, use_fp16=self.use_fp16) - for param in self.text_predictor.parameters(): - param.requires_grad = False + self.cn_recognizer = TextRecognizer(args, self.text_predictor) self.embedding_manager.recog = self.cn_recognizer @torch.no_grad() From f347ff2e940d3527d87a343f428b5ac4b134e2c6 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tolga=20Cang=C3=B6z?= Date: Fri, 9 Aug 2024 14:26:08 +0300 Subject: [PATCH 47/87] refactor: Simplify --- .../anytext/auxiliary_latent_module.py | 19 ++-- .../anytext/embedding_manager.py | 103 ++---------------- 2 files changed, 17 insertions(+), 105 deletions(-) diff --git a/examples/research_projects/anytext/auxiliary_latent_module.py b/examples/research_projects/anytext/auxiliary_latent_module.py index f09066bf29e0..a1ff0ae14a8f 100644 --- a/examples/research_projects/anytext/auxiliary_latent_module.py +++ b/examples/research_projects/anytext/auxiliary_latent_module.py @@ -13,13 +13,6 @@ logger = logging.get_logger(__name__) # pylint: disable=invalid-name -# Copied from diffusers.models.controlnet.zero_module -def zero_module(module: nn.Module) -> nn.Module: - for p in module.parameters(): - nn.init.zeros_(p) - return module - - # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents def retrieve_latents( encoder_output: torch.Tensor, generator: Optional[torch.Generator] = None, sample_mode: str = "sample" @@ -83,18 +76,20 @@ def __init__( self.vae = vae.eval() - self.fuse_block = zero_module(nn.Conv2d(256 + 64 + 4, model_channels, 3, padding=1)) + self.fuse_block = nn.Conv2d(256 + 64 + 4, model_channels, 3, padding=1) self.glyph_block.load_state_dict( load_file("AuxiliaryLatentModule/glyph_block.safetensors", device=self.device) ) - self.glyph_block = self.glyph_block.to(dtype=torch.float16 if self.use_fp16 else torch.float32) self.position_block.load_state_dict( load_file("AuxiliaryLatentModule/position_block.safetensors", device=self.device) ) - self.position_block = self.position_block.to(dtype=torch.float16 if self.use_fp16 else torch.float32) self.fuse_block.load_state_dict(load_file("AuxiliaryLatentModule/fuse_block.safetensors", device=self.device)) - self.fuse_block = self.fuse_block.to(dtype=torch.float16 if self.use_fp16 else torch.float32) + + if use_fp16: + self.glyph_block = self.glyph_block.to(dtype=torch.float16) + self.position_block = self.position_block.to(dtype=torch.float16) + self.fuse_block = self.fuse_block.to(dtype=torch.float16) @torch.no_grad() def forward( @@ -181,6 +176,6 @@ def to(self, device): self.device = device self.glyph_block = self.glyph_block.to(device) self.position_block = self.position_block.to(device) - self.vae = self.vae.to(device) self.fuse_block = self.fuse_block.to(device) + self.vae = self.vae.to(device) return self diff --git a/examples/research_projects/anytext/embedding_manager.py b/examples/research_projects/anytext/embedding_manager.py index 7ebdb389cdac..f396d69f4ae8 100644 --- a/examples/research_projects/anytext/embedding_manager.py +++ b/examples/research_projects/anytext/embedding_manager.py @@ -5,27 +5,7 @@ import torch import torch.nn as nn -import torch.nn.functional as F - - -def conv_nd(dims, *args, **kwargs): - """ - Create a 1D, 2D, or 3D convolution module. - """ - if dims == 1: - return nn.Conv1d(*args, **kwargs) - elif dims == 2: - return nn.Conv2d(*args, **kwargs) - elif dims == 3: - return nn.Conv3d(*args, **kwargs) - raise ValueError(f"unsupported dimensions: {dims}") - - -# Copied from diffusers.models.controlnet.zero_module -def zero_module(module: nn.Module) -> nn.Module: - for p in module.parameters(): - nn.init.zeros_(p) - return module +from safetensors.torch import load_file def get_clip_token_for_string(tokenizer, string): @@ -45,24 +25,6 @@ def get_clip_token_for_string(tokenizer, string): return tokens[0, 1] -def get_bert_token_for_string(tokenizer, string): - token = tokenizer(string) - assert ( - torch.count_nonzero(token) == 3 - ), f"String '{string}' maps to more than a single token. Please use another string" - token = token[0, 1] - return token - - -def get_clip_vision_emb(encoder, processor, img): - _img = img.repeat(1, 3, 1, 1) * 255 - inputs = processor(images=_img, return_tensors="pt") - inputs["pixel_values"] = inputs["pixel_values"].to(img.device) - outputs = encoder(**inputs) - emb = outputs.image_embeds - return emb - - def get_recog_emb(encoder, img_list): _img_list = [(img.repeat(1, 3, 1, 1) * 255)[0] for img in img_list] encoder.predictor.eval() @@ -70,86 +32,40 @@ def get_recog_emb(encoder, img_list): return preds_neck -def pad_H(x): - _, _, H, W = x.shape - p_top = (W - H) // 2 - p_bot = W - H - p_top - return F.pad(x, (0, 0, p_top, p_bot)) - - -class EncodeNet(nn.Module): - def __init__(self, in_channels, out_channels): - super(EncodeNet, self).__init__() - chan = 16 - n_layer = 4 # downsample - - self.conv1 = conv_nd(2, in_channels, chan, 3, padding=1) - self.conv_list = nn.ModuleList([]) - _c = chan - for i in range(n_layer): - self.conv_list.append(conv_nd(2, _c, _c * 2, 3, padding=1, stride=2)) - _c *= 2 - self.conv2 = conv_nd(2, _c, out_channels, 3, padding=1) - self.avgpool = nn.AdaptiveAvgPool2d(1) - self.act = nn.SiLU() - - def forward(self, x): - x = self.act(self.conv1(x)) - for layer in self.conv_list: - x = self.act(layer(x)) - x = self.act(self.conv2(x)) - x = self.avgpool(x) - x = x.view(x.size(0), -1) - return x - - class EmbeddingManager(nn.Module): def __init__( self, embedder, - position_channels=1, placeholder_string="*", - add_pos=False, - emb_type="ocr", use_fp16=False, ): super().__init__() get_token_for_string = partial(get_clip_token_for_string, embedder.tokenizer) token_dim = 768 self.get_recog_emb = None - token_dim = 1280 self.token_dim = token_dim - self.emb_type = emb_type - self.add_pos = add_pos - if add_pos: - self.position_encoder = EncodeNet(position_channels, token_dim) - if emb_type == "ocr": - self.proj = nn.Sequential(zero_module(nn.Linear(40 * 64, token_dim)), nn.LayerNorm(token_dim)) - self.proj = self.proj.to(dtype=torch.float16 if use_fp16 else torch.float32) + self.proj = nn.Linear(40 * 64, token_dim) + self.proj.load_state_dict(load_file("EmbeddingManager/embedding_manager.safetensors", device=self.device)) + if use_fp16: + self.proj = self.proj.to(dtype=torch.float16) self.placeholder_token = get_token_for_string(placeholder_string) + @torch.no_grad() def encode_text(self, text_info): - if self.get_recog_emb is None and self.emb_type == "ocr": + if self.get_recog_emb is None: self.get_recog_emb = partial(get_recog_emb, self.recog) gline_list = [] - pos_list = [] for i in range(len(text_info["n_lines"])): # sample index in a batch n_lines = text_info["n_lines"][i] for j in range(n_lines): # line gline_list += [text_info["gly_line"][j][i : i + 1]] - if self.add_pos: - pos_list += [text_info["positions"][j][i : i + 1]] if len(gline_list) > 0: - if self.emb_type == "ocr": - recog_emb = self.get_recog_emb(gline_list) - enc_glyph = self.proj(recog_emb.reshape(recog_emb.shape[0], -1)) - if self.add_pos: - enc_pos = self.position_encoder(torch.cat(gline_list, dim=0)) - enc_glyph = enc_glyph + enc_pos + recog_emb = self.get_recog_emb(gline_list) + enc_glyph = self.proj(recog_emb.reshape(recog_emb.shape[0], -1)) self.text_embs_all = [] n_idx = 0 @@ -161,6 +77,7 @@ def encode_text(self, text_info): n_idx += 1 self.text_embs_all += [text_embs] + @torch.no_grad() def forward( self, tokenized_text, From d52e973756e31551f4cf368501e1dc720a5a1584 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tolga=20Cang=C3=B6z?= Date: Fri, 9 Aug 2024 17:32:37 +0300 Subject: [PATCH 48/87] refactor: Simplify code for loading models and handling data types --- .../anytext/auxiliary_latent_module.py | 10 ++++---- .../anytext/embedding_manager.py | 4 +-- .../anytext/ocr_recog/RecModel.py | 2 ++ .../anytext/pipeline_anytext.py | 6 +++-- .../research_projects/anytext/recognizer.py | 4 +-- .../anytext/text_embedding_module.py | 25 ++++++++++++++----- 6 files changed, 33 insertions(+), 18 deletions(-) diff --git a/examples/research_projects/anytext/auxiliary_latent_module.py b/examples/research_projects/anytext/auxiliary_latent_module.py index a1ff0ae14a8f..25a59e99e5f3 100644 --- a/examples/research_projects/anytext/auxiliary_latent_module.py +++ b/examples/research_projects/anytext/auxiliary_latent_module.py @@ -29,10 +29,10 @@ def retrieve_latents( class AuxiliaryLatentModule(nn.Module): def __init__( - self, glyph_channels=1, position_channels=1, model_channels=320, vae=None, device="cpu", use_fp16=False + self, font_path, glyph_channels=1, position_channels=1, model_channels=320, vae=None, device="cpu", use_fp16=False ): super().__init__() - self.font = ImageFont.truetype("font/Arial_Unicode.ttf", 60) + self.font = ImageFont.truetype(font_path, 60) self.use_fp16 = use_fp16 self.device = device @@ -79,12 +79,12 @@ def __init__( self.fuse_block = nn.Conv2d(256 + 64 + 4, model_channels, 3, padding=1) self.glyph_block.load_state_dict( - load_file("AuxiliaryLatentModule/glyph_block.safetensors", device=self.device) + load_file("glyph_block.safetensors", device=str(self.device)) ) self.position_block.load_state_dict( - load_file("AuxiliaryLatentModule/position_block.safetensors", device=self.device) + load_file("position_block.safetensors", device=str(self.device)) ) - self.fuse_block.load_state_dict(load_file("AuxiliaryLatentModule/fuse_block.safetensors", device=self.device)) + self.fuse_block.load_state_dict(load_file("fuse_block.safetensors", device=str(self.device))) if use_fp16: self.glyph_block = self.glyph_block.to(dtype=torch.float16) diff --git a/examples/research_projects/anytext/embedding_manager.py b/examples/research_projects/anytext/embedding_manager.py index f396d69f4ae8..256b0a60857c 100644 --- a/examples/research_projects/anytext/embedding_manager.py +++ b/examples/research_projects/anytext/embedding_manager.py @@ -46,7 +46,7 @@ def __init__( self.token_dim = token_dim self.proj = nn.Linear(40 * 64, token_dim) - self.proj.load_state_dict(load_file("EmbeddingManager/embedding_manager.safetensors", device=self.device)) + self.proj.load_state_dict(load_file("proj.safetensors", device=str(embedder.device))) if use_fp16: self.proj = self.proj.to(dtype=torch.float16) @@ -65,7 +65,7 @@ def encode_text(self, text_info): if len(gline_list) > 0: recog_emb = self.get_recog_emb(gline_list) - enc_glyph = self.proj(recog_emb.reshape(recog_emb.shape[0], -1)) + enc_glyph = self.proj(recog_emb.reshape(recog_emb.shape[0], -1).to(self.proj.weight.device)) self.text_embs_all = [] n_idx = 0 diff --git a/examples/research_projects/anytext/ocr_recog/RecModel.py b/examples/research_projects/anytext/ocr_recog/RecModel.py index 50b0cec967d5..4c72ecdec3c4 100755 --- a/examples/research_projects/anytext/ocr_recog/RecModel.py +++ b/examples/research_projects/anytext/ocr_recog/RecModel.py @@ -34,6 +34,8 @@ def load_3rd_state_dict(self, _3rd_name, _state): self.head.load_3rd_state_dict(_3rd_name, _state) def forward(self, x): + import torch + x = x.to(torch.float32) x = self.backbone(x) x = self.neck(x) x = self.head(x) diff --git a/examples/research_projects/anytext/pipeline_anytext.py b/examples/research_projects/anytext/pipeline_anytext.py index 32cb07e1a425..47f03f7e41f9 100644 --- a/examples/research_projects/anytext/pipeline_anytext.py +++ b/examples/research_projects/anytext/pipeline_anytext.py @@ -208,6 +208,7 @@ class AnyTextPipeline( def __init__( self, + font_path: str, vae: AutoencoderKL, text_encoder: CLIPTextModel, tokenizer: CLIPTokenizer, @@ -218,7 +219,6 @@ def __init__( feature_extractor: CLIPImageProcessor, image_encoder: CLIPVisionModelWithProjection = None, requires_safety_checker: bool = True, - font_path: str = "font/Arial_Unicode.ttf", ): super().__init__() self.text_embedding_module = TextEmbeddingModule( @@ -257,13 +257,15 @@ def __init__( safety_checker=safety_checker, feature_extractor=feature_extractor, image_encoder=image_encoder, + # text_embedding_module=text_embedding_module, + # auxiliary_latent_module=auxiliary_latent_module, ) self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor, do_convert_rgb=True) self.control_image_processor = VaeImageProcessor( vae_scale_factor=self.vae_scale_factor, do_convert_rgb=True, do_normalize=False ) - self.register_to_config(requires_safety_checker=requires_safety_checker) + self.register_to_config(requires_safety_checker=requires_safety_checker, font_path=font_path) def modify_prompt(self, prompt): prompt = prompt.replace("“", '"') diff --git a/examples/research_projects/anytext/recognizer.py b/examples/research_projects/anytext/recognizer.py index 7359cee0b43d..25d527fdedc3 100755 --- a/examples/research_projects/anytext/recognizer.py +++ b/examples/research_projects/anytext/recognizer.py @@ -98,9 +98,7 @@ def create_predictor(model_dir=None, model_lang="ch", device="cpu", use_fp16=Fal rec_model = RecModel(rec_config) if model_file_path is not None: - rec_model.load_state_dict(load_file(model_file_path, device=device)).to( - dtype=torch.float16 if use_fp16 else torch.float32 - ) + rec_model.load_state_dict(torch.load(model_file_path, map_location=device)) return rec_model diff --git a/examples/research_projects/anytext/text_embedding_module.py b/examples/research_projects/anytext/text_embedding_module.py index 32a7a8700469..6b62c8d37b27 100644 --- a/examples/research_projects/anytext/text_embedding_module.py +++ b/examples/research_projects/anytext/text_embedding_module.py @@ -24,10 +24,9 @@ def __init__(self, font_path, use_fp16=False, device="cpu"): self.font = ImageFont.truetype(font_path, 60) self.frozen_CLIP_embedder_t3 = FrozenCLIPEmbedderT3(device=self.device, use_fp16=self.use_fp16) self.embedding_manager = EmbeddingManager(self.frozen_CLIP_embedder_t3, use_fp16=self.use_fp16) - # TODO: Understand the reason of param.requires_grad = True - for param in self.embedding_manager.embedding_parameters(): - param.requires_grad = True - rec_model_dir = "OCR/ppv3_rec.safetensors" + # for param in self.embedding_manager.embedding_parameters(): + # param.requires_grad = True + rec_model_dir = "OCR/ppv3_rec.pth" self.text_predictor = create_predictor(rec_model_dir, device=self.device, use_fp16=self.use_fp16).eval() for param in self.text_predictor.parameters(): param.requires_grad = False @@ -36,8 +35,7 @@ def __init__(self, font_path, use_fp16=False, device="cpu"): args["rec_batch_num"] = 6 args["rec_char_dict_path"] = "OCR/ppocr_keys_v1.txt" args["use_fp16"] = self.use_fp16 - self.cn_recognizer = TextRecognizer(args, self.text_predictor) - self.embedding_manager.recog = self.cn_recognizer + self.embedding_manager.recog = TextRecognizer(args, self.text_predictor) @torch.no_grad() def forward( @@ -290,3 +288,18 @@ def draw_glyph2(self, font, text, polygon, vertAng=10, scale=1, width=512, heigh img.paste(rotated_layer, (x_offset, y_offset), rotated_layer) img = np.expand_dims(np.array(img.convert("1")), axis=2).astype(np.float64) return img + + def insert_spaces(self, string, nSpace): + if nSpace == 0: + return string + new_string = "" + for char in string: + new_string += char + " " * nSpace + return new_string[:-nSpace] + + def to(self, device): + self.device = device + self.frozen_CLIP_embedder_t3.to(device) + self.embedding_manager.to(device) + self.text_predictor.to(device) + return self From a3b493f8d79aac5a295cb34be6ee8b34ff151d48 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tolga=20Cang=C3=B6z?= Date: Fri, 9 Aug 2024 17:33:15 +0300 Subject: [PATCH 49/87] `make style` --- .../anytext/auxiliary_latent_module.py | 17 ++++++++++------- .../anytext/ocr_recog/RecModel.py | 1 + .../research_projects/anytext/recognizer.py | 1 - 3 files changed, 11 insertions(+), 8 deletions(-) diff --git a/examples/research_projects/anytext/auxiliary_latent_module.py b/examples/research_projects/anytext/auxiliary_latent_module.py index 25a59e99e5f3..fa83db89b206 100644 --- a/examples/research_projects/anytext/auxiliary_latent_module.py +++ b/examples/research_projects/anytext/auxiliary_latent_module.py @@ -29,7 +29,14 @@ def retrieve_latents( class AuxiliaryLatentModule(nn.Module): def __init__( - self, font_path, glyph_channels=1, position_channels=1, model_channels=320, vae=None, device="cpu", use_fp16=False + self, + font_path, + glyph_channels=1, + position_channels=1, + model_channels=320, + vae=None, + device="cpu", + use_fp16=False, ): super().__init__() self.font = ImageFont.truetype(font_path, 60) @@ -78,12 +85,8 @@ def __init__( self.fuse_block = nn.Conv2d(256 + 64 + 4, model_channels, 3, padding=1) - self.glyph_block.load_state_dict( - load_file("glyph_block.safetensors", device=str(self.device)) - ) - self.position_block.load_state_dict( - load_file("position_block.safetensors", device=str(self.device)) - ) + self.glyph_block.load_state_dict(load_file("glyph_block.safetensors", device=str(self.device))) + self.position_block.load_state_dict(load_file("position_block.safetensors", device=str(self.device))) self.fuse_block.load_state_dict(load_file("fuse_block.safetensors", device=str(self.device))) if use_fp16: diff --git a/examples/research_projects/anytext/ocr_recog/RecModel.py b/examples/research_projects/anytext/ocr_recog/RecModel.py index 4c72ecdec3c4..5f0f8f0375f1 100755 --- a/examples/research_projects/anytext/ocr_recog/RecModel.py +++ b/examples/research_projects/anytext/ocr_recog/RecModel.py @@ -35,6 +35,7 @@ def load_3rd_state_dict(self, _3rd_name, _state): def forward(self, x): import torch + x = x.to(torch.float32) x = self.backbone(x) x = self.neck(x) diff --git a/examples/research_projects/anytext/recognizer.py b/examples/research_projects/anytext/recognizer.py index 25d527fdedc3..5cd7e245509d 100755 --- a/examples/research_projects/anytext/recognizer.py +++ b/examples/research_projects/anytext/recognizer.py @@ -13,7 +13,6 @@ import torch.nn.functional as F from easydict import EasyDict as edict from ocr_recog.RecModel import RecModel -from safetensors.torch import load_file from skimage.transform._geometric import _umeyama as get_sym_mat From 4267c84808e4641244fb60ee2bc2eedea444b9c5 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tolga=20Cang=C3=B6z?= Date: Fri, 9 Aug 2024 22:54:52 +0300 Subject: [PATCH 50/87] refactor: Update to() method in FrozenCLIPEmbedderT3 and TextEmbeddingModule --- .../anytext/frozen_clip_embedder_t3.py | 5 +++++ .../research_projects/anytext/text_embedding_module.py | 10 +++++----- 2 files changed, 10 insertions(+), 5 deletions(-) diff --git a/examples/research_projects/anytext/frozen_clip_embedder_t3.py b/examples/research_projects/anytext/frozen_clip_embedder_t3.py index f0446ca52621..00f33109b3d0 100644 --- a/examples/research_projects/anytext/frozen_clip_embedder_t3.py +++ b/examples/research_projects/anytext/frozen_clip_embedder_t3.py @@ -207,3 +207,8 @@ def split_chunks(self, input_ids, chunk_size=75): remaining_group_pad = torch.cat((id_start, remaining_group, padding, id_end), dim=1) tokens_list.append(remaining_group_pad) return tokens_list + + def to(self, *args, **kwargs): + self.transformer = self.transformer.to(*args, **kwargs) + self.device = self.transformer.device + return self diff --git a/examples/research_projects/anytext/text_embedding_module.py b/examples/research_projects/anytext/text_embedding_module.py index 6b62c8d37b27..f992cdd6024d 100644 --- a/examples/research_projects/anytext/text_embedding_module.py +++ b/examples/research_projects/anytext/text_embedding_module.py @@ -297,9 +297,9 @@ def insert_spaces(self, string, nSpace): new_string += char + " " * nSpace return new_string[:-nSpace] - def to(self, device): - self.device = device - self.frozen_CLIP_embedder_t3.to(device) - self.embedding_manager.to(device) - self.text_predictor.to(device) + def to(self, *args, **kwargs): + self.frozen_CLIP_embedder_t3 = self.frozen_CLIP_embedder_t3.to(*args, **kwargs) + self.embedding_manager = self.embedding_manager.to(*args, **kwargs) + self.text_predictor = self.text_predictor.to(*args, **kwargs) + self.device = self.frozen_CLIP_embedder_t3.device return self From ab5122604e1998b0c2d2140b3624ad9936564867 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tolga=20Cang=C3=B6z?= Date: Fri, 9 Aug 2024 22:55:12 +0300 Subject: [PATCH 51/87] refactor: Update dtype in embedding_manager.py to match proj.weight --- examples/research_projects/anytext/embedding_manager.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/research_projects/anytext/embedding_manager.py b/examples/research_projects/anytext/embedding_manager.py index 256b0a60857c..5afda3eed5f1 100644 --- a/examples/research_projects/anytext/embedding_manager.py +++ b/examples/research_projects/anytext/embedding_manager.py @@ -65,7 +65,7 @@ def encode_text(self, text_info): if len(gline_list) > 0: recog_emb = self.get_recog_emb(gline_list) - enc_glyph = self.proj(recog_emb.reshape(recog_emb.shape[0], -1).to(self.proj.weight.device)) + enc_glyph = self.proj(recog_emb.reshape(recog_emb.shape[0], -1).to(self.proj.weight.dtype)) self.text_embs_all = [] n_idx = 0 From 1521e8f538d5f9bebee51c5b119d22cfed57df6c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tolga=20Cang=C3=B6z?= Date: Sat, 10 Aug 2024 20:42:07 +0300 Subject: [PATCH 52/87] Up --- examples/research_projects/anytext/auxiliary_latent_module.py | 2 +- examples/research_projects/anytext/text_embedding_module.py | 4 ---- 2 files changed, 1 insertion(+), 5 deletions(-) diff --git a/examples/research_projects/anytext/auxiliary_latent_module.py b/examples/research_projects/anytext/auxiliary_latent_module.py index fa83db89b206..4394be14ecd3 100644 --- a/examples/research_projects/anytext/auxiliary_latent_module.py +++ b/examples/research_projects/anytext/auxiliary_latent_module.py @@ -81,7 +81,7 @@ def __init__( nn.SiLU(), ) - self.vae = vae.eval() + self.vae = vae.eval() if vae is not None else None self.fuse_block = nn.Conv2d(256 + 64 + 4, model_channels, 3, padding=1) diff --git a/examples/research_projects/anytext/text_embedding_module.py b/examples/research_projects/anytext/text_embedding_module.py index f992cdd6024d..7b6c7c494955 100644 --- a/examples/research_projects/anytext/text_embedding_module.py +++ b/examples/research_projects/anytext/text_embedding_module.py @@ -24,12 +24,8 @@ def __init__(self, font_path, use_fp16=False, device="cpu"): self.font = ImageFont.truetype(font_path, 60) self.frozen_CLIP_embedder_t3 = FrozenCLIPEmbedderT3(device=self.device, use_fp16=self.use_fp16) self.embedding_manager = EmbeddingManager(self.frozen_CLIP_embedder_t3, use_fp16=self.use_fp16) - # for param in self.embedding_manager.embedding_parameters(): - # param.requires_grad = True rec_model_dir = "OCR/ppv3_rec.pth" self.text_predictor = create_predictor(rec_model_dir, device=self.device, use_fp16=self.use_fp16).eval() - for param in self.text_predictor.parameters(): - param.requires_grad = False args = {} args["rec_image_shape"] = "3, 48, 320" args["rec_batch_num"] = 6 From 56992d1167d3676c4b759ca6c21ebf720776ca7b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tolga=20Cang=C3=B6z?= Date: Mon, 7 Oct 2024 11:37:26 +0300 Subject: [PATCH 53/87] Add attribution and adaptation information to pipeline_anytext.py --- examples/research_projects/anytext/pipeline_anytext.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/examples/research_projects/anytext/pipeline_anytext.py b/examples/research_projects/anytext/pipeline_anytext.py index 47f03f7e41f9..7686dec03146 100644 --- a/examples/research_projects/anytext/pipeline_anytext.py +++ b/examples/research_projects/anytext/pipeline_anytext.py @@ -11,6 +11,12 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +# +# Based on [AnyText: Multilingual Visual Text Generation And Editing](https://huggingface.co/papers/2311.03054). +# Authors: Yuxiang Tuo, Wangmeng Xiang, Jun-Yan He, Yifeng Geng, Xuansong Xie +# Code: https://github.com/tyxsspa/AnyText with Apache-2.0 license +# +# Adapted to Diffusers by [M. Tolga Cangöz](https://github.com/tolgacangoz). import inspect From 7ad686584b3e6d1b368f06414c804c159feafef7 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tolga=20Cang=C3=B6z?= Date: Fri, 11 Oct 2024 19:11:23 +0300 Subject: [PATCH 54/87] Update usage example --- .../anytext/pipeline_anytext.py | 16 +++++++--------- 1 file changed, 7 insertions(+), 9 deletions(-) diff --git a/examples/research_projects/anytext/pipeline_anytext.py b/examples/research_projects/anytext/pipeline_anytext.py index 7686dec03146..1ecc4961f2be 100644 --- a/examples/research_projects/anytext/pipeline_anytext.py +++ b/examples/research_projects/anytext/pipeline_anytext.py @@ -75,11 +75,10 @@ >>> import torch >>> # load control net and stable diffusion v1-5 - >>> text_controlnet = TextControlNetModel.from_pretrained("a/TextControlNet", torch_dtype=torch.float16) - >>> pipe = AnyTextPipeline.from_pretrained( - ... "a/AnyText", controlnet=text_controlnet, torch_dtype=torch.float16, - ... variant="fp16" - ... ).to("cuda") + >>> text_controlnet = TextControlNetModel.from_pretrained("tolgacangoz/anytext-controlnet", torch_dtype=torch.float16) + >>> pipe = DiffusionPipeline.from_pretrained("tolgacangoz/anytext", controlnet=text_controlnet, + ... torch_dtype=torch.float16, variant="fp16", + ... ).to("cuda") >>> pipe.scheduler = DDIMScheduler.from_config(pipe.scheduler.config) >>> # uncomment following line if PyTorch>=2.0 is not installed for memory optimization @@ -93,10 +92,9 @@ >>> generator = torch.Generator("cpu").manual_seed(66273235) >>> prompt = 'photo of caramel macchiato coffee on the table, top-down perspective, with "Any" "Text" written on it using cream' >>> draw_pos = load_image("www.huggingface.co/a/AnyText/tree/main/examples/gen9.png") - >>> image = pipe( - ... prompt, num_inference_steps=20, generator=generator, mode="generate", - ... draw_pos=draw_pos - ... ).images[0] + >>> image = pipe(prompt, num_inference_steps=20, generator=generator, mode="generate", + ... draw_pos=draw_pos, + ... ).images[0] >>> image ``` """ From a5edca5db77635166858dd863ed256a45b8ca2f4 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tolga=20Cang=C3=B6z?= Date: Fri, 11 Oct 2024 19:12:02 +0300 Subject: [PATCH 55/87] Will refactor `controlnet_cond_embedding` initialization --- examples/research_projects/anytext/text_controlnet.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/examples/research_projects/anytext/text_controlnet.py b/examples/research_projects/anytext/text_controlnet.py index cc6f2c59835d..c1d81b56f933 100644 --- a/examples/research_projects/anytext/text_controlnet.py +++ b/examples/research_projects/anytext/text_controlnet.py @@ -172,7 +172,9 @@ def __init__( global_pool_conditions, addition_embed_type_num_heads, ) - self.controlnet_cond_embedding = None # This part is computed inside AuxiliaryLatentModel + self.controlnet_cond_embedding = ( + None # TODO: Instead of this, design a custom `ControlNetConditioningEmbedding` + ) def forward( self, From 2f42e40254d7c694e367aa46c0d91e1a070235da Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tolga=20Cang=C3=B6z?= Date: Sun, 13 Oct 2024 15:14:01 +0300 Subject: [PATCH 56/87] Add `AnyTextControlNetConditioningEmbedding` template --- .../anytext/text_controlnet.py | 69 +++++++++++++++++-- 1 file changed, 64 insertions(+), 5 deletions(-) diff --git a/examples/research_projects/anytext/text_controlnet.py b/examples/research_projects/anytext/text_controlnet.py index c1d81b56f933..d860bde57fa4 100644 --- a/examples/research_projects/anytext/text_controlnet.py +++ b/examples/research_projects/anytext/text_controlnet.py @@ -14,6 +14,8 @@ from typing import Any, Dict, Optional, Tuple, Union import torch +import torch.nn.functional as F +from torch import nn from diffusers.configuration_utils import register_to_config from diffusers.models.controlnet import ( @@ -26,6 +28,51 @@ logger = logging.get_logger(__name__) # pylint: disable=invalid-name +class AnyTextControlNetConditioningEmbedding(nn.Module): + """ + Quoting from https://arxiv.org/abs/2302.05543: "Stable Diffusion uses a pre-processing method similar to VQ-GAN + [11] to convert the entire dataset of 512 × 512 images into smaller 64 × 64 “latent images” for stabilized + training. This requires ControlNets to convert image-based conditions to 64 × 64 feature space to match the + convolution size. We use a tiny network E(·) of four convolution layers with 4 × 4 kernels and 2 × 2 strides + (activated by ReLU, channels are 16, 32, 64, 128, initialized with Gaussian weights, trained jointly with the full + model) to encode image-space conditions ... into feature maps ..." + """ + + def __init__( + self, + conditioning_embedding_channels: int, + conditioning_channels: int = 3, + block_out_channels: Tuple[int, ...] = (16, 32, 96, 256), + ): + super().__init__() + + self.conv_in = nn.Conv2d(conditioning_channels, block_out_channels[0], kernel_size=3, padding=1) + + self.blocks = nn.ModuleList([]) + + for i in range(len(block_out_channels) - 1): + channel_in = block_out_channels[i] + channel_out = block_out_channels[i + 1] + self.blocks.append(nn.Conv2d(channel_in, channel_in, kernel_size=3, padding=1)) + self.blocks.append(nn.Conv2d(channel_in, channel_out, kernel_size=3, padding=1, stride=2)) + + self.conv_out = zero_module( + nn.Conv2d(block_out_channels[-1], conditioning_embedding_channels, kernel_size=3, padding=1) + ) + + def forward(self, conditioning): + embedding = self.conv_in(conditioning) + embedding = F.silu(embedding) + + for block in self.blocks: + embedding = block(embedding) + embedding = F.silu(embedding) + + embedding = self.conv_out(embedding) + + return embedding + + class AnyTextControlNetModel(ControlNetModel): """ A AnyTextControlNetModel model. @@ -172,8 +219,13 @@ def __init__( global_pool_conditions, addition_embed_type_num_heads, ) - self.controlnet_cond_embedding = ( - None # TODO: Instead of this, design a custom `ControlNetConditioningEmbedding` + + # control net conditioning embedding + # TODO: what happens ControlNetModel's self.controlnet_cond_embedding's memory occupation? + self.controlnet_cond_embedding = AnyTextControlNetConditioningEmbedding( + conditioning_embedding_channels=block_out_channels[0], + block_out_channels=conditioning_embedding_out_channels, + conditioning_channels=conditioning_channels, ) def forward( @@ -181,7 +233,7 @@ def forward( sample: torch.Tensor, timestep: Union[torch.Tensor, float, int], encoder_hidden_states: torch.Tensor, - guided_hint: torch.Tensor, + controlnet_cond: torch.Tensor, conditioning_scale: float = 1.0, class_labels: Optional[torch.Tensor] = None, timestep_cond: Optional[torch.Tensor] = None, @@ -310,8 +362,8 @@ def forward( # 2. pre-process sample = self.conv_in(sample) - # controlnet_cond = self.controlnet_cond_embedding(controlnet_cond) - sample = sample + guided_hint + controlnet_cond = self.controlnet_cond_embedding(controlnet_cond) + sample = sample + controlnet_cond # 3. down down_block_res_samples = (sample,) @@ -375,3 +427,10 @@ def forward( return ControlNetOutput( down_block_res_samples=down_block_res_samples, mid_block_res_sample=mid_block_res_sample ) + + +# Copied from diffusers.models.controlnet.zero_module +def zero_module(module): + for p in module.parameters(): + nn.init.zeros_(p) + return module From 670fef5df45315099885dd4d0cfd170cd5f87ead Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tolga=20Cang=C3=B6z?= Date: Fri, 18 Oct 2024 11:14:44 +0300 Subject: [PATCH 57/87] Refactor organization --- examples/research_projects/anytext/README.md | 15 +- .../anytext/auxiliary_latent_module.py | 184 ------- .../anytext/pipeline_anytext.py | 494 +++++++++++++++++- .../anytext/text_embedding_module.py | 301 ----------- 4 files changed, 502 insertions(+), 492 deletions(-) delete mode 100644 examples/research_projects/anytext/auxiliary_latent_module.py delete mode 100644 examples/research_projects/anytext/text_embedding_module.py diff --git a/examples/research_projects/anytext/README.md b/examples/research_projects/anytext/README.md index e81142990290..6f7be0e851af 100644 --- a/examples/research_projects/anytext/README.md +++ b/examples/research_projects/anytext/README.md @@ -22,11 +22,13 @@ import torch from pipeline_anytext import AnyTextPipeline from text_controlnet import TextControlNetModel from diffusers import DDIMScheduler +from diffusers.utils import load_image -controlnet = TextControlNetModel.from_pretrained("a/b", subfolder="controlnet", torch_dtype=torch.float16) -model_id = "path-to-model" -pipe = AnyTextPipeline.from_pretrained("a/b", subfolder="base", controlnet=controlnet, torch_dtype=torch.float16, variant="fp16") +controlnet = TextControlNetModel.from_pretrained("tolgacangoz/anytext-controlnet", torch_dtype=torch.float16, + variant="fp16") +pipe = AnyTextPipeline.from_pretrained("tolgacangoz/anytext", controlnet=controlnet, + torch_dtype=torch.float16, variant="fp16") # speed up diffusion process with faster scheduler and memory optimization pipe.scheduler = DDIMScheduler.from_config(pipe.scheduler.config) @@ -34,7 +36,10 @@ pipe.scheduler = DDIMScheduler.from_config(pipe.scheduler.config) #pipe.enable_xformers_memory_efficient_attention() pipe.enable_model_cpu_offload() # generate image -generator = torch.Generator("cpu").manual_seed(0) -image = pipe("photo of caramel macchiato coffee on the table, top-down perspective, with "Any" "Text" written on it using cream", num_inference_steps=20, generator=generator).images[0] +generator = torch.Generator("cpu").manual_seed(66273235) +prompt = 'photo of caramel macchiato coffee on the table, top-down perspective, with "Any" "Text" written on it using cream' +draw_pos = load_image("www.huggingface.co/a/AnyText/tree/main/examples/gen9.png") +image = pipe(prompt, num_inference_steps=20, generator=generator, mode="generate", draw_pos=draw_pos, + ).images[0] image ``` diff --git a/examples/research_projects/anytext/auxiliary_latent_module.py b/examples/research_projects/anytext/auxiliary_latent_module.py deleted file mode 100644 index 4394be14ecd3..000000000000 --- a/examples/research_projects/anytext/auxiliary_latent_module.py +++ /dev/null @@ -1,184 +0,0 @@ -from typing import Optional - -import cv2 -import numpy as np -import torch -from PIL import ImageFont -from safetensors.torch import load_file -from torch import nn - -from diffusers.utils import logging - - -logger = logging.get_logger(__name__) # pylint: disable=invalid-name - - -# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents -def retrieve_latents( - encoder_output: torch.Tensor, generator: Optional[torch.Generator] = None, sample_mode: str = "sample" -): - if hasattr(encoder_output, "latent_dist") and sample_mode == "sample": - return encoder_output.latent_dist.sample(generator) - elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax": - return encoder_output.latent_dist.mode() - elif hasattr(encoder_output, "latents"): - return encoder_output.latents - else: - raise AttributeError("Could not access latents of provided encoder_output") - - -class AuxiliaryLatentModule(nn.Module): - def __init__( - self, - font_path, - glyph_channels=1, - position_channels=1, - model_channels=320, - vae=None, - device="cpu", - use_fp16=False, - ): - super().__init__() - self.font = ImageFont.truetype(font_path, 60) - self.use_fp16 = use_fp16 - self.device = device - - self.glyph_block = nn.Sequential( - nn.Conv2d(glyph_channels, 8, 3, padding=1), - nn.SiLU(), - nn.Conv2d(8, 8, 3, padding=1), - nn.SiLU(), - nn.Conv2d(8, 16, 3, padding=1, stride=2), - nn.SiLU(), - nn.Conv2d(16, 16, 3, padding=1), - nn.SiLU(), - nn.Conv2d(16, 32, 3, padding=1, stride=2), - nn.SiLU(), - nn.Conv2d(32, 32, 3, padding=1), - nn.SiLU(), - nn.Conv2d(32, 96, 3, padding=1, stride=2), - nn.SiLU(), - nn.Conv2d(96, 96, 3, padding=1), - nn.SiLU(), - nn.Conv2d(96, 256, 3, padding=1, stride=2), - nn.SiLU(), - ) - - self.position_block = nn.Sequential( - nn.Conv2d(position_channels, 8, 3, padding=1), - nn.SiLU(), - nn.Conv2d(8, 8, 3, padding=1), - nn.SiLU(), - nn.Conv2d(8, 16, 3, padding=1, stride=2), - nn.SiLU(), - nn.Conv2d(16, 16, 3, padding=1), - nn.SiLU(), - nn.Conv2d(16, 32, 3, padding=1, stride=2), - nn.SiLU(), - nn.Conv2d(32, 32, 3, padding=1), - nn.SiLU(), - nn.Conv2d(32, 64, 3, padding=1, stride=2), - nn.SiLU(), - ) - - self.vae = vae.eval() if vae is not None else None - - self.fuse_block = nn.Conv2d(256 + 64 + 4, model_channels, 3, padding=1) - - self.glyph_block.load_state_dict(load_file("glyph_block.safetensors", device=str(self.device))) - self.position_block.load_state_dict(load_file("position_block.safetensors", device=str(self.device))) - self.fuse_block.load_state_dict(load_file("fuse_block.safetensors", device=str(self.device))) - - if use_fp16: - self.glyph_block = self.glyph_block.to(dtype=torch.float16) - self.position_block = self.position_block.to(dtype=torch.float16) - self.fuse_block = self.fuse_block.to(dtype=torch.float16) - - @torch.no_grad() - def forward( - self, - text_info, - mode, - draw_pos, - ori_image, - num_images_per_prompt, - np_hint, - h=512, - w=512, - ): - if mode == "generate": - edit_image = np.ones((h, w, 3)) * 127.5 # empty mask image - elif mode == "edit": - if draw_pos is None or ori_image is None: - raise ValueError("Reference image and position image are needed for text editing!") - if isinstance(ori_image, str): - ori_image = cv2.imread(ori_image)[..., ::-1] - if ori_image is None: - raise ValueError(f"Can't read ori_image image from {ori_image}!") - elif isinstance(ori_image, torch.Tensor): - ori_image = ori_image.cpu().numpy() - else: - if not isinstance(ori_image, np.ndarray): - raise ValueError(f"Unknown format of ori_image: {type(ori_image)}") - edit_image = ori_image.clip(1, 255) # for mask reason - edit_image = self.check_channels(edit_image) - edit_image = self.resize_image( - edit_image, max_length=768 - ) # make w h multiple of 64, resize if w or h > max_length - - # get masked_x - masked_img = ((edit_image.astype(np.float32) / 127.5) - 1.0) * (1 - np_hint) - masked_img = np.transpose(masked_img, (2, 0, 1)) - masked_img = torch.from_numpy(masked_img.copy()).float().to(self.device) - if self.use_fp16: - masked_img = masked_img.half() - masked_x = (retrieve_latents(self.vae.encode(masked_img[None, ...])) * self.vae.config.scaling_factor).detach() - if self.use_fp16: - masked_x = masked_x.half() - text_info["masked_x"] = torch.cat([masked_x for _ in range(num_images_per_prompt)], dim=0) - - glyphs = torch.cat(text_info["glyphs"], dim=1).sum(dim=1, keepdim=True) - positions = torch.cat(text_info["positions"], dim=1).sum(dim=1, keepdim=True) - enc_glyph = self.glyph_block(glyphs) - enc_pos = self.position_block(positions) - guided_hint = self.fuse_block(torch.cat([enc_glyph, enc_pos, text_info["masked_x"]], dim=1)) - - return guided_hint - - def check_channels(self, image): - channels = image.shape[2] if len(image.shape) == 3 else 1 - if channels == 1: - image = cv2.cvtColor(image, cv2.COLOR_GRAY2BGR) - elif channels > 3: - image = image[:, :, :3] - return image - - def resize_image(self, img, max_length=768): - height, width = img.shape[:2] - max_dimension = max(height, width) - - if max_dimension > max_length: - scale_factor = max_length / max_dimension - new_width = int(round(width * scale_factor)) - new_height = int(round(height * scale_factor)) - new_size = (new_width, new_height) - img = cv2.resize(img, new_size) - height, width = img.shape[:2] - img = cv2.resize(img, (width - (width % 64), height - (height % 64))) - return img - - def insert_spaces(self, string, nSpace): - if nSpace == 0: - return string - new_string = "" - for char in string: - new_string += char + " " * nSpace - return new_string[:-nSpace] - - def to(self, device): - self.device = device - self.glyph_block = self.glyph_block.to(device) - self.position_block = self.position_block.to(device) - self.fuse_block = self.fuse_block.to(device) - self.vae = self.vae.to(device) - return self diff --git a/examples/research_projects/anytext/pipeline_anytext.py b/examples/research_projects/anytext/pipeline_anytext.py index 1ecc4961f2be..578a7a42c653 100644 --- a/examples/research_projects/anytext/pipeline_anytext.py +++ b/examples/research_projects/anytext/pipeline_anytext.py @@ -27,7 +27,6 @@ import PIL.Image import torch import torch.nn.functional as F -from auxiliary_latent_module import AuxiliaryLatentModule from bert_tokenizer import BasicTokenizer from text_embedding_module import TextEmbeddingModule from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer, CLIPVisionModelWithProjection @@ -75,7 +74,8 @@ >>> import torch >>> # load control net and stable diffusion v1-5 - >>> text_controlnet = TextControlNetModel.from_pretrained("tolgacangoz/anytext-controlnet", torch_dtype=torch.float16) + >>> text_controlnet = TextControlNetModel.from_pretrained("tolgacangoz/anytext-controlnet", torch_dtype=torch.float16, + ... variant="fp16",) >>> pipe = DiffusionPipeline.from_pretrained("tolgacangoz/anytext", controlnet=text_controlnet, ... torch_dtype=torch.float16, variant="fp16", ... ).to("cuda") @@ -100,6 +100,496 @@ """ +import cv2 +import numpy as np +import torch +from embedding_manager import EmbeddingManager +from frozen_clip_embedder_t3 import FrozenCLIPEmbedderT3 +from PIL import Image, ImageDraw, ImageFont +from recognizer import TextRecognizer, create_predictor +from torch import nn + +from diffusers.utils import ( + logging, +) + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +class TextEmbeddingModule(nn.Module): + def __init__(self, font_path, use_fp16=False, device="cpu"): + super().__init__() + self.use_fp16 = use_fp16 + self.device = device + # TODO: Learn if the recommended font file is free to use + self.font = ImageFont.truetype(font_path, 60) + self.frozen_CLIP_embedder_t3 = FrozenCLIPEmbedderT3(device=self.device, use_fp16=self.use_fp16) + self.embedding_manager = EmbeddingManager(self.frozen_CLIP_embedder_t3, use_fp16=self.use_fp16) + rec_model_dir = "OCR/ppv3_rec.pth" + self.text_predictor = create_predictor(rec_model_dir, device=self.device, use_fp16=self.use_fp16).eval() + args = {} + args["rec_image_shape"] = "3, 48, 320" + args["rec_batch_num"] = 6 + args["rec_char_dict_path"] = "OCR/ppocr_keys_v1.txt" + args["use_fp16"] = self.use_fp16 + self.embedding_manager.recog = TextRecognizer(args, self.text_predictor) + + @torch.no_grad() + def forward( + self, + prompt, + texts, + negative_prompt, + num_images_per_prompt, + mode, + draw_pos, + sort_priority="↕", + max_chars=77, + revise_pos=False, + h=512, + w=512, + ): + if prompt is None and texts is None: + raise ValueError("Prompt or texts must be provided!") + # preprocess pos_imgs(if numpy, make sure it's white pos in black bg) + if draw_pos is None: + pos_imgs = np.zeros((w, h, 1)) + if isinstance(draw_pos, str): + draw_pos = cv2.imread(draw_pos)[..., ::-1] + if draw_pos is None: + raise ValueError(f"Can't read draw_pos image from {draw_pos}!") + pos_imgs = 255 - draw_pos + elif isinstance(draw_pos, torch.Tensor): + pos_imgs = draw_pos.cpu().numpy() + else: + if not isinstance(draw_pos, np.ndarray): + raise ValueError(f"Unknown format of draw_pos: {type(draw_pos)}") + if mode == "edit": + pos_imgs = cv2.resize(pos_imgs, (w, h)) + pos_imgs = pos_imgs[..., 0:1] + pos_imgs = cv2.convertScaleAbs(pos_imgs) + _, pos_imgs = cv2.threshold(pos_imgs, 254, 255, cv2.THRESH_BINARY) + # separate pos_imgs + pos_imgs = self.separate_pos_imgs(pos_imgs, sort_priority) + if len(pos_imgs) == 0: + pos_imgs = [np.zeros((h, w, 1))] + n_lines = len(texts) + if len(pos_imgs) < n_lines: + if n_lines == 1 and texts[0] == " ": + pass # text-to-image without text + else: + raise ValueError( + f"Found {len(pos_imgs)} positions that < needed {n_lines} from prompt, check and try again!" + ) + elif len(pos_imgs) > n_lines: + str_warning = f"Warning: found {len(pos_imgs)} positions that > needed {n_lines} from prompt." + logger.warning(str_warning) + # get pre_pos, poly_list, hint that needed for anytext + pre_pos = [] + poly_list = [] + for input_pos in pos_imgs: + if input_pos.mean() != 0: + input_pos = input_pos[..., np.newaxis] if len(input_pos.shape) == 2 else input_pos + poly, pos_img = self.find_polygon(input_pos) + pre_pos += [pos_img / 255.0] + poly_list += [poly] + else: + pre_pos += [np.zeros((h, w, 1))] + poly_list += [None] + np_hint = np.sum(pre_pos, axis=0).clip(0, 1) + # prepare info dict + text_info = {} + text_info["glyphs"] = [] + text_info["gly_line"] = [] + text_info["positions"] = [] + text_info["n_lines"] = [len(texts)] * num_images_per_prompt + for i in range(len(texts)): + text = texts[i] + if len(text) > max_chars: + str_warning = f'"{text}" length > max_chars: {max_chars}, will be cut off...' + logger.warning(str_warning) + text = text[:max_chars] + gly_scale = 2 + if pre_pos[i].mean() != 0: + gly_line = self.draw_glyph(self.font, text) + glyphs = self.draw_glyph2( + self.font, text, poly_list[i], scale=gly_scale, width=w, height=h, add_space=False + ) + if revise_pos: + resize_gly = cv2.resize(glyphs, (pre_pos[i].shape[1], pre_pos[i].shape[0])) + new_pos = cv2.morphologyEx( + (resize_gly * 255).astype(np.uint8), + cv2.MORPH_CLOSE, + kernel=np.ones((resize_gly.shape[0] // 10, resize_gly.shape[1] // 10), dtype=np.uint8), + iterations=1, + ) + new_pos = new_pos[..., np.newaxis] if len(new_pos.shape) == 2 else new_pos + contours, _ = cv2.findContours(new_pos, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_NONE) + if len(contours) != 1: + str_warning = f"Fail to revise position {i} to bounding rect, remain position unchanged..." + logger.warning(str_warning) + else: + rect = cv2.minAreaRect(contours[0]) + poly = np.int0(cv2.boxPoints(rect)) + pre_pos[i] = cv2.drawContours(new_pos, [poly], -1, 255, -1) / 255.0 + else: + glyphs = np.zeros((h * gly_scale, w * gly_scale, 1)) + gly_line = np.zeros((80, 512, 1)) + pos = pre_pos[i] + text_info["glyphs"] += [self.arr2tensor(glyphs, num_images_per_prompt)] + text_info["gly_line"] += [self.arr2tensor(gly_line, num_images_per_prompt)] + text_info["positions"] += [self.arr2tensor(pos, num_images_per_prompt)] + + # hint = self.arr2tensor(np_hint, len(prompt)) + + self.embedding_manager.encode_text(text_info) + prompt_embeds = self.frozen_CLIP_embedder_t3.encode([prompt], embedding_manager=self.embedding_manager) + + self.embedding_manager.encode_text(text_info) + negative_prompt_embeds = self.frozen_CLIP_embedder_t3.encode( + [negative_prompt], embedding_manager=self.embedding_manager + ) + + return prompt_embeds, negative_prompt_embeds, text_info, np_hint + + def arr2tensor(self, arr, bs): + arr = np.transpose(arr, (2, 0, 1)) + _arr = torch.from_numpy(arr.copy()).float().cpu() + if self.use_fp16: + _arr = _arr.half() + _arr = torch.stack([_arr for _ in range(bs)], dim=0) + return _arr + + def separate_pos_imgs(self, img, sort_priority, gap=102): + num_labels, labels, stats, centroids = cv2.connectedComponentsWithStats(img) + components = [] + for label in range(1, num_labels): + component = np.zeros_like(img) + component[labels == label] = 255 + components.append((component, centroids[label])) + if sort_priority == "↕": + fir, sec = 1, 0 # top-down first + elif sort_priority == "↔": + fir, sec = 0, 1 # left-right first + else: + raise ValueError(f"Unknown sort_priority: {sort_priority}") + components.sort(key=lambda c: (c[1][fir] // gap, c[1][sec] // gap)) + sorted_components = [c[0] for c in components] + return sorted_components + + def find_polygon(self, image, min_rect=False): + contours, hierarchy = cv2.findContours(image, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_NONE) + max_contour = max(contours, key=cv2.contourArea) # get contour with max area + if min_rect: + # get minimum enclosing rectangle + rect = cv2.minAreaRect(max_contour) + poly = np.int0(cv2.boxPoints(rect)) + else: + # get approximate polygon + epsilon = 0.01 * cv2.arcLength(max_contour, True) + poly = cv2.approxPolyDP(max_contour, epsilon, True) + n, _, xy = poly.shape + poly = poly.reshape(n, xy) + cv2.drawContours(image, [poly], -1, 255, -1) + return poly, image + + def draw_glyph(self, font, text): + g_size = 50 + W, H = (512, 80) + new_font = font.font_variant(size=g_size) + img = Image.new(mode="1", size=(W, H), color=0) + draw = ImageDraw.Draw(img) + left, top, right, bottom = new_font.getbbox(text) + text_width = max(right - left, 5) + text_height = max(bottom - top, 5) + ratio = min(W * 0.9 / text_width, H * 0.9 / text_height) + new_font = font.font_variant(size=int(g_size * ratio)) + + text_width, text_height = new_font.getsize(text) + offset_x, offset_y = new_font.getoffset(text) + x = (img.width - text_width) // 2 + y = (img.height - text_height) // 2 - offset_y // 2 + draw.text((x, y), text, font=new_font, fill="white") + img = np.expand_dims(np.array(img), axis=2).astype(np.float64) + return img + + def draw_glyph2(self, font, text, polygon, vertAng=10, scale=1, width=512, height=512, add_space=True): + enlarge_polygon = polygon * scale + rect = cv2.minAreaRect(enlarge_polygon) + box = cv2.boxPoints(rect) + box = np.int0(box) + w, h = rect[1] + angle = rect[2] + if angle < -45: + angle += 90 + angle = -angle + if w < h: + angle += 90 + + vert = False + if abs(angle) % 90 < vertAng or abs(90 - abs(angle) % 90) % 90 < vertAng: + _w = max(box[:, 0]) - min(box[:, 0]) + _h = max(box[:, 1]) - min(box[:, 1]) + if _h >= _w: + vert = True + angle = 0 + + img = np.zeros((height * scale, width * scale, 3), np.uint8) + img = Image.fromarray(img) + + # infer font size + image4ratio = Image.new("RGB", img.size, "white") + draw = ImageDraw.Draw(image4ratio) + _, _, _tw, _th = draw.textbbox(xy=(0, 0), text=text, font=font) + text_w = min(w, h) * (_tw / _th) + if text_w <= max(w, h): + # add space + if len(text) > 1 and not vert and add_space: + for i in range(1, 100): + text_space = self.insert_spaces(text, i) + _, _, _tw2, _th2 = draw.textbbox(xy=(0, 0), text=text_space, font=font) + if min(w, h) * (_tw2 / _th2) > max(w, h): + break + text = self.insert_spaces(text, i - 1) + font_size = min(w, h) * 0.80 + else: + shrink = 0.75 if vert else 0.85 + font_size = min(w, h) / (text_w / max(w, h)) * shrink + new_font = font.font_variant(size=int(font_size)) + + left, top, right, bottom = new_font.getbbox(text) + text_width = right - left + text_height = bottom - top + + layer = Image.new("RGBA", img.size, (0, 0, 0, 0)) + draw = ImageDraw.Draw(layer) + if not vert: + draw.text( + (rect[0][0] - text_width // 2, rect[0][1] - text_height // 2 - top), + text, + font=new_font, + fill=(255, 255, 255, 255), + ) + else: + x_s = min(box[:, 0]) + _w // 2 - text_height // 2 + y_s = min(box[:, 1]) + for c in text: + draw.text((x_s, y_s), c, font=new_font, fill=(255, 255, 255, 255)) + _, _t, _, _b = new_font.getbbox(c) + y_s += _b + + rotated_layer = layer.rotate(angle, expand=1, center=(rect[0][0], rect[0][1])) + + x_offset = int((img.width - rotated_layer.width) / 2) + y_offset = int((img.height - rotated_layer.height) / 2) + img.paste(rotated_layer, (x_offset, y_offset), rotated_layer) + img = np.expand_dims(np.array(img.convert("1")), axis=2).astype(np.float64) + return img + + def insert_spaces(self, string, nSpace): + if nSpace == 0: + return string + new_string = "" + for char in string: + new_string += char + " " * nSpace + return new_string[:-nSpace] + + def to(self, *args, **kwargs): + self.frozen_CLIP_embedder_t3 = self.frozen_CLIP_embedder_t3.to(*args, **kwargs) + self.embedding_manager = self.embedding_manager.to(*args, **kwargs) + self.text_predictor = self.text_predictor.to(*args, **kwargs) + self.device = self.frozen_CLIP_embedder_t3.device + return self + + +from typing import Optional + +import cv2 +import numpy as np +import torch +from PIL import ImageFont +from safetensors.torch import load_file +from torch import nn + +from diffusers.utils import logging + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents +def retrieve_latents( + encoder_output: torch.Tensor, generator: Optional[torch.Generator] = None, sample_mode: str = "sample" +): + if hasattr(encoder_output, "latent_dist") and sample_mode == "sample": + return encoder_output.latent_dist.sample(generator) + elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax": + return encoder_output.latent_dist.mode() + elif hasattr(encoder_output, "latents"): + return encoder_output.latents + else: + raise AttributeError("Could not access latents of provided encoder_output") + + +class AuxiliaryLatentModule(nn.Module): + def __init__( + self, + font_path, + glyph_channels=1, + position_channels=1, + model_channels=320, + vae=None, + device="cpu", + use_fp16=False, + ): + super().__init__() + self.font = ImageFont.truetype(font_path, 60) + self.use_fp16 = use_fp16 + self.device = device + + self.glyph_block = nn.Sequential( + nn.Conv2d(glyph_channels, 8, 3, padding=1), + nn.SiLU(), + nn.Conv2d(8, 8, 3, padding=1), + nn.SiLU(), + nn.Conv2d(8, 16, 3, padding=1, stride=2), + nn.SiLU(), + nn.Conv2d(16, 16, 3, padding=1), + nn.SiLU(), + nn.Conv2d(16, 32, 3, padding=1, stride=2), + nn.SiLU(), + nn.Conv2d(32, 32, 3, padding=1), + nn.SiLU(), + nn.Conv2d(32, 96, 3, padding=1, stride=2), + nn.SiLU(), + nn.Conv2d(96, 96, 3, padding=1), + nn.SiLU(), + nn.Conv2d(96, 256, 3, padding=1, stride=2), + nn.SiLU(), + ) + + self.position_block = nn.Sequential( + nn.Conv2d(position_channels, 8, 3, padding=1), + nn.SiLU(), + nn.Conv2d(8, 8, 3, padding=1), + nn.SiLU(), + nn.Conv2d(8, 16, 3, padding=1, stride=2), + nn.SiLU(), + nn.Conv2d(16, 16, 3, padding=1), + nn.SiLU(), + nn.Conv2d(16, 32, 3, padding=1, stride=2), + nn.SiLU(), + nn.Conv2d(32, 32, 3, padding=1), + nn.SiLU(), + nn.Conv2d(32, 64, 3, padding=1, stride=2), + nn.SiLU(), + ) + + self.vae = vae.eval() if vae is not None else None + + self.fuse_block = nn.Conv2d(256 + 64 + 4, model_channels, 3, padding=1) + + self.glyph_block.load_state_dict(load_file("glyph_block.safetensors", device=str(self.device))) + self.position_block.load_state_dict(load_file("position_block.safetensors", device=str(self.device))) + self.fuse_block.load_state_dict(load_file("fuse_block.safetensors", device=str(self.device))) + + if use_fp16: + self.glyph_block = self.glyph_block.to(dtype=torch.float16) + self.position_block = self.position_block.to(dtype=torch.float16) + self.fuse_block = self.fuse_block.to(dtype=torch.float16) + + @torch.no_grad() + def forward( + self, + text_info, + mode, + draw_pos, + ori_image, + num_images_per_prompt, + np_hint, + h=512, + w=512, + ): + if mode == "generate": + edit_image = np.ones((h, w, 3)) * 127.5 # empty mask image + elif mode == "edit": + if draw_pos is None or ori_image is None: + raise ValueError("Reference image and position image are needed for text editing!") + if isinstance(ori_image, str): + ori_image = cv2.imread(ori_image)[..., ::-1] + if ori_image is None: + raise ValueError(f"Can't read ori_image image from {ori_image}!") + elif isinstance(ori_image, torch.Tensor): + ori_image = ori_image.cpu().numpy() + else: + if not isinstance(ori_image, np.ndarray): + raise ValueError(f"Unknown format of ori_image: {type(ori_image)}") + edit_image = ori_image.clip(1, 255) # for mask reason + edit_image = self.check_channels(edit_image) + edit_image = self.resize_image( + edit_image, max_length=768 + ) # make w h multiple of 64, resize if w or h > max_length + + # get masked_x + masked_img = ((edit_image.astype(np.float32) / 127.5) - 1.0) * (1 - np_hint) + masked_img = np.transpose(masked_img, (2, 0, 1)) + masked_img = torch.from_numpy(masked_img.copy()).float().to(self.device) + if self.use_fp16: + masked_img = masked_img.half() + masked_x = (retrieve_latents(self.vae.encode(masked_img[None, ...])) * self.vae.config.scaling_factor).detach() + if self.use_fp16: + masked_x = masked_x.half() + text_info["masked_x"] = torch.cat([masked_x for _ in range(num_images_per_prompt)], dim=0) + + glyphs = torch.cat(text_info["glyphs"], dim=1).sum(dim=1, keepdim=True) + positions = torch.cat(text_info["positions"], dim=1).sum(dim=1, keepdim=True) + enc_glyph = self.glyph_block(glyphs) + enc_pos = self.position_block(positions) + guided_hint = self.fuse_block(torch.cat([enc_glyph, enc_pos, text_info["masked_x"]], dim=1)) + + return guided_hint + + def check_channels(self, image): + channels = image.shape[2] if len(image.shape) == 3 else 1 + if channels == 1: + image = cv2.cvtColor(image, cv2.COLOR_GRAY2BGR) + elif channels > 3: + image = image[:, :, :3] + return image + + def resize_image(self, img, max_length=768): + height, width = img.shape[:2] + max_dimension = max(height, width) + + if max_dimension > max_length: + scale_factor = max_length / max_dimension + new_width = int(round(width * scale_factor)) + new_height = int(round(height * scale_factor)) + new_size = (new_width, new_height) + img = cv2.resize(img, new_size) + height, width = img.shape[:2] + img = cv2.resize(img, (width - (width % 64), height - (height % 64))) + return img + + def insert_spaces(self, string, nSpace): + if nSpace == 0: + return string + new_string = "" + for char in string: + new_string += char + " " * nSpace + return new_string[:-nSpace] + + def to(self, device): + self.device = device + self.glyph_block = self.glyph_block.to(device) + self.position_block = self.position_block.to(device) + self.fuse_block = self.fuse_block.to(device) + self.vae = self.vae.to(device) + return self + + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps def retrieve_timesteps( scheduler, diff --git a/examples/research_projects/anytext/text_embedding_module.py b/examples/research_projects/anytext/text_embedding_module.py deleted file mode 100644 index 7b6c7c494955..000000000000 --- a/examples/research_projects/anytext/text_embedding_module.py +++ /dev/null @@ -1,301 +0,0 @@ -import cv2 -import numpy as np -import torch -from embedding_manager import EmbeddingManager -from frozen_clip_embedder_t3 import FrozenCLIPEmbedderT3 -from PIL import Image, ImageDraw, ImageFont -from recognizer import TextRecognizer, create_predictor -from torch import nn - -from diffusers.utils import ( - logging, -) - - -logger = logging.get_logger(__name__) # pylint: disable=invalid-name - - -class TextEmbeddingModule(nn.Module): - def __init__(self, font_path, use_fp16=False, device="cpu"): - super().__init__() - self.use_fp16 = use_fp16 - self.device = device - # TODO: Learn if the recommended font file is free to use - self.font = ImageFont.truetype(font_path, 60) - self.frozen_CLIP_embedder_t3 = FrozenCLIPEmbedderT3(device=self.device, use_fp16=self.use_fp16) - self.embedding_manager = EmbeddingManager(self.frozen_CLIP_embedder_t3, use_fp16=self.use_fp16) - rec_model_dir = "OCR/ppv3_rec.pth" - self.text_predictor = create_predictor(rec_model_dir, device=self.device, use_fp16=self.use_fp16).eval() - args = {} - args["rec_image_shape"] = "3, 48, 320" - args["rec_batch_num"] = 6 - args["rec_char_dict_path"] = "OCR/ppocr_keys_v1.txt" - args["use_fp16"] = self.use_fp16 - self.embedding_manager.recog = TextRecognizer(args, self.text_predictor) - - @torch.no_grad() - def forward( - self, - prompt, - texts, - negative_prompt, - num_images_per_prompt, - mode, - draw_pos, - sort_priority="↕", - max_chars=77, - revise_pos=False, - h=512, - w=512, - ): - if prompt is None and texts is None: - raise ValueError("Prompt or texts must be provided!") - # preprocess pos_imgs(if numpy, make sure it's white pos in black bg) - if draw_pos is None: - pos_imgs = np.zeros((w, h, 1)) - if isinstance(draw_pos, str): - draw_pos = cv2.imread(draw_pos)[..., ::-1] - if draw_pos is None: - raise ValueError(f"Can't read draw_pos image from {draw_pos}!") - pos_imgs = 255 - draw_pos - elif isinstance(draw_pos, torch.Tensor): - pos_imgs = draw_pos.cpu().numpy() - else: - if not isinstance(draw_pos, np.ndarray): - raise ValueError(f"Unknown format of draw_pos: {type(draw_pos)}") - if mode == "edit": - pos_imgs = cv2.resize(pos_imgs, (w, h)) - pos_imgs = pos_imgs[..., 0:1] - pos_imgs = cv2.convertScaleAbs(pos_imgs) - _, pos_imgs = cv2.threshold(pos_imgs, 254, 255, cv2.THRESH_BINARY) - # separate pos_imgs - pos_imgs = self.separate_pos_imgs(pos_imgs, sort_priority) - if len(pos_imgs) == 0: - pos_imgs = [np.zeros((h, w, 1))] - n_lines = len(texts) - if len(pos_imgs) < n_lines: - if n_lines == 1 and texts[0] == " ": - pass # text-to-image without text - else: - raise ValueError( - f"Found {len(pos_imgs)} positions that < needed {n_lines} from prompt, check and try again!" - ) - elif len(pos_imgs) > n_lines: - str_warning = f"Warning: found {len(pos_imgs)} positions that > needed {n_lines} from prompt." - logger.warning(str_warning) - # get pre_pos, poly_list, hint that needed for anytext - pre_pos = [] - poly_list = [] - for input_pos in pos_imgs: - if input_pos.mean() != 0: - input_pos = input_pos[..., np.newaxis] if len(input_pos.shape) == 2 else input_pos - poly, pos_img = self.find_polygon(input_pos) - pre_pos += [pos_img / 255.0] - poly_list += [poly] - else: - pre_pos += [np.zeros((h, w, 1))] - poly_list += [None] - np_hint = np.sum(pre_pos, axis=0).clip(0, 1) - # prepare info dict - text_info = {} - text_info["glyphs"] = [] - text_info["gly_line"] = [] - text_info["positions"] = [] - text_info["n_lines"] = [len(texts)] * num_images_per_prompt - for i in range(len(texts)): - text = texts[i] - if len(text) > max_chars: - str_warning = f'"{text}" length > max_chars: {max_chars}, will be cut off...' - logger.warning(str_warning) - text = text[:max_chars] - gly_scale = 2 - if pre_pos[i].mean() != 0: - gly_line = self.draw_glyph(self.font, text) - glyphs = self.draw_glyph2( - self.font, text, poly_list[i], scale=gly_scale, width=w, height=h, add_space=False - ) - if revise_pos: - resize_gly = cv2.resize(glyphs, (pre_pos[i].shape[1], pre_pos[i].shape[0])) - new_pos = cv2.morphologyEx( - (resize_gly * 255).astype(np.uint8), - cv2.MORPH_CLOSE, - kernel=np.ones((resize_gly.shape[0] // 10, resize_gly.shape[1] // 10), dtype=np.uint8), - iterations=1, - ) - new_pos = new_pos[..., np.newaxis] if len(new_pos.shape) == 2 else new_pos - contours, _ = cv2.findContours(new_pos, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_NONE) - if len(contours) != 1: - str_warning = f"Fail to revise position {i} to bounding rect, remain position unchanged..." - logger.warning(str_warning) - else: - rect = cv2.minAreaRect(contours[0]) - poly = np.int0(cv2.boxPoints(rect)) - pre_pos[i] = cv2.drawContours(new_pos, [poly], -1, 255, -1) / 255.0 - else: - glyphs = np.zeros((h * gly_scale, w * gly_scale, 1)) - gly_line = np.zeros((80, 512, 1)) - pos = pre_pos[i] - text_info["glyphs"] += [self.arr2tensor(glyphs, num_images_per_prompt)] - text_info["gly_line"] += [self.arr2tensor(gly_line, num_images_per_prompt)] - text_info["positions"] += [self.arr2tensor(pos, num_images_per_prompt)] - - # hint = self.arr2tensor(np_hint, len(prompt)) - - self.embedding_manager.encode_text(text_info) - prompt_embeds = self.frozen_CLIP_embedder_t3.encode([prompt], embedding_manager=self.embedding_manager) - - self.embedding_manager.encode_text(text_info) - negative_prompt_embeds = self.frozen_CLIP_embedder_t3.encode( - [negative_prompt], embedding_manager=self.embedding_manager - ) - - return prompt_embeds, negative_prompt_embeds, text_info, np_hint - - def arr2tensor(self, arr, bs): - arr = np.transpose(arr, (2, 0, 1)) - _arr = torch.from_numpy(arr.copy()).float().cpu() - if self.use_fp16: - _arr = _arr.half() - _arr = torch.stack([_arr for _ in range(bs)], dim=0) - return _arr - - def separate_pos_imgs(self, img, sort_priority, gap=102): - num_labels, labels, stats, centroids = cv2.connectedComponentsWithStats(img) - components = [] - for label in range(1, num_labels): - component = np.zeros_like(img) - component[labels == label] = 255 - components.append((component, centroids[label])) - if sort_priority == "↕": - fir, sec = 1, 0 # top-down first - elif sort_priority == "↔": - fir, sec = 0, 1 # left-right first - else: - raise ValueError(f"Unknown sort_priority: {sort_priority}") - components.sort(key=lambda c: (c[1][fir] // gap, c[1][sec] // gap)) - sorted_components = [c[0] for c in components] - return sorted_components - - def find_polygon(self, image, min_rect=False): - contours, hierarchy = cv2.findContours(image, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_NONE) - max_contour = max(contours, key=cv2.contourArea) # get contour with max area - if min_rect: - # get minimum enclosing rectangle - rect = cv2.minAreaRect(max_contour) - poly = np.int0(cv2.boxPoints(rect)) - else: - # get approximate polygon - epsilon = 0.01 * cv2.arcLength(max_contour, True) - poly = cv2.approxPolyDP(max_contour, epsilon, True) - n, _, xy = poly.shape - poly = poly.reshape(n, xy) - cv2.drawContours(image, [poly], -1, 255, -1) - return poly, image - - def draw_glyph(self, font, text): - g_size = 50 - W, H = (512, 80) - new_font = font.font_variant(size=g_size) - img = Image.new(mode="1", size=(W, H), color=0) - draw = ImageDraw.Draw(img) - left, top, right, bottom = new_font.getbbox(text) - text_width = max(right - left, 5) - text_height = max(bottom - top, 5) - ratio = min(W * 0.9 / text_width, H * 0.9 / text_height) - new_font = font.font_variant(size=int(g_size * ratio)) - - text_width, text_height = new_font.getsize(text) - offset_x, offset_y = new_font.getoffset(text) - x = (img.width - text_width) // 2 - y = (img.height - text_height) // 2 - offset_y // 2 - draw.text((x, y), text, font=new_font, fill="white") - img = np.expand_dims(np.array(img), axis=2).astype(np.float64) - return img - - def draw_glyph2(self, font, text, polygon, vertAng=10, scale=1, width=512, height=512, add_space=True): - enlarge_polygon = polygon * scale - rect = cv2.minAreaRect(enlarge_polygon) - box = cv2.boxPoints(rect) - box = np.int0(box) - w, h = rect[1] - angle = rect[2] - if angle < -45: - angle += 90 - angle = -angle - if w < h: - angle += 90 - - vert = False - if abs(angle) % 90 < vertAng or abs(90 - abs(angle) % 90) % 90 < vertAng: - _w = max(box[:, 0]) - min(box[:, 0]) - _h = max(box[:, 1]) - min(box[:, 1]) - if _h >= _w: - vert = True - angle = 0 - - img = np.zeros((height * scale, width * scale, 3), np.uint8) - img = Image.fromarray(img) - - # infer font size - image4ratio = Image.new("RGB", img.size, "white") - draw = ImageDraw.Draw(image4ratio) - _, _, _tw, _th = draw.textbbox(xy=(0, 0), text=text, font=font) - text_w = min(w, h) * (_tw / _th) - if text_w <= max(w, h): - # add space - if len(text) > 1 and not vert and add_space: - for i in range(1, 100): - text_space = self.insert_spaces(text, i) - _, _, _tw2, _th2 = draw.textbbox(xy=(0, 0), text=text_space, font=font) - if min(w, h) * (_tw2 / _th2) > max(w, h): - break - text = self.insert_spaces(text, i - 1) - font_size = min(w, h) * 0.80 - else: - shrink = 0.75 if vert else 0.85 - font_size = min(w, h) / (text_w / max(w, h)) * shrink - new_font = font.font_variant(size=int(font_size)) - - left, top, right, bottom = new_font.getbbox(text) - text_width = right - left - text_height = bottom - top - - layer = Image.new("RGBA", img.size, (0, 0, 0, 0)) - draw = ImageDraw.Draw(layer) - if not vert: - draw.text( - (rect[0][0] - text_width // 2, rect[0][1] - text_height // 2 - top), - text, - font=new_font, - fill=(255, 255, 255, 255), - ) - else: - x_s = min(box[:, 0]) + _w // 2 - text_height // 2 - y_s = min(box[:, 1]) - for c in text: - draw.text((x_s, y_s), c, font=new_font, fill=(255, 255, 255, 255)) - _, _t, _, _b = new_font.getbbox(c) - y_s += _b - - rotated_layer = layer.rotate(angle, expand=1, center=(rect[0][0], rect[0][1])) - - x_offset = int((img.width - rotated_layer.width) / 2) - y_offset = int((img.height - rotated_layer.height) / 2) - img.paste(rotated_layer, (x_offset, y_offset), rotated_layer) - img = np.expand_dims(np.array(img.convert("1")), axis=2).astype(np.float64) - return img - - def insert_spaces(self, string, nSpace): - if nSpace == 0: - return string - new_string = "" - for char in string: - new_string += char + " " * nSpace - return new_string[:-nSpace] - - def to(self, *args, **kwargs): - self.frozen_CLIP_embedder_t3 = self.frozen_CLIP_embedder_t3.to(*args, **kwargs) - self.embedding_manager = self.embedding_manager.to(*args, **kwargs) - self.text_predictor = self.text_predictor.to(*args, **kwargs) - self.device = self.frozen_CLIP_embedder_t3.device - return self From 930c37ad7d4760b70c16f15eab34b64bc45f81fd Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tolga=20Cang=C3=B6z?= Date: Fri, 18 Oct 2024 11:16:24 +0300 Subject: [PATCH 58/87] style --- .../anytext/pipeline_anytext.py | 41 ++++--------------- 1 file changed, 7 insertions(+), 34 deletions(-) diff --git a/examples/research_projects/anytext/pipeline_anytext.py b/examples/research_projects/anytext/pipeline_anytext.py index 578a7a42c653..c0b7d8ca2835 100644 --- a/examples/research_projects/anytext/pipeline_anytext.py +++ b/examples/research_projects/anytext/pipeline_anytext.py @@ -23,12 +23,18 @@ import re from typing import Any, Callable, Dict, List, Optional, Tuple, Union +import cv2 import numpy as np import PIL.Image import torch import torch.nn.functional as F from bert_tokenizer import BasicTokenizer -from text_embedding_module import TextEmbeddingModule +from embedding_manager import EmbeddingManager +from frozen_clip_embedder_t3 import FrozenCLIPEmbedderT3 +from PIL import Image, ImageDraw, ImageFont +from recognizer import TextRecognizer, create_predictor +from safetensors.torch import load_file +from torch import nn from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer, CLIPVisionModelWithProjection from diffusers.callbacks import MultiPipelineCallbacks, PipelineCallback @@ -100,23 +106,6 @@ """ -import cv2 -import numpy as np -import torch -from embedding_manager import EmbeddingManager -from frozen_clip_embedder_t3 import FrozenCLIPEmbedderT3 -from PIL import Image, ImageDraw, ImageFont -from recognizer import TextRecognizer, create_predictor -from torch import nn - -from diffusers.utils import ( - logging, -) - - -logger = logging.get_logger(__name__) # pylint: disable=invalid-name - - class TextEmbeddingModule(nn.Module): def __init__(self, font_path, use_fp16=False, device="cpu"): super().__init__() @@ -403,21 +392,6 @@ def to(self, *args, **kwargs): return self -from typing import Optional - -import cv2 -import numpy as np -import torch -from PIL import ImageFont -from safetensors.torch import load_file -from torch import nn - -from diffusers.utils import logging - - -logger = logging.get_logger(__name__) # pylint: disable=invalid-name - - # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents def retrieve_latents( encoder_output: torch.Tensor, generator: Optional[torch.Generator] = None, sample_mode: str = "sample" @@ -589,7 +563,6 @@ def to(self, device): return self - # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps def retrieve_timesteps( scheduler, From 21c0c35d5bc87c03623f373fb531c1d2cc407f57 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tolga=20Cang=C3=B6z?= Date: Sun, 20 Oct 2024 17:12:30 +0300 Subject: [PATCH 59/87] style --- src/diffusers/dependency_versions_table.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/diffusers/dependency_versions_table.py b/src/diffusers/dependency_versions_table.py index 9e7bf242eca7..0e421b71e48d 100644 --- a/src/diffusers/dependency_versions_table.py +++ b/src/diffusers/dependency_versions_table.py @@ -38,7 +38,7 @@ "regex": "regex!=2019.12.17", "requests": "requests", "tensorboard": "tensorboard", - "torch": "torch>=1.4", + "torch": "torch>=1.4,<2.5.0", "torchvision": "torchvision", "transformers": "transformers>=4.41.2", "urllib3": "urllib3<=2.0.0", From c4db96a37a92a2d9e9cef55b1f108c9876373feb Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tolga=20Cang=C3=B6z?= Date: Sun, 20 Oct 2024 17:31:35 +0300 Subject: [PATCH 60/87] Move custom blocks from `AuxiliaryLatentModule` to `AnyTextControlNetConditioningEmbedding` --- .../anytext/pipeline_anytext.py | 58 +------------- .../anytext/text_controlnet.py | 75 +++++++++++++------ 2 files changed, 53 insertions(+), 80 deletions(-) diff --git a/examples/research_projects/anytext/pipeline_anytext.py b/examples/research_projects/anytext/pipeline_anytext.py index c0b7d8ca2835..1c6a4fb6609e 100644 --- a/examples/research_projects/anytext/pipeline_anytext.py +++ b/examples/research_projects/anytext/pipeline_anytext.py @@ -33,7 +33,6 @@ from frozen_clip_embedder_t3 import FrozenCLIPEmbedderT3 from PIL import Image, ImageDraw, ImageFont from recognizer import TextRecognizer, create_predictor -from safetensors.torch import load_file from torch import nn from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer, CLIPVisionModelWithProjection @@ -410,9 +409,6 @@ class AuxiliaryLatentModule(nn.Module): def __init__( self, font_path, - glyph_channels=1, - position_channels=1, - model_channels=320, vae=None, device="cpu", use_fp16=False, @@ -422,57 +418,8 @@ def __init__( self.use_fp16 = use_fp16 self.device = device - self.glyph_block = nn.Sequential( - nn.Conv2d(glyph_channels, 8, 3, padding=1), - nn.SiLU(), - nn.Conv2d(8, 8, 3, padding=1), - nn.SiLU(), - nn.Conv2d(8, 16, 3, padding=1, stride=2), - nn.SiLU(), - nn.Conv2d(16, 16, 3, padding=1), - nn.SiLU(), - nn.Conv2d(16, 32, 3, padding=1, stride=2), - nn.SiLU(), - nn.Conv2d(32, 32, 3, padding=1), - nn.SiLU(), - nn.Conv2d(32, 96, 3, padding=1, stride=2), - nn.SiLU(), - nn.Conv2d(96, 96, 3, padding=1), - nn.SiLU(), - nn.Conv2d(96, 256, 3, padding=1, stride=2), - nn.SiLU(), - ) - - self.position_block = nn.Sequential( - nn.Conv2d(position_channels, 8, 3, padding=1), - nn.SiLU(), - nn.Conv2d(8, 8, 3, padding=1), - nn.SiLU(), - nn.Conv2d(8, 16, 3, padding=1, stride=2), - nn.SiLU(), - nn.Conv2d(16, 16, 3, padding=1), - nn.SiLU(), - nn.Conv2d(16, 32, 3, padding=1, stride=2), - nn.SiLU(), - nn.Conv2d(32, 32, 3, padding=1), - nn.SiLU(), - nn.Conv2d(32, 64, 3, padding=1, stride=2), - nn.SiLU(), - ) - self.vae = vae.eval() if vae is not None else None - self.fuse_block = nn.Conv2d(256 + 64 + 4, model_channels, 3, padding=1) - - self.glyph_block.load_state_dict(load_file("glyph_block.safetensors", device=str(self.device))) - self.position_block.load_state_dict(load_file("position_block.safetensors", device=str(self.device))) - self.fuse_block.load_state_dict(load_file("fuse_block.safetensors", device=str(self.device))) - - if use_fp16: - self.glyph_block = self.glyph_block.to(dtype=torch.float16) - self.position_block = self.position_block.to(dtype=torch.float16) - self.fuse_block = self.fuse_block.to(dtype=torch.float16) - @torch.no_grad() def forward( self, @@ -518,11 +465,8 @@ def forward( glyphs = torch.cat(text_info["glyphs"], dim=1).sum(dim=1, keepdim=True) positions = torch.cat(text_info["positions"], dim=1).sum(dim=1, keepdim=True) - enc_glyph = self.glyph_block(glyphs) - enc_pos = self.position_block(positions) - guided_hint = self.fuse_block(torch.cat([enc_glyph, enc_pos, text_info["masked_x"]], dim=1)) - return guided_hint + return glyphs, positions, text_info def check_channels(self, image): channels = image.shape[2] if len(image.shape) == 3 else 1 diff --git a/examples/research_projects/anytext/text_controlnet.py b/examples/research_projects/anytext/text_controlnet.py index d860bde57fa4..8ad40af78063 100644 --- a/examples/research_projects/anytext/text_controlnet.py +++ b/examples/research_projects/anytext/text_controlnet.py @@ -14,7 +14,6 @@ from typing import Any, Dict, Optional, Tuple, Union import torch -import torch.nn.functional as F from torch import nn from diffusers.configuration_utils import register_to_config @@ -40,37 +39,67 @@ class AnyTextControlNetConditioningEmbedding(nn.Module): def __init__( self, - conditioning_embedding_channels: int, - conditioning_channels: int = 3, - block_out_channels: Tuple[int, ...] = (16, 32, 96, 256), + glyph_channels=1, + position_channels=1, + model_channels=320, ): super().__init__() - self.conv_in = nn.Conv2d(conditioning_channels, block_out_channels[0], kernel_size=3, padding=1) - - self.blocks = nn.ModuleList([]) - - for i in range(len(block_out_channels) - 1): - channel_in = block_out_channels[i] - channel_out = block_out_channels[i + 1] - self.blocks.append(nn.Conv2d(channel_in, channel_in, kernel_size=3, padding=1)) - self.blocks.append(nn.Conv2d(channel_in, channel_out, kernel_size=3, padding=1, stride=2)) + self.glyph_block = nn.Sequential( + nn.Conv2d(glyph_channels, 8, 3, padding=1), + nn.SiLU(), + nn.Conv2d(8, 8, 3, padding=1), + nn.SiLU(), + nn.Conv2d(8, 16, 3, padding=1, stride=2), + nn.SiLU(), + nn.Conv2d(16, 16, 3, padding=1), + nn.SiLU(), + nn.Conv2d(16, 32, 3, padding=1, stride=2), + nn.SiLU(), + nn.Conv2d(32, 32, 3, padding=1), + nn.SiLU(), + nn.Conv2d(32, 96, 3, padding=1, stride=2), + nn.SiLU(), + nn.Conv2d(96, 96, 3, padding=1), + nn.SiLU(), + nn.Conv2d(96, 256, 3, padding=1, stride=2), + nn.SiLU(), + ) - self.conv_out = zero_module( - nn.Conv2d(block_out_channels[-1], conditioning_embedding_channels, kernel_size=3, padding=1) + self.position_block = nn.Sequential( + nn.Conv2d(position_channels, 8, 3, padding=1), + nn.SiLU(), + nn.Conv2d(8, 8, 3, padding=1), + nn.SiLU(), + nn.Conv2d(8, 16, 3, padding=1, stride=2), + nn.SiLU(), + nn.Conv2d(16, 16, 3, padding=1), + nn.SiLU(), + nn.Conv2d(16, 32, 3, padding=1, stride=2), + nn.SiLU(), + nn.Conv2d(32, 32, 3, padding=1), + nn.SiLU(), + nn.Conv2d(32, 64, 3, padding=1, stride=2), + nn.SiLU(), ) - def forward(self, conditioning): - embedding = self.conv_in(conditioning) - embedding = F.silu(embedding) + self.fuse_block = nn.Conv2d(256 + 64 + 4, model_channels, 3, padding=1) + + # self.glyph_block.load_state_dict(load_file("glyph_block.safetensors", device=str(self.device))) + # self.position_block.load_state_dict(load_file("position_block.safetensors", device=str(self.device))) + # self.fuse_block.load_state_dict(load_file("fuse_block.safetensors", device=str(self.device))) - for block in self.blocks: - embedding = block(embedding) - embedding = F.silu(embedding) + # if use_fp16: + # self.glyph_block = self.glyph_block.to(dtype=torch.float16) + # self.position_block = self.position_block.to(dtype=torch.float16) + # self.fuse_block = self.fuse_block.to(dtype=torch.float16) - embedding = self.conv_out(embedding) + def forward(self, glyphs, positions, text_info): + glyph_embedding = self.glyph_block(glyphs) + position_embedding = self.position_block(positions) + guided_hint = self.fuse_block(torch.cat([glyph_embedding, position_embedding, text_info["masked_x"]], dim=1)) - return embedding + return guided_hint class AnyTextControlNetModel(ControlNetModel): From 6bd0b4cd96ca1a0b54645c20f6eec39ef265f1a0 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tolga=20Cang=C3=B6z?= Date: Sun, 3 Nov 2024 10:20:17 +0300 Subject: [PATCH 61/87] Follow one-file policy --- .../anytext/embedding_manager.py | 101 ----- .../anytext/pipeline_anytext.py | 366 +++++++++++++++++- .../research_projects/anytext/recognizer.py | 306 --------------- 3 files changed, 364 insertions(+), 409 deletions(-) delete mode 100644 examples/research_projects/anytext/embedding_manager.py delete mode 100755 examples/research_projects/anytext/recognizer.py diff --git a/examples/research_projects/anytext/embedding_manager.py b/examples/research_projects/anytext/embedding_manager.py deleted file mode 100644 index 5afda3eed5f1..000000000000 --- a/examples/research_projects/anytext/embedding_manager.py +++ /dev/null @@ -1,101 +0,0 @@ -""" -Copyright (c) Alibaba, Inc. and its affiliates. -""" -from functools import partial - -import torch -import torch.nn as nn -from safetensors.torch import load_file - - -def get_clip_token_for_string(tokenizer, string): - batch_encoding = tokenizer( - string, - truncation=True, - max_length=77, - return_length=True, - return_overflowing_tokens=False, - padding="max_length", - return_tensors="pt", - ) - tokens = batch_encoding["input_ids"] - assert ( - torch.count_nonzero(tokens - 49407) == 2 - ), f"String '{string}' maps to more than a single token. Please use another string" - return tokens[0, 1] - - -def get_recog_emb(encoder, img_list): - _img_list = [(img.repeat(1, 3, 1, 1) * 255)[0] for img in img_list] - encoder.predictor.eval() - _, preds_neck = encoder.pred_imglist(_img_list, show_debug=False) - return preds_neck - - -class EmbeddingManager(nn.Module): - def __init__( - self, - embedder, - placeholder_string="*", - use_fp16=False, - ): - super().__init__() - get_token_for_string = partial(get_clip_token_for_string, embedder.tokenizer) - token_dim = 768 - self.get_recog_emb = None - self.token_dim = token_dim - - self.proj = nn.Linear(40 * 64, token_dim) - self.proj.load_state_dict(load_file("proj.safetensors", device=str(embedder.device))) - if use_fp16: - self.proj = self.proj.to(dtype=torch.float16) - - self.placeholder_token = get_token_for_string(placeholder_string) - - @torch.no_grad() - def encode_text(self, text_info): - if self.get_recog_emb is None: - self.get_recog_emb = partial(get_recog_emb, self.recog) - - gline_list = [] - for i in range(len(text_info["n_lines"])): # sample index in a batch - n_lines = text_info["n_lines"][i] - for j in range(n_lines): # line - gline_list += [text_info["gly_line"][j][i : i + 1]] - - if len(gline_list) > 0: - recog_emb = self.get_recog_emb(gline_list) - enc_glyph = self.proj(recog_emb.reshape(recog_emb.shape[0], -1).to(self.proj.weight.dtype)) - - self.text_embs_all = [] - n_idx = 0 - for i in range(len(text_info["n_lines"])): # sample index in a batch - n_lines = text_info["n_lines"][i] - text_embs = [] - for j in range(n_lines): # line - text_embs += [enc_glyph[n_idx : n_idx + 1]] - n_idx += 1 - self.text_embs_all += [text_embs] - - @torch.no_grad() - def forward( - self, - tokenized_text, - embedded_text, - ): - b, device = tokenized_text.shape[0], tokenized_text.device - for i in range(b): - idx = tokenized_text[i] == self.placeholder_token.to(device) - if sum(idx) > 0: - if i >= len(self.text_embs_all): - print("truncation for log images...") - break - text_emb = torch.cat(self.text_embs_all[i], dim=0) - if sum(idx) != len(text_emb): - print("truncation for long caption...") - text_emb = text_emb.to(embedded_text.device) - embedded_text[i][idx] = text_emb[: sum(idx)] - return embedded_text - - def embedding_parameters(self): - return self.parameters() diff --git a/examples/research_projects/anytext/pipeline_anytext.py b/examples/research_projects/anytext/pipeline_anytext.py index 1c6a4fb6609e..219a93a330a9 100644 --- a/examples/research_projects/anytext/pipeline_anytext.py +++ b/examples/research_projects/anytext/pipeline_anytext.py @@ -29,10 +29,8 @@ import torch import torch.nn.functional as F from bert_tokenizer import BasicTokenizer -from embedding_manager import EmbeddingManager from frozen_clip_embedder_t3 import FrozenCLIPEmbedderT3 from PIL import Image, ImageDraw, ImageFont -from recognizer import TextRecognizer, create_predictor from torch import nn from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer, CLIPVisionModelWithProjection @@ -104,6 +102,370 @@ ``` """ +""" +Copyright (c) Alibaba, Inc. and its affiliates. +""" +from functools import partial + +import torch +import torch.nn as nn +from safetensors.torch import load_file + + +def get_clip_token_for_string(tokenizer, string): + batch_encoding = tokenizer( + string, + truncation=True, + max_length=77, + return_length=True, + return_overflowing_tokens=False, + padding="max_length", + return_tensors="pt", + ) + tokens = batch_encoding["input_ids"] + assert ( + torch.count_nonzero(tokens - 49407) == 2 + ), f"String '{string}' maps to more than a single token. Please use another string" + return tokens[0, 1] + + +def get_recog_emb(encoder, img_list): + _img_list = [(img.repeat(1, 3, 1, 1) * 255)[0] for img in img_list] + encoder.predictor.eval() + _, preds_neck = encoder.pred_imglist(_img_list, show_debug=False) + return preds_neck + + +class EmbeddingManager(nn.Module): + def __init__( + self, + embedder, + placeholder_string="*", + use_fp16=False, + ): + super().__init__() + get_token_for_string = partial(get_clip_token_for_string, embedder.tokenizer) + token_dim = 768 + self.get_recog_emb = None + self.token_dim = token_dim + + self.proj = nn.Linear(40 * 64, token_dim) + self.proj.load_state_dict(load_file("proj.safetensors", device=str(embedder.device))) + if use_fp16: + self.proj = self.proj.to(dtype=torch.float16) + + self.placeholder_token = get_token_for_string(placeholder_string) + + @torch.no_grad() + def encode_text(self, text_info): + if self.get_recog_emb is None: + self.get_recog_emb = partial(get_recog_emb, self.recog) + + gline_list = [] + for i in range(len(text_info["n_lines"])): # sample index in a batch + n_lines = text_info["n_lines"][i] + for j in range(n_lines): # line + gline_list += [text_info["gly_line"][j][i : i + 1]] + + if len(gline_list) > 0: + recog_emb = self.get_recog_emb(gline_list) + enc_glyph = self.proj(recog_emb.reshape(recog_emb.shape[0], -1).to(self.proj.weight.dtype)) + + self.text_embs_all = [] + n_idx = 0 + for i in range(len(text_info["n_lines"])): # sample index in a batch + n_lines = text_info["n_lines"][i] + text_embs = [] + for j in range(n_lines): # line + text_embs += [enc_glyph[n_idx : n_idx + 1]] + n_idx += 1 + self.text_embs_all += [text_embs] + + @torch.no_grad() + def forward( + self, + tokenized_text, + embedded_text, + ): + b, device = tokenized_text.shape[0], tokenized_text.device + for i in range(b): + idx = tokenized_text[i] == self.placeholder_token.to(device) + if sum(idx) > 0: + if i >= len(self.text_embs_all): + print("truncation for log images...") + break + text_emb = torch.cat(self.text_embs_all[i], dim=0) + if sum(idx) != len(text_emb): + print("truncation for long caption...") + text_emb = text_emb.to(embedded_text.device) + embedded_text[i][idx] = text_emb[: sum(idx)] + return embedded_text + + def embedding_parameters(self): + return self.parameters() + + + +""" +Copyright (c) Alibaba, Inc. and its affiliates. +""" +import math +import os +import sys +import time +import traceback + +import cv2 +import numpy as np +import torch +import torch.nn.functional as F +from easydict import EasyDict as edict +from ocr_recog.RecModel import RecModel +from skimage.transform._geometric import _umeyama as get_sym_mat + + +sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))) + + +def min_bounding_rect(img): + ret, thresh = cv2.threshold(img, 127, 255, 0) + contours, hierarchy = cv2.findContours(thresh, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE) + if len(contours) == 0: + print("Bad contours, using fake bbox...") + return np.array([[0, 0], [100, 0], [100, 100], [0, 100]]) + max_contour = max(contours, key=cv2.contourArea) + rect = cv2.minAreaRect(max_contour) + box = cv2.boxPoints(rect) + box = np.int0(box) + # sort + x_sorted = sorted(box, key=lambda x: x[0]) + left = x_sorted[:2] + right = x_sorted[2:] + left = sorted(left, key=lambda x: x[1]) + (tl, bl) = left + right = sorted(right, key=lambda x: x[1]) + (tr, br) = right + if tl[1] > bl[1]: + (tl, bl) = (bl, tl) + if tr[1] > br[1]: + (tr, br) = (br, tr) + return np.array([tl, tr, br, bl]) + + +def adjust_image(box, img): + pts1 = np.float32([box[0], box[1], box[2], box[3]]) + width = max(np.linalg.norm(pts1[0] - pts1[1]), np.linalg.norm(pts1[2] - pts1[3])) + height = max(np.linalg.norm(pts1[0] - pts1[3]), np.linalg.norm(pts1[1] - pts1[2])) + pts2 = np.float32([[0, 0], [width, 0], [width, height], [0, height]]) + # get transform matrix + M = get_sym_mat(pts1, pts2, estimate_scale=True) + C, H, W = img.shape + T = np.array([[2 / W, 0, -1], [0, 2 / H, -1], [0, 0, 1]]) + theta = np.linalg.inv(T @ M @ np.linalg.inv(T)) + theta = torch.from_numpy(theta[:2, :]).unsqueeze(0).type(torch.float32).to(img.device) + grid = F.affine_grid(theta, torch.Size([1, C, H, W]), align_corners=True) + result = F.grid_sample(img.unsqueeze(0), grid, align_corners=True) + result = torch.clamp(result.squeeze(0), 0, 255) + # crop + result = result[:, : int(height), : int(width)] + return result + + +""" +mask: numpy.ndarray, mask of textual, HWC +src_img: torch.Tensor, source image, CHW +""" + + +def crop_image(src_img, mask): + box = min_bounding_rect(mask) + result = adjust_image(box, src_img) + if len(result.shape) == 2: + result = torch.stack([result] * 3, axis=-1) + return result + + +def create_predictor(model_dir=None, model_lang="ch", device="cpu", use_fp16=False): + model_file_path = model_dir + if model_file_path is not None and not os.path.exists(model_file_path): + raise ValueError("not find model file path {}".format(model_file_path)) + + if model_lang == "ch": + n_class = 6625 + elif model_lang == "en": + n_class = 97 + else: + raise ValueError(f"Unsupported OCR recog model_lang: {model_lang}") + rec_config = edict( + in_channels=3, + backbone=edict(type="MobileNetV1Enhance", scale=0.5, last_conv_stride=[1, 2], last_pool_type="avg"), + neck=edict(type="SequenceEncoder", encoder_type="svtr", dims=64, depth=2, hidden_dims=120, use_guide=True), + head=edict(type="CTCHead", fc_decay=0.00001, out_channels=n_class, return_feats=True), + ) + + rec_model = RecModel(rec_config) + if model_file_path is not None: + rec_model.load_state_dict(torch.load(model_file_path, map_location=device)) + return rec_model + + +def _check_image_file(path): + img_end = ("tiff", "tif", "bmp", "rgb", "jpg", "png", "jpeg") + return path.lower().endswith(tuple(img_end)) + + +def get_image_file_list(img_file): + imgs_lists = [] + if img_file is None or not os.path.exists(img_file): + raise Exception("not found any img file in {}".format(img_file)) + if os.path.isfile(img_file) and _check_image_file(img_file): + imgs_lists.append(img_file) + elif os.path.isdir(img_file): + for single_file in os.listdir(img_file): + file_path = os.path.join(img_file, single_file) + if os.path.isfile(file_path) and _check_image_file(file_path): + imgs_lists.append(file_path) + if len(imgs_lists) == 0: + raise Exception("not found any img file in {}".format(img_file)) + imgs_lists = sorted(imgs_lists) + return imgs_lists + + +class TextRecognizer(object): + def __init__(self, args, predictor): + self.rec_image_shape = [int(v) for v in args["rec_image_shape"].split(",")] + self.rec_batch_num = args["rec_batch_num"] + self.predictor = predictor + self.chars = self.get_char_dict(args["rec_char_dict_path"]) + self.char2id = {x: i for i, x in enumerate(self.chars)} + self.is_onnx = not isinstance(self.predictor, torch.nn.Module) + self.use_fp16 = args["use_fp16"] + + # img: CHW + def resize_norm_img(self, img, max_wh_ratio): + imgC, imgH, imgW = self.rec_image_shape + assert imgC == img.shape[0] + imgW = int((imgH * max_wh_ratio)) + + h, w = img.shape[1:] + ratio = w / float(h) + if math.ceil(imgH * ratio) > imgW: + resized_w = imgW + else: + resized_w = int(math.ceil(imgH * ratio)) + resized_image = torch.nn.functional.interpolate( + img.unsqueeze(0), + size=(imgH, resized_w), + mode="bilinear", + align_corners=True, + ) + resized_image /= 255.0 + resized_image -= 0.5 + resized_image /= 0.5 + padding_im = torch.zeros((imgC, imgH, imgW), dtype=torch.float32).to(img.device) + padding_im[:, :, 0:resized_w] = resized_image[0] + return padding_im + + # img_list: list of tensors with shape chw 0-255 + def pred_imglist(self, img_list, show_debug=False): + img_num = len(img_list) + assert img_num > 0 + # Calculate the aspect ratio of all text bars + width_list = [] + for img in img_list: + width_list.append(img.shape[2] / float(img.shape[1])) + # Sorting can speed up the recognition process + indices = torch.from_numpy(np.argsort(np.array(width_list))) + batch_num = self.rec_batch_num + preds_all = [None] * img_num + preds_neck_all = [None] * img_num + for beg_img_no in range(0, img_num, batch_num): + end_img_no = min(img_num, beg_img_no + batch_num) + norm_img_batch = [] + + imgC, imgH, imgW = self.rec_image_shape[:3] + max_wh_ratio = imgW / imgH + for ino in range(beg_img_no, end_img_no): + h, w = img_list[indices[ino]].shape[1:] + if h > w * 1.2: + img = img_list[indices[ino]] + img = torch.transpose(img, 1, 2).flip(dims=[1]) + img_list[indices[ino]] = img + h, w = img.shape[1:] + # wh_ratio = w * 1.0 / h + # max_wh_ratio = max(max_wh_ratio, wh_ratio) # comment to not use different ratio + for ino in range(beg_img_no, end_img_no): + norm_img = self.resize_norm_img(img_list[indices[ino]], max_wh_ratio) + if self.use_fp16: + norm_img = norm_img.half() + norm_img = norm_img.unsqueeze(0) + norm_img_batch.append(norm_img) + norm_img_batch = torch.cat(norm_img_batch, dim=0) + if show_debug: + for i in range(len(norm_img_batch)): + _img = norm_img_batch[i].permute(1, 2, 0).detach().cpu().numpy() + _img = (_img + 0.5) * 255 + _img = _img[:, :, ::-1] + file_name = f"{indices[beg_img_no + i]}" + if os.path.exists(file_name + ".jpg"): + file_name += "_2" # ori image + cv2.imwrite(file_name + ".jpg", _img) + if self.is_onnx: + input_dict = {} + input_dict[self.predictor.get_inputs()[0].name] = norm_img_batch.detach().cpu().numpy() + outputs = self.predictor.run(None, input_dict) + preds = {} + preds["ctc"] = torch.from_numpy(outputs[0]) + preds["ctc_neck"] = [torch.zeros(1)] * img_num + else: + preds = self.predictor(norm_img_batch) + for rno in range(preds["ctc"].shape[0]): + preds_all[indices[beg_img_no + rno]] = preds["ctc"][rno] + preds_neck_all[indices[beg_img_no + rno]] = preds["ctc_neck"][rno] + + return torch.stack(preds_all, dim=0), torch.stack(preds_neck_all, dim=0) + + def get_char_dict(self, character_dict_path): + character_str = [] + with open(character_dict_path, "rb") as fin: + lines = fin.readlines() + for line in lines: + line = line.decode("utf-8").strip("\n").strip("\r\n") + character_str.append(line) + dict_character = list(character_str) + dict_character = ["sos"] + dict_character + [" "] # eos is space + return dict_character + + def get_text(self, order): + char_list = [self.chars[text_id] for text_id in order] + return "".join(char_list) + + def decode(self, mat): + text_index = mat.detach().cpu().numpy().argmax(axis=1) + ignored_tokens = [0] + selection = np.ones(len(text_index), dtype=bool) + selection[1:] = text_index[1:] != text_index[:-1] + for ignored_token in ignored_tokens: + selection &= text_index != ignored_token + return text_index[selection], np.where(selection)[0] + + def get_ctcloss(self, preds, gt_text, weight): + if not isinstance(weight, torch.Tensor): + weight = torch.tensor(weight).to(preds.device) + ctc_loss = torch.nn.CTCLoss(reduction="none") + log_probs = preds.log_softmax(dim=2).permute(1, 0, 2) # NTC-->TNC + targets = [] + target_lengths = [] + for t in gt_text: + targets += [self.char2id.get(i, len(self.chars) - 1) for i in t] + target_lengths += [len(t)] + targets = torch.tensor(targets).to(preds.device) + target_lengths = torch.tensor(target_lengths).to(preds.device) + input_lengths = torch.tensor([log_probs.shape[0]] * (log_probs.shape[1])).to(preds.device) + loss = ctc_loss(log_probs, targets, input_lengths, target_lengths) + loss = loss / input_lengths * weight + return loss + class TextEmbeddingModule(nn.Module): def __init__(self, font_path, use_fp16=False, device="cpu"): diff --git a/examples/research_projects/anytext/recognizer.py b/examples/research_projects/anytext/recognizer.py deleted file mode 100755 index 5cd7e245509d..000000000000 --- a/examples/research_projects/anytext/recognizer.py +++ /dev/null @@ -1,306 +0,0 @@ -""" -Copyright (c) Alibaba, Inc. and its affiliates. -""" -import math -import os -import sys -import time -import traceback - -import cv2 -import numpy as np -import torch -import torch.nn.functional as F -from easydict import EasyDict as edict -from ocr_recog.RecModel import RecModel -from skimage.transform._geometric import _umeyama as get_sym_mat - - -sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))) - - -def min_bounding_rect(img): - ret, thresh = cv2.threshold(img, 127, 255, 0) - contours, hierarchy = cv2.findContours(thresh, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE) - if len(contours) == 0: - print("Bad contours, using fake bbox...") - return np.array([[0, 0], [100, 0], [100, 100], [0, 100]]) - max_contour = max(contours, key=cv2.contourArea) - rect = cv2.minAreaRect(max_contour) - box = cv2.boxPoints(rect) - box = np.int0(box) - # sort - x_sorted = sorted(box, key=lambda x: x[0]) - left = x_sorted[:2] - right = x_sorted[2:] - left = sorted(left, key=lambda x: x[1]) - (tl, bl) = left - right = sorted(right, key=lambda x: x[1]) - (tr, br) = right - if tl[1] > bl[1]: - (tl, bl) = (bl, tl) - if tr[1] > br[1]: - (tr, br) = (br, tr) - return np.array([tl, tr, br, bl]) - - -def adjust_image(box, img): - pts1 = np.float32([box[0], box[1], box[2], box[3]]) - width = max(np.linalg.norm(pts1[0] - pts1[1]), np.linalg.norm(pts1[2] - pts1[3])) - height = max(np.linalg.norm(pts1[0] - pts1[3]), np.linalg.norm(pts1[1] - pts1[2])) - pts2 = np.float32([[0, 0], [width, 0], [width, height], [0, height]]) - # get transform matrix - M = get_sym_mat(pts1, pts2, estimate_scale=True) - C, H, W = img.shape - T = np.array([[2 / W, 0, -1], [0, 2 / H, -1], [0, 0, 1]]) - theta = np.linalg.inv(T @ M @ np.linalg.inv(T)) - theta = torch.from_numpy(theta[:2, :]).unsqueeze(0).type(torch.float32).to(img.device) - grid = F.affine_grid(theta, torch.Size([1, C, H, W]), align_corners=True) - result = F.grid_sample(img.unsqueeze(0), grid, align_corners=True) - result = torch.clamp(result.squeeze(0), 0, 255) - # crop - result = result[:, : int(height), : int(width)] - return result - - -""" -mask: numpy.ndarray, mask of textual, HWC -src_img: torch.Tensor, source image, CHW -""" - - -def crop_image(src_img, mask): - box = min_bounding_rect(mask) - result = adjust_image(box, src_img) - if len(result.shape) == 2: - result = torch.stack([result] * 3, axis=-1) - return result - - -def create_predictor(model_dir=None, model_lang="ch", device="cpu", use_fp16=False): - model_file_path = model_dir - if model_file_path is not None and not os.path.exists(model_file_path): - raise ValueError("not find model file path {}".format(model_file_path)) - - if model_lang == "ch": - n_class = 6625 - elif model_lang == "en": - n_class = 97 - else: - raise ValueError(f"Unsupported OCR recog model_lang: {model_lang}") - rec_config = edict( - in_channels=3, - backbone=edict(type="MobileNetV1Enhance", scale=0.5, last_conv_stride=[1, 2], last_pool_type="avg"), - neck=edict(type="SequenceEncoder", encoder_type="svtr", dims=64, depth=2, hidden_dims=120, use_guide=True), - head=edict(type="CTCHead", fc_decay=0.00001, out_channels=n_class, return_feats=True), - ) - - rec_model = RecModel(rec_config) - if model_file_path is not None: - rec_model.load_state_dict(torch.load(model_file_path, map_location=device)) - return rec_model - - -def _check_image_file(path): - img_end = ("tiff", "tif", "bmp", "rgb", "jpg", "png", "jpeg") - return path.lower().endswith(tuple(img_end)) - - -def get_image_file_list(img_file): - imgs_lists = [] - if img_file is None or not os.path.exists(img_file): - raise Exception("not found any img file in {}".format(img_file)) - if os.path.isfile(img_file) and _check_image_file(img_file): - imgs_lists.append(img_file) - elif os.path.isdir(img_file): - for single_file in os.listdir(img_file): - file_path = os.path.join(img_file, single_file) - if os.path.isfile(file_path) and _check_image_file(file_path): - imgs_lists.append(file_path) - if len(imgs_lists) == 0: - raise Exception("not found any img file in {}".format(img_file)) - imgs_lists = sorted(imgs_lists) - return imgs_lists - - -class TextRecognizer(object): - def __init__(self, args, predictor): - self.rec_image_shape = [int(v) for v in args["rec_image_shape"].split(",")] - self.rec_batch_num = args["rec_batch_num"] - self.predictor = predictor - self.chars = self.get_char_dict(args["rec_char_dict_path"]) - self.char2id = {x: i for i, x in enumerate(self.chars)} - self.is_onnx = not isinstance(self.predictor, torch.nn.Module) - self.use_fp16 = args["use_fp16"] - - # img: CHW - def resize_norm_img(self, img, max_wh_ratio): - imgC, imgH, imgW = self.rec_image_shape - assert imgC == img.shape[0] - imgW = int((imgH * max_wh_ratio)) - - h, w = img.shape[1:] - ratio = w / float(h) - if math.ceil(imgH * ratio) > imgW: - resized_w = imgW - else: - resized_w = int(math.ceil(imgH * ratio)) - resized_image = torch.nn.functional.interpolate( - img.unsqueeze(0), - size=(imgH, resized_w), - mode="bilinear", - align_corners=True, - ) - resized_image /= 255.0 - resized_image -= 0.5 - resized_image /= 0.5 - padding_im = torch.zeros((imgC, imgH, imgW), dtype=torch.float32).to(img.device) - padding_im[:, :, 0:resized_w] = resized_image[0] - return padding_im - - # img_list: list of tensors with shape chw 0-255 - def pred_imglist(self, img_list, show_debug=False): - img_num = len(img_list) - assert img_num > 0 - # Calculate the aspect ratio of all text bars - width_list = [] - for img in img_list: - width_list.append(img.shape[2] / float(img.shape[1])) - # Sorting can speed up the recognition process - indices = torch.from_numpy(np.argsort(np.array(width_list))) - batch_num = self.rec_batch_num - preds_all = [None] * img_num - preds_neck_all = [None] * img_num - for beg_img_no in range(0, img_num, batch_num): - end_img_no = min(img_num, beg_img_no + batch_num) - norm_img_batch = [] - - imgC, imgH, imgW = self.rec_image_shape[:3] - max_wh_ratio = imgW / imgH - for ino in range(beg_img_no, end_img_no): - h, w = img_list[indices[ino]].shape[1:] - if h > w * 1.2: - img = img_list[indices[ino]] - img = torch.transpose(img, 1, 2).flip(dims=[1]) - img_list[indices[ino]] = img - h, w = img.shape[1:] - # wh_ratio = w * 1.0 / h - # max_wh_ratio = max(max_wh_ratio, wh_ratio) # comment to not use different ratio - for ino in range(beg_img_no, end_img_no): - norm_img = self.resize_norm_img(img_list[indices[ino]], max_wh_ratio) - if self.use_fp16: - norm_img = norm_img.half() - norm_img = norm_img.unsqueeze(0) - norm_img_batch.append(norm_img) - norm_img_batch = torch.cat(norm_img_batch, dim=0) - if show_debug: - for i in range(len(norm_img_batch)): - _img = norm_img_batch[i].permute(1, 2, 0).detach().cpu().numpy() - _img = (_img + 0.5) * 255 - _img = _img[:, :, ::-1] - file_name = f"{indices[beg_img_no + i]}" - if os.path.exists(file_name + ".jpg"): - file_name += "_2" # ori image - cv2.imwrite(file_name + ".jpg", _img) - if self.is_onnx: - input_dict = {} - input_dict[self.predictor.get_inputs()[0].name] = norm_img_batch.detach().cpu().numpy() - outputs = self.predictor.run(None, input_dict) - preds = {} - preds["ctc"] = torch.from_numpy(outputs[0]) - preds["ctc_neck"] = [torch.zeros(1)] * img_num - else: - preds = self.predictor(norm_img_batch) - for rno in range(preds["ctc"].shape[0]): - preds_all[indices[beg_img_no + rno]] = preds["ctc"][rno] - preds_neck_all[indices[beg_img_no + rno]] = preds["ctc_neck"][rno] - - return torch.stack(preds_all, dim=0), torch.stack(preds_neck_all, dim=0) - - def get_char_dict(self, character_dict_path): - character_str = [] - with open(character_dict_path, "rb") as fin: - lines = fin.readlines() - for line in lines: - line = line.decode("utf-8").strip("\n").strip("\r\n") - character_str.append(line) - dict_character = list(character_str) - dict_character = ["sos"] + dict_character + [" "] # eos is space - return dict_character - - def get_text(self, order): - char_list = [self.chars[text_id] for text_id in order] - return "".join(char_list) - - def decode(self, mat): - text_index = mat.detach().cpu().numpy().argmax(axis=1) - ignored_tokens = [0] - selection = np.ones(len(text_index), dtype=bool) - selection[1:] = text_index[1:] != text_index[:-1] - for ignored_token in ignored_tokens: - selection &= text_index != ignored_token - return text_index[selection], np.where(selection)[0] - - def get_ctcloss(self, preds, gt_text, weight): - if not isinstance(weight, torch.Tensor): - weight = torch.tensor(weight).to(preds.device) - ctc_loss = torch.nn.CTCLoss(reduction="none") - log_probs = preds.log_softmax(dim=2).permute(1, 0, 2) # NTC-->TNC - targets = [] - target_lengths = [] - for t in gt_text: - targets += [self.char2id.get(i, len(self.chars) - 1) for i in t] - target_lengths += [len(t)] - targets = torch.tensor(targets).to(preds.device) - target_lengths = torch.tensor(target_lengths).to(preds.device) - input_lengths = torch.tensor([log_probs.shape[0]] * (log_probs.shape[1])).to(preds.device) - loss = ctc_loss(log_probs, targets, input_lengths, target_lengths) - loss = loss / input_lengths * weight - return loss - - -def main(): - rec_model_dir = "./ocr_weights/ppv3_rec.pth" - predictor = create_predictor(rec_model_dir) - args = edict() - args.rec_image_shape = "3, 48, 320" - args.rec_char_dict_path = "./ocr_weights/ppocr_keys_v1.txt" - args.rec_batch_num = 6 - text_recognizer = TextRecognizer(args, predictor) - image_dir = "./test_imgs_cn" - gt_text = ["韩国小馆"] * 14 - - image_file_list = get_image_file_list(image_dir) - valid_image_file_list = [] - img_list = [] - - for image_file in image_file_list: - img = cv2.imread(image_file) - if img is None: - print("error in loading image:{}".format(image_file)) - continue - valid_image_file_list.append(image_file) - img_list.append(torch.from_numpy(img).permute(2, 0, 1).float()) - try: - tic = time.time() - times = [] - for i in range(10): - preds, _ = text_recognizer.pred_imglist(img_list) # get text - preds_all = preds.softmax(dim=2) - times += [(time.time() - tic) * 1000.0] - tic = time.time() - print(times) - print(np.mean(times[1:]) / len(preds_all)) - weight = np.ones(len(gt_text)) - loss = text_recognizer.get_ctcloss(preds, gt_text, weight) - for i in range(len(valid_image_file_list)): - pred = preds_all[i] - order, idx = text_recognizer.decode(pred) - text = text_recognizer.get_text(order) - print(f'{valid_image_file_list[i]}: pred/gt="{text}"/"{gt_text[i]}", loss={loss[i]:.2f}') - except Exception as E: - print(traceback.format_exc(), E) - - -if __name__ == "__main__": - main() From b3f98a77c42c17df05d2f95c0b01c53148cae785 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tolga=20Cang=C3=B6z?= Date: Sun, 3 Nov 2024 10:22:14 +0300 Subject: [PATCH 62/87] style --- .../anytext/pipeline_anytext.py | 37 +++++-------------- 1 file changed, 9 insertions(+), 28 deletions(-) diff --git a/examples/research_projects/anytext/pipeline_anytext.py b/examples/research_projects/anytext/pipeline_anytext.py index 219a93a330a9..3a8d6017ae8c 100644 --- a/examples/research_projects/anytext/pipeline_anytext.py +++ b/examples/research_projects/anytext/pipeline_anytext.py @@ -1,4 +1,5 @@ # Copyright 2024 The HuggingFace Team. All rights reserved. +# Copyright (c) Alibaba, Inc. and its affiliates. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -20,7 +21,11 @@ import inspect +import math +import os import re +import sys +from functools import partial from typing import Any, Callable, Dict, List, Optional, Tuple, Union import cv2 @@ -29,8 +34,12 @@ import torch import torch.nn.functional as F from bert_tokenizer import BasicTokenizer +from easydict import EasyDict as edict from frozen_clip_embedder_t3 import FrozenCLIPEmbedderT3 +from ocr_recog.RecModel import RecModel from PIL import Image, ImageDraw, ImageFont +from safetensors.torch import load_file +from skimage.transform._geometric import _umeyama as get_sym_mat from torch import nn from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer, CLIPVisionModelWithProjection @@ -102,15 +111,6 @@ ``` """ -""" -Copyright (c) Alibaba, Inc. and its affiliates. -""" -from functools import partial - -import torch -import torch.nn as nn -from safetensors.torch import load_file - def get_clip_token_for_string(tokenizer, string): batch_encoding = tokenizer( @@ -205,25 +205,6 @@ def embedding_parameters(self): return self.parameters() - -""" -Copyright (c) Alibaba, Inc. and its affiliates. -""" -import math -import os -import sys -import time -import traceback - -import cv2 -import numpy as np -import torch -import torch.nn.functional as F -from easydict import EasyDict as edict -from ocr_recog.RecModel import RecModel -from skimage.transform._geometric import _umeyama as get_sym_mat - - sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))) From 75a0f1f68c3c1025523bfdeea9ef7083d283bea9 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tolga=20Cang=C3=B6z?= Date: Mon, 13 Jan 2025 11:17:48 +0300 Subject: [PATCH 63/87] [Docs] Update README and pipeline_anytext.py to use AnyTextControlNetModel --- examples/research_projects/anytext/README.md | 4 ++-- .../anytext/{text_controlnet.py => anytext_controlnet.py} | 0 examples/research_projects/anytext/pipeline_anytext.py | 4 ++-- 3 files changed, 4 insertions(+), 4 deletions(-) rename examples/research_projects/anytext/{text_controlnet.py => anytext_controlnet.py} (100%) diff --git a/examples/research_projects/anytext/README.md b/examples/research_projects/anytext/README.md index 6f7be0e851af..661414e2d02a 100644 --- a/examples/research_projects/anytext/README.md +++ b/examples/research_projects/anytext/README.md @@ -20,12 +20,12 @@ To learn about how to convert the fine-tuned stable diffusion model, see the [Lo ```py import torch from pipeline_anytext import AnyTextPipeline -from text_controlnet import TextControlNetModel +from text_controlnet import AnyTextControlNetModel from diffusers import DDIMScheduler from diffusers.utils import load_image -controlnet = TextControlNetModel.from_pretrained("tolgacangoz/anytext-controlnet", torch_dtype=torch.float16, +controlnet = AnyTextControlNetModel.from_pretrained("tolgacangoz/anytext-controlnet", torch_dtype=torch.float16, variant="fp16") pipe = AnyTextPipeline.from_pretrained("tolgacangoz/anytext", controlnet=controlnet, torch_dtype=torch.float16, variant="fp16") diff --git a/examples/research_projects/anytext/text_controlnet.py b/examples/research_projects/anytext/anytext_controlnet.py similarity index 100% rename from examples/research_projects/anytext/text_controlnet.py rename to examples/research_projects/anytext/anytext_controlnet.py diff --git a/examples/research_projects/anytext/pipeline_anytext.py b/examples/research_projects/anytext/pipeline_anytext.py index 3a8d6017ae8c..41aac7c8511f 100644 --- a/examples/research_projects/anytext/pipeline_anytext.py +++ b/examples/research_projects/anytext/pipeline_anytext.py @@ -80,13 +80,13 @@ Examples: ```py >>> from pipeline_anytext import AnyTextPipeline - >>> from text_controlnet import TextControlNetModel + >>> from text_controlnet import AnyTextControlNetModel >>> from diffusers import DDIMScheduler >>> from diffusers.utils import load_image >>> import torch >>> # load control net and stable diffusion v1-5 - >>> text_controlnet = TextControlNetModel.from_pretrained("tolgacangoz/anytext-controlnet", torch_dtype=torch.float16, + >>> text_controlnet = AnyTextControlNetModel.from_pretrained("tolgacangoz/anytext-controlnet", torch_dtype=torch.float16, ... variant="fp16",) >>> pipe = DiffusionPipeline.from_pretrained("tolgacangoz/anytext", controlnet=text_controlnet, ... torch_dtype=torch.float16, variant="fp16", From d3dcf57fca45dbc3f5fbb1029fa4c10cb86cdb53 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tolga=20Cang=C3=B6z?= Date: Mon, 13 Jan 2025 12:01:29 +0300 Subject: [PATCH 64/87] [Docs] Update import statement for AnyTextControlNetModel in pipeline_anytext.py --- examples/research_projects/anytext/pipeline_anytext.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/research_projects/anytext/pipeline_anytext.py b/examples/research_projects/anytext/pipeline_anytext.py index 41aac7c8511f..906bf2e10b6b 100644 --- a/examples/research_projects/anytext/pipeline_anytext.py +++ b/examples/research_projects/anytext/pipeline_anytext.py @@ -80,7 +80,7 @@ Examples: ```py >>> from pipeline_anytext import AnyTextPipeline - >>> from text_controlnet import AnyTextControlNetModel + >>> from anytext_controlnet import AnyTextControlNetModel >>> from diffusers import DDIMScheduler >>> from diffusers.utils import load_image >>> import torch From 963fac0b89c150de3e7e0a69319a38b92b867482 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tolga=20Cang=C3=B6z?= Date: Mon, 13 Jan 2025 12:02:15 +0300 Subject: [PATCH 65/87] [Fix] Update import path for ControlNetModel, ControlNetOutput in anytext_controlnet.py --- examples/research_projects/anytext/anytext_controlnet.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/research_projects/anytext/anytext_controlnet.py b/examples/research_projects/anytext/anytext_controlnet.py index 8ad40af78063..abaad3a27173 100644 --- a/examples/research_projects/anytext/anytext_controlnet.py +++ b/examples/research_projects/anytext/anytext_controlnet.py @@ -17,7 +17,7 @@ from torch import nn from diffusers.configuration_utils import register_to_config -from diffusers.models.controlnet import ( +from diffusers.models.controlnets.controlnet import ( ControlNetModel, ControlNetOutput, ) From 2b6f08ba452fabed72f6952ffbb8803f301766cc Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tolga=20Cang=C3=B6z?= Date: Mon, 13 Jan 2025 18:57:41 +0300 Subject: [PATCH 66/87] Refactor AnyTextControlNet to use configurable conditioning embedding channels --- .../anytext/anytext_controlnet.py | 17 +++++++++-------- 1 file changed, 9 insertions(+), 8 deletions(-) diff --git a/examples/research_projects/anytext/anytext_controlnet.py b/examples/research_projects/anytext/anytext_controlnet.py index abaad3a27173..937cb554c659 100644 --- a/examples/research_projects/anytext/anytext_controlnet.py +++ b/examples/research_projects/anytext/anytext_controlnet.py @@ -39,9 +39,9 @@ class AnyTextControlNetConditioningEmbedding(nn.Module): def __init__( self, + conditioning_embedding_channels: int, glyph_channels=1, position_channels=1, - model_channels=320, ): super().__init__() @@ -83,7 +83,7 @@ def __init__( nn.SiLU(), ) - self.fuse_block = nn.Conv2d(256 + 64 + 4, model_channels, 3, padding=1) + self.fuse_block = nn.Conv2d(256 + 64 + 4, conditioning_embedding_channels, 3, padding=1) # self.glyph_block.load_state_dict(load_file("glyph_block.safetensors", device=str(self.device))) # self.position_block.load_state_dict(load_file("position_block.safetensors", device=str(self.device))) @@ -177,7 +177,7 @@ class conditioning with `class_embed_type` equal to `None`. def __init__( self, in_channels: int = 4, - conditioning_channels: int = 3, + conditioning_channels: int = 1, flip_sin_to_cos: bool = True, freq_shift: int = 0, down_block_types: Tuple[str, ...] = ( @@ -251,11 +251,12 @@ def __init__( # control net conditioning embedding # TODO: what happens ControlNetModel's self.controlnet_cond_embedding's memory occupation? - self.controlnet_cond_embedding = AnyTextControlNetConditioningEmbedding( - conditioning_embedding_channels=block_out_channels[0], - block_out_channels=conditioning_embedding_out_channels, - conditioning_channels=conditioning_channels, - ) + # self.controlnet_cond_embedding = AnyTextControlNetConditioningEmbedding( + # conditioning_embedding_channels=block_out_channels[0], + # glyph_channels=conditioning_channels, + # position_channels=conditioning_channels, + # ) + self.controlnet_cond_embedding = None def forward( self, From 9c43a652e6612cfd92188f87f7a1165bba709f6a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tolga=20Cang=C3=B6z?= Date: Mon, 20 Jan 2025 17:28:40 +0300 Subject: [PATCH 67/87] Complete control net conditioning embedding in AnyTextControlNetModel --- .../research_projects/anytext/anytext_controlnet.py | 12 +++++------- 1 file changed, 5 insertions(+), 7 deletions(-) diff --git a/examples/research_projects/anytext/anytext_controlnet.py b/examples/research_projects/anytext/anytext_controlnet.py index 937cb554c659..9fa97569b87a 100644 --- a/examples/research_projects/anytext/anytext_controlnet.py +++ b/examples/research_projects/anytext/anytext_controlnet.py @@ -250,13 +250,11 @@ def __init__( ) # control net conditioning embedding - # TODO: what happens ControlNetModel's self.controlnet_cond_embedding's memory occupation? - # self.controlnet_cond_embedding = AnyTextControlNetConditioningEmbedding( - # conditioning_embedding_channels=block_out_channels[0], - # glyph_channels=conditioning_channels, - # position_channels=conditioning_channels, - # ) - self.controlnet_cond_embedding = None + self.controlnet_cond_embedding = AnyTextControlNetConditioningEmbedding( + conditioning_embedding_channels=block_out_channels[0], + glyph_channels=conditioning_channels, + position_channels=conditioning_channels, + ) def forward( self, From b8ca0d67089480d104b68ea432675a9278ae59d7 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tolga=20Cang=C3=B6z?= Date: Thu, 20 Feb 2025 18:11:57 +0300 Subject: [PATCH 68/87] up --- .../anytext/pipeline_anytext.py | 26 +++++++++---------- 1 file changed, 12 insertions(+), 14 deletions(-) diff --git a/examples/research_projects/anytext/pipeline_anytext.py b/examples/research_projects/anytext/pipeline_anytext.py index 906bf2e10b6b..b961c3b37b0a 100644 --- a/examples/research_projects/anytext/pipeline_anytext.py +++ b/examples/research_projects/anytext/pipeline_anytext.py @@ -67,6 +67,8 @@ unscale_lora_layers, ) from diffusers.utils.torch_utils import is_compiled_module, is_torch_version, randn_tensor +from diffusers.configuration_utils import register_to_config, ConfigMixin +from diffusers.models.modeling_utils import ModelMixin checker = BasicTokenizer() @@ -88,7 +90,7 @@ >>> # load control net and stable diffusion v1-5 >>> text_controlnet = AnyTextControlNetModel.from_pretrained("tolgacangoz/anytext-controlnet", torch_dtype=torch.float16, ... variant="fp16",) - >>> pipe = DiffusionPipeline.from_pretrained("tolgacangoz/anytext", controlnet=text_controlnet, + >>> pipe = AnyTextPipeline.from_pretrained("tolgacangoz/anytext", controlnet=text_controlnet, ... torch_dtype=torch.float16, variant="fp16", ... ).to("cuda") @@ -150,7 +152,7 @@ def __init__( self.token_dim = token_dim self.proj = nn.Linear(40 * 64, token_dim) - self.proj.load_state_dict(load_file("proj.safetensors", device=str(embedder.device))) + # self.proj.load_state_dict(load_file("proj.safetensors", device=str(embedder.device))) if use_fp16: self.proj = self.proj.to(dtype=torch.float16) @@ -449,20 +451,19 @@ def get_ctcloss(self, preds, gt_text, weight): class TextEmbeddingModule(nn.Module): + # @register_to_config def __init__(self, font_path, use_fp16=False, device="cpu"): super().__init__() - self.use_fp16 = use_fp16 - self.device = device # TODO: Learn if the recommended font file is free to use self.font = ImageFont.truetype(font_path, 60) - self.frozen_CLIP_embedder_t3 = FrozenCLIPEmbedderT3(device=self.device, use_fp16=self.use_fp16) - self.embedding_manager = EmbeddingManager(self.frozen_CLIP_embedder_t3, use_fp16=self.use_fp16) - rec_model_dir = "OCR/ppv3_rec.pth" - self.text_predictor = create_predictor(rec_model_dir, device=self.device, use_fp16=self.use_fp16).eval() + self.frozen_CLIP_embedder_t3 = FrozenCLIPEmbedderT3(device=device, use_fp16=use_fp16) + self.embedding_manager = EmbeddingManager(self.frozen_CLIP_embedder_t3, use_fp16=use_fp16) + rec_model_dir = "./text_embedding_module/OCR/ppv3_rec.pth" + self.text_predictor = create_predictor(rec_model_dir, device=device, use_fp16=use_fp16).eval() args = {} args["rec_image_shape"] = "3, 48, 320" args["rec_batch_num"] = 6 - args["rec_char_dict_path"] = "OCR/ppocr_keys_v1.txt" + args["rec_char_dict_path"] = "./text_embedding_module/OCR/ppocr_keys_v1.txt" args["use_fp16"] = self.use_fp16 self.embedding_manager.recog = TextRecognizer(args, self.text_predictor) @@ -843,9 +844,6 @@ def insert_spaces(self, string, nSpace): def to(self, device): self.device = device - self.glyph_block = self.glyph_block.to(device) - self.position_block = self.position_block.to(device) - self.fuse_block = self.fuse_block.to(device) self.vae = self.vae.to(device) return self @@ -1011,8 +1009,8 @@ def __init__( safety_checker=safety_checker, feature_extractor=feature_extractor, image_encoder=image_encoder, - # text_embedding_module=text_embedding_module, - # auxiliary_latent_module=auxiliary_latent_module, + # text_embedding_module=self.text_embedding_module, + # auxiliary_latent_module=self.auxiliary_latent_module, ) self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor, do_convert_rgb=True) From 9657980a6eb3e8086f03f7997837b042c4799d6e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tolga=20Cang=C3=B6z?= Date: Fri, 21 Feb 2025 16:31:32 +0300 Subject: [PATCH 69/87] [FIX] Ensure embeddings use correct device in AnyTextControlNetModel --- examples/research_projects/anytext/anytext_controlnet.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/examples/research_projects/anytext/anytext_controlnet.py b/examples/research_projects/anytext/anytext_controlnet.py index 9fa97569b87a..81f65a8315a4 100644 --- a/examples/research_projects/anytext/anytext_controlnet.py +++ b/examples/research_projects/anytext/anytext_controlnet.py @@ -95,8 +95,8 @@ def __init__( # self.fuse_block = self.fuse_block.to(dtype=torch.float16) def forward(self, glyphs, positions, text_info): - glyph_embedding = self.glyph_block(glyphs) - position_embedding = self.position_block(positions) + glyph_embedding = self.glyph_block(glyphs.to(self.glyph_block[0].weight.device)) + position_embedding = self.position_block(positions.to(self.position_block[0].weight.device)) guided_hint = self.fuse_block(torch.cat([glyph_embedding, position_embedding, text_info["masked_x"]], dim=1)) return guided_hint @@ -390,7 +390,7 @@ def forward( # 2. pre-process sample = self.conv_in(sample) - controlnet_cond = self.controlnet_cond_embedding(controlnet_cond) + controlnet_cond = self.controlnet_cond_embedding(*controlnet_cond) sample = sample + controlnet_cond # 3. down From 25ea8be778b84531be5cb33bb1a8de1a14a1d41f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tolga=20Cang=C3=B6z?= Date: Fri, 21 Feb 2025 16:31:53 +0300 Subject: [PATCH 70/87] up --- examples/research_projects/anytext/anytext.py | 2152 +++++++++++++++++ 1 file changed, 2152 insertions(+) create mode 100644 examples/research_projects/anytext/anytext.py diff --git a/examples/research_projects/anytext/anytext.py b/examples/research_projects/anytext/anytext.py new file mode 100644 index 000000000000..d7ee7df4a31c --- /dev/null +++ b/examples/research_projects/anytext/anytext.py @@ -0,0 +1,2152 @@ +# Copyright 2024 The HuggingFace Team. All rights reserved. +# Copyright (c) Alibaba, Inc. and its affiliates. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# Based on [AnyText: Multilingual Visual Text Generation And Editing](https://huggingface.co/papers/2311.03054). +# Authors: Yuxiang Tuo, Wangmeng Xiang, Jun-Yan He, Yifeng Geng, Xuansong Xie +# Code: https://github.com/tyxsspa/AnyText with Apache-2.0 license +# +# Adapted to Diffusers by [M. Tolga Cangöz](https://github.com/tolgacangoz). + + +import inspect +import math +import os +import re +import sys +from functools import partial +from typing import Any, Callable, Dict, List, Optional, Tuple, Union + +import cv2 +import huggingface_hub +import numpy as np +import PIL.Image +import torch +import torch.nn.functional as F +from bert_tokenizer import BasicTokenizer +from easydict import EasyDict as edict +from diffusers.utils.constants import HF_MODULES_CACHE +from frozen_clip_embedder_t3 import FrozenCLIPEmbedderT3 +from ocr_recog.RecModel import RecModel +from PIL import Image, ImageDraw, ImageFont +from safetensors.torch import load_file +from skimage.transform._geometric import _umeyama as get_sym_mat +from torch import nn +from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer, CLIPVisionModelWithProjection + +from diffusers.callbacks import MultiPipelineCallbacks, PipelineCallback +from diffusers.image_processor import PipelineImageInput, VaeImageProcessor +from diffusers.loaders import ( + FromSingleFileMixin, + IPAdapterMixin, + StableDiffusionLoraLoaderMixin, + TextualInversionLoaderMixin, +) +from diffusers.models import AutoencoderKL, ControlNetModel, ImageProjection, UNet2DConditionModel +from diffusers.models.lora import adjust_lora_scale_text_encoder +from diffusers.pipelines.controlnet.multicontrolnet import MultiControlNetModel +from diffusers.pipelines.pipeline_utils import DiffusionPipeline, StableDiffusionMixin +from diffusers.pipelines.stable_diffusion.pipeline_output import StableDiffusionPipelineOutput +from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker +from diffusers.schedulers import KarrasDiffusionSchedulers +from diffusers.utils import ( + USE_PEFT_BACKEND, + deprecate, + logging, + replace_example_docstring, + scale_lora_layers, + unscale_lora_layers, +) +from diffusers.utils.torch_utils import is_compiled_module, is_torch_version, randn_tensor +from diffusers.configuration_utils import register_to_config, ConfigMixin +from diffusers.models.modeling_utils import ModelMixin +from huggingface_hub import hf_hub_download + + +checker = BasicTokenizer() + + +PLACE_HOLDER = "*" +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +EXAMPLE_DOC_STRING = """ + Examples: + ```py + >>> from pipeline_anytext import AnyTextPipeline + >>> from anytext_controlnet import AnyTextControlNetModel + >>> from diffusers import DDIMScheduler + >>> from diffusers.utils import load_image + >>> import torch + + >>> # load control net and stable diffusion v1-5 + >>> text_controlnet = AnyTextControlNetModel.from_pretrained("tolgacangoz/anytext-controlnet", torch_dtype=torch.float16, + ... variant="fp16",) + >>> pipe = AnyTextPipeline.from_pretrained("tolgacangoz/anytext", controlnet=text_controlnet, + ... torch_dtype=torch.float16, variant="fp16", + ... ).to("cuda") + + >>> pipe.scheduler = DDIMScheduler.from_config(pipe.scheduler.config) + >>> # uncomment following line if PyTorch>=2.0 is not installed for memory optimization + >>> #pipe.enable_xformers_memory_efficient_attention() + + >>> # uncomment following line if you want to offload the model to CPU for memory optimization + >>> # also remove the `.to("cuda")` part + >>> #pipe.enable_model_cpu_offload() + + >>> # generate image + >>> generator = torch.Generator("cpu").manual_seed(66273235) + >>> prompt = 'photo of caramel macchiato coffee on the table, top-down perspective, with "Any" "Text" written on it using cream' + >>> draw_pos = load_image("www.huggingface.co/a/AnyText/tree/main/examples/gen9.png") + >>> image = pipe(prompt, num_inference_steps=20, generator=generator, mode="generate", + ... draw_pos=draw_pos, + ... ).images[0] + >>> image + ``` +""" + + +def get_clip_token_for_string(tokenizer, string): + batch_encoding = tokenizer( + string, + truncation=True, + max_length=77, + return_length=True, + return_overflowing_tokens=False, + padding="max_length", + return_tensors="pt", + ) + tokens = batch_encoding["input_ids"] + assert ( + torch.count_nonzero(tokens - 49407) == 2 + ), f"String '{string}' maps to more than a single token. Please use another string" + return tokens[0, 1] + + +def get_recog_emb(encoder, img_list): + _img_list = [(img.repeat(1, 3, 1, 1) * 255)[0] for img in img_list] + encoder.predictor.eval() + _, preds_neck = encoder.pred_imglist(_img_list, show_debug=False) + return preds_neck + + +class EmbeddingManager(nn.Module): + def __init__( + self, + embedder, + placeholder_string="*", + use_fp16=False, + ): + super().__init__() + get_token_for_string = partial(get_clip_token_for_string, embedder.tokenizer) + token_dim = 768 + self.get_recog_emb = None + self.token_dim = token_dim + + self.proj = nn.Linear(40 * 64, token_dim) + proj_dir = hf_hub_download( + repo_id="tolgacangoz/anytext", + filename="text_embedding_module/proj.safetensors", + cache_dir=HF_MODULES_CACHE + ) + self.proj.load_state_dict(load_file(proj_dir, device=str(embedder.device))) + if use_fp16: + self.proj = self.proj.to(dtype=torch.float16) + + self.placeholder_token = get_token_for_string(placeholder_string) + + @torch.no_grad() + def encode_text(self, text_info): + if self.get_recog_emb is None: + self.get_recog_emb = partial(get_recog_emb, self.recog) + + gline_list = [] + for i in range(len(text_info["n_lines"])): # sample index in a batch + n_lines = text_info["n_lines"][i] + for j in range(n_lines): # line + gline_list += [text_info["gly_line"][j][i : i + 1]] + + if len(gline_list) > 0: + recog_emb = self.get_recog_emb(gline_list) + enc_glyph = self.proj(recog_emb.reshape(recog_emb.shape[0], -1).to(self.proj.weight.dtype)) + + self.text_embs_all = [] + n_idx = 0 + for i in range(len(text_info["n_lines"])): # sample index in a batch + n_lines = text_info["n_lines"][i] + text_embs = [] + for j in range(n_lines): # line + text_embs += [enc_glyph[n_idx : n_idx + 1]] + n_idx += 1 + self.text_embs_all += [text_embs] + + @torch.no_grad() + def forward( + self, + tokenized_text, + embedded_text, + ): + b, device = tokenized_text.shape[0], tokenized_text.device + for i in range(b): + idx = tokenized_text[i] == self.placeholder_token.to(device) + if sum(idx) > 0: + if i >= len(self.text_embs_all): + print("truncation for log images...") + break + text_emb = torch.cat(self.text_embs_all[i], dim=0) + if sum(idx) != len(text_emb): + print("truncation for long caption...") + text_emb = text_emb.to(embedded_text.device) + embedded_text[i][idx] = text_emb[: sum(idx)] + return embedded_text + + def embedding_parameters(self): + return self.parameters() + + +sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))) + + +def min_bounding_rect(img): + ret, thresh = cv2.threshold(img, 127, 255, 0) + contours, hierarchy = cv2.findContours(thresh, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE) + if len(contours) == 0: + print("Bad contours, using fake bbox...") + return np.array([[0, 0], [100, 0], [100, 100], [0, 100]]) + max_contour = max(contours, key=cv2.contourArea) + rect = cv2.minAreaRect(max_contour) + box = cv2.boxPoints(rect) + box = np.int0(box) + # sort + x_sorted = sorted(box, key=lambda x: x[0]) + left = x_sorted[:2] + right = x_sorted[2:] + left = sorted(left, key=lambda x: x[1]) + (tl, bl) = left + right = sorted(right, key=lambda x: x[1]) + (tr, br) = right + if tl[1] > bl[1]: + (tl, bl) = (bl, tl) + if tr[1] > br[1]: + (tr, br) = (br, tr) + return np.array([tl, tr, br, bl]) + + +def adjust_image(box, img): + pts1 = np.float32([box[0], box[1], box[2], box[3]]) + width = max(np.linalg.norm(pts1[0] - pts1[1]), np.linalg.norm(pts1[2] - pts1[3])) + height = max(np.linalg.norm(pts1[0] - pts1[3]), np.linalg.norm(pts1[1] - pts1[2])) + pts2 = np.float32([[0, 0], [width, 0], [width, height], [0, height]]) + # get transform matrix + M = get_sym_mat(pts1, pts2, estimate_scale=True) + C, H, W = img.shape + T = np.array([[2 / W, 0, -1], [0, 2 / H, -1], [0, 0, 1]]) + theta = np.linalg.inv(T @ M @ np.linalg.inv(T)) + theta = torch.from_numpy(theta[:2, :]).unsqueeze(0).type(torch.float32).to(img.device) + grid = F.affine_grid(theta, torch.Size([1, C, H, W]), align_corners=True) + result = F.grid_sample(img.unsqueeze(0), grid, align_corners=True) + result = torch.clamp(result.squeeze(0), 0, 255) + # crop + result = result[:, : int(height), : int(width)] + return result + + +""" +mask: numpy.ndarray, mask of textual, HWC +src_img: torch.Tensor, source image, CHW +""" + + +def crop_image(src_img, mask): + box = min_bounding_rect(mask) + result = adjust_image(box, src_img) + if len(result.shape) == 2: + result = torch.stack([result] * 3, axis=-1) + return result + + +def create_predictor(model_dir=None, model_lang="ch", device="cpu", use_fp16=False): + if model_dir is None or not os.path.exists(model_dir): + model_dir = hf_hub_download( + repo_id="tolgacangoz/anytext", + filename="text_embedding_module/OCR/ppv3_rec.pth", + cache_dir=HF_MODULES_CACHE + ) + if not os.path.exists(model_dir): + raise ValueError("not find model file path {}".format(model_dir)) + + if model_lang == "ch": + n_class = 6625 + elif model_lang == "en": + n_class = 97 + else: + raise ValueError(f"Unsupported OCR recog model_lang: {model_lang}") + rec_config = edict( + in_channels=3, + backbone=edict(type="MobileNetV1Enhance", scale=0.5, last_conv_stride=[1, 2], last_pool_type="avg"), + neck=edict(type="SequenceEncoder", encoder_type="svtr", dims=64, depth=2, hidden_dims=120, use_guide=True), + head=edict(type="CTCHead", fc_decay=0.00001, out_channels=n_class, return_feats=True), + ) + + rec_model = RecModel(rec_config) + state_dict = torch.load(model_dir, map_location=device) + rec_model.load_state_dict(state_dict) + return rec_model + + +def _check_image_file(path): + img_end = ("tiff", "tif", "bmp", "rgb", "jpg", "png", "jpeg") + return path.lower().endswith(tuple(img_end)) + + +def get_image_file_list(img_file): + imgs_lists = [] + if img_file is None or not os.path.exists(img_file): + raise Exception("not found any img file in {}".format(img_file)) + if os.path.isfile(img_file) and _check_image_file(img_file): + imgs_lists.append(img_file) + elif os.path.isdir(img_file): + for single_file in os.listdir(img_file): + file_path = os.path.join(img_file, single_file) + if os.path.isfile(file_path) and _check_image_file(file_path): + imgs_lists.append(file_path) + if len(imgs_lists) == 0: + raise Exception("not found any img file in {}".format(img_file)) + imgs_lists = sorted(imgs_lists) + return imgs_lists + + +class TextRecognizer(object): + def __init__(self, args, predictor): + self.rec_image_shape = [int(v) for v in args["rec_image_shape"].split(",")] + self.rec_batch_num = args["rec_batch_num"] + self.predictor = predictor + self.chars = self.get_char_dict(args["rec_char_dict_path"]) + self.char2id = {x: i for i, x in enumerate(self.chars)} + self.is_onnx = not isinstance(self.predictor, torch.nn.Module) + self.use_fp16 = args["use_fp16"] + + # img: CHW + def resize_norm_img(self, img, max_wh_ratio): + imgC, imgH, imgW = self.rec_image_shape + assert imgC == img.shape[0] + imgW = int((imgH * max_wh_ratio)) + + h, w = img.shape[1:] + ratio = w / float(h) + if math.ceil(imgH * ratio) > imgW: + resized_w = imgW + else: + resized_w = int(math.ceil(imgH * ratio)) + resized_image = torch.nn.functional.interpolate( + img.unsqueeze(0), + size=(imgH, resized_w), + mode="bilinear", + align_corners=True, + ) + resized_image /= 255.0 + resized_image -= 0.5 + resized_image /= 0.5 + padding_im = torch.zeros((imgC, imgH, imgW), dtype=torch.float32).to(img.device) + padding_im[:, :, 0:resized_w] = resized_image[0] + return padding_im + + # img_list: list of tensors with shape chw 0-255 + def pred_imglist(self, img_list, show_debug=False): + img_num = len(img_list) + assert img_num > 0 + # Calculate the aspect ratio of all text bars + width_list = [] + for img in img_list: + width_list.append(img.shape[2] / float(img.shape[1])) + # Sorting can speed up the recognition process + indices = torch.from_numpy(np.argsort(np.array(width_list))) + batch_num = self.rec_batch_num + preds_all = [None] * img_num + preds_neck_all = [None] * img_num + for beg_img_no in range(0, img_num, batch_num): + end_img_no = min(img_num, beg_img_no + batch_num) + norm_img_batch = [] + + imgC, imgH, imgW = self.rec_image_shape[:3] + max_wh_ratio = imgW / imgH + for ino in range(beg_img_no, end_img_no): + h, w = img_list[indices[ino]].shape[1:] + if h > w * 1.2: + img = img_list[indices[ino]] + img = torch.transpose(img, 1, 2).flip(dims=[1]) + img_list[indices[ino]] = img + h, w = img.shape[1:] + # wh_ratio = w * 1.0 / h + # max_wh_ratio = max(max_wh_ratio, wh_ratio) # comment to not use different ratio + for ino in range(beg_img_no, end_img_no): + norm_img = self.resize_norm_img(img_list[indices[ino]], max_wh_ratio) + if self.use_fp16: + norm_img = norm_img.half() + norm_img = norm_img.unsqueeze(0) + norm_img_batch.append(norm_img) + norm_img_batch = torch.cat(norm_img_batch, dim=0) + if show_debug: + for i in range(len(norm_img_batch)): + _img = norm_img_batch[i].permute(1, 2, 0).detach().cpu().numpy() + _img = (_img + 0.5) * 255 + _img = _img[:, :, ::-1] + file_name = f"{indices[beg_img_no + i]}" + if os.path.exists(file_name + ".jpg"): + file_name += "_2" # ori image + cv2.imwrite(file_name + ".jpg", _img) + if self.is_onnx: + input_dict = {} + input_dict[self.predictor.get_inputs()[0].name] = norm_img_batch.detach().cpu().numpy() + outputs = self.predictor.run(None, input_dict) + preds = {} + preds["ctc"] = torch.from_numpy(outputs[0]) + preds["ctc_neck"] = [torch.zeros(1)] * img_num + else: + preds = self.predictor(norm_img_batch.to(next(self.predictor.parameters()).device)) + for rno in range(preds["ctc"].shape[0]): + preds_all[indices[beg_img_no + rno]] = preds["ctc"][rno] + preds_neck_all[indices[beg_img_no + rno]] = preds["ctc_neck"][rno] + + return torch.stack(preds_all, dim=0), torch.stack(preds_neck_all, dim=0) + + def get_char_dict(self, character_dict_path): + character_str = [] + with open(character_dict_path, "rb") as fin: + lines = fin.readlines() + for line in lines: + line = line.decode("utf-8").strip("\n").strip("\r\n") + character_str.append(line) + dict_character = list(character_str) + dict_character = ["sos"] + dict_character + [" "] # eos is space + return dict_character + + def get_text(self, order): + char_list = [self.chars[text_id] for text_id in order] + return "".join(char_list) + + def decode(self, mat): + text_index = mat.detach().cpu().numpy().argmax(axis=1) + ignored_tokens = [0] + selection = np.ones(len(text_index), dtype=bool) + selection[1:] = text_index[1:] != text_index[:-1] + for ignored_token in ignored_tokens: + selection &= text_index != ignored_token + return text_index[selection], np.where(selection)[0] + + def get_ctcloss(self, preds, gt_text, weight): + if not isinstance(weight, torch.Tensor): + weight = torch.tensor(weight).to(preds.device) + ctc_loss = torch.nn.CTCLoss(reduction="none") + log_probs = preds.log_softmax(dim=2).permute(1, 0, 2) # NTC-->TNC + targets = [] + target_lengths = [] + for t in gt_text: + targets += [self.char2id.get(i, len(self.chars) - 1) for i in t] + target_lengths += [len(t)] + targets = torch.tensor(targets).to(preds.device) + target_lengths = torch.tensor(target_lengths).to(preds.device) + input_lengths = torch.tensor([log_probs.shape[0]] * (log_probs.shape[1])).to(preds.device) + loss = ctc_loss(log_probs, targets, input_lengths, target_lengths) + loss = loss / input_lengths * weight + return loss + + +class TextEmbeddingModule(nn.Module): + # @register_to_config + def __init__(self, font_path, use_fp16=False, device="cpu"): + super().__init__() + # TODO: Learn if the recommended font file is free to use + self.font = ImageFont.truetype(font_path, 60) + self.use_fp16 = use_fp16 + self.device = device + self.frozen_CLIP_embedder_t3 = FrozenCLIPEmbedderT3(device=device, use_fp16=use_fp16) + self.embedding_manager = EmbeddingManager(self.frozen_CLIP_embedder_t3, use_fp16=use_fp16) + rec_model_dir = "./text_embedding_module/OCR/ppv3_rec.pth" + self.text_predictor = create_predictor(rec_model_dir, device=device, use_fp16=use_fp16).eval() + args = {} + args["rec_image_shape"] = "3, 48, 320" + args["rec_batch_num"] = 6 + args["rec_char_dict_path"] = "./text_embedding_module/OCR/ppocr_keys_v1.txt" + args["rec_char_dict_path"] = hf_hub_download( + repo_id="tolgacangoz/anytext", + filename="text_embedding_module/OCR/ppocr_keys_v1.txt", + cache_dir=HF_MODULES_CACHE + ) + args["use_fp16"] = use_fp16 + self.embedding_manager.recog = TextRecognizer(args, self.text_predictor) + + @torch.no_grad() + def forward( + self, + prompt, + texts, + negative_prompt, + num_images_per_prompt, + mode, + draw_pos, + sort_priority="↕", + max_chars=77, + revise_pos=False, + h=512, + w=512, + ): + if prompt is None and texts is None: + raise ValueError("Prompt or texts must be provided!") + # preprocess pos_imgs(if numpy, make sure it's white pos in black bg) + if draw_pos is None: + pos_imgs = np.zeros((w, h, 1)) + if isinstance(draw_pos, PIL.Image.Image): + pos_imgs = np.array(draw_pos)[..., ::-1] + pos_imgs = 255 - pos_imgs + elif isinstance(draw_pos, str): + draw_pos = cv2.imread(draw_pos)[..., ::-1] + if draw_pos is None: + raise ValueError(f"Can't read draw_pos image from {draw_pos}!") + pos_imgs = 255 - draw_pos + elif isinstance(draw_pos, torch.Tensor): + pos_imgs = draw_pos.cpu().numpy() + else: + if not isinstance(draw_pos, np.ndarray): + raise ValueError(f"Unknown format of draw_pos: {type(draw_pos)}") + if mode == "edit": + pos_imgs = cv2.resize(pos_imgs, (w, h)) + pos_imgs = pos_imgs[..., 0:1] + pos_imgs = cv2.convertScaleAbs(pos_imgs) + _, pos_imgs = cv2.threshold(pos_imgs, 254, 255, cv2.THRESH_BINARY) + # separate pos_imgs + pos_imgs = self.separate_pos_imgs(pos_imgs, sort_priority) + if len(pos_imgs) == 0: + pos_imgs = [np.zeros((h, w, 1))] + n_lines = len(texts) + if len(pos_imgs) < n_lines: + if n_lines == 1 and texts[0] == " ": + pass # text-to-image without text + else: + raise ValueError( + f"Found {len(pos_imgs)} positions that < needed {n_lines} from prompt, check and try again!" + ) + elif len(pos_imgs) > n_lines: + str_warning = f"Warning: found {len(pos_imgs)} positions that > needed {n_lines} from prompt." + logger.warning(str_warning) + # get pre_pos, poly_list, hint that needed for anytext + pre_pos = [] + poly_list = [] + for input_pos in pos_imgs: + if input_pos.mean() != 0: + input_pos = input_pos[..., np.newaxis] if len(input_pos.shape) == 2 else input_pos + poly, pos_img = self.find_polygon(input_pos) + pre_pos += [pos_img / 255.0] + poly_list += [poly] + else: + pre_pos += [np.zeros((h, w, 1))] + poly_list += [None] + np_hint = np.sum(pre_pos, axis=0).clip(0, 1) + # prepare info dict + text_info = {} + text_info["glyphs"] = [] + text_info["gly_line"] = [] + text_info["positions"] = [] + text_info["n_lines"] = [len(texts)] * num_images_per_prompt + for i in range(len(texts)): + text = texts[i] + if len(text) > max_chars: + str_warning = f'"{text}" length > max_chars: {max_chars}, will be cut off...' + logger.warning(str_warning) + text = text[:max_chars] + gly_scale = 2 + if pre_pos[i].mean() != 0: + gly_line = self.draw_glyph(self.font, text) + glyphs = self.draw_glyph2( + self.font, text, poly_list[i], scale=gly_scale, width=w, height=h, add_space=False + ) + if revise_pos: + resize_gly = cv2.resize(glyphs, (pre_pos[i].shape[1], pre_pos[i].shape[0])) + new_pos = cv2.morphologyEx( + (resize_gly * 255).astype(np.uint8), + cv2.MORPH_CLOSE, + kernel=np.ones((resize_gly.shape[0] // 10, resize_gly.shape[1] // 10), dtype=np.uint8), + iterations=1, + ) + new_pos = new_pos[..., np.newaxis] if len(new_pos.shape) == 2 else new_pos + contours, _ = cv2.findContours(new_pos, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_NONE) + if len(contours) != 1: + str_warning = f"Fail to revise position {i} to bounding rect, remain position unchanged..." + logger.warning(str_warning) + else: + rect = cv2.minAreaRect(contours[0]) + poly = np.int0(cv2.boxPoints(rect)) + pre_pos[i] = cv2.drawContours(new_pos, [poly], -1, 255, -1) / 255.0 + else: + glyphs = np.zeros((h * gly_scale, w * gly_scale, 1)) + gly_line = np.zeros((80, 512, 1)) + pos = pre_pos[i] + text_info["glyphs"] += [self.arr2tensor(glyphs, num_images_per_prompt)] + text_info["gly_line"] += [self.arr2tensor(gly_line, num_images_per_prompt)] + text_info["positions"] += [self.arr2tensor(pos, num_images_per_prompt)] + + # hint = self.arr2tensor(np_hint, len(prompt)) + + self.embedding_manager.encode_text(text_info) + prompt_embeds = self.frozen_CLIP_embedder_t3.encode([prompt], embedding_manager=self.embedding_manager) + + self.embedding_manager.encode_text(text_info) + negative_prompt_embeds = self.frozen_CLIP_embedder_t3.encode( + [negative_prompt or ''], embedding_manager=self.embedding_manager + ) + + return prompt_embeds, negative_prompt_embeds, text_info, np_hint + + def arr2tensor(self, arr, bs): + arr = np.transpose(arr, (2, 0, 1)) + _arr = torch.from_numpy(arr.copy()).float().cpu() + if self.use_fp16: + _arr = _arr.half() + _arr = torch.stack([_arr for _ in range(bs)], dim=0) + return _arr + + def separate_pos_imgs(self, img, sort_priority, gap=102): + num_labels, labels, stats, centroids = cv2.connectedComponentsWithStats(img) + components = [] + for label in range(1, num_labels): + component = np.zeros_like(img) + component[labels == label] = 255 + components.append((component, centroids[label])) + if sort_priority == "↕": + fir, sec = 1, 0 # top-down first + elif sort_priority == "↔": + fir, sec = 0, 1 # left-right first + else: + raise ValueError(f"Unknown sort_priority: {sort_priority}") + components.sort(key=lambda c: (c[1][fir] // gap, c[1][sec] // gap)) + sorted_components = [c[0] for c in components] + return sorted_components + + def find_polygon(self, image, min_rect=False): + contours, hierarchy = cv2.findContours(image, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_NONE) + max_contour = max(contours, key=cv2.contourArea) # get contour with max area + if min_rect: + # get minimum enclosing rectangle + rect = cv2.minAreaRect(max_contour) + poly = np.int0(cv2.boxPoints(rect)) + else: + # get approximate polygon + epsilon = 0.01 * cv2.arcLength(max_contour, True) + poly = cv2.approxPolyDP(max_contour, epsilon, True) + n, _, xy = poly.shape + poly = poly.reshape(n, xy) + cv2.drawContours(image, [poly], -1, 255, -1) + return poly, image + + def draw_glyph(self, font, text): + g_size = 50 + W, H = (512, 80) + new_font = font.font_variant(size=g_size) + img = Image.new(mode="1", size=(W, H), color=0) + draw = ImageDraw.Draw(img) + left, top, right, bottom = new_font.getbbox(text) + text_width = max(right - left, 5) + text_height = max(bottom - top, 5) + ratio = min(W * 0.9 / text_width, H * 0.9 / text_height) + new_font = font.font_variant(size=int(g_size * ratio)) + + text_width, text_height = new_font.getsize(text) + offset_x, offset_y = new_font.getoffset(text) + x = (img.width - text_width) // 2 + y = (img.height - text_height) // 2 - offset_y // 2 + draw.text((x, y), text, font=new_font, fill="white") + img = np.expand_dims(np.array(img), axis=2).astype(np.float64) + return img + + def draw_glyph2(self, font, text, polygon, vertAng=10, scale=1, width=512, height=512, add_space=True): + enlarge_polygon = polygon * scale + rect = cv2.minAreaRect(enlarge_polygon) + box = cv2.boxPoints(rect) + box = np.int0(box) + w, h = rect[1] + angle = rect[2] + if angle < -45: + angle += 90 + angle = -angle + if w < h: + angle += 90 + + vert = False + if abs(angle) % 90 < vertAng or abs(90 - abs(angle) % 90) % 90 < vertAng: + _w = max(box[:, 0]) - min(box[:, 0]) + _h = max(box[:, 1]) - min(box[:, 1]) + if _h >= _w: + vert = True + angle = 0 + + img = np.zeros((height * scale, width * scale, 3), np.uint8) + img = Image.fromarray(img) + + # infer font size + image4ratio = Image.new("RGB", img.size, "white") + draw = ImageDraw.Draw(image4ratio) + _, _, _tw, _th = draw.textbbox(xy=(0, 0), text=text, font=font) + text_w = min(w, h) * (_tw / _th) + if text_w <= max(w, h): + # add space + if len(text) > 1 and not vert and add_space: + for i in range(1, 100): + text_space = self.insert_spaces(text, i) + _, _, _tw2, _th2 = draw.textbbox(xy=(0, 0), text=text_space, font=font) + if min(w, h) * (_tw2 / _th2) > max(w, h): + break + text = self.insert_spaces(text, i - 1) + font_size = min(w, h) * 0.80 + else: + shrink = 0.75 if vert else 0.85 + font_size = min(w, h) / (text_w / max(w, h)) * shrink + new_font = font.font_variant(size=int(font_size)) + + left, top, right, bottom = new_font.getbbox(text) + text_width = right - left + text_height = bottom - top + + layer = Image.new("RGBA", img.size, (0, 0, 0, 0)) + draw = ImageDraw.Draw(layer) + if not vert: + draw.text( + (rect[0][0] - text_width // 2, rect[0][1] - text_height // 2 - top), + text, + font=new_font, + fill=(255, 255, 255, 255), + ) + else: + x_s = min(box[:, 0]) + _w // 2 - text_height // 2 + y_s = min(box[:, 1]) + for c in text: + draw.text((x_s, y_s), c, font=new_font, fill=(255, 255, 255, 255)) + _, _t, _, _b = new_font.getbbox(c) + y_s += _b + + rotated_layer = layer.rotate(angle, expand=1, center=(rect[0][0], rect[0][1])) + + x_offset = int((img.width - rotated_layer.width) / 2) + y_offset = int((img.height - rotated_layer.height) / 2) + img.paste(rotated_layer, (x_offset, y_offset), rotated_layer) + img = np.expand_dims(np.array(img.convert("1")), axis=2).astype(np.float64) + return img + + def insert_spaces(self, string, nSpace): + if nSpace == 0: + return string + new_string = "" + for char in string: + new_string += char + " " * nSpace + return new_string[:-nSpace] + + def to(self, *args, **kwargs): + self.frozen_CLIP_embedder_t3 = self.frozen_CLIP_embedder_t3.to(*args, **kwargs) + self.embedding_manager = self.embedding_manager.to(*args, **kwargs) + self.text_predictor = self.text_predictor.to(*args, **kwargs) + self.device = self.frozen_CLIP_embedder_t3.device + return self + + +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents +def retrieve_latents( + encoder_output: torch.Tensor, generator: Optional[torch.Generator] = None, sample_mode: str = "sample" +): + if hasattr(encoder_output, "latent_dist") and sample_mode == "sample": + return encoder_output.latent_dist.sample(generator) + elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax": + return encoder_output.latent_dist.mode() + elif hasattr(encoder_output, "latents"): + return encoder_output.latents + else: + raise AttributeError("Could not access latents of provided encoder_output") + + +class AuxiliaryLatentModule(nn.Module): + def __init__( + self, + font_path, + vae=None, + device="cpu", + use_fp16=False, + ): + super().__init__() + self.font = ImageFont.truetype(font_path, 60) + self.use_fp16 = use_fp16 + self.device = device + + self.vae = vae.eval() if vae is not None else None + + @torch.no_grad() + def forward( + self, + text_info, + mode, + draw_pos, + ori_image, + num_images_per_prompt, + np_hint, + h=512, + w=512, + ): + if mode == "generate": + edit_image = np.ones((h, w, 3)) * 127.5 # empty mask image + elif mode == "edit": + if draw_pos is None or ori_image is None: + raise ValueError("Reference image and position image are needed for text editing!") + if isinstance(ori_image, str): + ori_image = cv2.imread(ori_image)[..., ::-1] + if ori_image is None: + raise ValueError(f"Can't read ori_image image from {ori_image}!") + elif isinstance(ori_image, torch.Tensor): + ori_image = ori_image.cpu().numpy() + else: + if not isinstance(ori_image, np.ndarray): + raise ValueError(f"Unknown format of ori_image: {type(ori_image)}") + edit_image = ori_image.clip(1, 255) # for mask reason + edit_image = self.check_channels(edit_image) + edit_image = self.resize_image( + edit_image, max_length=768 + ) # make w h multiple of 64, resize if w or h > max_length + + # get masked_x + masked_img = ((edit_image.astype(np.float32) / 127.5) - 1.0) * (1 - np_hint) + masked_img = np.transpose(masked_img, (2, 0, 1)) + device = next(self.vae.parameters()).device + masked_img = torch.from_numpy(masked_img.copy()).float().to(device) + if self.use_fp16: + masked_img = masked_img.half() + masked_x = (retrieve_latents(self.vae.encode(masked_img[None, ...])) * self.vae.config.scaling_factor).detach() + if self.use_fp16: + masked_x = masked_x.half() + text_info["masked_x"] = torch.cat([masked_x for _ in range(num_images_per_prompt)], dim=0) + + glyphs = torch.cat(text_info["glyphs"], dim=1).sum(dim=1, keepdim=True) + positions = torch.cat(text_info["positions"], dim=1).sum(dim=1, keepdim=True) + + return glyphs, positions, text_info + + def check_channels(self, image): + channels = image.shape[2] if len(image.shape) == 3 else 1 + if channels == 1: + image = cv2.cvtColor(image, cv2.COLOR_GRAY2BGR) + elif channels > 3: + image = image[:, :, :3] + return image + + def resize_image(self, img, max_length=768): + height, width = img.shape[:2] + max_dimension = max(height, width) + + if max_dimension > max_length: + scale_factor = max_length / max_dimension + new_width = int(round(width * scale_factor)) + new_height = int(round(height * scale_factor)) + new_size = (new_width, new_height) + img = cv2.resize(img, new_size) + height, width = img.shape[:2] + img = cv2.resize(img, (width - (width % 64), height - (height % 64))) + return img + + def insert_spaces(self, string, nSpace): + if nSpace == 0: + return string + new_string = "" + for char in string: + new_string += char + " " * nSpace + return new_string[:-nSpace] + + def to(self, *args, **kwargs): + self.vae = self.vae.to(*args, **kwargs) + self.device = self.vae.device + return self + + +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps +def retrieve_timesteps( + scheduler, + num_inference_steps: Optional[int] = None, + device: Optional[Union[str, torch.device]] = None, + timesteps: Optional[List[int]] = None, + sigmas: Optional[List[float]] = None, + **kwargs, +): + """ + Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles + custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`. + + Args: + scheduler (`SchedulerMixin`): + The scheduler to get timesteps from. + num_inference_steps (`int`): + The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps` + must be `None`. + device (`str` or `torch.device`, *optional*): + The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. + timesteps (`List[int]`, *optional*): + Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed, + `num_inference_steps` and `sigmas` must be `None`. + sigmas (`List[float]`, *optional*): + Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed, + `num_inference_steps` and `timesteps` must be `None`. + + Returns: + `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the + second element is the number of inference steps. + """ + if timesteps is not None and sigmas is not None: + raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values") + if timesteps is not None: + accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accepts_timesteps: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" timestep schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + elif sigmas is not None: + accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accept_sigmas: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" sigmas schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + else: + scheduler.set_timesteps(num_inference_steps, device=device, **kwargs) + timesteps = scheduler.timesteps + return timesteps, num_inference_steps + + +class AnyTextPipeline( + DiffusionPipeline, + StableDiffusionMixin, + TextualInversionLoaderMixin, + StableDiffusionLoraLoaderMixin, + IPAdapterMixin, + FromSingleFileMixin, +): + r""" + Pipeline for text-to-image generation using Stable Diffusion with ControlNet guidance. + + This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods + implemented for all pipelines (downloading, saving, running on a particular device, etc.). + + The pipeline also inherits the following loading methods: + - [`~loaders.TextualInversionLoaderMixin.load_textual_inversion`] for loading textual inversion embeddings + - [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_weights`] for loading LoRA weights + - [`~loaders.StableDiffusionLoraLoaderMixin.save_lora_weights`] for saving LoRA weights + - [`~loaders.FromSingleFileMixin.from_single_file`] for loading `.ckpt` files + - [`~loaders.IPAdapterMixin.load_ip_adapter`] for loading IP Adapters + + Args: + vae ([`AutoencoderKL`]): + Variational Auto-Encoder (VAE) model to encode and decode images to and from latent representations. + text_encoder ([`~transformers.CLIPTextModel`]): + Frozen text-encoder ([clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14)). + tokenizer ([`~transformers.CLIPTokenizer`]): + A `CLIPTokenizer` to tokenize text. + unet ([`UNet2DConditionModel`]): + A `UNet2DConditionModel` to denoise the encoded image latents. + controlnet ([`ControlNetModel`] or `List[ControlNetModel]`): + Provides additional conditioning to the `unet` during the denoising process. If you set multiple + ControlNets as a list, the outputs from each ControlNet are added together to create one combined + additional conditioning. + scheduler ([`SchedulerMixin`]): + A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of + [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`]. + safety_checker ([`StableDiffusionSafetyChecker`]): + Classification module that estimates whether generated images could be considered offensive or harmful. + Please refer to the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) for more details + about a model's potential harms. + feature_extractor ([`~transformers.CLIPImageProcessor`]): + A `CLIPImageProcessor` to extract features from generated images; used as inputs to the `safety_checker`. + """ + + model_cpu_offload_seq = "text_encoder->image_encoder->unet->vae" + _optional_components = ["safety_checker", "feature_extractor", "image_encoder"] + _exclude_from_cpu_offload = ["safety_checker"] + _callback_tensor_inputs = ["latents", "prompt_embeds", "negative_prompt_embeds"] + + def __init__( + self, + font_path: str, + vae: AutoencoderKL, + text_encoder: CLIPTextModel, + tokenizer: CLIPTokenizer, + unet: UNet2DConditionModel, + controlnet: Union[ControlNetModel, List[ControlNetModel], Tuple[ControlNetModel], MultiControlNetModel], + scheduler: KarrasDiffusionSchedulers, + safety_checker: StableDiffusionSafetyChecker, + feature_extractor: CLIPImageProcessor, + trust_remote_code: bool = False, + text_embedding_module: TextEmbeddingModule = None, + auxiliary_latent_module: AuxiliaryLatentModule = None, + image_encoder: CLIPVisionModelWithProjection = None, + requires_safety_checker: bool = True, + ): + super().__init__() + self.text_embedding_module = TextEmbeddingModule( + use_fp16=unet.dtype == torch.float16, device=unet.device, font_path=font_path + ) + self.auxiliary_latent_module = AuxiliaryLatentModule( + vae=vae, use_fp16=unet.dtype == torch.float16, device=unet.device, font_path=font_path + ) + + if safety_checker is None and requires_safety_checker: + logger.warning( + f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure" + " that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered" + " results in services or applications open to the public. Both the diffusers team and Hugging Face" + " strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling" + " it only for use-cases that involve analyzing network behavior or auditing its results. For more" + " information, please have a look at https://github.com/huggingface/diffusers/pull/254 ." + ) + + if safety_checker is not None and feature_extractor is None: + raise ValueError( + "Make sure to define a feature extractor when loading {self.__class__} if you want to use the safety" + " checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead." + ) + + if isinstance(controlnet, (list, tuple)): + controlnet = MultiControlNetModel(controlnet) + + self.register_modules( + vae=vae, + text_encoder=text_encoder, + tokenizer=tokenizer, + unet=unet, + controlnet=controlnet, + scheduler=scheduler, + safety_checker=safety_checker, + feature_extractor=feature_extractor, + image_encoder=image_encoder, + text_embedding_module=self.text_embedding_module, + auxiliary_latent_module=self.auxiliary_latent_module, + ) + self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) + self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor, do_convert_rgb=True) + self.control_image_processor = VaeImageProcessor( + vae_scale_factor=self.vae_scale_factor, do_convert_rgb=True, do_normalize=False + ) + self.register_to_config(requires_safety_checker=requires_safety_checker, font_path=font_path) + + def modify_prompt(self, prompt): + prompt = prompt.replace("“", '"') + prompt = prompt.replace("”", '"') + p = '"(.*?)"' + strs = re.findall(p, prompt) + if len(strs) == 0: + strs = [" "] + else: + for s in strs: + prompt = prompt.replace(f'"{s}"', f" {PLACE_HOLDER} ", 1) + if self.is_chinese(prompt): + if self.trans_pipe is None: + return None, None + old_prompt = prompt + prompt = self.trans_pipe(input=prompt + " .")["translation"][:-1] + print(f"Translate: {old_prompt} --> {prompt}") + return prompt, strs + + def is_chinese(self, text): + text = checker._clean_text(text) + for char in text: + cp = ord(char) + if checker._is_chinese_char(cp): + return True + return False + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline._encode_prompt + def _encode_prompt( + self, + prompt, + device, + num_images_per_prompt, + do_classifier_free_guidance, + negative_prompt=None, + prompt_embeds: Optional[torch.Tensor] = None, + negative_prompt_embeds: Optional[torch.Tensor] = None, + lora_scale: Optional[float] = None, + **kwargs, + ): + deprecation_message = "`_encode_prompt()` is deprecated and it will be removed in a future version. Use `encode_prompt()` instead. Also, be aware that the output format changed from a concatenated tensor to a tuple." + deprecate("_encode_prompt()", "1.0.0", deprecation_message, standard_warn=False) + + prompt_embeds_tuple = self.encode_prompt( + prompt=prompt, + device=device, + num_images_per_prompt=num_images_per_prompt, + do_classifier_free_guidance=do_classifier_free_guidance, + negative_prompt=negative_prompt, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + lora_scale=lora_scale, + **kwargs, + ) + + # concatenate for backwards comp + prompt_embeds = torch.cat([prompt_embeds_tuple[1], prompt_embeds_tuple[0]]) + + return prompt_embeds + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.encode_prompt + def encode_prompt( + self, + prompt, + device, + num_images_per_prompt, + do_classifier_free_guidance, + negative_prompt=None, + prompt_embeds: Optional[torch.Tensor] = None, + negative_prompt_embeds: Optional[torch.Tensor] = None, + lora_scale: Optional[float] = None, + clip_skip: Optional[int] = None, + ): + r""" + Encodes the prompt into text encoder hidden states. + + Args: + prompt (`str` or `List[str]`, *optional*): + prompt to be encoded + device: (`torch.device`): + torch device + num_images_per_prompt (`int`): + number of images that should be generated per prompt + do_classifier_free_guidance (`bool`): + whether to use classifier free guidance or not + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is + less than `1`). + prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + lora_scale (`float`, *optional*): + A LoRA scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded. + clip_skip (`int`, *optional*): + Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that + the output of the pre-final layer will be used for computing the prompt embeddings. + """ + # set lora scale so that monkey patched LoRA + # function of text encoder can correctly access it + if lora_scale is not None and isinstance(self, StableDiffusionLoraLoaderMixin): + self._lora_scale = lora_scale + + # dynamically adjust the LoRA scale + if not USE_PEFT_BACKEND: + adjust_lora_scale_text_encoder(self.text_encoder, lora_scale) + else: + scale_lora_layers(self.text_encoder, lora_scale) + + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + if prompt_embeds is None: + # textual inversion: process multi-vector tokens if necessary + if isinstance(self, TextualInversionLoaderMixin): + prompt = self.maybe_convert_prompt(prompt, self.tokenizer) + + text_inputs = self.tokenizer( + prompt, + padding="max_length", + max_length=self.tokenizer.model_max_length, + truncation=True, + return_tensors="pt", + ) + text_input_ids = text_inputs.input_ids + untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids + + if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal( + text_input_ids, untruncated_ids + ): + removed_text = self.tokenizer.batch_decode( + untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1] + ) + logger.warning( + "The following part of your input was truncated because CLIP can only handle sequences up to" + f" {self.tokenizer.model_max_length} tokens: {removed_text}" + ) + + if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask: + attention_mask = text_inputs.attention_mask.to(device) + else: + attention_mask = None + + if clip_skip is None: + prompt_embeds = self.text_encoder(text_input_ids.to(device), attention_mask=attention_mask) + prompt_embeds = prompt_embeds[0] + else: + prompt_embeds = self.text_encoder( + text_input_ids.to(device), attention_mask=attention_mask, output_hidden_states=True + ) + # Access the `hidden_states` first, that contains a tuple of + # all the hidden states from the encoder layers. Then index into + # the tuple to access the hidden states from the desired layer. + prompt_embeds = prompt_embeds[-1][-(clip_skip + 1)] + # We also need to apply the final LayerNorm here to not mess with the + # representations. The `last_hidden_states` that we typically use for + # obtaining the final prompt representations passes through the LayerNorm + # layer. + prompt_embeds = self.text_encoder.text_model.final_layer_norm(prompt_embeds) + + if self.text_encoder is not None: + prompt_embeds_dtype = self.text_encoder.dtype + elif self.unet is not None: + prompt_embeds_dtype = self.unet.dtype + else: + prompt_embeds_dtype = prompt_embeds.dtype + + prompt_embeds = prompt_embeds.to(dtype=prompt_embeds_dtype, device=device) + + bs_embed, seq_len, _ = prompt_embeds.shape + # duplicate text embeddings for each generation per prompt, using mps friendly method + prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) + prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1) + + # get unconditional embeddings for classifier free guidance + if do_classifier_free_guidance and negative_prompt_embeds is None: + uncond_tokens: List[str] + if negative_prompt is None: + uncond_tokens = [""] * batch_size + elif prompt is not None and type(prompt) is not type(negative_prompt): + raise TypeError( + f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" + f" {type(prompt)}." + ) + elif isinstance(negative_prompt, str): + uncond_tokens = [negative_prompt] + elif batch_size != len(negative_prompt): + raise ValueError( + f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" + f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" + " the batch size of `prompt`." + ) + else: + uncond_tokens = negative_prompt + + # textual inversion: process multi-vector tokens if necessary + if isinstance(self, TextualInversionLoaderMixin): + uncond_tokens = self.maybe_convert_prompt(uncond_tokens, self.tokenizer) + + max_length = prompt_embeds.shape[1] + uncond_input = self.tokenizer( + uncond_tokens, + padding="max_length", + max_length=max_length, + truncation=True, + return_tensors="pt", + ) + + if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask: + attention_mask = uncond_input.attention_mask.to(device) + else: + attention_mask = None + + negative_prompt_embeds = self.text_encoder( + uncond_input.input_ids.to(device), + attention_mask=attention_mask, + ) + negative_prompt_embeds = negative_prompt_embeds[0] + + if do_classifier_free_guidance: + # duplicate unconditional embeddings for each generation per prompt, using mps friendly method + seq_len = negative_prompt_embeds.shape[1] + + negative_prompt_embeds = negative_prompt_embeds.to(dtype=prompt_embeds_dtype, device=device) + + negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1) + negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) + + if self.text_encoder is not None: + if isinstance(self, StableDiffusionLoraLoaderMixin) and USE_PEFT_BACKEND: + # Retrieve the original scale by scaling back the LoRA layers + unscale_lora_layers(self.text_encoder, lora_scale) + + return prompt_embeds, negative_prompt_embeds + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.encode_image + def encode_image(self, image, device, num_images_per_prompt, output_hidden_states=None): + dtype = next(self.image_encoder.parameters()).dtype + + if not isinstance(image, torch.Tensor): + image = self.feature_extractor(image, return_tensors="pt").pixel_values + + image = image.to(device=device, dtype=dtype) + if output_hidden_states: + image_enc_hidden_states = self.image_encoder(image, output_hidden_states=True).hidden_states[-2] + image_enc_hidden_states = image_enc_hidden_states.repeat_interleave(num_images_per_prompt, dim=0) + uncond_image_enc_hidden_states = self.image_encoder( + torch.zeros_like(image), output_hidden_states=True + ).hidden_states[-2] + uncond_image_enc_hidden_states = uncond_image_enc_hidden_states.repeat_interleave( + num_images_per_prompt, dim=0 + ) + return image_enc_hidden_states, uncond_image_enc_hidden_states + else: + image_embeds = self.image_encoder(image).image_embeds + image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0) + uncond_image_embeds = torch.zeros_like(image_embeds) + + return image_embeds, uncond_image_embeds + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_ip_adapter_image_embeds + def prepare_ip_adapter_image_embeds( + self, ip_adapter_image, ip_adapter_image_embeds, device, num_images_per_prompt, do_classifier_free_guidance + ): + image_embeds = [] + if do_classifier_free_guidance: + negative_image_embeds = [] + if ip_adapter_image_embeds is None: + if not isinstance(ip_adapter_image, list): + ip_adapter_image = [ip_adapter_image] + + if len(ip_adapter_image) != len(self.unet.encoder_hid_proj.image_projection_layers): + raise ValueError( + f"`ip_adapter_image` must have same length as the number of IP Adapters. Got {len(ip_adapter_image)} images and {len(self.unet.encoder_hid_proj.image_projection_layers)} IP Adapters." + ) + + for single_ip_adapter_image, image_proj_layer in zip( + ip_adapter_image, self.unet.encoder_hid_proj.image_projection_layers + ): + output_hidden_state = not isinstance(image_proj_layer, ImageProjection) + single_image_embeds, single_negative_image_embeds = self.encode_image( + single_ip_adapter_image, device, 1, output_hidden_state + ) + + image_embeds.append(single_image_embeds[None, :]) + if do_classifier_free_guidance: + negative_image_embeds.append(single_negative_image_embeds[None, :]) + else: + for single_image_embeds in ip_adapter_image_embeds: + if do_classifier_free_guidance: + single_negative_image_embeds, single_image_embeds = single_image_embeds.chunk(2) + negative_image_embeds.append(single_negative_image_embeds) + image_embeds.append(single_image_embeds) + + ip_adapter_image_embeds = [] + for i, single_image_embeds in enumerate(image_embeds): + single_image_embeds = torch.cat([single_image_embeds] * num_images_per_prompt, dim=0) + if do_classifier_free_guidance: + single_negative_image_embeds = torch.cat([negative_image_embeds[i]] * num_images_per_prompt, dim=0) + single_image_embeds = torch.cat([single_negative_image_embeds, single_image_embeds], dim=0) + + single_image_embeds = single_image_embeds.to(device=device) + ip_adapter_image_embeds.append(single_image_embeds) + + return ip_adapter_image_embeds + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.run_safety_checker + def run_safety_checker(self, image, device, dtype): + if self.safety_checker is None: + has_nsfw_concept = None + else: + if torch.is_tensor(image): + feature_extractor_input = self.image_processor.postprocess(image, output_type="pil") + else: + feature_extractor_input = self.image_processor.numpy_to_pil(image) + safety_checker_input = self.feature_extractor(feature_extractor_input, return_tensors="pt").to(device) + image, has_nsfw_concept = self.safety_checker( + images=image, clip_input=safety_checker_input.pixel_values.to(dtype) + ) + return image, has_nsfw_concept + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.decode_latents + def decode_latents(self, latents): + deprecation_message = "The decode_latents method is deprecated and will be removed in 1.0.0. Please use VaeImageProcessor.postprocess(...) instead" + deprecate("decode_latents", "1.0.0", deprecation_message, standard_warn=False) + + latents = 1 / self.vae.config.scaling_factor * latents + image = self.vae.decode(latents, return_dict=False)[0] + image = (image / 2 + 0.5).clamp(0, 1) + # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16 + image = image.cpu().permute(0, 2, 3, 1).float().numpy() + return image + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs + def prepare_extra_step_kwargs(self, generator, eta): + # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature + # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. + # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502 + # and should be between [0, 1] + + accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) + extra_step_kwargs = {} + if accepts_eta: + extra_step_kwargs["eta"] = eta + + # check if the scheduler accepts generator + accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys()) + if accepts_generator: + extra_step_kwargs["generator"] = generator + return extra_step_kwargs + + def check_inputs( + self, + prompt, + # image, + callback_steps, + negative_prompt=None, + prompt_embeds=None, + negative_prompt_embeds=None, + ip_adapter_image=None, + ip_adapter_image_embeds=None, + controlnet_conditioning_scale=1.0, + control_guidance_start=0.0, + control_guidance_end=1.0, + callback_on_step_end_tensor_inputs=None, + ): + if callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0): + raise ValueError( + f"`callback_steps` has to be a positive integer but is {callback_steps} of type" + f" {type(callback_steps)}." + ) + + if callback_on_step_end_tensor_inputs is not None and not all( + k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs + ): + raise ValueError( + f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}" + ) + + if prompt is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt is None and prompt_embeds is None: + raise ValueError( + "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." + ) + elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + + if negative_prompt is not None and negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:" + f" {negative_prompt_embeds}. Please make sure to only forward one of the two." + ) + + if prompt_embeds is not None and negative_prompt_embeds is not None: + if prompt_embeds.shape != negative_prompt_embeds.shape: + raise ValueError( + "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but" + f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`" + f" {negative_prompt_embeds.shape}." + ) + + # Check `image` + is_compiled = hasattr(F, "scaled_dot_product_attention") and isinstance( + self.controlnet, torch._dynamo.eval_frame.OptimizedModule + ) + + # Check `controlnet_conditioning_scale` + if ( + isinstance(self.controlnet, ControlNetModel) + or is_compiled + and isinstance(self.controlnet._orig_mod, ControlNetModel) + ): + if not isinstance(controlnet_conditioning_scale, float): + print(controlnet_conditioning_scale) + raise TypeError("For single controlnet: `controlnet_conditioning_scale` must be type `float`.") + elif ( + isinstance(self.controlnet, MultiControlNetModel) + or is_compiled + and isinstance(self.controlnet._orig_mod, MultiControlNetModel) + ): + if isinstance(controlnet_conditioning_scale, list): + if any(isinstance(i, list) for i in controlnet_conditioning_scale): + raise ValueError( + "A single batch of varying conditioning scale settings (e.g. [[1.0, 0.5], [0.2, 0.8]]) is not supported at the moment. " + "The conditioning scale must be fixed across the batch." + ) + elif isinstance(controlnet_conditioning_scale, list) and len(controlnet_conditioning_scale) != len( + self.controlnet.nets + ): + raise ValueError( + "For multiple controlnets: When `controlnet_conditioning_scale` is specified as `list`, it must have" + " the same length as the number of controlnets" + ) + else: + assert False + + if not isinstance(control_guidance_start, (tuple, list)): + control_guidance_start = [control_guidance_start] + + if not isinstance(control_guidance_end, (tuple, list)): + control_guidance_end = [control_guidance_end] + + if len(control_guidance_start) != len(control_guidance_end): + raise ValueError( + f"`control_guidance_start` has {len(control_guidance_start)} elements, but `control_guidance_end` has {len(control_guidance_end)} elements. Make sure to provide the same number of elements to each list." + ) + + if isinstance(self.controlnet, MultiControlNetModel): + if len(control_guidance_start) != len(self.controlnet.nets): + raise ValueError( + f"`control_guidance_start`: {control_guidance_start} has {len(control_guidance_start)} elements but there are {len(self.controlnet.nets)} controlnets available. Make sure to provide {len(self.controlnet.nets)}." + ) + + for start, end in zip(control_guidance_start, control_guidance_end): + if start >= end: + raise ValueError( + f"control guidance start: {start} cannot be larger or equal to control guidance end: {end}." + ) + if start < 0.0: + raise ValueError(f"control guidance start: {start} can't be smaller than 0.") + if end > 1.0: + raise ValueError(f"control guidance end: {end} can't be larger than 1.0.") + + if ip_adapter_image is not None and ip_adapter_image_embeds is not None: + raise ValueError( + "Provide either `ip_adapter_image` or `ip_adapter_image_embeds`. Cannot leave both `ip_adapter_image` and `ip_adapter_image_embeds` defined." + ) + + if ip_adapter_image_embeds is not None: + if not isinstance(ip_adapter_image_embeds, list): + raise ValueError( + f"`ip_adapter_image_embeds` has to be of type `list` but is {type(ip_adapter_image_embeds)}" + ) + elif ip_adapter_image_embeds[0].ndim not in [3, 4]: + raise ValueError( + f"`ip_adapter_image_embeds` has to be a list of 3D or 4D tensors but is {ip_adapter_image_embeds[0].ndim}D" + ) + + def check_image(self, image, prompt, prompt_embeds): + image_is_pil = isinstance(image, PIL.Image.Image) + image_is_tensor = isinstance(image, torch.Tensor) + image_is_np = isinstance(image, np.ndarray) + image_is_pil_list = isinstance(image, list) and isinstance(image[0], PIL.Image.Image) + image_is_tensor_list = isinstance(image, list) and isinstance(image[0], torch.Tensor) + image_is_np_list = isinstance(image, list) and isinstance(image[0], np.ndarray) + + if ( + not image_is_pil + and not image_is_tensor + and not image_is_np + and not image_is_pil_list + and not image_is_tensor_list + and not image_is_np_list + ): + raise TypeError( + f"image must be passed and be one of PIL image, numpy array, torch tensor, list of PIL images, list of numpy arrays or list of torch tensors, but is {type(image)}" + ) + + if image_is_pil: + image_batch_size = 1 + else: + image_batch_size = len(image) + + if prompt is not None and isinstance(prompt, str): + prompt_batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + prompt_batch_size = len(prompt) + elif prompt_embeds is not None: + prompt_batch_size = prompt_embeds.shape[0] + + if image_batch_size != 1 and image_batch_size != prompt_batch_size: + raise ValueError( + f"If image batch size is not 1, image batch size must be same as prompt batch size. image batch size: {image_batch_size}, prompt batch size: {prompt_batch_size}" + ) + + def prepare_image( + self, + image, + width, + height, + batch_size, + num_images_per_prompt, + device, + dtype, + do_classifier_free_guidance=False, + guess_mode=False, + ): + image = self.control_image_processor.preprocess(image, height=height, width=width).to(dtype=torch.float32) + image_batch_size = image.shape[0] + + if image_batch_size == 1: + repeat_by = batch_size + else: + # image batch size is the same as prompt batch size + repeat_by = num_images_per_prompt + + image = image.repeat_interleave(repeat_by, dim=0) + + image = image.to(device=device, dtype=dtype) + + if do_classifier_free_guidance and not guess_mode: + image = torch.cat([image] * 2) + + return image + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents + def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None): + shape = ( + batch_size, + num_channels_latents, + int(height) // self.vae_scale_factor, + int(width) // self.vae_scale_factor, + ) + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + + if latents is None: + latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + else: + latents = latents.to(device) + + # scale the initial noise by the standard deviation required by the scheduler + latents = latents * self.scheduler.init_noise_sigma + return latents + + # Copied from diffusers.pipelines.latent_consistency_models.pipeline_latent_consistency_text2img.LatentConsistencyModelPipeline.get_guidance_scale_embedding + def get_guidance_scale_embedding( + self, w: torch.Tensor, embedding_dim: int = 512, dtype: torch.dtype = torch.float32 + ) -> torch.Tensor: + """ + See https://github.com/google-research/vdm/blob/dc27b98a554f65cdc654b800da5aa1846545d41b/model_vdm.py#L298 + + Args: + w (`torch.Tensor`): + Generate embedding vectors with a specified guidance scale to subsequently enrich timestep embeddings. + embedding_dim (`int`, *optional*, defaults to 512): + Dimension of the embeddings to generate. + dtype (`torch.dtype`, *optional*, defaults to `torch.float32`): + Data type of the generated embeddings. + + Returns: + `torch.Tensor`: Embedding vectors with shape `(len(w), embedding_dim)`. + """ + assert len(w.shape) == 1 + w = w * 1000.0 + + half_dim = embedding_dim // 2 + emb = torch.log(torch.tensor(10000.0)) / (half_dim - 1) + emb = torch.exp(torch.arange(half_dim, dtype=dtype) * -emb) + emb = w.to(dtype)[:, None] * emb[None, :] + emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1) + if embedding_dim % 2 == 1: # zero pad + emb = torch.nn.functional.pad(emb, (0, 1)) + assert emb.shape == (w.shape[0], embedding_dim) + return emb + + @property + def guidance_scale(self): + return self._guidance_scale + + @property + def clip_skip(self): + return self._clip_skip + + # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) + # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` + # corresponds to doing no classifier free guidance. + @property + def do_classifier_free_guidance(self): + return self._guidance_scale > 1 and self.unet.config.time_cond_proj_dim is None + + @property + def cross_attention_kwargs(self): + return self._cross_attention_kwargs + + @property + def num_timesteps(self): + return self._num_timesteps + + @torch.no_grad() + @replace_example_docstring(EXAMPLE_DOC_STRING) + def __call__( + self, + prompt: Union[str, List[str]] = None, + height: Optional[int] = None, + width: Optional[int] = None, + num_inference_steps: int = 50, + mode: Optional[str] = "generate", + draw_pos: Optional[Union[str, torch.Tensor]] = None, + ori_image: Optional[Union[str, torch.Tensor]] = None, + timesteps: List[int] = None, + sigmas: List[float] = None, + guidance_scale: float = 7.5, + negative_prompt: Optional[Union[str, List[str]]] = None, + num_images_per_prompt: Optional[int] = 1, + eta: float = 0.0, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + latents: Optional[torch.Tensor] = None, + prompt_embeds: Optional[torch.Tensor] = None, + negative_prompt_embeds: Optional[torch.Tensor] = None, + ip_adapter_image: Optional[PipelineImageInput] = None, + ip_adapter_image_embeds: Optional[List[torch.Tensor]] = None, + output_type: Optional[str] = "pil", + return_dict: bool = True, + cross_attention_kwargs: Optional[Dict[str, Any]] = None, + controlnet_conditioning_scale: Union[float, List[float]] = 1.0, + guess_mode: bool = False, + control_guidance_start: Union[float, List[float]] = 0.0, + control_guidance_end: Union[float, List[float]] = 1.0, + clip_skip: Optional[int] = None, + callback_on_step_end: Optional[ + Union[Callable[[int, int, Dict], None], PipelineCallback, MultiPipelineCallbacks] + ] = None, + callback_on_step_end_tensor_inputs: List[str] = ["latents"], + **kwargs, + ): + r""" + The call function to the pipeline for generation. + + Args: + prompt (`str` or `List[str]`, *optional*): + The prompt or prompts to guide image generation. If not defined, you need to pass `prompt_embeds`. + image (`torch.Tensor`, `PIL.Image.Image`, `np.ndarray`, `List[torch.Tensor]`, `List[PIL.Image.Image]`, `List[np.ndarray]`,: + `List[List[torch.Tensor]]`, `List[List[np.ndarray]]` or `List[List[PIL.Image.Image]]`): + The ControlNet input condition to provide guidance to the `unet` for generation. If the type is + specified as `torch.Tensor`, it is passed to ControlNet as is. `PIL.Image.Image` can also be accepted + as an image. The dimensions of the output image defaults to `image`'s dimensions. If height and/or + width are passed, `image` is resized accordingly. If multiple ControlNets are specified in `init`, + images must be passed as a list such that each element of the list can be correctly batched for input + to a single ControlNet. When `prompt` is a list, and if a list of images is passed for a single + ControlNet, each will be paired with each prompt in the `prompt` list. This also applies to multiple + ControlNets, where a list of image lists can be passed to batch for each prompt and each ControlNet. + height (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`): + The height in pixels of the generated image. + width (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`): + The width in pixels of the generated image. + num_inference_steps (`int`, *optional*, defaults to 50): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + timesteps (`List[int]`, *optional*): + Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument + in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is + passed will be used. Must be in descending order. + sigmas (`List[float]`, *optional*): + Custom sigmas to use for the denoising process with schedulers which support a `sigmas` argument in + their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed + will be used. + guidance_scale (`float`, *optional*, defaults to 7.5): + A higher guidance scale value encourages the model to generate images closely linked to the text + `prompt` at the expense of lower image quality. Guidance scale is enabled when `guidance_scale > 1`. + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts to guide what to not include in image generation. If not defined, you need to + pass `negative_prompt_embeds` instead. Ignored when not using guidance (`guidance_scale < 1`). + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + eta (`float`, *optional*, defaults to 0.0): + Corresponds to parameter eta (η) from the [DDIM](https://arxiv.org/abs/2010.02502) paper. Only applies + to the [`~schedulers.DDIMScheduler`], and is ignored in other schedulers. + generator (`torch.Generator` or `List[torch.Generator]`, *optional*): + A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make + generation deterministic. + latents (`torch.Tensor`, *optional*): + Pre-generated noisy latents sampled from a Gaussian distribution, to be used as inputs for image + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor is generated by sampling using the supplied random `generator`. + prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not + provided, text embeddings are generated from the `prompt` input argument. + negative_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs (prompt weighting). If + not provided, `negative_prompt_embeds` are generated from the `negative_prompt` input argument. + ip_adapter_image: (`PipelineImageInput`, *optional*): Optional image input to work with IP Adapters. + ip_adapter_image_embeds (`List[torch.Tensor]`, *optional*): + Pre-generated image embeddings for IP-Adapter. It should be a list of length same as number of + IP-adapters. Each element should be a tensor of shape `(batch_size, num_images, emb_dim)`. It should + contain the negative image embedding if `do_classifier_free_guidance` is set to `True`. If not + provided, embeddings are computed from the `ip_adapter_image` input argument. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generated image. Choose between `PIL.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a + plain tuple. + callback (`Callable`, *optional*): + A function that calls every `callback_steps` steps during inference. The function is called with the + following arguments: `callback(step: int, timestep: int, latents: torch.Tensor)`. + callback_steps (`int`, *optional*, defaults to 1): + The frequency at which the `callback` function is called. If not specified, the callback is called at + every step. + cross_attention_kwargs (`dict`, *optional*): + A kwargs dictionary that if specified is passed along to the [`AttentionProcessor`] as defined in + [`self.processor`](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). + controlnet_conditioning_scale (`float` or `List[float]`, *optional*, defaults to 1.0): + The outputs of the ControlNet are multiplied by `controlnet_conditioning_scale` before they are added + to the residual in the original `unet`. If multiple ControlNets are specified in `init`, you can set + the corresponding scale as a list. + guess_mode (`bool`, *optional*, defaults to `False`): + The ControlNet encoder tries to recognize the content of the input image even if you remove all + prompts. A `guidance_scale` value between 3.0 and 5.0 is recommended. + control_guidance_start (`float` or `List[float]`, *optional*, defaults to 0.0): + The percentage of total steps at which the ControlNet starts applying. + control_guidance_end (`float` or `List[float]`, *optional*, defaults to 1.0): + The percentage of total steps at which the ControlNet stops applying. + clip_skip (`int`, *optional*): + Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that + the output of the pre-final layer will be used for computing the prompt embeddings. + callback_on_step_end (`Callable`, `PipelineCallback`, `MultiPipelineCallbacks`, *optional*): + A function or a subclass of `PipelineCallback` or `MultiPipelineCallbacks` that is called at the end of + each denoising step during the inference. with the following arguments: `callback_on_step_end(self: + DiffusionPipeline, step: int, timestep: int, callback_kwargs: Dict)`. `callback_kwargs` will include a + list of all tensors as specified by `callback_on_step_end_tensor_inputs`. + callback_on_step_end_tensor_inputs (`List`, *optional*): + The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list + will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the + `._callback_tensor_inputs` attribute of your pipeline class. + + Examples: + + Returns: + [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`: + If `return_dict` is `True`, [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] is returned, + otherwise a `tuple` is returned where the first element is a list with the generated images and the + second element is a list of `bool`s indicating whether the corresponding generated image contains + "not-safe-for-work" (nsfw) content. + """ + + callback = kwargs.pop("callback", None) + callback_steps = kwargs.pop("callback_steps", None) + + if callback is not None: + deprecate( + "callback", + "1.0.0", + "Passing `callback` as an input argument to `__call__` is deprecated, consider using `callback_on_step_end`", + ) + if callback_steps is not None: + deprecate( + "callback_steps", + "1.0.0", + "Passing `callback_steps` as an input argument to `__call__` is deprecated, consider using `callback_on_step_end`", + ) + + if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)): + callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs + + controlnet = self.controlnet._orig_mod if is_compiled_module(self.controlnet) else self.controlnet + + # align format for control guidance + if not isinstance(control_guidance_start, list) and isinstance(control_guidance_end, list): + control_guidance_start = len(control_guidance_end) * [control_guidance_start] + elif not isinstance(control_guidance_end, list) and isinstance(control_guidance_start, list): + control_guidance_end = len(control_guidance_start) * [control_guidance_end] + elif not isinstance(control_guidance_start, list) and not isinstance(control_guidance_end, list): + mult = len(controlnet.nets) if isinstance(controlnet, MultiControlNetModel) else 1 + control_guidance_start, control_guidance_end = ( + mult * [control_guidance_start], + mult * [control_guidance_end], + ) + + # 1. Check inputs. Raise error if not correct + self.check_inputs( + prompt, + # image, + callback_steps, + negative_prompt, + prompt_embeds, + negative_prompt_embeds, + ip_adapter_image, + ip_adapter_image_embeds, + controlnet_conditioning_scale, + control_guidance_start, + control_guidance_end, + callback_on_step_end_tensor_inputs, + ) + + self._guidance_scale = guidance_scale + self._clip_skip = clip_skip + self._cross_attention_kwargs = cross_attention_kwargs + + # 2. Define call parameters + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + device = self._execution_device + + if isinstance(controlnet, MultiControlNetModel) and isinstance(controlnet_conditioning_scale, float): + controlnet_conditioning_scale = [controlnet_conditioning_scale] * len(controlnet.nets) + + global_pool_conditions = ( + controlnet.config.global_pool_conditions + if isinstance(controlnet, ControlNetModel) + else controlnet.nets[0].config.global_pool_conditions + ) + guess_mode = guess_mode or global_pool_conditions + + prompt, texts = self.modify_prompt(prompt) + + # 3. Encode input prompt + text_encoder_lora_scale = ( + self.cross_attention_kwargs.get("scale", None) if self.cross_attention_kwargs is not None else None + ) + draw_pos = draw_pos.to(device=device) if isinstance(draw_pos, torch.Tensor) else draw_pos + prompt_embeds, negative_prompt_embeds, text_info, np_hint = self.text_embedding_module( + prompt, + texts, + negative_prompt, + num_images_per_prompt, + mode, + draw_pos, + ) + + # For classifier free guidance, we need to do two forward passes. + # Here we concatenate the unconditional and text embeddings into a single batch + # to avoid doing two forward passes + if self.do_classifier_free_guidance: + prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds]) + + if ip_adapter_image is not None or ip_adapter_image_embeds is not None: + image_embeds = self.prepare_ip_adapter_image_embeds( + ip_adapter_image, + ip_adapter_image_embeds, + device, + batch_size * num_images_per_prompt, + self.do_classifier_free_guidance, + ) + + # 3.5 Optionally get Guidance Scale Embedding + timestep_cond = None + if self.unet.config.time_cond_proj_dim is not None: + guidance_scale_tensor = torch.tensor(self.guidance_scale - 1).repeat(batch_size * num_images_per_prompt) + timestep_cond = self.get_guidance_scale_embedding( + guidance_scale_tensor, embedding_dim=self.unet.config.time_cond_proj_dim + ).to(device=device, dtype=latents.dtype) + + # 4. Prepare image + if isinstance(controlnet, ControlNetModel): + # image = self.prepare_image( + # image=image, + # width=width, + # height=height, + # batch_size=batch_size * num_images_per_prompt, + # num_images_per_prompt=num_images_per_prompt, + # device=device, + # dtype=controlnet.dtype, + # do_classifier_free_guidance=self.do_classifier_free_guidance, + # guess_mode=guess_mode, + # ) + # height, width = image.shape[-2:] + guided_hint = self.auxiliary_latent_module( + text_info=text_info, + mode=mode, + draw_pos=draw_pos, + ori_image=ori_image, + num_images_per_prompt=num_images_per_prompt, + np_hint=np_hint, + ) + height, width = 512, 512 + # elif isinstance(controlnet, MultiControlNetModel): + # images = [] + + # # Nested lists as ControlNet condition + # if isinstance(image[0], list): + # # Transpose the nested image list + # image = [list(t) for t in zip(*image)] + + # for image_ in image: + # image_ = self.prepare_image( + # image=image_, + # width=width, + # height=height, + # batch_size=batch_size * num_images_per_prompt, + # num_images_per_prompt=num_images_per_prompt, + # device=device, + # dtype=controlnet.dtype, + # do_classifier_free_guidance=self.do_classifier_free_guidance, + # guess_mode=guess_mode, + # ) + + # images.append(image_) + + # image = images + # height, width = image[0].shape[-2:] + else: + assert False + + # 5. Prepare timesteps + timesteps, num_inference_steps = retrieve_timesteps( + self.scheduler, num_inference_steps, device, timesteps, sigmas + ) + self._num_timesteps = len(timesteps) + + # 6. Prepare latent variables + num_channels_latents = self.unet.config.in_channels + latents = self.prepare_latents( + batch_size * num_images_per_prompt, + num_channels_latents, + height, + width, + prompt_embeds.dtype, + device, + generator, + latents, + ) + + # 7. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline + extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) + + # 7.1 Add image embeds for IP-Adapter + added_cond_kwargs = ( + {"image_embeds": image_embeds} + if ip_adapter_image is not None or ip_adapter_image_embeds is not None + else None + ) + + # 7.2 Create tensor stating which controlnets to keep + controlnet_keep = [] + for i in range(len(timesteps)): + keeps = [ + 1.0 - float(i / len(timesteps) < s or (i + 1) / len(timesteps) > e) + for s, e in zip(control_guidance_start, control_guidance_end) + ] + controlnet_keep.append(keeps[0] if isinstance(controlnet, ControlNetModel) else keeps) + + # 8. Denoising loop + num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order + is_unet_compiled = is_compiled_module(self.unet) + is_controlnet_compiled = is_compiled_module(self.controlnet) + is_torch_higher_equal_2_1 = is_torch_version(">=", "2.1") + with self.progress_bar(total=num_inference_steps) as progress_bar: + for i, t in enumerate(timesteps): + # Relevant thread: + # https://dev-discuss.pytorch.org/t/cudagraphs-in-pytorch-2-0/1428 + if (is_unet_compiled and is_controlnet_compiled) and is_torch_higher_equal_2_1: + torch._inductor.cudagraph_mark_step_begin() + # expand the latents if we are doing classifier free guidance + latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents + latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) + + # controlnet(s) inference + if guess_mode and self.do_classifier_free_guidance: + # Infer ControlNet only for the conditional batch. + control_model_input = latents + control_model_input = self.scheduler.scale_model_input(control_model_input, t) + controlnet_prompt_embeds = prompt_embeds.chunk(2)[1] + else: + control_model_input = latent_model_input + controlnet_prompt_embeds = prompt_embeds + + if isinstance(controlnet_keep[i], list): + cond_scale = [c * s for c, s in zip(controlnet_conditioning_scale, controlnet_keep[i])] + else: + controlnet_cond_scale = controlnet_conditioning_scale + if isinstance(controlnet_cond_scale, list): + controlnet_cond_scale = controlnet_cond_scale[0] + cond_scale = controlnet_cond_scale * controlnet_keep[i] + + down_block_res_samples, mid_block_res_sample = self.controlnet( + control_model_input, + t, + encoder_hidden_states=controlnet_prompt_embeds, + controlnet_cond=guided_hint, + conditioning_scale=cond_scale, + guess_mode=guess_mode, + return_dict=False, + ) + + if guess_mode and self.do_classifier_free_guidance: + # Inferred ControlNet only for the conditional batch. + # To apply the output of ControlNet to both the unconditional and conditional batches, + # add 0 to the unconditional batch to keep it unchanged. + down_block_res_samples = [torch.cat([torch.zeros_like(d), d]) for d in down_block_res_samples] + mid_block_res_sample = torch.cat([torch.zeros_like(mid_block_res_sample), mid_block_res_sample]) + + # predict the noise residual + noise_pred = self.unet( + latent_model_input, + t, + encoder_hidden_states=prompt_embeds, + timestep_cond=timestep_cond, + cross_attention_kwargs=self.cross_attention_kwargs, + down_block_additional_residuals=down_block_res_samples, + mid_block_additional_residual=mid_block_res_sample, + added_cond_kwargs=added_cond_kwargs, + return_dict=False, + )[0] + + # perform guidance + if self.do_classifier_free_guidance: + noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) + noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond) + + # compute the previous noisy sample x_t -> x_t-1 + latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0] + + if callback_on_step_end is not None: + callback_kwargs = {} + for k in callback_on_step_end_tensor_inputs: + callback_kwargs[k] = locals()[k] + callback_outputs = callback_on_step_end(self, i, t, callback_kwargs) + + latents = callback_outputs.pop("latents", latents) + prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds) + negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds) + + # call the callback, if provided + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + progress_bar.update() + if callback is not None and i % callback_steps == 0: + step_idx = i // getattr(self.scheduler, "order", 1) + callback(step_idx, t, latents) + + # If we do sequential model offloading, let's offload unet and controlnet + # manually for max memory savings + if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None: + self.unet.to("cpu") + self.controlnet.to("cpu") + torch.cuda.empty_cache() + + if not output_type == "latent": + image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False, generator=generator)[ + 0 + ] + image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype) + else: + image = latents + has_nsfw_concept = None + + if has_nsfw_concept is None: + do_denormalize = [True] * image.shape[0] + else: + do_denormalize = [not has_nsfw for has_nsfw in has_nsfw_concept] + + image = self.image_processor.postprocess(image, output_type=output_type, do_denormalize=do_denormalize) + + # Offload all models + self.maybe_free_model_hooks() + + if not return_dict: + return (image, has_nsfw_concept) + + return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept) + + def to(self, *args, **kwargs): + super().to(*args, **kwargs) + self.text_embedding_module.to(*args, **kwargs) + self.auxiliary_latent_module.to(*args, **kwargs) + return self From 2be7bca1184d06f0dc19dac68615be6c813b76e1 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tolga=20Cang=C3=B6z?= Date: Fri, 21 Feb 2025 16:33:07 +0300 Subject: [PATCH 71/87] up --- .../anytext/pipeline_anytext.py | 2118 ----------------- 1 file changed, 2118 deletions(-) delete mode 100644 examples/research_projects/anytext/pipeline_anytext.py diff --git a/examples/research_projects/anytext/pipeline_anytext.py b/examples/research_projects/anytext/pipeline_anytext.py deleted file mode 100644 index b961c3b37b0a..000000000000 --- a/examples/research_projects/anytext/pipeline_anytext.py +++ /dev/null @@ -1,2118 +0,0 @@ -# Copyright 2024 The HuggingFace Team. All rights reserved. -# Copyright (c) Alibaba, Inc. and its affiliates. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# -# Based on [AnyText: Multilingual Visual Text Generation And Editing](https://huggingface.co/papers/2311.03054). -# Authors: Yuxiang Tuo, Wangmeng Xiang, Jun-Yan He, Yifeng Geng, Xuansong Xie -# Code: https://github.com/tyxsspa/AnyText with Apache-2.0 license -# -# Adapted to Diffusers by [M. Tolga Cangöz](https://github.com/tolgacangoz). - - -import inspect -import math -import os -import re -import sys -from functools import partial -from typing import Any, Callable, Dict, List, Optional, Tuple, Union - -import cv2 -import numpy as np -import PIL.Image -import torch -import torch.nn.functional as F -from bert_tokenizer import BasicTokenizer -from easydict import EasyDict as edict -from frozen_clip_embedder_t3 import FrozenCLIPEmbedderT3 -from ocr_recog.RecModel import RecModel -from PIL import Image, ImageDraw, ImageFont -from safetensors.torch import load_file -from skimage.transform._geometric import _umeyama as get_sym_mat -from torch import nn -from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer, CLIPVisionModelWithProjection - -from diffusers.callbacks import MultiPipelineCallbacks, PipelineCallback -from diffusers.image_processor import PipelineImageInput, VaeImageProcessor -from diffusers.loaders import ( - FromSingleFileMixin, - IPAdapterMixin, - StableDiffusionLoraLoaderMixin, - TextualInversionLoaderMixin, -) -from diffusers.models import AutoencoderKL, ControlNetModel, ImageProjection, UNet2DConditionModel -from diffusers.models.lora import adjust_lora_scale_text_encoder -from diffusers.pipelines.controlnet.multicontrolnet import MultiControlNetModel -from diffusers.pipelines.pipeline_utils import DiffusionPipeline, StableDiffusionMixin -from diffusers.pipelines.stable_diffusion.pipeline_output import StableDiffusionPipelineOutput -from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker -from diffusers.schedulers import KarrasDiffusionSchedulers -from diffusers.utils import ( - USE_PEFT_BACKEND, - deprecate, - logging, - replace_example_docstring, - scale_lora_layers, - unscale_lora_layers, -) -from diffusers.utils.torch_utils import is_compiled_module, is_torch_version, randn_tensor -from diffusers.configuration_utils import register_to_config, ConfigMixin -from diffusers.models.modeling_utils import ModelMixin - - -checker = BasicTokenizer() - - -PLACE_HOLDER = "*" -logger = logging.get_logger(__name__) # pylint: disable=invalid-name - - -EXAMPLE_DOC_STRING = """ - Examples: - ```py - >>> from pipeline_anytext import AnyTextPipeline - >>> from anytext_controlnet import AnyTextControlNetModel - >>> from diffusers import DDIMScheduler - >>> from diffusers.utils import load_image - >>> import torch - - >>> # load control net and stable diffusion v1-5 - >>> text_controlnet = AnyTextControlNetModel.from_pretrained("tolgacangoz/anytext-controlnet", torch_dtype=torch.float16, - ... variant="fp16",) - >>> pipe = AnyTextPipeline.from_pretrained("tolgacangoz/anytext", controlnet=text_controlnet, - ... torch_dtype=torch.float16, variant="fp16", - ... ).to("cuda") - - >>> pipe.scheduler = DDIMScheduler.from_config(pipe.scheduler.config) - >>> # uncomment following line if PyTorch>=2.0 is not installed for memory optimization - >>> #pipe.enable_xformers_memory_efficient_attention() - - >>> # uncomment following line if you want to offload the model to CPU for memory optimization - >>> # also remove the `.to("cuda")` part - >>> #pipe.enable_model_cpu_offload() - - >>> # generate image - >>> generator = torch.Generator("cpu").manual_seed(66273235) - >>> prompt = 'photo of caramel macchiato coffee on the table, top-down perspective, with "Any" "Text" written on it using cream' - >>> draw_pos = load_image("www.huggingface.co/a/AnyText/tree/main/examples/gen9.png") - >>> image = pipe(prompt, num_inference_steps=20, generator=generator, mode="generate", - ... draw_pos=draw_pos, - ... ).images[0] - >>> image - ``` -""" - - -def get_clip_token_for_string(tokenizer, string): - batch_encoding = tokenizer( - string, - truncation=True, - max_length=77, - return_length=True, - return_overflowing_tokens=False, - padding="max_length", - return_tensors="pt", - ) - tokens = batch_encoding["input_ids"] - assert ( - torch.count_nonzero(tokens - 49407) == 2 - ), f"String '{string}' maps to more than a single token. Please use another string" - return tokens[0, 1] - - -def get_recog_emb(encoder, img_list): - _img_list = [(img.repeat(1, 3, 1, 1) * 255)[0] for img in img_list] - encoder.predictor.eval() - _, preds_neck = encoder.pred_imglist(_img_list, show_debug=False) - return preds_neck - - -class EmbeddingManager(nn.Module): - def __init__( - self, - embedder, - placeholder_string="*", - use_fp16=False, - ): - super().__init__() - get_token_for_string = partial(get_clip_token_for_string, embedder.tokenizer) - token_dim = 768 - self.get_recog_emb = None - self.token_dim = token_dim - - self.proj = nn.Linear(40 * 64, token_dim) - # self.proj.load_state_dict(load_file("proj.safetensors", device=str(embedder.device))) - if use_fp16: - self.proj = self.proj.to(dtype=torch.float16) - - self.placeholder_token = get_token_for_string(placeholder_string) - - @torch.no_grad() - def encode_text(self, text_info): - if self.get_recog_emb is None: - self.get_recog_emb = partial(get_recog_emb, self.recog) - - gline_list = [] - for i in range(len(text_info["n_lines"])): # sample index in a batch - n_lines = text_info["n_lines"][i] - for j in range(n_lines): # line - gline_list += [text_info["gly_line"][j][i : i + 1]] - - if len(gline_list) > 0: - recog_emb = self.get_recog_emb(gline_list) - enc_glyph = self.proj(recog_emb.reshape(recog_emb.shape[0], -1).to(self.proj.weight.dtype)) - - self.text_embs_all = [] - n_idx = 0 - for i in range(len(text_info["n_lines"])): # sample index in a batch - n_lines = text_info["n_lines"][i] - text_embs = [] - for j in range(n_lines): # line - text_embs += [enc_glyph[n_idx : n_idx + 1]] - n_idx += 1 - self.text_embs_all += [text_embs] - - @torch.no_grad() - def forward( - self, - tokenized_text, - embedded_text, - ): - b, device = tokenized_text.shape[0], tokenized_text.device - for i in range(b): - idx = tokenized_text[i] == self.placeholder_token.to(device) - if sum(idx) > 0: - if i >= len(self.text_embs_all): - print("truncation for log images...") - break - text_emb = torch.cat(self.text_embs_all[i], dim=0) - if sum(idx) != len(text_emb): - print("truncation for long caption...") - text_emb = text_emb.to(embedded_text.device) - embedded_text[i][idx] = text_emb[: sum(idx)] - return embedded_text - - def embedding_parameters(self): - return self.parameters() - - -sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))) - - -def min_bounding_rect(img): - ret, thresh = cv2.threshold(img, 127, 255, 0) - contours, hierarchy = cv2.findContours(thresh, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE) - if len(contours) == 0: - print("Bad contours, using fake bbox...") - return np.array([[0, 0], [100, 0], [100, 100], [0, 100]]) - max_contour = max(contours, key=cv2.contourArea) - rect = cv2.minAreaRect(max_contour) - box = cv2.boxPoints(rect) - box = np.int0(box) - # sort - x_sorted = sorted(box, key=lambda x: x[0]) - left = x_sorted[:2] - right = x_sorted[2:] - left = sorted(left, key=lambda x: x[1]) - (tl, bl) = left - right = sorted(right, key=lambda x: x[1]) - (tr, br) = right - if tl[1] > bl[1]: - (tl, bl) = (bl, tl) - if tr[1] > br[1]: - (tr, br) = (br, tr) - return np.array([tl, tr, br, bl]) - - -def adjust_image(box, img): - pts1 = np.float32([box[0], box[1], box[2], box[3]]) - width = max(np.linalg.norm(pts1[0] - pts1[1]), np.linalg.norm(pts1[2] - pts1[3])) - height = max(np.linalg.norm(pts1[0] - pts1[3]), np.linalg.norm(pts1[1] - pts1[2])) - pts2 = np.float32([[0, 0], [width, 0], [width, height], [0, height]]) - # get transform matrix - M = get_sym_mat(pts1, pts2, estimate_scale=True) - C, H, W = img.shape - T = np.array([[2 / W, 0, -1], [0, 2 / H, -1], [0, 0, 1]]) - theta = np.linalg.inv(T @ M @ np.linalg.inv(T)) - theta = torch.from_numpy(theta[:2, :]).unsqueeze(0).type(torch.float32).to(img.device) - grid = F.affine_grid(theta, torch.Size([1, C, H, W]), align_corners=True) - result = F.grid_sample(img.unsqueeze(0), grid, align_corners=True) - result = torch.clamp(result.squeeze(0), 0, 255) - # crop - result = result[:, : int(height), : int(width)] - return result - - -""" -mask: numpy.ndarray, mask of textual, HWC -src_img: torch.Tensor, source image, CHW -""" - - -def crop_image(src_img, mask): - box = min_bounding_rect(mask) - result = adjust_image(box, src_img) - if len(result.shape) == 2: - result = torch.stack([result] * 3, axis=-1) - return result - - -def create_predictor(model_dir=None, model_lang="ch", device="cpu", use_fp16=False): - model_file_path = model_dir - if model_file_path is not None and not os.path.exists(model_file_path): - raise ValueError("not find model file path {}".format(model_file_path)) - - if model_lang == "ch": - n_class = 6625 - elif model_lang == "en": - n_class = 97 - else: - raise ValueError(f"Unsupported OCR recog model_lang: {model_lang}") - rec_config = edict( - in_channels=3, - backbone=edict(type="MobileNetV1Enhance", scale=0.5, last_conv_stride=[1, 2], last_pool_type="avg"), - neck=edict(type="SequenceEncoder", encoder_type="svtr", dims=64, depth=2, hidden_dims=120, use_guide=True), - head=edict(type="CTCHead", fc_decay=0.00001, out_channels=n_class, return_feats=True), - ) - - rec_model = RecModel(rec_config) - if model_file_path is not None: - rec_model.load_state_dict(torch.load(model_file_path, map_location=device)) - return rec_model - - -def _check_image_file(path): - img_end = ("tiff", "tif", "bmp", "rgb", "jpg", "png", "jpeg") - return path.lower().endswith(tuple(img_end)) - - -def get_image_file_list(img_file): - imgs_lists = [] - if img_file is None or not os.path.exists(img_file): - raise Exception("not found any img file in {}".format(img_file)) - if os.path.isfile(img_file) and _check_image_file(img_file): - imgs_lists.append(img_file) - elif os.path.isdir(img_file): - for single_file in os.listdir(img_file): - file_path = os.path.join(img_file, single_file) - if os.path.isfile(file_path) and _check_image_file(file_path): - imgs_lists.append(file_path) - if len(imgs_lists) == 0: - raise Exception("not found any img file in {}".format(img_file)) - imgs_lists = sorted(imgs_lists) - return imgs_lists - - -class TextRecognizer(object): - def __init__(self, args, predictor): - self.rec_image_shape = [int(v) for v in args["rec_image_shape"].split(",")] - self.rec_batch_num = args["rec_batch_num"] - self.predictor = predictor - self.chars = self.get_char_dict(args["rec_char_dict_path"]) - self.char2id = {x: i for i, x in enumerate(self.chars)} - self.is_onnx = not isinstance(self.predictor, torch.nn.Module) - self.use_fp16 = args["use_fp16"] - - # img: CHW - def resize_norm_img(self, img, max_wh_ratio): - imgC, imgH, imgW = self.rec_image_shape - assert imgC == img.shape[0] - imgW = int((imgH * max_wh_ratio)) - - h, w = img.shape[1:] - ratio = w / float(h) - if math.ceil(imgH * ratio) > imgW: - resized_w = imgW - else: - resized_w = int(math.ceil(imgH * ratio)) - resized_image = torch.nn.functional.interpolate( - img.unsqueeze(0), - size=(imgH, resized_w), - mode="bilinear", - align_corners=True, - ) - resized_image /= 255.0 - resized_image -= 0.5 - resized_image /= 0.5 - padding_im = torch.zeros((imgC, imgH, imgW), dtype=torch.float32).to(img.device) - padding_im[:, :, 0:resized_w] = resized_image[0] - return padding_im - - # img_list: list of tensors with shape chw 0-255 - def pred_imglist(self, img_list, show_debug=False): - img_num = len(img_list) - assert img_num > 0 - # Calculate the aspect ratio of all text bars - width_list = [] - for img in img_list: - width_list.append(img.shape[2] / float(img.shape[1])) - # Sorting can speed up the recognition process - indices = torch.from_numpy(np.argsort(np.array(width_list))) - batch_num = self.rec_batch_num - preds_all = [None] * img_num - preds_neck_all = [None] * img_num - for beg_img_no in range(0, img_num, batch_num): - end_img_no = min(img_num, beg_img_no + batch_num) - norm_img_batch = [] - - imgC, imgH, imgW = self.rec_image_shape[:3] - max_wh_ratio = imgW / imgH - for ino in range(beg_img_no, end_img_no): - h, w = img_list[indices[ino]].shape[1:] - if h > w * 1.2: - img = img_list[indices[ino]] - img = torch.transpose(img, 1, 2).flip(dims=[1]) - img_list[indices[ino]] = img - h, w = img.shape[1:] - # wh_ratio = w * 1.0 / h - # max_wh_ratio = max(max_wh_ratio, wh_ratio) # comment to not use different ratio - for ino in range(beg_img_no, end_img_no): - norm_img = self.resize_norm_img(img_list[indices[ino]], max_wh_ratio) - if self.use_fp16: - norm_img = norm_img.half() - norm_img = norm_img.unsqueeze(0) - norm_img_batch.append(norm_img) - norm_img_batch = torch.cat(norm_img_batch, dim=0) - if show_debug: - for i in range(len(norm_img_batch)): - _img = norm_img_batch[i].permute(1, 2, 0).detach().cpu().numpy() - _img = (_img + 0.5) * 255 - _img = _img[:, :, ::-1] - file_name = f"{indices[beg_img_no + i]}" - if os.path.exists(file_name + ".jpg"): - file_name += "_2" # ori image - cv2.imwrite(file_name + ".jpg", _img) - if self.is_onnx: - input_dict = {} - input_dict[self.predictor.get_inputs()[0].name] = norm_img_batch.detach().cpu().numpy() - outputs = self.predictor.run(None, input_dict) - preds = {} - preds["ctc"] = torch.from_numpy(outputs[0]) - preds["ctc_neck"] = [torch.zeros(1)] * img_num - else: - preds = self.predictor(norm_img_batch) - for rno in range(preds["ctc"].shape[0]): - preds_all[indices[beg_img_no + rno]] = preds["ctc"][rno] - preds_neck_all[indices[beg_img_no + rno]] = preds["ctc_neck"][rno] - - return torch.stack(preds_all, dim=0), torch.stack(preds_neck_all, dim=0) - - def get_char_dict(self, character_dict_path): - character_str = [] - with open(character_dict_path, "rb") as fin: - lines = fin.readlines() - for line in lines: - line = line.decode("utf-8").strip("\n").strip("\r\n") - character_str.append(line) - dict_character = list(character_str) - dict_character = ["sos"] + dict_character + [" "] # eos is space - return dict_character - - def get_text(self, order): - char_list = [self.chars[text_id] for text_id in order] - return "".join(char_list) - - def decode(self, mat): - text_index = mat.detach().cpu().numpy().argmax(axis=1) - ignored_tokens = [0] - selection = np.ones(len(text_index), dtype=bool) - selection[1:] = text_index[1:] != text_index[:-1] - for ignored_token in ignored_tokens: - selection &= text_index != ignored_token - return text_index[selection], np.where(selection)[0] - - def get_ctcloss(self, preds, gt_text, weight): - if not isinstance(weight, torch.Tensor): - weight = torch.tensor(weight).to(preds.device) - ctc_loss = torch.nn.CTCLoss(reduction="none") - log_probs = preds.log_softmax(dim=2).permute(1, 0, 2) # NTC-->TNC - targets = [] - target_lengths = [] - for t in gt_text: - targets += [self.char2id.get(i, len(self.chars) - 1) for i in t] - target_lengths += [len(t)] - targets = torch.tensor(targets).to(preds.device) - target_lengths = torch.tensor(target_lengths).to(preds.device) - input_lengths = torch.tensor([log_probs.shape[0]] * (log_probs.shape[1])).to(preds.device) - loss = ctc_loss(log_probs, targets, input_lengths, target_lengths) - loss = loss / input_lengths * weight - return loss - - -class TextEmbeddingModule(nn.Module): - # @register_to_config - def __init__(self, font_path, use_fp16=False, device="cpu"): - super().__init__() - # TODO: Learn if the recommended font file is free to use - self.font = ImageFont.truetype(font_path, 60) - self.frozen_CLIP_embedder_t3 = FrozenCLIPEmbedderT3(device=device, use_fp16=use_fp16) - self.embedding_manager = EmbeddingManager(self.frozen_CLIP_embedder_t3, use_fp16=use_fp16) - rec_model_dir = "./text_embedding_module/OCR/ppv3_rec.pth" - self.text_predictor = create_predictor(rec_model_dir, device=device, use_fp16=use_fp16).eval() - args = {} - args["rec_image_shape"] = "3, 48, 320" - args["rec_batch_num"] = 6 - args["rec_char_dict_path"] = "./text_embedding_module/OCR/ppocr_keys_v1.txt" - args["use_fp16"] = self.use_fp16 - self.embedding_manager.recog = TextRecognizer(args, self.text_predictor) - - @torch.no_grad() - def forward( - self, - prompt, - texts, - negative_prompt, - num_images_per_prompt, - mode, - draw_pos, - sort_priority="↕", - max_chars=77, - revise_pos=False, - h=512, - w=512, - ): - if prompt is None and texts is None: - raise ValueError("Prompt or texts must be provided!") - # preprocess pos_imgs(if numpy, make sure it's white pos in black bg) - if draw_pos is None: - pos_imgs = np.zeros((w, h, 1)) - if isinstance(draw_pos, str): - draw_pos = cv2.imread(draw_pos)[..., ::-1] - if draw_pos is None: - raise ValueError(f"Can't read draw_pos image from {draw_pos}!") - pos_imgs = 255 - draw_pos - elif isinstance(draw_pos, torch.Tensor): - pos_imgs = draw_pos.cpu().numpy() - else: - if not isinstance(draw_pos, np.ndarray): - raise ValueError(f"Unknown format of draw_pos: {type(draw_pos)}") - if mode == "edit": - pos_imgs = cv2.resize(pos_imgs, (w, h)) - pos_imgs = pos_imgs[..., 0:1] - pos_imgs = cv2.convertScaleAbs(pos_imgs) - _, pos_imgs = cv2.threshold(pos_imgs, 254, 255, cv2.THRESH_BINARY) - # separate pos_imgs - pos_imgs = self.separate_pos_imgs(pos_imgs, sort_priority) - if len(pos_imgs) == 0: - pos_imgs = [np.zeros((h, w, 1))] - n_lines = len(texts) - if len(pos_imgs) < n_lines: - if n_lines == 1 and texts[0] == " ": - pass # text-to-image without text - else: - raise ValueError( - f"Found {len(pos_imgs)} positions that < needed {n_lines} from prompt, check and try again!" - ) - elif len(pos_imgs) > n_lines: - str_warning = f"Warning: found {len(pos_imgs)} positions that > needed {n_lines} from prompt." - logger.warning(str_warning) - # get pre_pos, poly_list, hint that needed for anytext - pre_pos = [] - poly_list = [] - for input_pos in pos_imgs: - if input_pos.mean() != 0: - input_pos = input_pos[..., np.newaxis] if len(input_pos.shape) == 2 else input_pos - poly, pos_img = self.find_polygon(input_pos) - pre_pos += [pos_img / 255.0] - poly_list += [poly] - else: - pre_pos += [np.zeros((h, w, 1))] - poly_list += [None] - np_hint = np.sum(pre_pos, axis=0).clip(0, 1) - # prepare info dict - text_info = {} - text_info["glyphs"] = [] - text_info["gly_line"] = [] - text_info["positions"] = [] - text_info["n_lines"] = [len(texts)] * num_images_per_prompt - for i in range(len(texts)): - text = texts[i] - if len(text) > max_chars: - str_warning = f'"{text}" length > max_chars: {max_chars}, will be cut off...' - logger.warning(str_warning) - text = text[:max_chars] - gly_scale = 2 - if pre_pos[i].mean() != 0: - gly_line = self.draw_glyph(self.font, text) - glyphs = self.draw_glyph2( - self.font, text, poly_list[i], scale=gly_scale, width=w, height=h, add_space=False - ) - if revise_pos: - resize_gly = cv2.resize(glyphs, (pre_pos[i].shape[1], pre_pos[i].shape[0])) - new_pos = cv2.morphologyEx( - (resize_gly * 255).astype(np.uint8), - cv2.MORPH_CLOSE, - kernel=np.ones((resize_gly.shape[0] // 10, resize_gly.shape[1] // 10), dtype=np.uint8), - iterations=1, - ) - new_pos = new_pos[..., np.newaxis] if len(new_pos.shape) == 2 else new_pos - contours, _ = cv2.findContours(new_pos, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_NONE) - if len(contours) != 1: - str_warning = f"Fail to revise position {i} to bounding rect, remain position unchanged..." - logger.warning(str_warning) - else: - rect = cv2.minAreaRect(contours[0]) - poly = np.int0(cv2.boxPoints(rect)) - pre_pos[i] = cv2.drawContours(new_pos, [poly], -1, 255, -1) / 255.0 - else: - glyphs = np.zeros((h * gly_scale, w * gly_scale, 1)) - gly_line = np.zeros((80, 512, 1)) - pos = pre_pos[i] - text_info["glyphs"] += [self.arr2tensor(glyphs, num_images_per_prompt)] - text_info["gly_line"] += [self.arr2tensor(gly_line, num_images_per_prompt)] - text_info["positions"] += [self.arr2tensor(pos, num_images_per_prompt)] - - # hint = self.arr2tensor(np_hint, len(prompt)) - - self.embedding_manager.encode_text(text_info) - prompt_embeds = self.frozen_CLIP_embedder_t3.encode([prompt], embedding_manager=self.embedding_manager) - - self.embedding_manager.encode_text(text_info) - negative_prompt_embeds = self.frozen_CLIP_embedder_t3.encode( - [negative_prompt], embedding_manager=self.embedding_manager - ) - - return prompt_embeds, negative_prompt_embeds, text_info, np_hint - - def arr2tensor(self, arr, bs): - arr = np.transpose(arr, (2, 0, 1)) - _arr = torch.from_numpy(arr.copy()).float().cpu() - if self.use_fp16: - _arr = _arr.half() - _arr = torch.stack([_arr for _ in range(bs)], dim=0) - return _arr - - def separate_pos_imgs(self, img, sort_priority, gap=102): - num_labels, labels, stats, centroids = cv2.connectedComponentsWithStats(img) - components = [] - for label in range(1, num_labels): - component = np.zeros_like(img) - component[labels == label] = 255 - components.append((component, centroids[label])) - if sort_priority == "↕": - fir, sec = 1, 0 # top-down first - elif sort_priority == "↔": - fir, sec = 0, 1 # left-right first - else: - raise ValueError(f"Unknown sort_priority: {sort_priority}") - components.sort(key=lambda c: (c[1][fir] // gap, c[1][sec] // gap)) - sorted_components = [c[0] for c in components] - return sorted_components - - def find_polygon(self, image, min_rect=False): - contours, hierarchy = cv2.findContours(image, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_NONE) - max_contour = max(contours, key=cv2.contourArea) # get contour with max area - if min_rect: - # get minimum enclosing rectangle - rect = cv2.minAreaRect(max_contour) - poly = np.int0(cv2.boxPoints(rect)) - else: - # get approximate polygon - epsilon = 0.01 * cv2.arcLength(max_contour, True) - poly = cv2.approxPolyDP(max_contour, epsilon, True) - n, _, xy = poly.shape - poly = poly.reshape(n, xy) - cv2.drawContours(image, [poly], -1, 255, -1) - return poly, image - - def draw_glyph(self, font, text): - g_size = 50 - W, H = (512, 80) - new_font = font.font_variant(size=g_size) - img = Image.new(mode="1", size=(W, H), color=0) - draw = ImageDraw.Draw(img) - left, top, right, bottom = new_font.getbbox(text) - text_width = max(right - left, 5) - text_height = max(bottom - top, 5) - ratio = min(W * 0.9 / text_width, H * 0.9 / text_height) - new_font = font.font_variant(size=int(g_size * ratio)) - - text_width, text_height = new_font.getsize(text) - offset_x, offset_y = new_font.getoffset(text) - x = (img.width - text_width) // 2 - y = (img.height - text_height) // 2 - offset_y // 2 - draw.text((x, y), text, font=new_font, fill="white") - img = np.expand_dims(np.array(img), axis=2).astype(np.float64) - return img - - def draw_glyph2(self, font, text, polygon, vertAng=10, scale=1, width=512, height=512, add_space=True): - enlarge_polygon = polygon * scale - rect = cv2.minAreaRect(enlarge_polygon) - box = cv2.boxPoints(rect) - box = np.int0(box) - w, h = rect[1] - angle = rect[2] - if angle < -45: - angle += 90 - angle = -angle - if w < h: - angle += 90 - - vert = False - if abs(angle) % 90 < vertAng or abs(90 - abs(angle) % 90) % 90 < vertAng: - _w = max(box[:, 0]) - min(box[:, 0]) - _h = max(box[:, 1]) - min(box[:, 1]) - if _h >= _w: - vert = True - angle = 0 - - img = np.zeros((height * scale, width * scale, 3), np.uint8) - img = Image.fromarray(img) - - # infer font size - image4ratio = Image.new("RGB", img.size, "white") - draw = ImageDraw.Draw(image4ratio) - _, _, _tw, _th = draw.textbbox(xy=(0, 0), text=text, font=font) - text_w = min(w, h) * (_tw / _th) - if text_w <= max(w, h): - # add space - if len(text) > 1 and not vert and add_space: - for i in range(1, 100): - text_space = self.insert_spaces(text, i) - _, _, _tw2, _th2 = draw.textbbox(xy=(0, 0), text=text_space, font=font) - if min(w, h) * (_tw2 / _th2) > max(w, h): - break - text = self.insert_spaces(text, i - 1) - font_size = min(w, h) * 0.80 - else: - shrink = 0.75 if vert else 0.85 - font_size = min(w, h) / (text_w / max(w, h)) * shrink - new_font = font.font_variant(size=int(font_size)) - - left, top, right, bottom = new_font.getbbox(text) - text_width = right - left - text_height = bottom - top - - layer = Image.new("RGBA", img.size, (0, 0, 0, 0)) - draw = ImageDraw.Draw(layer) - if not vert: - draw.text( - (rect[0][0] - text_width // 2, rect[0][1] - text_height // 2 - top), - text, - font=new_font, - fill=(255, 255, 255, 255), - ) - else: - x_s = min(box[:, 0]) + _w // 2 - text_height // 2 - y_s = min(box[:, 1]) - for c in text: - draw.text((x_s, y_s), c, font=new_font, fill=(255, 255, 255, 255)) - _, _t, _, _b = new_font.getbbox(c) - y_s += _b - - rotated_layer = layer.rotate(angle, expand=1, center=(rect[0][0], rect[0][1])) - - x_offset = int((img.width - rotated_layer.width) / 2) - y_offset = int((img.height - rotated_layer.height) / 2) - img.paste(rotated_layer, (x_offset, y_offset), rotated_layer) - img = np.expand_dims(np.array(img.convert("1")), axis=2).astype(np.float64) - return img - - def insert_spaces(self, string, nSpace): - if nSpace == 0: - return string - new_string = "" - for char in string: - new_string += char + " " * nSpace - return new_string[:-nSpace] - - def to(self, *args, **kwargs): - self.frozen_CLIP_embedder_t3 = self.frozen_CLIP_embedder_t3.to(*args, **kwargs) - self.embedding_manager = self.embedding_manager.to(*args, **kwargs) - self.text_predictor = self.text_predictor.to(*args, **kwargs) - self.device = self.frozen_CLIP_embedder_t3.device - return self - - -# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents -def retrieve_latents( - encoder_output: torch.Tensor, generator: Optional[torch.Generator] = None, sample_mode: str = "sample" -): - if hasattr(encoder_output, "latent_dist") and sample_mode == "sample": - return encoder_output.latent_dist.sample(generator) - elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax": - return encoder_output.latent_dist.mode() - elif hasattr(encoder_output, "latents"): - return encoder_output.latents - else: - raise AttributeError("Could not access latents of provided encoder_output") - - -class AuxiliaryLatentModule(nn.Module): - def __init__( - self, - font_path, - vae=None, - device="cpu", - use_fp16=False, - ): - super().__init__() - self.font = ImageFont.truetype(font_path, 60) - self.use_fp16 = use_fp16 - self.device = device - - self.vae = vae.eval() if vae is not None else None - - @torch.no_grad() - def forward( - self, - text_info, - mode, - draw_pos, - ori_image, - num_images_per_prompt, - np_hint, - h=512, - w=512, - ): - if mode == "generate": - edit_image = np.ones((h, w, 3)) * 127.5 # empty mask image - elif mode == "edit": - if draw_pos is None or ori_image is None: - raise ValueError("Reference image and position image are needed for text editing!") - if isinstance(ori_image, str): - ori_image = cv2.imread(ori_image)[..., ::-1] - if ori_image is None: - raise ValueError(f"Can't read ori_image image from {ori_image}!") - elif isinstance(ori_image, torch.Tensor): - ori_image = ori_image.cpu().numpy() - else: - if not isinstance(ori_image, np.ndarray): - raise ValueError(f"Unknown format of ori_image: {type(ori_image)}") - edit_image = ori_image.clip(1, 255) # for mask reason - edit_image = self.check_channels(edit_image) - edit_image = self.resize_image( - edit_image, max_length=768 - ) # make w h multiple of 64, resize if w or h > max_length - - # get masked_x - masked_img = ((edit_image.astype(np.float32) / 127.5) - 1.0) * (1 - np_hint) - masked_img = np.transpose(masked_img, (2, 0, 1)) - masked_img = torch.from_numpy(masked_img.copy()).float().to(self.device) - if self.use_fp16: - masked_img = masked_img.half() - masked_x = (retrieve_latents(self.vae.encode(masked_img[None, ...])) * self.vae.config.scaling_factor).detach() - if self.use_fp16: - masked_x = masked_x.half() - text_info["masked_x"] = torch.cat([masked_x for _ in range(num_images_per_prompt)], dim=0) - - glyphs = torch.cat(text_info["glyphs"], dim=1).sum(dim=1, keepdim=True) - positions = torch.cat(text_info["positions"], dim=1).sum(dim=1, keepdim=True) - - return glyphs, positions, text_info - - def check_channels(self, image): - channels = image.shape[2] if len(image.shape) == 3 else 1 - if channels == 1: - image = cv2.cvtColor(image, cv2.COLOR_GRAY2BGR) - elif channels > 3: - image = image[:, :, :3] - return image - - def resize_image(self, img, max_length=768): - height, width = img.shape[:2] - max_dimension = max(height, width) - - if max_dimension > max_length: - scale_factor = max_length / max_dimension - new_width = int(round(width * scale_factor)) - new_height = int(round(height * scale_factor)) - new_size = (new_width, new_height) - img = cv2.resize(img, new_size) - height, width = img.shape[:2] - img = cv2.resize(img, (width - (width % 64), height - (height % 64))) - return img - - def insert_spaces(self, string, nSpace): - if nSpace == 0: - return string - new_string = "" - for char in string: - new_string += char + " " * nSpace - return new_string[:-nSpace] - - def to(self, device): - self.device = device - self.vae = self.vae.to(device) - return self - - -# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps -def retrieve_timesteps( - scheduler, - num_inference_steps: Optional[int] = None, - device: Optional[Union[str, torch.device]] = None, - timesteps: Optional[List[int]] = None, - sigmas: Optional[List[float]] = None, - **kwargs, -): - """ - Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles - custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`. - - Args: - scheduler (`SchedulerMixin`): - The scheduler to get timesteps from. - num_inference_steps (`int`): - The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps` - must be `None`. - device (`str` or `torch.device`, *optional*): - The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. - timesteps (`List[int]`, *optional*): - Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed, - `num_inference_steps` and `sigmas` must be `None`. - sigmas (`List[float]`, *optional*): - Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed, - `num_inference_steps` and `timesteps` must be `None`. - - Returns: - `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the - second element is the number of inference steps. - """ - if timesteps is not None and sigmas is not None: - raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values") - if timesteps is not None: - accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) - if not accepts_timesteps: - raise ValueError( - f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" - f" timestep schedules. Please check whether you are using the correct scheduler." - ) - scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs) - timesteps = scheduler.timesteps - num_inference_steps = len(timesteps) - elif sigmas is not None: - accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) - if not accept_sigmas: - raise ValueError( - f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" - f" sigmas schedules. Please check whether you are using the correct scheduler." - ) - scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs) - timesteps = scheduler.timesteps - num_inference_steps = len(timesteps) - else: - scheduler.set_timesteps(num_inference_steps, device=device, **kwargs) - timesteps = scheduler.timesteps - return timesteps, num_inference_steps - - -class AnyTextPipeline( - DiffusionPipeline, - StableDiffusionMixin, - TextualInversionLoaderMixin, - StableDiffusionLoraLoaderMixin, - IPAdapterMixin, - FromSingleFileMixin, -): - r""" - Pipeline for text-to-image generation using Stable Diffusion with ControlNet guidance. - - This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods - implemented for all pipelines (downloading, saving, running on a particular device, etc.). - - The pipeline also inherits the following loading methods: - - [`~loaders.TextualInversionLoaderMixin.load_textual_inversion`] for loading textual inversion embeddings - - [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_weights`] for loading LoRA weights - - [`~loaders.StableDiffusionLoraLoaderMixin.save_lora_weights`] for saving LoRA weights - - [`~loaders.FromSingleFileMixin.from_single_file`] for loading `.ckpt` files - - [`~loaders.IPAdapterMixin.load_ip_adapter`] for loading IP Adapters - - Args: - vae ([`AutoencoderKL`]): - Variational Auto-Encoder (VAE) model to encode and decode images to and from latent representations. - text_encoder ([`~transformers.CLIPTextModel`]): - Frozen text-encoder ([clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14)). - tokenizer ([`~transformers.CLIPTokenizer`]): - A `CLIPTokenizer` to tokenize text. - unet ([`UNet2DConditionModel`]): - A `UNet2DConditionModel` to denoise the encoded image latents. - controlnet ([`ControlNetModel`] or `List[ControlNetModel]`): - Provides additional conditioning to the `unet` during the denoising process. If you set multiple - ControlNets as a list, the outputs from each ControlNet are added together to create one combined - additional conditioning. - scheduler ([`SchedulerMixin`]): - A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of - [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`]. - safety_checker ([`StableDiffusionSafetyChecker`]): - Classification module that estimates whether generated images could be considered offensive or harmful. - Please refer to the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) for more details - about a model's potential harms. - feature_extractor ([`~transformers.CLIPImageProcessor`]): - A `CLIPImageProcessor` to extract features from generated images; used as inputs to the `safety_checker`. - """ - - model_cpu_offload_seq = "text_encoder->image_encoder->unet->vae" - _optional_components = ["safety_checker", "feature_extractor", "image_encoder"] - _exclude_from_cpu_offload = ["safety_checker"] - _callback_tensor_inputs = ["latents", "prompt_embeds", "negative_prompt_embeds"] - - def __init__( - self, - font_path: str, - vae: AutoencoderKL, - text_encoder: CLIPTextModel, - tokenizer: CLIPTokenizer, - unet: UNet2DConditionModel, - controlnet: Union[ControlNetModel, List[ControlNetModel], Tuple[ControlNetModel], MultiControlNetModel], - scheduler: KarrasDiffusionSchedulers, - safety_checker: StableDiffusionSafetyChecker, - feature_extractor: CLIPImageProcessor, - image_encoder: CLIPVisionModelWithProjection = None, - requires_safety_checker: bool = True, - ): - super().__init__() - self.text_embedding_module = TextEmbeddingModule( - use_fp16=unet.dtype == torch.float16, device=unet.device, font_path=font_path - ) - self.auxiliary_latent_module = AuxiliaryLatentModule( - vae=vae, use_fp16=unet.dtype == torch.float16, device=unet.device, font_path=font_path - ) - - if safety_checker is None and requires_safety_checker: - logger.warning( - f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure" - " that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered" - " results in services or applications open to the public. Both the diffusers team and Hugging Face" - " strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling" - " it only for use-cases that involve analyzing network behavior or auditing its results. For more" - " information, please have a look at https://github.com/huggingface/diffusers/pull/254 ." - ) - - if safety_checker is not None and feature_extractor is None: - raise ValueError( - "Make sure to define a feature extractor when loading {self.__class__} if you want to use the safety" - " checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead." - ) - - if isinstance(controlnet, (list, tuple)): - controlnet = MultiControlNetModel(controlnet) - - self.register_modules( - vae=vae, - text_encoder=text_encoder, - tokenizer=tokenizer, - unet=unet, - controlnet=controlnet, - scheduler=scheduler, - safety_checker=safety_checker, - feature_extractor=feature_extractor, - image_encoder=image_encoder, - # text_embedding_module=self.text_embedding_module, - # auxiliary_latent_module=self.auxiliary_latent_module, - ) - self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) - self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor, do_convert_rgb=True) - self.control_image_processor = VaeImageProcessor( - vae_scale_factor=self.vae_scale_factor, do_convert_rgb=True, do_normalize=False - ) - self.register_to_config(requires_safety_checker=requires_safety_checker, font_path=font_path) - - def modify_prompt(self, prompt): - prompt = prompt.replace("“", '"') - prompt = prompt.replace("”", '"') - p = '"(.*?)"' - strs = re.findall(p, prompt) - if len(strs) == 0: - strs = [" "] - else: - for s in strs: - prompt = prompt.replace(f'"{s}"', f" {PLACE_HOLDER} ", 1) - if self.is_chinese(prompt): - if self.trans_pipe is None: - return None, None - old_prompt = prompt - prompt = self.trans_pipe(input=prompt + " .")["translation"][:-1] - print(f"Translate: {old_prompt} --> {prompt}") - return prompt, strs - - def is_chinese(self, text): - text = checker._clean_text(text) - for char in text: - cp = ord(char) - if checker._is_chinese_char(cp): - return True - return False - - # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline._encode_prompt - def _encode_prompt( - self, - prompt, - device, - num_images_per_prompt, - do_classifier_free_guidance, - negative_prompt=None, - prompt_embeds: Optional[torch.Tensor] = None, - negative_prompt_embeds: Optional[torch.Tensor] = None, - lora_scale: Optional[float] = None, - **kwargs, - ): - deprecation_message = "`_encode_prompt()` is deprecated and it will be removed in a future version. Use `encode_prompt()` instead. Also, be aware that the output format changed from a concatenated tensor to a tuple." - deprecate("_encode_prompt()", "1.0.0", deprecation_message, standard_warn=False) - - prompt_embeds_tuple = self.encode_prompt( - prompt=prompt, - device=device, - num_images_per_prompt=num_images_per_prompt, - do_classifier_free_guidance=do_classifier_free_guidance, - negative_prompt=negative_prompt, - prompt_embeds=prompt_embeds, - negative_prompt_embeds=negative_prompt_embeds, - lora_scale=lora_scale, - **kwargs, - ) - - # concatenate for backwards comp - prompt_embeds = torch.cat([prompt_embeds_tuple[1], prompt_embeds_tuple[0]]) - - return prompt_embeds - - # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.encode_prompt - def encode_prompt( - self, - prompt, - device, - num_images_per_prompt, - do_classifier_free_guidance, - negative_prompt=None, - prompt_embeds: Optional[torch.Tensor] = None, - negative_prompt_embeds: Optional[torch.Tensor] = None, - lora_scale: Optional[float] = None, - clip_skip: Optional[int] = None, - ): - r""" - Encodes the prompt into text encoder hidden states. - - Args: - prompt (`str` or `List[str]`, *optional*): - prompt to be encoded - device: (`torch.device`): - torch device - num_images_per_prompt (`int`): - number of images that should be generated per prompt - do_classifier_free_guidance (`bool`): - whether to use classifier free guidance or not - negative_prompt (`str` or `List[str]`, *optional*): - The prompt or prompts not to guide the image generation. If not defined, one has to pass - `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is - less than `1`). - prompt_embeds (`torch.Tensor`, *optional*): - Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not - provided, text embeddings will be generated from `prompt` input argument. - negative_prompt_embeds (`torch.Tensor`, *optional*): - Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt - weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input - argument. - lora_scale (`float`, *optional*): - A LoRA scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded. - clip_skip (`int`, *optional*): - Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that - the output of the pre-final layer will be used for computing the prompt embeddings. - """ - # set lora scale so that monkey patched LoRA - # function of text encoder can correctly access it - if lora_scale is not None and isinstance(self, StableDiffusionLoraLoaderMixin): - self._lora_scale = lora_scale - - # dynamically adjust the LoRA scale - if not USE_PEFT_BACKEND: - adjust_lora_scale_text_encoder(self.text_encoder, lora_scale) - else: - scale_lora_layers(self.text_encoder, lora_scale) - - if prompt is not None and isinstance(prompt, str): - batch_size = 1 - elif prompt is not None and isinstance(prompt, list): - batch_size = len(prompt) - else: - batch_size = prompt_embeds.shape[0] - - if prompt_embeds is None: - # textual inversion: process multi-vector tokens if necessary - if isinstance(self, TextualInversionLoaderMixin): - prompt = self.maybe_convert_prompt(prompt, self.tokenizer) - - text_inputs = self.tokenizer( - prompt, - padding="max_length", - max_length=self.tokenizer.model_max_length, - truncation=True, - return_tensors="pt", - ) - text_input_ids = text_inputs.input_ids - untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids - - if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal( - text_input_ids, untruncated_ids - ): - removed_text = self.tokenizer.batch_decode( - untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1] - ) - logger.warning( - "The following part of your input was truncated because CLIP can only handle sequences up to" - f" {self.tokenizer.model_max_length} tokens: {removed_text}" - ) - - if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask: - attention_mask = text_inputs.attention_mask.to(device) - else: - attention_mask = None - - if clip_skip is None: - prompt_embeds = self.text_encoder(text_input_ids.to(device), attention_mask=attention_mask) - prompt_embeds = prompt_embeds[0] - else: - prompt_embeds = self.text_encoder( - text_input_ids.to(device), attention_mask=attention_mask, output_hidden_states=True - ) - # Access the `hidden_states` first, that contains a tuple of - # all the hidden states from the encoder layers. Then index into - # the tuple to access the hidden states from the desired layer. - prompt_embeds = prompt_embeds[-1][-(clip_skip + 1)] - # We also need to apply the final LayerNorm here to not mess with the - # representations. The `last_hidden_states` that we typically use for - # obtaining the final prompt representations passes through the LayerNorm - # layer. - prompt_embeds = self.text_encoder.text_model.final_layer_norm(prompt_embeds) - - if self.text_encoder is not None: - prompt_embeds_dtype = self.text_encoder.dtype - elif self.unet is not None: - prompt_embeds_dtype = self.unet.dtype - else: - prompt_embeds_dtype = prompt_embeds.dtype - - prompt_embeds = prompt_embeds.to(dtype=prompt_embeds_dtype, device=device) - - bs_embed, seq_len, _ = prompt_embeds.shape - # duplicate text embeddings for each generation per prompt, using mps friendly method - prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) - prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1) - - # get unconditional embeddings for classifier free guidance - if do_classifier_free_guidance and negative_prompt_embeds is None: - uncond_tokens: List[str] - if negative_prompt is None: - uncond_tokens = [""] * batch_size - elif prompt is not None and type(prompt) is not type(negative_prompt): - raise TypeError( - f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" - f" {type(prompt)}." - ) - elif isinstance(negative_prompt, str): - uncond_tokens = [negative_prompt] - elif batch_size != len(negative_prompt): - raise ValueError( - f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" - f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" - " the batch size of `prompt`." - ) - else: - uncond_tokens = negative_prompt - - # textual inversion: process multi-vector tokens if necessary - if isinstance(self, TextualInversionLoaderMixin): - uncond_tokens = self.maybe_convert_prompt(uncond_tokens, self.tokenizer) - - max_length = prompt_embeds.shape[1] - uncond_input = self.tokenizer( - uncond_tokens, - padding="max_length", - max_length=max_length, - truncation=True, - return_tensors="pt", - ) - - if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask: - attention_mask = uncond_input.attention_mask.to(device) - else: - attention_mask = None - - negative_prompt_embeds = self.text_encoder( - uncond_input.input_ids.to(device), - attention_mask=attention_mask, - ) - negative_prompt_embeds = negative_prompt_embeds[0] - - if do_classifier_free_guidance: - # duplicate unconditional embeddings for each generation per prompt, using mps friendly method - seq_len = negative_prompt_embeds.shape[1] - - negative_prompt_embeds = negative_prompt_embeds.to(dtype=prompt_embeds_dtype, device=device) - - negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1) - negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) - - if self.text_encoder is not None: - if isinstance(self, StableDiffusionLoraLoaderMixin) and USE_PEFT_BACKEND: - # Retrieve the original scale by scaling back the LoRA layers - unscale_lora_layers(self.text_encoder, lora_scale) - - return prompt_embeds, negative_prompt_embeds - - # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.encode_image - def encode_image(self, image, device, num_images_per_prompt, output_hidden_states=None): - dtype = next(self.image_encoder.parameters()).dtype - - if not isinstance(image, torch.Tensor): - image = self.feature_extractor(image, return_tensors="pt").pixel_values - - image = image.to(device=device, dtype=dtype) - if output_hidden_states: - image_enc_hidden_states = self.image_encoder(image, output_hidden_states=True).hidden_states[-2] - image_enc_hidden_states = image_enc_hidden_states.repeat_interleave(num_images_per_prompt, dim=0) - uncond_image_enc_hidden_states = self.image_encoder( - torch.zeros_like(image), output_hidden_states=True - ).hidden_states[-2] - uncond_image_enc_hidden_states = uncond_image_enc_hidden_states.repeat_interleave( - num_images_per_prompt, dim=0 - ) - return image_enc_hidden_states, uncond_image_enc_hidden_states - else: - image_embeds = self.image_encoder(image).image_embeds - image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0) - uncond_image_embeds = torch.zeros_like(image_embeds) - - return image_embeds, uncond_image_embeds - - # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_ip_adapter_image_embeds - def prepare_ip_adapter_image_embeds( - self, ip_adapter_image, ip_adapter_image_embeds, device, num_images_per_prompt, do_classifier_free_guidance - ): - image_embeds = [] - if do_classifier_free_guidance: - negative_image_embeds = [] - if ip_adapter_image_embeds is None: - if not isinstance(ip_adapter_image, list): - ip_adapter_image = [ip_adapter_image] - - if len(ip_adapter_image) != len(self.unet.encoder_hid_proj.image_projection_layers): - raise ValueError( - f"`ip_adapter_image` must have same length as the number of IP Adapters. Got {len(ip_adapter_image)} images and {len(self.unet.encoder_hid_proj.image_projection_layers)} IP Adapters." - ) - - for single_ip_adapter_image, image_proj_layer in zip( - ip_adapter_image, self.unet.encoder_hid_proj.image_projection_layers - ): - output_hidden_state = not isinstance(image_proj_layer, ImageProjection) - single_image_embeds, single_negative_image_embeds = self.encode_image( - single_ip_adapter_image, device, 1, output_hidden_state - ) - - image_embeds.append(single_image_embeds[None, :]) - if do_classifier_free_guidance: - negative_image_embeds.append(single_negative_image_embeds[None, :]) - else: - for single_image_embeds in ip_adapter_image_embeds: - if do_classifier_free_guidance: - single_negative_image_embeds, single_image_embeds = single_image_embeds.chunk(2) - negative_image_embeds.append(single_negative_image_embeds) - image_embeds.append(single_image_embeds) - - ip_adapter_image_embeds = [] - for i, single_image_embeds in enumerate(image_embeds): - single_image_embeds = torch.cat([single_image_embeds] * num_images_per_prompt, dim=0) - if do_classifier_free_guidance: - single_negative_image_embeds = torch.cat([negative_image_embeds[i]] * num_images_per_prompt, dim=0) - single_image_embeds = torch.cat([single_negative_image_embeds, single_image_embeds], dim=0) - - single_image_embeds = single_image_embeds.to(device=device) - ip_adapter_image_embeds.append(single_image_embeds) - - return ip_adapter_image_embeds - - # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.run_safety_checker - def run_safety_checker(self, image, device, dtype): - if self.safety_checker is None: - has_nsfw_concept = None - else: - if torch.is_tensor(image): - feature_extractor_input = self.image_processor.postprocess(image, output_type="pil") - else: - feature_extractor_input = self.image_processor.numpy_to_pil(image) - safety_checker_input = self.feature_extractor(feature_extractor_input, return_tensors="pt").to(device) - image, has_nsfw_concept = self.safety_checker( - images=image, clip_input=safety_checker_input.pixel_values.to(dtype) - ) - return image, has_nsfw_concept - - # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.decode_latents - def decode_latents(self, latents): - deprecation_message = "The decode_latents method is deprecated and will be removed in 1.0.0. Please use VaeImageProcessor.postprocess(...) instead" - deprecate("decode_latents", "1.0.0", deprecation_message, standard_warn=False) - - latents = 1 / self.vae.config.scaling_factor * latents - image = self.vae.decode(latents, return_dict=False)[0] - image = (image / 2 + 0.5).clamp(0, 1) - # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16 - image = image.cpu().permute(0, 2, 3, 1).float().numpy() - return image - - # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs - def prepare_extra_step_kwargs(self, generator, eta): - # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature - # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. - # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502 - # and should be between [0, 1] - - accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) - extra_step_kwargs = {} - if accepts_eta: - extra_step_kwargs["eta"] = eta - - # check if the scheduler accepts generator - accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys()) - if accepts_generator: - extra_step_kwargs["generator"] = generator - return extra_step_kwargs - - def check_inputs( - self, - prompt, - # image, - callback_steps, - negative_prompt=None, - prompt_embeds=None, - negative_prompt_embeds=None, - ip_adapter_image=None, - ip_adapter_image_embeds=None, - controlnet_conditioning_scale=1.0, - control_guidance_start=0.0, - control_guidance_end=1.0, - callback_on_step_end_tensor_inputs=None, - ): - if callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0): - raise ValueError( - f"`callback_steps` has to be a positive integer but is {callback_steps} of type" - f" {type(callback_steps)}." - ) - - if callback_on_step_end_tensor_inputs is not None and not all( - k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs - ): - raise ValueError( - f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}" - ) - - if prompt is not None and prompt_embeds is not None: - raise ValueError( - f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" - " only forward one of the two." - ) - elif prompt is None and prompt_embeds is None: - raise ValueError( - "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." - ) - elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): - raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") - - if negative_prompt is not None and negative_prompt_embeds is not None: - raise ValueError( - f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:" - f" {negative_prompt_embeds}. Please make sure to only forward one of the two." - ) - - if prompt_embeds is not None and negative_prompt_embeds is not None: - if prompt_embeds.shape != negative_prompt_embeds.shape: - raise ValueError( - "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but" - f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`" - f" {negative_prompt_embeds.shape}." - ) - - # Check `image` - is_compiled = hasattr(F, "scaled_dot_product_attention") and isinstance( - self.controlnet, torch._dynamo.eval_frame.OptimizedModule - ) - - # Check `controlnet_conditioning_scale` - if ( - isinstance(self.controlnet, ControlNetModel) - or is_compiled - and isinstance(self.controlnet._orig_mod, ControlNetModel) - ): - if not isinstance(controlnet_conditioning_scale, float): - print(controlnet_conditioning_scale) - raise TypeError("For single controlnet: `controlnet_conditioning_scale` must be type `float`.") - elif ( - isinstance(self.controlnet, MultiControlNetModel) - or is_compiled - and isinstance(self.controlnet._orig_mod, MultiControlNetModel) - ): - if isinstance(controlnet_conditioning_scale, list): - if any(isinstance(i, list) for i in controlnet_conditioning_scale): - raise ValueError( - "A single batch of varying conditioning scale settings (e.g. [[1.0, 0.5], [0.2, 0.8]]) is not supported at the moment. " - "The conditioning scale must be fixed across the batch." - ) - elif isinstance(controlnet_conditioning_scale, list) and len(controlnet_conditioning_scale) != len( - self.controlnet.nets - ): - raise ValueError( - "For multiple controlnets: When `controlnet_conditioning_scale` is specified as `list`, it must have" - " the same length as the number of controlnets" - ) - else: - assert False - - if not isinstance(control_guidance_start, (tuple, list)): - control_guidance_start = [control_guidance_start] - - if not isinstance(control_guidance_end, (tuple, list)): - control_guidance_end = [control_guidance_end] - - if len(control_guidance_start) != len(control_guidance_end): - raise ValueError( - f"`control_guidance_start` has {len(control_guidance_start)} elements, but `control_guidance_end` has {len(control_guidance_end)} elements. Make sure to provide the same number of elements to each list." - ) - - if isinstance(self.controlnet, MultiControlNetModel): - if len(control_guidance_start) != len(self.controlnet.nets): - raise ValueError( - f"`control_guidance_start`: {control_guidance_start} has {len(control_guidance_start)} elements but there are {len(self.controlnet.nets)} controlnets available. Make sure to provide {len(self.controlnet.nets)}." - ) - - for start, end in zip(control_guidance_start, control_guidance_end): - if start >= end: - raise ValueError( - f"control guidance start: {start} cannot be larger or equal to control guidance end: {end}." - ) - if start < 0.0: - raise ValueError(f"control guidance start: {start} can't be smaller than 0.") - if end > 1.0: - raise ValueError(f"control guidance end: {end} can't be larger than 1.0.") - - if ip_adapter_image is not None and ip_adapter_image_embeds is not None: - raise ValueError( - "Provide either `ip_adapter_image` or `ip_adapter_image_embeds`. Cannot leave both `ip_adapter_image` and `ip_adapter_image_embeds` defined." - ) - - if ip_adapter_image_embeds is not None: - if not isinstance(ip_adapter_image_embeds, list): - raise ValueError( - f"`ip_adapter_image_embeds` has to be of type `list` but is {type(ip_adapter_image_embeds)}" - ) - elif ip_adapter_image_embeds[0].ndim not in [3, 4]: - raise ValueError( - f"`ip_adapter_image_embeds` has to be a list of 3D or 4D tensors but is {ip_adapter_image_embeds[0].ndim}D" - ) - - def check_image(self, image, prompt, prompt_embeds): - image_is_pil = isinstance(image, PIL.Image.Image) - image_is_tensor = isinstance(image, torch.Tensor) - image_is_np = isinstance(image, np.ndarray) - image_is_pil_list = isinstance(image, list) and isinstance(image[0], PIL.Image.Image) - image_is_tensor_list = isinstance(image, list) and isinstance(image[0], torch.Tensor) - image_is_np_list = isinstance(image, list) and isinstance(image[0], np.ndarray) - - if ( - not image_is_pil - and not image_is_tensor - and not image_is_np - and not image_is_pil_list - and not image_is_tensor_list - and not image_is_np_list - ): - raise TypeError( - f"image must be passed and be one of PIL image, numpy array, torch tensor, list of PIL images, list of numpy arrays or list of torch tensors, but is {type(image)}" - ) - - if image_is_pil: - image_batch_size = 1 - else: - image_batch_size = len(image) - - if prompt is not None and isinstance(prompt, str): - prompt_batch_size = 1 - elif prompt is not None and isinstance(prompt, list): - prompt_batch_size = len(prompt) - elif prompt_embeds is not None: - prompt_batch_size = prompt_embeds.shape[0] - - if image_batch_size != 1 and image_batch_size != prompt_batch_size: - raise ValueError( - f"If image batch size is not 1, image batch size must be same as prompt batch size. image batch size: {image_batch_size}, prompt batch size: {prompt_batch_size}" - ) - - def prepare_image( - self, - image, - width, - height, - batch_size, - num_images_per_prompt, - device, - dtype, - do_classifier_free_guidance=False, - guess_mode=False, - ): - image = self.control_image_processor.preprocess(image, height=height, width=width).to(dtype=torch.float32) - image_batch_size = image.shape[0] - - if image_batch_size == 1: - repeat_by = batch_size - else: - # image batch size is the same as prompt batch size - repeat_by = num_images_per_prompt - - image = image.repeat_interleave(repeat_by, dim=0) - - image = image.to(device=device, dtype=dtype) - - if do_classifier_free_guidance and not guess_mode: - image = torch.cat([image] * 2) - - return image - - # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents - def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None): - shape = ( - batch_size, - num_channels_latents, - int(height) // self.vae_scale_factor, - int(width) // self.vae_scale_factor, - ) - if isinstance(generator, list) and len(generator) != batch_size: - raise ValueError( - f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" - f" size of {batch_size}. Make sure the batch size matches the length of the generators." - ) - - if latents is None: - latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) - else: - latents = latents.to(device) - - # scale the initial noise by the standard deviation required by the scheduler - latents = latents * self.scheduler.init_noise_sigma - return latents - - # Copied from diffusers.pipelines.latent_consistency_models.pipeline_latent_consistency_text2img.LatentConsistencyModelPipeline.get_guidance_scale_embedding - def get_guidance_scale_embedding( - self, w: torch.Tensor, embedding_dim: int = 512, dtype: torch.dtype = torch.float32 - ) -> torch.Tensor: - """ - See https://github.com/google-research/vdm/blob/dc27b98a554f65cdc654b800da5aa1846545d41b/model_vdm.py#L298 - - Args: - w (`torch.Tensor`): - Generate embedding vectors with a specified guidance scale to subsequently enrich timestep embeddings. - embedding_dim (`int`, *optional*, defaults to 512): - Dimension of the embeddings to generate. - dtype (`torch.dtype`, *optional*, defaults to `torch.float32`): - Data type of the generated embeddings. - - Returns: - `torch.Tensor`: Embedding vectors with shape `(len(w), embedding_dim)`. - """ - assert len(w.shape) == 1 - w = w * 1000.0 - - half_dim = embedding_dim // 2 - emb = torch.log(torch.tensor(10000.0)) / (half_dim - 1) - emb = torch.exp(torch.arange(half_dim, dtype=dtype) * -emb) - emb = w.to(dtype)[:, None] * emb[None, :] - emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1) - if embedding_dim % 2 == 1: # zero pad - emb = torch.nn.functional.pad(emb, (0, 1)) - assert emb.shape == (w.shape[0], embedding_dim) - return emb - - @property - def guidance_scale(self): - return self._guidance_scale - - @property - def clip_skip(self): - return self._clip_skip - - # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) - # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` - # corresponds to doing no classifier free guidance. - @property - def do_classifier_free_guidance(self): - return self._guidance_scale > 1 and self.unet.config.time_cond_proj_dim is None - - @property - def cross_attention_kwargs(self): - return self._cross_attention_kwargs - - @property - def num_timesteps(self): - return self._num_timesteps - - @torch.no_grad() - @replace_example_docstring(EXAMPLE_DOC_STRING) - def __call__( - self, - prompt: Union[str, List[str]] = None, - height: Optional[int] = None, - width: Optional[int] = None, - num_inference_steps: int = 50, - mode: Optional[str] = "generate", - draw_pos: Optional[Union[str, torch.Tensor]] = None, - ori_image: Optional[Union[str, torch.Tensor]] = None, - timesteps: List[int] = None, - sigmas: List[float] = None, - guidance_scale: float = 7.5, - negative_prompt: Optional[Union[str, List[str]]] = None, - num_images_per_prompt: Optional[int] = 1, - eta: float = 0.0, - generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, - latents: Optional[torch.Tensor] = None, - prompt_embeds: Optional[torch.Tensor] = None, - negative_prompt_embeds: Optional[torch.Tensor] = None, - ip_adapter_image: Optional[PipelineImageInput] = None, - ip_adapter_image_embeds: Optional[List[torch.Tensor]] = None, - output_type: Optional[str] = "pil", - return_dict: bool = True, - cross_attention_kwargs: Optional[Dict[str, Any]] = None, - controlnet_conditioning_scale: Union[float, List[float]] = 1.0, - guess_mode: bool = False, - control_guidance_start: Union[float, List[float]] = 0.0, - control_guidance_end: Union[float, List[float]] = 1.0, - clip_skip: Optional[int] = None, - callback_on_step_end: Optional[ - Union[Callable[[int, int, Dict], None], PipelineCallback, MultiPipelineCallbacks] - ] = None, - callback_on_step_end_tensor_inputs: List[str] = ["latents"], - **kwargs, - ): - r""" - The call function to the pipeline for generation. - - Args: - prompt (`str` or `List[str]`, *optional*): - The prompt or prompts to guide image generation. If not defined, you need to pass `prompt_embeds`. - image (`torch.Tensor`, `PIL.Image.Image`, `np.ndarray`, `List[torch.Tensor]`, `List[PIL.Image.Image]`, `List[np.ndarray]`,: - `List[List[torch.Tensor]]`, `List[List[np.ndarray]]` or `List[List[PIL.Image.Image]]`): - The ControlNet input condition to provide guidance to the `unet` for generation. If the type is - specified as `torch.Tensor`, it is passed to ControlNet as is. `PIL.Image.Image` can also be accepted - as an image. The dimensions of the output image defaults to `image`'s dimensions. If height and/or - width are passed, `image` is resized accordingly. If multiple ControlNets are specified in `init`, - images must be passed as a list such that each element of the list can be correctly batched for input - to a single ControlNet. When `prompt` is a list, and if a list of images is passed for a single - ControlNet, each will be paired with each prompt in the `prompt` list. This also applies to multiple - ControlNets, where a list of image lists can be passed to batch for each prompt and each ControlNet. - height (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`): - The height in pixels of the generated image. - width (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`): - The width in pixels of the generated image. - num_inference_steps (`int`, *optional*, defaults to 50): - The number of denoising steps. More denoising steps usually lead to a higher quality image at the - expense of slower inference. - timesteps (`List[int]`, *optional*): - Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument - in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is - passed will be used. Must be in descending order. - sigmas (`List[float]`, *optional*): - Custom sigmas to use for the denoising process with schedulers which support a `sigmas` argument in - their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed - will be used. - guidance_scale (`float`, *optional*, defaults to 7.5): - A higher guidance scale value encourages the model to generate images closely linked to the text - `prompt` at the expense of lower image quality. Guidance scale is enabled when `guidance_scale > 1`. - negative_prompt (`str` or `List[str]`, *optional*): - The prompt or prompts to guide what to not include in image generation. If not defined, you need to - pass `negative_prompt_embeds` instead. Ignored when not using guidance (`guidance_scale < 1`). - num_images_per_prompt (`int`, *optional*, defaults to 1): - The number of images to generate per prompt. - eta (`float`, *optional*, defaults to 0.0): - Corresponds to parameter eta (η) from the [DDIM](https://arxiv.org/abs/2010.02502) paper. Only applies - to the [`~schedulers.DDIMScheduler`], and is ignored in other schedulers. - generator (`torch.Generator` or `List[torch.Generator]`, *optional*): - A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make - generation deterministic. - latents (`torch.Tensor`, *optional*): - Pre-generated noisy latents sampled from a Gaussian distribution, to be used as inputs for image - generation. Can be used to tweak the same generation with different prompts. If not provided, a latents - tensor is generated by sampling using the supplied random `generator`. - prompt_embeds (`torch.Tensor`, *optional*): - Pre-generated text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not - provided, text embeddings are generated from the `prompt` input argument. - negative_prompt_embeds (`torch.Tensor`, *optional*): - Pre-generated negative text embeddings. Can be used to easily tweak text inputs (prompt weighting). If - not provided, `negative_prompt_embeds` are generated from the `negative_prompt` input argument. - ip_adapter_image: (`PipelineImageInput`, *optional*): Optional image input to work with IP Adapters. - ip_adapter_image_embeds (`List[torch.Tensor]`, *optional*): - Pre-generated image embeddings for IP-Adapter. It should be a list of length same as number of - IP-adapters. Each element should be a tensor of shape `(batch_size, num_images, emb_dim)`. It should - contain the negative image embedding if `do_classifier_free_guidance` is set to `True`. If not - provided, embeddings are computed from the `ip_adapter_image` input argument. - output_type (`str`, *optional*, defaults to `"pil"`): - The output format of the generated image. Choose between `PIL.Image` or `np.array`. - return_dict (`bool`, *optional*, defaults to `True`): - Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a - plain tuple. - callback (`Callable`, *optional*): - A function that calls every `callback_steps` steps during inference. The function is called with the - following arguments: `callback(step: int, timestep: int, latents: torch.Tensor)`. - callback_steps (`int`, *optional*, defaults to 1): - The frequency at which the `callback` function is called. If not specified, the callback is called at - every step. - cross_attention_kwargs (`dict`, *optional*): - A kwargs dictionary that if specified is passed along to the [`AttentionProcessor`] as defined in - [`self.processor`](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). - controlnet_conditioning_scale (`float` or `List[float]`, *optional*, defaults to 1.0): - The outputs of the ControlNet are multiplied by `controlnet_conditioning_scale` before they are added - to the residual in the original `unet`. If multiple ControlNets are specified in `init`, you can set - the corresponding scale as a list. - guess_mode (`bool`, *optional*, defaults to `False`): - The ControlNet encoder tries to recognize the content of the input image even if you remove all - prompts. A `guidance_scale` value between 3.0 and 5.0 is recommended. - control_guidance_start (`float` or `List[float]`, *optional*, defaults to 0.0): - The percentage of total steps at which the ControlNet starts applying. - control_guidance_end (`float` or `List[float]`, *optional*, defaults to 1.0): - The percentage of total steps at which the ControlNet stops applying. - clip_skip (`int`, *optional*): - Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that - the output of the pre-final layer will be used for computing the prompt embeddings. - callback_on_step_end (`Callable`, `PipelineCallback`, `MultiPipelineCallbacks`, *optional*): - A function or a subclass of `PipelineCallback` or `MultiPipelineCallbacks` that is called at the end of - each denoising step during the inference. with the following arguments: `callback_on_step_end(self: - DiffusionPipeline, step: int, timestep: int, callback_kwargs: Dict)`. `callback_kwargs` will include a - list of all tensors as specified by `callback_on_step_end_tensor_inputs`. - callback_on_step_end_tensor_inputs (`List`, *optional*): - The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list - will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the - `._callback_tensor_inputs` attribute of your pipeline class. - - Examples: - - Returns: - [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`: - If `return_dict` is `True`, [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] is returned, - otherwise a `tuple` is returned where the first element is a list with the generated images and the - second element is a list of `bool`s indicating whether the corresponding generated image contains - "not-safe-for-work" (nsfw) content. - """ - - callback = kwargs.pop("callback", None) - callback_steps = kwargs.pop("callback_steps", None) - - if callback is not None: - deprecate( - "callback", - "1.0.0", - "Passing `callback` as an input argument to `__call__` is deprecated, consider using `callback_on_step_end`", - ) - if callback_steps is not None: - deprecate( - "callback_steps", - "1.0.0", - "Passing `callback_steps` as an input argument to `__call__` is deprecated, consider using `callback_on_step_end`", - ) - - if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)): - callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs - - controlnet = self.controlnet._orig_mod if is_compiled_module(self.controlnet) else self.controlnet - - # align format for control guidance - if not isinstance(control_guidance_start, list) and isinstance(control_guidance_end, list): - control_guidance_start = len(control_guidance_end) * [control_guidance_start] - elif not isinstance(control_guidance_end, list) and isinstance(control_guidance_start, list): - control_guidance_end = len(control_guidance_start) * [control_guidance_end] - elif not isinstance(control_guidance_start, list) and not isinstance(control_guidance_end, list): - mult = len(controlnet.nets) if isinstance(controlnet, MultiControlNetModel) else 1 - control_guidance_start, control_guidance_end = ( - mult * [control_guidance_start], - mult * [control_guidance_end], - ) - - # 1. Check inputs. Raise error if not correct - self.check_inputs( - prompt, - # image, - callback_steps, - negative_prompt, - prompt_embeds, - negative_prompt_embeds, - ip_adapter_image, - ip_adapter_image_embeds, - controlnet_conditioning_scale, - control_guidance_start, - control_guidance_end, - callback_on_step_end_tensor_inputs, - ) - - self._guidance_scale = guidance_scale - self._clip_skip = clip_skip - self._cross_attention_kwargs = cross_attention_kwargs - - # 2. Define call parameters - if prompt is not None and isinstance(prompt, str): - batch_size = 1 - elif prompt is not None and isinstance(prompt, list): - batch_size = len(prompt) - else: - batch_size = prompt_embeds.shape[0] - - device = self._execution_device - - if isinstance(controlnet, MultiControlNetModel) and isinstance(controlnet_conditioning_scale, float): - controlnet_conditioning_scale = [controlnet_conditioning_scale] * len(controlnet.nets) - - global_pool_conditions = ( - controlnet.config.global_pool_conditions - if isinstance(controlnet, ControlNetModel) - else controlnet.nets[0].config.global_pool_conditions - ) - guess_mode = guess_mode or global_pool_conditions - - prompt, texts = self.modify_prompt(prompt) - - # 3. Encode input prompt - text_encoder_lora_scale = ( - self.cross_attention_kwargs.get("scale", None) if self.cross_attention_kwargs is not None else None - ) - prompt_embeds, negative_prompt_embeds, text_info, np_hint = self.text_embedding_module( - prompt, - texts, - negative_prompt, - num_images_per_prompt, - mode, - draw_pos, - ) - - # For classifier free guidance, we need to do two forward passes. - # Here we concatenate the unconditional and text embeddings into a single batch - # to avoid doing two forward passes - if self.do_classifier_free_guidance: - prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds]) - - if ip_adapter_image is not None or ip_adapter_image_embeds is not None: - image_embeds = self.prepare_ip_adapter_image_embeds( - ip_adapter_image, - ip_adapter_image_embeds, - device, - batch_size * num_images_per_prompt, - self.do_classifier_free_guidance, - ) - - # 3.5 Optionally get Guidance Scale Embedding - timestep_cond = None - if self.unet.config.time_cond_proj_dim is not None: - guidance_scale_tensor = torch.tensor(self.guidance_scale - 1).repeat(batch_size * num_images_per_prompt) - timestep_cond = self.get_guidance_scale_embedding( - guidance_scale_tensor, embedding_dim=self.unet.config.time_cond_proj_dim - ).to(device=device, dtype=latents.dtype) - - # 4. Prepare image - if isinstance(controlnet, ControlNetModel): - # image = self.prepare_image( - # image=image, - # width=width, - # height=height, - # batch_size=batch_size * num_images_per_prompt, - # num_images_per_prompt=num_images_per_prompt, - # device=device, - # dtype=controlnet.dtype, - # do_classifier_free_guidance=self.do_classifier_free_guidance, - # guess_mode=guess_mode, - # ) - # height, width = image.shape[-2:] - guided_hint = self.auxiliary_latent_module( - text_info=text_info, - mode=mode, - draw_pos=draw_pos, - ori_image=ori_image, - num_images_per_prompt=num_images_per_prompt, - np_hint=np_hint, - ) - height, width = 512, 512 - # elif isinstance(controlnet, MultiControlNetModel): - # images = [] - - # # Nested lists as ControlNet condition - # if isinstance(image[0], list): - # # Transpose the nested image list - # image = [list(t) for t in zip(*image)] - - # for image_ in image: - # image_ = self.prepare_image( - # image=image_, - # width=width, - # height=height, - # batch_size=batch_size * num_images_per_prompt, - # num_images_per_prompt=num_images_per_prompt, - # device=device, - # dtype=controlnet.dtype, - # do_classifier_free_guidance=self.do_classifier_free_guidance, - # guess_mode=guess_mode, - # ) - - # images.append(image_) - - # image = images - # height, width = image[0].shape[-2:] - else: - assert False - - # 5. Prepare timesteps - timesteps, num_inference_steps = retrieve_timesteps( - self.scheduler, num_inference_steps, device, timesteps, sigmas - ) - self._num_timesteps = len(timesteps) - - # 6. Prepare latent variables - num_channels_latents = self.unet.config.in_channels - latents = self.prepare_latents( - batch_size * num_images_per_prompt, - num_channels_latents, - height, - width, - prompt_embeds.dtype, - device, - generator, - latents, - ) - - # 7. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline - extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) - - # 7.1 Add image embeds for IP-Adapter - added_cond_kwargs = ( - {"image_embeds": image_embeds} - if ip_adapter_image is not None or ip_adapter_image_embeds is not None - else None - ) - - # 7.2 Create tensor stating which controlnets to keep - controlnet_keep = [] - for i in range(len(timesteps)): - keeps = [ - 1.0 - float(i / len(timesteps) < s or (i + 1) / len(timesteps) > e) - for s, e in zip(control_guidance_start, control_guidance_end) - ] - controlnet_keep.append(keeps[0] if isinstance(controlnet, ControlNetModel) else keeps) - - # 8. Denoising loop - num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order - is_unet_compiled = is_compiled_module(self.unet) - is_controlnet_compiled = is_compiled_module(self.controlnet) - is_torch_higher_equal_2_1 = is_torch_version(">=", "2.1") - with self.progress_bar(total=num_inference_steps) as progress_bar: - for i, t in enumerate(timesteps): - # Relevant thread: - # https://dev-discuss.pytorch.org/t/cudagraphs-in-pytorch-2-0/1428 - if (is_unet_compiled and is_controlnet_compiled) and is_torch_higher_equal_2_1: - torch._inductor.cudagraph_mark_step_begin() - # expand the latents if we are doing classifier free guidance - latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents - latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) - - # controlnet(s) inference - if guess_mode and self.do_classifier_free_guidance: - # Infer ControlNet only for the conditional batch. - control_model_input = latents - control_model_input = self.scheduler.scale_model_input(control_model_input, t) - controlnet_prompt_embeds = prompt_embeds.chunk(2)[1] - else: - control_model_input = latent_model_input - controlnet_prompt_embeds = prompt_embeds - - if isinstance(controlnet_keep[i], list): - cond_scale = [c * s for c, s in zip(controlnet_conditioning_scale, controlnet_keep[i])] - else: - controlnet_cond_scale = controlnet_conditioning_scale - if isinstance(controlnet_cond_scale, list): - controlnet_cond_scale = controlnet_cond_scale[0] - cond_scale = controlnet_cond_scale * controlnet_keep[i] - - down_block_res_samples, mid_block_res_sample = self.controlnet( - control_model_input, - t, - encoder_hidden_states=controlnet_prompt_embeds, - guided_hint=guided_hint, - conditioning_scale=cond_scale, - guess_mode=guess_mode, - return_dict=False, - ) - - if guess_mode and self.do_classifier_free_guidance: - # Inferred ControlNet only for the conditional batch. - # To apply the output of ControlNet to both the unconditional and conditional batches, - # add 0 to the unconditional batch to keep it unchanged. - down_block_res_samples = [torch.cat([torch.zeros_like(d), d]) for d in down_block_res_samples] - mid_block_res_sample = torch.cat([torch.zeros_like(mid_block_res_sample), mid_block_res_sample]) - - # predict the noise residual - noise_pred = self.unet( - latent_model_input, - t, - encoder_hidden_states=prompt_embeds, - timestep_cond=timestep_cond, - cross_attention_kwargs=self.cross_attention_kwargs, - down_block_additional_residuals=down_block_res_samples, - mid_block_additional_residual=mid_block_res_sample, - added_cond_kwargs=added_cond_kwargs, - return_dict=False, - )[0] - - # perform guidance - if self.do_classifier_free_guidance: - noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) - noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond) - - # compute the previous noisy sample x_t -> x_t-1 - latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0] - - if callback_on_step_end is not None: - callback_kwargs = {} - for k in callback_on_step_end_tensor_inputs: - callback_kwargs[k] = locals()[k] - callback_outputs = callback_on_step_end(self, i, t, callback_kwargs) - - latents = callback_outputs.pop("latents", latents) - prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds) - negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds) - - # call the callback, if provided - if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): - progress_bar.update() - if callback is not None and i % callback_steps == 0: - step_idx = i // getattr(self.scheduler, "order", 1) - callback(step_idx, t, latents) - - # If we do sequential model offloading, let's offload unet and controlnet - # manually for max memory savings - if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None: - self.unet.to("cpu") - self.controlnet.to("cpu") - torch.cuda.empty_cache() - - if not output_type == "latent": - image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False, generator=generator)[ - 0 - ] - image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype) - else: - image = latents - has_nsfw_concept = None - - if has_nsfw_concept is None: - do_denormalize = [True] * image.shape[0] - else: - do_denormalize = [not has_nsfw for has_nsfw in has_nsfw_concept] - - image = self.image_processor.postprocess(image, output_type=output_type, do_denormalize=do_denormalize) - - # Offload all models - self.maybe_free_model_hooks() - - if not return_dict: - return (image, has_nsfw_concept) - - return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept) From 0fc4aabee2c01f10966671adfe1d36ff86295f9f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tolga=20Cang=C3=B6z?= Date: Fri, 21 Feb 2025 16:34:18 +0300 Subject: [PATCH 72/87] style --- examples/research_projects/anytext/anytext.py | 15 ++++++--------- 1 file changed, 6 insertions(+), 9 deletions(-) diff --git a/examples/research_projects/anytext/anytext.py b/examples/research_projects/anytext/anytext.py index d7ee7df4a31c..d033a9af7bc7 100644 --- a/examples/research_projects/anytext/anytext.py +++ b/examples/research_projects/anytext/anytext.py @@ -29,15 +29,14 @@ from typing import Any, Callable, Dict, List, Optional, Tuple, Union import cv2 -import huggingface_hub import numpy as np import PIL.Image import torch import torch.nn.functional as F from bert_tokenizer import BasicTokenizer from easydict import EasyDict as edict -from diffusers.utils.constants import HF_MODULES_CACHE from frozen_clip_embedder_t3 import FrozenCLIPEmbedderT3 +from huggingface_hub import hf_hub_download from ocr_recog.RecModel import RecModel from PIL import Image, ImageDraw, ImageFont from safetensors.torch import load_file @@ -68,10 +67,8 @@ scale_lora_layers, unscale_lora_layers, ) +from diffusers.utils.constants import HF_MODULES_CACHE from diffusers.utils.torch_utils import is_compiled_module, is_torch_version, randn_tensor -from diffusers.configuration_utils import register_to_config, ConfigMixin -from diffusers.models.modeling_utils import ModelMixin -from huggingface_hub import hf_hub_download checker = BasicTokenizer() @@ -158,7 +155,7 @@ def __init__( proj_dir = hf_hub_download( repo_id="tolgacangoz/anytext", filename="text_embedding_module/proj.safetensors", - cache_dir=HF_MODULES_CACHE + cache_dir=HF_MODULES_CACHE, ) self.proj.load_state_dict(load_file(proj_dir, device=str(embedder.device))) if use_fp16: @@ -281,7 +278,7 @@ def create_predictor(model_dir=None, model_lang="ch", device="cpu", use_fp16=Fal model_dir = hf_hub_download( repo_id="tolgacangoz/anytext", filename="text_embedding_module/OCR/ppv3_rec.pth", - cache_dir=HF_MODULES_CACHE + cache_dir=HF_MODULES_CACHE, ) if not os.path.exists(model_dir): raise ValueError("not find model file path {}".format(model_dir)) @@ -482,7 +479,7 @@ def __init__(self, font_path, use_fp16=False, device="cpu"): args["rec_char_dict_path"] = hf_hub_download( repo_id="tolgacangoz/anytext", filename="text_embedding_module/OCR/ppocr_keys_v1.txt", - cache_dir=HF_MODULES_CACHE + cache_dir=HF_MODULES_CACHE, ) args["use_fp16"] = use_fp16 self.embedding_manager.recog = TextRecognizer(args, self.text_predictor) @@ -603,7 +600,7 @@ def forward( self.embedding_manager.encode_text(text_info) negative_prompt_embeds = self.frozen_CLIP_embedder_t3.encode( - [negative_prompt or ''], embedding_manager=self.embedding_manager + [negative_prompt or ""], embedding_manager=self.embedding_manager ) return prompt_embeds, negative_prompt_embeds, text_info, np_hint From 5345702d351f0f916a04580c3294ab4ca335187f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tolga=20Cang=C3=B6z?= Date: Sat, 22 Feb 2025 11:30:14 +0300 Subject: [PATCH 73/87] [UPDATE] Revise README and example code for AnyTextPipeline integration with DiffusionPipeline --- examples/research_projects/anytext/README.md | 50 +++++++++---------- examples/research_projects/anytext/anytext.py | 19 ++++--- 2 files changed, 33 insertions(+), 36 deletions(-) diff --git a/examples/research_projects/anytext/README.md b/examples/research_projects/anytext/README.md index 661414e2d02a..2c2d2e131cdc 100644 --- a/examples/research_projects/anytext/README.md +++ b/examples/research_projects/anytext/README.md @@ -1,45 +1,43 @@ # AnyTextPipeline Pipeline -From the project [page](https://zhendong-wang.github.io/prompt-diffusion.github.io/) +From the repo [page](https://github.com/tyxsspa/AnyText) -"With a prompt consisting of a task-specific example pair of images and text guidance, and a new query image, Prompt Diffusion can comprehend the desired task and generate the corresponding output image on both seen (trained) and unseen (new) task types." +"AnyText comprises a diffusion pipeline with two primary elements: an auxiliary latent module and a text embedding module. The former uses inputs like text glyph, position, and masked image to generate latent features for text generation or editing. The latter employs an OCR model for encoding stroke data as embeddings, which blend with image caption embeddings from the tokenizer to generate texts that seamlessly integrate with the background. We employed text-control diffusion loss and text perceptual loss for training to further enhance writing accuracy." -For any usage questions, please refer to the [paper](https://arxiv.org/abs/2305.01115). - -Prepare models by converting them from the [checkpoint](https://huggingface.co/zhendongw/prompt-diffusion) - -To convert the controlnet, use cldm_v15.yaml from the [repository](https://github.com/Zhendong-Wang/Prompt-Diffusion/tree/main/models/): - -```sh -python convert_original_anytext_to_diffusers.py --checkpoint_path path-to-network-step04999.ckpt --original_config_file path-to-cldm_v15.yaml --dump_path path-to-output-directory -``` - -To learn about how to convert the fine-tuned stable diffusion model, see the [Load different Stable Diffusion formats guide](https://huggingface.co/docs/diffusers/main/en/using-diffusers/other-formats). +For any usage questions, please refer to the [paper](https://arxiv.org/abs/2311.03054). ```py import torch -from pipeline_anytext import AnyTextPipeline -from text_controlnet import AnyTextControlNetModel +from diffusers import DiffusionPipeline +from anytext_controlnet import AnyTextControlNetModel from diffusers import DDIMScheduler from diffusers.utils import load_image -controlnet = AnyTextControlNetModel.from_pretrained("tolgacangoz/anytext-controlnet", torch_dtype=torch.float16, - variant="fp16") -pipe = AnyTextPipeline.from_pretrained("tolgacangoz/anytext", controlnet=controlnet, - torch_dtype=torch.float16, variant="fp16") +# I chose a font file shared by an HF staff: +!wget https://huggingface.co/spaces/ysharma/TranslateQuotesInImageForwards/resolve/main/arial-unicode-ms.ttf + +# load control net and stable diffusion v1-5 +anytext_controlnet = AnyTextControlNetModel.from_pretrained("tolgacangoz/anytext-controlnet", torch_dtype=torch.float16, + variant="fp16",) +pipe = DiffusionPipeline.from_pretrained("tolgacangoz/anytext", font_path="arial-unicode-ms.ttf", + controlnet=anytext_controlnet, torch_dtype=torch.float16, + trust_remote_code=True, + ).to("cuda") -# speed up diffusion process with faster scheduler and memory optimization pipe.scheduler = DDIMScheduler.from_config(pipe.scheduler.config) -# uncomment following line if torch<2.0 +# uncomment following line if PyTorch>=2.0 is not installed for memory optimization #pipe.enable_xformers_memory_efficient_attention() -pipe.enable_model_cpu_offload() + +# uncomment following line if you want to offload the model to CPU for memory optimization +# also remove the `.to("cuda")` part +#pipe.enable_model_cpu_offload() + # generate image -generator = torch.Generator("cpu").manual_seed(66273235) prompt = 'photo of caramel macchiato coffee on the table, top-down perspective, with "Any" "Text" written on it using cream' -draw_pos = load_image("www.huggingface.co/a/AnyText/tree/main/examples/gen9.png") -image = pipe(prompt, num_inference_steps=20, generator=generator, mode="generate", draw_pos=draw_pos, - ).images[0] +draw_pos = load_image("https://raw.githubusercontent.com/tyxsspa/AnyText/refs/heads/main/example_images/gen9.png") +image = pipe(prompt, num_inference_steps=20, mode="generate", draw_pos=draw_pos, + ).images[0] image ``` diff --git a/examples/research_projects/anytext/anytext.py b/examples/research_projects/anytext/anytext.py index d033a9af7bc7..d1b2e55752b6 100644 --- a/examples/research_projects/anytext/anytext.py +++ b/examples/research_projects/anytext/anytext.py @@ -81,18 +81,19 @@ EXAMPLE_DOC_STRING = """ Examples: ```py - >>> from pipeline_anytext import AnyTextPipeline + >>> from diffusers import DiffusionPipeline >>> from anytext_controlnet import AnyTextControlNetModel >>> from diffusers import DDIMScheduler >>> from diffusers.utils import load_image >>> import torch >>> # load control net and stable diffusion v1-5 - >>> text_controlnet = AnyTextControlNetModel.from_pretrained("tolgacangoz/anytext-controlnet", torch_dtype=torch.float16, - ... variant="fp16",) - >>> pipe = AnyTextPipeline.from_pretrained("tolgacangoz/anytext", controlnet=text_controlnet, - ... torch_dtype=torch.float16, variant="fp16", - ... ).to("cuda") + >>> anytext_controlnet = AnyTextControlNetModel.from_pretrained("tolgacangoz/anytext-controlnet", torch_dtype=torch.float16, + ... variant="fp16",) + >>> pipe = DiffusionPipeline.from_pretrained("tolgacangoz/anytext", font_path="Arial_Unicode2.ttf", + ... controlnet=anytext_controlnet, torch_dtype=torch.float16, + ... trust_remote_code=True, + ... ).to("cuda") >>> pipe.scheduler = DDIMScheduler.from_config(pipe.scheduler.config) >>> # uncomment following line if PyTorch>=2.0 is not installed for memory optimization @@ -103,11 +104,9 @@ >>> #pipe.enable_model_cpu_offload() >>> # generate image - >>> generator = torch.Generator("cpu").manual_seed(66273235) >>> prompt = 'photo of caramel macchiato coffee on the table, top-down perspective, with "Any" "Text" written on it using cream' - >>> draw_pos = load_image("www.huggingface.co/a/AnyText/tree/main/examples/gen9.png") - >>> image = pipe(prompt, num_inference_steps=20, generator=generator, mode="generate", - ... draw_pos=draw_pos, + >>> draw_pos = load_image("https://raw.githubusercontent.com/tyxsspa/AnyText/refs/heads/main/example_images/gen9.png") + >>> image = pipe(prompt, num_inference_steps=20, mode="generate", draw_pos=draw_pos, ... ).images[0] >>> image ``` From 5b73a1d97183cd5a88f1fe29945e9b4d1f782b62 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tolga=20Cang=C3=B6z?= Date: Sat, 22 Feb 2025 11:32:12 +0300 Subject: [PATCH 74/87] [UPDATE] Update example code in anytext.py to use correct font file and improve clarity --- examples/research_projects/anytext/anytext.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/examples/research_projects/anytext/anytext.py b/examples/research_projects/anytext/anytext.py index d1b2e55752b6..fdf52792f045 100644 --- a/examples/research_projects/anytext/anytext.py +++ b/examples/research_projects/anytext/anytext.py @@ -81,16 +81,19 @@ EXAMPLE_DOC_STRING = """ Examples: ```py + >>> import torch >>> from diffusers import DiffusionPipeline >>> from anytext_controlnet import AnyTextControlNetModel >>> from diffusers import DDIMScheduler >>> from diffusers.utils import load_image - >>> import torch + + >>> # I chose a font file shared by an HF staff: + >>> !wget https://huggingface.co/spaces/ysharma/TranslateQuotesInImageForwards/resolve/main/arial-unicode-ms.ttf >>> # load control net and stable diffusion v1-5 >>> anytext_controlnet = AnyTextControlNetModel.from_pretrained("tolgacangoz/anytext-controlnet", torch_dtype=torch.float16, ... variant="fp16",) - >>> pipe = DiffusionPipeline.from_pretrained("tolgacangoz/anytext", font_path="Arial_Unicode2.ttf", + >>> pipe = DiffusionPipeline.from_pretrained("tolgacangoz/anytext", font_path="arial-unicode-ms.ttf", ... controlnet=anytext_controlnet, torch_dtype=torch.float16, ... trust_remote_code=True, ... ).to("cuda") From 7f87755d8bdc3717a599bd259474aa753b9f3ee8 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tolga=20Cang=C3=B6z?= Date: Sat, 22 Feb 2025 12:26:08 +0300 Subject: [PATCH 75/87] down --- ...convert_anytext_controlnet_to_diffusers.py | 111 - .../anytext/convert_from_ckpt.py | 1872 ----------------- .../convert_original_anytext_to_diffusers.py | 1 - 3 files changed, 1984 deletions(-) delete mode 100644 examples/research_projects/anytext/convert_anytext_controlnet_to_diffusers.py delete mode 100644 examples/research_projects/anytext/convert_from_ckpt.py delete mode 100644 examples/research_projects/anytext/convert_original_anytext_to_diffusers.py diff --git a/examples/research_projects/anytext/convert_anytext_controlnet_to_diffusers.py b/examples/research_projects/anytext/convert_anytext_controlnet_to_diffusers.py deleted file mode 100644 index 52c3b5281b41..000000000000 --- a/examples/research_projects/anytext/convert_anytext_controlnet_to_diffusers.py +++ /dev/null @@ -1,111 +0,0 @@ -# coding=utf-8 -# Copyright 2024 The HuggingFace Inc. team. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -"""Conversion script for stable diffusion checkpoints which _only_ contain a controlnet.""" - -import argparse - -from convert_from_ckpt import download_controlnet_from_original_ckpt - - -if __name__ == "__main__": - parser = argparse.ArgumentParser() - - parser.add_argument( - "--checkpoint_path", default=None, type=str, required=False, help="Path to the checkpoint to convert." - ) - parser.add_argument( - "--original_config_file", - type=str, - required=False, - help="The YAML config file corresponding to the original architecture.", - ) - parser.add_argument( - "--num_in_channels", - default=None, - type=int, - help="The number of input channels. If `None` number of input channels will be automatically inferred.", - ) - parser.add_argument( - "--image_size", - default=512, - type=int, - help=( - "The image size that the model was trained on. Use 512 for Stable Diffusion v1.X and Stable Diffusion v2" - " Base. Use 768 for Stable Diffusion v2." - ), - ) - parser.add_argument( - "--extract_ema", - action="store_true", - help=( - "Only relevant for checkpoints that have both EMA and non-EMA weights. Whether to extract the EMA weights" - " or not. Defaults to `False`. Add `--extract_ema` to extract the EMA weights. EMA weights usually yield" - " higher quality images for inference. Non-EMA weights are usually better to continue fine-tuning." - ), - ) - parser.add_argument( - "--upcast_attention", - action="store_true", - help=( - "Whether the attention computation should always be upcasted. This is necessary when running stable" - " diffusion 2.1." - ), - ) - parser.add_argument( - "--from_safetensors", - action="store_true", - help="If `--checkpoint_path` is in `safetensors` format, load checkpoint with safetensors instead of PyTorch.", - ) - parser.add_argument( - "--to_safetensors", - action="store_true", - help="Whether to store pipeline in safetensors format or not.", - ) - parser.add_argument("--dump_path", default=None, type=str, required=False, help="Path to the output model.") - parser.add_argument("--device", type=str, help="Device to use (e.g. cpu, cuda:0, cuda:1, etc.)") - - # small workaround to get argparser to parse a boolean input as either true _or_ false - def parse_bool(string): - if string == "True": - return True - elif string == "False": - return False - else: - raise ValueError(f"could not parse string as bool {string}") - - parser.add_argument( - "--use_linear_projection", help="Override for use linear projection", required=False, type=parse_bool - ) - - parser.add_argument("--cross_attention_dim", help="Override for cross attention_dim", required=False, type=int) - - args = parser.parse_args() - - controlnet = download_controlnet_from_original_ckpt( - checkpoint_path="/home/x/Documents/gits/AnyText/anytext_v1.1.ckpt", - original_config_file="/home/x/Documents/gits/AnyText/models_yaml/anytext_sd15.yaml", - image_size=args.image_size, - extract_ema=args.extract_ema, - num_in_channels=args.num_in_channels, - upcast_attention=args.upcast_attention, - from_safetensors=args.from_safetensors, - device="cpu", - use_linear_projection=args.use_linear_projection, - cross_attention_dim=args.cross_attention_dim, - ) - - controlnet.save_pretrained( - "/home/x/Documents/gits/diffusers/examples/research_projects/anytext", safe_serialization=False - ) diff --git a/examples/research_projects/anytext/convert_from_ckpt.py b/examples/research_projects/anytext/convert_from_ckpt.py deleted file mode 100644 index 172afb30a4f5..000000000000 --- a/examples/research_projects/anytext/convert_from_ckpt.py +++ /dev/null @@ -1,1872 +0,0 @@ -# coding=utf-8 -# Copyright 2024 The HuggingFace Inc. team. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -"""Conversion script for the Stable Diffusion checkpoints.""" - -import re -from contextlib import nullcontext -from io import BytesIO -from typing import Dict, Optional, Union - -import requests -import torch -import yaml -from transformers import ( - AutoFeatureExtractor, - BertTokenizerFast, - CLIPImageProcessor, - CLIPTextConfig, - CLIPTextModel, - CLIPTextModelWithProjection, - CLIPTokenizer, - CLIPVisionConfig, - CLIPVisionModelWithProjection, -) - -from diffusers.models import ( - AutoencoderKL, - ControlNetModel, - PriorTransformer, - UNet2DConditionModel, -) -from diffusers.pipelines.latent_diffusion.pipeline_latent_diffusion import LDMBertConfig, LDMBertModel -from diffusers.pipelines.paint_by_example import PaintByExampleImageEncoder -from diffusers.pipelines.pipeline_utils import DiffusionPipeline -from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker -from diffusers.pipelines.stable_diffusion.stable_unclip_image_normalizer import StableUnCLIPImageNormalizer -from diffusers.schedulers import ( - DDIMScheduler, - DDPMScheduler, - DPMSolverMultistepScheduler, - EulerAncestralDiscreteScheduler, - EulerDiscreteScheduler, - HeunDiscreteScheduler, - LMSDiscreteScheduler, - PNDMScheduler, - UnCLIPScheduler, -) -from diffusers.utils import is_accelerate_available, logging - - -if is_accelerate_available(): - from accelerate import init_empty_weights - from accelerate.utils import set_module_tensor_to_device - -logger = logging.get_logger(__name__) # pylint: disable=invalid-name - - -def shave_segments(path, n_shave_prefix_segments=1): - """ - Removes segments. Positive values shave the first segments, negative shave the last segments. - """ - if n_shave_prefix_segments >= 0: - return ".".join(path.split(".")[n_shave_prefix_segments:]) - else: - return ".".join(path.split(".")[:n_shave_prefix_segments]) - - -def renew_resnet_paths(old_list, n_shave_prefix_segments=0): - """ - Updates paths inside resnets to the new naming scheme (local renaming) - """ - mapping = [] - for old_item in old_list: - new_item = old_item.replace("in_layers.0", "norm1") - new_item = new_item.replace("in_layers.2", "conv1") - - new_item = new_item.replace("out_layers.0", "norm2") - new_item = new_item.replace("out_layers.3", "conv2") - - new_item = new_item.replace("emb_layers.1", "time_emb_proj") - new_item = new_item.replace("skip_connection", "conv_shortcut") - - new_item = shave_segments(new_item, n_shave_prefix_segments=n_shave_prefix_segments) - - mapping.append({"old": old_item, "new": new_item}) - - return mapping - - -def renew_vae_resnet_paths(old_list, n_shave_prefix_segments=0): - """ - Updates paths inside resnets to the new naming scheme (local renaming) - """ - mapping = [] - for old_item in old_list: - new_item = old_item - - new_item = new_item.replace("nin_shortcut", "conv_shortcut") - new_item = shave_segments(new_item, n_shave_prefix_segments=n_shave_prefix_segments) - - mapping.append({"old": old_item, "new": new_item}) - - return mapping - - -def renew_attention_paths(old_list, n_shave_prefix_segments=0): - """ - Updates paths inside attentions to the new naming scheme (local renaming) - """ - mapping = [] - for old_item in old_list: - new_item = old_item - - # new_item = new_item.replace('norm.weight', 'group_norm.weight') - # new_item = new_item.replace('norm.bias', 'group_norm.bias') - - # new_item = new_item.replace('proj_out.weight', 'proj_attn.weight') - # new_item = new_item.replace('proj_out.bias', 'proj_attn.bias') - - # new_item = shave_segments(new_item, n_shave_prefix_segments=n_shave_prefix_segments) - - mapping.append({"old": old_item, "new": new_item}) - - return mapping - - -def renew_vae_attention_paths(old_list, n_shave_prefix_segments=0): - """ - Updates paths inside attentions to the new naming scheme (local renaming) - """ - mapping = [] - for old_item in old_list: - new_item = old_item - - new_item = new_item.replace("norm.weight", "group_norm.weight") - new_item = new_item.replace("norm.bias", "group_norm.bias") - - new_item = new_item.replace("q.weight", "to_q.weight") - new_item = new_item.replace("q.bias", "to_q.bias") - - new_item = new_item.replace("k.weight", "to_k.weight") - new_item = new_item.replace("k.bias", "to_k.bias") - - new_item = new_item.replace("v.weight", "to_v.weight") - new_item = new_item.replace("v.bias", "to_v.bias") - - new_item = new_item.replace("proj_out.weight", "to_out.0.weight") - new_item = new_item.replace("proj_out.bias", "to_out.0.bias") - - new_item = shave_segments(new_item, n_shave_prefix_segments=n_shave_prefix_segments) - - mapping.append({"old": old_item, "new": new_item}) - - return mapping - - -def assign_to_checkpoint( - paths, checkpoint, old_checkpoint, attention_paths_to_split=None, additional_replacements=None, config=None -): - """ - This does the final conversion step: take locally converted weights and apply a global renaming to them. It splits - attention layers, and takes into account additional replacements that may arise. - - Assigns the weights to the new checkpoint. - """ - assert isinstance(paths, list), "Paths should be a list of dicts containing 'old' and 'new' keys." - - # Splits the attention layers into three variables. - if attention_paths_to_split is not None: - for path, path_map in attention_paths_to_split.items(): - old_tensor = old_checkpoint[path] - channels = old_tensor.shape[0] // 3 - - target_shape = (-1, channels) if len(old_tensor.shape) == 3 else (-1) - - num_heads = old_tensor.shape[0] // config["num_head_channels"] // 3 - - old_tensor = old_tensor.reshape((num_heads, 3 * channels // num_heads) + old_tensor.shape[1:]) - query, key, value = old_tensor.split(channels // num_heads, dim=1) - - checkpoint[path_map["query"]] = query.reshape(target_shape) - checkpoint[path_map["key"]] = key.reshape(target_shape) - checkpoint[path_map["value"]] = value.reshape(target_shape) - - for path in paths: - new_path = path["new"] - - # These have already been assigned - if attention_paths_to_split is not None and new_path in attention_paths_to_split: - continue - - # Global renaming happens here - new_path = new_path.replace("middle_block.0", "mid_block.resnets.0") - new_path = new_path.replace("middle_block.1", "mid_block.attentions.0") - new_path = new_path.replace("middle_block.2", "mid_block.resnets.1") - - if additional_replacements is not None: - for replacement in additional_replacements: - new_path = new_path.replace(replacement["old"], replacement["new"]) - - # proj_attn.weight has to be converted from conv 1D to linear - is_attn_weight = "proj_attn.weight" in new_path or ("attentions" in new_path and "to_" in new_path) - shape = old_checkpoint[path["old"]].shape - if is_attn_weight and len(shape) == 3: - checkpoint[new_path] = old_checkpoint[path["old"]][:, :, 0] - elif is_attn_weight and len(shape) == 4: - checkpoint[new_path] = old_checkpoint[path["old"]][:, :, 0, 0] - else: - checkpoint[new_path] = old_checkpoint[path["old"]] - - -def conv_attn_to_linear(checkpoint): - keys = list(checkpoint.keys()) - attn_keys = ["query.weight", "key.weight", "value.weight"] - for key in keys: - if ".".join(key.split(".")[-2:]) in attn_keys: - if checkpoint[key].ndim > 2: - checkpoint[key] = checkpoint[key][:, :, 0, 0] - elif "proj_attn.weight" in key: - if checkpoint[key].ndim > 2: - checkpoint[key] = checkpoint[key][:, :, 0] - - -def create_unet_diffusers_config(original_config, image_size: int, controlnet=False): - """ - Creates a config for the diffusers based on the config of the LDM model. - """ - if controlnet: - unet_params = original_config["model"]["params"]["control_stage_config"]["params"] - else: - if ( - "unet_config" in original_config["model"]["params"] - and original_config["model"]["params"]["unet_config"] is not None - ): - unet_params = original_config["model"]["params"]["unet_config"]["params"] - else: - unet_params = original_config["model"]["params"]["network_config"]["params"] - - vae_params = original_config["model"]["params"]["first_stage_config"]["params"]["ddconfig"] - - block_out_channels = [unet_params["model_channels"] * mult for mult in unet_params["channel_mult"]] - - down_block_types = [] - resolution = 1 - for i in range(len(block_out_channels)): - block_type = "CrossAttnDownBlock2D" if resolution in unet_params["attention_resolutions"] else "DownBlock2D" - down_block_types.append(block_type) - if i != len(block_out_channels) - 1: - resolution *= 2 - - up_block_types = [] - for i in range(len(block_out_channels)): - block_type = "CrossAttnUpBlock2D" if resolution in unet_params["attention_resolutions"] else "UpBlock2D" - up_block_types.append(block_type) - resolution //= 2 - - if unet_params["transformer_depth"] is not None: - transformer_layers_per_block = ( - unet_params["transformer_depth"] - if isinstance(unet_params["transformer_depth"], int) - else list(unet_params["transformer_depth"]) - ) - else: - transformer_layers_per_block = 1 - - vae_scale_factor = 2 ** (len(vae_params["ch_mult"]) - 1) - - head_dim = unet_params["num_heads"] if "num_heads" in unet_params else None - use_linear_projection = ( - unet_params["use_linear_in_transformer"] if "use_linear_in_transformer" in unet_params else False - ) - if use_linear_projection: - # stable diffusion 2-base-512 and 2-768 - if head_dim is None: - head_dim_mult = unet_params["model_channels"] // unet_params["num_head_channels"] - head_dim = [head_dim_mult * c for c in list(unet_params["channel_mult"])] - - class_embed_type = None - addition_embed_type = None - addition_time_embed_dim = None - projection_class_embeddings_input_dim = None - context_dim = None - - if unet_params["context_dim"] is not None: - context_dim = ( - unet_params["context_dim"] - if isinstance(unet_params["context_dim"], int) - else unet_params["context_dim"][0] - ) - - if "num_classes" in unet_params: - if unet_params["num_classes"] == "sequential": - if context_dim in [2048, 1280]: - # SDXL - addition_embed_type = "text_time" - addition_time_embed_dim = 256 - else: - class_embed_type = "projection" - assert "adm_in_channels" in unet_params - projection_class_embeddings_input_dim = unet_params["adm_in_channels"] - - config = { - "sample_size": image_size // vae_scale_factor, - "in_channels": unet_params["in_channels"], - "down_block_types": tuple(down_block_types), - "block_out_channels": tuple(block_out_channels), - "layers_per_block": unet_params["num_res_blocks"], - "cross_attention_dim": context_dim, - "attention_head_dim": head_dim, - "use_linear_projection": use_linear_projection, - "class_embed_type": class_embed_type, - "addition_embed_type": addition_embed_type, - "addition_time_embed_dim": addition_time_embed_dim, - "projection_class_embeddings_input_dim": projection_class_embeddings_input_dim, - "transformer_layers_per_block": transformer_layers_per_block, - } - - if "disable_self_attentions" in unet_params: - config["only_cross_attention"] = unet_params["disable_self_attentions"] - - if "num_classes" in unet_params and isinstance(unet_params["num_classes"], int): - config["num_class_embeds"] = unet_params["num_classes"] - - if not controlnet: - config["out_channels"] = unet_params["out_channels"] - config["up_block_types"] = tuple(up_block_types) - - return config - - -def create_vae_diffusers_config(original_config, image_size: int): - """ - Creates a config for the diffusers based on the config of the LDM model. - """ - vae_params = original_config["model"]["params"]["first_stage_config"]["params"]["ddconfig"] - _ = original_config["model"]["params"]["first_stage_config"]["params"]["embed_dim"] - - block_out_channels = [vae_params["ch"] * mult for mult in vae_params["ch_mult"]] - down_block_types = ["DownEncoderBlock2D"] * len(block_out_channels) - up_block_types = ["UpDecoderBlock2D"] * len(block_out_channels) - - config = { - "sample_size": image_size, - "in_channels": vae_params["in_channels"], - "out_channels": vae_params["out_ch"], - "down_block_types": tuple(down_block_types), - "up_block_types": tuple(up_block_types), - "block_out_channels": tuple(block_out_channels), - "latent_channels": vae_params["z_channels"], - "layers_per_block": vae_params["num_res_blocks"], - } - return config - - -def create_diffusers_schedular(original_config): - schedular = DDIMScheduler( - num_train_timesteps=original_config["model"]["params"]["timesteps"], - beta_start=original_config["model"]["params"]["linear_start"], - beta_end=original_config["model"]["params"]["linear_end"], - beta_schedule="scaled_linear", - ) - return schedular - - -def create_ldm_bert_config(original_config): - bert_params = original_config["model"]["params"]["cond_stage_config"]["params"] - config = LDMBertConfig( - d_model=bert_params.n_embed, - encoder_layers=bert_params.n_layer, - encoder_ffn_dim=bert_params.n_embed * 4, - ) - return config - - -def convert_ldm_unet_checkpoint( - checkpoint, config, path=None, extract_ema=False, controlnet=False, skip_extract_state_dict=False -): - """ - Takes a state dict and a config, and returns a converted checkpoint. - """ - - if skip_extract_state_dict: - unet_state_dict = checkpoint - else: - # extract state_dict for UNet - unet_state_dict = {} - keys = list(checkpoint.keys()) - - if controlnet: - unet_key = "control_model." - else: - unet_key = "model.diffusion_model." - - # at least a 100 parameters have to start with `model_ema` in order for the checkpoint to be EMA - if sum(k.startswith("model_ema") for k in keys) > 100 and extract_ema: - logger.warning(f"Checkpoint {path} has both EMA and non-EMA weights.") - logger.warning( - "In this conversion only the EMA weights are extracted. If you want to instead extract the non-EMA" - " weights (useful to continue fine-tuning), please make sure to remove the `--extract_ema` flag." - ) - for key in keys: - if key.startswith("model.diffusion_model"): - flat_ema_key = "model_ema." + "".join(key.split(".")[1:]) - unet_state_dict[key.replace(unet_key, "")] = checkpoint.pop(flat_ema_key) - else: - if sum(k.startswith("model_ema") for k in keys) > 100: - logger.warning( - "In this conversion only the non-EMA weights are extracted. If you want to instead extract the EMA" - " weights (usually better for inference), please make sure to add the `--extract_ema` flag." - ) - - for key in keys: - if ( - key.startswith(unet_key) - and not key.startswith("control_model.glyph_block") - and not key.startswith("control_model.position_block") - and not key.startswith("control_model.fuse_block") - ): - unet_state_dict[key.replace(unet_key, "")] = checkpoint.pop(key) - - new_checkpoint = {} - - new_checkpoint["time_embedding.linear_1.weight"] = unet_state_dict["time_embed.0.weight"] - new_checkpoint["time_embedding.linear_1.bias"] = unet_state_dict["time_embed.0.bias"] - new_checkpoint["time_embedding.linear_2.weight"] = unet_state_dict["time_embed.2.weight"] - new_checkpoint["time_embedding.linear_2.bias"] = unet_state_dict["time_embed.2.bias"] - - if config["class_embed_type"] is None: - # No parameters to port - ... - elif config["class_embed_type"] == "timestep" or config["class_embed_type"] == "projection": - new_checkpoint["class_embedding.linear_1.weight"] = unet_state_dict["label_emb.0.0.weight"] - new_checkpoint["class_embedding.linear_1.bias"] = unet_state_dict["label_emb.0.0.bias"] - new_checkpoint["class_embedding.linear_2.weight"] = unet_state_dict["label_emb.0.2.weight"] - new_checkpoint["class_embedding.linear_2.bias"] = unet_state_dict["label_emb.0.2.bias"] - else: - raise NotImplementedError(f"Not implemented `class_embed_type`: {config['class_embed_type']}") - - if config["addition_embed_type"] == "text_time": - new_checkpoint["add_embedding.linear_1.weight"] = unet_state_dict["label_emb.0.0.weight"] - new_checkpoint["add_embedding.linear_1.bias"] = unet_state_dict["label_emb.0.0.bias"] - new_checkpoint["add_embedding.linear_2.weight"] = unet_state_dict["label_emb.0.2.weight"] - new_checkpoint["add_embedding.linear_2.bias"] = unet_state_dict["label_emb.0.2.bias"] - - # Relevant to StableDiffusionUpscalePipeline - if "num_class_embeds" in config: - if (config["num_class_embeds"] is not None) and ("label_emb.weight" in unet_state_dict): - new_checkpoint["class_embedding.weight"] = unet_state_dict["label_emb.weight"] - - new_checkpoint["conv_in.weight"] = unet_state_dict["input_blocks.0.0.weight"] - new_checkpoint["conv_in.bias"] = unet_state_dict["input_blocks.0.0.bias"] - - if not controlnet: - new_checkpoint["conv_norm_out.weight"] = unet_state_dict["out.0.weight"] - new_checkpoint["conv_norm_out.bias"] = unet_state_dict["out.0.bias"] - new_checkpoint["conv_out.weight"] = unet_state_dict["out.2.weight"] - new_checkpoint["conv_out.bias"] = unet_state_dict["out.2.bias"] - - # Retrieves the keys for the input blocks only - num_input_blocks = len({".".join(layer.split(".")[:2]) for layer in unet_state_dict if "input_blocks" in layer}) - input_blocks = { - layer_id: [key for key in unet_state_dict if f"input_blocks.{layer_id}" in key] - for layer_id in range(num_input_blocks) - } - - # Retrieves the keys for the middle blocks only - num_middle_blocks = len({".".join(layer.split(".")[:2]) for layer in unet_state_dict if "middle_block" in layer}) - middle_blocks = { - layer_id: [key for key in unet_state_dict if f"middle_block.{layer_id}" in key] - for layer_id in range(num_middle_blocks) - } - - # Retrieves the keys for the output blocks only - num_output_blocks = len({".".join(layer.split(".")[:2]) for layer in unet_state_dict if "output_blocks" in layer}) - output_blocks = { - layer_id: [key for key in unet_state_dict if f"output_blocks.{layer_id}" in key] - for layer_id in range(num_output_blocks) - } - - for i in range(1, num_input_blocks): - block_id = (i - 1) // (config["layers_per_block"] + 1) - layer_in_block_id = (i - 1) % (config["layers_per_block"] + 1) - - resnets = [ - key for key in input_blocks[i] if f"input_blocks.{i}.0" in key and f"input_blocks.{i}.0.op" not in key - ] - attentions = [key for key in input_blocks[i] if f"input_blocks.{i}.1" in key] - - if f"input_blocks.{i}.0.op.weight" in unet_state_dict: - new_checkpoint[f"down_blocks.{block_id}.downsamplers.0.conv.weight"] = unet_state_dict.pop( - f"input_blocks.{i}.0.op.weight" - ) - new_checkpoint[f"down_blocks.{block_id}.downsamplers.0.conv.bias"] = unet_state_dict.pop( - f"input_blocks.{i}.0.op.bias" - ) - - paths = renew_resnet_paths(resnets) - meta_path = {"old": f"input_blocks.{i}.0", "new": f"down_blocks.{block_id}.resnets.{layer_in_block_id}"} - assign_to_checkpoint( - paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config - ) - - if len(attentions): - paths = renew_attention_paths(attentions) - - meta_path = {"old": f"input_blocks.{i}.1", "new": f"down_blocks.{block_id}.attentions.{layer_in_block_id}"} - assign_to_checkpoint( - paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config - ) - - resnet_0 = middle_blocks[0] - attentions = middle_blocks[1] - resnet_1 = middle_blocks[2] - - resnet_0_paths = renew_resnet_paths(resnet_0) - assign_to_checkpoint(resnet_0_paths, new_checkpoint, unet_state_dict, config=config) - - resnet_1_paths = renew_resnet_paths(resnet_1) - assign_to_checkpoint(resnet_1_paths, new_checkpoint, unet_state_dict, config=config) - - attentions_paths = renew_attention_paths(attentions) - meta_path = {"old": "middle_block.1", "new": "mid_block.attentions.0"} - assign_to_checkpoint( - attentions_paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config - ) - - for i in range(num_output_blocks): - block_id = i // (config["layers_per_block"] + 1) - layer_in_block_id = i % (config["layers_per_block"] + 1) - output_block_layers = [shave_segments(name, 2) for name in output_blocks[i]] - output_block_list = {} - - for layer in output_block_layers: - layer_id, layer_name = layer.split(".")[0], shave_segments(layer, 1) - if layer_id in output_block_list: - output_block_list[layer_id].append(layer_name) - else: - output_block_list[layer_id] = [layer_name] - - if len(output_block_list) > 1: - resnets = [key for key in output_blocks[i] if f"output_blocks.{i}.0" in key] - attentions = [key for key in output_blocks[i] if f"output_blocks.{i}.1" in key] - - resnet_0_paths = renew_resnet_paths(resnets) - paths = renew_resnet_paths(resnets) - - meta_path = {"old": f"output_blocks.{i}.0", "new": f"up_blocks.{block_id}.resnets.{layer_in_block_id}"} - assign_to_checkpoint( - paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config - ) - - output_block_list = {k: sorted(v) for k, v in sorted(output_block_list.items())} - if ["conv.bias", "conv.weight"] in output_block_list.values(): - index = list(output_block_list.values()).index(["conv.bias", "conv.weight"]) - new_checkpoint[f"up_blocks.{block_id}.upsamplers.0.conv.weight"] = unet_state_dict[ - f"output_blocks.{i}.{index}.conv.weight" - ] - new_checkpoint[f"up_blocks.{block_id}.upsamplers.0.conv.bias"] = unet_state_dict[ - f"output_blocks.{i}.{index}.conv.bias" - ] - - # Clear attentions as they have been attributed above. - if len(attentions) == 2: - attentions = [] - - if len(attentions): - paths = renew_attention_paths(attentions) - meta_path = { - "old": f"output_blocks.{i}.1", - "new": f"up_blocks.{block_id}.attentions.{layer_in_block_id}", - } - assign_to_checkpoint( - paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config - ) - else: - resnet_0_paths = renew_resnet_paths(output_block_layers, n_shave_prefix_segments=1) - for path in resnet_0_paths: - old_path = ".".join(["output_blocks", str(i), path["old"]]) - new_path = ".".join(["up_blocks", str(block_id), "resnets", str(layer_in_block_id), path["new"]]) - - new_checkpoint[new_path] = unet_state_dict[old_path] - - if controlnet: - # # conditioning embedding - - # orig_index = 0 - - # new_checkpoint["controlnet_cond_embedding.conv_in.weight"] = unet_state_dict.pop( - # f"input_hint_block.{orig_index}.weight" - # ) - # new_checkpoint["controlnet_cond_embedding.conv_in.bias"] = unet_state_dict.pop( - # f"input_hint_block.{orig_index}.bias" - # ) - - # orig_index += 2 - - # diffusers_index = 0 - - # while diffusers_index < 6: - # new_checkpoint[f"controlnet_cond_embedding.blocks.{diffusers_index}.weight"] = unet_state_dict.pop( - # f"input_hint_block.{orig_index}.weight" - # ) - # new_checkpoint[f"controlnet_cond_embedding.blocks.{diffusers_index}.bias"] = unet_state_dict.pop( - # f"input_hint_block.{orig_index}.bias" - # ) - # diffusers_index += 1 - # orig_index += 2 - - # new_checkpoint["controlnet_cond_embedding.conv_out.weight"] = unet_state_dict.pop( - # f"input_hint_block.{orig_index}.weight" - # ) - # new_checkpoint["controlnet_cond_embedding.conv_out.bias"] = unet_state_dict.pop( - # f"input_hint_block.{orig_index}.bias" - # ) - - # down blocks - for i in range(num_input_blocks): - new_checkpoint[f"controlnet_down_blocks.{i}.weight"] = unet_state_dict.pop(f"zero_convs.{i}.0.weight") - new_checkpoint[f"controlnet_down_blocks.{i}.bias"] = unet_state_dict.pop(f"zero_convs.{i}.0.bias") - - # mid block - new_checkpoint["controlnet_mid_block.weight"] = unet_state_dict.pop("middle_block_out.0.weight") - new_checkpoint["controlnet_mid_block.bias"] = unet_state_dict.pop("middle_block_out.0.bias") - - return new_checkpoint - - -def convert_ldm_vae_checkpoint(checkpoint, config): - # extract state dict for VAE - vae_state_dict = {} - keys = list(checkpoint.keys()) - vae_key = "first_stage_model." if any(k.startswith("first_stage_model.") for k in keys) else "" - for key in keys: - if key.startswith(vae_key): - vae_state_dict[key.replace(vae_key, "")] = checkpoint.get(key) - - new_checkpoint = {} - - new_checkpoint["encoder.conv_in.weight"] = vae_state_dict["encoder.conv_in.weight"] - new_checkpoint["encoder.conv_in.bias"] = vae_state_dict["encoder.conv_in.bias"] - new_checkpoint["encoder.conv_out.weight"] = vae_state_dict["encoder.conv_out.weight"] - new_checkpoint["encoder.conv_out.bias"] = vae_state_dict["encoder.conv_out.bias"] - new_checkpoint["encoder.conv_norm_out.weight"] = vae_state_dict["encoder.norm_out.weight"] - new_checkpoint["encoder.conv_norm_out.bias"] = vae_state_dict["encoder.norm_out.bias"] - - new_checkpoint["decoder.conv_in.weight"] = vae_state_dict["decoder.conv_in.weight"] - new_checkpoint["decoder.conv_in.bias"] = vae_state_dict["decoder.conv_in.bias"] - new_checkpoint["decoder.conv_out.weight"] = vae_state_dict["decoder.conv_out.weight"] - new_checkpoint["decoder.conv_out.bias"] = vae_state_dict["decoder.conv_out.bias"] - new_checkpoint["decoder.conv_norm_out.weight"] = vae_state_dict["decoder.norm_out.weight"] - new_checkpoint["decoder.conv_norm_out.bias"] = vae_state_dict["decoder.norm_out.bias"] - - new_checkpoint["quant_conv.weight"] = vae_state_dict["quant_conv.weight"] - new_checkpoint["quant_conv.bias"] = vae_state_dict["quant_conv.bias"] - new_checkpoint["post_quant_conv.weight"] = vae_state_dict["post_quant_conv.weight"] - new_checkpoint["post_quant_conv.bias"] = vae_state_dict["post_quant_conv.bias"] - - # Retrieves the keys for the encoder down blocks only - num_down_blocks = len({".".join(layer.split(".")[:3]) for layer in vae_state_dict if "encoder.down" in layer}) - down_blocks = { - layer_id: [key for key in vae_state_dict if f"down.{layer_id}" in key] for layer_id in range(num_down_blocks) - } - - # Retrieves the keys for the decoder up blocks only - num_up_blocks = len({".".join(layer.split(".")[:3]) for layer in vae_state_dict if "decoder.up" in layer}) - up_blocks = { - layer_id: [key for key in vae_state_dict if f"up.{layer_id}" in key] for layer_id in range(num_up_blocks) - } - - for i in range(num_down_blocks): - resnets = [key for key in down_blocks[i] if f"down.{i}" in key and f"down.{i}.downsample" not in key] - - if f"encoder.down.{i}.downsample.conv.weight" in vae_state_dict: - new_checkpoint[f"encoder.down_blocks.{i}.downsamplers.0.conv.weight"] = vae_state_dict.pop( - f"encoder.down.{i}.downsample.conv.weight" - ) - new_checkpoint[f"encoder.down_blocks.{i}.downsamplers.0.conv.bias"] = vae_state_dict.pop( - f"encoder.down.{i}.downsample.conv.bias" - ) - - paths = renew_vae_resnet_paths(resnets) - meta_path = {"old": f"down.{i}.block", "new": f"down_blocks.{i}.resnets"} - assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config) - - mid_resnets = [key for key in vae_state_dict if "encoder.mid.block" in key] - num_mid_res_blocks = 2 - for i in range(1, num_mid_res_blocks + 1): - resnets = [key for key in mid_resnets if f"encoder.mid.block_{i}" in key] - - paths = renew_vae_resnet_paths(resnets) - meta_path = {"old": f"mid.block_{i}", "new": f"mid_block.resnets.{i - 1}"} - assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config) - - mid_attentions = [key for key in vae_state_dict if "encoder.mid.attn" in key] - paths = renew_vae_attention_paths(mid_attentions) - meta_path = {"old": "mid.attn_1", "new": "mid_block.attentions.0"} - assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config) - conv_attn_to_linear(new_checkpoint) - - for i in range(num_up_blocks): - block_id = num_up_blocks - 1 - i - resnets = [ - key for key in up_blocks[block_id] if f"up.{block_id}" in key and f"up.{block_id}.upsample" not in key - ] - - if f"decoder.up.{block_id}.upsample.conv.weight" in vae_state_dict: - new_checkpoint[f"decoder.up_blocks.{i}.upsamplers.0.conv.weight"] = vae_state_dict[ - f"decoder.up.{block_id}.upsample.conv.weight" - ] - new_checkpoint[f"decoder.up_blocks.{i}.upsamplers.0.conv.bias"] = vae_state_dict[ - f"decoder.up.{block_id}.upsample.conv.bias" - ] - - paths = renew_vae_resnet_paths(resnets) - meta_path = {"old": f"up.{block_id}.block", "new": f"up_blocks.{i}.resnets"} - assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config) - - mid_resnets = [key for key in vae_state_dict if "decoder.mid.block" in key] - num_mid_res_blocks = 2 - for i in range(1, num_mid_res_blocks + 1): - resnets = [key for key in mid_resnets if f"decoder.mid.block_{i}" in key] - - paths = renew_vae_resnet_paths(resnets) - meta_path = {"old": f"mid.block_{i}", "new": f"mid_block.resnets.{i - 1}"} - assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config) - - mid_attentions = [key for key in vae_state_dict if "decoder.mid.attn" in key] - paths = renew_vae_attention_paths(mid_attentions) - meta_path = {"old": "mid.attn_1", "new": "mid_block.attentions.0"} - assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config) - conv_attn_to_linear(new_checkpoint) - return new_checkpoint - - -def convert_ldm_bert_checkpoint(checkpoint, config): - def _copy_attn_layer(hf_attn_layer, pt_attn_layer): - hf_attn_layer.q_proj.weight.data = pt_attn_layer.to_q.weight - hf_attn_layer.k_proj.weight.data = pt_attn_layer.to_k.weight - hf_attn_layer.v_proj.weight.data = pt_attn_layer.to_v.weight - - hf_attn_layer.out_proj.weight = pt_attn_layer.to_out.weight - hf_attn_layer.out_proj.bias = pt_attn_layer.to_out.bias - - def _copy_linear(hf_linear, pt_linear): - hf_linear.weight = pt_linear.weight - hf_linear.bias = pt_linear.bias - - def _copy_layer(hf_layer, pt_layer): - # copy layer norms - _copy_linear(hf_layer.self_attn_layer_norm, pt_layer[0][0]) - _copy_linear(hf_layer.final_layer_norm, pt_layer[1][0]) - - # copy attn - _copy_attn_layer(hf_layer.self_attn, pt_layer[0][1]) - - # copy MLP - pt_mlp = pt_layer[1][1] - _copy_linear(hf_layer.fc1, pt_mlp.net[0][0]) - _copy_linear(hf_layer.fc2, pt_mlp.net[2]) - - def _copy_layers(hf_layers, pt_layers): - for i, hf_layer in enumerate(hf_layers): - if i != 0: - i += i - pt_layer = pt_layers[i : i + 2] - _copy_layer(hf_layer, pt_layer) - - hf_model = LDMBertModel(config).eval() - - # copy embeds - hf_model.model.embed_tokens.weight = checkpoint.transformer.token_emb.weight - hf_model.model.embed_positions.weight.data = checkpoint.transformer.pos_emb.emb.weight - - # copy layer norm - _copy_linear(hf_model.model.layer_norm, checkpoint.transformer.norm) - - # copy hidden layers - _copy_layers(hf_model.model.layers, checkpoint.transformer.attn_layers.layers) - - _copy_linear(hf_model.to_logits, checkpoint.transformer.to_logits) - - return hf_model - - -def convert_ldm_clip_checkpoint(checkpoint, local_files_only=False, text_encoder=None): - if text_encoder is None: - config_name = "openai/clip-vit-large-patch14" - try: - config = CLIPTextConfig.from_pretrained(config_name, local_files_only=local_files_only) - except Exception: - raise ValueError( - f"With local_files_only set to {local_files_only}, you must first locally save the configuration in the following path: 'openai/clip-vit-large-patch14'." - ) - - ctx = init_empty_weights if is_accelerate_available() else nullcontext - with ctx(): - text_model = CLIPTextModel(config) - else: - text_model = text_encoder - - keys = list(checkpoint.keys()) - - text_model_dict = {} - - remove_prefixes = ["cond_stage_model.transformer", "conditioner.embedders.0.transformer"] - - for key in keys: - for prefix in remove_prefixes: - if key.startswith(prefix): - text_model_dict[key[len(prefix + ".") :]] = checkpoint[key] - - if is_accelerate_available(): - for param_name, param in text_model_dict.items(): - set_module_tensor_to_device(text_model, param_name, "cpu", value=param) - else: - if not (hasattr(text_model, "embeddings") and hasattr(text_model.embeddings.position_ids)): - text_model_dict.pop("text_model.embeddings.position_ids", None) - - text_model.load_state_dict(text_model_dict) - - return text_model - - -textenc_conversion_lst = [ - ("positional_embedding", "text_model.embeddings.position_embedding.weight"), - ("token_embedding.weight", "text_model.embeddings.token_embedding.weight"), - ("ln_final.weight", "text_model.final_layer_norm.weight"), - ("ln_final.bias", "text_model.final_layer_norm.bias"), - ("text_projection", "text_projection.weight"), -] -textenc_conversion_map = {x[0]: x[1] for x in textenc_conversion_lst} - -textenc_transformer_conversion_lst = [ - # (stable-diffusion, HF Diffusers) - ("resblocks.", "text_model.encoder.layers."), - ("ln_1", "layer_norm1"), - ("ln_2", "layer_norm2"), - (".c_fc.", ".fc1."), - (".c_proj.", ".fc2."), - (".attn", ".self_attn"), - ("ln_final.", "transformer.text_model.final_layer_norm."), - ("token_embedding.weight", "transformer.text_model.embeddings.token_embedding.weight"), - ("positional_embedding", "transformer.text_model.embeddings.position_embedding.weight"), -] -protected = {re.escape(x[0]): x[1] for x in textenc_transformer_conversion_lst} -textenc_pattern = re.compile("|".join(protected.keys())) - - -def convert_paint_by_example_checkpoint(checkpoint, local_files_only=False): - config = CLIPVisionConfig.from_pretrained("openai/clip-vit-large-patch14", local_files_only=local_files_only) - model = PaintByExampleImageEncoder(config) - - keys = list(checkpoint.keys()) - - text_model_dict = {} - - for key in keys: - if key.startswith("cond_stage_model.transformer"): - text_model_dict[key[len("cond_stage_model.transformer.") :]] = checkpoint[key] - - # load clip vision - model.model.load_state_dict(text_model_dict) - - # load mapper - keys_mapper = { - k[len("cond_stage_model.mapper.res") :]: v - for k, v in checkpoint.items() - if k.startswith("cond_stage_model.mapper") - } - - MAPPING = { - "attn.c_qkv": ["attn1.to_q", "attn1.to_k", "attn1.to_v"], - "attn.c_proj": ["attn1.to_out.0"], - "ln_1": ["norm1"], - "ln_2": ["norm3"], - "mlp.c_fc": ["ff.net.0.proj"], - "mlp.c_proj": ["ff.net.2"], - } - - mapped_weights = {} - for key, value in keys_mapper.items(): - prefix = key[: len("blocks.i")] - suffix = key.split(prefix)[-1].split(".")[-1] - name = key.split(prefix)[-1].split(suffix)[0][1:-1] - mapped_names = MAPPING[name] - - num_splits = len(mapped_names) - for i, mapped_name in enumerate(mapped_names): - new_name = ".".join([prefix, mapped_name, suffix]) - shape = value.shape[0] // num_splits - mapped_weights[new_name] = value[i * shape : (i + 1) * shape] - - model.mapper.load_state_dict(mapped_weights) - - # load final layer norm - model.final_layer_norm.load_state_dict( - { - "bias": checkpoint["cond_stage_model.final_ln.bias"], - "weight": checkpoint["cond_stage_model.final_ln.weight"], - } - ) - - # load final proj - model.proj_out.load_state_dict( - { - "bias": checkpoint["proj_out.bias"], - "weight": checkpoint["proj_out.weight"], - } - ) - - # load uncond vector - model.uncond_vector.data = torch.nn.Parameter(checkpoint["learnable_vector"]) - return model - - -def convert_open_clip_checkpoint( - checkpoint, - config_name, - prefix="cond_stage_model.model.", - has_projection=False, - local_files_only=False, - **config_kwargs, -): - # text_model = CLIPTextModel.from_pretrained("stabilityai/stable-diffusion-2", subfolder="text_encoder") - # text_model = CLIPTextModelWithProjection.from_pretrained( - # "laion/CLIP-ViT-bigG-14-laion2B-39B-b160k", projection_dim=1280 - # ) - try: - config = CLIPTextConfig.from_pretrained(config_name, **config_kwargs, local_files_only=local_files_only) - except Exception: - raise ValueError( - f"With local_files_only set to {local_files_only}, you must first locally save the configuration in the following path: '{config_name}'." - ) - - ctx = init_empty_weights if is_accelerate_available() else nullcontext - with ctx(): - text_model = CLIPTextModelWithProjection(config) if has_projection else CLIPTextModel(config) - - keys = list(checkpoint.keys()) - - keys_to_ignore = [] - if config_name == "stabilityai/stable-diffusion-2" and config.num_hidden_layers == 23: - # make sure to remove all keys > 22 - keys_to_ignore += [k for k in keys if k.startswith("cond_stage_model.model.transformer.resblocks.23")] - keys_to_ignore += ["cond_stage_model.model.text_projection"] - - text_model_dict = {} - - if prefix + "text_projection" in checkpoint: - d_model = int(checkpoint[prefix + "text_projection"].shape[0]) - else: - d_model = 1024 - - text_model_dict["text_model.embeddings.position_ids"] = text_model.text_model.embeddings.get_buffer("position_ids") - - for key in keys: - if key in keys_to_ignore: - continue - if key[len(prefix) :] in textenc_conversion_map: - if key.endswith("text_projection"): - value = checkpoint[key].T.contiguous() - else: - value = checkpoint[key] - - text_model_dict[textenc_conversion_map[key[len(prefix) :]]] = value - - if key.startswith(prefix + "transformer."): - new_key = key[len(prefix + "transformer.") :] - if new_key.endswith(".in_proj_weight"): - new_key = new_key[: -len(".in_proj_weight")] - new_key = textenc_pattern.sub(lambda m: protected[re.escape(m.group(0))], new_key) - text_model_dict[new_key + ".q_proj.weight"] = checkpoint[key][:d_model, :] - text_model_dict[new_key + ".k_proj.weight"] = checkpoint[key][d_model : d_model * 2, :] - text_model_dict[new_key + ".v_proj.weight"] = checkpoint[key][d_model * 2 :, :] - elif new_key.endswith(".in_proj_bias"): - new_key = new_key[: -len(".in_proj_bias")] - new_key = textenc_pattern.sub(lambda m: protected[re.escape(m.group(0))], new_key) - text_model_dict[new_key + ".q_proj.bias"] = checkpoint[key][:d_model] - text_model_dict[new_key + ".k_proj.bias"] = checkpoint[key][d_model : d_model * 2] - text_model_dict[new_key + ".v_proj.bias"] = checkpoint[key][d_model * 2 :] - else: - new_key = textenc_pattern.sub(lambda m: protected[re.escape(m.group(0))], new_key) - - text_model_dict[new_key] = checkpoint[key] - - if is_accelerate_available(): - for param_name, param in text_model_dict.items(): - set_module_tensor_to_device(text_model, param_name, "cpu", value=param) - else: - if not (hasattr(text_model, "embeddings") and hasattr(text_model.embeddings.position_ids)): - text_model_dict.pop("text_model.embeddings.position_ids", None) - - text_model.load_state_dict(text_model_dict) - - return text_model - - -def stable_unclip_image_encoder(original_config, local_files_only=False): - """ - Returns the image processor and clip image encoder for the img2img unclip pipeline. - - We currently know of two types of stable unclip models which separately use the clip and the openclip image - encoders. - """ - - image_embedder_config = original_config["model"]["params"]["embedder_config"] - - sd_clip_image_embedder_class = image_embedder_config["target"] - sd_clip_image_embedder_class = sd_clip_image_embedder_class.split(".")[-1] - - if sd_clip_image_embedder_class == "ClipImageEmbedder": - clip_model_name = image_embedder_config.params.model - - if clip_model_name == "ViT-L/14": - feature_extractor = CLIPImageProcessor() - image_encoder = CLIPVisionModelWithProjection.from_pretrained( - "openai/clip-vit-large-patch14", local_files_only=local_files_only - ) - else: - raise NotImplementedError(f"Unknown CLIP checkpoint name in stable diffusion checkpoint {clip_model_name}") - - elif sd_clip_image_embedder_class == "FrozenOpenCLIPImageEmbedder": - feature_extractor = CLIPImageProcessor() - image_encoder = CLIPVisionModelWithProjection.from_pretrained( - "laion/CLIP-ViT-H-14-laion2B-s32B-b79K", local_files_only=local_files_only - ) - else: - raise NotImplementedError( - f"Unknown CLIP image embedder class in stable diffusion checkpoint {sd_clip_image_embedder_class}" - ) - - return feature_extractor, image_encoder - - -def stable_unclip_image_noising_components( - original_config, clip_stats_path: Optional[str] = None, device: Optional[str] = None -): - """ - Returns the noising components for the img2img and txt2img unclip pipelines. - - Converts the stability noise augmentor into - 1. a `StableUnCLIPImageNormalizer` for holding the CLIP stats - 2. a `DDPMScheduler` for holding the noise schedule - - If the noise augmentor config specifies a clip stats path, the `clip_stats_path` must be provided. - """ - noise_aug_config = original_config["model"]["params"]["noise_aug_config"] - noise_aug_class = noise_aug_config["target"] - noise_aug_class = noise_aug_class.split(".")[-1] - - if noise_aug_class == "CLIPEmbeddingNoiseAugmentation": - noise_aug_config = noise_aug_config.params - embedding_dim = noise_aug_config.timestep_dim - max_noise_level = noise_aug_config.noise_schedule_config.timesteps - beta_schedule = noise_aug_config.noise_schedule_config.beta_schedule - - image_normalizer = StableUnCLIPImageNormalizer(embedding_dim=embedding_dim) - image_noising_scheduler = DDPMScheduler(num_train_timesteps=max_noise_level, beta_schedule=beta_schedule) - - if "clip_stats_path" in noise_aug_config: - if clip_stats_path is None: - raise ValueError("This stable unclip config requires a `clip_stats_path`") - - clip_mean, clip_std = torch.load(clip_stats_path, map_location=device) - clip_mean = clip_mean[None, :] - clip_std = clip_std[None, :] - - clip_stats_state_dict = { - "mean": clip_mean, - "std": clip_std, - } - - image_normalizer.load_state_dict(clip_stats_state_dict) - else: - raise NotImplementedError(f"Unknown noise augmentor class: {noise_aug_class}") - - return image_normalizer, image_noising_scheduler - - -def convert_controlnet_checkpoint( - checkpoint, - original_config, - checkpoint_path, - image_size, - upcast_attention, - extract_ema, - use_linear_projection=None, - cross_attention_dim=None, -): - ctrlnet_config = create_unet_diffusers_config(original_config, image_size=image_size, controlnet=True) - ctrlnet_config["upcast_attention"] = upcast_attention - - ctrlnet_config.pop("sample_size") - - if use_linear_projection is not None: - ctrlnet_config["use_linear_projection"] = use_linear_projection - - if cross_attention_dim is not None: - ctrlnet_config["cross_attention_dim"] = cross_attention_dim - - ctx = init_empty_weights if is_accelerate_available() else nullcontext - with ctx(): - controlnet = ControlNetModel(**ctrlnet_config) - - # Some controlnet ckpt files are distributed independently from the rest of the - # model components i.e. https://huggingface.co/thibaud/controlnet-sd21/ - if "time_embed.0.weight" in checkpoint: - skip_extract_state_dict = True - else: - skip_extract_state_dict = False - - converted_ctrl_checkpoint = convert_ldm_unet_checkpoint( - checkpoint, - ctrlnet_config, - path=checkpoint_path, - extract_ema=extract_ema, - controlnet=True, - skip_extract_state_dict=skip_extract_state_dict, - ) - - if is_accelerate_available(): - for param_name, param in converted_ctrl_checkpoint.items(): - set_module_tensor_to_device(controlnet, param_name, "cpu", value=param) - else: - controlnet.load_state_dict(converted_ctrl_checkpoint) - - return controlnet - - -def download_from_original_stable_diffusion_ckpt( - checkpoint_path_or_dict: Union[str, Dict[str, torch.Tensor]], - original_config_file: str = None, - image_size: Optional[int] = None, - prediction_type: str = None, - model_type: str = None, - extract_ema: bool = False, - scheduler_type: str = "pndm", - num_in_channels: Optional[int] = None, - upcast_attention: Optional[bool] = None, - device: str = None, - from_safetensors: bool = False, - stable_unclip: Optional[str] = None, - stable_unclip_prior: Optional[str] = None, - clip_stats_path: Optional[str] = None, - controlnet: Optional[bool] = None, - adapter: Optional[bool] = None, - load_safety_checker: bool = True, - safety_checker: Optional[StableDiffusionSafetyChecker] = None, - feature_extractor: Optional[AutoFeatureExtractor] = None, - pipeline_class: DiffusionPipeline = None, - local_files_only=False, - vae_path=None, - vae=None, - text_encoder=None, - text_encoder_2=None, - tokenizer=None, - tokenizer_2=None, - config_files=None, -) -> DiffusionPipeline: - """ - Load a Stable Diffusion pipeline object from a CompVis-style `.ckpt`/`.safetensors` file and (ideally) a `.yaml` - config file. - - Although many of the arguments can be automatically inferred, some of these rely on brittle checks against the - global step count, which will likely fail for models that have undergone further fine-tuning. Therefore, it is - recommended that you override the default values and/or supply an `original_config_file` wherever possible. - - Args: - checkpoint_path_or_dict (`str` or `dict`): Path to `.ckpt` file, or the state dict. - original_config_file (`str`): - Path to `.yaml` config file corresponding to the original architecture. If `None`, will be automatically - inferred by looking for a key that only exists in SD2.0 models. - image_size (`int`, *optional*, defaults to 512): - The image size that the model was trained on. Use 512 for Stable Diffusion v1.X and Stable Diffusion v2 - Base. Use 768 for Stable Diffusion v2. - prediction_type (`str`, *optional*): - The prediction type that the model was trained on. Use `'epsilon'` for Stable Diffusion v1.X and Stable - Diffusion v2 Base. Use `'v_prediction'` for Stable Diffusion v2. - num_in_channels (`int`, *optional*, defaults to None): - The number of input channels. If `None`, it will be automatically inferred. - scheduler_type (`str`, *optional*, defaults to 'pndm'): - Type of scheduler to use. Should be one of `["pndm", "lms", "heun", "euler", "euler-ancestral", "dpm", - "ddim"]`. - model_type (`str`, *optional*, defaults to `None`): - The pipeline type. `None` to automatically infer, or one of `["FrozenOpenCLIPEmbedder", - "FrozenCLIPEmbedder", "PaintByExample"]`. - is_img2img (`bool`, *optional*, defaults to `False`): - Whether the model should be loaded as an img2img pipeline. - extract_ema (`bool`, *optional*, defaults to `False`): Only relevant for - checkpoints that have both EMA and non-EMA weights. Whether to extract the EMA weights or not. Defaults to - `False`. Pass `True` to extract the EMA weights. EMA weights usually yield higher quality images for - inference. Non-EMA weights are usually better to continue fine-tuning. - upcast_attention (`bool`, *optional*, defaults to `None`): - Whether the attention computation should always be upcasted. This is necessary when running stable - diffusion 2.1. - device (`str`, *optional*, defaults to `None`): - The device to use. Pass `None` to determine automatically. - from_safetensors (`str`, *optional*, defaults to `False`): - If `checkpoint_path` is in `safetensors` format, load checkpoint with safetensors instead of PyTorch. - load_safety_checker (`bool`, *optional*, defaults to `True`): - Whether to load the safety checker or not. Defaults to `True`. - safety_checker (`StableDiffusionSafetyChecker`, *optional*, defaults to `None`): - Safety checker to use. If this parameter is `None`, the function will load a new instance of - [StableDiffusionSafetyChecker] by itself, if needed. - feature_extractor (`AutoFeatureExtractor`, *optional*, defaults to `None`): - Feature extractor to use. If this parameter is `None`, the function will load a new instance of - [AutoFeatureExtractor] by itself, if needed. - pipeline_class (`str`, *optional*, defaults to `None`): - The pipeline class to use. Pass `None` to determine automatically. - local_files_only (`bool`, *optional*, defaults to `False`): - Whether or not to only look at local files (i.e., do not try to download the model). - vae (`AutoencoderKL`, *optional*, defaults to `None`): - Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations. If - this parameter is `None`, the function will load a new instance of [CLIP] by itself, if needed. - text_encoder (`CLIPTextModel`, *optional*, defaults to `None`): - An instance of [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel) - to use, specifically the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) - variant. If this parameter is `None`, the function will load a new instance of [CLIP] by itself, if needed. - tokenizer (`CLIPTokenizer`, *optional*, defaults to `None`): - An instance of - [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer) - to use. If this parameter is `None`, the function will load a new instance of [CLIPTokenizer] by itself, if - needed. - config_files (`Dict[str, str]`, *optional*, defaults to `None`): - A dictionary mapping from config file names to their contents. If this parameter is `None`, the function - will load the config files by itself, if needed. Valid keys are: - - `v1`: Config file for Stable Diffusion v1 - - `v2`: Config file for Stable Diffusion v2 - - `xl`: Config file for Stable Diffusion XL - - `xl_refiner`: Config file for Stable Diffusion XL Refiner - return: A StableDiffusionPipeline object representing the passed-in `.ckpt`/`.safetensors` file. - """ - - # import pipelines here to avoid circular import error when using from_single_file method - from diffusers import ( - LDMTextToImagePipeline, - PaintByExamplePipeline, - StableDiffusionControlNetPipeline, - StableDiffusionInpaintPipeline, - StableDiffusionPipeline, - StableDiffusionUpscalePipeline, - StableDiffusionXLControlNetInpaintPipeline, - StableDiffusionXLImg2ImgPipeline, - StableDiffusionXLInpaintPipeline, - StableDiffusionXLPipeline, - StableUnCLIPImg2ImgPipeline, - StableUnCLIPPipeline, - ) - - if prediction_type == "v-prediction": - prediction_type = "v_prediction" - - if isinstance(checkpoint_path_or_dict, str): - if from_safetensors: - from safetensors.torch import load_file as safe_load - - checkpoint = safe_load(checkpoint_path_or_dict, device="cpu") - else: - if device is None: - device = "cuda" if torch.cuda.is_available() else "cpu" - checkpoint = torch.load(checkpoint_path_or_dict, map_location=device) - else: - checkpoint = torch.load(checkpoint_path_or_dict, map_location=device) - elif isinstance(checkpoint_path_or_dict, dict): - checkpoint = checkpoint_path_or_dict - - # Sometimes models don't have the global_step item - if "global_step" in checkpoint: - global_step = checkpoint["global_step"] - else: - logger.debug("global_step key not found in model") - global_step = None - - # NOTE: this while loop isn't great but this controlnet checkpoint has one additional - # "state_dict" key https://huggingface.co/thibaud/controlnet-canny-sd21 - while "state_dict" in checkpoint: - checkpoint = checkpoint["state_dict"] - - if original_config_file is None: - key_name_v2_1 = "model.diffusion_model.input_blocks.2.1.transformer_blocks.0.attn2.to_k.weight" - key_name_sd_xl_base = "conditioner.embedders.1.model.transformer.resblocks.9.mlp.c_proj.bias" - key_name_sd_xl_refiner = "conditioner.embedders.0.model.transformer.resblocks.9.mlp.c_proj.bias" - is_upscale = pipeline_class == StableDiffusionUpscalePipeline - - config_url = None - - # model_type = "v1" - if config_files is not None and "v1" in config_files: - original_config_file = config_files["v1"] - else: - config_url = "https://raw.githubusercontent.com/CompVis/stable-diffusion/main/configs/stable-diffusion/v1-inference.yaml" - - if key_name_v2_1 in checkpoint and checkpoint[key_name_v2_1].shape[-1] == 1024: - # model_type = "v2" - if config_files is not None and "v2" in config_files: - original_config_file = config_files["v2"] - else: - config_url = "https://raw.githubusercontent.com/Stability-AI/stablediffusion/main/configs/stable-diffusion/v2-inference-v.yaml" - if global_step == 110000: - # v2.1 needs to upcast attention - upcast_attention = True - elif key_name_sd_xl_base in checkpoint: - # only base xl has two text embedders - if config_files is not None and "xl" in config_files: - original_config_file = config_files["xl"] - else: - config_url = "https://raw.githubusercontent.com/Stability-AI/generative-models/main/configs/inference/sd_xl_base.yaml" - elif key_name_sd_xl_refiner in checkpoint: - # only refiner xl has embedder and one text embedders - if config_files is not None and "xl_refiner" in config_files: - original_config_file = config_files["xl_refiner"] - else: - config_url = "https://raw.githubusercontent.com/Stability-AI/generative-models/main/configs/inference/sd_xl_refiner.yaml" - - if is_upscale: - config_url = "https://raw.githubusercontent.com/Stability-AI/stablediffusion/main/configs/stable-diffusion/x4-upscaling.yaml" - - if config_url is not None: - original_config_file = BytesIO(requests.get(config_url).content) - else: - with open(original_config_file, "r") as f: - original_config_file = f.read() - else: - with open(original_config_file, "r") as f: - original_config_file = f.read() - - original_config = yaml.safe_load(original_config_file) - - # Convert the text model. - if ( - model_type is None - and "cond_stage_config" in original_config["model"]["params"] - and original_config["model"]["params"]["cond_stage_config"] is not None - ): - model_type = original_config["model"]["params"]["cond_stage_config"]["target"].split(".")[-1] - logger.debug(f"no `model_type` given, `model_type` inferred as: {model_type}") - elif model_type is None and original_config["model"]["params"]["network_config"] is not None: - if original_config["model"]["params"]["network_config"]["params"]["context_dim"] == 2048: - model_type = "SDXL" - else: - model_type = "SDXL-Refiner" - if image_size is None: - image_size = 1024 - - if pipeline_class is None: - # Check if we have a SDXL or SD model and initialize default pipeline - if model_type not in ["SDXL", "SDXL-Refiner"]: - pipeline_class = StableDiffusionPipeline if not controlnet else StableDiffusionControlNetPipeline - else: - pipeline_class = StableDiffusionXLPipeline if model_type == "SDXL" else StableDiffusionXLImg2ImgPipeline - - if num_in_channels is None and pipeline_class in [ - StableDiffusionInpaintPipeline, - StableDiffusionXLInpaintPipeline, - StableDiffusionXLControlNetInpaintPipeline, - ]: - num_in_channels = 9 - if num_in_channels is None and pipeline_class == StableDiffusionUpscalePipeline: - num_in_channels = 7 - elif num_in_channels is None: - num_in_channels = 4 - - if "unet_config" in original_config["model"]["params"]: - original_config["model"]["params"]["unet_config"]["params"]["in_channels"] = num_in_channels - elif "network_config" in original_config["model"]["params"]: - original_config["model"]["params"]["network_config"]["params"]["in_channels"] = num_in_channels - - if ( - "parameterization" in original_config["model"]["params"] - and original_config["model"]["params"]["parameterization"] == "v" - ): - if prediction_type is None: - # NOTE: For stable diffusion 2 base it is recommended to pass `prediction_type=="epsilon"` - # as it relies on a brittle global step parameter here - prediction_type = "epsilon" if global_step == 875000 else "v_prediction" - if image_size is None: - # NOTE: For stable diffusion 2 base one has to pass `image_size==512` - # as it relies on a brittle global step parameter here - image_size = 512 if global_step == 875000 else 768 - else: - if prediction_type is None: - prediction_type = "epsilon" - if image_size is None: - image_size = 512 - - if controlnet is None and "control_stage_config" in original_config["model"]["params"]: - path = checkpoint_path_or_dict if isinstance(checkpoint_path_or_dict, str) else "" - controlnet = convert_controlnet_checkpoint( - checkpoint, original_config, path, image_size, upcast_attention, extract_ema - ) - - if "timesteps" in original_config["model"]["params"]: - num_train_timesteps = original_config["model"]["params"]["timesteps"] - else: - num_train_timesteps = 1000 - - if model_type in ["SDXL", "SDXL-Refiner"]: - scheduler_dict = { - "beta_schedule": "scaled_linear", - "beta_start": 0.00085, - "beta_end": 0.012, - "interpolation_type": "linear", - "num_train_timesteps": num_train_timesteps, - "prediction_type": "epsilon", - "sample_max_value": 1.0, - "set_alpha_to_one": False, - "skip_prk_steps": True, - "steps_offset": 1, - "timestep_spacing": "leading", - } - scheduler = EulerDiscreteScheduler.from_config(scheduler_dict) - scheduler_type = "euler" - else: - if "linear_start" in original_config["model"]["params"]: - beta_start = original_config["model"]["params"]["linear_start"] - else: - beta_start = 0.02 - - if "linear_end" in original_config["model"]["params"]: - beta_end = original_config["model"]["params"]["linear_end"] - else: - beta_end = 0.085 - scheduler = DDIMScheduler( - beta_end=beta_end, - beta_schedule="scaled_linear", - beta_start=beta_start, - num_train_timesteps=num_train_timesteps, - steps_offset=1, - clip_sample=False, - set_alpha_to_one=False, - prediction_type=prediction_type, - ) - # make sure scheduler works correctly with DDIM - scheduler.register_to_config(clip_sample=False) - - if scheduler_type == "pndm": - config = dict(scheduler.config) - config["skip_prk_steps"] = True - scheduler = PNDMScheduler.from_config(config) - elif scheduler_type == "lms": - scheduler = LMSDiscreteScheduler.from_config(scheduler.config) - elif scheduler_type == "heun": - scheduler = HeunDiscreteScheduler.from_config(scheduler.config) - elif scheduler_type == "euler": - scheduler = EulerDiscreteScheduler.from_config(scheduler.config) - elif scheduler_type == "euler-ancestral": - scheduler = EulerAncestralDiscreteScheduler.from_config(scheduler.config) - elif scheduler_type == "dpm": - scheduler = DPMSolverMultistepScheduler.from_config(scheduler.config) - elif scheduler_type == "ddim": - scheduler = scheduler - else: - raise ValueError(f"Scheduler of type {scheduler_type} doesn't exist!") - - if pipeline_class == StableDiffusionUpscalePipeline: - image_size = original_config["model"]["params"]["unet_config"]["params"]["image_size"] - - # Convert the UNet2DConditionModel model. - unet_config = create_unet_diffusers_config(original_config, image_size=image_size) - unet_config["upcast_attention"] = upcast_attention - - path = checkpoint_path_or_dict if isinstance(checkpoint_path_or_dict, str) else "" - converted_unet_checkpoint = convert_ldm_unet_checkpoint( - checkpoint, unet_config, path=path, extract_ema=extract_ema - ) - - ctx = init_empty_weights if is_accelerate_available() else nullcontext - with ctx(): - unet = UNet2DConditionModel(**unet_config) - - if is_accelerate_available(): - if model_type not in ["SDXL", "SDXL-Refiner"]: # SBM Delay this. - for param_name, param in converted_unet_checkpoint.items(): - set_module_tensor_to_device(unet, param_name, "cpu", value=param) - else: - unet.load_state_dict(converted_unet_checkpoint) - - # Convert the VAE model. - if vae_path is None and vae is None: - vae_config = create_vae_diffusers_config(original_config, image_size=image_size) - converted_vae_checkpoint = convert_ldm_vae_checkpoint(checkpoint, vae_config) - - if ( - "model" in original_config - and "params" in original_config["model"] - and "scale_factor" in original_config["model"]["params"] - ): - vae_scaling_factor = original_config["model"]["params"]["scale_factor"] - else: - vae_scaling_factor = 0.18215 # default SD scaling factor - - vae_config["scaling_factor"] = vae_scaling_factor - - ctx = init_empty_weights if is_accelerate_available() else nullcontext - with ctx(): - vae = AutoencoderKL(**vae_config) - - if is_accelerate_available(): - for param_name, param in converted_vae_checkpoint.items(): - set_module_tensor_to_device(vae, param_name, "cpu", value=param) - else: - vae.load_state_dict(converted_vae_checkpoint) - elif vae is None: - vae = AutoencoderKL.from_pretrained(vae_path, local_files_only=local_files_only) - - if model_type == "FrozenOpenCLIPEmbedder": - config_name = "stabilityai/stable-diffusion-2" - config_kwargs = {"subfolder": "text_encoder"} - - if text_encoder is None: - text_model = convert_open_clip_checkpoint( - checkpoint, config_name, local_files_only=local_files_only, **config_kwargs - ) - else: - text_model = text_encoder - - try: - tokenizer = CLIPTokenizer.from_pretrained( - "stabilityai/stable-diffusion-2", subfolder="tokenizer", local_files_only=local_files_only - ) - except Exception: - raise ValueError( - f"With local_files_only set to {local_files_only}, you must first locally save the tokenizer in the following path: 'stabilityai/stable-diffusion-2'." - ) - - if stable_unclip is None: - if controlnet: - pipe = pipeline_class( - vae=vae, - text_encoder=text_model, - tokenizer=tokenizer, - unet=unet, - scheduler=scheduler, - controlnet=controlnet, - safety_checker=safety_checker, - feature_extractor=feature_extractor, - ) - if hasattr(pipe, "requires_safety_checker"): - pipe.requires_safety_checker = False - - elif pipeline_class == StableDiffusionUpscalePipeline: - scheduler = DDIMScheduler.from_pretrained( - "stabilityai/stable-diffusion-x4-upscaler", subfolder="scheduler" - ) - low_res_scheduler = DDPMScheduler.from_pretrained( - "stabilityai/stable-diffusion-x4-upscaler", subfolder="low_res_scheduler" - ) - - pipe = pipeline_class( - vae=vae, - text_encoder=text_model, - tokenizer=tokenizer, - unet=unet, - scheduler=scheduler, - low_res_scheduler=low_res_scheduler, - safety_checker=safety_checker, - feature_extractor=feature_extractor, - ) - - else: - pipe = pipeline_class( - vae=vae, - text_encoder=text_model, - tokenizer=tokenizer, - unet=unet, - scheduler=scheduler, - safety_checker=safety_checker, - feature_extractor=feature_extractor, - ) - if hasattr(pipe, "requires_safety_checker"): - pipe.requires_safety_checker = False - - else: - image_normalizer, image_noising_scheduler = stable_unclip_image_noising_components( - original_config, clip_stats_path=clip_stats_path, device=device - ) - - if stable_unclip == "img2img": - feature_extractor, image_encoder = stable_unclip_image_encoder(original_config) - - pipe = StableUnCLIPImg2ImgPipeline( - # image encoding components - feature_extractor=feature_extractor, - image_encoder=image_encoder, - # image noising components - image_normalizer=image_normalizer, - image_noising_scheduler=image_noising_scheduler, - # regular denoising components - tokenizer=tokenizer, - text_encoder=text_model, - unet=unet, - scheduler=scheduler, - # vae - vae=vae, - ) - elif stable_unclip == "txt2img": - if stable_unclip_prior is None or stable_unclip_prior == "karlo": - karlo_model = "kakaobrain/karlo-v1-alpha" - prior = PriorTransformer.from_pretrained( - karlo_model, subfolder="prior", local_files_only=local_files_only - ) - - try: - prior_tokenizer = CLIPTokenizer.from_pretrained( - "openai/clip-vit-large-patch14", local_files_only=local_files_only - ) - except Exception: - raise ValueError( - f"With local_files_only set to {local_files_only}, you must first locally save the tokenizer in the following path: 'openai/clip-vit-large-patch14'." - ) - prior_text_model = CLIPTextModelWithProjection.from_pretrained( - "openai/clip-vit-large-patch14", local_files_only=local_files_only - ) - - prior_scheduler = UnCLIPScheduler.from_pretrained( - karlo_model, subfolder="prior_scheduler", local_files_only=local_files_only - ) - prior_scheduler = DDPMScheduler.from_config(prior_scheduler.config) - else: - raise NotImplementedError(f"unknown prior for stable unclip model: {stable_unclip_prior}") - - pipe = StableUnCLIPPipeline( - # prior components - prior_tokenizer=prior_tokenizer, - prior_text_encoder=prior_text_model, - prior=prior, - prior_scheduler=prior_scheduler, - # image noising components - image_normalizer=image_normalizer, - image_noising_scheduler=image_noising_scheduler, - # regular denoising components - tokenizer=tokenizer, - text_encoder=text_model, - unet=unet, - scheduler=scheduler, - # vae - vae=vae, - ) - else: - raise NotImplementedError(f"unknown `stable_unclip` type: {stable_unclip}") - elif model_type == "PaintByExample": - vision_model = convert_paint_by_example_checkpoint(checkpoint) - try: - tokenizer = CLIPTokenizer.from_pretrained( - "openai/clip-vit-large-patch14", local_files_only=local_files_only - ) - except Exception: - raise ValueError( - f"With local_files_only set to {local_files_only}, you must first locally save the tokenizer in the following path: 'openai/clip-vit-large-patch14'." - ) - try: - feature_extractor = AutoFeatureExtractor.from_pretrained( - "CompVis/stable-diffusion-safety-checker", local_files_only=local_files_only - ) - except Exception: - raise ValueError( - f"With local_files_only set to {local_files_only}, you must first locally save the feature_extractor in the following path: 'CompVis/stable-diffusion-safety-checker'." - ) - pipe = PaintByExamplePipeline( - vae=vae, - image_encoder=vision_model, - unet=unet, - scheduler=scheduler, - safety_checker=None, - feature_extractor=feature_extractor, - ) - elif model_type == "FrozenCLIPEmbedder": - text_model = convert_ldm_clip_checkpoint( - checkpoint, local_files_only=local_files_only, text_encoder=text_encoder - ) - try: - tokenizer = ( - CLIPTokenizer.from_pretrained("openai/clip-vit-large-patch14", local_files_only=local_files_only) - if tokenizer is None - else tokenizer - ) - except Exception: - raise ValueError( - f"With local_files_only set to {local_files_only}, you must first locally save the tokenizer in the following path: 'openai/clip-vit-large-patch14'." - ) - - if load_safety_checker: - safety_checker = StableDiffusionSafetyChecker.from_pretrained( - "CompVis/stable-diffusion-safety-checker", local_files_only=local_files_only - ) - feature_extractor = AutoFeatureExtractor.from_pretrained( - "CompVis/stable-diffusion-safety-checker", local_files_only=local_files_only - ) - - if controlnet: - pipe = pipeline_class( - vae=vae, - text_encoder=text_model, - tokenizer=tokenizer, - unet=unet, - controlnet=controlnet, - scheduler=scheduler, - safety_checker=safety_checker, - feature_extractor=feature_extractor, - ) - else: - pipe = pipeline_class( - vae=vae, - text_encoder=text_model, - tokenizer=tokenizer, - unet=unet, - scheduler=scheduler, - safety_checker=safety_checker, - feature_extractor=feature_extractor, - ) - elif model_type in ["SDXL", "SDXL-Refiner"]: - is_refiner = model_type == "SDXL-Refiner" - - if (is_refiner is False) and (tokenizer is None): - try: - tokenizer = CLIPTokenizer.from_pretrained( - "openai/clip-vit-large-patch14", local_files_only=local_files_only - ) - except Exception: - raise ValueError( - f"With local_files_only set to {local_files_only}, you must first locally save the tokenizer in the following path: 'openai/clip-vit-large-patch14'." - ) - - if (is_refiner is False) and (text_encoder is None): - text_encoder = convert_ldm_clip_checkpoint(checkpoint, local_files_only=local_files_only) - - if tokenizer_2 is None: - try: - tokenizer_2 = CLIPTokenizer.from_pretrained( - "laion/CLIP-ViT-bigG-14-laion2B-39B-b160k", pad_token="!", local_files_only=local_files_only - ) - except Exception: - raise ValueError( - f"With local_files_only set to {local_files_only}, you must first locally save the tokenizer in the following path: 'laion/CLIP-ViT-bigG-14-laion2B-39B-b160k' with `pad_token` set to '!'." - ) - - if text_encoder_2 is None: - config_name = "laion/CLIP-ViT-bigG-14-laion2B-39B-b160k" - config_kwargs = {"projection_dim": 1280} - prefix = "conditioner.embedders.0.model." if is_refiner else "conditioner.embedders.1.model." - - text_encoder_2 = convert_open_clip_checkpoint( - checkpoint, - config_name, - prefix=prefix, - has_projection=True, - local_files_only=local_files_only, - **config_kwargs, - ) - - if is_accelerate_available(): # SBM Now move model to cpu. - for param_name, param in converted_unet_checkpoint.items(): - set_module_tensor_to_device(unet, param_name, "cpu", value=param) - - if controlnet: - pipe = pipeline_class( - vae=vae, - text_encoder=text_encoder, - tokenizer=tokenizer, - text_encoder_2=text_encoder_2, - tokenizer_2=tokenizer_2, - unet=unet, - controlnet=controlnet, - scheduler=scheduler, - force_zeros_for_empty_prompt=True, - ) - elif adapter: - pipe = pipeline_class( - vae=vae, - text_encoder=text_encoder, - tokenizer=tokenizer, - text_encoder_2=text_encoder_2, - tokenizer_2=tokenizer_2, - unet=unet, - adapter=adapter, - scheduler=scheduler, - force_zeros_for_empty_prompt=True, - ) - - else: - pipeline_kwargs = { - "vae": vae, - "text_encoder": text_encoder, - "tokenizer": tokenizer, - "text_encoder_2": text_encoder_2, - "tokenizer_2": tokenizer_2, - "unet": unet, - "scheduler": scheduler, - } - - if (pipeline_class == StableDiffusionXLImg2ImgPipeline) or ( - pipeline_class == StableDiffusionXLInpaintPipeline - ): - pipeline_kwargs.update({"requires_aesthetics_score": is_refiner}) - - if is_refiner: - pipeline_kwargs.update({"force_zeros_for_empty_prompt": False}) - - pipe = pipeline_class(**pipeline_kwargs) - else: - text_config = create_ldm_bert_config(original_config) - text_model = convert_ldm_bert_checkpoint(checkpoint, text_config) - tokenizer = BertTokenizerFast.from_pretrained("bert-base-uncased", local_files_only=local_files_only) - pipe = LDMTextToImagePipeline(vqvae=vae, bert=text_model, tokenizer=tokenizer, unet=unet, scheduler=scheduler) - - return pipe - - -def download_controlnet_from_original_ckpt( - checkpoint_path: str, - original_config_file: str, - image_size: int = 512, - extract_ema: bool = False, - num_in_channels: Optional[int] = None, - upcast_attention: Optional[bool] = None, - device: str = None, - from_safetensors: bool = False, - use_linear_projection: Optional[bool] = None, - cross_attention_dim: Optional[bool] = None, -) -> DiffusionPipeline: - if from_safetensors: - from safetensors import safe_open - - checkpoint = {} - with safe_open(checkpoint_path, framework="pt", device="cpu") as f: - for key in f.keys(): - checkpoint[key] = f.get_tensor(key) - else: - if device is None: - device = "cuda" if torch.cuda.is_available() else "cpu" - checkpoint = torch.load(checkpoint_path, map_location=device) - else: - checkpoint = torch.load(checkpoint_path, map_location=device) - - # NOTE: this while loop isn't great but this controlnet checkpoint has one additional - # "state_dict" key https://huggingface.co/thibaud/controlnet-canny-sd21 - while "state_dict" in checkpoint: - checkpoint = checkpoint["state_dict"] - - with open(original_config_file, "r") as f: - original_config_file = f.read() - original_config = yaml.safe_load(original_config_file) - - if num_in_channels is not None: - original_config["model"]["params"]["unet_config"]["params"]["in_channels"] = num_in_channels - - if "control_stage_config" not in original_config["model"]["params"]: - raise ValueError("`control_stage_config` not present in original config") - - controlnet = convert_controlnet_checkpoint( - checkpoint, - original_config, - checkpoint_path, - image_size, - upcast_attention, - extract_ema, - use_linear_projection=use_linear_projection, - cross_attention_dim=cross_attention_dim, - ) - - return controlnet diff --git a/examples/research_projects/anytext/convert_original_anytext_to_diffusers.py b/examples/research_projects/anytext/convert_original_anytext_to_diffusers.py deleted file mode 100644 index 4f5fd7aa01a8..000000000000 --- a/examples/research_projects/anytext/convert_original_anytext_to_diffusers.py +++ /dev/null @@ -1 +0,0 @@ -# In construction... From 61693a589877446ed0b74cbbab1f947f4b86ff5e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tolga=20Cang=C3=B6z?= Date: Sat, 22 Feb 2025 12:27:05 +0300 Subject: [PATCH 76/87] [UPDATE] Refactor BasicTokenizer usage to a new Checker class for text processing --- examples/research_projects/anytext/anytext.py | 67 ++- .../anytext/bert_tokenizer.py | 430 ------------------ 2 files changed, 65 insertions(+), 432 deletions(-) delete mode 100644 examples/research_projects/anytext/bert_tokenizer.py diff --git a/examples/research_projects/anytext/anytext.py b/examples/research_projects/anytext/anytext.py index fdf52792f045..4adf77b4a5a8 100644 --- a/examples/research_projects/anytext/anytext.py +++ b/examples/research_projects/anytext/anytext.py @@ -33,7 +33,6 @@ import PIL.Image import torch import torch.nn.functional as F -from bert_tokenizer import BasicTokenizer from easydict import EasyDict as edict from frozen_clip_embedder_t3 import FrozenCLIPEmbedderT3 from huggingface_hub import hf_hub_download @@ -71,7 +70,71 @@ from diffusers.utils.torch_utils import is_compiled_module, is_torch_version, randn_tensor -checker = BasicTokenizer() +class Checker: + def __init__(self): + pass + + def _is_chinese_char(self, cp): + """Checks whether CP is the codepoint of a CJK character.""" + # This defines a "chinese character" as anything in the CJK Unicode block: + # https://en.wikipedia.org/wiki/CJK_Unified_Ideographs_(Unicode_block) + # + # Note that the CJK Unicode block is NOT all Japanese and Korean characters, + # despite its name. The modern Korean Hangul alphabet is a different block, + # as is Japanese Hiragana and Katakana. Those alphabets are used to write + # space-separated words, so they are not treated specially and handled + # like the all of the other languages. + if ( + (cp >= 0x4E00 and cp <= 0x9FFF) + or (cp >= 0x3400 and cp <= 0x4DBF) + or (cp >= 0x20000 and cp <= 0x2A6DF) + or (cp >= 0x2A700 and cp <= 0x2B73F) + or (cp >= 0x2B740 and cp <= 0x2B81F) + or (cp >= 0x2B820 and cp <= 0x2CEAF) + or (cp >= 0xF900 and cp <= 0xFAFF) + or (cp >= 0x2F800 and cp <= 0x2FA1F) + ): + return True + + return False + + def _clean_text(self, text): + """Performs invalid character removal and whitespace cleanup on text.""" + output = [] + for char in text: + cp = ord(char) + if cp == 0 or cp == 0xFFFD or self._is_control(char): + continue + if self._is_whitespace(char): + output.append(" ") + else: + output.append(char) + return "".join(output) + + def _is_control(self, char): + """Checks whether `chars` is a control character.""" + # These are technically control characters but we count them as whitespace + # characters. + if char == "\t" or char == "\n" or char == "\r": + return False + cat = unicodedata.category(char) + if cat in ("Cc", "Cf"): + return True + return False + + def _is_whitespace(self, char): + """Checks whether `chars` is a whitespace character.""" + # \t, \n, and \r are technically control characters but we treat them + # as whitespace since they are generally considered as such. + if char == " " or char == "\t" or char == "\n" or char == "\r": + return True + cat = unicodedata.category(char) + if cat == "Zs": + return True + return False + + +checker = Checker() PLACE_HOLDER = "*" diff --git a/examples/research_projects/anytext/bert_tokenizer.py b/examples/research_projects/anytext/bert_tokenizer.py deleted file mode 100644 index fd1e0ce32c47..000000000000 --- a/examples/research_projects/anytext/bert_tokenizer.py +++ /dev/null @@ -1,430 +0,0 @@ -# Copyright 2018 The Google AI Language Team Authors. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -# TODO: Try to use the `transformers` library instead of this custom implementation if possible. -"""Tokenization classes.""" - -from __future__ import absolute_import, division, print_function - -import collections -import re -import unicodedata - -import six - - -def validate_case_matches_checkpoint(do_lower_case, init_checkpoint): - """Checks whether the casing config is consistent with the checkpoint name.""" - - # The casing has to be passed in by the user and there is no explicit check - # as to whether it matches the checkpoint. The casing information probably - # should have been stored in the bert_config.json file, but it's not, so - # we have to heuristically detect it to validate. - - if not init_checkpoint: - return - - m = re.match("^.*?([A-Za-z0-9_-]+)/bert_model.ckpt", init_checkpoint) - if m is None: - return - - model_name = m.group(1) - - lower_models = [ - "uncased_L-24_H-1024_A-16", - "uncased_L-12_H-768_A-12", - "multilingual_L-12_H-768_A-12", - "chinese_L-12_H-768_A-12", - ] - - cased_models = ["cased_L-12_H-768_A-12", "cased_L-24_H-1024_A-16", "multi_cased_L-12_H-768_A-12"] - - is_bad_config = False - if model_name in lower_models and not do_lower_case: - is_bad_config = True - actual_flag = "False" - case_name = "lowercased" - opposite_flag = "True" - - if model_name in cased_models and do_lower_case: - is_bad_config = True - actual_flag = "True" - case_name = "cased" - opposite_flag = "False" - - if is_bad_config: - raise ValueError( - "You passed in `--do_lower_case=%s` with `--init_checkpoint=%s`. " - "However, `%s` seems to be a %s model, so you " - "should pass in `--do_lower_case=%s` so that the fine-tuning matches " - "how the model was pre-training. If this error is wrong, please " - "just comment out this check." % (actual_flag, init_checkpoint, model_name, case_name, opposite_flag) - ) - - -def convert_to_unicode(text): - """Converts `text` to Unicode (if it's not already), assuming utf-8 input.""" - if six.PY3: - if isinstance(text, str): - return text - elif isinstance(text, bytes): - return text.decode("utf-8", "ignore") - else: - raise ValueError("Unsupported string type: %s" % (type(text))) - elif six.PY2: - if isinstance(text, str): - return text.decode("utf-8", "ignore") - elif isinstance(text, unicode): # type: ignore # noqa: F821 - return text - else: - raise ValueError("Unsupported string type: %s" % (type(text))) - else: - raise ValueError("Not running on Python2 or Python 3?") - - -def printable_text(text): - """Returns text encoded in a way suitable for print or `tf.logging`.""" - - # These functions want `str` for both Python2 and Python3, but in one case - # it's a Unicode string and in the other it's a byte string. - if six.PY3: - if isinstance(text, str): - return text - elif isinstance(text, bytes): - return text.decode("utf-8", "ignore") - else: - raise ValueError("Unsupported string type: %s" % (type(text))) - elif six.PY2: - if isinstance(text, str): - return text - elif isinstance(text, unicode): # type: ignore # noqa: F821 - return text.encode("utf-8") - else: - raise ValueError("Unsupported string type: %s" % (type(text))) - else: - raise ValueError("Not running on Python2 or Python 3?") - - -def load_vocab(vocab_file): - """Loads a vocabulary file into a dictionary.""" - vocab = collections.OrderedDict() - index = 0 - with open(vocab_file, "r", encoding="utf-8") as reader: - while True: - token = convert_to_unicode(reader.readline()) - if not token: - break - token = token.strip() - vocab[token] = index - index += 1 - return vocab - - -def convert_by_vocab(vocab, items): - """Converts a sequence of [tokens|ids] using the vocab.""" - output = [] - for item in items: - output.append(vocab[item]) - return output - - -def convert_tokens_to_ids(vocab, tokens): - return convert_by_vocab(vocab, tokens) - - -def convert_ids_to_tokens(inv_vocab, ids): - return convert_by_vocab(inv_vocab, ids) - - -def whitespace_tokenize(text): - """Runs basic whitespace cleaning and splitting on a piece of text.""" - text = text.strip() - if not text: - return [] - tokens = text.split() - return tokens - - -class FullTokenizer(object): - """Runs end-to-end tokenization.""" - - def __init__(self, vocab_file, do_lower_case=True): - self.vocab = load_vocab(vocab_file) - self.inv_vocab = {v: k for k, v in self.vocab.items()} - self.basic_tokenizer = BasicTokenizer(do_lower_case=do_lower_case) - self.wordpiece_tokenizer = WordpieceTokenizer(vocab=self.vocab) - - def tokenize(self, text): - split_tokens = [] - for token in self.basic_tokenizer.tokenize(text): - for sub_token in self.wordpiece_tokenizer.tokenize(token): - split_tokens.append(sub_token) - - return split_tokens - - def convert_tokens_to_ids(self, tokens): - return convert_by_vocab(self.vocab, tokens) - - def convert_ids_to_tokens(self, ids): - return convert_by_vocab(self.inv_vocab, ids) - - @staticmethod - def convert_tokens_to_string(tokens, clean_up_tokenization_spaces=True): - """Converts a sequence of tokens (string) in a single string.""" - - def clean_up_tokenization(out_string): - """Clean up a list of simple English tokenization artifacts - like spaces before punctuations and abbreviated forms. - """ - out_string = ( - out_string.replace(" .", ".") - .replace(" ?", "?") - .replace(" !", "!") - .replace(" ,", ",") - .replace(" ' ", "'") - .replace(" n't", "n't") - .replace(" 'm", "'m") - .replace(" 's", "'s") - .replace(" 've", "'ve") - .replace(" 're", "'re") - ) - return out_string - - text = " ".join(tokens).replace(" ##", "").strip() - if clean_up_tokenization_spaces: - clean_text = clean_up_tokenization(text) - return clean_text - else: - return text - - def vocab_size(self): - return len(self.vocab) - - -class BasicTokenizer(object): - """Runs basic tokenization (punctuation splitting, lower casing, etc.).""" - - def __init__(self, do_lower_case=True): - """Constructs a BasicTokenizer. - - Args: - do_lower_case: Whether to lower case the input. - """ - self.do_lower_case = do_lower_case - - def tokenize(self, text): - """Tokenizes a piece of text.""" - text = convert_to_unicode(text) - text = self._clean_text(text) - - # This was added on November 1st, 2018 for the multilingual and Chinese - # models. This is also applied to the English models now, but it doesn't - # matter since the English models were not trained on any Chinese data - # and generally don't have any Chinese data in them (there are Chinese - # characters in the vocabulary because Wikipedia does have some Chinese - # words in the English Wikipedia.). - text = self._tokenize_chinese_chars(text) - - orig_tokens = whitespace_tokenize(text) - split_tokens = [] - for token in orig_tokens: - if self.do_lower_case: - token = token.lower() - token = self._run_strip_accents(token) - split_tokens.extend(self._run_split_on_punc(token)) - - output_tokens = whitespace_tokenize(" ".join(split_tokens)) - return output_tokens - - def _run_strip_accents(self, text): - """Strips accents from a piece of text.""" - text = unicodedata.normalize("NFD", text) - output = [] - for char in text: - cat = unicodedata.category(char) - if cat == "Mn": - continue - output.append(char) - return "".join(output) - - def _run_split_on_punc(self, text): - """Splits punctuation on a piece of text.""" - chars = list(text) - i = 0 - start_new_word = True - output = [] - while i < len(chars): - char = chars[i] - if _is_punctuation(char): - output.append([char]) - start_new_word = True - else: - if start_new_word: - output.append([]) - start_new_word = False - output[-1].append(char) - i += 1 - - return ["".join(x) for x in output] - - def _tokenize_chinese_chars(self, text): - """Adds whitespace around any CJK character.""" - output = [] - for char in text: - cp = ord(char) - if self._is_chinese_char(cp): - output.append(" ") - output.append(char) - output.append(" ") - else: - output.append(char) - return "".join(output) - - def _is_chinese_char(self, cp): - """Checks whether CP is the codepoint of a CJK character.""" - # This defines a "chinese character" as anything in the CJK Unicode block: - # https://en.wikipedia.org/wiki/CJK_Unified_Ideographs_(Unicode_block) - # - # Note that the CJK Unicode block is NOT all Japanese and Korean characters, - # despite its name. The modern Korean Hangul alphabet is a different block, - # as is Japanese Hiragana and Katakana. Those alphabets are used to write - # space-separated words, so they are not treated specially and handled - # like the all of the other languages. - if ( - (cp >= 0x4E00 and cp <= 0x9FFF) - or (cp >= 0x3400 and cp <= 0x4DBF) - or (cp >= 0x20000 and cp <= 0x2A6DF) - or (cp >= 0x2A700 and cp <= 0x2B73F) - or (cp >= 0x2B740 and cp <= 0x2B81F) - or (cp >= 0x2B820 and cp <= 0x2CEAF) - or (cp >= 0xF900 and cp <= 0xFAFF) - or (cp >= 0x2F800 and cp <= 0x2FA1F) - ): - return True - - return False - - def _clean_text(self, text): - """Performs invalid character removal and whitespace cleanup on text.""" - output = [] - for char in text: - cp = ord(char) - if cp == 0 or cp == 0xFFFD or _is_control(char): - continue - if _is_whitespace(char): - output.append(" ") - else: - output.append(char) - return "".join(output) - - -class WordpieceTokenizer(object): - """Runs WordPiece tokenization.""" - - def __init__(self, vocab, unk_token="[UNK]", max_input_chars_per_word=200): - self.vocab = vocab - self.unk_token = unk_token - self.max_input_chars_per_word = max_input_chars_per_word - - def tokenize(self, text): - """Tokenizes a piece of text into its word pieces. - - This uses a greedy longest-match-first algorithm to perform tokenization - using the given vocabulary. - - For example: - input = "unaffable" - output = ["un", "##aff", "##able"] - - Args: - text: A single token or whitespace separated tokens. This should have - already been passed through `BasicTokenizer. - - Returns: - A list of wordpiece tokens. - """ - - text = convert_to_unicode(text) - - output_tokens = [] - for token in whitespace_tokenize(text): - chars = list(token) - if len(chars) > self.max_input_chars_per_word: - output_tokens.append(self.unk_token) - continue - - is_bad = False - start = 0 - sub_tokens = [] - while start < len(chars): - end = len(chars) - cur_substr = None - while start < end: - substr = "".join(chars[start:end]) - if start > 0: - substr = "##" + substr - if substr in self.vocab: - cur_substr = substr - break - end -= 1 - if cur_substr is None: - is_bad = True - break - sub_tokens.append(cur_substr) - start = end - - if is_bad: - output_tokens.append(self.unk_token) - else: - output_tokens.extend(sub_tokens) - return output_tokens - - -def _is_whitespace(char): - """Checks whether `chars` is a whitespace character.""" - # \t, \n, and \r are technically control characters but we treat them - # as whitespace since they are generally considered as such. - if char == " " or char == "\t" or char == "\n" or char == "\r": - return True - cat = unicodedata.category(char) - if cat == "Zs": - return True - return False - - -def _is_control(char): - """Checks whether `chars` is a control character.""" - # These are technically control characters but we count them as whitespace - # characters. - if char == "\t" or char == "\n" or char == "\r": - return False - cat = unicodedata.category(char) - if cat in ("Cc", "Cf"): - return True - return False - - -def _is_punctuation(char): - """Checks whether `chars` is a punctuation character.""" - cp = ord(char) - # We treat all non-letter/number ASCII as punctuation. - # Characters such as "^", "$", and "`" are not in the Unicode - # Punctuation class but we treat them as punctuation anyways, for - # consistency. - if (cp >= 33 and cp <= 47) or (cp >= 58 and cp <= 64) or (cp >= 91 and cp <= 96) or (cp >= 123 and cp <= 126): - return True - cat = unicodedata.category(char) - if cat.startswith("P"): - return True - return False From 3b2435fd574c521912f49f8298a023d9677e4444 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tolga=20Cang=C3=B6z?= Date: Sat, 22 Feb 2025 12:27:39 +0300 Subject: [PATCH 77/87] update pillow --- examples/research_projects/anytext/anytext.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/examples/research_projects/anytext/anytext.py b/examples/research_projects/anytext/anytext.py index 4adf77b4a5a8..3fa61d202e0e 100644 --- a/examples/research_projects/anytext/anytext.py +++ b/examples/research_projects/anytext/anytext.py @@ -723,10 +723,11 @@ def draw_glyph(self, font, text): ratio = min(W * 0.9 / text_width, H * 0.9 / text_height) new_font = font.font_variant(size=int(g_size * ratio)) - text_width, text_height = new_font.getsize(text) - offset_x, offset_y = new_font.getoffset(text) + left, top, right, bottom = new_font.getbbox(text) + text_width = right - left + text_height = bottom - top x = (img.width - text_width) // 2 - y = (img.height - text_height) // 2 - offset_y // 2 + y = (img.height - text_height) // 2 - top // 2 draw.text((x, y), text, font=new_font, fill="white") img = np.expand_dims(np.array(img), axis=2).astype(np.float64) return img From 3ea49c15149e2c4da8bc4881dd442a4797e83429 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tolga=20Cang=C3=B6z?= Date: Sat, 22 Feb 2025 12:28:36 +0300 Subject: [PATCH 78/87] [UPDATE] Remove commented-out code and unnecessary docstring in anytext.py and anytext_controlnet.py for improved clarity --- examples/research_projects/anytext/anytext.py | 9 +-------- examples/research_projects/anytext/anytext_controlnet.py | 9 --------- 2 files changed, 1 insertion(+), 17 deletions(-) diff --git a/examples/research_projects/anytext/anytext.py b/examples/research_projects/anytext/anytext.py index 3fa61d202e0e..1bf4999762af 100644 --- a/examples/research_projects/anytext/anytext.py +++ b/examples/research_projects/anytext/anytext.py @@ -25,6 +25,7 @@ import os import re import sys +import unicodedata from functools import partial from typing import Any, Callable, Dict, List, Optional, Tuple, Union @@ -324,12 +325,6 @@ def adjust_image(box, img): return result -""" -mask: numpy.ndarray, mask of textual, HWC -src_img: torch.Tensor, source image, CHW -""" - - def crop_image(src_img, mask): box = min_bounding_rect(mask) result = adjust_image(box, src_img) @@ -526,10 +521,8 @@ def get_ctcloss(self, preds, gt_text, weight): class TextEmbeddingModule(nn.Module): - # @register_to_config def __init__(self, font_path, use_fp16=False, device="cpu"): super().__init__() - # TODO: Learn if the recommended font file is free to use self.font = ImageFont.truetype(font_path, 60) self.use_fp16 = use_fp16 self.device = device diff --git a/examples/research_projects/anytext/anytext_controlnet.py b/examples/research_projects/anytext/anytext_controlnet.py index 81f65a8315a4..51e47cdf6a6f 100644 --- a/examples/research_projects/anytext/anytext_controlnet.py +++ b/examples/research_projects/anytext/anytext_controlnet.py @@ -85,15 +85,6 @@ def __init__( self.fuse_block = nn.Conv2d(256 + 64 + 4, conditioning_embedding_channels, 3, padding=1) - # self.glyph_block.load_state_dict(load_file("glyph_block.safetensors", device=str(self.device))) - # self.position_block.load_state_dict(load_file("position_block.safetensors", device=str(self.device))) - # self.fuse_block.load_state_dict(load_file("fuse_block.safetensors", device=str(self.device))) - - # if use_fp16: - # self.glyph_block = self.glyph_block.to(dtype=torch.float16) - # self.position_block = self.position_block.to(dtype=torch.float16) - # self.fuse_block = self.fuse_block.to(dtype=torch.float16) - def forward(self, glyphs, positions, text_info): glyph_embedding = self.glyph_block(glyphs.to(self.glyph_block[0].weight.device)) position_embedding = self.position_block(positions.to(self.position_block[0].weight.device)) From 299a646d9db7433ca270dc0761ba470a00e1d832 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tolga=20Cang=C3=B6z?= Date: Sat, 22 Feb 2025 15:33:05 +0300 Subject: [PATCH 79/87] [REMOVE] Delete frozen_clip_embedder_t3.py as it is in the anytext.py file --- examples/research_projects/anytext/anytext.py | 217 +++++++++++++++++- .../anytext/frozen_clip_embedder_t3.py | 214 ----------------- 2 files changed, 216 insertions(+), 215 deletions(-) delete mode 100644 examples/research_projects/anytext/frozen_clip_embedder_t3.py diff --git a/examples/research_projects/anytext/anytext.py b/examples/research_projects/anytext/anytext.py index 1bf4999762af..8b4e1913551d 100644 --- a/examples/research_projects/anytext/anytext.py +++ b/examples/research_projects/anytext/anytext.py @@ -35,7 +35,6 @@ import torch import torch.nn.functional as F from easydict import EasyDict as edict -from frozen_clip_embedder_t3 import FrozenCLIPEmbedderT3 from huggingface_hub import hf_hub_download from ocr_recog.RecModel import RecModel from PIL import Image, ImageDraw, ImageFont @@ -520,6 +519,222 @@ def get_ctcloss(self, preds, gt_text, weight): return loss +import torch +from torch import nn +from transformers import CLIPTextModel, CLIPTokenizer +from transformers.modeling_attn_mask_utils import _create_4d_causal_attention_mask, _prepare_4d_attention_mask + + +class AbstractEncoder(nn.Module): + def __init__(self): + super().__init__() + + def encode(self, *args, **kwargs): + raise NotImplementedError + + +class FrozenCLIPEmbedderT3(AbstractEncoder): + """Uses the CLIP transformer encoder for text (from Hugging Face)""" + + def __init__( + self, + version="openai/clip-vit-large-patch14", + device="cpu", + max_length=77, + freeze=True, + use_fp16=False, + ): + super().__init__() + self.tokenizer = CLIPTokenizer.from_pretrained(version) + self.transformer = CLIPTextModel.from_pretrained( + version, use_safetensors=True, torch_dtype=torch.float16 if use_fp16 else torch.float32 + ).to(device) + self.device = device + self.max_length = max_length + if freeze: + self.freeze() + + def embedding_forward( + self, + input_ids=None, + position_ids=None, + inputs_embeds=None, + embedding_manager=None, + ): + seq_length = input_ids.shape[-1] if input_ids is not None else inputs_embeds.shape[-2] + if position_ids is None: + position_ids = self.position_ids[:, :seq_length] + if inputs_embeds is None: + inputs_embeds = self.token_embedding(input_ids) + if embedding_manager is not None: + inputs_embeds = embedding_manager(input_ids, inputs_embeds) + position_embeddings = self.position_embedding(position_ids) + embeddings = inputs_embeds + position_embeddings + return embeddings + + self.transformer.text_model.embeddings.forward = embedding_forward.__get__( + self.transformer.text_model.embeddings + ) + + def encoder_forward( + self, + inputs_embeds, + attention_mask=None, + causal_attention_mask=None, + output_attentions=None, + output_hidden_states=None, + return_dict=None, + ): + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + encoder_states = () if output_hidden_states else None + all_attentions = () if output_attentions else None + hidden_states = inputs_embeds + for idx, encoder_layer in enumerate(self.layers): + if output_hidden_states: + encoder_states = encoder_states + (hidden_states,) + layer_outputs = encoder_layer( + hidden_states, + attention_mask, + causal_attention_mask, + output_attentions=output_attentions, + ) + hidden_states = layer_outputs[0] + if output_attentions: + all_attentions = all_attentions + (layer_outputs[1],) + if output_hidden_states: + encoder_states = encoder_states + (hidden_states,) + return hidden_states + + self.transformer.text_model.encoder.forward = encoder_forward.__get__(self.transformer.text_model.encoder) + + def text_encoder_forward( + self, + input_ids=None, + attention_mask=None, + position_ids=None, + output_attentions=None, + output_hidden_states=None, + return_dict=None, + embedding_manager=None, + ): + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + if input_ids is None: + raise ValueError("You have to specify either input_ids") + input_shape = input_ids.size() + input_ids = input_ids.view(-1, input_shape[-1]) + hidden_states = self.embeddings( + input_ids=input_ids, position_ids=position_ids, embedding_manager=embedding_manager + ) + # CLIP's text model uses causal mask, prepare it here. + # https://github.com/openai/CLIP/blob/cfcffb90e69f37bf2ff1e988237a0fbe41f33c04/clip/model.py#L324 + causal_attention_mask = _create_4d_causal_attention_mask( + input_shape, hidden_states.dtype, device=hidden_states.device + ) + # expand attention_mask + if attention_mask is not None: + # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] + attention_mask = _prepare_4d_attention_mask(attention_mask, hidden_states.dtype) + last_hidden_state = self.encoder( + inputs_embeds=hidden_states, + attention_mask=attention_mask, + causal_attention_mask=causal_attention_mask, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + last_hidden_state = self.final_layer_norm(last_hidden_state) + return last_hidden_state + + self.transformer.text_model.forward = text_encoder_forward.__get__(self.transformer.text_model) + + def transformer_forward( + self, + input_ids=None, + attention_mask=None, + position_ids=None, + output_attentions=None, + output_hidden_states=None, + return_dict=None, + embedding_manager=None, + ): + return self.text_model( + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + embedding_manager=embedding_manager, + ) + + self.transformer.forward = transformer_forward.__get__(self.transformer) + + def freeze(self): + self.transformer = self.transformer.eval() + for param in self.parameters(): + param.requires_grad = False + + def forward(self, text, **kwargs): + batch_encoding = self.tokenizer( + text, + truncation=False, + max_length=self.max_length, + return_length=True, + return_overflowing_tokens=False, + padding="longest", + return_tensors="pt", + ) + input_ids = batch_encoding["input_ids"] + tokens_list = self.split_chunks(input_ids) + z_list = [] + for tokens in tokens_list: + tokens = tokens.to(self.device) + _z = self.transformer(input_ids=tokens, **kwargs) + z_list += [_z] + return torch.cat(z_list, dim=1) + + def encode(self, text, **kwargs): + return self(text, **kwargs) + + def split_chunks(self, input_ids, chunk_size=75): + tokens_list = [] + bs, n = input_ids.shape + id_start = input_ids[:, 0].unsqueeze(1) # dim --> [bs, 1] + id_end = input_ids[:, -1].unsqueeze(1) + if n == 2: # empty caption + tokens_list.append(torch.cat((id_start,) + (id_end,) * (chunk_size + 1), dim=1)) + + trimmed_encoding = input_ids[:, 1:-1] + num_full_groups = (n - 2) // chunk_size + + for i in range(num_full_groups): + group = trimmed_encoding[:, i * chunk_size : (i + 1) * chunk_size] + group_pad = torch.cat((id_start, group, id_end), dim=1) + tokens_list.append(group_pad) + + remaining_columns = (n - 2) % chunk_size + if remaining_columns > 0: + remaining_group = trimmed_encoding[:, -remaining_columns:] + padding_columns = chunk_size - remaining_group.shape[1] + padding = id_end.expand(bs, padding_columns) + remaining_group_pad = torch.cat((id_start, remaining_group, padding, id_end), dim=1) + tokens_list.append(remaining_group_pad) + return tokens_list + + def to(self, *args, **kwargs): + self.transformer = self.transformer.to(*args, **kwargs) + self.device = self.transformer.device + return self + + class TextEmbeddingModule(nn.Module): def __init__(self, font_path, use_fp16=False, device="cpu"): super().__init__() diff --git a/examples/research_projects/anytext/frozen_clip_embedder_t3.py b/examples/research_projects/anytext/frozen_clip_embedder_t3.py deleted file mode 100644 index 00f33109b3d0..000000000000 --- a/examples/research_projects/anytext/frozen_clip_embedder_t3.py +++ /dev/null @@ -1,214 +0,0 @@ -import torch -from torch import nn -from transformers import CLIPTextModel, CLIPTokenizer -from transformers.modeling_attn_mask_utils import _create_4d_causal_attention_mask, _prepare_4d_attention_mask - - -class AbstractEncoder(nn.Module): - def __init__(self): - super().__init__() - - def encode(self, *args, **kwargs): - raise NotImplementedError - - -class FrozenCLIPEmbedderT3(AbstractEncoder): - """Uses the CLIP transformer encoder for text (from Hugging Face)""" - - def __init__( - self, - version="openai/clip-vit-large-patch14", - device="cpu", - max_length=77, - freeze=True, - use_fp16=False, - ): - super().__init__() - self.tokenizer = CLIPTokenizer.from_pretrained(version) - self.transformer = CLIPTextModel.from_pretrained( - version, use_safetensors=True, torch_dtype=torch.float16 if use_fp16 else torch.float32 - ).to(device) - self.device = device - self.max_length = max_length - if freeze: - self.freeze() - - def embedding_forward( - self, - input_ids=None, - position_ids=None, - inputs_embeds=None, - embedding_manager=None, - ): - seq_length = input_ids.shape[-1] if input_ids is not None else inputs_embeds.shape[-2] - if position_ids is None: - position_ids = self.position_ids[:, :seq_length] - if inputs_embeds is None: - inputs_embeds = self.token_embedding(input_ids) - if embedding_manager is not None: - inputs_embeds = embedding_manager(input_ids, inputs_embeds) - position_embeddings = self.position_embedding(position_ids) - embeddings = inputs_embeds + position_embeddings - return embeddings - - self.transformer.text_model.embeddings.forward = embedding_forward.__get__( - self.transformer.text_model.embeddings - ) - - def encoder_forward( - self, - inputs_embeds, - attention_mask=None, - causal_attention_mask=None, - output_attentions=None, - output_hidden_states=None, - return_dict=None, - ): - output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions - output_hidden_states = ( - output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states - ) - return_dict = return_dict if return_dict is not None else self.config.use_return_dict - encoder_states = () if output_hidden_states else None - all_attentions = () if output_attentions else None - hidden_states = inputs_embeds - for idx, encoder_layer in enumerate(self.layers): - if output_hidden_states: - encoder_states = encoder_states + (hidden_states,) - layer_outputs = encoder_layer( - hidden_states, - attention_mask, - causal_attention_mask, - output_attentions=output_attentions, - ) - hidden_states = layer_outputs[0] - if output_attentions: - all_attentions = all_attentions + (layer_outputs[1],) - if output_hidden_states: - encoder_states = encoder_states + (hidden_states,) - return hidden_states - - self.transformer.text_model.encoder.forward = encoder_forward.__get__(self.transformer.text_model.encoder) - - def text_encoder_forward( - self, - input_ids=None, - attention_mask=None, - position_ids=None, - output_attentions=None, - output_hidden_states=None, - return_dict=None, - embedding_manager=None, - ): - output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions - output_hidden_states = ( - output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states - ) - return_dict = return_dict if return_dict is not None else self.config.use_return_dict - if input_ids is None: - raise ValueError("You have to specify either input_ids") - input_shape = input_ids.size() - input_ids = input_ids.view(-1, input_shape[-1]) - hidden_states = self.embeddings( - input_ids=input_ids, position_ids=position_ids, embedding_manager=embedding_manager - ) - # CLIP's text model uses causal mask, prepare it here. - # https://github.com/openai/CLIP/blob/cfcffb90e69f37bf2ff1e988237a0fbe41f33c04/clip/model.py#L324 - causal_attention_mask = _create_4d_causal_attention_mask( - input_shape, hidden_states.dtype, device=hidden_states.device - ) - # expand attention_mask - if attention_mask is not None: - # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] - attention_mask = _prepare_4d_attention_mask(attention_mask, hidden_states.dtype) - last_hidden_state = self.encoder( - inputs_embeds=hidden_states, - attention_mask=attention_mask, - causal_attention_mask=causal_attention_mask, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - ) - last_hidden_state = self.final_layer_norm(last_hidden_state) - return last_hidden_state - - self.transformer.text_model.forward = text_encoder_forward.__get__(self.transformer.text_model) - - def transformer_forward( - self, - input_ids=None, - attention_mask=None, - position_ids=None, - output_attentions=None, - output_hidden_states=None, - return_dict=None, - embedding_manager=None, - ): - return self.text_model( - input_ids=input_ids, - attention_mask=attention_mask, - position_ids=position_ids, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - embedding_manager=embedding_manager, - ) - - self.transformer.forward = transformer_forward.__get__(self.transformer) - - def freeze(self): - self.transformer = self.transformer.eval() - for param in self.parameters(): - param.requires_grad = False - - def forward(self, text, **kwargs): - batch_encoding = self.tokenizer( - text, - truncation=False, - max_length=self.max_length, - return_length=True, - return_overflowing_tokens=False, - padding="longest", - return_tensors="pt", - ) - input_ids = batch_encoding["input_ids"] - tokens_list = self.split_chunks(input_ids) - z_list = [] - for tokens in tokens_list: - tokens = tokens.to(self.device) - _z = self.transformer(input_ids=tokens, **kwargs) - z_list += [_z] - return torch.cat(z_list, dim=1) - - def encode(self, text, **kwargs): - return self(text, **kwargs) - - def split_chunks(self, input_ids, chunk_size=75): - tokens_list = [] - bs, n = input_ids.shape - id_start = input_ids[:, 0].unsqueeze(1) # dim --> [bs, 1] - id_end = input_ids[:, -1].unsqueeze(1) - if n == 2: # empty caption - tokens_list.append(torch.cat((id_start,) + (id_end,) * (chunk_size + 1), dim=1)) - - trimmed_encoding = input_ids[:, 1:-1] - num_full_groups = (n - 2) // chunk_size - - for i in range(num_full_groups): - group = trimmed_encoding[:, i * chunk_size : (i + 1) * chunk_size] - group_pad = torch.cat((id_start, group, id_end), dim=1) - tokens_list.append(group_pad) - - remaining_columns = (n - 2) % chunk_size - if remaining_columns > 0: - remaining_group = trimmed_encoding[:, -remaining_columns:] - padding_columns = chunk_size - remaining_group.shape[1] - padding = id_end.expand(bs, padding_columns) - remaining_group_pad = torch.cat((id_start, remaining_group, padding, id_end), dim=1) - tokens_list.append(remaining_group_pad) - return tokens_list - - def to(self, *args, **kwargs): - self.transformer = self.transformer.to(*args, **kwargs) - self.device = self.transformer.device - return self From 0d44b5baf1ebb44b2e5d5afe17fb580b25d36dce Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tolga=20Cang=C3=B6z?= Date: Mon, 24 Feb 2025 11:49:48 +0300 Subject: [PATCH 80/87] [UPDATE] Replace edict with dict for configuration in anytext.py and RecModel.py for consistency --- examples/research_projects/anytext/anytext.py | 8 ++++---- .../research_projects/anytext/ocr_recog/RecModel.py | 12 ++++++------ 2 files changed, 10 insertions(+), 10 deletions(-) diff --git a/examples/research_projects/anytext/anytext.py b/examples/research_projects/anytext/anytext.py index 8b4e1913551d..ecda444654de 100644 --- a/examples/research_projects/anytext/anytext.py +++ b/examples/research_projects/anytext/anytext.py @@ -348,11 +348,11 @@ def create_predictor(model_dir=None, model_lang="ch", device="cpu", use_fp16=Fal n_class = 97 else: raise ValueError(f"Unsupported OCR recog model_lang: {model_lang}") - rec_config = edict( + rec_config = dict( in_channels=3, - backbone=edict(type="MobileNetV1Enhance", scale=0.5, last_conv_stride=[1, 2], last_pool_type="avg"), - neck=edict(type="SequenceEncoder", encoder_type="svtr", dims=64, depth=2, hidden_dims=120, use_guide=True), - head=edict(type="CTCHead", fc_decay=0.00001, out_channels=n_class, return_feats=True), + backbone=dict(type="MobileNetV1Enhance", scale=0.5, last_conv_stride=[1, 2], last_pool_type="avg"), + neck=dict(type="SequenceEncoder", encoder_type="svtr", dims=64, depth=2, hidden_dims=120, use_guide=True), + head=dict(type="CTCHead", fc_decay=0.00001, out_channels=n_class, return_feats=True), ) rec_model = RecModel(rec_config) diff --git a/examples/research_projects/anytext/ocr_recog/RecModel.py b/examples/research_projects/anytext/ocr_recog/RecModel.py index 5f0f8f0375f1..26c988333a5a 100755 --- a/examples/research_projects/anytext/ocr_recog/RecModel.py +++ b/examples/research_projects/anytext/ocr_recog/RecModel.py @@ -14,17 +14,17 @@ class RecModel(nn.Module): def __init__(self, config): super().__init__() assert "in_channels" in config, "in_channels must in model config" - backbone_type = config.backbone.pop("type") + backbone_type = config["backbone"].pop("type") assert backbone_type in backbone_dict, f"backbone.type must in {backbone_dict}" - self.backbone = backbone_dict[backbone_type](config.in_channels, **config.backbone) + self.backbone = backbone_dict[backbone_type](config['in_channels'], **config['backbone']) - neck_type = config.neck.pop("type") + neck_type = config['neck'].pop("type") assert neck_type in neck_dict, f"neck.type must in {neck_dict}" - self.neck = neck_dict[neck_type](self.backbone.out_channels, **config.neck) + self.neck = neck_dict[neck_type](self.backbone.out_channels, **config['neck']) - head_type = config.head.pop("type") + head_type = config['head'].pop("type") assert head_type in head_dict, f"head.type must in {head_dict}" - self.head = head_dict[head_type](self.neck.out_channels, **config.head) + self.head = head_dict[head_type](self.neck.out_channels, **config['head']) self.name = f"RecModel_{backbone_type}_{neck_type}_{head_type}" From 13b7ecf241d523a981eb63e5be707aae44e4dc09 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tolga=20Cang=C3=B6z?= Date: Mon, 24 Feb 2025 16:49:17 +0300 Subject: [PATCH 81/87] =?UTF-8?q?=F0=9F=86=99?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- examples/research_projects/anytext/anytext.py | 173 +++++++++--------- 1 file changed, 91 insertions(+), 82 deletions(-) diff --git a/examples/research_projects/anytext/anytext.py b/examples/research_projects/anytext/anytext.py index ecda444654de..1010ac8de6e2 100644 --- a/examples/research_projects/anytext/anytext.py +++ b/examples/research_projects/anytext/anytext.py @@ -34,7 +34,6 @@ import PIL.Image import torch import torch.nn.functional as F -from easydict import EasyDict as edict from huggingface_hub import hf_hub_download from ocr_recog.RecModel import RecModel from PIL import Image, ImageDraw, ImageFont @@ -58,6 +57,8 @@ from diffusers.pipelines.stable_diffusion.pipeline_output import StableDiffusionPipelineOutput from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker from diffusers.schedulers import KarrasDiffusionSchedulers +from diffusers.configuration_utils import register_to_config, ConfigMixin +from diffusers.models.modeling_utils import ModelMixin from diffusers.utils import ( USE_PEFT_BACKEND, deprecate, @@ -203,18 +204,18 @@ def get_recog_emb(encoder, img_list): return preds_neck -class EmbeddingManager(nn.Module): +class EmbeddingManager(ModelMixin, ConfigMixin): + @register_to_config def __init__( self, embedder, placeholder_string="*", use_fp16=False, + token_dim = 768, + get_recog_emb = None, ): super().__init__() get_token_for_string = partial(get_clip_token_for_string, embedder.tokenizer) - token_dim = 768 - self.get_recog_emb = None - self.token_dim = token_dim self.proj = nn.Linear(40 * 64, token_dim) proj_dir = hf_hub_download( @@ -226,12 +227,14 @@ def __init__( if use_fp16: self.proj = self.proj.to(dtype=torch.float16) + # self.register_parameter("proj", proj) self.placeholder_token = get_token_for_string(placeholder_string) + # self.register_config(placeholder_token=placeholder_token) @torch.no_grad() def encode_text(self, text_info): - if self.get_recog_emb is None: - self.get_recog_emb = partial(get_recog_emb, self.recog) + if self.config.get_recog_emb is None: + self.config.get_recog_emb = partial(get_recog_emb, self.recog) gline_list = [] for i in range(len(text_info["n_lines"])): # sample index in a batch @@ -240,7 +243,7 @@ def encode_text(self, text_info): gline_list += [text_info["gly_line"][j][i : i + 1]] if len(gline_list) > 0: - recog_emb = self.get_recog_emb(gline_list) + recog_emb = self.config.get_recog_emb(gline_list) enc_glyph = self.proj(recog_emb.reshape(recog_emb.shape[0], -1).to(self.proj.weight.dtype)) self.text_embs_all = [] @@ -332,13 +335,12 @@ def crop_image(src_img, mask): return result -def create_predictor(model_dir=None, model_lang="ch", device="cpu", use_fp16=False): - if model_dir is None or not os.path.exists(model_dir): - model_dir = hf_hub_download( - repo_id="tolgacangoz/anytext", - filename="text_embedding_module/OCR/ppv3_rec.pth", - cache_dir=HF_MODULES_CACHE, - ) +def create_predictor(model_lang="ch", device="cpu", use_fp16=False): + model_dir = hf_hub_download( + repo_id="tolgacangoz/anytext", + filename="text_embedding_module/OCR/ppv3_rec.pth", + cache_dir=HF_MODULES_CACHE, + ) if not os.path.exists(model_dir): raise ValueError("not find model file path {}".format(model_dir)) @@ -533,24 +535,24 @@ def encode(self, *args, **kwargs): raise NotImplementedError -class FrozenCLIPEmbedderT3(AbstractEncoder): +class FrozenCLIPEmbedderT3(AbstractEncoder, ModelMixin, ConfigMixin): """Uses the CLIP transformer encoder for text (from Hugging Face)""" - + @register_to_config def __init__( self, - version="openai/clip-vit-large-patch14", device="cpu", max_length=77, freeze=True, use_fp16=False, + variant: Optional[str] = None, ): super().__init__() - self.tokenizer = CLIPTokenizer.from_pretrained(version) - self.transformer = CLIPTextModel.from_pretrained( - version, use_safetensors=True, torch_dtype=torch.float16 if use_fp16 else torch.float32 - ).to(device) - self.device = device - self.max_length = max_length + self.tokenizer = CLIPTokenizer.from_pretrained("tolgacangoz/anytext", subfolder="tokenizer") + self.transformer = CLIPTextModel.from_pretrained("tolgacangoz/anytext", subfolder="text_encoder", + torch_dtype=torch.float16 if use_fp16 else torch.float32, + variant="fp16" if use_fp16 else None) + # self.device = device + # self.max_length = max_length if freeze: self.freeze() @@ -686,7 +688,7 @@ def forward(self, text, **kwargs): batch_encoding = self.tokenizer( text, truncation=False, - max_length=self.max_length, + max_length=self.config.max_length, return_length=True, return_overflowing_tokens=False, padding="longest", @@ -729,34 +731,39 @@ def split_chunks(self, input_ids, chunk_size=75): tokens_list.append(remaining_group_pad) return tokens_list - def to(self, *args, **kwargs): - self.transformer = self.transformer.to(*args, **kwargs) - self.device = self.transformer.device - return self + # def to(self, *args, **kwargs): + # self.transformer = self.transformer.to(*args, **kwargs) + # self.device = self.transformer.device + # return self -class TextEmbeddingModule(nn.Module): +class TextEmbeddingModule(ModelMixin, ConfigMixin): + @register_to_config def __init__(self, font_path, use_fp16=False, device="cpu"): super().__init__() - self.font = ImageFont.truetype(font_path, 60) - self.use_fp16 = use_fp16 - self.device = device + font = ImageFont.truetype(font_path, 60) + + # self.use_fp16 = use_fp16 + # self.device = device self.frozen_CLIP_embedder_t3 = FrozenCLIPEmbedderT3(device=device, use_fp16=use_fp16) self.embedding_manager = EmbeddingManager(self.frozen_CLIP_embedder_t3, use_fp16=use_fp16) - rec_model_dir = "./text_embedding_module/OCR/ppv3_rec.pth" - self.text_predictor = create_predictor(rec_model_dir, device=device, use_fp16=use_fp16).eval() - args = {} - args["rec_image_shape"] = "3, 48, 320" - args["rec_batch_num"] = 6 - args["rec_char_dict_path"] = "./text_embedding_module/OCR/ppocr_keys_v1.txt" - args["rec_char_dict_path"] = hf_hub_download( - repo_id="tolgacangoz/anytext", - filename="text_embedding_module/OCR/ppocr_keys_v1.txt", - cache_dir=HF_MODULES_CACHE, - ) - args["use_fp16"] = use_fp16 + self.text_predictor = create_predictor(device=device, use_fp16=use_fp16).eval() + args = {"rec_image_shape": "3, 48, 320", + "rec_batch_num": 6, + "rec_char_dict_path": hf_hub_download( + repo_id="tolgacangoz/anytext", + filename="text_embedding_module/OCR/ppocr_keys_v1.txt", + cache_dir=HF_MODULES_CACHE, + ), + "use_fp16": use_fp16} self.embedding_manager.recog = TextRecognizer(args, self.text_predictor) + # self.register_modules( + # frozen_CLIP_embedder_t3=frozen_CLIP_embedder_t3, + # embedding_manager=embedding_manager, + # ) + self.register_to_config(font=font) + @torch.no_grad() def forward( self, @@ -837,9 +844,9 @@ def forward( text = text[:max_chars] gly_scale = 2 if pre_pos[i].mean() != 0: - gly_line = self.draw_glyph(self.font, text) + gly_line = self.draw_glyph(self.config.font, text) glyphs = self.draw_glyph2( - self.font, text, poly_list[i], scale=gly_scale, width=w, height=h, add_space=False + self.config.font, text, poly_list[i], scale=gly_scale, width=w, height=h, add_space=False ) if revise_pos: resize_gly = cv2.resize(glyphs, (pre_pos[i].shape[1], pre_pos[i].shape[0])) @@ -881,7 +888,7 @@ def forward( def arr2tensor(self, arr, bs): arr = np.transpose(arr, (2, 0, 1)) _arr = torch.from_numpy(arr.copy()).float().cpu() - if self.use_fp16: + if self.config.use_fp16: _arr = _arr.half() _arr = torch.stack([_arr for _ in range(bs)], dim=0) return _arr @@ -1021,12 +1028,10 @@ def insert_spaces(self, string, nSpace): new_string += char + " " * nSpace return new_string[:-nSpace] - def to(self, *args, **kwargs): - self.frozen_CLIP_embedder_t3 = self.frozen_CLIP_embedder_t3.to(*args, **kwargs) - self.embedding_manager = self.embedding_manager.to(*args, **kwargs) - self.text_predictor = self.text_predictor.to(*args, **kwargs) - self.device = self.frozen_CLIP_embedder_t3.device - return self + # def to(self, *args, **kwargs): + # self.frozen_CLIP_embedder_t3 = self.frozen_CLIP_embedder_t3.to(*args, **kwargs) + # self.embedding_manager = self.embedding_manager.to(*args, **kwargs) + # return self # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents @@ -1043,20 +1048,17 @@ def retrieve_latents( raise AttributeError("Could not access latents of provided encoder_output") -class AuxiliaryLatentModule(nn.Module): +class AuxiliaryLatentModule(ModelMixin, ConfigMixin): + @register_to_config def __init__( self, - font_path, - vae=None, + # font_path, + vae, device="cpu", - use_fp16=False, ): super().__init__() - self.font = ImageFont.truetype(font_path, 60) - self.use_fp16 = use_fp16 - self.device = device - - self.vae = vae.eval() if vae is not None else None + # self.font = ImageFont.truetype(font_path, 60) + # self.vae = vae.eval() if vae is not None else None @torch.no_grad() def forward( @@ -1093,12 +1095,13 @@ def forward( # get masked_x masked_img = ((edit_image.astype(np.float32) / 127.5) - 1.0) * (1 - np_hint) masked_img = np.transpose(masked_img, (2, 0, 1)) - device = next(self.vae.parameters()).device + device = next(self.config.vae.parameters()).device + dtype = next(self.config.vae.parameters()).dtype masked_img = torch.from_numpy(masked_img.copy()).float().to(device) - if self.use_fp16: + if dtype == torch.float16: masked_img = masked_img.half() - masked_x = (retrieve_latents(self.vae.encode(masked_img[None, ...])) * self.vae.config.scaling_factor).detach() - if self.use_fp16: + masked_x = (retrieve_latents(self.config.vae.encode(masked_img[None, ...])) * self.config.vae.config.scaling_factor).detach() + if dtype == torch.float16: masked_x = masked_x.half() text_info["masked_x"] = torch.cat([masked_x for _ in range(num_images_per_prompt)], dim=0) @@ -1137,10 +1140,10 @@ def insert_spaces(self, string, nSpace): new_string += char + " " * nSpace return new_string[:-nSpace] - def to(self, *args, **kwargs): - self.vae = self.vae.to(*args, **kwargs) - self.device = self.vae.device - return self + # def to(self, *args, **kwargs): + # self.vae = self.vae.to(*args, **kwargs) + # self.device = self.vae.device + # return self # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps @@ -1255,7 +1258,6 @@ class AnyTextPipeline( def __init__( self, - font_path: str, vae: AutoencoderKL, text_encoder: CLIPTextModel, tokenizer: CLIPTokenizer, @@ -1264,18 +1266,25 @@ def __init__( scheduler: KarrasDiffusionSchedulers, safety_checker: StableDiffusionSafetyChecker, feature_extractor: CLIPImageProcessor, + font_path: str = None, + text_embedding_module: Optional[TextEmbeddingModule] = None, + auxiliary_latent_module: Optional[AuxiliaryLatentModule] = None, trust_remote_code: bool = False, - text_embedding_module: TextEmbeddingModule = None, - auxiliary_latent_module: AuxiliaryLatentModule = None, image_encoder: CLIPVisionModelWithProjection = None, requires_safety_checker: bool = True, ): super().__init__() - self.text_embedding_module = TextEmbeddingModule( - use_fp16=unet.dtype == torch.float16, device=unet.device, font_path=font_path + if font_path is None: + raise ValueError("font_path is required!") + + text_embedding_module = TextEmbeddingModule( + font_path=font_path, + use_fp16=unet.dtype == torch.float16, ) - self.auxiliary_latent_module = AuxiliaryLatentModule( - vae=vae, use_fp16=unet.dtype == torch.float16, device=unet.device, font_path=font_path + auxiliary_latent_module = AuxiliaryLatentModule( + # font_path=font_path, + vae=vae, + # use_fp16=unet.dtype == torch.float16, ) if safety_checker is None and requires_safety_checker: @@ -1307,15 +1316,15 @@ def __init__( safety_checker=safety_checker, feature_extractor=feature_extractor, image_encoder=image_encoder, - text_embedding_module=self.text_embedding_module, - auxiliary_latent_module=self.auxiliary_latent_module, + text_embedding_module=text_embedding_module, + auxiliary_latent_module=auxiliary_latent_module, ) self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor, do_convert_rgb=True) self.control_image_processor = VaeImageProcessor( vae_scale_factor=self.vae_scale_factor, do_convert_rgb=True, do_normalize=False ) - self.register_to_config(requires_safety_checker=requires_safety_checker, font_path=font_path) + self.register_to_config(requires_safety_checker=requires_safety_checker)#, font_path=font_path) def modify_prompt(self, prompt): prompt = prompt.replace("“", '"') @@ -2331,7 +2340,7 @@ def __call__( cond_scale = controlnet_cond_scale * controlnet_keep[i] down_block_res_samples, mid_block_res_sample = self.controlnet( - control_model_input, + control_model_input.to(self.controlnet.dtype), t, encoder_hidden_states=controlnet_prompt_embeds, controlnet_cond=guided_hint, From d5a6e5f068cdbd593e3b106ac994f8a415f10570 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tolga=20Cang=C3=B6z?= Date: Wed, 26 Feb 2025 12:26:51 +0300 Subject: [PATCH 82/87] style --- examples/research_projects/anytext/anytext.py | 119 +++++++----------- .../anytext/ocr_recog/RecModel.py | 10 +- 2 files changed, 48 insertions(+), 81 deletions(-) diff --git a/examples/research_projects/anytext/anytext.py b/examples/research_projects/anytext/anytext.py index 1010ac8de6e2..d497db2859d3 100644 --- a/examples/research_projects/anytext/anytext.py +++ b/examples/research_projects/anytext/anytext.py @@ -41,8 +41,10 @@ from skimage.transform._geometric import _umeyama as get_sym_mat from torch import nn from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer, CLIPVisionModelWithProjection +from transformers.modeling_attn_mask_utils import _create_4d_causal_attention_mask, _prepare_4d_attention_mask from diffusers.callbacks import MultiPipelineCallbacks, PipelineCallback +from diffusers.configuration_utils import ConfigMixin, register_to_config from diffusers.image_processor import PipelineImageInput, VaeImageProcessor from diffusers.loaders import ( FromSingleFileMixin, @@ -52,13 +54,12 @@ ) from diffusers.models import AutoencoderKL, ControlNetModel, ImageProjection, UNet2DConditionModel from diffusers.models.lora import adjust_lora_scale_text_encoder +from diffusers.models.modeling_utils import ModelMixin from diffusers.pipelines.controlnet.multicontrolnet import MultiControlNetModel from diffusers.pipelines.pipeline_utils import DiffusionPipeline, StableDiffusionMixin from diffusers.pipelines.stable_diffusion.pipeline_output import StableDiffusionPipelineOutput from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker from diffusers.schedulers import KarrasDiffusionSchedulers -from diffusers.configuration_utils import register_to_config, ConfigMixin -from diffusers.models.modeling_utils import ModelMixin from diffusers.utils import ( USE_PEFT_BACKEND, deprecate, @@ -154,21 +155,14 @@ def _is_whitespace(self, char): >>> # I chose a font file shared by an HF staff: >>> !wget https://huggingface.co/spaces/ysharma/TranslateQuotesInImageForwards/resolve/main/arial-unicode-ms.ttf - >>> # load control net and stable diffusion v1-5 >>> anytext_controlnet = AnyTextControlNetModel.from_pretrained("tolgacangoz/anytext-controlnet", torch_dtype=torch.float16, ... variant="fp16",) >>> pipe = DiffusionPipeline.from_pretrained("tolgacangoz/anytext", font_path="arial-unicode-ms.ttf", ... controlnet=anytext_controlnet, torch_dtype=torch.float16, - ... trust_remote_code=True, + ... trust_remote_code=False, # One needs to give permission to run this pipeline's code ... ).to("cuda") >>> pipe.scheduler = DDIMScheduler.from_config(pipe.scheduler.config) - >>> # uncomment following line if PyTorch>=2.0 is not installed for memory optimization - >>> #pipe.enable_xformers_memory_efficient_attention() - - >>> # uncomment following line if you want to offload the model to CPU for memory optimization - >>> # also remove the `.to("cuda")` part - >>> #pipe.enable_model_cpu_offload() >>> # generate image >>> prompt = 'photo of caramel macchiato coffee on the table, top-down perspective, with "Any" "Text" written on it using cream' @@ -211,8 +205,8 @@ def __init__( embedder, placeholder_string="*", use_fp16=False, - token_dim = 768, - get_recog_emb = None, + token_dim=768, + get_recog_emb=None, ): super().__init__() get_token_for_string = partial(get_clip_token_for_string, embedder.tokenizer) @@ -227,9 +221,7 @@ def __init__( if use_fp16: self.proj = self.proj.to(dtype=torch.float16) - # self.register_parameter("proj", proj) self.placeholder_token = get_token_for_string(placeholder_string) - # self.register_config(placeholder_token=placeholder_token) @torch.no_grad() def encode_text(self, text_info): @@ -350,12 +342,19 @@ def create_predictor(model_lang="ch", device="cpu", use_fp16=False): n_class = 97 else: raise ValueError(f"Unsupported OCR recog model_lang: {model_lang}") - rec_config = dict( - in_channels=3, - backbone=dict(type="MobileNetV1Enhance", scale=0.5, last_conv_stride=[1, 2], last_pool_type="avg"), - neck=dict(type="SequenceEncoder", encoder_type="svtr", dims=64, depth=2, hidden_dims=120, use_guide=True), - head=dict(type="CTCHead", fc_decay=0.00001, out_channels=n_class, return_feats=True), - ) + rec_config = { + "in_channels": 3, + "backbone": {"type": "MobileNetV1Enhance", "scale": 0.5, "last_conv_stride": [1, 2], "last_pool_type": "avg"}, + "neck": { + "type": "SequenceEncoder", + "encoder_type": "svtr", + "dims": 64, + "depth": 2, + "hidden_dims": 120, + "use_guide": True, + }, + "head": {"type": "CTCHead", "fc_decay": 0.00001, "out_channels": n_class, "return_feats": True}, + } rec_model = RecModel(rec_config) state_dict = torch.load(model_dir, map_location=device) @@ -521,12 +520,6 @@ def get_ctcloss(self, preds, gt_text, weight): return loss -import torch -from torch import nn -from transformers import CLIPTextModel, CLIPTokenizer -from transformers.modeling_attn_mask_utils import _create_4d_causal_attention_mask, _prepare_4d_attention_mask - - class AbstractEncoder(nn.Module): def __init__(self): super().__init__() @@ -537,6 +530,7 @@ def encode(self, *args, **kwargs): class FrozenCLIPEmbedderT3(AbstractEncoder, ModelMixin, ConfigMixin): """Uses the CLIP transformer encoder for text (from Hugging Face)""" + @register_to_config def __init__( self, @@ -548,11 +542,13 @@ def __init__( ): super().__init__() self.tokenizer = CLIPTokenizer.from_pretrained("tolgacangoz/anytext", subfolder="tokenizer") - self.transformer = CLIPTextModel.from_pretrained("tolgacangoz/anytext", subfolder="text_encoder", - torch_dtype=torch.float16 if use_fp16 else torch.float32, - variant="fp16" if use_fp16 else None) - # self.device = device - # self.max_length = max_length + self.transformer = CLIPTextModel.from_pretrained( + "tolgacangoz/anytext", + subfolder="text_encoder", + torch_dtype=torch.float16 if use_fp16 else torch.float32, + variant="fp16" if use_fp16 else None, + ) + if freeze: self.freeze() @@ -731,11 +727,6 @@ def split_chunks(self, input_ids, chunk_size=75): tokens_list.append(remaining_group_pad) return tokens_list - # def to(self, *args, **kwargs): - # self.transformer = self.transformer.to(*args, **kwargs) - # self.device = self.transformer.device - # return self - class TextEmbeddingModule(ModelMixin, ConfigMixin): @register_to_config @@ -743,25 +734,21 @@ def __init__(self, font_path, use_fp16=False, device="cpu"): super().__init__() font = ImageFont.truetype(font_path, 60) - # self.use_fp16 = use_fp16 - # self.device = device self.frozen_CLIP_embedder_t3 = FrozenCLIPEmbedderT3(device=device, use_fp16=use_fp16) self.embedding_manager = EmbeddingManager(self.frozen_CLIP_embedder_t3, use_fp16=use_fp16) self.text_predictor = create_predictor(device=device, use_fp16=use_fp16).eval() - args = {"rec_image_shape": "3, 48, 320", - "rec_batch_num": 6, - "rec_char_dict_path": hf_hub_download( - repo_id="tolgacangoz/anytext", - filename="text_embedding_module/OCR/ppocr_keys_v1.txt", - cache_dir=HF_MODULES_CACHE, - ), - "use_fp16": use_fp16} + args = { + "rec_image_shape": "3, 48, 320", + "rec_batch_num": 6, + "rec_char_dict_path": hf_hub_download( + repo_id="tolgacangoz/anytext", + filename="text_embedding_module/OCR/ppocr_keys_v1.txt", + cache_dir=HF_MODULES_CACHE, + ), + "use_fp16": use_fp16, + } self.embedding_manager.recog = TextRecognizer(args, self.text_predictor) - # self.register_modules( - # frozen_CLIP_embedder_t3=frozen_CLIP_embedder_t3, - # embedding_manager=embedding_manager, - # ) self.register_to_config(font=font) @torch.no_grad() @@ -873,8 +860,6 @@ def forward( text_info["gly_line"] += [self.arr2tensor(gly_line, num_images_per_prompt)] text_info["positions"] += [self.arr2tensor(pos, num_images_per_prompt)] - # hint = self.arr2tensor(np_hint, len(prompt)) - self.embedding_manager.encode_text(text_info) prompt_embeds = self.frozen_CLIP_embedder_t3.encode([prompt], embedding_manager=self.embedding_manager) @@ -1028,11 +1013,6 @@ def insert_spaces(self, string, nSpace): new_string += char + " " * nSpace return new_string[:-nSpace] - # def to(self, *args, **kwargs): - # self.frozen_CLIP_embedder_t3 = self.frozen_CLIP_embedder_t3.to(*args, **kwargs) - # self.embedding_manager = self.embedding_manager.to(*args, **kwargs) - # return self - # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents def retrieve_latents( @@ -1052,13 +1032,10 @@ class AuxiliaryLatentModule(ModelMixin, ConfigMixin): @register_to_config def __init__( self, - # font_path, vae, device="cpu", ): super().__init__() - # self.font = ImageFont.truetype(font_path, 60) - # self.vae = vae.eval() if vae is not None else None @torch.no_grad() def forward( @@ -1100,7 +1077,9 @@ def forward( masked_img = torch.from_numpy(masked_img.copy()).float().to(device) if dtype == torch.float16: masked_img = masked_img.half() - masked_x = (retrieve_latents(self.config.vae.encode(masked_img[None, ...])) * self.config.vae.config.scaling_factor).detach() + masked_x = ( + retrieve_latents(self.config.vae.encode(masked_img[None, ...])) * self.config.vae.config.scaling_factor + ).detach() if dtype == torch.float16: masked_x = masked_x.half() text_info["masked_x"] = torch.cat([masked_x for _ in range(num_images_per_prompt)], dim=0) @@ -1140,11 +1119,6 @@ def insert_spaces(self, string, nSpace): new_string += char + " " * nSpace return new_string[:-nSpace] - # def to(self, *args, **kwargs): - # self.vae = self.vae.to(*args, **kwargs) - # self.device = self.vae.device - # return self - # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps def retrieve_timesteps( @@ -1277,15 +1251,8 @@ def __init__( if font_path is None: raise ValueError("font_path is required!") - text_embedding_module = TextEmbeddingModule( - font_path=font_path, - use_fp16=unet.dtype == torch.float16, - ) - auxiliary_latent_module = AuxiliaryLatentModule( - # font_path=font_path, - vae=vae, - # use_fp16=unet.dtype == torch.float16, - ) + text_embedding_module = TextEmbeddingModule(font_path=font_path, use_fp16=unet.dtype == torch.float16) + auxiliary_latent_module = AuxiliaryLatentModule(vae=vae) if safety_checker is None and requires_safety_checker: logger.warning( @@ -1324,7 +1291,7 @@ def __init__( self.control_image_processor = VaeImageProcessor( vae_scale_factor=self.vae_scale_factor, do_convert_rgb=True, do_normalize=False ) - self.register_to_config(requires_safety_checker=requires_safety_checker)#, font_path=font_path) + self.register_to_config(requires_safety_checker=requires_safety_checker) def modify_prompt(self, prompt): prompt = prompt.replace("“", '"') diff --git a/examples/research_projects/anytext/ocr_recog/RecModel.py b/examples/research_projects/anytext/ocr_recog/RecModel.py index 26c988333a5a..872ccade69e0 100755 --- a/examples/research_projects/anytext/ocr_recog/RecModel.py +++ b/examples/research_projects/anytext/ocr_recog/RecModel.py @@ -16,15 +16,15 @@ def __init__(self, config): assert "in_channels" in config, "in_channels must in model config" backbone_type = config["backbone"].pop("type") assert backbone_type in backbone_dict, f"backbone.type must in {backbone_dict}" - self.backbone = backbone_dict[backbone_type](config['in_channels'], **config['backbone']) + self.backbone = backbone_dict[backbone_type](config["in_channels"], **config["backbone"]) - neck_type = config['neck'].pop("type") + neck_type = config["neck"].pop("type") assert neck_type in neck_dict, f"neck.type must in {neck_dict}" - self.neck = neck_dict[neck_type](self.backbone.out_channels, **config['neck']) + self.neck = neck_dict[neck_type](self.backbone.out_channels, **config["neck"]) - head_type = config['head'].pop("type") + head_type = config["head"].pop("type") assert head_type in head_dict, f"head.type must in {head_dict}" - self.head = head_dict[head_type](self.neck.out_channels, **config['head']) + self.head = head_dict[head_type](self.neck.out_channels, **config["head"]) self.name = f"RecModel_{backbone_type}_{neck_type}_{head_type}" From 09fdd2228f4fef23f695507a7f3b57c2542a0f08 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tolga=20Cang=C3=B6z?= Date: Wed, 26 Feb 2025 12:54:54 +0300 Subject: [PATCH 83/87] [UPDATE] Revise README.md for clarity, remove unused imports in anytext.py, and add author credits in anytext_controlnet.py --- examples/research_projects/anytext/README.md | 23 +++++-------------- examples/research_projects/anytext/anytext.py | 2 -- .../anytext/anytext_controlnet.py | 8 +++++++ 3 files changed, 14 insertions(+), 19 deletions(-) diff --git a/examples/research_projects/anytext/README.md b/examples/research_projects/anytext/README.md index 2c2d2e131cdc..c291c23955c0 100644 --- a/examples/research_projects/anytext/README.md +++ b/examples/research_projects/anytext/README.md @@ -1,43 +1,32 @@ # AnyTextPipeline Pipeline -From the repo [page](https://github.com/tyxsspa/AnyText) +Project page: https://aigcdesigngroup.github.io/homepage_anytext "AnyText comprises a diffusion pipeline with two primary elements: an auxiliary latent module and a text embedding module. The former uses inputs like text glyph, position, and masked image to generate latent features for text generation or editing. The latter employs an OCR model for encoding stroke data as embeddings, which blend with image caption embeddings from the tokenizer to generate texts that seamlessly integrate with the background. We employed text-control diffusion loss and text perceptual loss for training to further enhance writing accuracy." -For any usage questions, please refer to the [paper](https://arxiv.org/abs/2311.03054). +Each text line that needs to be generated should be enclosed in double quotes. For any usage questions, please refer to the [paper](https://arxiv.org/abs/2311.03054). ```py import torch from diffusers import DiffusionPipeline from anytext_controlnet import AnyTextControlNetModel -from diffusers import DDIMScheduler from diffusers.utils import load_image - # I chose a font file shared by an HF staff: !wget https://huggingface.co/spaces/ysharma/TranslateQuotesInImageForwards/resolve/main/arial-unicode-ms.ttf -# load control net and stable diffusion v1-5 anytext_controlnet = AnyTextControlNetModel.from_pretrained("tolgacangoz/anytext-controlnet", torch_dtype=torch.float16, variant="fp16",) pipe = DiffusionPipeline.from_pretrained("tolgacangoz/anytext", font_path="arial-unicode-ms.ttf", - controlnet=anytext_controlnet, torch_dtype=torch.float16, - trust_remote_code=True, - ).to("cuda") - -pipe.scheduler = DDIMScheduler.from_config(pipe.scheduler.config) -# uncomment following line if PyTorch>=2.0 is not installed for memory optimization -#pipe.enable_xformers_memory_efficient_attention() - -# uncomment following line if you want to offload the model to CPU for memory optimization -# also remove the `.to("cuda")` part -#pipe.enable_model_cpu_offload() + controlnet=anytext_controlnet, torch_dtype=torch.float16, + trust_remote_code=False, # One needs to give permission to run this pipeline's code + ).to("cuda") # generate image prompt = 'photo of caramel macchiato coffee on the table, top-down perspective, with "Any" "Text" written on it using cream' draw_pos = load_image("https://raw.githubusercontent.com/tyxsspa/AnyText/refs/heads/main/example_images/gen9.png") image = pipe(prompt, num_inference_steps=20, mode="generate", draw_pos=draw_pos, - ).images[0] + ).images[0] image ``` diff --git a/examples/research_projects/anytext/anytext.py b/examples/research_projects/anytext/anytext.py index d497db2859d3..db54c3172ba9 100644 --- a/examples/research_projects/anytext/anytext.py +++ b/examples/research_projects/anytext/anytext.py @@ -149,7 +149,6 @@ def _is_whitespace(self, char): >>> import torch >>> from diffusers import DiffusionPipeline >>> from anytext_controlnet import AnyTextControlNetModel - >>> from diffusers import DDIMScheduler >>> from diffusers.utils import load_image >>> # I chose a font file shared by an HF staff: @@ -162,7 +161,6 @@ def _is_whitespace(self, char): ... trust_remote_code=False, # One needs to give permission to run this pipeline's code ... ).to("cuda") - >>> pipe.scheduler = DDIMScheduler.from_config(pipe.scheduler.config) >>> # generate image >>> prompt = 'photo of caramel macchiato coffee on the table, top-down perspective, with "Any" "Text" written on it using cream' diff --git a/examples/research_projects/anytext/anytext_controlnet.py b/examples/research_projects/anytext/anytext_controlnet.py index 51e47cdf6a6f..5965ceed1370 100644 --- a/examples/research_projects/anytext/anytext_controlnet.py +++ b/examples/research_projects/anytext/anytext_controlnet.py @@ -11,6 +11,14 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +# +# Based on [AnyText: Multilingual Visual Text Generation And Editing](https://huggingface.co/papers/2311.03054). +# Authors: Yuxiang Tuo, Wangmeng Xiang, Jun-Yan He, Yifeng Geng, Xuansong Xie +# Code: https://github.com/tyxsspa/AnyText with Apache-2.0 license +# +# Adapted to Diffusers by [M. Tolga Cangöz](https://github.com/tolgacangoz). + + from typing import Any, Dict, Optional, Tuple, Union import torch From 9495ddb92049a590728ba941428d9fbbbe953dda Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tolga=20Cang=C3=B6z?= Date: Wed, 26 Feb 2025 13:19:28 +0300 Subject: [PATCH 84/87] style --- examples/research_projects/anytext/README.md | 1 + 1 file changed, 1 insertion(+) diff --git a/examples/research_projects/anytext/README.md b/examples/research_projects/anytext/README.md index c291c23955c0..bbaecbf976a5 100644 --- a/examples/research_projects/anytext/README.md +++ b/examples/research_projects/anytext/README.md @@ -7,6 +7,7 @@ Project page: https://aigcdesigngroup.github.io/homepage_anytext Each text line that needs to be generated should be enclosed in double quotes. For any usage questions, please refer to the [paper](https://arxiv.org/abs/2311.03054). + ```py import torch from diffusers import DiffusionPipeline From f4abaf2196d90bc29933b8f2bf38202f80b1a49f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tolga=20Cang=C3=B6z?= <46008593+tolgacangoz@users.noreply.github.com> Date: Sat, 1 Mar 2025 20:26:52 +0300 Subject: [PATCH 85/87] Update examples/research_projects/anytext/README.md Co-authored-by: Aryan --- examples/research_projects/anytext/README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/research_projects/anytext/README.md b/examples/research_projects/anytext/README.md index bbaecbf976a5..592a9da248b0 100644 --- a/examples/research_projects/anytext/README.md +++ b/examples/research_projects/anytext/README.md @@ -15,7 +15,7 @@ from anytext_controlnet import AnyTextControlNetModel from diffusers.utils import load_image # I chose a font file shared by an HF staff: -!wget https://huggingface.co/spaces/ysharma/TranslateQuotesInImageForwards/resolve/main/arial-unicode-ms.ttf +# !wget https://huggingface.co/spaces/ysharma/TranslateQuotesInImageForwards/resolve/main/arial-unicode-ms.ttf anytext_controlnet = AnyTextControlNetModel.from_pretrained("tolgacangoz/anytext-controlnet", torch_dtype=torch.float16, variant="fp16",) From 8d313bce109a67f13f58dc68427ac7783f13a537 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tolga=20Cang=C3=B6z?= Date: Sun, 2 Mar 2025 10:33:02 +0300 Subject: [PATCH 86/87] Remove commented-out image preparation code in AnyTextPipeline --- examples/research_projects/anytext/anytext.py | 37 ------------------- 1 file changed, 37 deletions(-) diff --git a/examples/research_projects/anytext/anytext.py b/examples/research_projects/anytext/anytext.py index db54c3172ba9..518452f97942 100644 --- a/examples/research_projects/anytext/anytext.py +++ b/examples/research_projects/anytext/anytext.py @@ -2184,18 +2184,6 @@ def __call__( # 4. Prepare image if isinstance(controlnet, ControlNetModel): - # image = self.prepare_image( - # image=image, - # width=width, - # height=height, - # batch_size=batch_size * num_images_per_prompt, - # num_images_per_prompt=num_images_per_prompt, - # device=device, - # dtype=controlnet.dtype, - # do_classifier_free_guidance=self.do_classifier_free_guidance, - # guess_mode=guess_mode, - # ) - # height, width = image.shape[-2:] guided_hint = self.auxiliary_latent_module( text_info=text_info, mode=mode, @@ -2205,31 +2193,6 @@ def __call__( np_hint=np_hint, ) height, width = 512, 512 - # elif isinstance(controlnet, MultiControlNetModel): - # images = [] - - # # Nested lists as ControlNet condition - # if isinstance(image[0], list): - # # Transpose the nested image list - # image = [list(t) for t in zip(*image)] - - # for image_ in image: - # image_ = self.prepare_image( - # image=image_, - # width=width, - # height=height, - # batch_size=batch_size * num_images_per_prompt, - # num_images_per_prompt=num_images_per_prompt, - # device=device, - # dtype=controlnet.dtype, - # do_classifier_free_guidance=self.do_classifier_free_guidance, - # guess_mode=guess_mode, - # ) - - # images.append(image_) - - # image = images - # height, width = image[0].shape[-2:] else: assert False From d02615f179cb5686b93b1b1c4f140618707273fe Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tolga=20Cang=C3=B6z?= Date: Sun, 2 Mar 2025 10:59:19 +0300 Subject: [PATCH 87/87] Remove unnecessary blank line in README.md --- examples/research_projects/anytext/README.md | 1 - 1 file changed, 1 deletion(-) diff --git a/examples/research_projects/anytext/README.md b/examples/research_projects/anytext/README.md index 592a9da248b0..f5f4fe59ddfd 100644 --- a/examples/research_projects/anytext/README.md +++ b/examples/research_projects/anytext/README.md @@ -7,7 +7,6 @@ Project page: https://aigcdesigngroup.github.io/homepage_anytext Each text line that needs to be generated should be enclosed in double quotes. For any usage questions, please refer to the [paper](https://arxiv.org/abs/2311.03054). - ```py import torch from diffusers import DiffusionPipeline